smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,197 @@
1
+ from pathlib import Path
2
+ import pandas as pd
3
+ import numpy as np
4
+ import scipy.sparse as sp
5
+ from typing import Optional, List, Dict, Union
6
+
7
+ def add_demux_type_annotation(
8
+ adata,
9
+ double_demux_source,
10
+ sep: str = "\t",
11
+ read_id_col: str = "read_id",
12
+ barcode_col: str = "barcode",
13
+ ):
14
+ """
15
+ Add adata.obs["demux_type"]:
16
+ - "double" if read_id appears in the *double demux* TSV
17
+ - "single" otherwise
18
+
19
+ Rows where barcode == "unclassified" in the demux TSV are ignored.
20
+
21
+ Parameters
22
+ ----------
23
+ adata : AnnData
24
+ AnnData object whose obs_names are read_ids.
25
+ double_demux_source : str | Path | list[str]
26
+ Either:
27
+ - path to a TSV/TXT of dorado demux results
28
+ - a list of read_ids
29
+ """
30
+
31
+ # -----------------------------
32
+ # If it's a file → load TSV
33
+ # -----------------------------
34
+ if isinstance(double_demux_source, (str, Path)):
35
+ file_path = Path(double_demux_source)
36
+ if not file_path.exists():
37
+ raise FileNotFoundError(f"File does not exist: {file_path}")
38
+
39
+ df = pd.read_csv(file_path, sep=sep, dtype=str)
40
+
41
+ # If the file has only one column → treat as a simple read list
42
+ if df.shape[1] == 1:
43
+ read_ids = df.iloc[:, 0].tolist()
44
+ else:
45
+ # Validate columns
46
+ if read_id_col not in df.columns:
47
+ raise ValueError(f"TSV must contain a '{read_id_col}' column.")
48
+ if barcode_col not in df.columns:
49
+ raise ValueError(f"TSV must contain a '{barcode_col}' column.")
50
+
51
+ # Drop unclassified reads
52
+ df = df[df[barcode_col].str.lower() != "unclassified"]
53
+
54
+ # Extract read_ids
55
+ read_ids = df[read_id_col].tolist()
56
+
57
+ # -----------------------------
58
+ # If user supplied list-of-ids
59
+ # -----------------------------
60
+ else:
61
+ read_ids = list(double_demux_source)
62
+
63
+ # Deduplicate for speed
64
+ double_set = set(read_ids)
65
+
66
+ # Boolean lookup in AnnData
67
+ is_double = adata.obs_names.isin(double_set)
68
+
69
+ adata.obs["demux_type"] = np.where(is_double, "double", "single")
70
+ adata.obs["demux_type"] = adata.obs["demux_type"].astype("category")
71
+
72
+ return adata
73
+
74
+ def add_read_length_and_mapping_qc(
75
+ adata,
76
+ bam_files: Optional[List[str]] = None,
77
+ read_metrics: Optional[Dict[str, Union[list, tuple]]] = None,
78
+ uns_flag: str = "read_lenth_and_mapping_qc_performed",
79
+ extract_read_features_from_bam_callable = None,
80
+ bypass: bool = False,
81
+ force_redo: bool = True
82
+ ):
83
+ """
84
+ Populate adata.obs with read/mapping QC columns.
85
+
86
+ Parameters
87
+ ----------
88
+ adata
89
+ AnnData to annotate (modified in-place).
90
+ bam_files
91
+ Optional list of BAM files to extract metrics from. Ignored if read_metrics supplied.
92
+ read_metrics
93
+ Optional dict mapping obs_name -> [read_length, read_quality, reference_length, mapped_length, mapping_quality]
94
+ If provided, this will be used directly and bam_files will be ignored.
95
+ uns_flag
96
+ key in final_adata.uns used to record that QC was performed (kept the name with original misspelling).
97
+ extract_read_features_from_bam_callable
98
+ Optional callable(bam_path) -> dict mapping read_name -> list/tuple of metrics.
99
+ If not provided and bam_files is given, function will attempt to call `extract_read_features_from_bam`
100
+ from the global namespace (your existing helper).
101
+ Returns
102
+ -------
103
+ None (mutates final_adata in-place)
104
+ """
105
+
106
+ # Only run if not already performed
107
+ already = bool(adata.uns.get(uns_flag, False))
108
+ if (already and not force_redo) or bypass:
109
+ # QC already performed; nothing to do
110
+ return
111
+
112
+ # Build read_metrics dict either from provided arg or by extracting from bam files
113
+ if read_metrics is None:
114
+ read_metrics = {}
115
+ if bam_files:
116
+ extractor = extract_read_features_from_bam_callable or globals().get("extract_read_features_from_bam")
117
+ if extractor is None:
118
+ raise ValueError("No `read_metrics` provided and `extract_read_features_from_bam` not found.")
119
+ for bam in bam_files:
120
+ bam_read_metrics = extractor(bam)
121
+ if not isinstance(bam_read_metrics, dict):
122
+ raise ValueError(f"extract_read_features_from_bam returned non-dict for {bam}")
123
+ read_metrics.update(bam_read_metrics)
124
+ else:
125
+ # nothing to do
126
+ read_metrics = {}
127
+
128
+ # Convert read_metrics dict -> DataFrame (rows = read id)
129
+ # Values may be lists/tuples or scalars; prefer lists/tuples with 5 entries.
130
+ if len(read_metrics) == 0:
131
+ # fill with NaNs
132
+ n = adata.n_obs
133
+ adata.obs['read_length'] = np.full(n, np.nan)
134
+ adata.obs['mapped_length'] = np.full(n, np.nan)
135
+ adata.obs['reference_length'] = np.full(n, np.nan)
136
+ adata.obs['read_quality'] = np.full(n, np.nan)
137
+ adata.obs['mapping_quality'] = np.full(n, np.nan)
138
+ else:
139
+ # Build DF robustly
140
+ # Convert values to lists where possible, else to [val, val, val...]
141
+ max_cols = 5
142
+ rows = {}
143
+ for k, v in read_metrics.items():
144
+ if isinstance(v, (list, tuple, np.ndarray)):
145
+ vals = list(v)
146
+ else:
147
+ # scalar -> replicate into 5 columns to preserve original behavior
148
+ vals = [v] * max_cols
149
+ # Ensure length >= 5
150
+ if len(vals) < max_cols:
151
+ vals = vals + [np.nan] * (max_cols - len(vals))
152
+ rows[k] = vals[:max_cols]
153
+
154
+ df = pd.DataFrame.from_dict(rows, orient='index', columns=[
155
+ 'read_length', 'read_quality', 'reference_length', 'mapped_length', 'mapping_quality'
156
+ ])
157
+
158
+ # Reindex to final_adata.obs_names so order matches adata
159
+ # If obs_names are not present as keys in df, the results will be NaN
160
+ df_reindexed = df.reindex(adata.obs_names).astype(float)
161
+
162
+ adata.obs['read_length'] = df_reindexed['read_length'].values
163
+ adata.obs['mapped_length'] = df_reindexed['mapped_length'].values
164
+ adata.obs['reference_length'] = df_reindexed['reference_length'].values
165
+ adata.obs['read_quality'] = df_reindexed['read_quality'].values
166
+ adata.obs['mapping_quality'] = df_reindexed['mapping_quality'].values
167
+
168
+ # Compute ratio columns safely (avoid divide-by-zero and preserve NaN)
169
+ # read_length_to_reference_length_ratio
170
+ rl = pd.to_numeric(adata.obs['read_length'], errors='coerce').to_numpy(dtype=float)
171
+ ref_len = pd.to_numeric(adata.obs['reference_length'], errors='coerce').to_numpy(dtype=float)
172
+ mapped_len = pd.to_numeric(adata.obs['mapped_length'], errors='coerce').to_numpy(dtype=float)
173
+
174
+ # safe divisions: use np.where to avoid warnings and replace inf with nan
175
+ with np.errstate(divide='ignore', invalid='ignore'):
176
+ rl_to_ref = np.where((ref_len != 0) & np.isfinite(ref_len), rl / ref_len, np.nan)
177
+ mapped_to_ref = np.where((ref_len != 0) & np.isfinite(ref_len), mapped_len / ref_len, np.nan)
178
+ mapped_to_read = np.where((rl != 0) & np.isfinite(rl), mapped_len / rl, np.nan)
179
+
180
+ adata.obs['read_length_to_reference_length_ratio'] = rl_to_ref
181
+ adata.obs['mapped_length_to_reference_length_ratio'] = mapped_to_ref
182
+ adata.obs['mapped_length_to_read_length_ratio'] = mapped_to_read
183
+
184
+ # Add read level raw modification signal: sum over X rows
185
+ X = adata.X
186
+ if sp.issparse(X):
187
+ # sum returns (n_obs, 1) sparse matrix; convert to 1d array
188
+ raw_sig = np.asarray(X.sum(axis=1)).ravel()
189
+ else:
190
+ raw_sig = np.asarray(X.sum(axis=1)).ravel()
191
+
192
+ adata.obs['Raw_modification_signal'] = raw_sig
193
+
194
+ # mark as done
195
+ adata.uns[uns_flag] = True
196
+
197
+ return None
@@ -1,11 +1,12 @@
1
- ## modkit_extract_to_adata
2
-
3
1
  import concurrent.futures
4
2
  import gc
5
- from .count_aligned_reads import count_aligned_reads
3
+ from .bam_functions import count_aligned_reads
6
4
  import pandas as pd
7
5
  from tqdm import tqdm
8
6
  import numpy as np
7
+ from pathlib import Path
8
+ from typing import Union, Iterable, Optional
9
+ import shutil
9
10
 
10
11
  def filter_bam_records(bam, mapping_threshold):
11
12
  """Processes a single BAM file, counts reads, and determines records to analyze."""
@@ -336,29 +337,122 @@ def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference
336
337
  dict_list[dict_index][record][sample] = processed_data
337
338
  return dict_list
338
339
 
339
- def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name, mods, batch_size, mod_tsv_dir, delete_batch_hdfs=False, threads=None):
340
+ def delete_intermediate_h5ads_and_tmpdir(
341
+ h5_dir: Union[str, Path, Iterable[str], None],
342
+ tmp_dir: Optional[Union[str, Path]] = None,
343
+ *,
344
+ dry_run: bool = False,
345
+ verbose: bool = True,
346
+ ):
347
+ """
348
+ Delete intermediate .h5ad files and a temporary directory.
349
+
350
+ Parameters
351
+ ----------
352
+ h5_dir : str | Path | iterable[str] | None
353
+ If a directory path is given, all files directly inside it will be considered.
354
+ If an iterable of file paths is given, those files will be considered.
355
+ Only files ending with '.h5ad' (and not ending with '.gz') are removed.
356
+ tmp_dir : str | Path | None
357
+ Path to a directory to remove recursively (e.g. a temp dir created earlier).
358
+ dry_run : bool
359
+ If True, print what *would* be removed but do not actually delete.
360
+ verbose : bool
361
+ Print progress / warnings.
362
+ """
363
+ # Helper: remove a single file path (Path-like or string)
364
+ def _maybe_unlink(p: Path):
365
+ if not p.exists():
366
+ if verbose:
367
+ print(f"[skip] not found: {p}")
368
+ return
369
+ if not p.is_file():
370
+ if verbose:
371
+ print(f"[skip] not a file: {p}")
372
+ return
373
+ if dry_run:
374
+ print(f"[dry-run] would remove file: {p}")
375
+ return
376
+ try:
377
+ p.unlink()
378
+ if verbose:
379
+ print(f"Removed file: {p}")
380
+ except Exception as e:
381
+ print(f"[error] failed to remove file {p}: {e}")
382
+
383
+ # Handle h5_dir input (directory OR iterable of file paths)
384
+ if h5_dir is not None:
385
+ # If it's a path to a directory, iterate its children
386
+ if isinstance(h5_dir, (str, Path)) and Path(h5_dir).is_dir():
387
+ dpath = Path(h5_dir)
388
+ for p in dpath.iterdir():
389
+ # only target top-level files (not recursing); require '.h5ad' suffix and exclude gz
390
+ name = p.name.lower()
391
+ if "h5ad" in name:
392
+ _maybe_unlink(p)
393
+ else:
394
+ if verbose:
395
+ # optional: comment this out if too noisy
396
+ print(f"[skip] not matching pattern: {p.name}")
397
+ else:
398
+ # treat as iterable of file paths
399
+ for f in h5_dir:
400
+ p = Path(f)
401
+ name = p.name.lower()
402
+ if name.endswith(".h5ad") and not name.endswith(".gz"):
403
+ _maybe_unlink(p)
404
+ else:
405
+ if verbose:
406
+ print(f"[skip] not matching pattern or not a file: {p}")
407
+
408
+ # Remove tmp_dir recursively (if provided)
409
+ if tmp_dir is not None:
410
+ td = Path(tmp_dir)
411
+ if not td.exists():
412
+ if verbose:
413
+ print(f"[skip] tmp_dir not found: {td}")
414
+ else:
415
+ if not td.is_dir():
416
+ if verbose:
417
+ print(f"[skip] tmp_dir is not a directory: {td}")
418
+ else:
419
+ if dry_run:
420
+ print(f"[dry-run] would remove directory tree: {td}")
421
+ else:
422
+ try:
423
+ shutil.rmtree(td)
424
+ if verbose:
425
+ print(f"Removed directory tree: {td}")
426
+ except Exception as e:
427
+ print(f"[error] failed to remove tmp dir {td}: {e}")
428
+
429
+ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapping_threshold, experiment_name, mods, batch_size, mod_tsv_dir, delete_batch_hdfs=False, threads=None, double_barcoded_path = None):
340
430
  """
341
431
  Takes modkit extract outputs and organizes it into an adata object
342
432
 
343
433
  Parameters:
344
- fasta (str): File path to the reference genome to align to.
345
- bam_dir (str): File path to the directory containing the aligned_sorted split modified BAM files
434
+ fasta (Path): File path to the reference genome to align to.
435
+ bam_dir (Path): File path to the directory containing the aligned_sorted split modified BAM files
436
+ out_dir (Path): File path to output directory
437
+ input_already_demuxed (bool): Whether input reads were originally demuxed
346
438
  mapping_threshold (float): A value in between 0 and 1 to threshold the minimal fraction of aligned reads which map to the reference region. References with values above the threshold are included in the output adata.
347
439
  experiment_name (str): A string to provide an experiment name to the output adata file.
348
440
  mods (list): A list of strings of the modification types to use in the analysis.
349
441
  batch_size (int): An integer number of TSV files to analyze in memory at once while loading the final adata object.
350
- mod_tsv_dir (str): String representing the path to the mod TSV directory
442
+ mod_tsv_dir (Path): path to the mod TSV directory
351
443
  delete_batch_hdfs (bool): Whether to delete the batch hdfs after writing out the final concatenated hdf. Default is False
444
+ double_barcoded_path (Path): Path to dorado demux summary file of double ended barcodes
352
445
 
353
446
  Returns:
354
- final_adata_path (str): Path to the final adata
447
+ final_adata_path (Path): Path to the final adata
355
448
  """
356
449
  ###################################################
357
450
  # Package imports
358
451
  from .. import readwrite
359
- from .get_native_references import get_native_references
360
- from .extract_base_identities import extract_base_identities
361
- from .ohe_batching import ohe_batching
452
+ from ..readwrite import safe_write_h5ad, make_dirs
453
+ from .fasta_functions import get_native_references
454
+ from .bam_functions import extract_base_identities
455
+ from .ohe import ohe_batching
362
456
  import pandas as pd
363
457
  import anndata as ad
364
458
  import os
@@ -368,41 +462,34 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
368
462
  from Bio.Seq import Seq
369
463
  from tqdm import tqdm
370
464
  import h5py
371
- from .make_dirs import make_dirs
372
465
  ###################################################
373
466
 
374
467
  ################## Get input tsv and bam file names into a sorted list ################
375
- # List all files in the directory
376
- tsv_files = os.listdir(mod_tsv_dir)
377
- bam_files = os.listdir(bam_dir)
378
- # get current working directory
379
- parent_dir = os.path.dirname(mod_tsv_dir)
380
-
381
468
  # Make output dirs
382
- h5_dir = os.path.join(parent_dir, 'h5ads')
383
- tmp_dir = os.path.join(parent_dir, 'tmp')
469
+ h5_dir = out_dir / 'h5ads'
470
+ tmp_dir = out_dir / 'tmp'
384
471
  make_dirs([h5_dir, tmp_dir])
385
- existing_h5s = os.listdir(h5_dir)
386
- existing_h5s = [h5 for h5 in existing_h5s if '.h5ad.gz' in h5]
387
- final_hdf = f'{experiment_name}_final_experiment_hdf5.h5ad'
388
- final_adata_path = os.path.join(h5_dir, final_hdf)
389
-
390
- if os.path.exists(f"{final_adata_path}.gz"):
391
- print(f'{final_adata_path}.gz already exists. Using existing adata')
392
- return f"{final_adata_path}.gz"
472
+
473
+ existing_h5s = h5_dir.iterdir()
474
+ existing_h5s = [h5 for h5 in existing_h5s if '.h5ad.gz' in str(h5)]
475
+ final_hdf = f'{experiment_name}.h5ad.gz'
476
+ final_adata_path = h5_dir / final_hdf
477
+ final_adata = None
393
478
 
394
- elif os.path.exists(f"{final_adata_path}"):
479
+ if final_adata_path.exists():
395
480
  print(f'{final_adata_path} already exists. Using existing adata')
396
- return final_adata_path
481
+ return final_adata, final_adata_path
397
482
 
398
- # Filter file names that contain the search string in their filename and keep them in a list
399
- tsvs = [tsv for tsv in tsv_files if 'extract.tsv' in tsv and 'unclassified' not in tsv]
400
- bams = [bam for bam in bam_files if '.bam' in bam and '.bai' not in bam and 'unclassified' not in bam]
401
- # Sort file list by names and print the list of file names
402
- tsvs.sort()
403
- tsv_path_list = [os.path.join(mod_tsv_dir, tsv) for tsv in tsvs]
404
- bams.sort()
405
- bam_path_list = [os.path.join(bam_dir, bam) for bam in bams]
483
+ # List all files in the directory
484
+ tsvs = sorted(
485
+ p for p in mod_tsv_dir.iterdir()
486
+ if p.is_file() and 'unclassified' not in p.name and 'extract.tsv' in p.name)
487
+ bams = sorted(
488
+ p for p in bam_dir.iterdir()
489
+ if p.is_file() and p.suffix == '.bam' and 'unclassified' not in p.name and '.bai' not in p.name)
490
+
491
+ tsv_path_list = [mod_tsv_dir / tsv for tsv in tsvs]
492
+ bam_path_list = [bam_dir / bam for bam in bams]
406
493
  print(f'{len(tsvs)} sample tsv files found: {tsvs}')
407
494
  print(f'{len(bams)} sample bams found: {bams}')
408
495
  ##########################################################################################
@@ -416,7 +503,7 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
416
503
  ########### Determine the maximum record length to analyze in the dataset ################
417
504
  # Get all references within the FASTA and indicate the length and identity of the record sequence
418
505
  max_reference_length = 0
419
- reference_dict = get_native_references(fasta) # returns a dict keyed by record name. Points to a tuple of (reference length, reference sequence)
506
+ reference_dict = get_native_references(str(fasta)) # returns a dict keyed by record name. Points to a tuple of (reference length, reference sequence)
420
507
  # Get the max record length in the dataset.
421
508
  for record in records_to_analyze:
422
509
  if reference_dict[record][0] > max_reference_length:
@@ -430,11 +517,11 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
430
517
  # One hot encode read sequences and write them out into the tmp_dir as h5ad files.
431
518
  # Save the file paths in the bam_record_ohe_files dict.
432
519
  bam_record_ohe_files = {}
433
- bam_record_save = os.path.join(tmp_dir, 'tmp_file_dict.h5ad')
520
+ bam_record_save = tmp_dir / 'tmp_file_dict.h5ad'
434
521
  fwd_mapped_reads = set()
435
522
  rev_mapped_reads = set()
436
523
  # If this step has already been performed, read in the tmp_dile_dict
437
- if os.path.exists(bam_record_save):
524
+ if bam_record_save.exists():
438
525
  bam_record_ohe_files = ad.read_h5ad(bam_record_save).uns
439
526
  print('Found existing OHE reads, using these')
440
527
  else:
@@ -444,8 +531,9 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
444
531
  for record in records_to_analyze:
445
532
  current_reference_length = reference_dict[record][0]
446
533
  positions = range(current_reference_length)
534
+ ref_seq = reference_dict[record][1]
447
535
  # Extract the base identities of reads aligned to the record
448
- fwd_base_identities, rev_base_identities = extract_base_identities(bam, record, positions, max_reference_length)
536
+ fwd_base_identities, rev_base_identities, mismatch_counts_per_read, mismatch_trend_per_read = extract_base_identities(bam, record, positions, max_reference_length, ref_seq)
449
537
  # Store read names of fwd and rev mapped reads
450
538
  fwd_mapped_reads.update(fwd_base_identities.keys())
451
539
  rev_mapped_reads.update(rev_base_identities.keys())
@@ -487,7 +575,7 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
487
575
  bam_path_list = bam_path_list[batch_size:]
488
576
  print('{0}: tsvs in batch {1} '.format(readwrite.time_string(), tsv_batch))
489
577
 
490
- batch_already_processed = sum([1 for h5 in existing_h5s if f'_{batch}_' in h5])
578
+ batch_already_processed = sum([1 for h5 in existing_h5s if f'_{batch}_' in h5.name])
491
579
  ###################################################
492
580
  if batch_already_processed:
493
581
  print(f'Batch {batch} has already been processed into h5ads. Skipping batch and using existing files')
@@ -675,7 +763,6 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
675
763
 
676
764
 
677
765
  # Save the sample files in the batch as gzipped hdf5 files
678
- os.chdir(h5_dir)
679
766
  print('{0}: Converting batch {1} dictionaries to anndata objects'.format(readwrite.time_string(), batch))
680
767
  for dict_index, dict_type in enumerate(dict_list):
681
768
  if dict_index not in dict_to_skip:
@@ -708,6 +795,7 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
708
795
  temp_adata.var_names = temp_adata.var_names.astype(str)
709
796
  print('{0}: Adding {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
710
797
  temp_adata.obs['Sample'] = [str(final_sample_index)] * len(temp_adata)
798
+ temp_adata.obs['Barcode'] = [str(final_sample_index)] * len(temp_adata)
711
799
  temp_adata.obs['Reference'] = [f'{record}'] * len(temp_adata)
712
800
  temp_adata.obs['Strand'] = [strand] * len(temp_adata)
713
801
  temp_adata.obs['Dataset'] = [dataset] * len(temp_adata)
@@ -804,7 +892,7 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
804
892
 
805
893
  try:
806
894
  print('{0}: Writing {1} anndata out as a hdf5 file'.format(readwrite.time_string(), sample_types[dict_index]))
807
- adata.write_h5ad('{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz'.format(readwrite.date_string(), batch, sample_types[dict_index]), compression='gzip')
895
+ adata.write_h5ad(h5_dir / '{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz'.format(readwrite.date_string(), batch, sample_types[dict_index]), compression='gzip')
808
896
  except:
809
897
  print(f"Skipping writing anndata for sample")
810
898
 
@@ -813,11 +901,10 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
813
901
  gc.collect()
814
902
 
815
903
  # Iterate over all of the batched hdf5 files and concatenate them.
816
- os.chdir(h5_dir)
817
- files = os.listdir(h5_dir)
904
+ files = h5_dir.iterdir()
818
905
  # Filter file names that contain the search string in their filename and keep them in a list
819
- hdfs = [hdf for hdf in files if 'hdf5.h5ad' in hdf and hdf != final_hdf]
820
- combined_hdfs = [hdf for hdf in hdfs if "combined" in hdf]
906
+ hdfs = [hdf for hdf in files if 'hdf5.h5ad' in hdf.name and hdf != final_hdf]
907
+ combined_hdfs = [hdf for hdf in hdfs if "combined" in hdf.name]
821
908
  if len(combined_hdfs) > 0:
822
909
  hdfs = combined_hdfs
823
910
  else:
@@ -825,7 +912,7 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
825
912
  # Sort file list by names and print the list of file names
826
913
  hdfs.sort()
827
914
  print('{0} sample files found: {1}'.format(len(hdfs), hdfs))
828
- hdf_paths = [os.path.join(h5_dir, hd5) for hd5 in hdfs]
915
+ hdf_paths = [h5_dir / hd5 for hd5 in hdfs]
829
916
  final_adata = None
830
917
  for hdf_index, hdf in enumerate(hdf_paths):
831
918
  print('{0}: Reading in {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
@@ -844,6 +931,7 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
844
931
 
845
932
  ohe_bases = ['A', 'C', 'G', 'T'] # ignore N bases for consensus
846
933
  ohe_layers = [f"{ohe_base}_binary_encoding" for ohe_base in ohe_bases]
934
+ final_adata.uns['References'] = {}
847
935
  for record in records_to_analyze:
848
936
  # Add FASTA sequence to the object
849
937
  sequence = record_seq_dict[record][0]
@@ -851,6 +939,7 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
851
939
  final_adata.var[f'{record}_top_strand_FASTA_base'] = list(sequence)
852
940
  final_adata.var[f'{record}_bottom_strand_FASTA_base'] = list(complement)
853
941
  final_adata.uns[f'{record}_FASTA_sequence'] = sequence
942
+ final_adata.uns['References'][f'{record}_FASTA_sequence'] = sequence
854
943
  # Add consensus sequence of samples mapped to the record to the object
855
944
  record_subset = final_adata[final_adata.obs['Reference'] == record]
856
945
  for strand in record_subset.obs['Strand'].cat.categories:
@@ -866,19 +955,16 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
866
955
  consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
867
956
  final_adata.var[f'{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples'] = consensus_sequence_list
868
957
 
869
- #final_adata.write_h5ad(final_adata_path)
958
+ if input_already_demuxed:
959
+ final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
960
+ final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
961
+ else:
962
+ from .h5ad_functions import add_demux_type_annotation
963
+ double_barcoded_reads = double_barcoded_path / "barcoding_summary.txt"
964
+ add_demux_type_annotation(final_adata, double_barcoded_reads)
870
965
 
871
966
  # Delete the individual h5ad files and only keep the final concatenated file
872
967
  if delete_batch_hdfs:
873
- files = os.listdir(h5_dir)
874
- hdfs_to_delete = [hdf for hdf in files if 'hdf5.h5ad' in hdf and hdf != final_hdf]
875
- hdf_paths_to_delete = [os.path.join(h5_dir, hdf) for hdf in hdfs_to_delete]
876
- # Iterate over the files and delete them
877
- for hdf in hdf_paths_to_delete:
878
- try:
879
- os.remove(hdf)
880
- print(f"Deleted file: {hdf}")
881
- except OSError as e:
882
- print(f"Error deleting file {hdf}: {e}")
968
+ delete_intermediate_h5ads_and_tmpdir(h5_dir, tmp_dir)
883
969
 
884
970
  return final_adata, final_adata_path
@@ -0,0 +1,129 @@
1
+ import os
2
+ import subprocess
3
+ import glob
4
+ import zipfile
5
+ from pathlib import Path
6
+
7
+ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassified=True, modkit_summary=False, threads=None):
8
+ """
9
+ Takes all of the aligned, sorted, split modified BAM files and runs Nanopore Modkit Extract to load the modification data into zipped TSV files
10
+
11
+ Parameters:
12
+ thresholds (list): A list of thresholds to use for marking each basecalled base as passing or failing on canonical and modification call status.
13
+ mod_tsv_dir (str): A string representing the file path to the directory to hold the modkit extract outputs.
14
+ split_dit (str): A string representing the file path to the directory containing the converted aligned_sorted_split BAM files.
15
+ bam_suffix (str): The suffix to use for the BAM file.
16
+ skip_unclassified (bool): Whether to skip unclassified bam file for modkit extract command
17
+ modkit_summary (bool): Whether to run and display modkit summary
18
+ threads (int): Number of threads to use
19
+
20
+ Returns:
21
+ None
22
+ Runs modkit extract on input aligned_sorted_split modified BAM files to output zipped TSVs containing modification calls.
23
+
24
+ """
25
+ filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
26
+ bam_files = sorted(p for p in split_dir.iterdir() if bam_suffix in p.name and '.bai' not in p.name)
27
+ if skip_unclassified:
28
+ bam_files = [p for p in bam_files if "unclassified" not in p.name]
29
+ print(f"Running modkit extract for the following bam files: {bam_files}")
30
+
31
+ if threads:
32
+ threads = str(threads)
33
+ else:
34
+ pass
35
+
36
+ for input_file in bam_files:
37
+ print(input_file)
38
+ # Construct the output TSV file path
39
+ output_tsv = mod_tsv_dir / (input_file.stem + "_extract.tsv")
40
+ output_tsv_gz = output_tsv.parent / (output_tsv.name + '.gz')
41
+ if output_tsv_gz.exists():
42
+ print(f"{output_tsv_gz} already exists, skipping modkit extract")
43
+ else:
44
+ print(f"Extracting modification data from {input_file}")
45
+ if modkit_summary:
46
+ # Run modkit summary
47
+ subprocess.run(["modkit", "summary", str(input_file)])
48
+ else:
49
+ pass
50
+ # Run modkit extract
51
+ if threads:
52
+ extract_command = [
53
+ "modkit", "extract",
54
+ "calls", "--mapped-only",
55
+ "--filter-threshold", f'{filter_threshold}',
56
+ "--mod-thresholds", f"m:{m5C_threshold}",
57
+ "--mod-thresholds", f"a:{m6A_threshold}",
58
+ "--mod-thresholds", f"h:{hm5C_threshold}",
59
+ "-t", threads,
60
+ str(input_file), str(output_tsv)
61
+ ]
62
+ else:
63
+ extract_command = [
64
+ "modkit", "extract",
65
+ "calls", "--mapped-only",
66
+ "--filter-threshold", f'{filter_threshold}',
67
+ "--mod-thresholds", f"m:{m5C_threshold}",
68
+ "--mod-thresholds", f"a:{m6A_threshold}",
69
+ "--mod-thresholds", f"h:{hm5C_threshold}",
70
+ str(input_file), str(output_tsv)
71
+ ]
72
+ subprocess.run(extract_command)
73
+ # Zip the output TSV
74
+ print(f'zipping {output_tsv}')
75
+ if threads:
76
+ zip_command = ["pigz", "-f", "-p", threads, str(output_tsv)]
77
+ else:
78
+ zip_command = ["pigz", "-f", str(output_tsv)]
79
+ subprocess.run(zip_command, check=True)
80
+ return
81
+
82
+ def make_modbed(aligned_sorted_output, thresholds, mod_bed_dir):
83
+ """
84
+ Generating position methylation summaries for each barcoded sample starting from the overall BAM file that was direct output of dorado aligner.
85
+ Parameters:
86
+ aligned_sorted_output (str): A string representing the file path to the aligned_sorted non-split BAM file.
87
+
88
+ Returns:
89
+ None
90
+ """
91
+ import os
92
+ import subprocess
93
+
94
+ filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
95
+ command = [
96
+ "modkit", "pileup", str(aligned_sorted_output), str(mod_bed_dir),
97
+ "--partition-tag", "BC",
98
+ "--only-tabs",
99
+ "--filter-threshold", f'{filter_threshold}',
100
+ "--mod-thresholds", f"m:{m5C_threshold}",
101
+ "--mod-thresholds", f"a:{m6A_threshold}",
102
+ "--mod-thresholds", f"h:{hm5C_threshold}"
103
+ ]
104
+ subprocess.run(command)
105
+
106
+ def modQC(aligned_sorted_output, thresholds):
107
+ """
108
+ Output the percentile of bases falling at a call threshold (threshold is a probability between 0-1) for the overall BAM file.
109
+ It is generally good to look at these parameters on positive and negative controls.
110
+
111
+ Parameters:
112
+ aligned_sorted_output (str): A string representing the file path of the aligned_sorted non-split BAM file output by the dorado aligned.
113
+ thresholds (list): A list of floats to pass for call thresholds.
114
+
115
+ Returns:
116
+ None
117
+ """
118
+ import subprocess
119
+
120
+ filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
121
+ subprocess.run(["modkit", "sample-probs", str(aligned_sorted_output)])
122
+ command = [
123
+ "modkit", "summary", str(aligned_sorted_output),
124
+ "--filter-threshold", f"{filter_threshold}",
125
+ "--mod-thresholds", f"m:{m5C_threshold}",
126
+ "--mod-thresholds", f"a:{m6A_threshold}",
127
+ "--mod-thresholds", f"h:{hm5C_threshold}"
128
+ ]
129
+ subprocess.run(command)