gwaslab 3.4.41__py3-none-any.whl → 3.4.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gwaslab might be problematic. Click here for more details.

@@ -21,8 +21,10 @@ from gwaslab.qc_check_datatype import check_dataframe_shape
21
21
  from gwaslab.bd_common_data import get_number_to_chr
22
22
  from gwaslab.bd_common_data import get_chr_list
23
23
  from gwaslab.bd_common_data import get_chr_to_number
24
+ from gwaslab.bd_common_data import _maketrans
24
25
  from gwaslab.g_vchange_status import vchange_status
25
26
  from gwaslab.g_version import _get_version
27
+ from gwaslab.cache_manager import CacheManager, PALINDROMIC_INDEL, NON_PALINDROMIC
26
28
 
27
29
  #rsidtochrpos
28
30
  #checkref
@@ -30,6 +32,34 @@ from gwaslab.g_version import _get_version
30
32
  #inferstrand
31
33
  #parallelecheckaf
32
34
 
35
+ ### CONSTANTS AND MAPPINGS ###
36
+
37
+ PADDING_VALUE = 100
38
+
39
+ # chr(0) should not be used in the mapping dict because it's a reserved value.
40
+ # Instead of starting from chr(1), we start from chr(2) because this could be useful in the future
41
+ # to compute the complementary allele with a simple XOR operation (e.g. 2 ^ 1 = 3, 3 ^ 1 = 2, 4 ^ 1 = 5, 5 ^ 1 = 4, ...)
42
+ MAPPING = {
43
+ "A": chr(2),
44
+ "T": chr(3),
45
+ "C": chr(4),
46
+ "G": chr(5),
47
+ "N": chr(6),
48
+ }
49
+ assert all(value != chr(0) for value in MAPPING.values()), "Mapping in the dictionary should not be equal to chr(0). This is a reserved value"
50
+
51
+ _COMPLEMENTARY_MAPPING = {
52
+ "A": "T",
53
+ "C": "G",
54
+ "G": "C",
55
+ "T": "A",
56
+ "N": "N",
57
+ }
58
+ COMPLEMENTARY_MAPPING = {k: MAPPING[v] for k,v in _COMPLEMENTARY_MAPPING.items()}
59
+
60
+ TRANSLATE_TABLE = _maketrans(MAPPING)
61
+ TRANSLATE_TABLE_COMPL = _maketrans(COMPLEMENTARY_MAPPING)
62
+
33
63
  #20220808
34
64
  #################################################################################################################
35
65
 
@@ -44,7 +74,7 @@ def rsidtochrpos(sumstats,
44
74
  ##start function with col checking##########################################################
45
75
  _start_line = "assign CHR and POS using rsIDs"
46
76
  _end_line = "assigning CHR and POS using rsIDs"
47
- _start_cols = [rsid,chrom,pos]
77
+ _start_cols = [rsid]
48
78
  _start_function = ".rsid_to_chrpos()"
49
79
  _must_args ={}
50
80
 
@@ -131,7 +161,7 @@ def parallelrsidtochrpos(sumstats, rsid="rsID", chrom="CHR",pos="POS", path=None
131
161
  ##start function with col checking##########################################################
132
162
  _start_line = "assign CHR and POS using rsIDs"
133
163
  _end_line = "assigning CHR and POS using rsIDs"
134
- _start_cols = [rsid,chrom,pos]
164
+ _start_cols = [rsid]
135
165
  _start_function = ".rsid_to_chrpos2()"
136
166
  _must_args ={}
137
167
 
@@ -186,7 +216,7 @@ def parallelrsidtochrpos(sumstats, rsid="rsID", chrom="CHR",pos="POS", path=None
186
216
  pool = Pool(n_cores)
187
217
  if chrom not in input_columns:
188
218
  log.write(" -Initiating CHR ... ",verbose=verbose)
189
- sumstats_rs[chrom]=pd.Series(dtype="Int32")
219
+ sumstats_rs[chrom]=pd.Series(dtype="Int64")
190
220
 
191
221
  if pos not in input_columns:
192
222
  log.write(" -Initiating POS ... ",verbose=verbose)
@@ -207,7 +237,7 @@ def parallelrsidtochrpos(sumstats, rsid="rsID", chrom="CHR",pos="POS", path=None
207
237
 
208
238
  # update CHR and POS using rsID with multiple threads
209
239
  sumstats_rs = pd.concat(pool.map(partial(merge_chrpos,all_groups_max=all_groups_max,path=path,build=build,status=status),df_split),ignore_index=True)
210
- sumstats_rs.loc[:,["CHR","POS"]] = sumstats_rs.loc[:,["CHR","POS"]].astype("Int64")
240
+ sumstats_rs[["CHR","POS"]] = sumstats_rs[["CHR","POS"]].astype("Int64")
211
241
  del df_split
212
242
  gc.collect()
213
243
  log.write(" -Merging group data... ",verbose=verbose)
@@ -234,8 +264,8 @@ def parallelrsidtochrpos(sumstats, rsid="rsID", chrom="CHR",pos="POS", path=None
234
264
  finished(log, verbose, _end_line)
235
265
  return sumstats
236
266
  ####################################################################################################################
237
- #20220426 check if non-effect allele is aligned with reference genome
238
- def check_status(row,record):
267
+ # old version
268
+ def _old_check_status(row,record):
239
269
  #pos,ea,nea
240
270
  # status
241
271
  #0 / -----> match
@@ -288,16 +318,14 @@ def check_status(row,record):
288
318
  return status_pre+"5"+status_end
289
319
  # ea !=ref
290
320
  return status_pre+"8"+status_end
291
-
292
321
 
293
- def checkref(sumstats,ref_path,chrom="CHR",pos="POS",ea="EA",nea="NEA",status="STATUS",chr_dict=get_chr_to_number(),remove=False,verbose=True,log=Log()):
322
+ def oldcheckref(sumstats,ref_seq,chrom="CHR",pos="POS",ea="EA",nea="NEA",status="STATUS",chr_dict=get_chr_to_number(),remove=False,verbose=True,log=Log()):
294
323
  ##start function with col checking##########################################################
295
324
  _start_line = "check if NEA is aligned with reference sequence"
296
325
  _end_line = "checking if NEA is aligned with reference sequence"
297
326
  _start_cols = [chrom,pos,ea,nea,status]
298
327
  _start_function = ".check_ref()"
299
328
  _must_args ={}
300
-
301
329
  is_enough_info = start_to(sumstats=sumstats,
302
330
  log=log,
303
331
  verbose=verbose,
@@ -308,10 +336,10 @@ def checkref(sumstats,ref_path,chrom="CHR",pos="POS",ea="EA",nea="NEA",status="S
308
336
  **_must_args)
309
337
  if is_enough_info == False: return sumstats
310
338
  ############################################################################################
311
- log.write(" -Reference genome FASTA file: "+ ref_path,verbose=verbose)
339
+ log.write(" -Reference genome FASTA file: "+ ref_seq,verbose=verbose)
312
340
  log.write(" -Checking records: ", end="",verbose=verbose)
313
341
  chromlist = get_chr_list(add_number=True)
314
- records = SeqIO.parse(ref_path, "fasta")
342
+ records = SeqIO.parse(ref_seq, "fasta")
315
343
  for record in records:
316
344
  #record = next(records)
317
345
  if record is not None:
@@ -323,7 +351,7 @@ def checkref(sumstats,ref_path,chrom="CHR",pos="POS",ea="EA",nea="NEA",status="S
323
351
  if i in chromlist:
324
352
  log.write(record_chr," ", end="",show_time=False,verbose=verbose)
325
353
  to_check_ref = (sumstats[chrom]==i) & (~sumstats[pos].isna()) & (~sumstats[nea].isna()) & (~sumstats[ea].isna())
326
- sumstats.loc[to_check_ref,status] = sumstats.loc[to_check_ref,[pos,ea,nea,status]].apply(lambda x:check_status(x,record),axis=1)
354
+ sumstats.loc[to_check_ref,status] = sumstats.loc[to_check_ref,[pos,ea,nea,status]].apply(lambda x:_old_check_status(x,record),axis=1)
327
355
 
328
356
  log.write("\n",end="",show_time=False,verbose=verbose)
329
357
 
@@ -360,6 +388,332 @@ def checkref(sumstats,ref_path,chrom="CHR",pos="POS",ea="EA",nea="NEA",status="S
360
388
  finished(log, verbose, _end_line)
361
389
  return sumstats
362
390
 
391
+ #20240320 check if non-effect allele is aligned with reference genome
392
+ def _fast_check_status(x: pd.DataFrame, record: np.array, starting_positions: np.array):
393
+ # status
394
+ #0 / -----> match
395
+ #1 / -----> Flipped Fixed
396
+ #2 / -----> Reverse_complementary Fixed
397
+ #3 / -----> flipped
398
+ #4 / -----> reverse_complementary
399
+ #5 / ------> reverse_complementary + flipped
400
+ #6 / -----> both allele on genome + unable to distinguish
401
+ #7 / ----> reverse_complementary + both allele on genome + unable to distinguish
402
+ #8 / -----> not on ref genome
403
+ #9 / ------> unchecked
404
+ if x.empty:
405
+ return np.array([])
406
+
407
+ # x is expected to be a DataFrame with these columns in that order: ['CHR', 'POS', 'EA', 'NEA', 'STATUS']
408
+ # In this way, we don't need to specify the columns names
409
+ _chrom = x.iloc[:, 0]
410
+ _pos = x.iloc[:, 1]
411
+ _ea = x.iloc[:, 2]
412
+ _nea = x.iloc[:, 3]
413
+ _status = x.iloc[:, 4]
414
+
415
+ # position of the status (i.e. x['STATUS']) that will be modified
416
+ status_flip_idx = 5
417
+
418
+ pos = _pos.values.astype(np.int64) # convert to int64 because they could be of type 'object'
419
+
420
+ # Rebase the chromosome numbers to 0-based indexing
421
+ # e.g. ['1', '2', '4', '2'] -> [0, 1, 2, 1]
422
+ # This is needed because record is a single 1D array containing all the records for all the selected chromosomes,
423
+ # so for instance if record contains the records for chr1, chr2, chr4 ([...chr1...chr2...chr4...]), we need to
424
+ # rebase the chromosome numbers to 0-based indexing to index the correct record portion when we do starting_positions[chrom]
425
+ # Note that in x there are only the rows for the same chromosomes for which we have the records in record
426
+ # (i.e. we don't have rows for chr3 if we don't have the record for chr3). This filtering is done in the caller function
427
+ _chrom = _chrom.values
428
+ unique_values, _ = np.unique(_chrom, return_inverse=True) # Get the sorted unique values and their indices
429
+ chrom = np.searchsorted(unique_values, _chrom) # Replace each value in '_chrom' with its corresponding index in the sorted unique values
430
+
431
+ max_len_nea = _nea.str.len().max()
432
+ max_len_ea = _ea.str.len().max()
433
+
434
+
435
+ # Let's apply the same magic used for the fasta records (check build_fasta_records() for details) to convert the NEA and EA to
436
+ # a numpy array of integers in a very fast way.
437
+ # In that case we start from a pd.Series to we can apply some built-in methods.
438
+ # Also, when doing nea.view('<u4'), each row will be automatically right-padded with zeros to reach the max_len_nea.
439
+ # For this reason, we then replace the zeros with out padding value
440
+ # (and that's why the mapping dict can't have chr(0) as a value, otherwise we would have zeros for both padding and a character)
441
+ # Reshaping is needed because .view('<u4') will create a flattened array
442
+ nea = _nea.str.translate(TRANSLATE_TABLE).to_numpy().astype(f'<U{max_len_nea}')
443
+ nea = nea.view('<u4').reshape(-1, max_len_nea).astype(np.uint8)
444
+ nea[nea == 0] = PADDING_VALUE # padding value
445
+
446
+ # Create a mask holding True at the position of non-padding values
447
+ mask_nea = nea != PADDING_VALUE
448
+
449
+ # Create the reverse complement of NEA
450
+ # In this case, we manually left-pad the translated string with the padding value, since the padding done by view('<u4') would be right-padded
451
+ # and that will make hard the reverse operation (because we would have e.g. [2, 2, 4, 100, ..., 100] which will be hard to convert into [4, 2, 2, 100, ..., 100])
452
+ rev_nea = _nea.str.translate(TRANSLATE_TABLE_COMPL).str.pad(max_len_nea, 'left', chr(PADDING_VALUE)).to_numpy().astype(f'<U{max_len_nea}')
453
+ rev_nea = rev_nea.view('<u4').reshape(-1, max_len_nea).astype(np.uint8)
454
+ rev_nea = rev_nea[:, ::-1]
455
+
456
+
457
+ # Let's do everything again for EA
458
+ ea = _ea.str.translate(TRANSLATE_TABLE).to_numpy().astype(f'<U{max_len_ea}')
459
+ ea = ea.view('<u4').reshape(-1, max_len_ea).astype(np.uint8)
460
+ ea[ea == 0] = PADDING_VALUE # padding value
461
+
462
+ mask_ea = ea != PADDING_VALUE
463
+
464
+ rev_ea = _ea.str.translate(TRANSLATE_TABLE_COMPL).str.pad(max_len_ea, 'left', chr(PADDING_VALUE)).to_numpy().astype(f'<U{max_len_ea}')
465
+ rev_ea = rev_ea.view('<u4').reshape(-1, max_len_ea).astype(np.uint8)
466
+ rev_ea = rev_ea[:, ::-1]
467
+
468
+
469
+ # Convert the status (which are integers represented as strings) to a numpy array of integers.
470
+ # Again, use the same concept as before to do this in a very fast way.
471
+ # e.g. ["9999999", "9939999", "9929999"] -> [[9, 9, 9, 9, 9, 9, 9], [9, 9, 3, 9, 9, 9, 9], [9, 9, 2, 9, 9, 9, 9]]
472
+ assert _status.str.len().value_counts().nunique() == 1 # all the status strings should have the same length, let's be sure of that.
473
+ status_len = len(_status.iloc[0])
474
+ mapping_status = {str(v): chr(v) for v in range(10)}
475
+ table_stats = _maketrans(mapping_status)
476
+ status = _status.str.translate(table_stats).to_numpy().astype(f'<U{status_len}')
477
+ status = status.view('<u4').reshape(-1, status_len).astype(np.uint8)
478
+
479
+
480
+ # Expand the position to a 2D array and subtract 1 to convert to 0-based indexing
481
+ # e.g. [2, 21, 46] -> [[1], [20], [45]]
482
+ pos = np.expand_dims(pos, axis=-1) - 1
483
+
484
+ # Create a modified indices array specifying the starting position of each chromosome in the concatenated record array
485
+ modified_indices = starting_positions[chrom]
486
+ modified_indices = modified_indices[:, np.newaxis] # Add a new axis to modified_indices to align with the dimensions of pos
487
+
488
+ # Create the range of indices: [0, ..., max_len_nea-1]
489
+ indices_range = np.arange(max_len_nea)
490
+
491
+ # Add the range of indices to the starting indices
492
+ # e.g. pos = [[1], [20], [45]], indices_range = [0, 1, 2], indices = [[1, 2, 3], [20, 21, 22], [45, 46, 47]]
493
+ indices = pos + indices_range
494
+
495
+ # Modify indices to select the correct absolute position in the concatenated record array
496
+ indices = indices + modified_indices
497
+
498
+ # Let's pad the fasta records array because if there is a (pos, chrom) for which (pos+starting_position[chrom]+max_len_nea > len(record) we get out of bounds error.
499
+ # This basically happens if there is a pos for the last chromosome for which pos+max_len_nea > len(record for that chrom).
500
+ # This is very unlikely to happen but we should handle this case.
501
+ record = np.pad(record, (0, max_len_nea), constant_values=PADDING_VALUE)
502
+
503
+ # Index the record array using the computed indices.
504
+ # Since we use np.take, indices must all have the same length, and this is why we added the padding to NEA
505
+ # and we create the indices using max_len_nea (long story short, we can't obtain a scattered/ragged array)
506
+ output_nea = np.take(record, indices)
507
+
508
+ # Check if the NEA is equal to the reference sequence at the given position
509
+ # In a non-matrix way, this is equivalent (for one single element) to:
510
+ # nea == record[pos-1: pos+len(nea)-1]
511
+ # where for example:
512
+ # a) nea = "AC", record = "ACTG", pos = 1 -> True
513
+ # b) nea = "T", record = "ACTG", pos = 3 -> True
514
+ # c) nea = "AG", record = "ACTG", pos = 1 -> False
515
+ # Since we want to do everything in a vectorized way, we will compare the padded NEA with the output
516
+ # and then we use the mask to focus only on the non-padded elements
517
+ # Pseudo example (X represents the padding value):
518
+ # nea = ['AC', 'T'], record = 'ACTGAAG', pos = [1, 3]
519
+ # -> nea = ['AC', 'TX'], indices = [[1, 2], [3, 4]], mask = [[True, True], [True, False]], output_nea = [['A', 'C'], ['T', 'G']]
520
+ # -> nea == output_nea: [[True, True], [True, False]], mask: [[True, True], [True, False]]
521
+ # -> nea == output_nea + ~mask: [[True, True], [True, True]]
522
+ # -> np.all(nea == output_nea + ~mask, 1): [True, True]
523
+ nea_eq_ref = np.all((nea == output_nea) + ~mask_nea, 1)
524
+ rev_nea_eq_ref = np.all((rev_nea == output_nea) + ~mask_nea, 1)
525
+
526
+ # Let's do everything again for EA
527
+ indices_range = np.arange(max_len_ea)
528
+ indices = pos + indices_range
529
+ indices = indices + modified_indices
530
+ output_ea = np.take(record, indices)
531
+
532
+ ea_eq_ref = np.all((ea == output_ea) + ~mask_ea, 1)
533
+ rev_ea_eq_ref = np.all((rev_ea == output_ea) + ~mask_ea, 1)
534
+
535
+ masks_max_len = max(mask_nea.shape[1], mask_ea.shape[1])
536
+
537
+ len_nea_eq_len_ea = np.all(
538
+ np.pad(mask_nea, ((0,0),(0, masks_max_len-mask_nea.shape[1])), constant_values=False) ==
539
+ np.pad(mask_ea, ((0,0),(0, masks_max_len-mask_ea.shape[1])), constant_values=False)
540
+ , axis=1) # pad masks with False to reach same shape
541
+ len_rev_nea_eq_rev_len_ea = len_nea_eq_len_ea
542
+
543
+ # The following conditions replicates the if-else statements of the original check_status function:
544
+ # https://github.com/Cloufield/gwaslab/blob/f6b4c4e58a26e5d67d6587141cde27acf9ce2a11/src/gwaslab/hm_harmonize_sumstats.py#L238
545
+
546
+ # nea == ref && ea == ref && len(nea) != len(ea)
547
+ status[nea_eq_ref * ea_eq_ref * ~len_nea_eq_len_ea, status_flip_idx] = 6
548
+
549
+ # nea == ref && ea != ref
550
+ status[nea_eq_ref * ~ea_eq_ref, status_flip_idx] = 0
551
+
552
+ # nea != ref && ea == ref
553
+ status[~nea_eq_ref * ea_eq_ref, status_flip_idx] = 3
554
+
555
+ # nea != ref && ea != ref && rev_nea == ref && rev_ea == ref && len(rev_nea) != len(rev_ea)
556
+ status[~nea_eq_ref * ~ea_eq_ref * rev_nea_eq_ref * rev_ea_eq_ref * ~len_rev_nea_eq_rev_len_ea, status_flip_idx] = 8
557
+
558
+ # nea != ref && ea != ref && rev_nea == ref && rev_ea != ref
559
+ status[~nea_eq_ref * ~ea_eq_ref * rev_nea_eq_ref * ~rev_ea_eq_ref, status_flip_idx] = 4
560
+
561
+ # nea != ref && ea != ref && rev_nea != ref && rev_ea == ref
562
+ status[~nea_eq_ref * ~ea_eq_ref * ~rev_nea_eq_ref * rev_ea_eq_ref, status_flip_idx] = 5
563
+
564
+ # nea != ref && ea != ref && rev_nea != ref && rev_ea != ref
565
+ status[~nea_eq_ref * ~ea_eq_ref * ~rev_nea_eq_ref * ~rev_ea_eq_ref, status_flip_idx] = 8
566
+
567
+ # Convert back the (now modified) 2D status array to a numpy array of strings in a very fast way.
568
+ # Since 'status' is a 2D array of integers ranging from 0 to 9, we can build the integer representation
569
+ # of each row using the efficent operation below (e.g. [1, 2, 3, 4, 5] -> [12345]).
570
+ # Then we convert this integer to a string using the f'<U{status.shape[1]}' dtype (e.g. 12345 -> '12345')
571
+ # The "naive" way would be:
572
+ # status_str = [''.join(map(str, l)) for l in status]
573
+ # status_arr = np.array(status_str)
574
+ status_flat = np.sum(status * 10**np.arange(status.shape[1]-1, -1, -1), axis=1)
575
+ status_arr = status_flat.astype(f'<U{status.shape[1]}')
576
+
577
+ return status_arr
578
+
579
+
580
+ def check_status(sumstats: pd.DataFrame, fasta_records_dict, log=Log(), verbose=True):
581
+
582
+ chrom,pos,ea,nea,status = sumstats.columns
583
+
584
+ # First, convert the fasta records to a single numpy array of integers
585
+ record, starting_positions_dict = build_fasta_records(fasta_records_dict, pos_as_dict=True, log=log, verbose=verbose)
586
+
587
+ # In _fast_check_status(), several 2D numpy arrays are created and they are padded to have shape[1] == max_len_nea or max_len_ea
588
+ # Since most of the NEA and EA strings are short, we perform the check first on the records having short NEA and EA strings,
589
+ # and then we perform the check on the records having long NEA and EA strings. In this way we can speed up the process (since the
590
+ # arrays are smaller) and save memory.
591
+ max_len = 4 # this is a chosen value, we could compute it using some stats about the length and count of NEA and EA strings
592
+ condition = (sumstats[nea].str.len() <= max_len) * (sumstats[ea].str.len() <= max_len)
593
+
594
+ log.write(f" -Checking records for ( len(NEA) <= {max_len} and len(EA) <= {max_len} )", verbose=verbose)
595
+ sumstats_cond = sumstats[condition]
596
+ starting_pos_cond = np.array([starting_positions_dict[k] for k in sumstats_cond[chrom].unique()])
597
+ sumstats.loc[condition, status] = _fast_check_status(sumstats_cond, record=record, starting_positions=starting_pos_cond)
598
+
599
+ log.write(f" -Checking records for ( len(NEA) > {max_len} or len(EA) > {max_len} )", verbose=verbose)
600
+ sumstats_not_cond = sumstats[~condition]
601
+ starting_not_pos_cond = np.array([starting_positions_dict[k] for k in sumstats_not_cond[chrom].unique()])
602
+ sumstats.loc[~condition, status] = _fast_check_status(sumstats_not_cond, record=record, starting_positions=starting_not_pos_cond)
603
+
604
+ return sumstats[status].values
605
+
606
+
607
+ def checkref(sumstats,ref_seq,chrom="CHR",pos="POS",ea="EA",nea="NEA",status="STATUS",chr_dict=get_chr_to_number(),remove=False,verbose=True,log=Log()):
608
+ ##start function with col checking##########################################################
609
+ _start_line = "check if NEA is aligned with reference sequence"
610
+ _end_line = "checking if NEA is aligned with reference sequence"
611
+ _start_cols = [chrom,pos,ea,nea,status]
612
+ _start_function = ".check_ref()"
613
+ _must_args ={}
614
+
615
+ is_enough_info = start_to(sumstats=sumstats,
616
+ log=log,
617
+ verbose=verbose,
618
+ start_line=_start_line,
619
+ end_line=_end_line,
620
+ start_cols=_start_cols,
621
+ start_function=_start_function,
622
+ **_must_args)
623
+ if is_enough_info == False: return sumstats
624
+ ############################################################################################
625
+ log.write(" -Reference genome FASTA file: "+ ref_seq,verbose=verbose)
626
+ log.write(" -Loading fasta records:",end="", verbose=verbose)
627
+ chromlist = get_chr_list(add_number=True)
628
+ records = SeqIO.parse(ref_seq, "fasta")
629
+
630
+ all_records_dict = {}
631
+ chroms_in_sumstats = sumstats[chrom].unique() # load records from Fasta file only for the chromosomes present in the sumstats
632
+ for record in records:
633
+ #record = next(records)
634
+ if record is not None:
635
+ record_chr = str(record.id).strip("chrCHR").upper()
636
+ if record_chr in chr_dict.keys():
637
+ i = chr_dict[record_chr]
638
+ else:
639
+ i = record_chr
640
+ if (i in chromlist) and (i in chroms_in_sumstats):
641
+ log.write(record_chr," ", end="",show_time=False,verbose=verbose)
642
+ all_records_dict.update({i: record})
643
+ log.write("",show_time=False,verbose=verbose)
644
+
645
+ if len(all_records_dict) > 0:
646
+ log.write(" -Checking records", verbose=verbose)
647
+ all_records_dict = dict(sorted(all_records_dict.items())) # sort by key in case the fasta records are not already ordered by chromosome
648
+ to_check_ref = (sumstats[chrom].isin(list(all_records_dict.keys()))) & (~sumstats[pos].isna()) & (~sumstats[nea].isna()) & (~sumstats[ea].isna())
649
+ sumstats_to_check = sumstats.loc[to_check_ref,[chrom,pos,ea,nea,status]]
650
+ sumstats.loc[to_check_ref,status] = check_status(sumstats_to_check, all_records_dict, log=log, verbose=verbose)
651
+ log.write(" -Finished checking records", verbose=verbose)
652
+
653
+ sumstats[status] = sumstats[status].astype("string")
654
+
655
+ available_to_check =sum( (~sumstats[pos].isna()) & (~sumstats[nea].isna()) & (~sumstats[ea].isna()))
656
+ status_0=sum(sumstats["STATUS"].str.match("\w\w\w\w\w[0]\w", case=False, flags=0, na=False))
657
+ status_3=sum(sumstats["STATUS"].str.match("\w\w\w\w\w[3]\w", case=False, flags=0, na=False))
658
+ status_4=sum(sumstats["STATUS"].str.match("\w\w\w\w\w[4]\w", case=False, flags=0, na=False))
659
+ status_5=sum(sumstats["STATUS"].str.match("\w\w\w\w\w[5]\w", case=False, flags=0, na=False))
660
+ status_6=sum(sumstats["STATUS"].str.match("\w\w\w\w\w[6]\w", case=False, flags=0, na=False))
661
+ #status_7=sum(sumstats["STATUS"].str.match("\w\w\w\w\w[7]\w", case=False, flags=0, na=False))
662
+ status_8=sum(sumstats["STATUS"].str.match("\w\w\w\w\w[8]\w", case=False, flags=0, na=False))
663
+
664
+ log.write(" -Variants allele on given reference sequence : ",status_0,verbose=verbose)
665
+ log.write(" -Variants flipped : ",status_3,verbose=verbose)
666
+ raw_matching_rate = (status_3+status_0)/available_to_check
667
+ flip_rate = status_3/available_to_check
668
+ log.write(" -Raw Matching rate : ","{:.2f}%".format(raw_matching_rate*100),verbose=verbose)
669
+ if raw_matching_rate <0.8:
670
+ log.warning("Matching rate is low, please check if the right reference genome is used.")
671
+ if flip_rate > 0.85 :
672
+ log.write(" -Flipping variants rate > 0.85, it is likely that the EA is aligned with REF in the original dataset.",verbose=verbose)
673
+
674
+ log.write(" -Variants inferred reverse_complement : ",status_4,verbose=verbose)
675
+ log.write(" -Variants inferred reverse_complement_flipped : ",status_5,verbose=verbose)
676
+ log.write(" -Both allele on genome + unable to distinguish : ",status_6,verbose=verbose)
677
+ #log.write(" -Reverse_complementary + both allele on genome + unable to distinguish: ",status_7)
678
+ log.write(" -Variants not on given reference sequence : ",status_8,verbose=verbose)
679
+
680
+ if remove is True:
681
+ sumstats = sumstats.loc[~sumstats["STATUS"].str.match("\w\w\w\w\w[8]\w"),:]
682
+ log.write(" -Variants not on given reference sequence were removed.",verbose=verbose)
683
+
684
+ finished(log, verbose, _end_line)
685
+ return sumstats
686
+
687
+ def build_fasta_records(fasta_records_dict, pos_as_dict=True, log=Log(), verbose=True):
688
+ log.write(" -Building numpy fasta records from dict", verbose=verbose)
689
+
690
+ # Let's do some magic to convert the fasta record to a numpy array of integers in a very fast way.
691
+ # fasta_record.seq._data is a byte-string, so we can use the bytes.maketrans to apply a translation.
692
+ # Here we map the bytes to the unicode character representing the desired integer as defined in the mapping dict
693
+ # (i.e. b'A' -> '\x02', b'T' -> '\x03', b'C' -> '\x04', b'G' -> '\x05', b'N' -> '\x06')
694
+ # Then, using np.array(... dtype=<U..) we convert the string to a numpy array of unicode characters.
695
+ # Then, we do a magic with view('<u4') to convert the unicode characters to 4-byte integers, so we obtain the actual integer representation of the characters
696
+ # Lastly, we cast the array to np.uint8 to convert the 4-byte integers to 1-byte integers to save memory
697
+ # Full example:
698
+ # fasta_record.seq._data = b'ACTGN' -> b'\x02\x04\x03\x05\x06' -> np.array(['\x02\x04\x03\x05\x06'], dtype='<U5') -> np.array([2, 4, 3, 5, 6], dtype=uint32) -> np.array([2, 4, 3, 5, 6], dtype=uint8)
699
+ all_r = []
700
+ for r in fasta_records_dict.values():
701
+ r = r.seq._data.translate(TRANSLATE_TABLE)
702
+ r = np.array([r], dtype=f'<U{len(r)}').view('<u4').astype(np.uint8)
703
+ all_r.append(r)
704
+
705
+ # We've just created a list of numpy arrays, so we can concatenate them to obtain a single numpy array
706
+ # Then we keep track of the starting position of each record in the concatenated array. This will be useful later
707
+ # to index the record array depending on the position of the variant and the chromosome
708
+ records_len = np.array([len(r) for r in all_r])
709
+ starting_positions = np.cumsum(records_len) - records_len
710
+ if pos_as_dict:
711
+ starting_positions = {k: v for k, v in zip(fasta_records_dict.keys(), starting_positions)}
712
+ record = np.concatenate(all_r)
713
+ del all_r # free memory
714
+
715
+ return record, starting_positions
716
+
363
717
  #######################################################################################################################################
364
718
 
365
719
  #20220721
@@ -559,6 +913,56 @@ def check_strand_status(chr,start,end,ref,alt,eaf,vcf_reader,alt_freq,status,chr
559
913
  return status_pre+"5"+status_end
560
914
  return status_pre+"8"+status_end
561
915
 
916
+ def check_strand_status_cache(data,cache,ref_infer=None,ref_alt_freq=None,chr_dict=get_number_to_chr(),trust_cache=True,log=Log(),verbose=True):
917
+ if not trust_cache:
918
+ assert ref_infer is not None, "If trust_cache is False, ref_infer must be provided"
919
+ log.warning("You are not trusting the cache, this will slow down the process. Please consider building a complete cache.")
920
+
921
+ if ref_infer is not None and not trust_cache:
922
+ vcf_reader = VariantFile(ref_infer)
923
+
924
+ if isinstance(data, pd.DataFrame):
925
+ data = data.values
926
+
927
+ in_cache = 0
928
+ new_statuses = []
929
+
930
+ for i in range(data.shape[0]):
931
+ _chrom, pos, ref, alt, eaf, status = data[i]
932
+ chrom = _chrom
933
+ start = pos - 1
934
+ end = pos
935
+
936
+ if chr_dict is not None: chrom=chr_dict[chrom]
937
+
938
+ status_pre=status[:6]
939
+ status_end=""
940
+
941
+ new_status = status_pre+"8"+status_end # default value
942
+
943
+ cache_key = f"{chrom}:{pos}:{ref}:{alt}"
944
+ if cache_key in cache:
945
+ in_cache += 1
946
+ record = cache[cache_key]
947
+ if record is None:
948
+ new_status = status_pre+"8"+status_end
949
+ else:
950
+ if (record<0.5) and (eaf<0.5):
951
+ new_status = status_pre+"1"+status_end
952
+ elif (record>0.5) and (eaf>0.5):
953
+ new_status = status_pre+"1"+status_end
954
+ else:
955
+ new_status = status_pre+"5"+status_end
956
+ else:
957
+ if not trust_cache:
958
+ # If we don't trust the cache as a not complete cache, we should perform the check reading from the VCF file
959
+ new_status = check_strand_status(_chrom, start, end, ref, alt, eaf, vcf_reader, ref_alt_freq, status, chr_dict)
960
+
961
+ new_statuses.append(new_status)
962
+
963
+ log.write(f" -Elements in cache: {in_cache}", verbose=verbose)
964
+ return new_statuses
965
+
562
966
 
563
967
  def check_unkonwn_indel(chr,start,end,ref,alt,eaf,vcf_reader,alt_freq,status,chr_dict=get_number_to_chr(),daf_tolerance=0.2):
564
968
  ### input : unknown indel, both on genome (xx1[45]x)
@@ -586,6 +990,65 @@ def check_unkonwn_indel(chr,start,end,ref,alt,eaf,vcf_reader,alt_freq,status,chr
586
990
 
587
991
  return status_pre+"8"+status_end
588
992
 
993
+
994
+ def check_unkonwn_indel_cache(data,cache,ref_infer=None,ref_alt_freq=None,chr_dict=get_number_to_chr(),daf_tolerance=0.2,trust_cache=True,log=Log(),verbose=True):
995
+ if not trust_cache:
996
+ assert ref_infer is not None, "If trust_cache is False, ref_infer must be provided"
997
+ log.warning("You are not trusting the cache, this will slow down the process. Please consider building a complete cache.")
998
+
999
+ if ref_infer is not None:
1000
+ vcf_reader = VariantFile(ref_infer)
1001
+
1002
+ if isinstance(data, pd.DataFrame):
1003
+ data = data.values
1004
+
1005
+ in_cache = 0
1006
+ new_statuses = []
1007
+
1008
+ for i in range(data.shape[0]):
1009
+ _chrom, pos, ref, alt, eaf, status = data[i]
1010
+ chrom = _chrom
1011
+
1012
+ if chr_dict is not None: chrom=chr_dict[chrom]
1013
+ start = pos - 1
1014
+ end = pos
1015
+
1016
+ status_pre=status[:6]
1017
+ status_end=""
1018
+
1019
+ new_status = status_pre+"8"+status_end # default value
1020
+
1021
+ cache_key_ref_alt = f"{chrom}:{pos}:{ref}:{alt}"
1022
+ cache_key_alt_ref = f"{chrom}:{pos}:{alt}:{ref}"
1023
+
1024
+ if cache_key_ref_alt in cache:
1025
+ in_cache += 1
1026
+ record = cache[cache_key_ref_alt]
1027
+ if record is None:
1028
+ new_status = status_pre+"8"+status_end
1029
+ else:
1030
+ if abs(record - eaf)<daf_tolerance:
1031
+ new_status = status_pre+"3"+status_end
1032
+
1033
+ elif cache_key_alt_ref in cache:
1034
+ in_cache += 1
1035
+ record = cache[cache_key_alt_ref]
1036
+ if record is None:
1037
+ new_status = status_pre+"8"+status_end
1038
+ else:
1039
+ if abs(record - (1 - eaf))<daf_tolerance:
1040
+ new_status = status_pre+"6"+status_end
1041
+
1042
+ else:
1043
+ if not trust_cache:
1044
+ # If we don't trust the cache as a not complete cache, we should perform the check reading from the VCF file
1045
+ new_status = check_unkonwn_indel(_chrom, start, end, ref, alt, eaf, vcf_reader, ref_alt_freq, status, chr_dict, daf_tolerance)
1046
+
1047
+ new_statuses.append(new_status)
1048
+
1049
+ log.write(f" -Elements in cache: {in_cache}", verbose=verbose)
1050
+ return new_statuses
1051
+
589
1052
 
590
1053
  def get_reverse_complementary_allele(a):
591
1054
  dic = str.maketrans({
@@ -610,16 +1073,40 @@ def check_strand(sumstats,ref_infer,ref_alt_freq=None,chr="CHR",pos="POS",ref="N
610
1073
  status_part = sumstats.apply(lambda x:check_strand_status(x.iloc[0],x.iloc[1]-1,x.iloc[1],x.iloc[2],x.iloc[3],x.iloc[4],vcf_reader,ref_alt_freq,x.iloc[5],chr_dict),axis=1)
611
1074
  return status_part
612
1075
 
1076
+ def check_strand_cache(sumstats,cache,ref_infer,ref_alt_freq=None,chr_dict=get_number_to_chr(),trust_cache=True,log=Log(),verbose=True):
1077
+ assert cache is not None, "Cache must be provided"
1078
+ status_part = check_strand_status_cache(sumstats,cache,ref_infer,ref_alt_freq,chr_dict,trust_cache,log,verbose)
1079
+ return status_part
1080
+
613
1081
  def check_indel(sumstats,ref_infer,ref_alt_freq=None,chr="CHR",pos="POS",ref="NEA",alt="EA",eaf="EAF",chr_dict=get_number_to_chr(),status="STATUS",daf_tolerance=0.2):
614
1082
  vcf_reader = VariantFile(ref_infer)
615
1083
  status_part = sumstats.apply(lambda x:check_unkonwn_indel(x.iloc[0],x.iloc[1]-1,x.iloc[1],x.iloc[2],x.iloc[3],x.iloc[4],vcf_reader,ref_alt_freq,x.iloc[5],chr_dict,daf_tolerance),axis=1)
616
1084
  return status_part
617
1085
 
1086
+ def check_indel_cache(sumstats,cache,ref_infer,ref_alt_freq=None,chr_dict=get_number_to_chr(),daf_tolerance=0.2,trust_cache=True,log=Log(),verbose=True):
1087
+ assert cache is not None, "Cache must be provided"
1088
+ status_part = check_unkonwn_indel_cache(sumstats,cache,ref_infer,ref_alt_freq,chr_dict,daf_tolerance,trust_cache,log,verbose)
1089
+ return status_part
1090
+
618
1091
  ##################################################################################################################################################
619
1092
 
620
1093
  def parallelinferstrand(sumstats,ref_infer,ref_alt_freq=None,maf_threshold=0.40,daf_tolerance=0.20,remove_snp="",mode="pi",n_cores=1,remove_indel="",
621
1094
  chr="CHR",pos="POS",ref="NEA",alt="EA",eaf="EAF",status="STATUS",
622
- chr_dict=None,verbose=True,log=Log()):
1095
+ chr_dict=None,cache_options={},verbose=True,log=Log()):
1096
+ '''
1097
+ Args:
1098
+ cache_options : A dictionary with the following keys:
1099
+ - cache_manager: CacheManager object or None. If any between cache_loader and cache_process is not None, or use_cache is True, a CacheManager object will be created automatically.
1100
+ - trust_cache: bool (optional, default: True). Whether to completely trust the cache or not. Trusting the cache means that any key not found inside the cache will be considered as a missing value even in the VCF file.
1101
+ - cache_loader: Object with a get_cache() method or None.
1102
+ - cache_process: Object with an apply_fn() method or None.
1103
+ - use_cache: bool (optional, default: False). If any of the cache_manager, cache_loader or cache_process is not None, this will be set to True automatically.
1104
+ If set to True and all between cache_manager, cache_loader and cache_process are None, the cache will be loaded (or built) on the spot.
1105
+
1106
+ The usefulness of a cache_loader or cache_process object is to pass a custom object which already has the cache loaded. This can be useful if the cache is loaded in background in another thread/process while other operations are performed.
1107
+ The cache_manager is a CacheManager object is used to expose the API to interact with the cache.
1108
+ '''
1109
+
623
1110
  ##start function with col checking##########################################################
624
1111
  _start_line = "infer strand for palindromic SNPs/align indistinguishable indels"
625
1112
  _end_line = "inferring strand for palindromic SNPs/align indistinguishable indels"
@@ -642,6 +1129,16 @@ def parallelinferstrand(sumstats,ref_infer,ref_alt_freq=None,maf_threshold=0.40,
642
1129
 
643
1130
  chr_dict = auto_check_vcf_chr_dict(ref_infer, chr_dict, verbose, log)
644
1131
 
1132
+ # Setup cache variables
1133
+ cache_manager = cache_options.get("cache_manager", None)
1134
+ if cache_manager is not None:
1135
+ assert isinstance(cache_manager, CacheManager), "cache_manager must be a CacheManager object"
1136
+ trust_cache = cache_options.get("trust_cache", True)
1137
+ cache_loader = cache_options.get("cache_loader", None)
1138
+ cache_process = cache_options.get("cache_process", None)
1139
+ use_cache = any(c is not None for c in [cache_manager, cache_loader, cache_process]) or cache_options.get('use_cache', False)
1140
+ _n_cores = n_cores # backup n_cores
1141
+
645
1142
  log.write(" -Field for alternative allele frequency in VCF INFO: {}".format(ref_alt_freq), verbose=verbose)
646
1143
 
647
1144
  if "p" in mode:
@@ -669,16 +1166,30 @@ def parallelinferstrand(sumstats,ref_infer,ref_alt_freq=None,maf_threshold=0.40,
669
1166
  #########################################################################################
670
1167
  if sum(unknow_palindromic_to_check)>0:
671
1168
  if sum(unknow_palindromic_to_check)<10000:
672
- n_cores=1
673
-
674
- #df_split = np.array_split(sumstats.loc[unknow_palindromic_to_check,[chr,pos,ref,alt,eaf,status]], n_cores)
675
- df_split = _df_split(sumstats.loc[unknow_palindromic_to_check,[chr,pos,ref,alt,eaf,status]], n_cores)
676
- pool = Pool(n_cores)
677
- map_func = partial(check_strand,chr=chr,pos=pos,ref=ref,alt=alt,eaf=eaf,status=status,ref_infer=ref_infer,ref_alt_freq=ref_alt_freq,chr_dict=chr_dict)
678
- status_inferred = pd.concat(pool.map(map_func,df_split))
679
- sumstats.loc[unknow_palindromic_to_check,status] = status_inferred.values
680
- pool.close()
681
- pool.join()
1169
+ n_cores=1
1170
+
1171
+ if use_cache and cache_manager is None:
1172
+ cache_manager = CacheManager(base_path=ref_infer, cache_loader=cache_loader, cache_process=cache_process,
1173
+ ref_alt_freq=ref_alt_freq, category=PALINDROMIC_INDEL,
1174
+ n_cores=_n_cores, log=log, verbose=verbose)
1175
+
1176
+ log.write(" -Starting strand inference for palindromic SNPs...",verbose=verbose)
1177
+ df_to_check = sumstats.loc[unknow_palindromic_to_check,[chr,pos,ref,alt,eaf,status]]
1178
+
1179
+ if use_cache and cache_manager.cache_len > 0:
1180
+ log.write(" -Using cache for strand inference",verbose=verbose)
1181
+ status_inferred = cache_manager.apply_fn(check_strand_cache, sumstats=df_to_check, ref_infer=ref_infer, ref_alt_freq=ref_alt_freq, chr_dict=chr_dict, trust_cache=trust_cache, log=log, verbose=verbose)
1182
+ sumstats.loc[unknow_palindromic_to_check,status] = status_inferred
1183
+ else:
1184
+ #df_split = np.array_split(df_to_check, n_cores)
1185
+ df_split = _df_split(df_to_check, n_cores)
1186
+ pool = Pool(n_cores)
1187
+ map_func = partial(check_strand,chr=chr,pos=pos,ref=ref,alt=alt,eaf=eaf,status=status,ref_infer=ref_infer,ref_alt_freq=ref_alt_freq,chr_dict=chr_dict)
1188
+ status_inferred = pd.concat(pool.map(map_func,df_split))
1189
+ sumstats.loc[unknow_palindromic_to_check,status] = status_inferred.values
1190
+ pool.close()
1191
+ pool.join()
1192
+ log.write(" -Finished strand inference.",verbose=verbose)
682
1193
  else:
683
1194
  log.warning("No palindromic variants available for checking.")
684
1195
  #########################################################################################
@@ -729,15 +1240,30 @@ def parallelinferstrand(sumstats,ref_infer,ref_alt_freq=None,maf_threshold=0.40,
729
1240
 
730
1241
  if sum(unknow_indel)>0:
731
1242
  if sum(unknow_indel)<10000:
732
- n_cores=1
733
- #df_split = np.array_split(sumstats.loc[unknow_indel, [chr,pos,ref,alt,eaf,status]], n_cores)
734
- df_split = _df_split(sumstats.loc[unknow_indel, [chr,pos,ref,alt,eaf,status]], n_cores)
735
- pool = Pool(n_cores)
736
- map_func = partial(check_indel,chr=chr,pos=pos,ref=ref,alt=alt,eaf=eaf,status=status,ref_infer=ref_infer,ref_alt_freq=ref_alt_freq,chr_dict=chr_dict,daf_tolerance=daf_tolerance)
737
- status_inferred = pd.concat(pool.map(map_func,df_split))
738
- sumstats.loc[unknow_indel,status] = status_inferred.values
739
- pool.close()
740
- pool.join()
1243
+ n_cores=1
1244
+
1245
+ if use_cache and cache_manager is None:
1246
+ cache_manager = CacheManager(base_path=ref_infer, cache_loader=cache_loader, cache_process=cache_process,
1247
+ ref_alt_freq=ref_alt_freq, category=PALINDROMIC_INDEL,
1248
+ n_cores=_n_cores, log=log, verbose=verbose)
1249
+
1250
+ log.write(" -Starting indistinguishable indel inference...",verbose=verbose)
1251
+ df_to_check = sumstats.loc[unknow_indel,[chr,pos,ref,alt,eaf,status]]
1252
+
1253
+ if use_cache and cache_manager.cache_len > 0:
1254
+ log.write(" -Using cache for indel inference",verbose=verbose)
1255
+ status_inferred = cache_manager.apply_fn(check_indel_cache, sumstats=df_to_check, ref_infer=ref_infer, ref_alt_freq=ref_alt_freq, chr_dict=chr_dict, daf_tolerance=daf_tolerance, trust_cache=trust_cache, log=log, verbose=verbose)
1256
+ sumstats.loc[unknow_indel,status] = status_inferred
1257
+ else:
1258
+ #df_split = np.array_split(sumstats.loc[unknow_indel, [chr,pos,ref,alt,eaf,status]], n_cores)
1259
+ df_split = _df_split(sumstats.loc[unknow_indel, [chr,pos,ref,alt,eaf,status]], n_cores)
1260
+ pool = Pool(n_cores)
1261
+ map_func = partial(check_indel,chr=chr,pos=pos,ref=ref,alt=alt,eaf=eaf,status=status,ref_infer=ref_infer,ref_alt_freq=ref_alt_freq,chr_dict=chr_dict,daf_tolerance=daf_tolerance)
1262
+ status_inferred = pd.concat(pool.map(map_func,df_split))
1263
+ sumstats.loc[unknow_indel,status] = status_inferred.values
1264
+ pool.close()
1265
+ pool.join()
1266
+ log.write(" -Finished indistinguishable indel inference.",verbose=verbose)
741
1267
 
742
1268
  #########################################################################################
743
1269