spacr 0.0.81__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +4 -0
- spacr/annotate_app.py +4 -0
- spacr/annotate_app_v2.py +511 -0
- spacr/core.py +258 -177
- spacr/deep_spacr.py +137 -50
- spacr/graph_learning.py +28 -8
- spacr/io.py +332 -142
- spacr/measure.py +2 -1
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model +0 -0
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +23 -0
- spacr/plot.py +102 -6
- spacr/sequencing.py +849 -129
- spacr/settings.py +477 -0
- spacr/timelapse.py +0 -3
- spacr/utils.py +312 -275
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/METADATA +1 -1
- spacr-0.1.0.dist-info/RECORD +40 -0
- spacr-0.0.81.dist-info/RECORD +0 -36
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/LICENSE +0 -0
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/WHEEL +0 -0
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.81.dist-info → spacr-0.1.0.dist-info}/top_level.txt +0 -0
spacr/sequencing.py
CHANGED
@@ -7,10 +7,19 @@ import matplotlib.pyplot as plt
|
|
7
7
|
import seaborn as sns
|
8
8
|
from Bio import pairwise2
|
9
9
|
import statsmodels.api as sm
|
10
|
-
|
10
|
+
from statsmodels.regression.mixed_linear_model import MixedLM
|
11
|
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
11
12
|
from scipy.stats import gmean
|
13
|
+
from scipy import stats
|
12
14
|
from difflib import SequenceMatcher
|
13
15
|
from collections import Counter
|
16
|
+
from IPython.display import display
|
17
|
+
|
18
|
+
from sklearn.linear_model import LinearRegression, Lasso, Ridge
|
19
|
+
from sklearn.preprocessing import FunctionTransformer, MinMaxScaler
|
20
|
+
|
21
|
+
from scipy.stats import shapiro
|
22
|
+
from patsy import dmatrices
|
14
23
|
|
15
24
|
def analyze_reads(settings):
|
16
25
|
"""
|
@@ -28,7 +37,7 @@ def analyze_reads(settings):
|
|
28
37
|
None
|
29
38
|
"""
|
30
39
|
|
31
|
-
def
|
40
|
+
def save_chunk_to_hdf5_v1(output_file_path, data_chunk, chunk_counter):
|
32
41
|
"""
|
33
42
|
Save a data chunk to an HDF5 file.
|
34
43
|
|
@@ -44,6 +53,28 @@ def analyze_reads(settings):
|
|
44
53
|
with pd.HDFStore(output_file_path, mode='a', complevel=5, complib='blosc') as store:
|
45
54
|
store.put(f'reads/chunk_{chunk_counter}', df, format='table', append=True)
|
46
55
|
|
56
|
+
def save_chunk_to_hdf5(output_file_path, data_chunk, chunk_counter):
|
57
|
+
"""
|
58
|
+
Save a data chunk to an HDF5 file.
|
59
|
+
|
60
|
+
Parameters:
|
61
|
+
- output_file_path (str): The path to the output HDF5 file.
|
62
|
+
- data_chunk (list): The data chunk to be saved.
|
63
|
+
- chunk_counter (int): The counter for the current chunk.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
None
|
67
|
+
"""
|
68
|
+
df = pd.DataFrame(data_chunk, columns=['combined_read', 'grna', 'plate_row', 'column', 'sample'])
|
69
|
+
with pd.HDFStore(output_file_path, mode='a', complevel=5, complib='blosc') as store:
|
70
|
+
store.put(
|
71
|
+
f'reads/chunk_{chunk_counter}',
|
72
|
+
df,
|
73
|
+
format='table',
|
74
|
+
append=True,
|
75
|
+
min_itemsize={'combined_read': 300, 'grna': 50, 'plate_row': 20, 'column': 20, 'sample': 50}
|
76
|
+
)
|
77
|
+
|
47
78
|
def reverse_complement(seq):
|
48
79
|
"""
|
49
80
|
Returns the reverse complement of a DNA sequence.
|
@@ -139,7 +170,7 @@ def analyze_reads(settings):
|
|
139
170
|
best_alignment = alignments[0]
|
140
171
|
return best_alignment
|
141
172
|
|
142
|
-
def combine_reads(samples_dict, src, chunk_size,
|
173
|
+
def combine_reads(samples_dict, src, chunk_size, barecode_length_1, barecode_length_2, upstream, downstream):
|
143
174
|
"""
|
144
175
|
Combine reads from paired-end sequencing files and save the combined reads to a new file.
|
145
176
|
|
@@ -186,7 +217,7 @@ def analyze_reads(settings):
|
|
186
217
|
r1_size_est = os.path.getsize(r1_path) // (avg_read_length * 4) if r1_path else 0
|
187
218
|
r2_size_est = os.path.getsize(r2_path) // (avg_read_length * 4) if r2_path else 0
|
188
219
|
max_size = max(r1_size_est, r2_size_est) * 10
|
189
|
-
|
220
|
+
test10 =0
|
190
221
|
with tqdm(total=max_size, desc=f"Processing {sample}") as pbar:
|
191
222
|
total_length_processed = 0
|
192
223
|
read_count = 0
|
@@ -229,12 +260,26 @@ def analyze_reads(settings):
|
|
229
260
|
combo_split_index_1 = read_combo.find(upstream)
|
230
261
|
combo_split_index_2 = read_combo.find(downstream)
|
231
262
|
|
232
|
-
barcode_1 = read_combo[combo_split_index_1 -
|
263
|
+
barcode_1 = read_combo[combo_split_index_1 - barecode_length_1:combo_split_index_1]
|
233
264
|
grna = read_combo[combo_split_index_1 + len(upstream):combo_split_index_2]
|
234
|
-
barcode_2 = read_combo[combo_split_index_2 + len(downstream):combo_split_index_2 + len(downstream) +
|
265
|
+
barcode_2 = read_combo[combo_split_index_2 + len(downstream):combo_split_index_2 + len(downstream) + barecode_length_2]
|
235
266
|
barcode_2 = reverse_complement(barcode_2)
|
236
267
|
data_chunk.append((read_combo, grna, barcode_1, barcode_2, sample))
|
237
268
|
|
269
|
+
if settings['test']:
|
270
|
+
if read_count % 1000 == 0:
|
271
|
+
print(f"Read count: {read_count}")
|
272
|
+
print(f"Read 1: {r1_read_rc}")
|
273
|
+
print(f"Read 2: {r2_read}")
|
274
|
+
print(f"Read combo: {read_combo}")
|
275
|
+
print(f"Barcode 1: {barcode_1}")
|
276
|
+
print(f"gRNA: {grna}")
|
277
|
+
print(f"Barcode 2: {barcode_2}")
|
278
|
+
print()
|
279
|
+
test10 += 1
|
280
|
+
if test10 == 10:
|
281
|
+
break
|
282
|
+
|
238
283
|
read_count += 1
|
239
284
|
total_length_processed += len(r1_read) + len(r2_read)
|
240
285
|
|
@@ -261,13 +306,12 @@ def analyze_reads(settings):
|
|
261
306
|
qc_df = pd.DataFrame([qc])
|
262
307
|
qc_df.to_csv(qc_file_path, index=False)
|
263
308
|
|
264
|
-
|
265
|
-
|
266
|
-
settings
|
267
|
-
|
268
|
-
|
309
|
+
from .utils import get_analyze_reads_default_settings
|
310
|
+
|
311
|
+
settings = get_analyze_reads_default_settings(settings)
|
312
|
+
|
269
313
|
samples_dict = parse_gz_files(settings['src'])
|
270
|
-
combine_reads(samples_dict, settings['src'], settings['chunk_size'], settings['
|
314
|
+
combine_reads(samples_dict, settings['src'], settings['chunk_size'], settings['barecode_length_1'], settings['barecode_length_2'], settings['upstream'], settings['downstream'])
|
271
315
|
|
272
316
|
def map_barcodes(h5_file_path, settings={}):
|
273
317
|
"""
|
@@ -280,27 +324,20 @@ def map_barcodes(h5_file_path, settings={}):
|
|
280
324
|
Returns:
|
281
325
|
None
|
282
326
|
"""
|
283
|
-
def get_read_qc(df,
|
327
|
+
def get_read_qc(df, settings):
|
284
328
|
"""
|
285
329
|
Calculate quality control metrics for sequencing reads.
|
286
330
|
|
287
331
|
Parameters:
|
288
332
|
- df: DataFrame containing the sequencing reads.
|
289
|
-
- df_cleaned: DataFrame containing the cleaned sequencing reads.
|
290
333
|
|
291
334
|
Returns:
|
292
|
-
-
|
293
|
-
|
294
|
-
- 'cleaned_reads': Total number of cleaned reads.
|
295
|
-
- 'NaN_grna': Number of reads with missing 'grna_metadata'.
|
296
|
-
- 'NaN_plate_row': Number of reads with missing 'plate_row_metadata'.
|
297
|
-
- 'NaN_column': Number of reads with missing 'column_metadata'.
|
298
|
-
- 'NaN_plate': Number of reads with missing 'plate_metadata'.
|
299
|
-
- 'unique_grna': Counter object containing the count of unique 'grna_metadata' values.
|
300
|
-
- 'unique_plate_row': Counter object containing the count of unique 'plate_row_metadata' values.
|
301
|
-
- 'unique_column': Counter object containing the count of unique 'column_metadata' values.
|
302
|
-
- 'unique_plate': Counter object containing the count of unique 'plate_metadata' values.
|
335
|
+
- df_cleaned: DataFrame containing the cleaned sequencing reads.
|
336
|
+
- qc_dict: Dictionary containing the quality control metrics.
|
303
337
|
"""
|
338
|
+
|
339
|
+
df_cleaned = df.dropna()
|
340
|
+
|
304
341
|
qc_dict = {}
|
305
342
|
qc_dict['reads'] = len(df)
|
306
343
|
qc_dict['cleaned_reads'] = len(df_cleaned)
|
@@ -312,9 +349,56 @@ def map_barcodes(h5_file_path, settings={}):
|
|
312
349
|
qc_dict['unique_plate_row'] = Counter(df['plate_row_metadata'].dropna().tolist())
|
313
350
|
qc_dict['unique_column'] = Counter(df['column_metadata'].dropna().tolist())
|
314
351
|
qc_dict['unique_plate'] = Counter(df['plate_metadata'].dropna().tolist())
|
352
|
+
|
353
|
+
# Calculate control error rates using cleaned DataFrame
|
354
|
+
total_pc_non_nan = df_cleaned[(df_cleaned['column_metadata'] == settings['pc_loc'])].shape[0]
|
355
|
+
total_nc_non_nan = df_cleaned[(df_cleaned['column_metadata'] == settings['nc_loc'])].shape[0]
|
315
356
|
|
316
|
-
|
317
|
-
|
357
|
+
pc_count_pc = df_cleaned[(df_cleaned['column_metadata'] == settings['pc_loc']) & (df_cleaned['grna_metadata'] == settings['pc'])].shape[0]
|
358
|
+
nc_count_nc = df_cleaned[(df_cleaned['column_metadata'] == settings['nc_loc']) & (df_cleaned['grna_metadata'] == settings['nc'])].shape[0]
|
359
|
+
|
360
|
+
pc_error_count = df_cleaned[(df_cleaned['column_metadata'] == settings['pc_loc']) & (df_cleaned['grna_metadata'] != settings['pc'])].shape[0]
|
361
|
+
nc_error_count = df_cleaned[(df_cleaned['column_metadata'] == settings['nc_loc']) & (df_cleaned['grna_metadata'] != settings['nc'])].shape[0]
|
362
|
+
|
363
|
+
pc_in_nc_loc_count = df_cleaned[(df_cleaned['column_metadata'] == settings['nc_loc']) & (df_cleaned['grna_metadata'] == settings['pc'])].shape[0]
|
364
|
+
nc_in_pc_loc_count = df_cleaned[(df_cleaned['column_metadata'] == settings['pc_loc']) & (df_cleaned['grna_metadata'] == settings['nc'])].shape[0]
|
365
|
+
|
366
|
+
# Collect QC metrics into a dictionary
|
367
|
+
# PC
|
368
|
+
qc_dict['pc_total_count'] = total_pc_non_nan
|
369
|
+
qc_dict['pc_count_pc'] = pc_count_pc
|
370
|
+
qc_dict['nc_count_pc'] = pc_in_nc_loc_count
|
371
|
+
qc_dict['pc_error_count'] = pc_error_count
|
372
|
+
# NC
|
373
|
+
qc_dict['nc_total_count'] = total_nc_non_nan
|
374
|
+
qc_dict['nc_count_nc'] = nc_count_nc
|
375
|
+
qc_dict['pc_count_nc'] = nc_in_pc_loc_count
|
376
|
+
qc_dict['nc_error_count'] = nc_error_count
|
377
|
+
|
378
|
+
return df_cleaned, qc_dict
|
379
|
+
|
380
|
+
def get_per_row_qc(df, settings):
|
381
|
+
"""
|
382
|
+
Calculate quality control metrics for each unique row in the control columns.
|
383
|
+
|
384
|
+
Parameters:
|
385
|
+
- df: DataFrame containing the sequencing reads.
|
386
|
+
- settings: Dictionary containing the settings for control values.
|
387
|
+
|
388
|
+
Returns:
|
389
|
+
- dict: Dictionary containing the quality control metrics for each unique row.
|
390
|
+
"""
|
391
|
+
qc_dict_per_row = {}
|
392
|
+
unique_rows = df['plate_row_metadata'].dropna().unique().tolist()
|
393
|
+
unique_rows = list(set(unique_rows)) # Remove duplicates
|
394
|
+
|
395
|
+
for row in unique_rows:
|
396
|
+
df_row = df[(df['plate_row_metadata'] == row)]
|
397
|
+
_, qc_dict_row = get_read_qc(df_row, settings)
|
398
|
+
qc_dict_per_row[row] = qc_dict_row
|
399
|
+
|
400
|
+
return qc_dict_per_row
|
401
|
+
|
318
402
|
def mapping_dicts(df, settings):
|
319
403
|
"""
|
320
404
|
Maps the values in the DataFrame columns to corresponding metadata using dictionaries.
|
@@ -339,22 +423,87 @@ def map_barcodes(h5_file_path, settings={}):
|
|
339
423
|
df['plate_row_metadata'] = df['plate_row'].map(plate_row_dict)
|
340
424
|
df['column_metadata'] = df['column'].map(column_dict)
|
341
425
|
df['plate_metadata'] = df['sample'].map(plate_dict)
|
342
|
-
|
426
|
+
|
343
427
|
return df
|
344
428
|
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
429
|
+
def filter_combinations(df, settings):
|
430
|
+
"""
|
431
|
+
Takes the combination counts Data Frame, filters the rows based on specific conditions,
|
432
|
+
and removes rows with a count lower than the highest value of max_count_c1 and max_count_c2.
|
433
|
+
|
434
|
+
Args:
|
435
|
+
combination_counts_file_path (str): The file path to the CSV file containing the combination counts.
|
436
|
+
pc (str, optional): The positive control sequence. Defaults to 'TGGT1_220950_1'.
|
437
|
+
nc (str, optional): The negative control sequence. Defaults to 'TGGT1_233460_4'.
|
438
|
+
|
439
|
+
Returns:
|
440
|
+
pd.DataFrame: The filtered DataFrame.
|
441
|
+
"""
|
442
|
+
|
443
|
+
pc = settings['pc']
|
444
|
+
nc = settings['nc']
|
445
|
+
pc_loc = settings['pc_loc']
|
446
|
+
nc_loc = settings['nc_loc']
|
447
|
+
|
448
|
+
filtered_c1 = df[(df['column'] == nc_loc) & (df['grna'] != nc)]
|
449
|
+
max_count_c1 = filtered_c1['count'].max()
|
450
|
+
|
451
|
+
filtered_c2 = df[(df['column'] == pc_loc) & (df['grna'] != pc)]
|
452
|
+
max_count_c2 = filtered_c2['count'].max()
|
453
|
+
|
454
|
+
#filtered_c3 = df[(df['column'] != nc_loc) & (df['grna'] == nc)]
|
455
|
+
#max_count_c3 = filtered_c3['count'].max()
|
456
|
+
|
457
|
+
#filtered_c4 = df[(df['column'] != pc_loc) & (df['grna'] == pc)]
|
458
|
+
#max_count_c4 = filtered_c4['count'].max()
|
459
|
+
|
460
|
+
# Find the highest value between max_count_c1 and max_count_c2
|
461
|
+
highest_max_count = max(max_count_c1, max_count_c2)
|
462
|
+
|
463
|
+
# Filter the DataFrame to remove rows with a count lower than the highest_max_count
|
464
|
+
filtered_df = df[df['count'] >= highest_max_count]
|
465
|
+
|
466
|
+
# Calculate total read counts for each unique combination of plate_row and column
|
467
|
+
filtered_df['total_reads'] = filtered_df.groupby(['plate_row', 'column'])['count'].transform('sum')
|
468
|
+
|
469
|
+
# Calculate read fraction for each row
|
470
|
+
filtered_df['read_fraction'] = filtered_df['count'] / filtered_df['total_reads']
|
471
|
+
|
472
|
+
if settings['verbose']:
|
473
|
+
print(f"Max count for non {nc} in {nc_loc}: {max_count_c1}")
|
474
|
+
print(f"Max count for non {pc} in {pc_loc}: {max_count_c2}")
|
475
|
+
#print(f"Max count for {nc} in other columns: {max_count_c3}")
|
476
|
+
|
477
|
+
return filtered_df
|
478
|
+
|
479
|
+
from .settings import get_map_barcodes_default_settings
|
480
|
+
|
481
|
+
settings = get_map_barcodes_default_settings(settings)
|
482
|
+
|
483
|
+
fldr = os.path.splitext(h5_file_path)[0]
|
484
|
+
file_name = os.path.basename(fldr)
|
485
|
+
|
486
|
+
if settings['test']:
|
487
|
+
fldr = os.path.join(fldr, 'test')
|
488
|
+
os.makedirs(fldr, exist_ok=True)
|
489
|
+
|
490
|
+
qc_file_path = os.path.join(fldr, f'{file_name}_qc_step_2.csv')
|
491
|
+
unique_grna_file_path = os.path.join(fldr, f'{file_name}_unique_grna.csv')
|
492
|
+
unique_plate_row_file_path = os.path.join(fldr, f'{file_name}_unique_plate_row.csv')
|
493
|
+
unique_column_file_path = os.path.join(fldr, f'{file_name}_unique_column.csv')
|
494
|
+
unique_plate_file_path = os.path.join(fldr, f'{file_name}_unique_plate.csv')
|
495
|
+
new_h5_file_path = os.path.join(fldr, f'{file_name}_cleaned.h5')
|
496
|
+
combination_counts_file_path = os.path.join(fldr, f'{file_name}_combination_counts.csv')
|
497
|
+
combination_counts_file_path_cleaned = os.path.join(fldr, f'{file_name}_combination_counts_cleaned.csv')
|
498
|
+
|
499
|
+
#qc_file_path = os.path.splitext(h5_file_path)[0] + '_qc_step_2.csv'
|
500
|
+
#unique_grna_file_path = os.path.splitext(h5_file_path)[0] + '_unique_grna.csv'
|
501
|
+
#unique_plate_row_file_path = os.path.splitext(h5_file_path)[0] + '_unique_plate_row.csv'
|
502
|
+
#unique_column_file_path = os.path.splitext(h5_file_path)[0] + '_unique_column.csv'
|
503
|
+
#unique_plate_file_path = os.path.splitext(h5_file_path)[0] + '_unique_plate.csv'
|
504
|
+
#new_h5_file_path = os.path.splitext(h5_file_path)[0] + '_cleaned.h5'
|
505
|
+
#combination_counts_file_path = os.path.splitext(h5_file_path)[0] + '_combination_counts.csv'
|
506
|
+
#combination_counts_file_path_cleaned = os.path.splitext(h5_file_path)[0] + '_combination_counts_cleaned.csv'
|
358
507
|
|
359
508
|
# Initialize the HDF5 store for cleaned data
|
360
509
|
store_cleaned = pd.HDFStore(new_h5_file_path, mode='a', complevel=5, complib='blosc')
|
@@ -370,38 +519,89 @@ def map_barcodes(h5_file_path, settings={}):
|
|
370
519
|
'unique_grna': Counter(),
|
371
520
|
'unique_plate_row': Counter(),
|
372
521
|
'unique_column': Counter(),
|
373
|
-
'unique_plate': Counter()
|
522
|
+
'unique_plate': Counter(),
|
523
|
+
'pc_total_count': 0,
|
524
|
+
'pc_count_pc': 0,
|
525
|
+
'nc_total_count': 0,
|
526
|
+
'nc_count_nc': 0,
|
527
|
+
'pc_count_nc': 0,
|
528
|
+
'nc_count_pc': 0,
|
529
|
+
'pc_error_count': 0,
|
530
|
+
'nc_error_count': 0,
|
531
|
+
'pc_fraction_pc': 0,
|
532
|
+
'nc_fraction_nc': 0,
|
533
|
+
'pc_fraction_nc': 0,
|
534
|
+
'nc_fraction_pc': 0
|
374
535
|
}
|
375
536
|
|
537
|
+
per_row_qc = {}
|
538
|
+
combination_counts = Counter()
|
539
|
+
|
376
540
|
with pd.HDFStore(h5_file_path, mode='r') as store:
|
377
541
|
keys = [key for key in store.keys() if key.startswith('/reads/chunk_')]
|
378
|
-
|
542
|
+
|
543
|
+
if settings['test']:
|
544
|
+
keys = keys[:3] # Only read the first chunks if in test mode
|
545
|
+
|
379
546
|
for key in keys:
|
380
547
|
df = store.get(key)
|
381
548
|
df = mapping_dicts(df, settings)
|
382
|
-
df_cleaned = df
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
overall_qc['reads'] += qc_dict['reads']
|
387
|
-
overall_qc['cleaned_reads'] += qc_dict['cleaned_reads']
|
388
|
-
overall_qc['NaN_grna'] += qc_dict['NaN_grna']
|
389
|
-
overall_qc['NaN_plate_row'] += qc_dict['NaN_plate_row']
|
390
|
-
overall_qc['NaN_column'] += qc_dict['NaN_column']
|
391
|
-
overall_qc['NaN_plate'] += qc_dict['NaN_plate']
|
392
|
-
overall_qc['unique_grna'].update(qc_dict['unique_grna'])
|
393
|
-
overall_qc['unique_plate_row'].update(qc_dict['unique_plate_row'])
|
394
|
-
overall_qc['unique_column'].update(qc_dict['unique_column'])
|
395
|
-
overall_qc['unique_plate'].update(qc_dict['unique_plate'])
|
549
|
+
df_cleaned, qc_dict = get_read_qc(df, settings)
|
550
|
+
|
551
|
+
# Accumulate counts for unique combinations
|
552
|
+
combinations = df_cleaned[['plate_row_metadata', 'column_metadata', 'grna_metadata']].apply(tuple, axis=1)
|
396
553
|
|
397
|
-
|
398
|
-
|
554
|
+
combination_counts.update(combinations)
|
555
|
+
|
556
|
+
if settings['test'] and settings['verbose']:
|
557
|
+
os.makedirs(os.path.join(os.path.splitext(h5_file_path)[0],'test'), exist_ok=True)
|
558
|
+
df.to_csv(os.path.join(os.path.splitext(h5_file_path)[0],'test','chunk_1_df.csv'), index=False)
|
559
|
+
df_cleaned.to_csv(os.path.join(os.path.splitext(h5_file_path)[0],'test','chunk_1_df_cleaned.csv'), index=False)
|
560
|
+
|
561
|
+
# Accumulate QC metrics for all rows
|
562
|
+
for metric in qc_dict:
|
563
|
+
if isinstance(overall_qc[metric], Counter):
|
564
|
+
overall_qc[metric].update(qc_dict[metric])
|
565
|
+
else:
|
566
|
+
overall_qc[metric] += qc_dict[metric]
|
567
|
+
|
568
|
+
# Update per_row_qc dictionary
|
569
|
+
chunk_per_row_qc = get_per_row_qc(df, settings)
|
570
|
+
for row in chunk_per_row_qc:
|
571
|
+
if row not in per_row_qc:
|
572
|
+
per_row_qc[row] = chunk_per_row_qc[row]
|
573
|
+
else:
|
574
|
+
for metric in chunk_per_row_qc[row]:
|
575
|
+
if isinstance(per_row_qc[row][metric], Counter):
|
576
|
+
per_row_qc[row][metric].update(chunk_per_row_qc[row][metric])
|
577
|
+
else:
|
578
|
+
per_row_qc[row][metric] += chunk_per_row_qc[row][metric]
|
579
|
+
|
580
|
+
# Ensure the DataFrame columns are in the desired order
|
581
|
+
df_cleaned = df_cleaned[['grna', 'plate_row', 'column', 'sample', 'grna_metadata', 'plate_row_metadata', 'column_metadata', 'plate_metadata']]
|
582
|
+
|
399
583
|
# Save cleaned data to the new HDF5 store
|
400
584
|
store_cleaned.put('reads/cleaned_data', df_cleaned, format='table', append=True)
|
401
|
-
|
585
|
+
|
402
586
|
del df_cleaned, df
|
403
587
|
gc.collect()
|
404
588
|
|
589
|
+
# Calculate overall fractions after accumulating all metrics
|
590
|
+
overall_qc['pc_fraction_pc'] = overall_qc['pc_count_pc'] / overall_qc['pc_total_count'] if overall_qc['pc_total_count'] else 0
|
591
|
+
overall_qc['nc_fraction_nc'] = overall_qc['nc_count_nc'] / overall_qc['nc_total_count'] if overall_qc['nc_total_count'] else 0
|
592
|
+
overall_qc['pc_fraction_nc'] = overall_qc['pc_count_nc'] / overall_qc['nc_total_count'] if overall_qc['nc_total_count'] else 0
|
593
|
+
overall_qc['nc_fraction_pc'] = overall_qc['nc_count_pc'] / overall_qc['pc_total_count'] if overall_qc['pc_total_count'] else 0
|
594
|
+
|
595
|
+
for row in per_row_qc:
|
596
|
+
if row != 'all_rows':
|
597
|
+
per_row_qc[row]['pc_fraction_pc'] = per_row_qc[row]['pc_count_pc'] / per_row_qc[row]['pc_total_count'] if per_row_qc[row]['pc_total_count'] else 0
|
598
|
+
per_row_qc[row]['nc_fraction_nc'] = per_row_qc[row]['nc_count_nc'] / per_row_qc[row]['nc_total_count'] if per_row_qc[row]['nc_total_count'] else 0
|
599
|
+
per_row_qc[row]['pc_fraction_nc'] = per_row_qc[row]['pc_count_nc'] / per_row_qc[row]['nc_total_count'] if per_row_qc[row]['nc_total_count'] else 0
|
600
|
+
per_row_qc[row]['nc_fraction_pc'] = per_row_qc[row]['nc_count_pc'] / per_row_qc[row]['pc_total_count'] if per_row_qc[row]['pc_total_count'] else 0
|
601
|
+
|
602
|
+
# Add overall_qc to per_row_qc with the key 'all_rows'
|
603
|
+
per_row_qc['all_rows'] = overall_qc
|
604
|
+
|
405
605
|
# Convert the Counter objects to DataFrames and save them to CSV files
|
406
606
|
unique_grna_df = pd.DataFrame(overall_qc['unique_grna'].items(), columns=['key', 'value'])
|
407
607
|
unique_plate_row_df = pd.DataFrame(overall_qc['unique_plate_row'].items(), columns=['key', 'value'])
|
@@ -422,89 +622,128 @@ def map_barcodes(h5_file_path, settings={}):
|
|
422
622
|
# Combine all remaining QC metrics into a single DataFrame and save it to CSV
|
423
623
|
qc_df = pd.DataFrame([overall_qc])
|
424
624
|
qc_df.to_csv(qc_file_path, index=False)
|
625
|
+
|
626
|
+
# Convert per_row_qc to a DataFrame and save it to CSV
|
627
|
+
per_row_qc_df = pd.DataFrame.from_dict(per_row_qc, orient='index')
|
628
|
+
per_row_qc_df = per_row_qc_df.sort_values(by='reads', ascending=False)
|
629
|
+
per_row_qc_df = per_row_qc_df.drop(['unique_grna', 'unique_plate_row', 'unique_column', 'unique_plate'], axis=1, errors='ignore')
|
630
|
+
per_row_qc_df = per_row_qc_df.dropna(subset=['reads'])
|
631
|
+
per_row_qc_df.to_csv(os.path.splitext(h5_file_path)[0] + '_per_row_qc.csv', index=True)
|
632
|
+
|
633
|
+
if settings['verbose']:
|
634
|
+
display(per_row_qc_df)
|
635
|
+
|
636
|
+
# Save the combination counts to a CSV file
|
637
|
+
try:
|
638
|
+
combination_counts_df = pd.DataFrame(combination_counts.items(), columns=['combination', 'count'])
|
639
|
+
combination_counts_df[['plate_row', 'column', 'grna']] = pd.DataFrame(combination_counts_df['combination'].tolist(), index=combination_counts_df.index)
|
640
|
+
combination_counts_df = combination_counts_df.drop('combination', axis=1)
|
641
|
+
combination_counts_df.to_csv(combination_counts_file_path, index=False)
|
642
|
+
|
643
|
+
grna_plate_heatmap(combination_counts_file_path, specific_grna=None)
|
644
|
+
grna_plate_heatmap(combination_counts_file_path, specific_grna=settings['pc'])
|
645
|
+
grna_plate_heatmap(combination_counts_file_path, specific_grna=settings['nc'])
|
646
|
+
|
647
|
+
combination_counts_df_cleaned = filter_combinations(combination_counts_df, settings)
|
648
|
+
combination_counts_df_cleaned.to_csv(combination_counts_file_path_cleaned, index=False)
|
649
|
+
|
650
|
+
grna_plate_heatmap(combination_counts_file_path_cleaned, specific_grna=None)
|
651
|
+
grna_plate_heatmap(combination_counts_file_path_cleaned, specific_grna=settings['pc'])
|
652
|
+
grna_plate_heatmap(combination_counts_file_path_cleaned, specific_grna=settings['nc'])
|
653
|
+
except Exception as e:
|
654
|
+
print(e)
|
425
655
|
|
426
656
|
# Close the HDF5 store
|
427
657
|
store_cleaned.close()
|
428
|
-
|
429
658
|
gc.collect()
|
430
659
|
return
|
431
660
|
|
432
|
-
def
|
661
|
+
def grna_plate_heatmap(path, specific_grna=None, min_max='all', cmap='viridis', min_count=0, save=True):
|
662
|
+
"""
|
663
|
+
Generate a heatmap of gRNA plate data.
|
433
664
|
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
665
|
+
Args:
|
666
|
+
path (str): The path to the CSV file containing the gRNA plate data.
|
667
|
+
specific_grna (str, optional): The specific gRNA to filter the data for. Defaults to None.
|
668
|
+
min_max (str or list or tuple, optional): The range of values to use for the color scale.
|
669
|
+
If 'all', the range will be determined by the minimum and maximum values in the data.
|
670
|
+
If 'allq', the range will be determined by the 2nd and 98th percentiles of the data.
|
671
|
+
If a list or tuple of two values, the range will be determined by those values.
|
672
|
+
Defaults to 'all'.
|
673
|
+
cmap (str, optional): The colormap to use for the heatmap. Defaults to 'viridis'.
|
674
|
+
min_count (int, optional): The minimum count threshold for including a gRNA in the heatmap.
|
675
|
+
Defaults to 0.
|
676
|
+
save (bool, optional): Whether to save the heatmap as a PDF file. Defaults to True.
|
677
|
+
|
678
|
+
Returns:
|
679
|
+
matplotlib.figure.Figure: The generated heatmap figure.
|
680
|
+
"""
|
681
|
+
def generate_grna_plate_heatmap(df, plate_number, min_max, min_count, specific_grna=None):
|
682
|
+
df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
|
451
683
|
|
452
|
-
|
684
|
+
# Filtering the dataframe based on the plate_number and specific gRNA if provided
|
685
|
+
df = df[df['plate_row'].str.startswith(plate_number)]
|
686
|
+
if specific_grna:
|
687
|
+
df = df[df['grna'] == specific_grna]
|
688
|
+
|
689
|
+
# Split plate_row into plate and row
|
690
|
+
df[['plate', 'row']] = df['plate_row'].str.split('_', expand=True)
|
691
|
+
|
692
|
+
# Ensure proper ordering
|
693
|
+
row_order = [f'r{i}' for i in range(1, 17)]
|
694
|
+
col_order = [f'c{i}' for i in range(1, 28)]
|
695
|
+
|
696
|
+
df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
|
697
|
+
df['column'] = pd.Categorical(df['column'], categories=col_order, ordered=True)
|
698
|
+
|
699
|
+
# Group by row and column, summing counts
|
700
|
+
grouped = df.groupby(['row', 'column'], observed=True)['count'].sum().reset_index()
|
701
|
+
|
702
|
+
plate_map = pd.pivot_table(grouped, values='count', index='row', columns='column').fillna(0)
|
703
|
+
|
704
|
+
if min_max == 'all':
|
705
|
+
min_max = [plate_map.min().min(), plate_map.max().max()]
|
706
|
+
elif min_max == 'allq':
|
707
|
+
min_max = np.quantile(plate_map.values, [0.02, 0.98])
|
708
|
+
elif isinstance(min_max, (list, tuple)) and len(min_max) == 2:
|
709
|
+
if isinstance(min_max[0], (float)) and isinstance(min_max[1], (float)):
|
710
|
+
min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
|
711
|
+
if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
|
712
|
+
min_max = [min_max[0], min_max[1]]
|
713
|
+
|
714
|
+
return plate_map, min_max
|
453
715
|
|
454
|
-
|
455
|
-
|
456
|
-
|
716
|
+
if isinstance(path, pd.DataFrame):
|
717
|
+
df = path
|
718
|
+
else:
|
719
|
+
df = pd.read_csv(path)
|
457
720
|
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
721
|
+
plates = df['plate_row'].str.split('_', expand=True)[0].unique()
|
722
|
+
n_rows, n_cols = (len(plates) + 3) // 4, 4
|
723
|
+
fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
|
724
|
+
ax = ax.flatten()
|
462
725
|
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
df['plate_metadata'] = df['sample'].map(plate_dict)
|
726
|
+
for index, plate in enumerate(plates):
|
727
|
+
plate_map, min_max_values = generate_grna_plate_heatmap(df, plate, min_max, min_count, specific_grna)
|
728
|
+
sns.heatmap(plate_map, cmap=cmap, vmin=min_max_values[0], vmax=min_max_values[1], ax=ax[index])
|
729
|
+
ax[index].set_title(plate)
|
468
730
|
|
469
|
-
|
470
|
-
|
471
|
-
settings.setdefault('grna', '/home/carruthers/Documents/grna_barcodes.csv')
|
472
|
-
settings.setdefault('barcodes', '/home/carruthers/Documents/SCREEN_BARCODES.csv')
|
473
|
-
settings.setdefault('plate_dict', {'EO1': 'plate1', 'EO2': 'plate2', 'EO3': 'plate3', 'EO4': 'plate4', 'EO5': 'plate5', 'EO6': 'plate6', 'EO7': 'plate7', 'EO8': 'plate8'})
|
474
|
-
settings.setdefault('test', False)
|
475
|
-
settings.setdefault('verbose', True)
|
476
|
-
settings.setdefault('min_itemsize', 1000)
|
477
|
-
|
478
|
-
qc_file_path = os.path.splitext(h5_file_path)[0] + '_qc_step_2.csv'
|
479
|
-
new_h5_file_path = os.path.splitext(h5_file_path)[0] + '_cleaned.h5'
|
731
|
+
for i in range(len(plates), n_rows * n_cols):
|
732
|
+
fig.delaxes(ax[i])
|
480
733
|
|
481
|
-
|
482
|
-
store_cleaned = pd.HDFStore(new_h5_file_path, mode='a', complevel=5, complib='blosc')
|
734
|
+
plt.subplots_adjust(wspace=0.1, hspace=0.4)
|
483
735
|
|
484
|
-
#
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
df_cleaned = df.dropna()
|
494
|
-
qc_dict = get_read_qc(df, df_cleaned)
|
495
|
-
qc_df_list.append(qc_dict)
|
496
|
-
df_cleaned = df_cleaned[df_cleaned['grna_length'] >= 30]
|
497
|
-
|
498
|
-
# Save cleaned data to the new HDF5 store
|
499
|
-
store_cleaned.put('reads/cleaned_data', df_cleaned, format='table', append=True)
|
500
|
-
|
501
|
-
# Combine all QC metrics into a single DataFrame and save it to CSV
|
502
|
-
qc_df = pd.DataFrame(qc_df_list)
|
503
|
-
qc_df.to_csv(qc_file_path, index=False)
|
736
|
+
# Save the figure
|
737
|
+
if save:
|
738
|
+
filename = path.replace('.csv', '')
|
739
|
+
if specific_grna:
|
740
|
+
filename += f'_{specific_grna}'
|
741
|
+
filename += '.pdf'
|
742
|
+
plt.savefig(filename)
|
743
|
+
print(f'saved {filename}')
|
744
|
+
plt.show()
|
504
745
|
|
505
|
-
|
506
|
-
store_cleaned.close()
|
507
|
-
return
|
746
|
+
return fig
|
508
747
|
|
509
748
|
def map_barcodes_folder(src, settings={}):
|
510
749
|
for file in os.listdir(src):
|
@@ -1144,4 +1383,485 @@ def generate_fraction_map(df, gene_column, min_=10, plates=['p1','p2','p3','p4']
|
|
1144
1383
|
independent_variables = independent_variables.drop('sum', axis=1)
|
1145
1384
|
independent_variables.index.name = 'prc'
|
1146
1385
|
independent_variables = independent_variables.loc[:, (independent_variables.sum() != 0)]
|
1147
|
-
return independent_variables
|
1386
|
+
return independent_variables
|
1387
|
+
|
1388
|
+
def precess_reads(csv_path, fraction_threshold, plate):
|
1389
|
+
# Read the CSV file into a DataFrame
|
1390
|
+
csv_df = pd.read_csv(csv_path)
|
1391
|
+
|
1392
|
+
# Ensure the necessary columns are present
|
1393
|
+
if not all(col in csv_df.columns for col in ['grna', 'count', 'column']):
|
1394
|
+
raise ValueError("The CSV file must contain 'grna', 'count', 'plate_row', and 'column' columns.")
|
1395
|
+
|
1396
|
+
if 'plate_row' in csv_df.columns:
|
1397
|
+
csv_df[['plate', 'row']] = csv_df['plate_row'].str.split('_', expand=True)
|
1398
|
+
if plate is not None:
|
1399
|
+
csv_df = csv_df.drop(columns=['plate'])
|
1400
|
+
csv_df['plate'] = plate
|
1401
|
+
|
1402
|
+
if plate is not None:
|
1403
|
+
csv_df['plate'] = plate
|
1404
|
+
|
1405
|
+
# Create the prc column
|
1406
|
+
csv_df['prc'] = csv_df['plate'] + '_' + csv_df['row'] + '_' + csv_df['column']
|
1407
|
+
|
1408
|
+
# Group by prc and calculate the sum of counts
|
1409
|
+
grouped_df = csv_df.groupby('prc')['count'].sum().reset_index()
|
1410
|
+
grouped_df = grouped_df.rename(columns={'count': 'total_counts'})
|
1411
|
+
merged_df = pd.merge(csv_df, grouped_df, on='prc')
|
1412
|
+
merged_df['fraction'] = merged_df['count'] / merged_df['total_counts']
|
1413
|
+
|
1414
|
+
# Filter rows with fraction under the threshold
|
1415
|
+
if fraction_threshold is not None:
|
1416
|
+
observations_before = len(merged_df)
|
1417
|
+
merged_df = merged_df[merged_df['fraction'] >= fraction_threshold]
|
1418
|
+
observations_after = len(merged_df)
|
1419
|
+
removed = observations_before - observations_after
|
1420
|
+
print(f'Removed {removed} observation below fraction threshold: {fraction_threshold}')
|
1421
|
+
|
1422
|
+
merged_df = merged_df[['prc', 'grna', 'fraction']]
|
1423
|
+
|
1424
|
+
if not all(col in merged_df.columns for col in ['grna', 'gene']):
|
1425
|
+
try:
|
1426
|
+
merged_df[['org', 'gene', 'grna']] = merged_df['grna'].str.split('_', expand=True)
|
1427
|
+
merged_df = merged_df.drop(columns=['org'])
|
1428
|
+
merged_df['grna'] = merged_df['gene'] + '_' + merged_df['grna']
|
1429
|
+
except:
|
1430
|
+
print('Error splitting grna into org, gene, grna.')
|
1431
|
+
|
1432
|
+
return merged_df
|
1433
|
+
|
1434
|
+
def apply_transformation(X, transform):
|
1435
|
+
if transform == 'log':
|
1436
|
+
transformer = FunctionTransformer(np.log1p, validate=True)
|
1437
|
+
elif transform == 'sqrt':
|
1438
|
+
transformer = FunctionTransformer(np.sqrt, validate=True)
|
1439
|
+
elif transform == 'square':
|
1440
|
+
transformer = FunctionTransformer(np.square, validate=True)
|
1441
|
+
else:
|
1442
|
+
transformer = None
|
1443
|
+
return transformer
|
1444
|
+
|
1445
|
+
def check_normality(data, variable_name, verbose=False):
|
1446
|
+
"""Check if the data is normally distributed using the Shapiro-Wilk test."""
|
1447
|
+
stat, p_value = shapiro(data)
|
1448
|
+
if verbose:
|
1449
|
+
print(f"Shapiro-Wilk Test for {variable_name}:\nStatistic: {stat}, P-value: {p_value}")
|
1450
|
+
if p_value > 0.05:
|
1451
|
+
if verbose:
|
1452
|
+
print(f"The data for {variable_name} is normally distributed.")
|
1453
|
+
return True
|
1454
|
+
else:
|
1455
|
+
if verbose:
|
1456
|
+
print(f"The data for {variable_name} is not normally distributed.")
|
1457
|
+
return False
|
1458
|
+
|
1459
|
+
def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
|
1460
|
+
|
1461
|
+
if plate is not None:
|
1462
|
+
df['plate'] = plate
|
1463
|
+
|
1464
|
+
if 'col' not in df.columns:
|
1465
|
+
df['col'] = df['column']
|
1466
|
+
|
1467
|
+
df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
|
1468
|
+
df = df[['prc', dependent_variable]]
|
1469
|
+
|
1470
|
+
# Group by prc and calculate the mean and count of the dependent_variable
|
1471
|
+
grouped = df.groupby('prc')[dependent_variable]
|
1472
|
+
|
1473
|
+
if regression_type != 'poisson':
|
1474
|
+
|
1475
|
+
print(f'Using agg_type: {agg_type}')
|
1476
|
+
|
1477
|
+
if agg_type == 'median':
|
1478
|
+
dependent_df = grouped.median().reset_index()
|
1479
|
+
elif agg_type == 'mean':
|
1480
|
+
dependent_df = grouped.mean().reset_index()
|
1481
|
+
elif agg_type == 'quantile':
|
1482
|
+
dependent_df = grouped.quantile(0.75).reset_index()
|
1483
|
+
elif agg_type == None:
|
1484
|
+
dependent_df = df.reset_index()
|
1485
|
+
if 'prcfo' in dependent_df.columns:
|
1486
|
+
dependent_df = dependent_df.drop(columns=['prcfo'])
|
1487
|
+
else:
|
1488
|
+
raise ValueError(f"Unsupported aggregation type {agg_type}")
|
1489
|
+
|
1490
|
+
if regression_type == 'poisson':
|
1491
|
+
agg_type = 'count'
|
1492
|
+
print(f'Using agg_type: {agg_type} for poisson regression')
|
1493
|
+
dependent_df = grouped.sum().reset_index()
|
1494
|
+
|
1495
|
+
# Calculate cell_count for all cases
|
1496
|
+
cell_count = grouped.size().reset_index(name='cell_count')
|
1497
|
+
|
1498
|
+
if agg_type is None:
|
1499
|
+
dependent_df = pd.merge(dependent_df, cell_count, on='prc')
|
1500
|
+
else:
|
1501
|
+
dependent_df['cell_count'] = cell_count['cell_count']
|
1502
|
+
|
1503
|
+
dependent_df = dependent_df[dependent_df['cell_count'] >= min_cell_count]
|
1504
|
+
|
1505
|
+
is_normal = check_normality(dependent_df[dependent_variable], dependent_variable)
|
1506
|
+
|
1507
|
+
if not transform is None:
|
1508
|
+
transformer = apply_transformation(dependent_df[dependent_variable], transform=transform)
|
1509
|
+
transformed_var = f'{transform}_{dependent_variable}'
|
1510
|
+
dependent_df[transformed_var] = transformer.fit_transform(dependent_df[[dependent_variable]])
|
1511
|
+
dependent_variable = transformed_var
|
1512
|
+
is_normal = check_normality(dependent_df[transformed_var], transformed_var)
|
1513
|
+
|
1514
|
+
if not is_normal:
|
1515
|
+
print(f'{dependent_variable} is not normally distributed')
|
1516
|
+
else:
|
1517
|
+
print(f'{dependent_variable} is normally distributed')
|
1518
|
+
|
1519
|
+
return dependent_df, dependent_variable
|
1520
|
+
|
1521
|
+
def perform_mixed_model(y, X, groups, alpha=1.0):
|
1522
|
+
# Ensure groups are defined correctly and check for multicollinearity
|
1523
|
+
if groups is None:
|
1524
|
+
raise ValueError("Groups must be defined for mixed model regression")
|
1525
|
+
|
1526
|
+
# Check for multicollinearity by calculating the VIF for each feature
|
1527
|
+
X_np = X.values
|
1528
|
+
vif = [variance_inflation_factor(X_np, i) for i in range(X_np.shape[1])]
|
1529
|
+
print(f"VIF: {vif}")
|
1530
|
+
if any(v > 10 for v in vif):
|
1531
|
+
print(f"Multicollinearity detected with VIF: {vif}. Applying Ridge regression to the fixed effects.")
|
1532
|
+
ridge = Ridge(alpha=alpha)
|
1533
|
+
ridge.fit(X, y)
|
1534
|
+
X_ridge = ridge.coef_ * X # Adjust X with Ridge coefficients
|
1535
|
+
model = MixedLM(y, X_ridge, groups=groups)
|
1536
|
+
else:
|
1537
|
+
model = MixedLM(y, X, groups=groups)
|
1538
|
+
|
1539
|
+
result = model.fit()
|
1540
|
+
return result
|
1541
|
+
|
1542
|
+
def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, remove_row_column_effect=True):
|
1543
|
+
|
1544
|
+
if regression_type == 'ols':
|
1545
|
+
model = sm.OLS(y, X).fit()
|
1546
|
+
|
1547
|
+
elif regression_type == 'gls':
|
1548
|
+
model = sm.GLS(y, X).fit()
|
1549
|
+
|
1550
|
+
elif regression_type == 'wls':
|
1551
|
+
weights = 1 / np.sqrt(X.iloc[:, 1])
|
1552
|
+
model = sm.WLS(y, X, weights=weights).fit()
|
1553
|
+
|
1554
|
+
elif regression_type == 'rlm':
|
1555
|
+
model = sm.RLM(y, X, M=sm.robust.norms.HuberT()).fit()
|
1556
|
+
#model = sm.RLM(y, X, M=sm.robust.norms.TukeyBiweight()).fit()
|
1557
|
+
#model = sm.RLM(y, X, M=sm.robust.norms.Hampel()).fit()
|
1558
|
+
#model = sm.RLM(y, X, M=sm.robust.norms.LeastSquares()).fit()
|
1559
|
+
#model = sm.RLM(y, X, M=sm.robust.norms.RamsayE()).fit()
|
1560
|
+
#model = sm.RLM(y, X, M=sm.robust.norms.TrimmedMean()).fit()
|
1561
|
+
|
1562
|
+
elif regression_type == 'glm':
|
1563
|
+
model = sm.GLM(y, X, family=sm.families.Gaussian()).fit() # Gaussian: Used for continuous data, similar to OLS regression.
|
1564
|
+
#model = sm.GLM(y, X, family=sm.families.Binomial()).fit() # Binomial: Used for binary data, modeling the probability of success.
|
1565
|
+
#model = sm.GLM(y, X, family=sm.families.Poisson()).fit() # Poisson: Used for count data.
|
1566
|
+
#model = sm.GLM(y, X, family=sm.families.Gamma()).fit() # Gamma: Used for continuous, positive data, often for modeling waiting times or life data.
|
1567
|
+
#model = sm.GLM(y, X, family=sm.families.InverseGaussian()).fit() # Inverse Gaussian: Used for positive continuous data with a variance that increases with the
|
1568
|
+
#model = sm.GLM(y, X, family=sm.families.NegativeBinomial()).fit() # Negative Binomial: Used for count data with overdispersion (variance greater than the mean).
|
1569
|
+
#model = sm.GLM(y, X, family=sm.families.Tweedie()).fit() # Tweedie: Used for data that can take both positive continuous and count values, allowing for a mixture of distributions.
|
1570
|
+
|
1571
|
+
elif regression_type == 'mixed':
|
1572
|
+
model = perform_mixed_model(y, X, groups, alpha=alpha)
|
1573
|
+
|
1574
|
+
elif regression_type == 'quantile':
|
1575
|
+
model = sm.QuantReg(y, X).fit(q=alpha)
|
1576
|
+
|
1577
|
+
elif regression_type == 'logit':
|
1578
|
+
model = sm.Logit(y, X).fit()
|
1579
|
+
|
1580
|
+
elif regression_type == 'probit':
|
1581
|
+
model = sm.Probit(y, X).fit()
|
1582
|
+
|
1583
|
+
elif regression_type == 'poisson':
|
1584
|
+
model = sm.Poisson(y, X).fit()
|
1585
|
+
|
1586
|
+
elif regression_type == 'lasso':
|
1587
|
+
model = Lasso(alpha=alpha).fit(X, y)
|
1588
|
+
|
1589
|
+
elif regression_type == 'ridge':
|
1590
|
+
model = Ridge(alpha=alpha).fit(X, y)
|
1591
|
+
|
1592
|
+
else:
|
1593
|
+
raise ValueError(f"Unsupported regression type {regression_type}")
|
1594
|
+
|
1595
|
+
if regression_type in ['lasso', 'ridge']:
|
1596
|
+
y_pred = model.predict(X)
|
1597
|
+
plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
|
1598
|
+
plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
|
1599
|
+
plt.xlabel('Features')
|
1600
|
+
plt.ylabel('Dependent Variable')
|
1601
|
+
plt.legend()
|
1602
|
+
plt.show()
|
1603
|
+
|
1604
|
+
return model
|
1605
|
+
|
1606
|
+
def clean_controls(df,pc,nc,other):
|
1607
|
+
if 'col' in df.columns:
|
1608
|
+
df['column'] = df['col']
|
1609
|
+
if nc != None:
|
1610
|
+
df = df[~df['column'].isin([nc])]
|
1611
|
+
if pc != None:
|
1612
|
+
df = df[~df['column'].isin([pc])]
|
1613
|
+
if other != None:
|
1614
|
+
df = df[~df['column'].isin([other])]
|
1615
|
+
print(f'Removed data from {nc, pc, other}')
|
1616
|
+
return df
|
1617
|
+
|
1618
|
+
# Remove outliers by capping values at 1st and 99th percentiles for numerical columns only
|
1619
|
+
def remove_outliers(df, low=0.01, high=0.99):
|
1620
|
+
numerical_cols = df.select_dtypes(include=[np.number]).columns
|
1621
|
+
quantiles = df[numerical_cols].quantile([low, high])
|
1622
|
+
for col in numerical_cols:
|
1623
|
+
df[col] = np.clip(df[col], quantiles.loc[low, col], quantiles.loc[high, col])
|
1624
|
+
return df
|
1625
|
+
|
1626
|
+
def calculate_p_values(X, y, model):
|
1627
|
+
# Predict y values
|
1628
|
+
y_pred = model.predict(X)
|
1629
|
+
|
1630
|
+
# Calculate residuals
|
1631
|
+
residuals = y - y_pred
|
1632
|
+
|
1633
|
+
# Calculate the standard error of the residuals
|
1634
|
+
dof = X.shape[0] - X.shape[1] - 1
|
1635
|
+
residual_std_error = np.sqrt(np.sum(residuals ** 2) / dof)
|
1636
|
+
|
1637
|
+
# Calculate the standard error of the coefficients
|
1638
|
+
X_design = np.hstack((np.ones((X.shape[0], 1)), X)) # Add intercept
|
1639
|
+
|
1640
|
+
# Use pseudoinverse instead of inverse to handle singular matrices
|
1641
|
+
coef_var_covar = residual_std_error ** 2 * np.linalg.pinv(X_design.T @ X_design)
|
1642
|
+
coef_standard_errors = np.sqrt(np.diag(coef_var_covar))
|
1643
|
+
|
1644
|
+
# Calculate t-statistics
|
1645
|
+
t_stats = model.coef_ / coef_standard_errors[1:] # Skip intercept error
|
1646
|
+
|
1647
|
+
# Calculate p-values
|
1648
|
+
p_values = [2 * (1 - stats.t.cdf(np.abs(t), dof)) for t in t_stats]
|
1649
|
+
|
1650
|
+
return np.array(p_values) # Ensure p_values is a 1-dimensional array
|
1651
|
+
|
1652
|
+
def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0, remove_row_column_effect=False):
|
1653
|
+
|
1654
|
+
from .plot import volcano_plot, plot_histogram
|
1655
|
+
|
1656
|
+
volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
|
1657
|
+
volcano_filename = regression_type+'_'+volcano_filename
|
1658
|
+
if regression_type == 'quantile':
|
1659
|
+
volcano_filename = str(alpha)+'_'+volcano_filename
|
1660
|
+
volcano_path=os.path.join(os.path.dirname(csv_path), volcano_filename)
|
1661
|
+
|
1662
|
+
is_normal = check_normality(df[dependent_variable], dependent_variable)
|
1663
|
+
|
1664
|
+
if regression_type is None:
|
1665
|
+
if is_normal:
|
1666
|
+
regression_type = 'ols'
|
1667
|
+
else:
|
1668
|
+
regression_type = 'glm'
|
1669
|
+
|
1670
|
+
#df = remove_outliers(df)
|
1671
|
+
|
1672
|
+
if remove_row_column_effect:
|
1673
|
+
|
1674
|
+
## 1. Fit the initial model with row and column to estimate their effects
|
1675
|
+
## 2. Fit the initial model using the specified regression type
|
1676
|
+
## 3. Calculate the residuals
|
1677
|
+
### Residual calculation: Residuals are the differences between the observed and predicted values. This step checks if the initial_model has an attribute resid (residuals). If it does, it directly uses them. Otherwise, it calculates residuals manually by subtracting the predicted values from the observed values (y_with_row_col).
|
1678
|
+
## 4. Use the residuals as the new dependent variable in the final regression model without row and column
|
1679
|
+
### Formula creation: A new regression formula is created, excluding row and column effects, with residuals as the new dependent variable.
|
1680
|
+
### Matrix creation: dmatrices is used again to create new design matrices (X for independent variables and y for the new dependent variable, residuals) based on the new formula and the dataframe df.
|
1681
|
+
#### Remove Confounding Effects:Variables like row and column can introduce systematic biases or confounding effects that might obscure the relationships between the dependent variable and the variables of interest (fraction:gene and fraction:grna).
|
1682
|
+
#### By first estimating the effects of row and column and then using the residuals (the part of the dependent variable that is not explained by row and column), we can focus the final regression model on the relationships of interest without the interference from row and column.
|
1683
|
+
|
1684
|
+
#### Reduce Multicollinearity: Including variables like row and column along with other predictors can sometimes lead to multicollinearity, where predictors are highly correlated with each other. This can make it difficult to determine the individual effect of each predictor.
|
1685
|
+
#### By regressing out the effects of row and column first, we reduce potential multicollinearity issues in the final model.
|
1686
|
+
|
1687
|
+
# Fit the initial model with row and column to estimate their effects
|
1688
|
+
formula_with_row_col = f'{dependent_variable} ~ row + column'
|
1689
|
+
y_with_row_col, X_with_row_col = dmatrices(formula_with_row_col, data=df, return_type='dataframe')
|
1690
|
+
|
1691
|
+
# Fit the initial model using the specified regression type
|
1692
|
+
initial_model = regression_model(X_with_row_col, y_with_row_col, regression_type=regression_type, alpha=alpha)
|
1693
|
+
|
1694
|
+
# Calculate the residuals manually
|
1695
|
+
if hasattr(initial_model, 'resid'):
|
1696
|
+
df['residuals'] = initial_model.resid
|
1697
|
+
else:
|
1698
|
+
df['residuals'] = y_with_row_col.values.ravel() - initial_model.predict(X_with_row_col)
|
1699
|
+
|
1700
|
+
# Use the residuals as the new dependent variable in the final regression model without row and column
|
1701
|
+
formula_without_row_col = 'residuals ~ fraction:gene + fraction:grna'
|
1702
|
+
y, X = dmatrices(formula_without_row_col, data=df, return_type='dataframe')
|
1703
|
+
|
1704
|
+
# Plot histogram of the residuals
|
1705
|
+
plot_histogram(df, 'residuals')
|
1706
|
+
|
1707
|
+
# Scale the independent variables and residuals
|
1708
|
+
scaler_X = MinMaxScaler()
|
1709
|
+
scaler_y = MinMaxScaler()
|
1710
|
+
X = pd.DataFrame(scaler_X.fit_transform(X), columns=X.columns)
|
1711
|
+
y = scaler_y.fit_transform(y)
|
1712
|
+
|
1713
|
+
else:
|
1714
|
+
formula = f'{dependent_variable} ~ fraction:gene + fraction:grna + row + column'
|
1715
|
+
y, X = dmatrices(formula, data=df, return_type='dataframe')
|
1716
|
+
|
1717
|
+
plot_histogram(y, dependent_variable)
|
1718
|
+
|
1719
|
+
# Scale the independent variables and dependent variable
|
1720
|
+
scaler_X = MinMaxScaler()
|
1721
|
+
scaler_y = MinMaxScaler()
|
1722
|
+
X = pd.DataFrame(scaler_X.fit_transform(X), columns=X.columns)
|
1723
|
+
y = scaler_y.fit_transform(y)
|
1724
|
+
|
1725
|
+
groups = df['prc'] if regression_type == 'mixed' else None
|
1726
|
+
print(f'performing {regression_type} regression')
|
1727
|
+
model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, remove_row_column_effect=remove_row_column_effect)
|
1728
|
+
|
1729
|
+
# Get the model coefficients and p-values
|
1730
|
+
if regression_type in ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson']:
|
1731
|
+
coefs = model.params
|
1732
|
+
p_values = model.pvalues
|
1733
|
+
|
1734
|
+
coef_df = pd.DataFrame({
|
1735
|
+
'feature': coefs.index,
|
1736
|
+
'coefficient': coefs.values,
|
1737
|
+
'p_value': p_values.values
|
1738
|
+
})
|
1739
|
+
elif regression_type in ['ridge', 'lasso']:
|
1740
|
+
coefs = model.coef_
|
1741
|
+
coefs = np.array(coefs).flatten()
|
1742
|
+
# Calculate p-values
|
1743
|
+
p_values = calculate_p_values(X, y, model)
|
1744
|
+
p_values = np.array(p_values).flatten()
|
1745
|
+
|
1746
|
+
# Create a DataFrame for the coefficients and p-values
|
1747
|
+
coef_df = pd.DataFrame({
|
1748
|
+
'feature': X.columns,
|
1749
|
+
'coefficient': coefs,
|
1750
|
+
'p_value': p_values})
|
1751
|
+
else:
|
1752
|
+
coefs = model.coef_
|
1753
|
+
intercept = model.intercept_
|
1754
|
+
feature_names = X.design_info.column_names
|
1755
|
+
|
1756
|
+
coef_df = pd.DataFrame({
|
1757
|
+
'feature': feature_names,
|
1758
|
+
'coefficient': coefs
|
1759
|
+
})
|
1760
|
+
coef_df.loc[0, 'coefficient'] += intercept
|
1761
|
+
coef_df['p_value'] = np.nan # Placeholder since sklearn doesn't provide p-values
|
1762
|
+
|
1763
|
+
coef_df['-log10(p_value)'] = -np.log10(coef_df['p_value'])
|
1764
|
+
coef_df_v = coef_df[coef_df['feature'] != 'Intercept']
|
1765
|
+
|
1766
|
+
# Create the highlight column
|
1767
|
+
coef_df['highlight'] = coef_df['feature'].apply(lambda x: '220950' in x)
|
1768
|
+
coef_df = coef_df[~coef_df['feature'].str.contains('row|column')]
|
1769
|
+
volcano_plot(coef_df, volcano_path)
|
1770
|
+
|
1771
|
+
return model, coef_df
|
1772
|
+
|
1773
|
+
def perform_regression(df, settings):
|
1774
|
+
|
1775
|
+
from spacr.plot import plot_plates
|
1776
|
+
from .utils import merge_regression_res_with_metadata
|
1777
|
+
from .settings import get_perform_regression_default_settings
|
1778
|
+
|
1779
|
+
reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
|
1780
|
+
if settings['regression_type'] not in reg_types:
|
1781
|
+
print(f'Possible regression types: {reg_types}')
|
1782
|
+
raise ValueError(f"Unsupported regression type {settings['regression_type']}")
|
1783
|
+
|
1784
|
+
if isinstance(df, str):
|
1785
|
+
df = pd.read_csv(df)
|
1786
|
+
elif isinstance(df, pd.DataFrame):
|
1787
|
+
pass
|
1788
|
+
else:
|
1789
|
+
raise ValueError("Data must be a DataFrame or a path to a CSV file")
|
1790
|
+
|
1791
|
+
|
1792
|
+
if settings['dependent_variable'] not in df.columns:
|
1793
|
+
print(f'Columns in DataFrame:')
|
1794
|
+
for col in df.columns:
|
1795
|
+
print(col)
|
1796
|
+
raise ValueError(f"Dependent variable {settings['dependent_variable']} not found in the DataFrame")
|
1797
|
+
|
1798
|
+
results_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results.csv'
|
1799
|
+
hits_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results_significant.csv'
|
1800
|
+
|
1801
|
+
results_filename = settings['regression_type']+'_'+results_filename
|
1802
|
+
hits_filename = settings['regression_type']+'_'+hits_filename
|
1803
|
+
if settings['regression_type'] == 'quantile':
|
1804
|
+
results_filename = str(settings['alpha'])+'_'+results_filename
|
1805
|
+
hits_filename = str(settings['alpha'])+'_'+hits_filename
|
1806
|
+
results_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), results_filename)
|
1807
|
+
hits_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), hits_filename)
|
1808
|
+
|
1809
|
+
settings = get_perform_regression_default_settings(settings)
|
1810
|
+
|
1811
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
1812
|
+
settings_dir = os.path.dirname(settings['gene_weights_csv'])
|
1813
|
+
settings_csv = os.path.join(settings_dir,f"{settings['regression_type']}_regression_settings.csv")
|
1814
|
+
settings_df.to_csv(settings_csv, index=False)
|
1815
|
+
display(settings_df)
|
1816
|
+
|
1817
|
+
df = clean_controls(df,settings['pc'],settings['nc'],settings['other'])
|
1818
|
+
|
1819
|
+
if 'prediction_probability_class_1' in df.columns:
|
1820
|
+
if not settings['class_1_threshold'] is None:
|
1821
|
+
df['predictions'] = (df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
|
1822
|
+
|
1823
|
+
dependent_df, dependent_variable = process_scores(df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
|
1824
|
+
|
1825
|
+
display(dependent_df)
|
1826
|
+
|
1827
|
+
independent_df = precess_reads(settings['gene_weights_csv'], settings['fraction_threshold'], settings['plate'])
|
1828
|
+
display(independent_df)
|
1829
|
+
|
1830
|
+
merged_df = pd.merge(independent_df, dependent_df, on='prc')
|
1831
|
+
|
1832
|
+
merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
|
1833
|
+
|
1834
|
+
if settings['transform'] is None:
|
1835
|
+
_ = plot_plates(df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'])
|
1836
|
+
|
1837
|
+
model, coef_df = regression(merged_df, settings['gene_weights_csv'], dependent_variable, settings['regression_type'], settings['alpha'], settings['remove_row_column_effect'])
|
1838
|
+
|
1839
|
+
coef_df.to_csv(results_path, index=False)
|
1840
|
+
|
1841
|
+
if settings['regression_type'] == 'lasso':
|
1842
|
+
significant = coef_df[coef_df['coefficient'] > 0]
|
1843
|
+
|
1844
|
+
else:
|
1845
|
+
significant = coef_df[coef_df['p_value']<= 0.05]
|
1846
|
+
#significant = significant[significant['coefficient'] > 0.1]
|
1847
|
+
significant.sort_values(by='coefficient', ascending=False, inplace=True)
|
1848
|
+
significant = significant[~significant['feature'].str.contains('row|column')]
|
1849
|
+
|
1850
|
+
if settings['regression_type'] == 'ols':
|
1851
|
+
print(model.summary())
|
1852
|
+
|
1853
|
+
significant.to_csv(hits_path, index=False)
|
1854
|
+
|
1855
|
+
me49 = '/home/carruthers/Documents/TGME49_Summary.csv'
|
1856
|
+
gt1 = '/home/carruthers/Documents/TGGT1_Summary.csv'
|
1857
|
+
|
1858
|
+
_ = merge_regression_res_with_metadata(hits_path, me49, name='_me49_metadata')
|
1859
|
+
_ = merge_regression_res_with_metadata(hits_path, gt1, name='_gt1_metadata')
|
1860
|
+
_ = merge_regression_res_with_metadata(results_path, me49, name='_me49_metadata')
|
1861
|
+
_ = merge_regression_res_with_metadata(results_path, gt1, name='_gt1_metadata')
|
1862
|
+
|
1863
|
+
print('Significant Genes')
|
1864
|
+
display(significant)
|
1865
|
+
return coef_df
|
1866
|
+
|
1867
|
+
|