spacr 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +37 -0
- spacr/__main__.py +15 -0
- spacr/annotate_app.py +495 -0
- spacr/cli.py +203 -0
- spacr/core.py +2250 -0
- spacr/gui_mask_app.py +247 -0
- spacr/gui_measure_app.py +214 -0
- spacr/gui_utils.py +488 -0
- spacr/io.py +2271 -0
- spacr/logger.py +20 -0
- spacr/mask_app.py +818 -0
- spacr/measure.py +1014 -0
- spacr/old_code.py +104 -0
- spacr/plot.py +1273 -0
- spacr/sim.py +1187 -0
- spacr/timelapse.py +576 -0
- spacr/train.py +494 -0
- spacr/umap.py +689 -0
- spacr/utils.py +2726 -0
- spacr/version.py +19 -0
- spacr-0.0.1.dist-info/LICENSE +21 -0
- spacr-0.0.1.dist-info/METADATA +64 -0
- spacr-0.0.1.dist-info/RECORD +26 -0
- spacr-0.0.1.dist-info/WHEEL +5 -0
- spacr-0.0.1.dist-info/entry_points.txt +5 -0
- spacr-0.0.1.dist-info/top_level.txt +1 -0
spacr/sim.py
ADDED
@@ -0,0 +1,1187 @@
|
|
1
|
+
|
2
|
+
import os, gc, random, warnings, traceback, itertools, matplotlib, sqlite3
|
3
|
+
import time as tm
|
4
|
+
from time import time, sleep
|
5
|
+
from datetime import datetime
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
import matplotlib.pyplot as plt
|
9
|
+
import seaborn as sns
|
10
|
+
import sklearn.metrics as metrics
|
11
|
+
from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix, precision_recall_curve
|
12
|
+
import statsmodels.api as sm
|
13
|
+
from multiprocessing import cpu_count, Value, Array, Lock, Pool, Manager
|
14
|
+
|
15
|
+
from .logger import log_function_call
|
16
|
+
|
17
|
+
warnings.filterwarnings("ignore")
|
18
|
+
warnings.filterwarnings("ignore", category=RuntimeWarning) # Ignore RuntimeWarning
|
19
|
+
|
20
|
+
def generate_gene_list(number_of_genes, number_of_all_genes):
|
21
|
+
"""
|
22
|
+
Generates a list of randomly selected genes.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
number_of_genes (int): The number of genes to be selected.
|
26
|
+
number_of_all_genes (int): The total number of genes available.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
list: A list of randomly selected genes.
|
30
|
+
"""
|
31
|
+
genes_ls = list(range(number_of_all_genes))
|
32
|
+
random.shuffle(genes_ls)
|
33
|
+
gene_list = genes_ls[:number_of_genes]
|
34
|
+
return gene_list
|
35
|
+
|
36
|
+
# plate_map is a table with a row for each well, containing well metadata: plate_id, row_id, and column_id
|
37
|
+
def generate_plate_map(nr_plates):
|
38
|
+
"""
|
39
|
+
Generate a plate map based on the number of plates.
|
40
|
+
|
41
|
+
Parameters:
|
42
|
+
nr_plates (int): The number of plates to generate the map for.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
pandas.DataFrame: The generated plate map dataframe.
|
46
|
+
"""
|
47
|
+
plate_row_column = [f"{i+1}_{ir+1}_{ic+1}" for i in range(nr_plates) for ir in range(16) for ic in range(24)]
|
48
|
+
df= pd.DataFrame({'plate_row_column': plate_row_column})
|
49
|
+
df["plate_id"], df["row_id"], df["column_id"] = zip(*[r.split("_") for r in df['plate_row_column']])
|
50
|
+
return df
|
51
|
+
|
52
|
+
def gini_coefficient(x):
|
53
|
+
"""
|
54
|
+
Compute Gini coefficient of array of values.
|
55
|
+
|
56
|
+
Parameters:
|
57
|
+
x (array-like): Array of values.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
float: Gini coefficient.
|
61
|
+
|
62
|
+
"""
|
63
|
+
diffsum = np.sum(np.abs(np.subtract.outer(x, x)))
|
64
|
+
return diffsum / (2 * len(x) ** 2 * np.mean(x))
|
65
|
+
|
66
|
+
def gini(x):
|
67
|
+
"""
|
68
|
+
Calculate the Gini coefficient for a given array of values.
|
69
|
+
|
70
|
+
Parameters:
|
71
|
+
x (array-like): Input array of values.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
float: The Gini coefficient.
|
75
|
+
|
76
|
+
Notes:
|
77
|
+
This implementation has a time and memory complexity of O(n**2), where n is the length of x.
|
78
|
+
Avoid passing in large samples to prevent performance issues.
|
79
|
+
"""
|
80
|
+
# Mean absolute difference
|
81
|
+
mad = np.abs(np.subtract.outer(x, x)).mean()
|
82
|
+
# Relative mean absolute difference
|
83
|
+
rmad = mad/np.mean(x)
|
84
|
+
# Gini coefficient
|
85
|
+
g = 0.5 * rmad
|
86
|
+
return g
|
87
|
+
|
88
|
+
def gini_gene_well(x):
|
89
|
+
"""
|
90
|
+
Calculate the Gini coefficient for a given income distribution.
|
91
|
+
|
92
|
+
The Gini coefficient measures income inequality in a population.
|
93
|
+
A value of 0 represents perfect income equality (everyone has the same income),
|
94
|
+
while a value of 1 represents perfect income inequality (one individual has all the income).
|
95
|
+
|
96
|
+
Parameters:
|
97
|
+
x (array-like): An array-like object representing the income distribution.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
float: The Gini coefficient for the given income distribution.
|
101
|
+
"""
|
102
|
+
total = 0
|
103
|
+
for i, xi in enumerate(x[:-1], 1):
|
104
|
+
total += np.sum(np.abs(xi - x[i:]))
|
105
|
+
return total / (len(x)**2 * np.mean(x))
|
106
|
+
|
107
|
+
def gini(x):
|
108
|
+
"""
|
109
|
+
Calculate the Gini coefficient for a given array of values.
|
110
|
+
|
111
|
+
Parameters:
|
112
|
+
x (array-like): The input array of values.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
float: The Gini coefficient.
|
116
|
+
|
117
|
+
References:
|
118
|
+
- Based on bottom eq: http://www.statsdirect.com/help/content/image/stat0206_wmf.gif
|
119
|
+
- From: http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm
|
120
|
+
- All values are treated equally, arrays must be 1d.
|
121
|
+
"""
|
122
|
+
x = np.array(x, dtype=np.float64)
|
123
|
+
n = len(x)
|
124
|
+
s = x.sum()
|
125
|
+
r = np.argsort(np.argsort(-x)) # ranks of x
|
126
|
+
return 1 - (2 * (r * x).sum() + s) / (n * s)
|
127
|
+
|
128
|
+
def dist_gen(mean, sd, df):
|
129
|
+
"""
|
130
|
+
Generate a Poisson distribution based on a gamma distribution.
|
131
|
+
|
132
|
+
Parameters:
|
133
|
+
mean (float): Mean of the gamma distribution.
|
134
|
+
sd (float): Standard deviation of the gamma distribution.
|
135
|
+
df (pandas.DataFrame): Input data.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
tuple: A tuple containing the generated Poisson distribution and the length of the input data.
|
139
|
+
"""
|
140
|
+
length = len(df)
|
141
|
+
shape = (mean / sd) ** 2 # Calculate shape parameter
|
142
|
+
scale = (sd ** 2) / mean # Calculate scale parameter
|
143
|
+
rate = np.random.gamma(shape, scale, size=length) # Generate random rate from gamma distribution
|
144
|
+
data = np.random.poisson(rate) # Use the random rate for a Poisson distribution
|
145
|
+
return data, length
|
146
|
+
|
147
|
+
def generate_gene_weights(positive_mean, positive_variance, df):
|
148
|
+
"""
|
149
|
+
Generate gene weights using a beta distribution.
|
150
|
+
|
151
|
+
Parameters:
|
152
|
+
- positive_mean (float): The mean value for the positive distribution.
|
153
|
+
- positive_variance (float): The variance value for the positive distribution.
|
154
|
+
- df (pandas.DataFrame): The DataFrame containing the data.
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
- weights (numpy.ndarray): An array of gene weights generated using a beta distribution.
|
158
|
+
"""
|
159
|
+
# alpha and beta for positive distribution
|
160
|
+
a1 = positive_mean*(positive_mean*(1-positive_mean)/positive_variance - 1)
|
161
|
+
b1 = a1*(1-positive_mean)/positive_mean
|
162
|
+
weights = np.random.beta(a1, b1, len(df))
|
163
|
+
return weights
|
164
|
+
|
165
|
+
def normalize_array(arr):
|
166
|
+
"""
|
167
|
+
Normalize an array by scaling its values between 0 and 1.
|
168
|
+
|
169
|
+
Parameters:
|
170
|
+
arr (numpy.ndarray): The input array to be normalized.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
numpy.ndarray: The normalized array.
|
174
|
+
|
175
|
+
"""
|
176
|
+
min_value = np.min(arr)
|
177
|
+
max_value = np.max(arr)
|
178
|
+
normalized_arr = (arr - min_value) / (max_value - min_value)
|
179
|
+
return normalized_arr
|
180
|
+
|
181
|
+
def generate_power_law_distribution(num_elements, coeff):
|
182
|
+
"""
|
183
|
+
Generate a power law distribution.
|
184
|
+
|
185
|
+
Parameters:
|
186
|
+
- num_elements (int): The number of elements in the distribution.
|
187
|
+
- coeff (float): The coefficient of the power law.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
- normalized_distribution (ndarray): The normalized power law distribution.
|
191
|
+
"""
|
192
|
+
base_distribution = np.arange(1, num_elements + 1)
|
193
|
+
powered_distribution = base_distribution ** -coeff
|
194
|
+
normalized_distribution = powered_distribution / np.sum(powered_distribution)
|
195
|
+
return normalized_distribution
|
196
|
+
|
197
|
+
# distribution generator function
|
198
|
+
def power_law_dist_gen(df, avg, well_ineq_coeff):
|
199
|
+
"""
|
200
|
+
Generate a power-law distribution for wells.
|
201
|
+
|
202
|
+
Parameters:
|
203
|
+
- df: DataFrame
|
204
|
+
The input DataFrame containing the wells.
|
205
|
+
- avg: float
|
206
|
+
The average value for the distribution.
|
207
|
+
- well_ineq_coeff: float
|
208
|
+
The inequality coefficient for the power-law distribution.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
- dist: ndarray
|
212
|
+
The generated power-law distribution for the wells.
|
213
|
+
"""
|
214
|
+
# Generate a power-law distribution for wells
|
215
|
+
distribution = generate_power_law_distribution(len(df), well_ineq_coeff)
|
216
|
+
dist = np.random.choice(distribution, len(df)) * avg
|
217
|
+
return dist
|
218
|
+
|
219
|
+
# plates is a table with for each cell in the experiment with columns [plate_id, row_id, column_id, gene_id, is_active]
|
220
|
+
def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_well, sd_genes_per_well, avg_cells_per_well, sd_cells_per_well, well_ineq_coeff, gene_ineq_coeff):
|
221
|
+
"""
|
222
|
+
Run a simulation experiment.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
plate_map (DataFrame): The plate map containing information about the wells.
|
226
|
+
number_of_genes (int): The total number of genes.
|
227
|
+
active_gene_list (list): The list of active genes.
|
228
|
+
avg_genes_per_well (float): The average number of genes per well.
|
229
|
+
sd_genes_per_well (float): The standard deviation of genes per well.
|
230
|
+
avg_cells_per_well (float): The average number of cells per well.
|
231
|
+
sd_cells_per_well (float): The standard deviation of cells per well.
|
232
|
+
well_ineq_coeff (float): The coefficient for well inequality.
|
233
|
+
gene_ineq_coeff (float): The coefficient for gene inequality.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
tuple: A tuple containing the following:
|
237
|
+
- cell_df (DataFrame): The DataFrame containing information about the cells.
|
238
|
+
- genes_per_well_df (DataFrame): The DataFrame containing gene counts per well.
|
239
|
+
- wells_per_gene_df (DataFrame): The DataFrame containing well counts per gene.
|
240
|
+
- df_ls (list): A list containing gene counts per well, well counts per gene, Gini coefficients for wells,
|
241
|
+
Gini coefficients for genes, gene weights array, and well weights.
|
242
|
+
"""
|
243
|
+
#generate primary distributions and genes
|
244
|
+
cpw, _ = dist_gen(avg_cells_per_well, sd_cells_per_well, plate_map)
|
245
|
+
gpw, _ = dist_gen(avg_genes_per_well, sd_genes_per_well, plate_map)
|
246
|
+
genes = [*range(1, number_of_genes+1, 1)]
|
247
|
+
|
248
|
+
#gene_weights = generate_power_law_distribution(number_of_genes, gene_ineq_coeff)
|
249
|
+
gene_weights = {gene: weight for gene, weight in zip(genes, generate_power_law_distribution(number_of_genes, gene_ineq_coeff))} # Generate gene_weights as a dictionary
|
250
|
+
gene_weights_array = np.array(list(gene_weights.values())) # Convert the values to an array
|
251
|
+
|
252
|
+
well_weights = generate_power_law_distribution(len(plate_map), well_ineq_coeff)
|
253
|
+
|
254
|
+
gene_to_well_mapping = {}
|
255
|
+
|
256
|
+
for gene in genes:
|
257
|
+
gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=int(gpw[gene-1]), p=well_weights) # Generate a number of wells for each gene according to well_weights
|
258
|
+
|
259
|
+
gene_to_well_mapping = {gene: wells for gene, wells in gene_to_well_mapping.items() if len(wells) >= 2}
|
260
|
+
|
261
|
+
cells = []
|
262
|
+
for i in [*range(0,len(plate_map))]:
|
263
|
+
ciw = random.choice(cpw)
|
264
|
+
present_genes = [gene for gene, wells in gene_to_well_mapping.items() if plate_map.loc[i, 'plate_row_column'] in wells] # Select genes present in the current well
|
265
|
+
present_gene_weights = [gene_weights[gene] for gene in present_genes] # For sampling, filter gene_weights according to present_genes
|
266
|
+
present_gene_weights /= np.sum(present_gene_weights)
|
267
|
+
if present_genes:
|
268
|
+
giw = np.random.choice(present_genes, int(gpw[i]), p=present_gene_weights)
|
269
|
+
if len(giw) > 0:
|
270
|
+
for _ in range(0,int(ciw)):
|
271
|
+
gene_nr = random.choice(giw)
|
272
|
+
cell = {
|
273
|
+
'plate_row_column': plate_map.loc[i, 'plate_row_column'],
|
274
|
+
'plate_id': plate_map.loc[i, 'plate_id'],
|
275
|
+
'row_id': plate_map.loc[i, 'row_id'],
|
276
|
+
'column_id': plate_map.loc[i, 'column_id'],
|
277
|
+
'genes_in_well': len(giw),
|
278
|
+
'gene_id': gene_nr,
|
279
|
+
'is_active': int(gene_nr in active_gene_list)
|
280
|
+
}
|
281
|
+
cells.append(cell)
|
282
|
+
|
283
|
+
cell_df = pd.DataFrame(cells)
|
284
|
+
cell_df = cell_df.dropna()
|
285
|
+
|
286
|
+
# calculate well, gene counts per well
|
287
|
+
gene_counts_per_well = cell_df.groupby('plate_row_column')['gene_id'].nunique().sort_values().tolist()
|
288
|
+
well_counts_per_gene = cell_df.groupby('gene_id')['plate_row_column'].nunique().sort_values().tolist()
|
289
|
+
|
290
|
+
# Create DataFrames
|
291
|
+
genes_per_well_df = pd.DataFrame(gene_counts_per_well, columns=['genes_per_well'])
|
292
|
+
genes_per_well_df['rank'] = range(1, len(genes_per_well_df) + 1)
|
293
|
+
wells_per_gene_df = pd.DataFrame(well_counts_per_gene, columns=['wells_per_gene'])
|
294
|
+
wells_per_gene_df['rank'] = range(1, len(wells_per_gene_df) + 1)
|
295
|
+
|
296
|
+
ls_ = []
|
297
|
+
gini_ls = []
|
298
|
+
for i,val in enumerate(cell_df['plate_row_column'].unique().tolist()):
|
299
|
+
temp = cell_df[cell_df['plate_row_column']==val]
|
300
|
+
x = temp['gene_id'].value_counts().to_numpy()
|
301
|
+
gini_val = gini_gene_well(x)
|
302
|
+
ls_.append(val)
|
303
|
+
gini_ls.append(gini_val)
|
304
|
+
gini_well = np.array(gini_ls)
|
305
|
+
|
306
|
+
ls_ = []
|
307
|
+
gini_ls = []
|
308
|
+
for i,val in enumerate(cell_df['gene_id'].unique().tolist()):
|
309
|
+
temp = cell_df[cell_df['gene_id']==val]
|
310
|
+
x = temp['plate_row_column'].value_counts().to_numpy()
|
311
|
+
gini_val = gini_gene_well(x)
|
312
|
+
ls_.append(val)
|
313
|
+
gini_ls.append(gini_val)
|
314
|
+
gini_gene = np.array(gini_ls)
|
315
|
+
df_ls = [gene_counts_per_well, well_counts_per_gene, gini_well, gini_gene, gene_weights_array, well_weights]
|
316
|
+
return cell_df, genes_per_well_df, wells_per_gene_df, df_ls
|
317
|
+
|
318
|
+
# classifier is a function that takes a cell state (active=1/inactive=0) and produces a score in [0, 1]
|
319
|
+
# For the input cell, it checks if it is active or inactive, and then samples from an appropriate beta distribution to give a score
|
320
|
+
def classifier(positive_mean, positive_variance, negative_mean, negative_variance, df):
|
321
|
+
"""
|
322
|
+
Classifies the data in the DataFrame based on the given parameters.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
positive_mean (float): The mean of the positive distribution.
|
326
|
+
positive_variance (float): The variance of the positive distribution.
|
327
|
+
negative_mean (float): The mean of the negative distribution.
|
328
|
+
negative_variance (float): The variance of the negative distribution.
|
329
|
+
df (pandas.DataFrame): The DataFrame containing the data to be classified.
|
330
|
+
|
331
|
+
Returns:
|
332
|
+
pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
|
333
|
+
"""
|
334
|
+
# alpha and beta for positive distribution
|
335
|
+
a1 = positive_mean*(positive_mean*(1-positive_mean)/positive_variance - 1)
|
336
|
+
b1 = a1*(1-positive_mean)/positive_mean
|
337
|
+
# alpha and beta for negative distribution
|
338
|
+
a2 = negative_mean*(negative_mean*(1-negative_mean)/negative_variance - 1)
|
339
|
+
b2 = a2*(1-negative_mean)/negative_mean
|
340
|
+
df['score'] = df['is_active'].apply(lambda is_active: np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2))
|
341
|
+
return df
|
342
|
+
|
343
|
+
def compute_roc_auc(cell_scores):
|
344
|
+
"""
|
345
|
+
Compute the Receiver Operating Characteristic (ROC) Area Under the Curve (AUC) for cell scores.
|
346
|
+
|
347
|
+
Parameters:
|
348
|
+
- cell_scores (DataFrame): DataFrame containing cell scores with columns 'is_active' and 'score'.
|
349
|
+
|
350
|
+
Returns:
|
351
|
+
- cell_roc_dict (dict): Dictionary containing the ROC curve information, including the threshold, true positive rate (TPR),
|
352
|
+
false positive rate (FPR), and ROC AUC.
|
353
|
+
|
354
|
+
"""
|
355
|
+
fpr, tpr, thresh = roc_curve(cell_scores['is_active'], cell_scores['score'], pos_label=1)
|
356
|
+
roc_auc = auc(fpr, tpr)
|
357
|
+
cell_roc_dict = {'threshold':thresh,'tpr': tpr,'fpr': fpr, 'roc_auc':roc_auc}
|
358
|
+
return cell_roc_dict
|
359
|
+
|
360
|
+
def compute_precision_recall(cell_scores):
|
361
|
+
"""
|
362
|
+
Compute precision, recall, F1 score, and PR AUC for a given set of cell scores.
|
363
|
+
|
364
|
+
Parameters:
|
365
|
+
- cell_scores (DataFrame): A DataFrame containing the cell scores with columns 'is_active' and 'score'.
|
366
|
+
|
367
|
+
Returns:
|
368
|
+
- cell_pr_dict (dict): A dictionary containing the computed precision, recall, F1 score, PR AUC, and threshold values.
|
369
|
+
"""
|
370
|
+
pr, re, th = precision_recall_curve(cell_scores['is_active'], cell_scores['score'])
|
371
|
+
th = np.insert(th, 0, 0)
|
372
|
+
f1_score = 2 * (pr * re) / (pr + re)
|
373
|
+
pr_auc = auc(re, pr)
|
374
|
+
cell_pr_dict = {'threshold':th,'precision': pr,'recall': re, 'f1_score':f1_score, 'pr_auc': pr_auc}
|
375
|
+
return cell_pr_dict
|
376
|
+
|
377
|
+
def get_optimum_threshold(cell_pr_dict):
|
378
|
+
"""
|
379
|
+
Calculates the optimum threshold based on the f1_score in the given cell_pr_dict.
|
380
|
+
|
381
|
+
Parameters:
|
382
|
+
cell_pr_dict (dict): A dictionary containing precision, recall, and f1_score values for different thresholds.
|
383
|
+
|
384
|
+
Returns:
|
385
|
+
float: The optimum threshold value.
|
386
|
+
"""
|
387
|
+
cell_pr_dict_df = pd.DataFrame(cell_pr_dict)
|
388
|
+
max_x = cell_pr_dict_df.loc[cell_pr_dict_df['f1_score'].idxmax()]
|
389
|
+
optimum = float(max_x['threshold'])
|
390
|
+
return optimum
|
391
|
+
|
392
|
+
def update_scores_and_get_cm(cell_scores, optimum):
|
393
|
+
"""
|
394
|
+
Update the cell scores based on the given optimum value and calculate the confusion matrix.
|
395
|
+
|
396
|
+
Args:
|
397
|
+
cell_scores (DataFrame): The DataFrame containing the cell scores.
|
398
|
+
optimum (float): The optimum value used for updating the scores.
|
399
|
+
|
400
|
+
Returns:
|
401
|
+
tuple: A tuple containing the updated cell scores DataFrame and the confusion matrix.
|
402
|
+
"""
|
403
|
+
cell_scores[optimum] = cell_scores.score.map(lambda x: 1 if x >= optimum else 0)
|
404
|
+
cell_cm = metrics.confusion_matrix(cell_scores.is_active, cell_scores[optimum])
|
405
|
+
return cell_scores, cell_cm
|
406
|
+
|
407
|
+
def cell_level_roc_auc(cell_scores):
|
408
|
+
"""
|
409
|
+
Compute the ROC AUC and precision-recall metrics at the cell level.
|
410
|
+
|
411
|
+
Args:
|
412
|
+
cell_scores (list): List of scores for each cell.
|
413
|
+
|
414
|
+
Returns:
|
415
|
+
cell_roc_dict_df (DataFrame): DataFrame containing the ROC AUC metrics for each cell.
|
416
|
+
cell_pr_dict_df (DataFrame): DataFrame containing the precision-recall metrics for each cell.
|
417
|
+
cell_scores (list): Updated list of scores after applying the optimum threshold.
|
418
|
+
cell_cm (array): Confusion matrix for the cell-level classification.
|
419
|
+
"""
|
420
|
+
cell_roc_dict = compute_roc_auc(cell_scores)
|
421
|
+
cell_pr_dict = compute_precision_recall(cell_scores)
|
422
|
+
optimum = get_optimum_threshold(cell_pr_dict)
|
423
|
+
cell_scores, cell_cm = update_scores_and_get_cm(cell_scores, optimum)
|
424
|
+
cell_pr_dict['optimum'] = optimum
|
425
|
+
cell_roc_dict_df = pd.DataFrame(cell_roc_dict)
|
426
|
+
cell_pr_dict_df = pd.DataFrame(cell_pr_dict)
|
427
|
+
return cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm
|
428
|
+
|
429
|
+
def generate_well_score(cell_scores):
|
430
|
+
"""
|
431
|
+
Generate well scores based on cell scores.
|
432
|
+
|
433
|
+
Args:
|
434
|
+
cell_scores (DataFrame): DataFrame containing cell scores.
|
435
|
+
|
436
|
+
Returns:
|
437
|
+
DataFrame: DataFrame containing well scores with average active score, gene list, and score.
|
438
|
+
|
439
|
+
"""
|
440
|
+
# Compute mean and list of unique gene_ids
|
441
|
+
well_score = cell_scores.groupby(['plate_row_column']).agg(
|
442
|
+
average_active_score=('is_active', 'mean'),
|
443
|
+
gene_list=('gene_id', lambda x: np.unique(x).tolist()))
|
444
|
+
well_score['score'] = np.log10(well_score['average_active_score'] + 1)
|
445
|
+
return well_score
|
446
|
+
|
447
|
+
def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_per_gene, sequencing_error=0.01):
|
448
|
+
"""
|
449
|
+
Simulates the sequencing of plates and calculates gene fractions and metadata.
|
450
|
+
|
451
|
+
Parameters:
|
452
|
+
well_score (pd.DataFrame): DataFrame containing well scores and gene lists.
|
453
|
+
number_of_genes (int): Number of genes.
|
454
|
+
avg_reads_per_gene (float): Average number of reads per gene.
|
455
|
+
sd_reads_per_gene (float): Standard deviation of reads per gene.
|
456
|
+
sequencing_error (float, optional): Probability of introducing sequencing error. Defaults to 0.01.
|
457
|
+
|
458
|
+
Returns:
|
459
|
+
gene_fraction_map (pd.DataFrame): DataFrame containing gene fractions for each well.
|
460
|
+
metadata (pd.DataFrame): DataFrame containing metadata for each well.
|
461
|
+
"""
|
462
|
+
reads, _ = dist_gen(avg_reads_per_gene, sd_reads_per_gene, well_score)
|
463
|
+
gene_names = [f'gene_{v}' for v in range(number_of_genes+1)]
|
464
|
+
all_wells = well_score.index
|
465
|
+
|
466
|
+
gene_counts_map = pd.DataFrame(np.zeros((len(all_wells), number_of_genes+1)), columns=gene_names, index=all_wells)
|
467
|
+
sum_reads = []
|
468
|
+
|
469
|
+
for _, row in well_score.iterrows():
|
470
|
+
gene_list = row['gene_list']
|
471
|
+
|
472
|
+
if gene_list:
|
473
|
+
for gene in gene_list:
|
474
|
+
gene_count = int(random.choice(reads))
|
475
|
+
|
476
|
+
# Decide whether to introduce error or not
|
477
|
+
error = np.random.binomial(1, sequencing_error)
|
478
|
+
if error:
|
479
|
+
# Randomly select a different well
|
480
|
+
wrong_well = np.random.choice(all_wells)
|
481
|
+
gene_counts_map.loc[wrong_well, f'gene_{int(gene)}'] += gene_count
|
482
|
+
else:
|
483
|
+
gene_counts_map.loc[_, f'gene_{int(gene)}'] += gene_count
|
484
|
+
|
485
|
+
sum_reads.append(np.sum(gene_counts_map.loc[_, :]))
|
486
|
+
|
487
|
+
gene_fraction_map = gene_counts_map.div(gene_counts_map.sum(axis=1), axis=0)
|
488
|
+
gene_fraction_map = gene_fraction_map.fillna(0)
|
489
|
+
|
490
|
+
metadata = pd.DataFrame(index=well_score.index)
|
491
|
+
metadata['genes_in_well'] = gene_fraction_map.astype(bool).sum(axis=1)
|
492
|
+
metadata['sum_fractions'] = gene_fraction_map.sum(axis=1)
|
493
|
+
metadata['sum_reads'] = sum_reads
|
494
|
+
|
495
|
+
return gene_fraction_map, metadata
|
496
|
+
|
497
|
+
#metadata['sum_reads'] = metadata['sum_fractions'].div(metadata['genes_in_well'])
|
498
|
+
def regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False):
|
499
|
+
"""
|
500
|
+
Calculate regression ROC AUC and other statistics.
|
501
|
+
|
502
|
+
Parameters:
|
503
|
+
results_df (DataFrame): DataFrame containing the results of regression analysis.
|
504
|
+
active_gene_list (list): List of active gene IDs.
|
505
|
+
control_gene_list (list): List of control gene IDs.
|
506
|
+
alpha (float, optional): Significance level for determining hits. Default is 0.05.
|
507
|
+
optimal (bool, optional): Whether to use the optimal threshold for classification. Default is False.
|
508
|
+
|
509
|
+
Returns:
|
510
|
+
tuple: A tuple containing the following:
|
511
|
+
- results_df (DataFrame): Updated DataFrame with additional columns.
|
512
|
+
- reg_roc_dict_df (DataFrame): DataFrame containing regression ROC curve data.
|
513
|
+
- reg_pr_dict_df (DataFrame): DataFrame containing precision-recall curve data.
|
514
|
+
- reg_cm (ndarray): Confusion matrix.
|
515
|
+
- sim_stats (DataFrame): DataFrame containing simulation statistics.
|
516
|
+
"""
|
517
|
+
results_df = results_df.rename(columns={"P>|t|": "p"})
|
518
|
+
|
519
|
+
# asign active genes a value of 1 and inactive genes a value of 0
|
520
|
+
actives_list = ['gene_' + str(i) for i in active_gene_list]
|
521
|
+
results_df['active'] = results_df['gene'].apply(lambda x: 1 if x in actives_list else 0)
|
522
|
+
results_df['active'].fillna(0, inplace=True)
|
523
|
+
|
524
|
+
#generate a colun to color control,active and inactive genes
|
525
|
+
controls_list = ['gene_' + str(i) for i in control_gene_list]
|
526
|
+
results_df['color'] = results_df['gene'].apply(lambda x: 'control' if x in controls_list else ('active' if x in actives_list else 'inactive'))
|
527
|
+
|
528
|
+
#generate a size column and handdf.replace([np.inf, -np.inf], np.nan, inplace=True)le infinate and NaN values create a new column for -log(p)
|
529
|
+
results_df['size'] = results_df['active']
|
530
|
+
results_df['p'] = results_df['p'].clip(lower=0.0001)
|
531
|
+
results_df['logp'] = -np.log10(results_df['p'])
|
532
|
+
|
533
|
+
#calculate cutoff for hits based on randomly chosen 'control' genes
|
534
|
+
control_df = results_df[results_df['color'] == 'control']
|
535
|
+
control_mean = control_df['coef'].mean()
|
536
|
+
#control_std = control_df['coef'].std()
|
537
|
+
control_var = control_df['coef'].var()
|
538
|
+
cutoff = abs(control_mean)+(3*control_var)
|
539
|
+
|
540
|
+
#calculate discriptive statistics for active genes
|
541
|
+
active_df = results_df[results_df['color'] == 'active']
|
542
|
+
active_mean = active_df['coef'].mean()
|
543
|
+
active_std = active_df['coef'].std()
|
544
|
+
active_var = active_df['coef'].var()
|
545
|
+
|
546
|
+
#calculate discriptive statistics for active genes
|
547
|
+
inactive_df = results_df[results_df['color'] == 'inactive']
|
548
|
+
inactive_mean = inactive_df['coef'].mean()
|
549
|
+
inactive_std = inactive_df['coef'].std()
|
550
|
+
inactive_var = inactive_df['coef'].var()
|
551
|
+
|
552
|
+
#generate score column for hits and non hitts
|
553
|
+
results_df['score'] = np.where(((results_df['coef'] >= cutoff) | (results_df['coef'] <= -cutoff)) & (results_df['p'] <= alpha), 1, 0)
|
554
|
+
|
555
|
+
#calculate regression roc based on controll cutoff
|
556
|
+
fpr, tpr, thresh = roc_curve(results_df['active'], results_df['score'])
|
557
|
+
roc_auc = auc(fpr, tpr)
|
558
|
+
reg_roc_dict_df = pd.DataFrame({'threshold':thresh, 'tpr': tpr, 'fpr': fpr, 'roc_auc':roc_auc})
|
559
|
+
|
560
|
+
pr, re, th = precision_recall_curve(results_df['active'], results_df['score'])
|
561
|
+
th = np.insert(th, 0, 0)
|
562
|
+
f1_score = 2 * (pr * re) / (pr + re)
|
563
|
+
pr_auc = auc(re, pr)
|
564
|
+
reg_pr_dict_df = pd.DataFrame({'threshold':th, 'precision': pr, 'recall': re, 'f1_score':f1_score, 'pr_auc': pr_auc})
|
565
|
+
|
566
|
+
optimal_threshold = reg_pr_dict_df['f1_score'].idxmax()
|
567
|
+
if optimal:
|
568
|
+
results_df[optimal_threshold] = results_df.score.apply(lambda x: 1 if x >= optimal_threshold else 0)
|
569
|
+
reg_cm = confusion_matrix(results_df.active, results_df[optimal_threshold])
|
570
|
+
else:
|
571
|
+
results_df[0.5] = results_df.score.apply(lambda x: 1 if x >= 0.5 else 0)
|
572
|
+
reg_cm = confusion_matrix(results_df.active, results_df[0.5])
|
573
|
+
|
574
|
+
TN = reg_cm[0][0]
|
575
|
+
FP = reg_cm[0][1]
|
576
|
+
FN = reg_cm[1][0]
|
577
|
+
TP = reg_cm[1][1]
|
578
|
+
|
579
|
+
accuracy = (TP + TN) / (TP + FP + FN + TN) # Accuracy
|
580
|
+
sim_stats = {'optimal_threshold':optimal_threshold,
|
581
|
+
'accuracy': accuracy,
|
582
|
+
'prauc':pr_auc,
|
583
|
+
'roc_auc':roc_auc,
|
584
|
+
'inactive_mean':inactive_mean,
|
585
|
+
'inactive_std':inactive_std,
|
586
|
+
'inactive_var':inactive_var,
|
587
|
+
'active_mean':active_mean,
|
588
|
+
'active_std':active_std,
|
589
|
+
'active_var':active_var,
|
590
|
+
'cutoff':cutoff,
|
591
|
+
'TP':TP,
|
592
|
+
'FP':FP,
|
593
|
+
'TN':TN,
|
594
|
+
'FN':FN}
|
595
|
+
|
596
|
+
return results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, pd.DataFrame([sim_stats])
|
597
|
+
|
598
|
+
def plot_histogram(data, x_label, ax, color, title, binwidth=0.01, log=False):
|
599
|
+
"""
|
600
|
+
Plots a histogram of the given data.
|
601
|
+
|
602
|
+
Parameters:
|
603
|
+
- data: The data to be plotted.
|
604
|
+
- x_label: The label for the x-axis.
|
605
|
+
- ax: The matplotlib axis object to plot on.
|
606
|
+
- color: The color of the histogram bars.
|
607
|
+
- title: The title of the plot.
|
608
|
+
- binwidth: The width of each histogram bin.
|
609
|
+
- log: Whether to use a logarithmic scale for the y-axis.
|
610
|
+
|
611
|
+
Returns:
|
612
|
+
None
|
613
|
+
"""
|
614
|
+
sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
|
615
|
+
legend=False, fill=True, element='step', palette='dark')
|
616
|
+
if log:
|
617
|
+
ax.set_yscale('log')
|
618
|
+
ax.set_title(title)
|
619
|
+
ax.set_xlabel(x_label)
|
620
|
+
|
621
|
+
def plot_roc_pr(data, ax, title, x_label, y_label):
|
622
|
+
"""
|
623
|
+
Plot the ROC (Receiver Operating Characteristic) and PR (Precision-Recall) curves.
|
624
|
+
|
625
|
+
Parameters:
|
626
|
+
- data: DataFrame containing the data to be plotted.
|
627
|
+
- ax: The matplotlib axes object to plot on.
|
628
|
+
- title: The title of the plot.
|
629
|
+
- x_label: The label for the x-axis.
|
630
|
+
- y_label: The label for the y-axis.
|
631
|
+
"""
|
632
|
+
ax.plot(data[x_label], data[y_label], color='black', lw=0.5)
|
633
|
+
ax.plot([0, 1], [0, 1], color='black', lw=0.5, linestyle="--", label='random classifier')
|
634
|
+
ax.set_title(title)
|
635
|
+
ax.set_ylabel(y_label)
|
636
|
+
ax.set_xlabel(x_label)
|
637
|
+
ax.legend(loc="lower right")
|
638
|
+
|
639
|
+
def plot_confusion_matrix(data, ax, title):
|
640
|
+
"""
|
641
|
+
Plots a confusion matrix using a heatmap.
|
642
|
+
|
643
|
+
Parameters:
|
644
|
+
data (numpy.ndarray): The confusion matrix data.
|
645
|
+
ax (matplotlib.axes.Axes): The axes object to plot the heatmap on.
|
646
|
+
title (str): The title of the plot.
|
647
|
+
|
648
|
+
Returns:
|
649
|
+
None
|
650
|
+
"""
|
651
|
+
group_names = ['True Neg','False Pos','False Neg','True Pos']
|
652
|
+
group_counts = ["{0:0.0f}".format(value) for value in data.flatten()]
|
653
|
+
group_percentages = ["{0:.2%}".format(value) for value in data.flatten()/np.sum(data)]
|
654
|
+
|
655
|
+
sns.heatmap(data, cmap='Blues', ax=ax)
|
656
|
+
for i in range(data.shape[0]):
|
657
|
+
for j in range(data.shape[1]):
|
658
|
+
ax.text(j+0.5, i+0.5, f'{group_names[i*2+j]}\n{group_counts[i*2+j]}\n{group_percentages[i*2+j]}',
|
659
|
+
ha="center", va="center", color="black")
|
660
|
+
|
661
|
+
ax.set_title(title)
|
662
|
+
ax.set_xlabel('\nPredicted Values')
|
663
|
+
ax.set_ylabel('Actual Values ')
|
664
|
+
ax.xaxis.set_ticklabels(['False','True'])
|
665
|
+
ax.yaxis.set_ticklabels(['False','True'])
|
666
|
+
|
667
|
+
|
668
|
+
def run_simulation(settings):
|
669
|
+
"""
|
670
|
+
Run the simulation based on the given settings.
|
671
|
+
|
672
|
+
Args:
|
673
|
+
settings (dict): A dictionary containing the simulation settings.
|
674
|
+
|
675
|
+
Returns:
|
676
|
+
tuple: A tuple containing the simulation results and distances.
|
677
|
+
- cell_scores (DataFrame): Scores for each cell.
|
678
|
+
- cell_roc_dict_df (DataFrame): ROC AUC scores for each cell.
|
679
|
+
- cell_pr_dict_df (DataFrame): Precision-Recall AUC scores for each cell.
|
680
|
+
- cell_cm (DataFrame): Confusion matrix for each cell.
|
681
|
+
- well_score (DataFrame): Scores for each well.
|
682
|
+
- gene_fraction_map (DataFrame): Fraction of genes for each well.
|
683
|
+
- metadata (DataFrame): Metadata for each well.
|
684
|
+
- results_df (DataFrame): Results of the regression analysis.
|
685
|
+
- reg_roc_dict_df (DataFrame): ROC AUC scores for each gene.
|
686
|
+
- reg_pr_dict_df (DataFrame): Precision-Recall AUC scores for each gene.
|
687
|
+
- reg_cm (DataFrame): Confusion matrix for each gene.
|
688
|
+
- sim_stats (dict): Additional simulation statistics.
|
689
|
+
- genes_per_well_df (DataFrame): Number of genes per well.
|
690
|
+
- wells_per_gene_df (DataFrame): Number of wells per gene.
|
691
|
+
dists (list): List of distances.
|
692
|
+
"""
|
693
|
+
#try:
|
694
|
+
active_gene_list = generate_gene_list(settings['number_of_active_genes'], settings['number_of_genes'])
|
695
|
+
control_gene_list = generate_gene_list(settings['number_of_control_genes'], settings['number_of_genes'])
|
696
|
+
plate_map = generate_plate_map(settings['nr_plates'])
|
697
|
+
|
698
|
+
#control_map = plate_map[plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is in [1,2,3,23,24]
|
699
|
+
plate_map = plate_map[~plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is not in [1,2,3,23,24]
|
700
|
+
|
701
|
+
cell_level, genes_per_well_df, wells_per_gene_df, dists = run_experiment(plate_map, settings['number_of_genes'], active_gene_list, settings['avg_genes_per_well'], settings['sd_genes_per_well'], settings['avg_cells_per_well'], settings['sd_cells_per_well'], settings['well_ineq_coeff'], settings['gene_ineq_coeff'])
|
702
|
+
cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], df=cell_level)
|
703
|
+
cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm = cell_level_roc_auc(cell_scores)
|
704
|
+
well_score = generate_well_score(cell_scores)
|
705
|
+
gene_fraction_map, metadata = sequence_plates(well_score, settings['number_of_genes'], settings['avg_reads_per_gene'], settings['sd_reads_per_gene'], sequencing_error=settings['sequencing_error'])
|
706
|
+
x = gene_fraction_map
|
707
|
+
y = np.log10(well_score['score']+1)
|
708
|
+
x = sm.add_constant(x)
|
709
|
+
#y = y.fillna(0)
|
710
|
+
#x = x.fillna(0)
|
711
|
+
#x['const'] = 0.0
|
712
|
+
model = sm.OLS(y, x).fit()
|
713
|
+
#predictions = model.predict(x)
|
714
|
+
results_summary = model.summary()
|
715
|
+
results_as_html = results_summary.tables[1].as_html()
|
716
|
+
results_df = pd.read_html(results_as_html, header=0, index_col=0)[0]
|
717
|
+
results_df = results_df.rename_axis("gene").reset_index()
|
718
|
+
results_df = results_df.iloc[1: , :]
|
719
|
+
results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats = regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False)
|
720
|
+
#except Exception as e:
|
721
|
+
# print(f"An error occurred while saving data: {e}")
|
722
|
+
return [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df], dists
|
723
|
+
|
724
|
+
def vis_dists(dists, src, v, i):
|
725
|
+
"""
|
726
|
+
Visualizes the distributions of given distances.
|
727
|
+
|
728
|
+
Args:
|
729
|
+
dists (list): List of distance arrays.
|
730
|
+
src (str): Source directory for saving the plot.
|
731
|
+
v (int): Number of vertices.
|
732
|
+
i (int): Index of the plot.
|
733
|
+
|
734
|
+
Returns:
|
735
|
+
None
|
736
|
+
"""
|
737
|
+
n_graphs = 6
|
738
|
+
height_graphs = 4
|
739
|
+
n=0
|
740
|
+
width_graphs = height_graphs*n_graphs
|
741
|
+
fig2, ax =plt.subplots(1,n_graphs, figsize = (width_graphs,height_graphs))
|
742
|
+
names = ['genes/well', 'wells/gene', 'genes/well gini', 'wells/gene gini', 'gene_weights', 'well_weights']
|
743
|
+
for index, dist in enumerate(dists):
|
744
|
+
temp = pd.DataFrame(dist, columns = [f'{names[index]}'])
|
745
|
+
sns.histplot(data=temp, x=f'{names[index]}', kde=False, binwidth=None, stat='count', element="step", ax=ax[n], color='teal', log_scale=False)
|
746
|
+
#plot_histogram(temp, f'{names[index]}', ax[n], 'slategray', f'{names[index]}', binwidth=None, log=False)
|
747
|
+
n+=1
|
748
|
+
save_plot(fig2, src, 'dists', i)
|
749
|
+
return
|
750
|
+
|
751
|
+
def visualize_all(output):
|
752
|
+
"""
|
753
|
+
Visualizes various plots based on the given output data.
|
754
|
+
|
755
|
+
Args:
|
756
|
+
output (list): A list containing the following elements:
|
757
|
+
- cell_scores (DataFrame): DataFrame containing cell scores.
|
758
|
+
- cell_roc_dict_df (DataFrame): DataFrame containing ROC curve data for cell classification.
|
759
|
+
- cell_pr_dict_df (DataFrame): DataFrame containing precision-recall curve data for cell classification.
|
760
|
+
- cell_cm (array-like): Confusion matrix for cell classification.
|
761
|
+
- well_score (DataFrame): DataFrame containing well scores.
|
762
|
+
- gene_fraction_map (dict): Dictionary mapping genes to fractions.
|
763
|
+
- metadata (dict): Dictionary containing metadata.
|
764
|
+
- results_df (DataFrame): DataFrame containing results.
|
765
|
+
- reg_roc_dict_df (DataFrame): DataFrame containing ROC curve data for gene regression.
|
766
|
+
- reg_pr_dict_df (DataFrame): DataFrame containing precision-recall curve data for gene regression.
|
767
|
+
- reg_cm (array-like): Confusion matrix for gene regression.
|
768
|
+
- sim_stats (dict): Dictionary containing simulation statistics.
|
769
|
+
- genes_per_well_df (DataFrame): DataFrame containing genes per well data.
|
770
|
+
- wells_per_gene_df (DataFrame): DataFrame containing wells per gene data.
|
771
|
+
|
772
|
+
Returns:
|
773
|
+
fig (matplotlib.figure.Figure): The generated figure object.
|
774
|
+
"""
|
775
|
+
cell_scores = output[0]
|
776
|
+
cell_roc_dict_df = output[1]
|
777
|
+
cell_pr_dict_df = output[2]
|
778
|
+
cell_cm = output[3]
|
779
|
+
well_score = output[4]
|
780
|
+
gene_fraction_map = output[5]
|
781
|
+
metadata = output[6]
|
782
|
+
results_df = output[7]
|
783
|
+
reg_roc_dict_df = output[8]
|
784
|
+
reg_pr_dict_df = output[9]
|
785
|
+
reg_cm =output[10]
|
786
|
+
sim_stats = output[11]
|
787
|
+
genes_per_well_df = output[12]
|
788
|
+
wells_per_gene_df = output[13]
|
789
|
+
|
790
|
+
hline = -np.log10(0.05)
|
791
|
+
n_graphs = 13
|
792
|
+
height_graphs = 4
|
793
|
+
n=0
|
794
|
+
width_graphs = height_graphs*n_graphs
|
795
|
+
|
796
|
+
fig, ax =plt.subplots(1,n_graphs, figsize = (width_graphs,height_graphs))
|
797
|
+
|
798
|
+
#plot genes per well
|
799
|
+
gini_genes_per_well = gini(genes_per_well_df['genes_per_well'].tolist())
|
800
|
+
plot_histogram(genes_per_well_df, "genes_per_well", ax[n], 'slategray', f'gene/well (gini = {gini_genes_per_well:.2f})', binwidth=None, log=False)
|
801
|
+
n+=1
|
802
|
+
|
803
|
+
#plot wells per gene
|
804
|
+
gini_wells_per_gene = gini(wells_per_gene_df['wells_per_gene'].tolist())
|
805
|
+
plot_histogram(wells_per_gene_df, "wells_per_gene", ax[n], 'slategray', f'well/gene (Gini = {gini_wells_per_gene:.2f})', binwidth=None, log=False)
|
806
|
+
#ax[n].set_xscale('log')
|
807
|
+
n+=1
|
808
|
+
|
809
|
+
#plot cell classification score by inactive and active
|
810
|
+
active_distribution = cell_scores[cell_scores['is_active'] == 1]
|
811
|
+
inactive_distribution = cell_scores[cell_scores['is_active'] == 0]
|
812
|
+
plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', binwidth=0.01, log=False)
|
813
|
+
plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', binwidth=0.01, log=False)
|
814
|
+
ax[n].set_xlim([0, 1])
|
815
|
+
n+=1
|
816
|
+
|
817
|
+
#plot classifier cell predictions by inactive and active well average
|
818
|
+
##inactive_distribution_well['score'] = pd.to_numeric(inactive_distribution['score'], errors='coerce')
|
819
|
+
##inactive_distribution_well = inactive_distribution_well.groupby('plate_row_column')['score'].mean()
|
820
|
+
|
821
|
+
##active_distribution_well['score'] = pd.to_numeric(active_distribution['score'], errors='coerce')
|
822
|
+
##active_distribution_well = active_distribution_well.groupby('plate_row_column')['score'].mean()
|
823
|
+
|
824
|
+
#inactive_distribution_well = inactive_distribution.groupby(['plate_row_column']).mean()
|
825
|
+
#active_distribution_well = active_distribution.groupby(['plate_row_column']).mean()
|
826
|
+
|
827
|
+
plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Well scores', binwidth=0.01, log=False)
|
828
|
+
plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Well scores', binwidth=0.01, log=False)
|
829
|
+
ax[n].set_xlim([0, 1])
|
830
|
+
n+=1
|
831
|
+
|
832
|
+
#plot ROC (cell classification)
|
833
|
+
plot_roc_pr(cell_roc_dict_df, ax[n], 'ROC (Cell)', 'fpr', 'tpr')
|
834
|
+
ax[n].plot([0, 1], [0, 1], color='black', lw=0.5, linestyle="--", label='random classifier')
|
835
|
+
n+=1
|
836
|
+
|
837
|
+
#plot Presision recall (cell classification)
|
838
|
+
plot_roc_pr(cell_pr_dict_df, ax[n], 'Precision recall (Cell)', 'recall', 'precision')
|
839
|
+
ax[n].set_ylim([-0.1, 1.1])
|
840
|
+
ax[n].set_xlim([-0.1, 1.1])
|
841
|
+
n+=1
|
842
|
+
|
843
|
+
#Confusion matrix at optimal threshold
|
844
|
+
plot_confusion_matrix(cell_cm, ax[n], 'Confusion Matrix Cell')
|
845
|
+
n+=1
|
846
|
+
|
847
|
+
#plot well score
|
848
|
+
plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=False)
|
849
|
+
ax[n].set_xlim([0, 1])
|
850
|
+
n+=1
|
851
|
+
|
852
|
+
control_df = results_df[results_df['color'] == 'control']
|
853
|
+
control_mean = control_df['coef'].mean()
|
854
|
+
control_var = control_df['coef'].std()
|
855
|
+
#control_var = control_df['coef'].var()
|
856
|
+
cutoff = abs(control_mean)+(3*control_var)
|
857
|
+
categories = ['inactive', 'control', 'active']
|
858
|
+
colors = ['lightgrey', 'black', 'purple']
|
859
|
+
|
860
|
+
for category, color in zip(categories, colors):
|
861
|
+
df = results_df[results_df['color'] == category]
|
862
|
+
ax[n].scatter(df['coef'], df['logp'], c=color, alpha=0.7, label=category)
|
863
|
+
|
864
|
+
reg_lab = ax[n].legend(title='', frameon=False, prop={'size': 10})
|
865
|
+
ax[n].add_artist(reg_lab)
|
866
|
+
ax[n].axhline(hline, zorder = 0,c = 'k', lw = 0.5,ls = '--')
|
867
|
+
ax[n].axvline(-cutoff, zorder = 0,c = 'k', lw = 0.5,ls = '--')
|
868
|
+
ax[n].axvline(cutoff, zorder = 0,c = 'k', lw = 0.5,ls = '--')
|
869
|
+
ax[n].set_title(f'Regression, threshold {cutoff:.3f}')
|
870
|
+
ax[n].set_xlim([-1, 1.1])
|
871
|
+
n+=1
|
872
|
+
|
873
|
+
# error plot
|
874
|
+
df = results_df[['gene', 'coef', 'std err', 'p']]
|
875
|
+
df = df.sort_values(by = ['coef', 'p'], ascending = [True, False], na_position = 'first')
|
876
|
+
df['rank'] = [*range(0,len(df),1)]
|
877
|
+
|
878
|
+
#df['rank'] = pd.to_numeric(df['rank'], errors='coerce')
|
879
|
+
#df['coef'] = pd.to_numeric(df['coef'], errors='coerce')
|
880
|
+
#df['std err'] = pd.to_numeric(df['std err'], errors='coerce')
|
881
|
+
#df['rank'] = df['rank'].astype(float)
|
882
|
+
#df['coef'] = df['coef'].astype(float)
|
883
|
+
#df['std err'] = df['std err'].astype(float)
|
884
|
+
#epsilon = 1e-6 # A small constant to ensure std err is never zero
|
885
|
+
#df['std err adj'] = df['std err'].replace(0, epsilon)
|
886
|
+
|
887
|
+
ax[n].plot(df['rank'], df['coef'], '-', color = 'black')
|
888
|
+
ax[n].fill_between(df['rank'], df['coef'] - abs(df['std err']), df['coef'] + abs(df['std err']), alpha=0.4, color='slategray')
|
889
|
+
ax[n].set_title('Effect score error')
|
890
|
+
ax[n].set_xlabel('rank')
|
891
|
+
ax[n].set_ylabel('Effect size')
|
892
|
+
n+=1
|
893
|
+
|
894
|
+
#plot ROC (gene classification)
|
895
|
+
plot_roc_pr(reg_roc_dict_df, ax[n], 'ROC (gene)', 'fpr', 'tpr')
|
896
|
+
ax[n].legend(loc="lower right")
|
897
|
+
n+=1
|
898
|
+
|
899
|
+
#plot Presision recall (regression classification)
|
900
|
+
plot_roc_pr(reg_pr_dict_df, ax[n], 'Precision recall (gene)', 'recall', 'precision')
|
901
|
+
ax[n].legend(loc="lower right")
|
902
|
+
n+=1
|
903
|
+
|
904
|
+
#Confusion matrix at optimal threshold
|
905
|
+
plot_confusion_matrix(reg_cm, ax[n], 'Confusion Matrix Reg')
|
906
|
+
|
907
|
+
for n in [*range(0,n_graphs,1)]:
|
908
|
+
ax[n].spines['top'].set_visible(False)
|
909
|
+
ax[n].spines['right'].set_visible(False)
|
910
|
+
|
911
|
+
plt.tight_layout()
|
912
|
+
plt.show()
|
913
|
+
return fig
|
914
|
+
|
915
|
+
def create_database(db_path):
|
916
|
+
"""
|
917
|
+
Creates a SQLite database at the specified path.
|
918
|
+
|
919
|
+
Args:
|
920
|
+
db_path (str): The path where the database should be created.
|
921
|
+
|
922
|
+
Returns:
|
923
|
+
None
|
924
|
+
"""
|
925
|
+
conn = None
|
926
|
+
try:
|
927
|
+
conn = sqlite3.connect(db_path)
|
928
|
+
#print(f"SQLite version: {sqlite3.version}")
|
929
|
+
except Exception as e:
|
930
|
+
print(e)
|
931
|
+
finally:
|
932
|
+
if conn:
|
933
|
+
conn.close()
|
934
|
+
|
935
|
+
def append_database(src, table, table_name):
|
936
|
+
"""
|
937
|
+
Append a pandas DataFrame to an SQLite database table.
|
938
|
+
|
939
|
+
Parameters:
|
940
|
+
src (str): The source directory where the database file is located.
|
941
|
+
table (pandas.DataFrame): The DataFrame to be appended to the database table.
|
942
|
+
table_name (str): The name of the database table.
|
943
|
+
|
944
|
+
Returns:
|
945
|
+
None
|
946
|
+
"""
|
947
|
+
try:
|
948
|
+
conn = sqlite3.connect(f'{src}/simulations.db', timeout=3600)
|
949
|
+
table.to_sql(table_name, conn, if_exists='append', index=False)
|
950
|
+
except sqlite3.OperationalError as e:
|
951
|
+
print("SQLite error:", e)
|
952
|
+
finally:
|
953
|
+
conn.close()
|
954
|
+
return
|
955
|
+
|
956
|
+
def save_data(src, output, settings, save_all=False, i=0, variable='all'):
|
957
|
+
"""
|
958
|
+
Save simulation data to specified location.
|
959
|
+
|
960
|
+
Args:
|
961
|
+
src (str): The directory path where the data will be saved.
|
962
|
+
output (list): A list of dataframes containing simulation output.
|
963
|
+
settings (dict): A dictionary containing simulation settings.
|
964
|
+
save_all (bool, optional): Flag indicating whether to save all tables or only a subset. Defaults to False.
|
965
|
+
i (int, optional): The simulation number. Defaults to 0.
|
966
|
+
variable (str, optional): The variable name. Defaults to 'all'.
|
967
|
+
|
968
|
+
Returns:
|
969
|
+
None
|
970
|
+
"""
|
971
|
+
try:
|
972
|
+
if not save_all:
|
973
|
+
src = f'{src}'
|
974
|
+
os.makedirs(src, exist_ok=True)
|
975
|
+
else:
|
976
|
+
os.makedirs(src, exist_ok=True)
|
977
|
+
|
978
|
+
settings_df = pd.DataFrame({key: [value] for key, value in settings.items()})
|
979
|
+
output = [settings_df] + output
|
980
|
+
table_names = ['settings', 'cell_scores', 'cell_roc', 'cell_precision_recall', 'cell_confusion_matrix', 'well_score', 'gene_fraction_map', 'metadata', 'regression_results', 'regression_roc', 'regression_precision_recall', 'regression_confusion_matrix', 'sim_stats', 'genes_per_well', 'wells_per_gene']
|
981
|
+
|
982
|
+
if not save_all:
|
983
|
+
gini_genes_per_well = gini(output[13]['genes_per_well'].tolist())
|
984
|
+
gini_wells_per_gene = gini(output[14]['wells_per_gene'].tolist())
|
985
|
+
indices_to_keep= [0,12] # Specify the indices to remove
|
986
|
+
filtered_output = [v for i, v in enumerate(output) if i in indices_to_keep]
|
987
|
+
df_concat = pd.concat(filtered_output, axis=1)
|
988
|
+
df_concat['genes_per_well_gini'] = gini_genes_per_well
|
989
|
+
df_concat['wells_per_gene_gini'] = gini_wells_per_gene
|
990
|
+
df_concat['date'] = datetime.now()
|
991
|
+
df_concat[f'variable_{variable}_sim_nr'] = i
|
992
|
+
|
993
|
+
append_database(src, df_concat, 'simulations')
|
994
|
+
|
995
|
+
if save_all:
|
996
|
+
for i, df in enumerate(output):
|
997
|
+
df = output[i]
|
998
|
+
if table_names[i] == 'well_score':
|
999
|
+
df['gene_list'] = df['gene_list'].astype(str)
|
1000
|
+
if not isinstance(df, pd.DataFrame):
|
1001
|
+
df = pd.DataFrame(df)
|
1002
|
+
append_database(src, df, table_names[i])
|
1003
|
+
except Exception as e:
|
1004
|
+
print(f"An error occurred while saving data: {e}")
|
1005
|
+
return
|
1006
|
+
|
1007
|
+
def save_plot(fig, src, variable, i):
|
1008
|
+
"""
|
1009
|
+
Save a matplotlib figure as a PDF file.
|
1010
|
+
|
1011
|
+
Parameters:
|
1012
|
+
- fig: The matplotlib figure to be saved.
|
1013
|
+
- src: The directory where the file will be saved.
|
1014
|
+
- variable: The name of the variable being plotted.
|
1015
|
+
- i: The index of the figure.
|
1016
|
+
|
1017
|
+
Returns:
|
1018
|
+
None
|
1019
|
+
"""
|
1020
|
+
os.makedirs(f'{src}/{variable}', exist_ok=True)
|
1021
|
+
filename_fig = f'{src}/{variable}/{str(i)}_figure.pdf'
|
1022
|
+
fig.savefig(filename_fig, dpi=600, format='pdf', bbox_inches='tight')
|
1023
|
+
return
|
1024
|
+
|
1025
|
+
def run_and_save(i, settings, time_ls, total_sims):
|
1026
|
+
"""
|
1027
|
+
Run the simulation and save the results.
|
1028
|
+
|
1029
|
+
Args:
|
1030
|
+
i (int): The simulation index.
|
1031
|
+
settings (dict): The simulation settings.
|
1032
|
+
time_ls (list): The list to store simulation times.
|
1033
|
+
total_sims (int): The total number of simulations.
|
1034
|
+
|
1035
|
+
Returns:
|
1036
|
+
tuple: A tuple containing the simulation index, simulation time, and None.
|
1037
|
+
"""
|
1038
|
+
if settings['random_seed']:
|
1039
|
+
random.seed(42) # sims will be too similar with random seed
|
1040
|
+
src = settings['src']
|
1041
|
+
plot = settings['plot']
|
1042
|
+
v = settings['variable']
|
1043
|
+
start_time = time() # Start time of the simulation
|
1044
|
+
now = datetime.now() # get current date
|
1045
|
+
date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1046
|
+
#try:
|
1047
|
+
output, dists = run_simulation(settings)
|
1048
|
+
sim_time = time() - start_time # Elapsed time for the simulation
|
1049
|
+
settings['sim_time'] = sim_time
|
1050
|
+
src = os.path.join(f'{src}/{date_string}',settings['name'])
|
1051
|
+
save_data(src, output, settings, save_all=False, i=i, variable=v)
|
1052
|
+
if vis_dists:
|
1053
|
+
vis_dists(dists,src, v, i)
|
1054
|
+
if plot:
|
1055
|
+
fig = visualize_all(output)
|
1056
|
+
save_plot(fig, src, v, i)
|
1057
|
+
plt.close(fig)
|
1058
|
+
plt.figure().clear()
|
1059
|
+
plt.cla()
|
1060
|
+
plt.clf()
|
1061
|
+
del fig
|
1062
|
+
del output, dists
|
1063
|
+
#except Exception as e:
|
1064
|
+
# print(e, end='\r', flush=True)
|
1065
|
+
# sim_time = time() - start_time
|
1066
|
+
#print(traceback.format_exc(), end='\r', flush=True)
|
1067
|
+
time_ls.append(sim_time)
|
1068
|
+
return i, sim_time, None
|
1069
|
+
|
1070
|
+
def generate_paramiters(settings):
|
1071
|
+
"""
|
1072
|
+
Generate a list of parameter sets for simulation based on the given settings.
|
1073
|
+
|
1074
|
+
Args:
|
1075
|
+
settings (dict): A dictionary containing the simulation settings.
|
1076
|
+
|
1077
|
+
Returns:
|
1078
|
+
list: A list of parameter sets for simulation.
|
1079
|
+
"""
|
1080
|
+
sim_ls = []
|
1081
|
+
for avg_genes_per_well in settings['avg_genes_per_well']:
|
1082
|
+
replicates = settings['replicates']
|
1083
|
+
sett = settings.copy()
|
1084
|
+
sett['avg_genes_per_well'] = avg_genes_per_well
|
1085
|
+
sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
|
1086
|
+
for avg_cells_per_well in settings['avg_cells_per_well']:
|
1087
|
+
sett['avg_cells_per_well'] = avg_cells_per_well
|
1088
|
+
sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
|
1089
|
+
for positive_mean in settings['positive_mean']:
|
1090
|
+
sett['positive_mean'] = positive_mean
|
1091
|
+
sett['negative_mean'] = 1-positive_mean
|
1092
|
+
sett['positive_variance'] = (1-positive_mean)/2
|
1093
|
+
sett['negative_variance'] = (1-positive_mean)/2
|
1094
|
+
for avg_reads_per_gene in settings['avg_reads_per_gene']:
|
1095
|
+
sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
|
1096
|
+
sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
|
1097
|
+
for sequencing_error in settings['sequencing_error']:
|
1098
|
+
sett['sequencing_error'] = sequencing_error
|
1099
|
+
for well_ineq_coeff in settings['well_ineq_coeff']:
|
1100
|
+
sett['well_ineq_coeff'] = well_ineq_coeff
|
1101
|
+
for gene_ineq_coeff in settings['gene_ineq_coeff']:
|
1102
|
+
sett['gene_ineq_coeff'] = gene_ineq_coeff
|
1103
|
+
for nr_plates in settings['nr_plates']:
|
1104
|
+
sett['nr_plates'] = nr_plates
|
1105
|
+
for number_of_genes in settings['number_of_genes']:
|
1106
|
+
sett['number_of_genes'] = number_of_genes
|
1107
|
+
for number_of_active_genes in settings['number_of_active_genes']:
|
1108
|
+
sett['number_of_active_genes'] = number_of_active_genes
|
1109
|
+
for i in [*range(1,replicates+1)]:
|
1110
|
+
sim_ls.append(sett)
|
1111
|
+
#print(sett)
|
1112
|
+
#print('Number of simulations:',len(sim_ls))
|
1113
|
+
return sim_ls
|
1114
|
+
|
1115
|
+
#altered for one set of settings see negative_mean and variance
|
1116
|
+
def generate_paramiters_single(settings):
|
1117
|
+
"""
|
1118
|
+
Generate a list of parameter sets for single simulations based on the given settings.
|
1119
|
+
|
1120
|
+
Args:
|
1121
|
+
settings (dict): A dictionary containing the simulation settings.
|
1122
|
+
|
1123
|
+
Returns:
|
1124
|
+
list: A list of parameter sets for single simulations.
|
1125
|
+
"""
|
1126
|
+
sim_ls = []
|
1127
|
+
for avg_genes_per_well in settings['avg_genes_per_well']:
|
1128
|
+
replicates = settings['replicates']
|
1129
|
+
sett = settings.copy()
|
1130
|
+
sett['avg_genes_per_well'] = avg_genes_per_well
|
1131
|
+
sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
|
1132
|
+
for avg_cells_per_well in settings['avg_cells_per_well']:
|
1133
|
+
sett['avg_cells_per_well'] = avg_cells_per_well
|
1134
|
+
sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
|
1135
|
+
for positive_mean in settings['positive_mean']:
|
1136
|
+
sett['positive_mean'] = positive_mean
|
1137
|
+
sett['negative_mean'] = 0.2
|
1138
|
+
sett['positive_variance'] = 0.13
|
1139
|
+
sett['negative_variance'] = 0.13
|
1140
|
+
for avg_reads_per_gene in settings['avg_reads_per_gene']:
|
1141
|
+
sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
|
1142
|
+
sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
|
1143
|
+
for sequencing_error in settings['sequencing_error']:
|
1144
|
+
sett['sequencing_error'] = sequencing_error
|
1145
|
+
for well_ineq_coeff in settings['well_ineq_coeff']:
|
1146
|
+
sett['well_ineq_coeff'] = well_ineq_coeff
|
1147
|
+
for gene_ineq_coeff in settings['gene_ineq_coeff']:
|
1148
|
+
sett['gene_ineq_coeff'] = gene_ineq_coeff
|
1149
|
+
for nr_plates in settings['nr_plates']:
|
1150
|
+
sett['nr_plates'] = nr_plates
|
1151
|
+
for number_of_genes in settings['number_of_genes']:
|
1152
|
+
sett['number_of_genes'] = number_of_genes
|
1153
|
+
for number_of_active_genes in settings['number_of_active_genes']:
|
1154
|
+
sett['number_of_active_genes'] = number_of_active_genes
|
1155
|
+
for i in [*range(1,replicates+1)]:
|
1156
|
+
sim_ls.append(sett)
|
1157
|
+
#print(sett)
|
1158
|
+
#print('Number of simulations:',len(sim_ls))
|
1159
|
+
return sim_ls
|
1160
|
+
|
1161
|
+
def run_multiple_simulations(settings):
|
1162
|
+
"""
|
1163
|
+
Run multiple simulations in parallel using the provided settings.
|
1164
|
+
|
1165
|
+
Args:
|
1166
|
+
settings (dict): A dictionary containing the simulation settings.
|
1167
|
+
|
1168
|
+
Returns:
|
1169
|
+
None
|
1170
|
+
"""
|
1171
|
+
|
1172
|
+
sim_ls = generate_paramiters(settings)
|
1173
|
+
print(f'Running {len(sim_ls)} simulations. Standard deviations for each variable are variable / 2')
|
1174
|
+
|
1175
|
+
max_workers = settings['max_workers'] or cpu_count() - 4
|
1176
|
+
with Manager() as manager:
|
1177
|
+
time_ls = manager.list()
|
1178
|
+
total_sims = len(sim_ls)
|
1179
|
+
with Pool(max_workers) as pool:
|
1180
|
+
result = pool.starmap_async(run_and_save, [(index, settings, time_ls, total_sims) for index, settings in enumerate(sim_ls)])
|
1181
|
+
while not result.ready():
|
1182
|
+
sleep(0.01)
|
1183
|
+
sims_processed = len(time_ls)
|
1184
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
1185
|
+
time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
|
1186
|
+
print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
|
1187
|
+
result.get()
|