spacr 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/sim.py ADDED
@@ -0,0 +1,1187 @@
1
+
2
+ import os, gc, random, warnings, traceback, itertools, matplotlib, sqlite3
3
+ import time as tm
4
+ from time import time, sleep
5
+ from datetime import datetime
6
+ import numpy as np
7
+ import pandas as pd
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import sklearn.metrics as metrics
11
+ from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix, precision_recall_curve
12
+ import statsmodels.api as sm
13
+ from multiprocessing import cpu_count, Value, Array, Lock, Pool, Manager
14
+
15
+ from .logger import log_function_call
16
+
17
+ warnings.filterwarnings("ignore")
18
+ warnings.filterwarnings("ignore", category=RuntimeWarning) # Ignore RuntimeWarning
19
+
20
+ def generate_gene_list(number_of_genes, number_of_all_genes):
21
+ """
22
+ Generates a list of randomly selected genes.
23
+
24
+ Args:
25
+ number_of_genes (int): The number of genes to be selected.
26
+ number_of_all_genes (int): The total number of genes available.
27
+
28
+ Returns:
29
+ list: A list of randomly selected genes.
30
+ """
31
+ genes_ls = list(range(number_of_all_genes))
32
+ random.shuffle(genes_ls)
33
+ gene_list = genes_ls[:number_of_genes]
34
+ return gene_list
35
+
36
+ # plate_map is a table with a row for each well, containing well metadata: plate_id, row_id, and column_id
37
+ def generate_plate_map(nr_plates):
38
+ """
39
+ Generate a plate map based on the number of plates.
40
+
41
+ Parameters:
42
+ nr_plates (int): The number of plates to generate the map for.
43
+
44
+ Returns:
45
+ pandas.DataFrame: The generated plate map dataframe.
46
+ """
47
+ plate_row_column = [f"{i+1}_{ir+1}_{ic+1}" for i in range(nr_plates) for ir in range(16) for ic in range(24)]
48
+ df= pd.DataFrame({'plate_row_column': plate_row_column})
49
+ df["plate_id"], df["row_id"], df["column_id"] = zip(*[r.split("_") for r in df['plate_row_column']])
50
+ return df
51
+
52
+ def gini_coefficient(x):
53
+ """
54
+ Compute Gini coefficient of array of values.
55
+
56
+ Parameters:
57
+ x (array-like): Array of values.
58
+
59
+ Returns:
60
+ float: Gini coefficient.
61
+
62
+ """
63
+ diffsum = np.sum(np.abs(np.subtract.outer(x, x)))
64
+ return diffsum / (2 * len(x) ** 2 * np.mean(x))
65
+
66
+ def gini(x):
67
+ """
68
+ Calculate the Gini coefficient for a given array of values.
69
+
70
+ Parameters:
71
+ x (array-like): Input array of values.
72
+
73
+ Returns:
74
+ float: The Gini coefficient.
75
+
76
+ Notes:
77
+ This implementation has a time and memory complexity of O(n**2), where n is the length of x.
78
+ Avoid passing in large samples to prevent performance issues.
79
+ """
80
+ # Mean absolute difference
81
+ mad = np.abs(np.subtract.outer(x, x)).mean()
82
+ # Relative mean absolute difference
83
+ rmad = mad/np.mean(x)
84
+ # Gini coefficient
85
+ g = 0.5 * rmad
86
+ return g
87
+
88
+ def gini_gene_well(x):
89
+ """
90
+ Calculate the Gini coefficient for a given income distribution.
91
+
92
+ The Gini coefficient measures income inequality in a population.
93
+ A value of 0 represents perfect income equality (everyone has the same income),
94
+ while a value of 1 represents perfect income inequality (one individual has all the income).
95
+
96
+ Parameters:
97
+ x (array-like): An array-like object representing the income distribution.
98
+
99
+ Returns:
100
+ float: The Gini coefficient for the given income distribution.
101
+ """
102
+ total = 0
103
+ for i, xi in enumerate(x[:-1], 1):
104
+ total += np.sum(np.abs(xi - x[i:]))
105
+ return total / (len(x)**2 * np.mean(x))
106
+
107
+ def gini(x):
108
+ """
109
+ Calculate the Gini coefficient for a given array of values.
110
+
111
+ Parameters:
112
+ x (array-like): The input array of values.
113
+
114
+ Returns:
115
+ float: The Gini coefficient.
116
+
117
+ References:
118
+ - Based on bottom eq: http://www.statsdirect.com/help/content/image/stat0206_wmf.gif
119
+ - From: http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm
120
+ - All values are treated equally, arrays must be 1d.
121
+ """
122
+ x = np.array(x, dtype=np.float64)
123
+ n = len(x)
124
+ s = x.sum()
125
+ r = np.argsort(np.argsort(-x)) # ranks of x
126
+ return 1 - (2 * (r * x).sum() + s) / (n * s)
127
+
128
+ def dist_gen(mean, sd, df):
129
+ """
130
+ Generate a Poisson distribution based on a gamma distribution.
131
+
132
+ Parameters:
133
+ mean (float): Mean of the gamma distribution.
134
+ sd (float): Standard deviation of the gamma distribution.
135
+ df (pandas.DataFrame): Input data.
136
+
137
+ Returns:
138
+ tuple: A tuple containing the generated Poisson distribution and the length of the input data.
139
+ """
140
+ length = len(df)
141
+ shape = (mean / sd) ** 2 # Calculate shape parameter
142
+ scale = (sd ** 2) / mean # Calculate scale parameter
143
+ rate = np.random.gamma(shape, scale, size=length) # Generate random rate from gamma distribution
144
+ data = np.random.poisson(rate) # Use the random rate for a Poisson distribution
145
+ return data, length
146
+
147
+ def generate_gene_weights(positive_mean, positive_variance, df):
148
+ """
149
+ Generate gene weights using a beta distribution.
150
+
151
+ Parameters:
152
+ - positive_mean (float): The mean value for the positive distribution.
153
+ - positive_variance (float): The variance value for the positive distribution.
154
+ - df (pandas.DataFrame): The DataFrame containing the data.
155
+
156
+ Returns:
157
+ - weights (numpy.ndarray): An array of gene weights generated using a beta distribution.
158
+ """
159
+ # alpha and beta for positive distribution
160
+ a1 = positive_mean*(positive_mean*(1-positive_mean)/positive_variance - 1)
161
+ b1 = a1*(1-positive_mean)/positive_mean
162
+ weights = np.random.beta(a1, b1, len(df))
163
+ return weights
164
+
165
+ def normalize_array(arr):
166
+ """
167
+ Normalize an array by scaling its values between 0 and 1.
168
+
169
+ Parameters:
170
+ arr (numpy.ndarray): The input array to be normalized.
171
+
172
+ Returns:
173
+ numpy.ndarray: The normalized array.
174
+
175
+ """
176
+ min_value = np.min(arr)
177
+ max_value = np.max(arr)
178
+ normalized_arr = (arr - min_value) / (max_value - min_value)
179
+ return normalized_arr
180
+
181
+ def generate_power_law_distribution(num_elements, coeff):
182
+ """
183
+ Generate a power law distribution.
184
+
185
+ Parameters:
186
+ - num_elements (int): The number of elements in the distribution.
187
+ - coeff (float): The coefficient of the power law.
188
+
189
+ Returns:
190
+ - normalized_distribution (ndarray): The normalized power law distribution.
191
+ """
192
+ base_distribution = np.arange(1, num_elements + 1)
193
+ powered_distribution = base_distribution ** -coeff
194
+ normalized_distribution = powered_distribution / np.sum(powered_distribution)
195
+ return normalized_distribution
196
+
197
+ # distribution generator function
198
+ def power_law_dist_gen(df, avg, well_ineq_coeff):
199
+ """
200
+ Generate a power-law distribution for wells.
201
+
202
+ Parameters:
203
+ - df: DataFrame
204
+ The input DataFrame containing the wells.
205
+ - avg: float
206
+ The average value for the distribution.
207
+ - well_ineq_coeff: float
208
+ The inequality coefficient for the power-law distribution.
209
+
210
+ Returns:
211
+ - dist: ndarray
212
+ The generated power-law distribution for the wells.
213
+ """
214
+ # Generate a power-law distribution for wells
215
+ distribution = generate_power_law_distribution(len(df), well_ineq_coeff)
216
+ dist = np.random.choice(distribution, len(df)) * avg
217
+ return dist
218
+
219
+ # plates is a table with for each cell in the experiment with columns [plate_id, row_id, column_id, gene_id, is_active]
220
+ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_well, sd_genes_per_well, avg_cells_per_well, sd_cells_per_well, well_ineq_coeff, gene_ineq_coeff):
221
+ """
222
+ Run a simulation experiment.
223
+
224
+ Args:
225
+ plate_map (DataFrame): The plate map containing information about the wells.
226
+ number_of_genes (int): The total number of genes.
227
+ active_gene_list (list): The list of active genes.
228
+ avg_genes_per_well (float): The average number of genes per well.
229
+ sd_genes_per_well (float): The standard deviation of genes per well.
230
+ avg_cells_per_well (float): The average number of cells per well.
231
+ sd_cells_per_well (float): The standard deviation of cells per well.
232
+ well_ineq_coeff (float): The coefficient for well inequality.
233
+ gene_ineq_coeff (float): The coefficient for gene inequality.
234
+
235
+ Returns:
236
+ tuple: A tuple containing the following:
237
+ - cell_df (DataFrame): The DataFrame containing information about the cells.
238
+ - genes_per_well_df (DataFrame): The DataFrame containing gene counts per well.
239
+ - wells_per_gene_df (DataFrame): The DataFrame containing well counts per gene.
240
+ - df_ls (list): A list containing gene counts per well, well counts per gene, Gini coefficients for wells,
241
+ Gini coefficients for genes, gene weights array, and well weights.
242
+ """
243
+ #generate primary distributions and genes
244
+ cpw, _ = dist_gen(avg_cells_per_well, sd_cells_per_well, plate_map)
245
+ gpw, _ = dist_gen(avg_genes_per_well, sd_genes_per_well, plate_map)
246
+ genes = [*range(1, number_of_genes+1, 1)]
247
+
248
+ #gene_weights = generate_power_law_distribution(number_of_genes, gene_ineq_coeff)
249
+ gene_weights = {gene: weight for gene, weight in zip(genes, generate_power_law_distribution(number_of_genes, gene_ineq_coeff))} # Generate gene_weights as a dictionary
250
+ gene_weights_array = np.array(list(gene_weights.values())) # Convert the values to an array
251
+
252
+ well_weights = generate_power_law_distribution(len(plate_map), well_ineq_coeff)
253
+
254
+ gene_to_well_mapping = {}
255
+
256
+ for gene in genes:
257
+ gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=int(gpw[gene-1]), p=well_weights) # Generate a number of wells for each gene according to well_weights
258
+
259
+ gene_to_well_mapping = {gene: wells for gene, wells in gene_to_well_mapping.items() if len(wells) >= 2}
260
+
261
+ cells = []
262
+ for i in [*range(0,len(plate_map))]:
263
+ ciw = random.choice(cpw)
264
+ present_genes = [gene for gene, wells in gene_to_well_mapping.items() if plate_map.loc[i, 'plate_row_column'] in wells] # Select genes present in the current well
265
+ present_gene_weights = [gene_weights[gene] for gene in present_genes] # For sampling, filter gene_weights according to present_genes
266
+ present_gene_weights /= np.sum(present_gene_weights)
267
+ if present_genes:
268
+ giw = np.random.choice(present_genes, int(gpw[i]), p=present_gene_weights)
269
+ if len(giw) > 0:
270
+ for _ in range(0,int(ciw)):
271
+ gene_nr = random.choice(giw)
272
+ cell = {
273
+ 'plate_row_column': plate_map.loc[i, 'plate_row_column'],
274
+ 'plate_id': plate_map.loc[i, 'plate_id'],
275
+ 'row_id': plate_map.loc[i, 'row_id'],
276
+ 'column_id': plate_map.loc[i, 'column_id'],
277
+ 'genes_in_well': len(giw),
278
+ 'gene_id': gene_nr,
279
+ 'is_active': int(gene_nr in active_gene_list)
280
+ }
281
+ cells.append(cell)
282
+
283
+ cell_df = pd.DataFrame(cells)
284
+ cell_df = cell_df.dropna()
285
+
286
+ # calculate well, gene counts per well
287
+ gene_counts_per_well = cell_df.groupby('plate_row_column')['gene_id'].nunique().sort_values().tolist()
288
+ well_counts_per_gene = cell_df.groupby('gene_id')['plate_row_column'].nunique().sort_values().tolist()
289
+
290
+ # Create DataFrames
291
+ genes_per_well_df = pd.DataFrame(gene_counts_per_well, columns=['genes_per_well'])
292
+ genes_per_well_df['rank'] = range(1, len(genes_per_well_df) + 1)
293
+ wells_per_gene_df = pd.DataFrame(well_counts_per_gene, columns=['wells_per_gene'])
294
+ wells_per_gene_df['rank'] = range(1, len(wells_per_gene_df) + 1)
295
+
296
+ ls_ = []
297
+ gini_ls = []
298
+ for i,val in enumerate(cell_df['plate_row_column'].unique().tolist()):
299
+ temp = cell_df[cell_df['plate_row_column']==val]
300
+ x = temp['gene_id'].value_counts().to_numpy()
301
+ gini_val = gini_gene_well(x)
302
+ ls_.append(val)
303
+ gini_ls.append(gini_val)
304
+ gini_well = np.array(gini_ls)
305
+
306
+ ls_ = []
307
+ gini_ls = []
308
+ for i,val in enumerate(cell_df['gene_id'].unique().tolist()):
309
+ temp = cell_df[cell_df['gene_id']==val]
310
+ x = temp['plate_row_column'].value_counts().to_numpy()
311
+ gini_val = gini_gene_well(x)
312
+ ls_.append(val)
313
+ gini_ls.append(gini_val)
314
+ gini_gene = np.array(gini_ls)
315
+ df_ls = [gene_counts_per_well, well_counts_per_gene, gini_well, gini_gene, gene_weights_array, well_weights]
316
+ return cell_df, genes_per_well_df, wells_per_gene_df, df_ls
317
+
318
+ # classifier is a function that takes a cell state (active=1/inactive=0) and produces a score in [0, 1]
319
+ # For the input cell, it checks if it is active or inactive, and then samples from an appropriate beta distribution to give a score
320
+ def classifier(positive_mean, positive_variance, negative_mean, negative_variance, df):
321
+ """
322
+ Classifies the data in the DataFrame based on the given parameters.
323
+
324
+ Args:
325
+ positive_mean (float): The mean of the positive distribution.
326
+ positive_variance (float): The variance of the positive distribution.
327
+ negative_mean (float): The mean of the negative distribution.
328
+ negative_variance (float): The variance of the negative distribution.
329
+ df (pandas.DataFrame): The DataFrame containing the data to be classified.
330
+
331
+ Returns:
332
+ pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
333
+ """
334
+ # alpha and beta for positive distribution
335
+ a1 = positive_mean*(positive_mean*(1-positive_mean)/positive_variance - 1)
336
+ b1 = a1*(1-positive_mean)/positive_mean
337
+ # alpha and beta for negative distribution
338
+ a2 = negative_mean*(negative_mean*(1-negative_mean)/negative_variance - 1)
339
+ b2 = a2*(1-negative_mean)/negative_mean
340
+ df['score'] = df['is_active'].apply(lambda is_active: np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2))
341
+ return df
342
+
343
+ def compute_roc_auc(cell_scores):
344
+ """
345
+ Compute the Receiver Operating Characteristic (ROC) Area Under the Curve (AUC) for cell scores.
346
+
347
+ Parameters:
348
+ - cell_scores (DataFrame): DataFrame containing cell scores with columns 'is_active' and 'score'.
349
+
350
+ Returns:
351
+ - cell_roc_dict (dict): Dictionary containing the ROC curve information, including the threshold, true positive rate (TPR),
352
+ false positive rate (FPR), and ROC AUC.
353
+
354
+ """
355
+ fpr, tpr, thresh = roc_curve(cell_scores['is_active'], cell_scores['score'], pos_label=1)
356
+ roc_auc = auc(fpr, tpr)
357
+ cell_roc_dict = {'threshold':thresh,'tpr': tpr,'fpr': fpr, 'roc_auc':roc_auc}
358
+ return cell_roc_dict
359
+
360
+ def compute_precision_recall(cell_scores):
361
+ """
362
+ Compute precision, recall, F1 score, and PR AUC for a given set of cell scores.
363
+
364
+ Parameters:
365
+ - cell_scores (DataFrame): A DataFrame containing the cell scores with columns 'is_active' and 'score'.
366
+
367
+ Returns:
368
+ - cell_pr_dict (dict): A dictionary containing the computed precision, recall, F1 score, PR AUC, and threshold values.
369
+ """
370
+ pr, re, th = precision_recall_curve(cell_scores['is_active'], cell_scores['score'])
371
+ th = np.insert(th, 0, 0)
372
+ f1_score = 2 * (pr * re) / (pr + re)
373
+ pr_auc = auc(re, pr)
374
+ cell_pr_dict = {'threshold':th,'precision': pr,'recall': re, 'f1_score':f1_score, 'pr_auc': pr_auc}
375
+ return cell_pr_dict
376
+
377
+ def get_optimum_threshold(cell_pr_dict):
378
+ """
379
+ Calculates the optimum threshold based on the f1_score in the given cell_pr_dict.
380
+
381
+ Parameters:
382
+ cell_pr_dict (dict): A dictionary containing precision, recall, and f1_score values for different thresholds.
383
+
384
+ Returns:
385
+ float: The optimum threshold value.
386
+ """
387
+ cell_pr_dict_df = pd.DataFrame(cell_pr_dict)
388
+ max_x = cell_pr_dict_df.loc[cell_pr_dict_df['f1_score'].idxmax()]
389
+ optimum = float(max_x['threshold'])
390
+ return optimum
391
+
392
+ def update_scores_and_get_cm(cell_scores, optimum):
393
+ """
394
+ Update the cell scores based on the given optimum value and calculate the confusion matrix.
395
+
396
+ Args:
397
+ cell_scores (DataFrame): The DataFrame containing the cell scores.
398
+ optimum (float): The optimum value used for updating the scores.
399
+
400
+ Returns:
401
+ tuple: A tuple containing the updated cell scores DataFrame and the confusion matrix.
402
+ """
403
+ cell_scores[optimum] = cell_scores.score.map(lambda x: 1 if x >= optimum else 0)
404
+ cell_cm = metrics.confusion_matrix(cell_scores.is_active, cell_scores[optimum])
405
+ return cell_scores, cell_cm
406
+
407
+ def cell_level_roc_auc(cell_scores):
408
+ """
409
+ Compute the ROC AUC and precision-recall metrics at the cell level.
410
+
411
+ Args:
412
+ cell_scores (list): List of scores for each cell.
413
+
414
+ Returns:
415
+ cell_roc_dict_df (DataFrame): DataFrame containing the ROC AUC metrics for each cell.
416
+ cell_pr_dict_df (DataFrame): DataFrame containing the precision-recall metrics for each cell.
417
+ cell_scores (list): Updated list of scores after applying the optimum threshold.
418
+ cell_cm (array): Confusion matrix for the cell-level classification.
419
+ """
420
+ cell_roc_dict = compute_roc_auc(cell_scores)
421
+ cell_pr_dict = compute_precision_recall(cell_scores)
422
+ optimum = get_optimum_threshold(cell_pr_dict)
423
+ cell_scores, cell_cm = update_scores_and_get_cm(cell_scores, optimum)
424
+ cell_pr_dict['optimum'] = optimum
425
+ cell_roc_dict_df = pd.DataFrame(cell_roc_dict)
426
+ cell_pr_dict_df = pd.DataFrame(cell_pr_dict)
427
+ return cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm
428
+
429
+ def generate_well_score(cell_scores):
430
+ """
431
+ Generate well scores based on cell scores.
432
+
433
+ Args:
434
+ cell_scores (DataFrame): DataFrame containing cell scores.
435
+
436
+ Returns:
437
+ DataFrame: DataFrame containing well scores with average active score, gene list, and score.
438
+
439
+ """
440
+ # Compute mean and list of unique gene_ids
441
+ well_score = cell_scores.groupby(['plate_row_column']).agg(
442
+ average_active_score=('is_active', 'mean'),
443
+ gene_list=('gene_id', lambda x: np.unique(x).tolist()))
444
+ well_score['score'] = np.log10(well_score['average_active_score'] + 1)
445
+ return well_score
446
+
447
+ def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_per_gene, sequencing_error=0.01):
448
+ """
449
+ Simulates the sequencing of plates and calculates gene fractions and metadata.
450
+
451
+ Parameters:
452
+ well_score (pd.DataFrame): DataFrame containing well scores and gene lists.
453
+ number_of_genes (int): Number of genes.
454
+ avg_reads_per_gene (float): Average number of reads per gene.
455
+ sd_reads_per_gene (float): Standard deviation of reads per gene.
456
+ sequencing_error (float, optional): Probability of introducing sequencing error. Defaults to 0.01.
457
+
458
+ Returns:
459
+ gene_fraction_map (pd.DataFrame): DataFrame containing gene fractions for each well.
460
+ metadata (pd.DataFrame): DataFrame containing metadata for each well.
461
+ """
462
+ reads, _ = dist_gen(avg_reads_per_gene, sd_reads_per_gene, well_score)
463
+ gene_names = [f'gene_{v}' for v in range(number_of_genes+1)]
464
+ all_wells = well_score.index
465
+
466
+ gene_counts_map = pd.DataFrame(np.zeros((len(all_wells), number_of_genes+1)), columns=gene_names, index=all_wells)
467
+ sum_reads = []
468
+
469
+ for _, row in well_score.iterrows():
470
+ gene_list = row['gene_list']
471
+
472
+ if gene_list:
473
+ for gene in gene_list:
474
+ gene_count = int(random.choice(reads))
475
+
476
+ # Decide whether to introduce error or not
477
+ error = np.random.binomial(1, sequencing_error)
478
+ if error:
479
+ # Randomly select a different well
480
+ wrong_well = np.random.choice(all_wells)
481
+ gene_counts_map.loc[wrong_well, f'gene_{int(gene)}'] += gene_count
482
+ else:
483
+ gene_counts_map.loc[_, f'gene_{int(gene)}'] += gene_count
484
+
485
+ sum_reads.append(np.sum(gene_counts_map.loc[_, :]))
486
+
487
+ gene_fraction_map = gene_counts_map.div(gene_counts_map.sum(axis=1), axis=0)
488
+ gene_fraction_map = gene_fraction_map.fillna(0)
489
+
490
+ metadata = pd.DataFrame(index=well_score.index)
491
+ metadata['genes_in_well'] = gene_fraction_map.astype(bool).sum(axis=1)
492
+ metadata['sum_fractions'] = gene_fraction_map.sum(axis=1)
493
+ metadata['sum_reads'] = sum_reads
494
+
495
+ return gene_fraction_map, metadata
496
+
497
+ #metadata['sum_reads'] = metadata['sum_fractions'].div(metadata['genes_in_well'])
498
+ def regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False):
499
+ """
500
+ Calculate regression ROC AUC and other statistics.
501
+
502
+ Parameters:
503
+ results_df (DataFrame): DataFrame containing the results of regression analysis.
504
+ active_gene_list (list): List of active gene IDs.
505
+ control_gene_list (list): List of control gene IDs.
506
+ alpha (float, optional): Significance level for determining hits. Default is 0.05.
507
+ optimal (bool, optional): Whether to use the optimal threshold for classification. Default is False.
508
+
509
+ Returns:
510
+ tuple: A tuple containing the following:
511
+ - results_df (DataFrame): Updated DataFrame with additional columns.
512
+ - reg_roc_dict_df (DataFrame): DataFrame containing regression ROC curve data.
513
+ - reg_pr_dict_df (DataFrame): DataFrame containing precision-recall curve data.
514
+ - reg_cm (ndarray): Confusion matrix.
515
+ - sim_stats (DataFrame): DataFrame containing simulation statistics.
516
+ """
517
+ results_df = results_df.rename(columns={"P>|t|": "p"})
518
+
519
+ # asign active genes a value of 1 and inactive genes a value of 0
520
+ actives_list = ['gene_' + str(i) for i in active_gene_list]
521
+ results_df['active'] = results_df['gene'].apply(lambda x: 1 if x in actives_list else 0)
522
+ results_df['active'].fillna(0, inplace=True)
523
+
524
+ #generate a colun to color control,active and inactive genes
525
+ controls_list = ['gene_' + str(i) for i in control_gene_list]
526
+ results_df['color'] = results_df['gene'].apply(lambda x: 'control' if x in controls_list else ('active' if x in actives_list else 'inactive'))
527
+
528
+ #generate a size column and handdf.replace([np.inf, -np.inf], np.nan, inplace=True)le infinate and NaN values create a new column for -log(p)
529
+ results_df['size'] = results_df['active']
530
+ results_df['p'] = results_df['p'].clip(lower=0.0001)
531
+ results_df['logp'] = -np.log10(results_df['p'])
532
+
533
+ #calculate cutoff for hits based on randomly chosen 'control' genes
534
+ control_df = results_df[results_df['color'] == 'control']
535
+ control_mean = control_df['coef'].mean()
536
+ #control_std = control_df['coef'].std()
537
+ control_var = control_df['coef'].var()
538
+ cutoff = abs(control_mean)+(3*control_var)
539
+
540
+ #calculate discriptive statistics for active genes
541
+ active_df = results_df[results_df['color'] == 'active']
542
+ active_mean = active_df['coef'].mean()
543
+ active_std = active_df['coef'].std()
544
+ active_var = active_df['coef'].var()
545
+
546
+ #calculate discriptive statistics for active genes
547
+ inactive_df = results_df[results_df['color'] == 'inactive']
548
+ inactive_mean = inactive_df['coef'].mean()
549
+ inactive_std = inactive_df['coef'].std()
550
+ inactive_var = inactive_df['coef'].var()
551
+
552
+ #generate score column for hits and non hitts
553
+ results_df['score'] = np.where(((results_df['coef'] >= cutoff) | (results_df['coef'] <= -cutoff)) & (results_df['p'] <= alpha), 1, 0)
554
+
555
+ #calculate regression roc based on controll cutoff
556
+ fpr, tpr, thresh = roc_curve(results_df['active'], results_df['score'])
557
+ roc_auc = auc(fpr, tpr)
558
+ reg_roc_dict_df = pd.DataFrame({'threshold':thresh, 'tpr': tpr, 'fpr': fpr, 'roc_auc':roc_auc})
559
+
560
+ pr, re, th = precision_recall_curve(results_df['active'], results_df['score'])
561
+ th = np.insert(th, 0, 0)
562
+ f1_score = 2 * (pr * re) / (pr + re)
563
+ pr_auc = auc(re, pr)
564
+ reg_pr_dict_df = pd.DataFrame({'threshold':th, 'precision': pr, 'recall': re, 'f1_score':f1_score, 'pr_auc': pr_auc})
565
+
566
+ optimal_threshold = reg_pr_dict_df['f1_score'].idxmax()
567
+ if optimal:
568
+ results_df[optimal_threshold] = results_df.score.apply(lambda x: 1 if x >= optimal_threshold else 0)
569
+ reg_cm = confusion_matrix(results_df.active, results_df[optimal_threshold])
570
+ else:
571
+ results_df[0.5] = results_df.score.apply(lambda x: 1 if x >= 0.5 else 0)
572
+ reg_cm = confusion_matrix(results_df.active, results_df[0.5])
573
+
574
+ TN = reg_cm[0][0]
575
+ FP = reg_cm[0][1]
576
+ FN = reg_cm[1][0]
577
+ TP = reg_cm[1][1]
578
+
579
+ accuracy = (TP + TN) / (TP + FP + FN + TN) # Accuracy
580
+ sim_stats = {'optimal_threshold':optimal_threshold,
581
+ 'accuracy': accuracy,
582
+ 'prauc':pr_auc,
583
+ 'roc_auc':roc_auc,
584
+ 'inactive_mean':inactive_mean,
585
+ 'inactive_std':inactive_std,
586
+ 'inactive_var':inactive_var,
587
+ 'active_mean':active_mean,
588
+ 'active_std':active_std,
589
+ 'active_var':active_var,
590
+ 'cutoff':cutoff,
591
+ 'TP':TP,
592
+ 'FP':FP,
593
+ 'TN':TN,
594
+ 'FN':FN}
595
+
596
+ return results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, pd.DataFrame([sim_stats])
597
+
598
+ def plot_histogram(data, x_label, ax, color, title, binwidth=0.01, log=False):
599
+ """
600
+ Plots a histogram of the given data.
601
+
602
+ Parameters:
603
+ - data: The data to be plotted.
604
+ - x_label: The label for the x-axis.
605
+ - ax: The matplotlib axis object to plot on.
606
+ - color: The color of the histogram bars.
607
+ - title: The title of the plot.
608
+ - binwidth: The width of each histogram bin.
609
+ - log: Whether to use a logarithmic scale for the y-axis.
610
+
611
+ Returns:
612
+ None
613
+ """
614
+ sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
615
+ legend=False, fill=True, element='step', palette='dark')
616
+ if log:
617
+ ax.set_yscale('log')
618
+ ax.set_title(title)
619
+ ax.set_xlabel(x_label)
620
+
621
+ def plot_roc_pr(data, ax, title, x_label, y_label):
622
+ """
623
+ Plot the ROC (Receiver Operating Characteristic) and PR (Precision-Recall) curves.
624
+
625
+ Parameters:
626
+ - data: DataFrame containing the data to be plotted.
627
+ - ax: The matplotlib axes object to plot on.
628
+ - title: The title of the plot.
629
+ - x_label: The label for the x-axis.
630
+ - y_label: The label for the y-axis.
631
+ """
632
+ ax.plot(data[x_label], data[y_label], color='black', lw=0.5)
633
+ ax.plot([0, 1], [0, 1], color='black', lw=0.5, linestyle="--", label='random classifier')
634
+ ax.set_title(title)
635
+ ax.set_ylabel(y_label)
636
+ ax.set_xlabel(x_label)
637
+ ax.legend(loc="lower right")
638
+
639
+ def plot_confusion_matrix(data, ax, title):
640
+ """
641
+ Plots a confusion matrix using a heatmap.
642
+
643
+ Parameters:
644
+ data (numpy.ndarray): The confusion matrix data.
645
+ ax (matplotlib.axes.Axes): The axes object to plot the heatmap on.
646
+ title (str): The title of the plot.
647
+
648
+ Returns:
649
+ None
650
+ """
651
+ group_names = ['True Neg','False Pos','False Neg','True Pos']
652
+ group_counts = ["{0:0.0f}".format(value) for value in data.flatten()]
653
+ group_percentages = ["{0:.2%}".format(value) for value in data.flatten()/np.sum(data)]
654
+
655
+ sns.heatmap(data, cmap='Blues', ax=ax)
656
+ for i in range(data.shape[0]):
657
+ for j in range(data.shape[1]):
658
+ ax.text(j+0.5, i+0.5, f'{group_names[i*2+j]}\n{group_counts[i*2+j]}\n{group_percentages[i*2+j]}',
659
+ ha="center", va="center", color="black")
660
+
661
+ ax.set_title(title)
662
+ ax.set_xlabel('\nPredicted Values')
663
+ ax.set_ylabel('Actual Values ')
664
+ ax.xaxis.set_ticklabels(['False','True'])
665
+ ax.yaxis.set_ticklabels(['False','True'])
666
+
667
+
668
+ def run_simulation(settings):
669
+ """
670
+ Run the simulation based on the given settings.
671
+
672
+ Args:
673
+ settings (dict): A dictionary containing the simulation settings.
674
+
675
+ Returns:
676
+ tuple: A tuple containing the simulation results and distances.
677
+ - cell_scores (DataFrame): Scores for each cell.
678
+ - cell_roc_dict_df (DataFrame): ROC AUC scores for each cell.
679
+ - cell_pr_dict_df (DataFrame): Precision-Recall AUC scores for each cell.
680
+ - cell_cm (DataFrame): Confusion matrix for each cell.
681
+ - well_score (DataFrame): Scores for each well.
682
+ - gene_fraction_map (DataFrame): Fraction of genes for each well.
683
+ - metadata (DataFrame): Metadata for each well.
684
+ - results_df (DataFrame): Results of the regression analysis.
685
+ - reg_roc_dict_df (DataFrame): ROC AUC scores for each gene.
686
+ - reg_pr_dict_df (DataFrame): Precision-Recall AUC scores for each gene.
687
+ - reg_cm (DataFrame): Confusion matrix for each gene.
688
+ - sim_stats (dict): Additional simulation statistics.
689
+ - genes_per_well_df (DataFrame): Number of genes per well.
690
+ - wells_per_gene_df (DataFrame): Number of wells per gene.
691
+ dists (list): List of distances.
692
+ """
693
+ #try:
694
+ active_gene_list = generate_gene_list(settings['number_of_active_genes'], settings['number_of_genes'])
695
+ control_gene_list = generate_gene_list(settings['number_of_control_genes'], settings['number_of_genes'])
696
+ plate_map = generate_plate_map(settings['nr_plates'])
697
+
698
+ #control_map = plate_map[plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is in [1,2,3,23,24]
699
+ plate_map = plate_map[~plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is not in [1,2,3,23,24]
700
+
701
+ cell_level, genes_per_well_df, wells_per_gene_df, dists = run_experiment(plate_map, settings['number_of_genes'], active_gene_list, settings['avg_genes_per_well'], settings['sd_genes_per_well'], settings['avg_cells_per_well'], settings['sd_cells_per_well'], settings['well_ineq_coeff'], settings['gene_ineq_coeff'])
702
+ cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], df=cell_level)
703
+ cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm = cell_level_roc_auc(cell_scores)
704
+ well_score = generate_well_score(cell_scores)
705
+ gene_fraction_map, metadata = sequence_plates(well_score, settings['number_of_genes'], settings['avg_reads_per_gene'], settings['sd_reads_per_gene'], sequencing_error=settings['sequencing_error'])
706
+ x = gene_fraction_map
707
+ y = np.log10(well_score['score']+1)
708
+ x = sm.add_constant(x)
709
+ #y = y.fillna(0)
710
+ #x = x.fillna(0)
711
+ #x['const'] = 0.0
712
+ model = sm.OLS(y, x).fit()
713
+ #predictions = model.predict(x)
714
+ results_summary = model.summary()
715
+ results_as_html = results_summary.tables[1].as_html()
716
+ results_df = pd.read_html(results_as_html, header=0, index_col=0)[0]
717
+ results_df = results_df.rename_axis("gene").reset_index()
718
+ results_df = results_df.iloc[1: , :]
719
+ results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats = regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False)
720
+ #except Exception as e:
721
+ # print(f"An error occurred while saving data: {e}")
722
+ return [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df], dists
723
+
724
+ def vis_dists(dists, src, v, i):
725
+ """
726
+ Visualizes the distributions of given distances.
727
+
728
+ Args:
729
+ dists (list): List of distance arrays.
730
+ src (str): Source directory for saving the plot.
731
+ v (int): Number of vertices.
732
+ i (int): Index of the plot.
733
+
734
+ Returns:
735
+ None
736
+ """
737
+ n_graphs = 6
738
+ height_graphs = 4
739
+ n=0
740
+ width_graphs = height_graphs*n_graphs
741
+ fig2, ax =plt.subplots(1,n_graphs, figsize = (width_graphs,height_graphs))
742
+ names = ['genes/well', 'wells/gene', 'genes/well gini', 'wells/gene gini', 'gene_weights', 'well_weights']
743
+ for index, dist in enumerate(dists):
744
+ temp = pd.DataFrame(dist, columns = [f'{names[index]}'])
745
+ sns.histplot(data=temp, x=f'{names[index]}', kde=False, binwidth=None, stat='count', element="step", ax=ax[n], color='teal', log_scale=False)
746
+ #plot_histogram(temp, f'{names[index]}', ax[n], 'slategray', f'{names[index]}', binwidth=None, log=False)
747
+ n+=1
748
+ save_plot(fig2, src, 'dists', i)
749
+ return
750
+
751
+ def visualize_all(output):
752
+ """
753
+ Visualizes various plots based on the given output data.
754
+
755
+ Args:
756
+ output (list): A list containing the following elements:
757
+ - cell_scores (DataFrame): DataFrame containing cell scores.
758
+ - cell_roc_dict_df (DataFrame): DataFrame containing ROC curve data for cell classification.
759
+ - cell_pr_dict_df (DataFrame): DataFrame containing precision-recall curve data for cell classification.
760
+ - cell_cm (array-like): Confusion matrix for cell classification.
761
+ - well_score (DataFrame): DataFrame containing well scores.
762
+ - gene_fraction_map (dict): Dictionary mapping genes to fractions.
763
+ - metadata (dict): Dictionary containing metadata.
764
+ - results_df (DataFrame): DataFrame containing results.
765
+ - reg_roc_dict_df (DataFrame): DataFrame containing ROC curve data for gene regression.
766
+ - reg_pr_dict_df (DataFrame): DataFrame containing precision-recall curve data for gene regression.
767
+ - reg_cm (array-like): Confusion matrix for gene regression.
768
+ - sim_stats (dict): Dictionary containing simulation statistics.
769
+ - genes_per_well_df (DataFrame): DataFrame containing genes per well data.
770
+ - wells_per_gene_df (DataFrame): DataFrame containing wells per gene data.
771
+
772
+ Returns:
773
+ fig (matplotlib.figure.Figure): The generated figure object.
774
+ """
775
+ cell_scores = output[0]
776
+ cell_roc_dict_df = output[1]
777
+ cell_pr_dict_df = output[2]
778
+ cell_cm = output[3]
779
+ well_score = output[4]
780
+ gene_fraction_map = output[5]
781
+ metadata = output[6]
782
+ results_df = output[7]
783
+ reg_roc_dict_df = output[8]
784
+ reg_pr_dict_df = output[9]
785
+ reg_cm =output[10]
786
+ sim_stats = output[11]
787
+ genes_per_well_df = output[12]
788
+ wells_per_gene_df = output[13]
789
+
790
+ hline = -np.log10(0.05)
791
+ n_graphs = 13
792
+ height_graphs = 4
793
+ n=0
794
+ width_graphs = height_graphs*n_graphs
795
+
796
+ fig, ax =plt.subplots(1,n_graphs, figsize = (width_graphs,height_graphs))
797
+
798
+ #plot genes per well
799
+ gini_genes_per_well = gini(genes_per_well_df['genes_per_well'].tolist())
800
+ plot_histogram(genes_per_well_df, "genes_per_well", ax[n], 'slategray', f'gene/well (gini = {gini_genes_per_well:.2f})', binwidth=None, log=False)
801
+ n+=1
802
+
803
+ #plot wells per gene
804
+ gini_wells_per_gene = gini(wells_per_gene_df['wells_per_gene'].tolist())
805
+ plot_histogram(wells_per_gene_df, "wells_per_gene", ax[n], 'slategray', f'well/gene (Gini = {gini_wells_per_gene:.2f})', binwidth=None, log=False)
806
+ #ax[n].set_xscale('log')
807
+ n+=1
808
+
809
+ #plot cell classification score by inactive and active
810
+ active_distribution = cell_scores[cell_scores['is_active'] == 1]
811
+ inactive_distribution = cell_scores[cell_scores['is_active'] == 0]
812
+ plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', binwidth=0.01, log=False)
813
+ plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', binwidth=0.01, log=False)
814
+ ax[n].set_xlim([0, 1])
815
+ n+=1
816
+
817
+ #plot classifier cell predictions by inactive and active well average
818
+ ##inactive_distribution_well['score'] = pd.to_numeric(inactive_distribution['score'], errors='coerce')
819
+ ##inactive_distribution_well = inactive_distribution_well.groupby('plate_row_column')['score'].mean()
820
+
821
+ ##active_distribution_well['score'] = pd.to_numeric(active_distribution['score'], errors='coerce')
822
+ ##active_distribution_well = active_distribution_well.groupby('plate_row_column')['score'].mean()
823
+
824
+ #inactive_distribution_well = inactive_distribution.groupby(['plate_row_column']).mean()
825
+ #active_distribution_well = active_distribution.groupby(['plate_row_column']).mean()
826
+
827
+ plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Well scores', binwidth=0.01, log=False)
828
+ plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Well scores', binwidth=0.01, log=False)
829
+ ax[n].set_xlim([0, 1])
830
+ n+=1
831
+
832
+ #plot ROC (cell classification)
833
+ plot_roc_pr(cell_roc_dict_df, ax[n], 'ROC (Cell)', 'fpr', 'tpr')
834
+ ax[n].plot([0, 1], [0, 1], color='black', lw=0.5, linestyle="--", label='random classifier')
835
+ n+=1
836
+
837
+ #plot Presision recall (cell classification)
838
+ plot_roc_pr(cell_pr_dict_df, ax[n], 'Precision recall (Cell)', 'recall', 'precision')
839
+ ax[n].set_ylim([-0.1, 1.1])
840
+ ax[n].set_xlim([-0.1, 1.1])
841
+ n+=1
842
+
843
+ #Confusion matrix at optimal threshold
844
+ plot_confusion_matrix(cell_cm, ax[n], 'Confusion Matrix Cell')
845
+ n+=1
846
+
847
+ #plot well score
848
+ plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=False)
849
+ ax[n].set_xlim([0, 1])
850
+ n+=1
851
+
852
+ control_df = results_df[results_df['color'] == 'control']
853
+ control_mean = control_df['coef'].mean()
854
+ control_var = control_df['coef'].std()
855
+ #control_var = control_df['coef'].var()
856
+ cutoff = abs(control_mean)+(3*control_var)
857
+ categories = ['inactive', 'control', 'active']
858
+ colors = ['lightgrey', 'black', 'purple']
859
+
860
+ for category, color in zip(categories, colors):
861
+ df = results_df[results_df['color'] == category]
862
+ ax[n].scatter(df['coef'], df['logp'], c=color, alpha=0.7, label=category)
863
+
864
+ reg_lab = ax[n].legend(title='', frameon=False, prop={'size': 10})
865
+ ax[n].add_artist(reg_lab)
866
+ ax[n].axhline(hline, zorder = 0,c = 'k', lw = 0.5,ls = '--')
867
+ ax[n].axvline(-cutoff, zorder = 0,c = 'k', lw = 0.5,ls = '--')
868
+ ax[n].axvline(cutoff, zorder = 0,c = 'k', lw = 0.5,ls = '--')
869
+ ax[n].set_title(f'Regression, threshold {cutoff:.3f}')
870
+ ax[n].set_xlim([-1, 1.1])
871
+ n+=1
872
+
873
+ # error plot
874
+ df = results_df[['gene', 'coef', 'std err', 'p']]
875
+ df = df.sort_values(by = ['coef', 'p'], ascending = [True, False], na_position = 'first')
876
+ df['rank'] = [*range(0,len(df),1)]
877
+
878
+ #df['rank'] = pd.to_numeric(df['rank'], errors='coerce')
879
+ #df['coef'] = pd.to_numeric(df['coef'], errors='coerce')
880
+ #df['std err'] = pd.to_numeric(df['std err'], errors='coerce')
881
+ #df['rank'] = df['rank'].astype(float)
882
+ #df['coef'] = df['coef'].astype(float)
883
+ #df['std err'] = df['std err'].astype(float)
884
+ #epsilon = 1e-6 # A small constant to ensure std err is never zero
885
+ #df['std err adj'] = df['std err'].replace(0, epsilon)
886
+
887
+ ax[n].plot(df['rank'], df['coef'], '-', color = 'black')
888
+ ax[n].fill_between(df['rank'], df['coef'] - abs(df['std err']), df['coef'] + abs(df['std err']), alpha=0.4, color='slategray')
889
+ ax[n].set_title('Effect score error')
890
+ ax[n].set_xlabel('rank')
891
+ ax[n].set_ylabel('Effect size')
892
+ n+=1
893
+
894
+ #plot ROC (gene classification)
895
+ plot_roc_pr(reg_roc_dict_df, ax[n], 'ROC (gene)', 'fpr', 'tpr')
896
+ ax[n].legend(loc="lower right")
897
+ n+=1
898
+
899
+ #plot Presision recall (regression classification)
900
+ plot_roc_pr(reg_pr_dict_df, ax[n], 'Precision recall (gene)', 'recall', 'precision')
901
+ ax[n].legend(loc="lower right")
902
+ n+=1
903
+
904
+ #Confusion matrix at optimal threshold
905
+ plot_confusion_matrix(reg_cm, ax[n], 'Confusion Matrix Reg')
906
+
907
+ for n in [*range(0,n_graphs,1)]:
908
+ ax[n].spines['top'].set_visible(False)
909
+ ax[n].spines['right'].set_visible(False)
910
+
911
+ plt.tight_layout()
912
+ plt.show()
913
+ return fig
914
+
915
+ def create_database(db_path):
916
+ """
917
+ Creates a SQLite database at the specified path.
918
+
919
+ Args:
920
+ db_path (str): The path where the database should be created.
921
+
922
+ Returns:
923
+ None
924
+ """
925
+ conn = None
926
+ try:
927
+ conn = sqlite3.connect(db_path)
928
+ #print(f"SQLite version: {sqlite3.version}")
929
+ except Exception as e:
930
+ print(e)
931
+ finally:
932
+ if conn:
933
+ conn.close()
934
+
935
+ def append_database(src, table, table_name):
936
+ """
937
+ Append a pandas DataFrame to an SQLite database table.
938
+
939
+ Parameters:
940
+ src (str): The source directory where the database file is located.
941
+ table (pandas.DataFrame): The DataFrame to be appended to the database table.
942
+ table_name (str): The name of the database table.
943
+
944
+ Returns:
945
+ None
946
+ """
947
+ try:
948
+ conn = sqlite3.connect(f'{src}/simulations.db', timeout=3600)
949
+ table.to_sql(table_name, conn, if_exists='append', index=False)
950
+ except sqlite3.OperationalError as e:
951
+ print("SQLite error:", e)
952
+ finally:
953
+ conn.close()
954
+ return
955
+
956
+ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
957
+ """
958
+ Save simulation data to specified location.
959
+
960
+ Args:
961
+ src (str): The directory path where the data will be saved.
962
+ output (list): A list of dataframes containing simulation output.
963
+ settings (dict): A dictionary containing simulation settings.
964
+ save_all (bool, optional): Flag indicating whether to save all tables or only a subset. Defaults to False.
965
+ i (int, optional): The simulation number. Defaults to 0.
966
+ variable (str, optional): The variable name. Defaults to 'all'.
967
+
968
+ Returns:
969
+ None
970
+ """
971
+ try:
972
+ if not save_all:
973
+ src = f'{src}'
974
+ os.makedirs(src, exist_ok=True)
975
+ else:
976
+ os.makedirs(src, exist_ok=True)
977
+
978
+ settings_df = pd.DataFrame({key: [value] for key, value in settings.items()})
979
+ output = [settings_df] + output
980
+ table_names = ['settings', 'cell_scores', 'cell_roc', 'cell_precision_recall', 'cell_confusion_matrix', 'well_score', 'gene_fraction_map', 'metadata', 'regression_results', 'regression_roc', 'regression_precision_recall', 'regression_confusion_matrix', 'sim_stats', 'genes_per_well', 'wells_per_gene']
981
+
982
+ if not save_all:
983
+ gini_genes_per_well = gini(output[13]['genes_per_well'].tolist())
984
+ gini_wells_per_gene = gini(output[14]['wells_per_gene'].tolist())
985
+ indices_to_keep= [0,12] # Specify the indices to remove
986
+ filtered_output = [v for i, v in enumerate(output) if i in indices_to_keep]
987
+ df_concat = pd.concat(filtered_output, axis=1)
988
+ df_concat['genes_per_well_gini'] = gini_genes_per_well
989
+ df_concat['wells_per_gene_gini'] = gini_wells_per_gene
990
+ df_concat['date'] = datetime.now()
991
+ df_concat[f'variable_{variable}_sim_nr'] = i
992
+
993
+ append_database(src, df_concat, 'simulations')
994
+
995
+ if save_all:
996
+ for i, df in enumerate(output):
997
+ df = output[i]
998
+ if table_names[i] == 'well_score':
999
+ df['gene_list'] = df['gene_list'].astype(str)
1000
+ if not isinstance(df, pd.DataFrame):
1001
+ df = pd.DataFrame(df)
1002
+ append_database(src, df, table_names[i])
1003
+ except Exception as e:
1004
+ print(f"An error occurred while saving data: {e}")
1005
+ return
1006
+
1007
+ def save_plot(fig, src, variable, i):
1008
+ """
1009
+ Save a matplotlib figure as a PDF file.
1010
+
1011
+ Parameters:
1012
+ - fig: The matplotlib figure to be saved.
1013
+ - src: The directory where the file will be saved.
1014
+ - variable: The name of the variable being plotted.
1015
+ - i: The index of the figure.
1016
+
1017
+ Returns:
1018
+ None
1019
+ """
1020
+ os.makedirs(f'{src}/{variable}', exist_ok=True)
1021
+ filename_fig = f'{src}/{variable}/{str(i)}_figure.pdf'
1022
+ fig.savefig(filename_fig, dpi=600, format='pdf', bbox_inches='tight')
1023
+ return
1024
+
1025
+ def run_and_save(i, settings, time_ls, total_sims):
1026
+ """
1027
+ Run the simulation and save the results.
1028
+
1029
+ Args:
1030
+ i (int): The simulation index.
1031
+ settings (dict): The simulation settings.
1032
+ time_ls (list): The list to store simulation times.
1033
+ total_sims (int): The total number of simulations.
1034
+
1035
+ Returns:
1036
+ tuple: A tuple containing the simulation index, simulation time, and None.
1037
+ """
1038
+ if settings['random_seed']:
1039
+ random.seed(42) # sims will be too similar with random seed
1040
+ src = settings['src']
1041
+ plot = settings['plot']
1042
+ v = settings['variable']
1043
+ start_time = time() # Start time of the simulation
1044
+ now = datetime.now() # get current date
1045
+ date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1046
+ #try:
1047
+ output, dists = run_simulation(settings)
1048
+ sim_time = time() - start_time # Elapsed time for the simulation
1049
+ settings['sim_time'] = sim_time
1050
+ src = os.path.join(f'{src}/{date_string}',settings['name'])
1051
+ save_data(src, output, settings, save_all=False, i=i, variable=v)
1052
+ if vis_dists:
1053
+ vis_dists(dists,src, v, i)
1054
+ if plot:
1055
+ fig = visualize_all(output)
1056
+ save_plot(fig, src, v, i)
1057
+ plt.close(fig)
1058
+ plt.figure().clear()
1059
+ plt.cla()
1060
+ plt.clf()
1061
+ del fig
1062
+ del output, dists
1063
+ #except Exception as e:
1064
+ # print(e, end='\r', flush=True)
1065
+ # sim_time = time() - start_time
1066
+ #print(traceback.format_exc(), end='\r', flush=True)
1067
+ time_ls.append(sim_time)
1068
+ return i, sim_time, None
1069
+
1070
+ def generate_paramiters(settings):
1071
+ """
1072
+ Generate a list of parameter sets for simulation based on the given settings.
1073
+
1074
+ Args:
1075
+ settings (dict): A dictionary containing the simulation settings.
1076
+
1077
+ Returns:
1078
+ list: A list of parameter sets for simulation.
1079
+ """
1080
+ sim_ls = []
1081
+ for avg_genes_per_well in settings['avg_genes_per_well']:
1082
+ replicates = settings['replicates']
1083
+ sett = settings.copy()
1084
+ sett['avg_genes_per_well'] = avg_genes_per_well
1085
+ sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
1086
+ for avg_cells_per_well in settings['avg_cells_per_well']:
1087
+ sett['avg_cells_per_well'] = avg_cells_per_well
1088
+ sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
1089
+ for positive_mean in settings['positive_mean']:
1090
+ sett['positive_mean'] = positive_mean
1091
+ sett['negative_mean'] = 1-positive_mean
1092
+ sett['positive_variance'] = (1-positive_mean)/2
1093
+ sett['negative_variance'] = (1-positive_mean)/2
1094
+ for avg_reads_per_gene in settings['avg_reads_per_gene']:
1095
+ sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
1096
+ sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
1097
+ for sequencing_error in settings['sequencing_error']:
1098
+ sett['sequencing_error'] = sequencing_error
1099
+ for well_ineq_coeff in settings['well_ineq_coeff']:
1100
+ sett['well_ineq_coeff'] = well_ineq_coeff
1101
+ for gene_ineq_coeff in settings['gene_ineq_coeff']:
1102
+ sett['gene_ineq_coeff'] = gene_ineq_coeff
1103
+ for nr_plates in settings['nr_plates']:
1104
+ sett['nr_plates'] = nr_plates
1105
+ for number_of_genes in settings['number_of_genes']:
1106
+ sett['number_of_genes'] = number_of_genes
1107
+ for number_of_active_genes in settings['number_of_active_genes']:
1108
+ sett['number_of_active_genes'] = number_of_active_genes
1109
+ for i in [*range(1,replicates+1)]:
1110
+ sim_ls.append(sett)
1111
+ #print(sett)
1112
+ #print('Number of simulations:',len(sim_ls))
1113
+ return sim_ls
1114
+
1115
+ #altered for one set of settings see negative_mean and variance
1116
+ def generate_paramiters_single(settings):
1117
+ """
1118
+ Generate a list of parameter sets for single simulations based on the given settings.
1119
+
1120
+ Args:
1121
+ settings (dict): A dictionary containing the simulation settings.
1122
+
1123
+ Returns:
1124
+ list: A list of parameter sets for single simulations.
1125
+ """
1126
+ sim_ls = []
1127
+ for avg_genes_per_well in settings['avg_genes_per_well']:
1128
+ replicates = settings['replicates']
1129
+ sett = settings.copy()
1130
+ sett['avg_genes_per_well'] = avg_genes_per_well
1131
+ sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
1132
+ for avg_cells_per_well in settings['avg_cells_per_well']:
1133
+ sett['avg_cells_per_well'] = avg_cells_per_well
1134
+ sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
1135
+ for positive_mean in settings['positive_mean']:
1136
+ sett['positive_mean'] = positive_mean
1137
+ sett['negative_mean'] = 0.2
1138
+ sett['positive_variance'] = 0.13
1139
+ sett['negative_variance'] = 0.13
1140
+ for avg_reads_per_gene in settings['avg_reads_per_gene']:
1141
+ sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
1142
+ sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
1143
+ for sequencing_error in settings['sequencing_error']:
1144
+ sett['sequencing_error'] = sequencing_error
1145
+ for well_ineq_coeff in settings['well_ineq_coeff']:
1146
+ sett['well_ineq_coeff'] = well_ineq_coeff
1147
+ for gene_ineq_coeff in settings['gene_ineq_coeff']:
1148
+ sett['gene_ineq_coeff'] = gene_ineq_coeff
1149
+ for nr_plates in settings['nr_plates']:
1150
+ sett['nr_plates'] = nr_plates
1151
+ for number_of_genes in settings['number_of_genes']:
1152
+ sett['number_of_genes'] = number_of_genes
1153
+ for number_of_active_genes in settings['number_of_active_genes']:
1154
+ sett['number_of_active_genes'] = number_of_active_genes
1155
+ for i in [*range(1,replicates+1)]:
1156
+ sim_ls.append(sett)
1157
+ #print(sett)
1158
+ #print('Number of simulations:',len(sim_ls))
1159
+ return sim_ls
1160
+
1161
+ def run_multiple_simulations(settings):
1162
+ """
1163
+ Run multiple simulations in parallel using the provided settings.
1164
+
1165
+ Args:
1166
+ settings (dict): A dictionary containing the simulation settings.
1167
+
1168
+ Returns:
1169
+ None
1170
+ """
1171
+
1172
+ sim_ls = generate_paramiters(settings)
1173
+ print(f'Running {len(sim_ls)} simulations. Standard deviations for each variable are variable / 2')
1174
+
1175
+ max_workers = settings['max_workers'] or cpu_count() - 4
1176
+ with Manager() as manager:
1177
+ time_ls = manager.list()
1178
+ total_sims = len(sim_ls)
1179
+ with Pool(max_workers) as pool:
1180
+ result = pool.starmap_async(run_and_save, [(index, settings, time_ls, total_sims) for index, settings in enumerate(sim_ls)])
1181
+ while not result.ready():
1182
+ sleep(0.01)
1183
+ sims_processed = len(time_ls)
1184
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1185
+ time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
1186
+ print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
1187
+ result.get()