spacr 0.0.81__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/sequencing.py CHANGED
@@ -7,10 +7,19 @@ import matplotlib.pyplot as plt
7
7
  import seaborn as sns
8
8
  from Bio import pairwise2
9
9
  import statsmodels.api as sm
10
- import statsmodels.formula.api as smf
10
+ from statsmodels.regression.mixed_linear_model import MixedLM
11
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
11
12
  from scipy.stats import gmean
13
+ from scipy import stats
12
14
  from difflib import SequenceMatcher
13
15
  from collections import Counter
16
+ from IPython.display import display
17
+
18
+ from sklearn.linear_model import LinearRegression, Lasso, Ridge
19
+ from sklearn.preprocessing import FunctionTransformer, MinMaxScaler
20
+
21
+ from scipy.stats import shapiro
22
+ from patsy import dmatrices
14
23
 
15
24
  def analyze_reads(settings):
16
25
  """
@@ -28,7 +37,7 @@ def analyze_reads(settings):
28
37
  None
29
38
  """
30
39
 
31
- def save_chunk_to_hdf5(output_file_path, data_chunk, chunk_counter):
40
+ def save_chunk_to_hdf5_v1(output_file_path, data_chunk, chunk_counter):
32
41
  """
33
42
  Save a data chunk to an HDF5 file.
34
43
 
@@ -44,6 +53,28 @@ def analyze_reads(settings):
44
53
  with pd.HDFStore(output_file_path, mode='a', complevel=5, complib='blosc') as store:
45
54
  store.put(f'reads/chunk_{chunk_counter}', df, format='table', append=True)
46
55
 
56
+ def save_chunk_to_hdf5(output_file_path, data_chunk, chunk_counter):
57
+ """
58
+ Save a data chunk to an HDF5 file.
59
+
60
+ Parameters:
61
+ - output_file_path (str): The path to the output HDF5 file.
62
+ - data_chunk (list): The data chunk to be saved.
63
+ - chunk_counter (int): The counter for the current chunk.
64
+
65
+ Returns:
66
+ None
67
+ """
68
+ df = pd.DataFrame(data_chunk, columns=['combined_read', 'grna', 'plate_row', 'column', 'sample'])
69
+ with pd.HDFStore(output_file_path, mode='a', complevel=5, complib='blosc') as store:
70
+ store.put(
71
+ f'reads/chunk_{chunk_counter}',
72
+ df,
73
+ format='table',
74
+ append=True,
75
+ min_itemsize={'combined_read': 300, 'grna': 50, 'plate_row': 20, 'column': 20, 'sample': 50}
76
+ )
77
+
47
78
  def reverse_complement(seq):
48
79
  """
49
80
  Returns the reverse complement of a DNA sequence.
@@ -139,7 +170,7 @@ def analyze_reads(settings):
139
170
  best_alignment = alignments[0]
140
171
  return best_alignment
141
172
 
142
- def combine_reads(samples_dict, src, chunk_size, barecode_length, upstream, downstream):
173
+ def combine_reads(samples_dict, src, chunk_size, barecode_length_1, barecode_length_2, upstream, downstream):
143
174
  """
144
175
  Combine reads from paired-end sequencing files and save the combined reads to a new file.
145
176
 
@@ -186,7 +217,7 @@ def analyze_reads(settings):
186
217
  r1_size_est = os.path.getsize(r1_path) // (avg_read_length * 4) if r1_path else 0
187
218
  r2_size_est = os.path.getsize(r2_path) // (avg_read_length * 4) if r2_path else 0
188
219
  max_size = max(r1_size_est, r2_size_est) * 10
189
-
220
+ test10 =0
190
221
  with tqdm(total=max_size, desc=f"Processing {sample}") as pbar:
191
222
  total_length_processed = 0
192
223
  read_count = 0
@@ -229,12 +260,26 @@ def analyze_reads(settings):
229
260
  combo_split_index_1 = read_combo.find(upstream)
230
261
  combo_split_index_2 = read_combo.find(downstream)
231
262
 
232
- barcode_1 = read_combo[combo_split_index_1 - barecode_length:combo_split_index_1]
263
+ barcode_1 = read_combo[combo_split_index_1 - barecode_length_1:combo_split_index_1]
233
264
  grna = read_combo[combo_split_index_1 + len(upstream):combo_split_index_2]
234
- barcode_2 = read_combo[combo_split_index_2 + len(downstream):combo_split_index_2 + len(downstream) + barecode_length]
265
+ barcode_2 = read_combo[combo_split_index_2 + len(downstream):combo_split_index_2 + len(downstream) + barecode_length_2]
235
266
  barcode_2 = reverse_complement(barcode_2)
236
267
  data_chunk.append((read_combo, grna, barcode_1, barcode_2, sample))
237
268
 
269
+ if settings['test']:
270
+ if read_count % 1000 == 0:
271
+ print(f"Read count: {read_count}")
272
+ print(f"Read 1: {r1_read_rc}")
273
+ print(f"Read 2: {r2_read}")
274
+ print(f"Read combo: {read_combo}")
275
+ print(f"Barcode 1: {barcode_1}")
276
+ print(f"gRNA: {grna}")
277
+ print(f"Barcode 2: {barcode_2}")
278
+ print()
279
+ test10 += 1
280
+ if test10 == 10:
281
+ break
282
+
238
283
  read_count += 1
239
284
  total_length_processed += len(r1_read) + len(r2_read)
240
285
 
@@ -261,13 +306,12 @@ def analyze_reads(settings):
261
306
  qc_df = pd.DataFrame([qc])
262
307
  qc_df.to_csv(qc_file_path, index=False)
263
308
 
264
- settings.setdefault('upstream', 'CTTCTGGTAAATGGGGATGTCAAGTT')
265
- settings.setdefault('downstream', 'GTTTAAGAGCTATGCTGGAAACAGCA')
266
- settings.setdefault('barecode_length', 8)
267
- settings.setdefault('chunk_size', 1000000)
268
-
309
+ from .utils import get_analyze_reads_default_settings
310
+
311
+ settings = get_analyze_reads_default_settings(settings)
312
+
269
313
  samples_dict = parse_gz_files(settings['src'])
270
- combine_reads(samples_dict, settings['src'], settings['chunk_size'], settings['barecode_length'], settings['upstream'], settings['downstream'])
314
+ combine_reads(samples_dict, settings['src'], settings['chunk_size'], settings['barecode_length_1'], settings['barecode_length_2'], settings['upstream'], settings['downstream'])
271
315
 
272
316
  def map_barcodes(h5_file_path, settings={}):
273
317
  """
@@ -280,27 +324,20 @@ def map_barcodes(h5_file_path, settings={}):
280
324
  Returns:
281
325
  None
282
326
  """
283
- def get_read_qc(df, df_cleaned):
327
+ def get_read_qc(df, settings):
284
328
  """
285
329
  Calculate quality control metrics for sequencing reads.
286
330
 
287
331
  Parameters:
288
332
  - df: DataFrame containing the sequencing reads.
289
- - df_cleaned: DataFrame containing the cleaned sequencing reads.
290
333
 
291
334
  Returns:
292
- - qc_dict: Dictionary containing the following quality control metrics:
293
- - 'reads': Total number of reads.
294
- - 'cleaned_reads': Total number of cleaned reads.
295
- - 'NaN_grna': Number of reads with missing 'grna_metadata'.
296
- - 'NaN_plate_row': Number of reads with missing 'plate_row_metadata'.
297
- - 'NaN_column': Number of reads with missing 'column_metadata'.
298
- - 'NaN_plate': Number of reads with missing 'plate_metadata'.
299
- - 'unique_grna': Counter object containing the count of unique 'grna_metadata' values.
300
- - 'unique_plate_row': Counter object containing the count of unique 'plate_row_metadata' values.
301
- - 'unique_column': Counter object containing the count of unique 'column_metadata' values.
302
- - 'unique_plate': Counter object containing the count of unique 'plate_metadata' values.
335
+ - df_cleaned: DataFrame containing the cleaned sequencing reads.
336
+ - qc_dict: Dictionary containing the quality control metrics.
303
337
  """
338
+
339
+ df_cleaned = df.dropna()
340
+
304
341
  qc_dict = {}
305
342
  qc_dict['reads'] = len(df)
306
343
  qc_dict['cleaned_reads'] = len(df_cleaned)
@@ -312,9 +349,56 @@ def map_barcodes(h5_file_path, settings={}):
312
349
  qc_dict['unique_plate_row'] = Counter(df['plate_row_metadata'].dropna().tolist())
313
350
  qc_dict['unique_column'] = Counter(df['column_metadata'].dropna().tolist())
314
351
  qc_dict['unique_plate'] = Counter(df['plate_metadata'].dropna().tolist())
352
+
353
+ # Calculate control error rates using cleaned DataFrame
354
+ total_pc_non_nan = df_cleaned[(df_cleaned['column_metadata'] == settings['pc_loc'])].shape[0]
355
+ total_nc_non_nan = df_cleaned[(df_cleaned['column_metadata'] == settings['nc_loc'])].shape[0]
315
356
 
316
- return qc_dict
317
-
357
+ pc_count_pc = df_cleaned[(df_cleaned['column_metadata'] == settings['pc_loc']) & (df_cleaned['grna_metadata'] == settings['pc'])].shape[0]
358
+ nc_count_nc = df_cleaned[(df_cleaned['column_metadata'] == settings['nc_loc']) & (df_cleaned['grna_metadata'] == settings['nc'])].shape[0]
359
+
360
+ pc_error_count = df_cleaned[(df_cleaned['column_metadata'] == settings['pc_loc']) & (df_cleaned['grna_metadata'] != settings['pc'])].shape[0]
361
+ nc_error_count = df_cleaned[(df_cleaned['column_metadata'] == settings['nc_loc']) & (df_cleaned['grna_metadata'] != settings['nc'])].shape[0]
362
+
363
+ pc_in_nc_loc_count = df_cleaned[(df_cleaned['column_metadata'] == settings['nc_loc']) & (df_cleaned['grna_metadata'] == settings['pc'])].shape[0]
364
+ nc_in_pc_loc_count = df_cleaned[(df_cleaned['column_metadata'] == settings['pc_loc']) & (df_cleaned['grna_metadata'] == settings['nc'])].shape[0]
365
+
366
+ # Collect QC metrics into a dictionary
367
+ # PC
368
+ qc_dict['pc_total_count'] = total_pc_non_nan
369
+ qc_dict['pc_count_pc'] = pc_count_pc
370
+ qc_dict['nc_count_pc'] = pc_in_nc_loc_count
371
+ qc_dict['pc_error_count'] = pc_error_count
372
+ # NC
373
+ qc_dict['nc_total_count'] = total_nc_non_nan
374
+ qc_dict['nc_count_nc'] = nc_count_nc
375
+ qc_dict['pc_count_nc'] = nc_in_pc_loc_count
376
+ qc_dict['nc_error_count'] = nc_error_count
377
+
378
+ return df_cleaned, qc_dict
379
+
380
+ def get_per_row_qc(df, settings):
381
+ """
382
+ Calculate quality control metrics for each unique row in the control columns.
383
+
384
+ Parameters:
385
+ - df: DataFrame containing the sequencing reads.
386
+ - settings: Dictionary containing the settings for control values.
387
+
388
+ Returns:
389
+ - dict: Dictionary containing the quality control metrics for each unique row.
390
+ """
391
+ qc_dict_per_row = {}
392
+ unique_rows = df['plate_row_metadata'].dropna().unique().tolist()
393
+ unique_rows = list(set(unique_rows)) # Remove duplicates
394
+
395
+ for row in unique_rows:
396
+ df_row = df[(df['plate_row_metadata'] == row)]
397
+ _, qc_dict_row = get_read_qc(df_row, settings)
398
+ qc_dict_per_row[row] = qc_dict_row
399
+
400
+ return qc_dict_per_row
401
+
318
402
  def mapping_dicts(df, settings):
319
403
  """
320
404
  Maps the values in the DataFrame columns to corresponding metadata using dictionaries.
@@ -339,22 +423,87 @@ def map_barcodes(h5_file_path, settings={}):
339
423
  df['plate_row_metadata'] = df['plate_row'].map(plate_row_dict)
340
424
  df['column_metadata'] = df['column'].map(column_dict)
341
425
  df['plate_metadata'] = df['sample'].map(plate_dict)
342
-
426
+
343
427
  return df
344
428
 
345
- settings.setdefault('grna', '/home/carruthers/Documents/grna_barcodes.csv')
346
- settings.setdefault('barcodes', '/home/carruthers/Documents/SCREEN_BARCODES.csv')
347
- settings.setdefault('plate_dict', {'EO1': 'plate1', 'EO2': 'plate2', 'EO3': 'plate3', 'EO4': 'plate4', 'EO5': 'plate5', 'EO6': 'plate6', 'EO7': 'plate7', 'EO8': 'plate8'})
348
- settings.setdefault('test', False)
349
- settings.setdefault('verbose', True)
350
- settings.setdefault('min_itemsize', 1000)
351
-
352
- qc_file_path = os.path.splitext(h5_file_path)[0] + '_qc_step_2.csv'
353
- unique_grna_file_path = os.path.splitext(h5_file_path)[0] + '_unique_grna.csv'
354
- unique_plate_row_file_path = os.path.splitext(h5_file_path)[0] + '_unique_plate_row.csv'
355
- unique_column_file_path = os.path.splitext(h5_file_path)[0] + '_unique_column.csv'
356
- unique_plate_file_path = os.path.splitext(h5_file_path)[0] + '_unique_plate.csv'
357
- new_h5_file_path = os.path.splitext(h5_file_path)[0] + '_cleaned.h5'
429
+ def filter_combinations(df, settings):
430
+ """
431
+ Takes the combination counts Data Frame, filters the rows based on specific conditions,
432
+ and removes rows with a count lower than the highest value of max_count_c1 and max_count_c2.
433
+
434
+ Args:
435
+ combination_counts_file_path (str): The file path to the CSV file containing the combination counts.
436
+ pc (str, optional): The positive control sequence. Defaults to 'TGGT1_220950_1'.
437
+ nc (str, optional): The negative control sequence. Defaults to 'TGGT1_233460_4'.
438
+
439
+ Returns:
440
+ pd.DataFrame: The filtered DataFrame.
441
+ """
442
+
443
+ pc = settings['pc']
444
+ nc = settings['nc']
445
+ pc_loc = settings['pc_loc']
446
+ nc_loc = settings['nc_loc']
447
+
448
+ filtered_c1 = df[(df['column'] == nc_loc) & (df['grna'] != nc)]
449
+ max_count_c1 = filtered_c1['count'].max()
450
+
451
+ filtered_c2 = df[(df['column'] == pc_loc) & (df['grna'] != pc)]
452
+ max_count_c2 = filtered_c2['count'].max()
453
+
454
+ #filtered_c3 = df[(df['column'] != nc_loc) & (df['grna'] == nc)]
455
+ #max_count_c3 = filtered_c3['count'].max()
456
+
457
+ #filtered_c4 = df[(df['column'] != pc_loc) & (df['grna'] == pc)]
458
+ #max_count_c4 = filtered_c4['count'].max()
459
+
460
+ # Find the highest value between max_count_c1 and max_count_c2
461
+ highest_max_count = max(max_count_c1, max_count_c2)
462
+
463
+ # Filter the DataFrame to remove rows with a count lower than the highest_max_count
464
+ filtered_df = df[df['count'] >= highest_max_count]
465
+
466
+ # Calculate total read counts for each unique combination of plate_row and column
467
+ filtered_df['total_reads'] = filtered_df.groupby(['plate_row', 'column'])['count'].transform('sum')
468
+
469
+ # Calculate read fraction for each row
470
+ filtered_df['read_fraction'] = filtered_df['count'] / filtered_df['total_reads']
471
+
472
+ if settings['verbose']:
473
+ print(f"Max count for non {nc} in {nc_loc}: {max_count_c1}")
474
+ print(f"Max count for non {pc} in {pc_loc}: {max_count_c2}")
475
+ #print(f"Max count for {nc} in other columns: {max_count_c3}")
476
+
477
+ return filtered_df
478
+
479
+ from .settings import get_map_barcodes_default_settings
480
+
481
+ settings = get_map_barcodes_default_settings(settings)
482
+
483
+ fldr = os.path.splitext(h5_file_path)[0]
484
+ file_name = os.path.basename(fldr)
485
+
486
+ if settings['test']:
487
+ fldr = os.path.join(fldr, 'test')
488
+ os.makedirs(fldr, exist_ok=True)
489
+
490
+ qc_file_path = os.path.join(fldr, f'{file_name}_qc_step_2.csv')
491
+ unique_grna_file_path = os.path.join(fldr, f'{file_name}_unique_grna.csv')
492
+ unique_plate_row_file_path = os.path.join(fldr, f'{file_name}_unique_plate_row.csv')
493
+ unique_column_file_path = os.path.join(fldr, f'{file_name}_unique_column.csv')
494
+ unique_plate_file_path = os.path.join(fldr, f'{file_name}_unique_plate.csv')
495
+ new_h5_file_path = os.path.join(fldr, f'{file_name}_cleaned.h5')
496
+ combination_counts_file_path = os.path.join(fldr, f'{file_name}_combination_counts.csv')
497
+ combination_counts_file_path_cleaned = os.path.join(fldr, f'{file_name}_combination_counts_cleaned.csv')
498
+
499
+ #qc_file_path = os.path.splitext(h5_file_path)[0] + '_qc_step_2.csv'
500
+ #unique_grna_file_path = os.path.splitext(h5_file_path)[0] + '_unique_grna.csv'
501
+ #unique_plate_row_file_path = os.path.splitext(h5_file_path)[0] + '_unique_plate_row.csv'
502
+ #unique_column_file_path = os.path.splitext(h5_file_path)[0] + '_unique_column.csv'
503
+ #unique_plate_file_path = os.path.splitext(h5_file_path)[0] + '_unique_plate.csv'
504
+ #new_h5_file_path = os.path.splitext(h5_file_path)[0] + '_cleaned.h5'
505
+ #combination_counts_file_path = os.path.splitext(h5_file_path)[0] + '_combination_counts.csv'
506
+ #combination_counts_file_path_cleaned = os.path.splitext(h5_file_path)[0] + '_combination_counts_cleaned.csv'
358
507
 
359
508
  # Initialize the HDF5 store for cleaned data
360
509
  store_cleaned = pd.HDFStore(new_h5_file_path, mode='a', complevel=5, complib='blosc')
@@ -370,38 +519,89 @@ def map_barcodes(h5_file_path, settings={}):
370
519
  'unique_grna': Counter(),
371
520
  'unique_plate_row': Counter(),
372
521
  'unique_column': Counter(),
373
- 'unique_plate': Counter()
522
+ 'unique_plate': Counter(),
523
+ 'pc_total_count': 0,
524
+ 'pc_count_pc': 0,
525
+ 'nc_total_count': 0,
526
+ 'nc_count_nc': 0,
527
+ 'pc_count_nc': 0,
528
+ 'nc_count_pc': 0,
529
+ 'pc_error_count': 0,
530
+ 'nc_error_count': 0,
531
+ 'pc_fraction_pc': 0,
532
+ 'nc_fraction_nc': 0,
533
+ 'pc_fraction_nc': 0,
534
+ 'nc_fraction_pc': 0
374
535
  }
375
536
 
537
+ per_row_qc = {}
538
+ combination_counts = Counter()
539
+
376
540
  with pd.HDFStore(h5_file_path, mode='r') as store:
377
541
  keys = [key for key in store.keys() if key.startswith('/reads/chunk_')]
378
-
542
+
543
+ if settings['test']:
544
+ keys = keys[:3] # Only read the first chunks if in test mode
545
+
379
546
  for key in keys:
380
547
  df = store.get(key)
381
548
  df = mapping_dicts(df, settings)
382
- df_cleaned = df.dropna()
383
- qc_dict = get_read_qc(df, df_cleaned)
384
-
385
- # Accumulate QC metrics
386
- overall_qc['reads'] += qc_dict['reads']
387
- overall_qc['cleaned_reads'] += qc_dict['cleaned_reads']
388
- overall_qc['NaN_grna'] += qc_dict['NaN_grna']
389
- overall_qc['NaN_plate_row'] += qc_dict['NaN_plate_row']
390
- overall_qc['NaN_column'] += qc_dict['NaN_column']
391
- overall_qc['NaN_plate'] += qc_dict['NaN_plate']
392
- overall_qc['unique_grna'].update(qc_dict['unique_grna'])
393
- overall_qc['unique_plate_row'].update(qc_dict['unique_plate_row'])
394
- overall_qc['unique_column'].update(qc_dict['unique_column'])
395
- overall_qc['unique_plate'].update(qc_dict['unique_plate'])
549
+ df_cleaned, qc_dict = get_read_qc(df, settings)
550
+
551
+ # Accumulate counts for unique combinations
552
+ combinations = df_cleaned[['plate_row_metadata', 'column_metadata', 'grna_metadata']].apply(tuple, axis=1)
396
553
 
397
- df_cleaned = df_cleaned[df_cleaned['grna_length'] >= 30]
398
-
554
+ combination_counts.update(combinations)
555
+
556
+ if settings['test'] and settings['verbose']:
557
+ os.makedirs(os.path.join(os.path.splitext(h5_file_path)[0],'test'), exist_ok=True)
558
+ df.to_csv(os.path.join(os.path.splitext(h5_file_path)[0],'test','chunk_1_df.csv'), index=False)
559
+ df_cleaned.to_csv(os.path.join(os.path.splitext(h5_file_path)[0],'test','chunk_1_df_cleaned.csv'), index=False)
560
+
561
+ # Accumulate QC metrics for all rows
562
+ for metric in qc_dict:
563
+ if isinstance(overall_qc[metric], Counter):
564
+ overall_qc[metric].update(qc_dict[metric])
565
+ else:
566
+ overall_qc[metric] += qc_dict[metric]
567
+
568
+ # Update per_row_qc dictionary
569
+ chunk_per_row_qc = get_per_row_qc(df, settings)
570
+ for row in chunk_per_row_qc:
571
+ if row not in per_row_qc:
572
+ per_row_qc[row] = chunk_per_row_qc[row]
573
+ else:
574
+ for metric in chunk_per_row_qc[row]:
575
+ if isinstance(per_row_qc[row][metric], Counter):
576
+ per_row_qc[row][metric].update(chunk_per_row_qc[row][metric])
577
+ else:
578
+ per_row_qc[row][metric] += chunk_per_row_qc[row][metric]
579
+
580
+ # Ensure the DataFrame columns are in the desired order
581
+ df_cleaned = df_cleaned[['grna', 'plate_row', 'column', 'sample', 'grna_metadata', 'plate_row_metadata', 'column_metadata', 'plate_metadata']]
582
+
399
583
  # Save cleaned data to the new HDF5 store
400
584
  store_cleaned.put('reads/cleaned_data', df_cleaned, format='table', append=True)
401
-
585
+
402
586
  del df_cleaned, df
403
587
  gc.collect()
404
588
 
589
+ # Calculate overall fractions after accumulating all metrics
590
+ overall_qc['pc_fraction_pc'] = overall_qc['pc_count_pc'] / overall_qc['pc_total_count'] if overall_qc['pc_total_count'] else 0
591
+ overall_qc['nc_fraction_nc'] = overall_qc['nc_count_nc'] / overall_qc['nc_total_count'] if overall_qc['nc_total_count'] else 0
592
+ overall_qc['pc_fraction_nc'] = overall_qc['pc_count_nc'] / overall_qc['nc_total_count'] if overall_qc['nc_total_count'] else 0
593
+ overall_qc['nc_fraction_pc'] = overall_qc['nc_count_pc'] / overall_qc['pc_total_count'] if overall_qc['pc_total_count'] else 0
594
+
595
+ for row in per_row_qc:
596
+ if row != 'all_rows':
597
+ per_row_qc[row]['pc_fraction_pc'] = per_row_qc[row]['pc_count_pc'] / per_row_qc[row]['pc_total_count'] if per_row_qc[row]['pc_total_count'] else 0
598
+ per_row_qc[row]['nc_fraction_nc'] = per_row_qc[row]['nc_count_nc'] / per_row_qc[row]['nc_total_count'] if per_row_qc[row]['nc_total_count'] else 0
599
+ per_row_qc[row]['pc_fraction_nc'] = per_row_qc[row]['pc_count_nc'] / per_row_qc[row]['nc_total_count'] if per_row_qc[row]['nc_total_count'] else 0
600
+ per_row_qc[row]['nc_fraction_pc'] = per_row_qc[row]['nc_count_pc'] / per_row_qc[row]['pc_total_count'] if per_row_qc[row]['pc_total_count'] else 0
601
+
602
+ # Add overall_qc to per_row_qc with the key 'all_rows'
603
+ per_row_qc['all_rows'] = overall_qc
604
+
405
605
  # Convert the Counter objects to DataFrames and save them to CSV files
406
606
  unique_grna_df = pd.DataFrame(overall_qc['unique_grna'].items(), columns=['key', 'value'])
407
607
  unique_plate_row_df = pd.DataFrame(overall_qc['unique_plate_row'].items(), columns=['key', 'value'])
@@ -422,89 +622,128 @@ def map_barcodes(h5_file_path, settings={}):
422
622
  # Combine all remaining QC metrics into a single DataFrame and save it to CSV
423
623
  qc_df = pd.DataFrame([overall_qc])
424
624
  qc_df.to_csv(qc_file_path, index=False)
625
+
626
+ # Convert per_row_qc to a DataFrame and save it to CSV
627
+ per_row_qc_df = pd.DataFrame.from_dict(per_row_qc, orient='index')
628
+ per_row_qc_df = per_row_qc_df.sort_values(by='reads', ascending=False)
629
+ per_row_qc_df = per_row_qc_df.drop(['unique_grna', 'unique_plate_row', 'unique_column', 'unique_plate'], axis=1, errors='ignore')
630
+ per_row_qc_df = per_row_qc_df.dropna(subset=['reads'])
631
+ per_row_qc_df.to_csv(os.path.splitext(h5_file_path)[0] + '_per_row_qc.csv', index=True)
632
+
633
+ if settings['verbose']:
634
+ display(per_row_qc_df)
635
+
636
+ # Save the combination counts to a CSV file
637
+ try:
638
+ combination_counts_df = pd.DataFrame(combination_counts.items(), columns=['combination', 'count'])
639
+ combination_counts_df[['plate_row', 'column', 'grna']] = pd.DataFrame(combination_counts_df['combination'].tolist(), index=combination_counts_df.index)
640
+ combination_counts_df = combination_counts_df.drop('combination', axis=1)
641
+ combination_counts_df.to_csv(combination_counts_file_path, index=False)
642
+
643
+ grna_plate_heatmap(combination_counts_file_path, specific_grna=None)
644
+ grna_plate_heatmap(combination_counts_file_path, specific_grna=settings['pc'])
645
+ grna_plate_heatmap(combination_counts_file_path, specific_grna=settings['nc'])
646
+
647
+ combination_counts_df_cleaned = filter_combinations(combination_counts_df, settings)
648
+ combination_counts_df_cleaned.to_csv(combination_counts_file_path_cleaned, index=False)
649
+
650
+ grna_plate_heatmap(combination_counts_file_path_cleaned, specific_grna=None)
651
+ grna_plate_heatmap(combination_counts_file_path_cleaned, specific_grna=settings['pc'])
652
+ grna_plate_heatmap(combination_counts_file_path_cleaned, specific_grna=settings['nc'])
653
+ except Exception as e:
654
+ print(e)
425
655
 
426
656
  # Close the HDF5 store
427
657
  store_cleaned.close()
428
-
429
658
  gc.collect()
430
659
  return
431
660
 
432
- def map_barcodes_v1(h5_file_path, settings={}):
661
+ def grna_plate_heatmap(path, specific_grna=None, min_max='all', cmap='viridis', min_count=0, save=True):
662
+ """
663
+ Generate a heatmap of gRNA plate data.
433
664
 
434
- def get_read_qc(df, df_cleaned):
435
- qc_dict = {}
436
- qc_dict['reads'] = len(df)
437
- qc_dict['cleaned_reads'] = len(df_cleaned)
438
- qc_dict['NaN_grna'] = df['grna_metadata'].isna().sum()
439
- qc_dict['NaN_plate_row'] = df['plate_row_metadata'].isna().sum()
440
- qc_dict['NaN_column'] = df['column_metadata'].isna().sum()
441
- qc_dict['NaN_plate'] = df['plate_metadata'].isna().sum()
442
-
443
-
444
- qc_dict['unique_grna'] = len(df['grna_metadata'].dropna().unique().tolist())
445
- qc_dict['unique_plate_row'] = len(df['plate_row_metadata'].dropna().unique().tolist())
446
- qc_dict['unique_column'] = len(df['column_metadata'].dropna().unique().tolist())
447
- qc_dict['unique_plate'] = len(df['plate_metadata'].dropna().unique().tolist())
448
- qc_dict['value_counts_grna'] = df['grna_metadata'].value_counts(dropna=True)
449
- qc_dict['value_counts_plate_row'] = df['plate_row_metadata'].value_counts(dropna=True)
450
- qc_dict['value_counts_column'] = df['column_metadata'].value_counts(dropna=True)
665
+ Args:
666
+ path (str): The path to the CSV file containing the gRNA plate data.
667
+ specific_grna (str, optional): The specific gRNA to filter the data for. Defaults to None.
668
+ min_max (str or list or tuple, optional): The range of values to use for the color scale.
669
+ If 'all', the range will be determined by the minimum and maximum values in the data.
670
+ If 'allq', the range will be determined by the 2nd and 98th percentiles of the data.
671
+ If a list or tuple of two values, the range will be determined by those values.
672
+ Defaults to 'all'.
673
+ cmap (str, optional): The colormap to use for the heatmap. Defaults to 'viridis'.
674
+ min_count (int, optional): The minimum count threshold for including a gRNA in the heatmap.
675
+ Defaults to 0.
676
+ save (bool, optional): Whether to save the heatmap as a PDF file. Defaults to True.
677
+
678
+ Returns:
679
+ matplotlib.figure.Figure: The generated heatmap figure.
680
+ """
681
+ def generate_grna_plate_heatmap(df, plate_number, min_max, min_count, specific_grna=None):
682
+ df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
451
683
 
452
- return qc_dict
684
+ # Filtering the dataframe based on the plate_number and specific gRNA if provided
685
+ df = df[df['plate_row'].str.startswith(plate_number)]
686
+ if specific_grna:
687
+ df = df[df['grna'] == specific_grna]
688
+
689
+ # Split plate_row into plate and row
690
+ df[['plate', 'row']] = df['plate_row'].str.split('_', expand=True)
691
+
692
+ # Ensure proper ordering
693
+ row_order = [f'r{i}' for i in range(1, 17)]
694
+ col_order = [f'c{i}' for i in range(1, 28)]
695
+
696
+ df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
697
+ df['column'] = pd.Categorical(df['column'], categories=col_order, ordered=True)
698
+
699
+ # Group by row and column, summing counts
700
+ grouped = df.groupby(['row', 'column'], observed=True)['count'].sum().reset_index()
701
+
702
+ plate_map = pd.pivot_table(grouped, values='count', index='row', columns='column').fillna(0)
703
+
704
+ if min_max == 'all':
705
+ min_max = [plate_map.min().min(), plate_map.max().max()]
706
+ elif min_max == 'allq':
707
+ min_max = np.quantile(plate_map.values, [0.02, 0.98])
708
+ elif isinstance(min_max, (list, tuple)) and len(min_max) == 2:
709
+ if isinstance(min_max[0], (float)) and isinstance(min_max[1], (float)):
710
+ min_max = np.quantile(plate_map.values, [min_max[0], min_max[1]])
711
+ if isinstance(min_max[0], (int)) and isinstance(min_max[1], (int)):
712
+ min_max = [min_max[0], min_max[1]]
713
+
714
+ return plate_map, min_max
453
715
 
454
- def mapping_dicts(df, settings):
455
- grna_df = pd.read_csv(settings['grna'])
456
- barcode_df = pd.read_csv(settings['barcodes'])
716
+ if isinstance(path, pd.DataFrame):
717
+ df = path
718
+ else:
719
+ df = pd.read_csv(path)
457
720
 
458
- grna_dict = {row['sequence']: row['name'] for _, row in grna_df.iterrows()}
459
- plate_row_dict = {row['sequence']: row['name'] for _, row in barcode_df.iterrows() if row['name'].startswith('p')}
460
- column_dict = {row['sequence']: row['name'] for _, row in barcode_df.iterrows() if row['name'].startswith('c')}
461
- plate_dict = settings['plate_dict']
721
+ plates = df['plate_row'].str.split('_', expand=True)[0].unique()
722
+ n_rows, n_cols = (len(plates) + 3) // 4, 4
723
+ fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
724
+ ax = ax.flatten()
462
725
 
463
- df['grna_metadata'] = df['grna'].map(grna_dict)
464
- df['grna_length'] = df['grna'].apply(len)
465
- df['plate_row_metadata'] = df['plate_row'].map(plate_row_dict)
466
- df['column_metadata'] = df['column'].map(column_dict)
467
- df['plate_metadata'] = df['sample'].map(plate_dict)
726
+ for index, plate in enumerate(plates):
727
+ plate_map, min_max_values = generate_grna_plate_heatmap(df, plate, min_max, min_count, specific_grna)
728
+ sns.heatmap(plate_map, cmap=cmap, vmin=min_max_values[0], vmax=min_max_values[1], ax=ax[index])
729
+ ax[index].set_title(plate)
468
730
 
469
- return df
470
-
471
- settings.setdefault('grna', '/home/carruthers/Documents/grna_barcodes.csv')
472
- settings.setdefault('barcodes', '/home/carruthers/Documents/SCREEN_BARCODES.csv')
473
- settings.setdefault('plate_dict', {'EO1': 'plate1', 'EO2': 'plate2', 'EO3': 'plate3', 'EO4': 'plate4', 'EO5': 'plate5', 'EO6': 'plate6', 'EO7': 'plate7', 'EO8': 'plate8'})
474
- settings.setdefault('test', False)
475
- settings.setdefault('verbose', True)
476
- settings.setdefault('min_itemsize', 1000)
477
-
478
- qc_file_path = os.path.splitext(h5_file_path)[0] + '_qc_step_2.csv'
479
- new_h5_file_path = os.path.splitext(h5_file_path)[0] + '_cleaned.h5'
731
+ for i in range(len(plates), n_rows * n_cols):
732
+ fig.delaxes(ax[i])
480
733
 
481
- # Initialize the HDF5 store for cleaned data
482
- store_cleaned = pd.HDFStore(new_h5_file_path, mode='a', complevel=5, complib='blosc')
734
+ plt.subplots_adjust(wspace=0.1, hspace=0.4)
483
735
 
484
- # Initialize the DataFrame for QC metrics
485
- qc_df_list = []
486
-
487
- with pd.HDFStore(h5_file_path, mode='r') as store:
488
- keys = [key for key in store.keys() if key.startswith('/reads/chunk_')]
489
-
490
- for key in keys:
491
- df = store.get(key)
492
- df = mapping_dicts(df, settings)
493
- df_cleaned = df.dropna()
494
- qc_dict = get_read_qc(df, df_cleaned)
495
- qc_df_list.append(qc_dict)
496
- df_cleaned = df_cleaned[df_cleaned['grna_length'] >= 30]
497
-
498
- # Save cleaned data to the new HDF5 store
499
- store_cleaned.put('reads/cleaned_data', df_cleaned, format='table', append=True)
500
-
501
- # Combine all QC metrics into a single DataFrame and save it to CSV
502
- qc_df = pd.DataFrame(qc_df_list)
503
- qc_df.to_csv(qc_file_path, index=False)
736
+ # Save the figure
737
+ if save:
738
+ filename = path.replace('.csv', '')
739
+ if specific_grna:
740
+ filename += f'_{specific_grna}'
741
+ filename += '.pdf'
742
+ plt.savefig(filename)
743
+ print(f'saved {filename}')
744
+ plt.show()
504
745
 
505
- # Close the HDF5 store
506
- store_cleaned.close()
507
- return
746
+ return fig
508
747
 
509
748
  def map_barcodes_folder(src, settings={}):
510
749
  for file in os.listdir(src):
@@ -1144,4 +1383,485 @@ def generate_fraction_map(df, gene_column, min_=10, plates=['p1','p2','p3','p4']
1144
1383
  independent_variables = independent_variables.drop('sum', axis=1)
1145
1384
  independent_variables.index.name = 'prc'
1146
1385
  independent_variables = independent_variables.loc[:, (independent_variables.sum() != 0)]
1147
- return independent_variables
1386
+ return independent_variables
1387
+
1388
+ def precess_reads(csv_path, fraction_threshold, plate):
1389
+ # Read the CSV file into a DataFrame
1390
+ csv_df = pd.read_csv(csv_path)
1391
+
1392
+ # Ensure the necessary columns are present
1393
+ if not all(col in csv_df.columns for col in ['grna', 'count', 'column']):
1394
+ raise ValueError("The CSV file must contain 'grna', 'count', 'plate_row', and 'column' columns.")
1395
+
1396
+ if 'plate_row' in csv_df.columns:
1397
+ csv_df[['plate', 'row']] = csv_df['plate_row'].str.split('_', expand=True)
1398
+ if plate is not None:
1399
+ csv_df = csv_df.drop(columns=['plate'])
1400
+ csv_df['plate'] = plate
1401
+
1402
+ if plate is not None:
1403
+ csv_df['plate'] = plate
1404
+
1405
+ # Create the prc column
1406
+ csv_df['prc'] = csv_df['plate'] + '_' + csv_df['row'] + '_' + csv_df['column']
1407
+
1408
+ # Group by prc and calculate the sum of counts
1409
+ grouped_df = csv_df.groupby('prc')['count'].sum().reset_index()
1410
+ grouped_df = grouped_df.rename(columns={'count': 'total_counts'})
1411
+ merged_df = pd.merge(csv_df, grouped_df, on='prc')
1412
+ merged_df['fraction'] = merged_df['count'] / merged_df['total_counts']
1413
+
1414
+ # Filter rows with fraction under the threshold
1415
+ if fraction_threshold is not None:
1416
+ observations_before = len(merged_df)
1417
+ merged_df = merged_df[merged_df['fraction'] >= fraction_threshold]
1418
+ observations_after = len(merged_df)
1419
+ removed = observations_before - observations_after
1420
+ print(f'Removed {removed} observation below fraction threshold: {fraction_threshold}')
1421
+
1422
+ merged_df = merged_df[['prc', 'grna', 'fraction']]
1423
+
1424
+ if not all(col in merged_df.columns for col in ['grna', 'gene']):
1425
+ try:
1426
+ merged_df[['org', 'gene', 'grna']] = merged_df['grna'].str.split('_', expand=True)
1427
+ merged_df = merged_df.drop(columns=['org'])
1428
+ merged_df['grna'] = merged_df['gene'] + '_' + merged_df['grna']
1429
+ except:
1430
+ print('Error splitting grna into org, gene, grna.')
1431
+
1432
+ return merged_df
1433
+
1434
+ def apply_transformation(X, transform):
1435
+ if transform == 'log':
1436
+ transformer = FunctionTransformer(np.log1p, validate=True)
1437
+ elif transform == 'sqrt':
1438
+ transformer = FunctionTransformer(np.sqrt, validate=True)
1439
+ elif transform == 'square':
1440
+ transformer = FunctionTransformer(np.square, validate=True)
1441
+ else:
1442
+ transformer = None
1443
+ return transformer
1444
+
1445
+ def check_normality(data, variable_name, verbose=False):
1446
+ """Check if the data is normally distributed using the Shapiro-Wilk test."""
1447
+ stat, p_value = shapiro(data)
1448
+ if verbose:
1449
+ print(f"Shapiro-Wilk Test for {variable_name}:\nStatistic: {stat}, P-value: {p_value}")
1450
+ if p_value > 0.05:
1451
+ if verbose:
1452
+ print(f"The data for {variable_name} is normally distributed.")
1453
+ return True
1454
+ else:
1455
+ if verbose:
1456
+ print(f"The data for {variable_name} is not normally distributed.")
1457
+ return False
1458
+
1459
+ def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
1460
+
1461
+ if plate is not None:
1462
+ df['plate'] = plate
1463
+
1464
+ if 'col' not in df.columns:
1465
+ df['col'] = df['column']
1466
+
1467
+ df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
1468
+ df = df[['prc', dependent_variable]]
1469
+
1470
+ # Group by prc and calculate the mean and count of the dependent_variable
1471
+ grouped = df.groupby('prc')[dependent_variable]
1472
+
1473
+ if regression_type != 'poisson':
1474
+
1475
+ print(f'Using agg_type: {agg_type}')
1476
+
1477
+ if agg_type == 'median':
1478
+ dependent_df = grouped.median().reset_index()
1479
+ elif agg_type == 'mean':
1480
+ dependent_df = grouped.mean().reset_index()
1481
+ elif agg_type == 'quantile':
1482
+ dependent_df = grouped.quantile(0.75).reset_index()
1483
+ elif agg_type == None:
1484
+ dependent_df = df.reset_index()
1485
+ if 'prcfo' in dependent_df.columns:
1486
+ dependent_df = dependent_df.drop(columns=['prcfo'])
1487
+ else:
1488
+ raise ValueError(f"Unsupported aggregation type {agg_type}")
1489
+
1490
+ if regression_type == 'poisson':
1491
+ agg_type = 'count'
1492
+ print(f'Using agg_type: {agg_type} for poisson regression')
1493
+ dependent_df = grouped.sum().reset_index()
1494
+
1495
+ # Calculate cell_count for all cases
1496
+ cell_count = grouped.size().reset_index(name='cell_count')
1497
+
1498
+ if agg_type is None:
1499
+ dependent_df = pd.merge(dependent_df, cell_count, on='prc')
1500
+ else:
1501
+ dependent_df['cell_count'] = cell_count['cell_count']
1502
+
1503
+ dependent_df = dependent_df[dependent_df['cell_count'] >= min_cell_count]
1504
+
1505
+ is_normal = check_normality(dependent_df[dependent_variable], dependent_variable)
1506
+
1507
+ if not transform is None:
1508
+ transformer = apply_transformation(dependent_df[dependent_variable], transform=transform)
1509
+ transformed_var = f'{transform}_{dependent_variable}'
1510
+ dependent_df[transformed_var] = transformer.fit_transform(dependent_df[[dependent_variable]])
1511
+ dependent_variable = transformed_var
1512
+ is_normal = check_normality(dependent_df[transformed_var], transformed_var)
1513
+
1514
+ if not is_normal:
1515
+ print(f'{dependent_variable} is not normally distributed')
1516
+ else:
1517
+ print(f'{dependent_variable} is normally distributed')
1518
+
1519
+ return dependent_df, dependent_variable
1520
+
1521
+ def perform_mixed_model(y, X, groups, alpha=1.0):
1522
+ # Ensure groups are defined correctly and check for multicollinearity
1523
+ if groups is None:
1524
+ raise ValueError("Groups must be defined for mixed model regression")
1525
+
1526
+ # Check for multicollinearity by calculating the VIF for each feature
1527
+ X_np = X.values
1528
+ vif = [variance_inflation_factor(X_np, i) for i in range(X_np.shape[1])]
1529
+ print(f"VIF: {vif}")
1530
+ if any(v > 10 for v in vif):
1531
+ print(f"Multicollinearity detected with VIF: {vif}. Applying Ridge regression to the fixed effects.")
1532
+ ridge = Ridge(alpha=alpha)
1533
+ ridge.fit(X, y)
1534
+ X_ridge = ridge.coef_ * X # Adjust X with Ridge coefficients
1535
+ model = MixedLM(y, X_ridge, groups=groups)
1536
+ else:
1537
+ model = MixedLM(y, X, groups=groups)
1538
+
1539
+ result = model.fit()
1540
+ return result
1541
+
1542
+ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, remove_row_column_effect=True):
1543
+
1544
+ if regression_type == 'ols':
1545
+ model = sm.OLS(y, X).fit()
1546
+
1547
+ elif regression_type == 'gls':
1548
+ model = sm.GLS(y, X).fit()
1549
+
1550
+ elif regression_type == 'wls':
1551
+ weights = 1 / np.sqrt(X.iloc[:, 1])
1552
+ model = sm.WLS(y, X, weights=weights).fit()
1553
+
1554
+ elif regression_type == 'rlm':
1555
+ model = sm.RLM(y, X, M=sm.robust.norms.HuberT()).fit()
1556
+ #model = sm.RLM(y, X, M=sm.robust.norms.TukeyBiweight()).fit()
1557
+ #model = sm.RLM(y, X, M=sm.robust.norms.Hampel()).fit()
1558
+ #model = sm.RLM(y, X, M=sm.robust.norms.LeastSquares()).fit()
1559
+ #model = sm.RLM(y, X, M=sm.robust.norms.RamsayE()).fit()
1560
+ #model = sm.RLM(y, X, M=sm.robust.norms.TrimmedMean()).fit()
1561
+
1562
+ elif regression_type == 'glm':
1563
+ model = sm.GLM(y, X, family=sm.families.Gaussian()).fit() # Gaussian: Used for continuous data, similar to OLS regression.
1564
+ #model = sm.GLM(y, X, family=sm.families.Binomial()).fit() # Binomial: Used for binary data, modeling the probability of success.
1565
+ #model = sm.GLM(y, X, family=sm.families.Poisson()).fit() # Poisson: Used for count data.
1566
+ #model = sm.GLM(y, X, family=sm.families.Gamma()).fit() # Gamma: Used for continuous, positive data, often for modeling waiting times or life data.
1567
+ #model = sm.GLM(y, X, family=sm.families.InverseGaussian()).fit() # Inverse Gaussian: Used for positive continuous data with a variance that increases with the
1568
+ #model = sm.GLM(y, X, family=sm.families.NegativeBinomial()).fit() # Negative Binomial: Used for count data with overdispersion (variance greater than the mean).
1569
+ #model = sm.GLM(y, X, family=sm.families.Tweedie()).fit() # Tweedie: Used for data that can take both positive continuous and count values, allowing for a mixture of distributions.
1570
+
1571
+ elif regression_type == 'mixed':
1572
+ model = perform_mixed_model(y, X, groups, alpha=alpha)
1573
+
1574
+ elif regression_type == 'quantile':
1575
+ model = sm.QuantReg(y, X).fit(q=alpha)
1576
+
1577
+ elif regression_type == 'logit':
1578
+ model = sm.Logit(y, X).fit()
1579
+
1580
+ elif regression_type == 'probit':
1581
+ model = sm.Probit(y, X).fit()
1582
+
1583
+ elif regression_type == 'poisson':
1584
+ model = sm.Poisson(y, X).fit()
1585
+
1586
+ elif regression_type == 'lasso':
1587
+ model = Lasso(alpha=alpha).fit(X, y)
1588
+
1589
+ elif regression_type == 'ridge':
1590
+ model = Ridge(alpha=alpha).fit(X, y)
1591
+
1592
+ else:
1593
+ raise ValueError(f"Unsupported regression type {regression_type}")
1594
+
1595
+ if regression_type in ['lasso', 'ridge']:
1596
+ y_pred = model.predict(X)
1597
+ plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
1598
+ plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
1599
+ plt.xlabel('Features')
1600
+ plt.ylabel('Dependent Variable')
1601
+ plt.legend()
1602
+ plt.show()
1603
+
1604
+ return model
1605
+
1606
+ def clean_controls(df,pc,nc,other):
1607
+ if 'col' in df.columns:
1608
+ df['column'] = df['col']
1609
+ if nc != None:
1610
+ df = df[~df['column'].isin([nc])]
1611
+ if pc != None:
1612
+ df = df[~df['column'].isin([pc])]
1613
+ if other != None:
1614
+ df = df[~df['column'].isin([other])]
1615
+ print(f'Removed data from {nc, pc, other}')
1616
+ return df
1617
+
1618
+ # Remove outliers by capping values at 1st and 99th percentiles for numerical columns only
1619
+ def remove_outliers(df, low=0.01, high=0.99):
1620
+ numerical_cols = df.select_dtypes(include=[np.number]).columns
1621
+ quantiles = df[numerical_cols].quantile([low, high])
1622
+ for col in numerical_cols:
1623
+ df[col] = np.clip(df[col], quantiles.loc[low, col], quantiles.loc[high, col])
1624
+ return df
1625
+
1626
+ def calculate_p_values(X, y, model):
1627
+ # Predict y values
1628
+ y_pred = model.predict(X)
1629
+
1630
+ # Calculate residuals
1631
+ residuals = y - y_pred
1632
+
1633
+ # Calculate the standard error of the residuals
1634
+ dof = X.shape[0] - X.shape[1] - 1
1635
+ residual_std_error = np.sqrt(np.sum(residuals ** 2) / dof)
1636
+
1637
+ # Calculate the standard error of the coefficients
1638
+ X_design = np.hstack((np.ones((X.shape[0], 1)), X)) # Add intercept
1639
+
1640
+ # Use pseudoinverse instead of inverse to handle singular matrices
1641
+ coef_var_covar = residual_std_error ** 2 * np.linalg.pinv(X_design.T @ X_design)
1642
+ coef_standard_errors = np.sqrt(np.diag(coef_var_covar))
1643
+
1644
+ # Calculate t-statistics
1645
+ t_stats = model.coef_ / coef_standard_errors[1:] # Skip intercept error
1646
+
1647
+ # Calculate p-values
1648
+ p_values = [2 * (1 - stats.t.cdf(np.abs(t), dof)) for t in t_stats]
1649
+
1650
+ return np.array(p_values) # Ensure p_values is a 1-dimensional array
1651
+
1652
+ def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0, remove_row_column_effect=False):
1653
+
1654
+ from .plot import volcano_plot, plot_histogram
1655
+
1656
+ volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
1657
+ volcano_filename = regression_type+'_'+volcano_filename
1658
+ if regression_type == 'quantile':
1659
+ volcano_filename = str(alpha)+'_'+volcano_filename
1660
+ volcano_path=os.path.join(os.path.dirname(csv_path), volcano_filename)
1661
+
1662
+ is_normal = check_normality(df[dependent_variable], dependent_variable)
1663
+
1664
+ if regression_type is None:
1665
+ if is_normal:
1666
+ regression_type = 'ols'
1667
+ else:
1668
+ regression_type = 'glm'
1669
+
1670
+ #df = remove_outliers(df)
1671
+
1672
+ if remove_row_column_effect:
1673
+
1674
+ ## 1. Fit the initial model with row and column to estimate their effects
1675
+ ## 2. Fit the initial model using the specified regression type
1676
+ ## 3. Calculate the residuals
1677
+ ### Residual calculation: Residuals are the differences between the observed and predicted values. This step checks if the initial_model has an attribute resid (residuals). If it does, it directly uses them. Otherwise, it calculates residuals manually by subtracting the predicted values from the observed values (y_with_row_col).
1678
+ ## 4. Use the residuals as the new dependent variable in the final regression model without row and column
1679
+ ### Formula creation: A new regression formula is created, excluding row and column effects, with residuals as the new dependent variable.
1680
+ ### Matrix creation: dmatrices is used again to create new design matrices (X for independent variables and y for the new dependent variable, residuals) based on the new formula and the dataframe df.
1681
+ #### Remove Confounding Effects:Variables like row and column can introduce systematic biases or confounding effects that might obscure the relationships between the dependent variable and the variables of interest (fraction:gene and fraction:grna).
1682
+ #### By first estimating the effects of row and column and then using the residuals (the part of the dependent variable that is not explained by row and column), we can focus the final regression model on the relationships of interest without the interference from row and column.
1683
+
1684
+ #### Reduce Multicollinearity: Including variables like row and column along with other predictors can sometimes lead to multicollinearity, where predictors are highly correlated with each other. This can make it difficult to determine the individual effect of each predictor.
1685
+ #### By regressing out the effects of row and column first, we reduce potential multicollinearity issues in the final model.
1686
+
1687
+ # Fit the initial model with row and column to estimate their effects
1688
+ formula_with_row_col = f'{dependent_variable} ~ row + column'
1689
+ y_with_row_col, X_with_row_col = dmatrices(formula_with_row_col, data=df, return_type='dataframe')
1690
+
1691
+ # Fit the initial model using the specified regression type
1692
+ initial_model = regression_model(X_with_row_col, y_with_row_col, regression_type=regression_type, alpha=alpha)
1693
+
1694
+ # Calculate the residuals manually
1695
+ if hasattr(initial_model, 'resid'):
1696
+ df['residuals'] = initial_model.resid
1697
+ else:
1698
+ df['residuals'] = y_with_row_col.values.ravel() - initial_model.predict(X_with_row_col)
1699
+
1700
+ # Use the residuals as the new dependent variable in the final regression model without row and column
1701
+ formula_without_row_col = 'residuals ~ fraction:gene + fraction:grna'
1702
+ y, X = dmatrices(formula_without_row_col, data=df, return_type='dataframe')
1703
+
1704
+ # Plot histogram of the residuals
1705
+ plot_histogram(df, 'residuals')
1706
+
1707
+ # Scale the independent variables and residuals
1708
+ scaler_X = MinMaxScaler()
1709
+ scaler_y = MinMaxScaler()
1710
+ X = pd.DataFrame(scaler_X.fit_transform(X), columns=X.columns)
1711
+ y = scaler_y.fit_transform(y)
1712
+
1713
+ else:
1714
+ formula = f'{dependent_variable} ~ fraction:gene + fraction:grna + row + column'
1715
+ y, X = dmatrices(formula, data=df, return_type='dataframe')
1716
+
1717
+ plot_histogram(y, dependent_variable)
1718
+
1719
+ # Scale the independent variables and dependent variable
1720
+ scaler_X = MinMaxScaler()
1721
+ scaler_y = MinMaxScaler()
1722
+ X = pd.DataFrame(scaler_X.fit_transform(X), columns=X.columns)
1723
+ y = scaler_y.fit_transform(y)
1724
+
1725
+ groups = df['prc'] if regression_type == 'mixed' else None
1726
+ print(f'performing {regression_type} regression')
1727
+ model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, remove_row_column_effect=remove_row_column_effect)
1728
+
1729
+ # Get the model coefficients and p-values
1730
+ if regression_type in ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson']:
1731
+ coefs = model.params
1732
+ p_values = model.pvalues
1733
+
1734
+ coef_df = pd.DataFrame({
1735
+ 'feature': coefs.index,
1736
+ 'coefficient': coefs.values,
1737
+ 'p_value': p_values.values
1738
+ })
1739
+ elif regression_type in ['ridge', 'lasso']:
1740
+ coefs = model.coef_
1741
+ coefs = np.array(coefs).flatten()
1742
+ # Calculate p-values
1743
+ p_values = calculate_p_values(X, y, model)
1744
+ p_values = np.array(p_values).flatten()
1745
+
1746
+ # Create a DataFrame for the coefficients and p-values
1747
+ coef_df = pd.DataFrame({
1748
+ 'feature': X.columns,
1749
+ 'coefficient': coefs,
1750
+ 'p_value': p_values})
1751
+ else:
1752
+ coefs = model.coef_
1753
+ intercept = model.intercept_
1754
+ feature_names = X.design_info.column_names
1755
+
1756
+ coef_df = pd.DataFrame({
1757
+ 'feature': feature_names,
1758
+ 'coefficient': coefs
1759
+ })
1760
+ coef_df.loc[0, 'coefficient'] += intercept
1761
+ coef_df['p_value'] = np.nan # Placeholder since sklearn doesn't provide p-values
1762
+
1763
+ coef_df['-log10(p_value)'] = -np.log10(coef_df['p_value'])
1764
+ coef_df_v = coef_df[coef_df['feature'] != 'Intercept']
1765
+
1766
+ # Create the highlight column
1767
+ coef_df['highlight'] = coef_df['feature'].apply(lambda x: '220950' in x)
1768
+ coef_df = coef_df[~coef_df['feature'].str.contains('row|column')]
1769
+ volcano_plot(coef_df, volcano_path)
1770
+
1771
+ return model, coef_df
1772
+
1773
+ def perform_regression(df, settings):
1774
+
1775
+ from spacr.plot import plot_plates
1776
+ from .utils import merge_regression_res_with_metadata
1777
+ from .settings import get_perform_regression_default_settings
1778
+
1779
+ reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
1780
+ if settings['regression_type'] not in reg_types:
1781
+ print(f'Possible regression types: {reg_types}')
1782
+ raise ValueError(f"Unsupported regression type {settings['regression_type']}")
1783
+
1784
+ if isinstance(df, str):
1785
+ df = pd.read_csv(df)
1786
+ elif isinstance(df, pd.DataFrame):
1787
+ pass
1788
+ else:
1789
+ raise ValueError("Data must be a DataFrame or a path to a CSV file")
1790
+
1791
+
1792
+ if settings['dependent_variable'] not in df.columns:
1793
+ print(f'Columns in DataFrame:')
1794
+ for col in df.columns:
1795
+ print(col)
1796
+ raise ValueError(f"Dependent variable {settings['dependent_variable']} not found in the DataFrame")
1797
+
1798
+ results_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results.csv'
1799
+ hits_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results_significant.csv'
1800
+
1801
+ results_filename = settings['regression_type']+'_'+results_filename
1802
+ hits_filename = settings['regression_type']+'_'+hits_filename
1803
+ if settings['regression_type'] == 'quantile':
1804
+ results_filename = str(settings['alpha'])+'_'+results_filename
1805
+ hits_filename = str(settings['alpha'])+'_'+hits_filename
1806
+ results_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), results_filename)
1807
+ hits_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), hits_filename)
1808
+
1809
+ settings = get_perform_regression_default_settings(settings)
1810
+
1811
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1812
+ settings_dir = os.path.dirname(settings['gene_weights_csv'])
1813
+ settings_csv = os.path.join(settings_dir,f"{settings['regression_type']}_regression_settings.csv")
1814
+ settings_df.to_csv(settings_csv, index=False)
1815
+ display(settings_df)
1816
+
1817
+ df = clean_controls(df,settings['pc'],settings['nc'],settings['other'])
1818
+
1819
+ if 'prediction_probability_class_1' in df.columns:
1820
+ if not settings['class_1_threshold'] is None:
1821
+ df['predictions'] = (df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
1822
+
1823
+ dependent_df, dependent_variable = process_scores(df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
1824
+
1825
+ display(dependent_df)
1826
+
1827
+ independent_df = precess_reads(settings['gene_weights_csv'], settings['fraction_threshold'], settings['plate'])
1828
+ display(independent_df)
1829
+
1830
+ merged_df = pd.merge(independent_df, dependent_df, on='prc')
1831
+
1832
+ merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
1833
+
1834
+ if settings['transform'] is None:
1835
+ _ = plot_plates(df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'])
1836
+
1837
+ model, coef_df = regression(merged_df, settings['gene_weights_csv'], dependent_variable, settings['regression_type'], settings['alpha'], settings['remove_row_column_effect'])
1838
+
1839
+ coef_df.to_csv(results_path, index=False)
1840
+
1841
+ if settings['regression_type'] == 'lasso':
1842
+ significant = coef_df[coef_df['coefficient'] > 0]
1843
+
1844
+ else:
1845
+ significant = coef_df[coef_df['p_value']<= 0.05]
1846
+ #significant = significant[significant['coefficient'] > 0.1]
1847
+ significant.sort_values(by='coefficient', ascending=False, inplace=True)
1848
+ significant = significant[~significant['feature'].str.contains('row|column')]
1849
+
1850
+ if settings['regression_type'] == 'ols':
1851
+ print(model.summary())
1852
+
1853
+ significant.to_csv(hits_path, index=False)
1854
+
1855
+ me49 = '/home/carruthers/Documents/TGME49_Summary.csv'
1856
+ gt1 = '/home/carruthers/Documents/TGGT1_Summary.csv'
1857
+
1858
+ _ = merge_regression_res_with_metadata(hits_path, me49, name='_me49_metadata')
1859
+ _ = merge_regression_res_with_metadata(hits_path, gt1, name='_gt1_metadata')
1860
+ _ = merge_regression_res_with_metadata(results_path, me49, name='_me49_metadata')
1861
+ _ = merge_regression_res_with_metadata(results_path, gt1, name='_gt1_metadata')
1862
+
1863
+ print('Significant Genes')
1864
+ display(significant)
1865
+ return coef_df
1866
+
1867
+