smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,129 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import scipy.sparse as sp
4
+ from typing import Optional, List, Dict, Union
5
+
6
+ def add_read_length_and_mapping_qc(
7
+ adata,
8
+ bam_files: Optional[List[str]] = None,
9
+ read_metrics: Optional[Dict[str, Union[list, tuple]]] = None,
10
+ uns_flag: str = "read_lenth_and_mapping_qc_performed",
11
+ extract_read_features_from_bam_callable = None,
12
+ bypass: bool = False,
13
+ force_redo: bool = True
14
+ ):
15
+ """
16
+ Populate adata.obs with read/mapping QC columns.
17
+
18
+ Parameters
19
+ ----------
20
+ adata
21
+ AnnData to annotate (modified in-place).
22
+ bam_files
23
+ Optional list of BAM files to extract metrics from. Ignored if read_metrics supplied.
24
+ read_metrics
25
+ Optional dict mapping obs_name -> [read_length, read_quality, reference_length, mapped_length, mapping_quality]
26
+ If provided, this will be used directly and bam_files will be ignored.
27
+ uns_flag
28
+ key in final_adata.uns used to record that QC was performed (kept the name with original misspelling).
29
+ extract_read_features_from_bam_callable
30
+ Optional callable(bam_path) -> dict mapping read_name -> list/tuple of metrics.
31
+ If not provided and bam_files is given, function will attempt to call `extract_read_features_from_bam`
32
+ from the global namespace (your existing helper).
33
+ Returns
34
+ -------
35
+ None (mutates final_adata in-place)
36
+ """
37
+
38
+ # Only run if not already performed
39
+ already = bool(adata.uns.get(uns_flag, False))
40
+ if (already and not force_redo) or bypass:
41
+ # QC already performed; nothing to do
42
+ return
43
+
44
+ # Build read_metrics dict either from provided arg or by extracting from bam files
45
+ if read_metrics is None:
46
+ read_metrics = {}
47
+ if bam_files:
48
+ extractor = extract_read_features_from_bam_callable or globals().get("extract_read_features_from_bam")
49
+ if extractor is None:
50
+ raise ValueError("No `read_metrics` provided and `extract_read_features_from_bam` not found.")
51
+ for bam in bam_files:
52
+ bam_read_metrics = extractor(bam)
53
+ if not isinstance(bam_read_metrics, dict):
54
+ raise ValueError(f"extract_read_features_from_bam returned non-dict for {bam}")
55
+ read_metrics.update(bam_read_metrics)
56
+ else:
57
+ # nothing to do
58
+ read_metrics = {}
59
+
60
+ # Convert read_metrics dict -> DataFrame (rows = read id)
61
+ # Values may be lists/tuples or scalars; prefer lists/tuples with 5 entries.
62
+ if len(read_metrics) == 0:
63
+ # fill with NaNs
64
+ n = adata.n_obs
65
+ adata.obs['read_length'] = np.full(n, np.nan)
66
+ adata.obs['mapped_length'] = np.full(n, np.nan)
67
+ adata.obs['reference_length'] = np.full(n, np.nan)
68
+ adata.obs['read_quality'] = np.full(n, np.nan)
69
+ adata.obs['mapping_quality'] = np.full(n, np.nan)
70
+ else:
71
+ # Build DF robustly
72
+ # Convert values to lists where possible, else to [val, val, val...]
73
+ max_cols = 5
74
+ rows = {}
75
+ for k, v in read_metrics.items():
76
+ if isinstance(v, (list, tuple, np.ndarray)):
77
+ vals = list(v)
78
+ else:
79
+ # scalar -> replicate into 5 columns to preserve original behavior
80
+ vals = [v] * max_cols
81
+ # Ensure length >= 5
82
+ if len(vals) < max_cols:
83
+ vals = vals + [np.nan] * (max_cols - len(vals))
84
+ rows[k] = vals[:max_cols]
85
+
86
+ df = pd.DataFrame.from_dict(rows, orient='index', columns=[
87
+ 'read_length', 'read_quality', 'reference_length', 'mapped_length', 'mapping_quality'
88
+ ])
89
+
90
+ # Reindex to final_adata.obs_names so order matches adata
91
+ # If obs_names are not present as keys in df, the results will be NaN
92
+ df_reindexed = df.reindex(adata.obs_names).astype(float)
93
+
94
+ adata.obs['read_length'] = df_reindexed['read_length'].values
95
+ adata.obs['mapped_length'] = df_reindexed['mapped_length'].values
96
+ adata.obs['reference_length'] = df_reindexed['reference_length'].values
97
+ adata.obs['read_quality'] = df_reindexed['read_quality'].values
98
+ adata.obs['mapping_quality'] = df_reindexed['mapping_quality'].values
99
+
100
+ # Compute ratio columns safely (avoid divide-by-zero and preserve NaN)
101
+ # read_length_to_reference_length_ratio
102
+ rl = pd.to_numeric(adata.obs['read_length'], errors='coerce').to_numpy(dtype=float)
103
+ ref_len = pd.to_numeric(adata.obs['reference_length'], errors='coerce').to_numpy(dtype=float)
104
+ mapped_len = pd.to_numeric(adata.obs['mapped_length'], errors='coerce').to_numpy(dtype=float)
105
+
106
+ # safe divisions: use np.where to avoid warnings and replace inf with nan
107
+ with np.errstate(divide='ignore', invalid='ignore'):
108
+ rl_to_ref = np.where((ref_len != 0) & np.isfinite(ref_len), rl / ref_len, np.nan)
109
+ mapped_to_ref = np.where((ref_len != 0) & np.isfinite(ref_len), mapped_len / ref_len, np.nan)
110
+ mapped_to_read = np.where((rl != 0) & np.isfinite(rl), mapped_len / rl, np.nan)
111
+
112
+ adata.obs['read_length_to_reference_length_ratio'] = rl_to_ref
113
+ adata.obs['mapped_length_to_reference_length_ratio'] = mapped_to_ref
114
+ adata.obs['mapped_length_to_read_length_ratio'] = mapped_to_read
115
+
116
+ # Add read level raw modification signal: sum over X rows
117
+ X = adata.X
118
+ if sp.issparse(X):
119
+ # sum returns (n_obs, 1) sparse matrix; convert to 1d array
120
+ raw_sig = np.asarray(X.sum(axis=1)).ravel()
121
+ else:
122
+ raw_sig = np.asarray(X.sum(axis=1)).ravel()
123
+
124
+ adata.obs['Raw_modification_signal'] = raw_sig
125
+
126
+ # mark as done
127
+ adata.uns[uns_flag] = True
128
+
129
+ return None
@@ -0,0 +1,122 @@
1
+ def append_base_context(adata,
2
+ obs_column='Reference_strand',
3
+ use_consensus=False,
4
+ native=False,
5
+ mod_target_bases=['GpC', 'CpG'],
6
+ bypass=False,
7
+ force_redo=False,
8
+ uns_flag='base_context_added'
9
+ ):
10
+ """
11
+ Adds nucleobase context to the position within the given category. When use_consensus is True, it uses the consensus sequence, otherwise it defaults to the FASTA sequence.
12
+
13
+ Parameters:
14
+ adata (AnnData): The input adata object.
15
+ obs_column (str): The observation column in which to stratify on. Default is 'Reference_strand', which should not be changed for most purposes.
16
+ use_consensus (bool): A truth statement indicating whether to use the consensus sequence from the reads mapped to the reference. If False, the reference FASTA is used instead.
17
+ native (bool): If False, perform conversion SMF assumptions. If True, perform native SMF assumptions
18
+ mod_target_bases (list): Base contexts that may be modified.
19
+
20
+ Returns:
21
+ None
22
+ """
23
+ import numpy as np
24
+ import anndata as ad
25
+
26
+ # Only run if not already performed
27
+ already = bool(adata.uns.get(uns_flag, False))
28
+ if (already and not force_redo) or bypass:
29
+ # QC already performed; nothing to do
30
+ return
31
+
32
+ print('Adding base context based on reference FASTA sequence for sample')
33
+ categories = adata.obs[obs_column].cat.categories
34
+ site_types = []
35
+
36
+ if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
37
+ site_types += ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C_site', 'C_site']
38
+
39
+ if 'A' in mod_target_bases:
40
+ site_types += ['A_site']
41
+
42
+ for cat in categories:
43
+ # Assess if the strand is the top or bottom strand converted
44
+ if 'top' in cat:
45
+ strand = 'top'
46
+ elif 'bottom' in cat:
47
+ strand = 'bottom'
48
+
49
+ if native:
50
+ basename = cat.split(f"_{strand}")[0]
51
+ if use_consensus:
52
+ sequence = adata.uns[f'{basename}_consensus_sequence']
53
+ else:
54
+ # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
55
+ sequence = adata.uns[f'{basename}_FASTA_sequence']
56
+ else:
57
+ basename = cat.split(f"_{strand}")[0]
58
+ if use_consensus:
59
+ sequence = adata.uns[f'{basename}_consensus_sequence']
60
+ else:
61
+ # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
62
+ sequence = adata.uns[f'{basename}_FASTA_sequence']
63
+ # Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
64
+ boolean_dict = {}
65
+ for site_type in site_types:
66
+ boolean_dict[f'{cat}_{site_type}'] = np.full(len(sequence), False, dtype=bool)
67
+
68
+ if any(base in mod_target_bases for base in ['GpC', 'CpG', 'C']):
69
+ if strand == 'top':
70
+ # Iterate through the sequence and apply the criteria
71
+ for i in range(1, len(sequence) - 1):
72
+ if sequence[i] == 'C':
73
+ boolean_dict[f'{cat}_C_site'][i] = True
74
+ if sequence[i - 1] == 'G' and sequence[i + 1] != 'G':
75
+ boolean_dict[f'{cat}_GpC_site'][i] = True
76
+ elif sequence[i - 1] == 'G' and sequence[i + 1] == 'G':
77
+ boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
78
+ elif sequence[i - 1] != 'G' and sequence[i + 1] == 'G':
79
+ boolean_dict[f'{cat}_CpG_site'][i] = True
80
+ elif sequence[i - 1] != 'G' and sequence[i + 1] != 'G':
81
+ boolean_dict[f'{cat}_other_C_site'][i] = True
82
+ elif strand == 'bottom':
83
+ # Iterate through the sequence and apply the criteria
84
+ for i in range(1, len(sequence) - 1):
85
+ if sequence[i] == 'G':
86
+ boolean_dict[f'{cat}_C_site'][i] = True
87
+ if sequence[i + 1] == 'C' and sequence[i - 1] != 'C':
88
+ boolean_dict[f'{cat}_GpC_site'][i] = True
89
+ elif sequence[i - 1] == 'C' and sequence[i + 1] == 'C':
90
+ boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
91
+ elif sequence[i - 1] == 'C' and sequence[i + 1] != 'C':
92
+ boolean_dict[f'{cat}_CpG_site'][i] = True
93
+ elif sequence[i - 1] != 'C' and sequence[i + 1] != 'C':
94
+ boolean_dict[f'{cat}_other_C_site'][i] = True
95
+ else:
96
+ print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
97
+
98
+ if 'A' in mod_target_bases:
99
+ if strand == 'top':
100
+ # Iterate through the sequence and apply the criteria
101
+ for i in range(1, len(sequence) - 1):
102
+ if sequence[i] == 'A':
103
+ boolean_dict[f'{cat}_A_site'][i] = True
104
+ elif strand == 'bottom':
105
+ # Iterate through the sequence and apply the criteria
106
+ for i in range(1, len(sequence) - 1):
107
+ if sequence[i] == 'T':
108
+ boolean_dict[f'{cat}_A_site'][i] = True
109
+ else:
110
+ print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
111
+
112
+ for site_type in site_types:
113
+ adata.var[f'{cat}_{site_type}'] = boolean_dict[f'{cat}_{site_type}'].astype(bool)
114
+ if native:
115
+ adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].layers['binarized_methylation']
116
+ else:
117
+ adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].X
118
+
119
+ # mark as done
120
+ adata.uns[uns_flag] = True
121
+
122
+ return None
@@ -0,0 +1,143 @@
1
+ import numpy as np
2
+ import scipy.sparse as sp
3
+
4
+ def append_binary_layer_by_base_context(
5
+ adata,
6
+ reference_column: str,
7
+ smf_modality: str = "conversion",
8
+ verbose: bool = True,
9
+ uns_flag: str = "binary_layers_by_base_context_added",
10
+ bypass: bool = False,
11
+ force_redo: bool = False
12
+ ):
13
+ """
14
+ Build per-reference C/G-site masked layers:
15
+ - GpC_site_binary
16
+ - CpG_site_binary
17
+ - GpC_CpG_combined_site_binary (numeric sum where present; NaN where neither present)
18
+ - C_site_binary
19
+ - other_C_site_binary
20
+
21
+ Behavior:
22
+ - If X is sparse it will be converted to dense for these layers (keeps original adata.X untouched).
23
+ - Missing var columns are warned about but do not crash.
24
+ - Masked positions are filled with np.nan to make masked vs unmasked explicit.
25
+ - Requires append_base_context to be run first
26
+ """
27
+
28
+ # Only run if not already performed
29
+ already = bool(adata.uns.get(uns_flag, False))
30
+ if (already and not force_redo) or bypass or ("base_context_added" not in adata.uns):
31
+ # QC already performed; nothing to do
32
+ return adata
33
+
34
+ # check inputs
35
+ if reference_column not in adata.obs.columns:
36
+ raise KeyError(f"reference_column '{reference_column}' not found in adata.obs")
37
+
38
+ # modality flag (kept for your potential use)
39
+ if smf_modality != "direct":
40
+ if smf_modality == "conversion":
41
+ deaminase = False
42
+ else:
43
+ deaminase = True
44
+ else:
45
+ deaminase = None # unused but preserved
46
+
47
+ # expected per-reference var column names
48
+ references = adata.obs[reference_column].astype("category").cat.categories
49
+ reference_to_gpc_column = {ref: f"{ref}_GpC_site" for ref in references}
50
+ reference_to_cpg_column = {ref: f"{ref}_CpG_site" for ref in references}
51
+ reference_to_c_column = {ref: f"{ref}_C_site" for ref in references}
52
+ reference_to_other_c_column = {ref: f"{ref}_other_C_site" for ref in references}
53
+
54
+ # verify var columns exist and build boolean masks per ref (len = n_vars)
55
+ n_obs, n_vars = adata.shape
56
+ def _col_mask_or_warn(colname):
57
+ if colname not in adata.var.columns:
58
+ if verbose:
59
+ print(f"Warning: var column '{colname}' not found; treating as all-False mask.")
60
+ return np.zeros(n_vars, dtype=bool)
61
+ vals = adata.var[colname].values
62
+ # coerce truthiness
63
+ try:
64
+ return vals.astype(bool)
65
+ except Exception:
66
+ return np.array([bool(v) for v in vals], dtype=bool)
67
+
68
+ gpc_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_gpc_column.items()}
69
+ cpg_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_cpg_column.items()}
70
+ c_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_c_column.items()}
71
+ other_c_var_masks = {ref: _col_mask_or_warn(col) for ref, col in reference_to_other_c_column.items()}
72
+
73
+ # prepare X as dense float32 for layer filling (we leave adata.X untouched)
74
+ X = adata.X
75
+ if sp.issparse(X):
76
+ if verbose:
77
+ print("Converting sparse X to dense array for layer construction (temporary).")
78
+ X = X.toarray()
79
+ X = np.asarray(X, dtype=np.float32)
80
+
81
+ # initialize masked arrays filled with NaN
82
+ masked_gpc = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
83
+ masked_cpg = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
84
+ masked_any_c = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
85
+ masked_other_c = np.full((n_obs, n_vars), np.nan, dtype=np.float32)
86
+
87
+ # fill row-blocks per reference (this avoids creating a full row×var boolean mask)
88
+ obs_ref_series = adata.obs[reference_column]
89
+ for ref in references:
90
+ rows_mask = (obs_ref_series.values == ref)
91
+ if not rows_mask.any():
92
+ continue
93
+ row_idx = np.nonzero(rows_mask)[0] # integer indices of rows for this ref
94
+
95
+ # column masks for this ref
96
+ gpc_cols = gpc_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
97
+ cpg_cols = cpg_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
98
+ c_cols = c_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
99
+ other_c_cols = other_c_var_masks.get(ref, np.zeros(n_vars, dtype=bool))
100
+
101
+ if gpc_cols.any():
102
+ # assign only the submatrix (rows x selected cols)
103
+ masked_gpc[np.ix_(row_idx, gpc_cols)] = X[np.ix_(row_idx, gpc_cols)]
104
+ if cpg_cols.any():
105
+ masked_cpg[np.ix_(row_idx, cpg_cols)] = X[np.ix_(row_idx, cpg_cols)]
106
+ if c_cols.any():
107
+ masked_any_c[np.ix_(row_idx, c_cols)] = X[np.ix_(row_idx, c_cols)]
108
+ if other_c_cols.any():
109
+ masked_other_c[np.ix_(row_idx, other_c_cols)] = X[np.ix_(row_idx, other_c_cols)]
110
+
111
+ # Build combined layer:
112
+ # - numeric_sum: sum where either exists, NaN where neither exists
113
+ # we compute numeric sum but preserve NaN where both are NaN
114
+ gpc_nan = np.isnan(masked_gpc)
115
+ cpg_nan = np.isnan(masked_cpg)
116
+ combined_sum = np.nan_to_num(masked_gpc, nan=0.0) + np.nan_to_num(masked_cpg, nan=0.0)
117
+ both_nan = gpc_nan & cpg_nan
118
+ combined_sum[both_nan] = np.nan
119
+
120
+ # Alternative: if you prefer a boolean OR combined layer, uncomment:
121
+ # combined_bool = (~gpc_nan & (masked_gpc != 0)) | (~cpg_nan & (masked_cpg != 0))
122
+ # combined_layer = combined_bool.astype(np.float32)
123
+
124
+ adata.layers['GpC_site_binary'] = masked_gpc
125
+ adata.layers['CpG_site_binary'] = masked_cpg
126
+ adata.layers['GpC_CpG_combined_site_binary'] = combined_sum
127
+ adata.layers['C_site_binary'] = masked_any_c
128
+ adata.layers['other_C_site_binary'] = masked_other_c
129
+
130
+ if verbose:
131
+ def _filled_positions(arr):
132
+ return int(np.sum(~np.isnan(arr)))
133
+ print("Layer build summary (non-NaN cell counts):")
134
+ print(f" GpC: {_filled_positions(masked_gpc)}")
135
+ print(f" CpG: {_filled_positions(masked_cpg)}")
136
+ print(f" GpC+CpG combined: {_filled_positions(combined_sum)}")
137
+ print(f" C: {_filled_positions(masked_any_c)}")
138
+ print(f" other_C: {_filled_positions(masked_other_c)}")
139
+
140
+ # mark as done
141
+ adata.uns[uns_flag] = True
142
+
143
+ return adata
@@ -0,0 +1,17 @@
1
+ import numpy as np
2
+
3
+ def binarize_adata(adata, source="X", target_layer="binary", threshold=0.8):
4
+ """
5
+ Binarize a dense matrix and preserve NaN.
6
+ source: "X" or layer name
7
+ """
8
+ X = adata.X if source == "X" else adata.layers[source]
9
+
10
+ # Copy to avoid modifying original in-place
11
+ X_bin = X.copy()
12
+
13
+ # Where not NaN: apply threshold
14
+ mask = ~np.isnan(X_bin)
15
+ X_bin[mask] = (X_bin[mask] > threshold).astype(np.int8)
16
+
17
+ adata.layers[target_layer] = X_bin
@@ -1,4 +1,4 @@
1
- def binarize_on_Youden(adata, obs_column='Reference'):
1
+ def binarize_on_Youden(adata, obs_column='Reference', output_layer_name='binarized_methylation'):
2
2
  """
3
3
  Binarize SMF values based on position thresholds determined by calculate_position_Youden.
4
4
 
@@ -42,4 +42,4 @@ def binarize_on_Youden(adata, obs_column='Reference'):
42
42
  binarized_methylation[cat_mask, :] = binarized_matrix
43
43
 
44
44
  # Store the binarized matrix in a new layer
45
- adata.layers['binarized_methylation'] = binarized_methylation
45
+ adata.layers[output_layer_name] = binarized_methylation
@@ -0,0 +1,248 @@
1
+ from typing import Optional
2
+ def calculate_complexity_II(
3
+ adata,
4
+ output_directory='',
5
+ sample_col='Sample_names',
6
+ ref_col: Optional[str] = 'Reference_strand',
7
+ cluster_col='sequence__merged_cluster_id',
8
+ plot=True,
9
+ save_plot=False,
10
+ n_boot=30,
11
+ n_depths=12,
12
+ random_state=0,
13
+ csv_summary=True,
14
+ uns_flag='complexity_analysis_complete',
15
+ force_redo=False,
16
+ bypass=False
17
+ ):
18
+ """
19
+ Estimate and plot library complexity.
20
+
21
+ If ref_col is None (default), behaves as before: one calculation per sample.
22
+ If ref_col is provided, computes complexity for each (sample, ref) pair.
23
+
24
+ Results:
25
+ - adata.uns['Library_complexity_results'] : dict keyed by (sample,) or (sample, ref) -> dict with fields
26
+ C0, n_reads, n_unique, depths, mean_unique, ci_low, ci_high
27
+ - Also stores per-entity record in adata.uns[f'Library_complexity_{sanitized_name}'] (backwards compatible)
28
+ - Optionally saves PNGs and CSVs (curve points + fit summary)
29
+ """
30
+ import os
31
+ import numpy as np
32
+ import pandas as pd
33
+ import matplotlib.pyplot as plt
34
+ from scipy.optimize import curve_fit
35
+ from datetime import datetime
36
+
37
+ # early exits
38
+ already = bool(adata.uns.get(uns_flag, False))
39
+ if (already and not force_redo):
40
+ return None
41
+ if bypass:
42
+ return None
43
+
44
+ rng = np.random.default_rng(random_state)
45
+
46
+ def lw(x, C0):
47
+ return C0 * (1.0 - np.exp(-x / C0))
48
+
49
+ def sanitize(name: str) -> str:
50
+ return "".join(c if c.isalnum() or c in "-._" else "_" for c in str(name))
51
+
52
+ # checks
53
+ for col in (sample_col, cluster_col):
54
+ if col not in adata.obs.columns:
55
+ raise KeyError(f"Required column '{col}' not found in adata.obs")
56
+ if ref_col is not None and ref_col not in adata.obs.columns:
57
+ raise KeyError(f"ref_col '{ref_col}' not found in adata.obs")
58
+
59
+ if save_plot or csv_summary:
60
+ os.makedirs(output_directory or ".", exist_ok=True)
61
+
62
+ # containers to collect CSV rows across all groups
63
+ fit_records = []
64
+ curve_records = []
65
+
66
+ # output dict stored centrally
67
+ results = {}
68
+
69
+ # build list of groups: either samples only, or (sample, ref) pairs
70
+ sseries = adata.obs[sample_col].astype("category")
71
+ samples = list(sseries.cat.categories)
72
+ if ref_col is None:
73
+ group_keys = [(s,) for s in samples]
74
+ else:
75
+ rseries = adata.obs[ref_col].astype("category")
76
+ references = list(rseries.cat.categories)
77
+ group_keys = []
78
+ # iterate only pairs that exist in data to avoid empty processing
79
+ for s in samples:
80
+ mask_s = (adata.obs[sample_col] == s)
81
+ # find references present for this sample
82
+ ref_present = pd.Categorical(adata.obs.loc[mask_s, ref_col]).categories
83
+ # Use intersection of known reference categories and those present for sample
84
+ for r in ref_present:
85
+ group_keys.append((s, r))
86
+
87
+ # iterate groups
88
+ for g in group_keys:
89
+ if ref_col is None:
90
+ sample = g[0]
91
+ # filter mask
92
+ mask = (adata.obs[sample_col] == sample).values
93
+ group_label = f"{sample}"
94
+ else:
95
+ sample, ref = g
96
+ mask = (adata.obs[sample_col] == sample) & (adata.obs[ref_col] == ref)
97
+ group_label = f"{sample}__{ref}"
98
+
99
+ n_reads = int(mask.sum())
100
+ if n_reads < 2:
101
+ # store empty placeholders and continue
102
+ results[g] = {
103
+ "C0": np.nan,
104
+ "n_reads": int(n_reads),
105
+ "n_unique": 0,
106
+ "depths": np.array([], dtype=int),
107
+ "mean_unique": np.array([], dtype=float),
108
+ "ci_low": np.array([], dtype=float),
109
+ "ci_high": np.array([], dtype=float),
110
+ }
111
+ # also store back-compat key
112
+ adata.uns[f'Library_complexity_{sanitize(group_label)}'] = results[g]
113
+ continue
114
+
115
+ # cluster ids array for this group
116
+ clusters = adata.obs.loc[mask, cluster_col].to_numpy()
117
+ # observed unique molecules at full depth
118
+ observed_unique = int(pd.unique(clusters).size)
119
+
120
+ # choose subsampling depths
121
+ if n_depths < 2:
122
+ depths = np.array([n_reads], dtype=int)
123
+ else:
124
+ lo = max(10, int(0.05 * n_reads))
125
+ depths = np.unique(np.linspace(lo, n_reads, n_depths, dtype=int))
126
+ depths = depths[depths > 0]
127
+ depths = depths.astype(int)
128
+ if depths.size == 0:
129
+ depths = np.array([n_reads], dtype=int)
130
+
131
+ # bootstrap sampling: for each depth, sample without replacement (if possible)
132
+ idx_all = np.arange(n_reads)
133
+ boot_unique = np.zeros((len(depths), n_boot), dtype=float)
134
+ for di, d in enumerate(depths):
135
+ d_use = int(min(d, n_reads))
136
+ # if d_use == n_reads we can short-circuit and set boot results to full observed uniques
137
+ if d_use == n_reads:
138
+ # bootstraps are deterministic in this special case
139
+ uniq_val = float(observed_unique)
140
+ boot_unique[di, :] = uniq_val
141
+ continue
142
+ # otherwise run bootstraps
143
+ for b in range(n_boot):
144
+ take = rng.choice(idx_all, size=d_use, replace=False)
145
+ boot_unique[di, b] = np.unique(clusters[take]).size
146
+
147
+ mean_unique = boot_unique.mean(axis=1)
148
+ lo_ci = np.percentile(boot_unique, 2.5, axis=1)
149
+ hi_ci = np.percentile(boot_unique, 97.5, axis=1)
150
+
151
+ # fit Lander-Waterman to the mean curve (safe bounds)
152
+ C0_init = max(observed_unique, mean_unique[-1] if mean_unique.size else observed_unique)
153
+ try:
154
+ popt, _ = curve_fit(
155
+ lw,
156
+ xdata=depths.astype(float),
157
+ ydata=mean_unique.astype(float),
158
+ p0=[C0_init],
159
+ bounds=(1.0, 1e12),
160
+ maxfev=10000,
161
+ )
162
+ C0 = float(popt[0])
163
+ except Exception:
164
+ C0 = float(observed_unique)
165
+
166
+ # store results
167
+ results[g] = {
168
+ "C0": C0,
169
+ "n_reads": int(n_reads),
170
+ "n_unique": int(observed_unique),
171
+ "depths": depths,
172
+ "mean_unique": mean_unique,
173
+ "ci_low": lo_ci,
174
+ "ci_high": hi_ci,
175
+ }
176
+
177
+ # save per-group in adata.uns for backward compatibility
178
+ adata.uns[f'Library_complexity_{sanitize(group_label)}'] = results[g]
179
+
180
+ # prepare curve and fit records for CSV
181
+ fit_records.append({
182
+ "sample": sample,
183
+ "reference": ref if ref_col is not None else "",
184
+ "C0": float(C0),
185
+ "n_reads": int(n_reads),
186
+ "n_unique_observed": int(observed_unique),
187
+ })
188
+
189
+ x_fit = np.linspace(0, max(n_reads, int(depths[-1]) if depths.size else n_reads), 200)
190
+ y_fit = lw(x_fit, C0)
191
+ for d, mu, lo, hi in zip(depths, mean_unique, lo_ci, hi_ci):
192
+ curve_records.append({
193
+ "sample": sample,
194
+ "reference": ref if ref_col is not None else "",
195
+ "type": "bootstrap",
196
+ "depth": int(d),
197
+ "mean_unique": float(mu),
198
+ "ci_low": float(lo),
199
+ "ci_high": float(hi),
200
+ })
201
+ for xf, yf in zip(x_fit, y_fit):
202
+ curve_records.append({
203
+ "sample": sample,
204
+ "reference": ref if ref_col is not None else "",
205
+ "type": "fit",
206
+ "depth": float(xf),
207
+ "mean_unique": float(yf),
208
+ "ci_low": np.nan,
209
+ "ci_high": np.nan,
210
+ })
211
+
212
+ # plotting for this group
213
+ if plot:
214
+ plt.figure(figsize=(6.5, 4.5))
215
+ plt.fill_between(depths, lo_ci, hi_ci, alpha=0.25, label="Bootstrap 95% CI")
216
+ plt.plot(depths, mean_unique, "o", label="Bootstrap mean")
217
+ plt.plot([n_reads], [observed_unique], "s", label="Observed (full)")
218
+ plt.plot(x_fit, y_fit, "-", label=f"LW fit C0≈{C0:,.0f}")
219
+ plt.xlabel("Total reads (subsampled depth)")
220
+ plt.ylabel("Unique molecules (clusters)")
221
+ title = f"Library Complexity — {sample}" + (f" / {ref}" if ref_col is not None else "")
222
+ plt.title(title)
223
+ plt.grid(True, alpha=0.3)
224
+ plt.legend()
225
+ plt.tight_layout()
226
+
227
+ if save_plot:
228
+ fname = f"complexity_{sanitize(group_label)}.png"
229
+ plt.savefig(os.path.join(output_directory or ".", fname), dpi=160, bbox_inches="tight")
230
+ plt.close()
231
+ else:
232
+ plt.show()
233
+
234
+ # store central results dict
235
+ adata.uns["Library_complexity_results"] = results
236
+
237
+ # mark complexity analysis as complete
238
+ adata.uns[uns_flag] = True
239
+
240
+ # CSV outputs
241
+ if csv_summary and (fit_records or curve_records):
242
+ fit_df = pd.DataFrame(fit_records)
243
+ curve_df = pd.DataFrame(curve_records)
244
+ base = output_directory or "."
245
+ fit_df.to_csv(os.path.join(base, f"complexity_fit_summary.csv"), index=False)
246
+ curve_df.to_csv(os.path.join(base, f"complexity_curves.csv"), index=False)
247
+
248
+ return results