smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from importlib import resources
4
+ from pathlib import Path
5
+
6
+ SCHEMA_REGISTRY_VERSION = "1"
7
+ SCHEMA_REGISTRY_RESOURCE = "anndata_schema_v1.yaml"
8
+
9
+
10
+ def get_schema_registry_path() -> Path:
11
+ return resources.files(__package__).joinpath(SCHEMA_REGISTRY_RESOURCE)
@@ -0,0 +1,227 @@
1
+ schema_version: "1"
2
+ description: "smftools AnnData schema registry (v1)."
3
+ stages:
4
+ raw:
5
+ stage_requires: []
6
+ obs:
7
+ Experiment_name:
8
+ dtype: "category"
9
+ created_by: "smftools.cli.load_adata"
10
+ modified_by: []
11
+ notes: "Experiment identifier applied to all reads."
12
+ requires: []
13
+ optional_inputs: []
14
+ Experiment_name_and_barcode:
15
+ dtype: "category"
16
+ created_by: "smftools.cli.load_adata"
17
+ modified_by: []
18
+ notes: "Concatenated experiment name and barcode."
19
+ requires: [["obs.Experiment_name", "obs.Barcode"]]
20
+ optional_inputs: []
21
+ Barcode:
22
+ dtype: "category"
23
+ created_by: "smftools.informatics.modkit_extract_to_adata"
24
+ modified_by: []
25
+ notes: "Barcode assigned during demultiplexing or extraction."
26
+ requires: []
27
+ optional_inputs: []
28
+ read_length:
29
+ dtype: "float"
30
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
31
+ modified_by: []
32
+ notes: "Read length in bases."
33
+ requires: []
34
+ optional_inputs: []
35
+ mapped_length:
36
+ dtype: "float"
37
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
38
+ modified_by: []
39
+ notes: "Aligned length in bases."
40
+ requires: []
41
+ optional_inputs: []
42
+ reference_length:
43
+ dtype: "float"
44
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
45
+ modified_by: []
46
+ notes: "Reference length for alignment target."
47
+ requires: []
48
+ optional_inputs: []
49
+ read_quality:
50
+ dtype: "float"
51
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
52
+ modified_by: []
53
+ notes: "Per-read quality score."
54
+ requires: []
55
+ optional_inputs: []
56
+ mapping_quality:
57
+ dtype: "float"
58
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
59
+ modified_by: []
60
+ notes: "Mapping quality score."
61
+ requires: []
62
+ optional_inputs: []
63
+ read_length_to_reference_length_ratio:
64
+ dtype: "float"
65
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
66
+ modified_by: []
67
+ notes: "Read length divided by reference length."
68
+ requires: [["obs.read_length", "obs.reference_length"]]
69
+ optional_inputs: []
70
+ mapped_length_to_reference_length_ratio:
71
+ dtype: "float"
72
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
73
+ modified_by: []
74
+ notes: "Mapped length divided by reference length."
75
+ requires: [["obs.mapped_length", "obs.reference_length"]]
76
+ optional_inputs: []
77
+ mapped_length_to_read_length_ratio:
78
+ dtype: "float"
79
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
80
+ modified_by: []
81
+ notes: "Mapped length divided by read length."
82
+ requires: [["obs.mapped_length", "obs.read_length"]]
83
+ optional_inputs: []
84
+ Raw_modification_signal:
85
+ dtype: "float"
86
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
87
+ modified_by:
88
+ - "smftools.cli.load_adata"
89
+ notes: "Summed modification signal per read."
90
+ requires: [["X"], ["layers.raw_mods"]]
91
+ optional_inputs: []
92
+ pod5_origin:
93
+ dtype: "string"
94
+ created_by: "smftools.informatics.h5ad_functions.annotate_pod5_origin"
95
+ modified_by: []
96
+ notes: "POD5 filename source for each read."
97
+ requires: [["obs_names"]]
98
+ optional_inputs: []
99
+ demux_type:
100
+ dtype: "category"
101
+ created_by: "smftools.informatics.h5ad_functions.add_demux_type_annotation"
102
+ modified_by: []
103
+ notes: "Classification of demultiplexing status."
104
+ requires: [["obs_names"]]
105
+ optional_inputs: []
106
+ var:
107
+ reference_position:
108
+ dtype: "int"
109
+ created_by: "smftools.informatics.modkit_extract_to_adata"
110
+ modified_by: []
111
+ notes: "Reference coordinate for each column."
112
+ requires: []
113
+ optional_inputs: []
114
+ reference_id:
115
+ dtype: "category"
116
+ created_by: "smftools.informatics.modkit_extract_to_adata"
117
+ modified_by: []
118
+ notes: "Reference contig or sequence name."
119
+ requires: []
120
+ optional_inputs: []
121
+ layers:
122
+ raw_mods:
123
+ dtype: "float"
124
+ created_by: "smftools.informatics.modkit_extract_to_adata"
125
+ modified_by: []
126
+ notes: "Raw modification scores (modality-dependent)."
127
+ requires: []
128
+ optional_inputs: []
129
+ obsm: {}
130
+ varm: {}
131
+ obsp: {}
132
+ uns:
133
+ smftools:
134
+ dtype: "mapping"
135
+ created_by: "smftools.metadata.record_smftools_metadata"
136
+ modified_by: []
137
+ notes: "smftools metadata including history, environment, provenance, schema snapshot."
138
+ requires: []
139
+ optional_inputs: []
140
+ preprocess:
141
+ stage_requires: ["raw"]
142
+ obs:
143
+ sequence__merged_cluster_id:
144
+ dtype: "category"
145
+ created_by: "smftools.preprocessing.flag_duplicate_reads"
146
+ modified_by: []
147
+ notes: "Cluster identifier for duplicate detection."
148
+ requires: [["layers.nan0_0minus1"]]
149
+ optional_inputs: ["obs.demux_type"]
150
+ layers:
151
+ nan0_0minus1:
152
+ dtype: "float"
153
+ created_by: "smftools.preprocessing.binarize"
154
+ modified_by:
155
+ - "smftools.preprocessing.clean_NaN"
156
+ notes: "Binarized methylation matrix (nan=0, 0=-1)."
157
+ requires: [["X"]]
158
+ optional_inputs: []
159
+ obsm:
160
+ X_umap:
161
+ dtype: "float"
162
+ created_by: "smftools.tools.calculate_umap"
163
+ modified_by: []
164
+ notes: "UMAP embedding for preprocessed reads."
165
+ requires: [["X"]]
166
+ optional_inputs: []
167
+ varm: {}
168
+ obsp: {}
169
+ uns:
170
+ duplicate_read_groups:
171
+ dtype: "mapping"
172
+ created_by: "smftools.preprocessing.flag_duplicate_reads"
173
+ modified_by: []
174
+ notes: "Duplicate read group metadata."
175
+ requires: [["obs.sequence__merged_cluster_id"]]
176
+ optional_inputs: []
177
+ spatial:
178
+ stage_requires: ["raw", "preprocess"]
179
+ obs:
180
+ leiden:
181
+ dtype: "category"
182
+ created_by: "smftools.tools.calculate_umap"
183
+ modified_by: []
184
+ notes: "Leiden cluster assignments."
185
+ requires: [["obsm.X_umap"]]
186
+ optional_inputs: []
187
+ obsm:
188
+ X_umap:
189
+ dtype: "float"
190
+ created_by: "smftools.tools.calculate_umap"
191
+ modified_by: []
192
+ notes: "UMAP embedding for spatial analyses."
193
+ requires: [["X"]]
194
+ optional_inputs: []
195
+ layers: {}
196
+ varm: {}
197
+ obsp: {}
198
+ uns:
199
+ positionwise_result:
200
+ dtype: "mapping"
201
+ created_by: "smftools.tools.position_stats.compute_positionwise_statistics"
202
+ modified_by: []
203
+ notes: "Positionwise correlation statistics for spatial analyses."
204
+ requires: [["X"]]
205
+ optional_inputs: ["obs.reference_column"]
206
+ hmm:
207
+ stage_requires: ["raw", "preprocess", "spatial"]
208
+ layers:
209
+ hmm_state_calls:
210
+ dtype: "int"
211
+ created_by: "smftools.hmm.call_hmm_peaks"
212
+ modified_by: []
213
+ notes: "HMM-derived state calls per read/position."
214
+ requires: [["layers.nan0_0minus1"]]
215
+ optional_inputs: []
216
+ obsm: {}
217
+ varm: {}
218
+ obsp: {}
219
+ obs: {}
220
+ uns:
221
+ hmm_annotated:
222
+ dtype: "bool"
223
+ created_by: "smftools.cli.hmm_adata"
224
+ modified_by: []
225
+ notes: "Flag indicating HMM annotations are present."
226
+ requires: [["layers.hmm_state_calls"]]
227
+ optional_inputs: []
@@ -1,20 +1,27 @@
1
- from .position_stats import calculate_relative_risk_on_activity, compute_positionwise_statistics
2
- from .calculate_umap import calculate_umap
3
- from .cluster_adata_on_methylation import cluster_adata_on_methylation
4
- from .general_tools import create_nan_mask_from_X, combine_layers, create_nan_or_non_gpc_mask
5
- from .read_stats import calculate_row_entropy
6
- from .spatial_autocorrelation import *
7
- from .subset_adata import subset_adata
1
+ from __future__ import annotations
8
2
 
3
+ from importlib import import_module
9
4
 
10
- __all__ = [
11
- "compute_positionwise_statistics",
12
- "calculate_row_entropy",
13
- "calculate_umap",
14
- "calculate_relative_risk_on_activity",
15
- "cluster_adata_on_methylation",
16
- "create_nan_mask_from_X",
17
- "create_nan_or_non_gpc_mask",
18
- "combine_layers",
19
- "subset_adata",
20
- ]
5
+ _LAZY_ATTRS = {
6
+ "calculate_umap": "smftools.tools.calculate_umap",
7
+ "cluster_adata_on_methylation": "smftools.tools.cluster_adata_on_methylation",
8
+ "combine_layers": "smftools.tools.general_tools",
9
+ "create_nan_mask_from_X": "smftools.tools.general_tools",
10
+ "create_nan_or_non_gpc_mask": "smftools.tools.general_tools",
11
+ "calculate_relative_risk_on_activity": "smftools.tools.position_stats",
12
+ "compute_positionwise_statistics": "smftools.tools.position_stats",
13
+ "calculate_row_entropy": "smftools.tools.read_stats",
14
+ "subset_adata": "smftools.tools.subset_adata",
15
+ }
16
+
17
+
18
+ def __getattr__(name: str):
19
+ if name in _LAZY_ATTRS:
20
+ module = import_module(_LAZY_ATTRS[name])
21
+ attr = getattr(module, name)
22
+ globals()[name] = attr
23
+ return attr
24
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
25
+
26
+
27
+ __all__ = list(_LAZY_ATTRS.keys())
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
  import torch
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  ## Train CNN, RNN, Random Forest models on double barcoded, low contamination datasets
2
4
  import torch
3
5
  import torch.nn as nn
@@ -21,13 +23,29 @@ device = (
21
23
 
22
24
  # ------------------------- Utilities -------------------------
23
25
  def random_fill_nans(X):
26
+ """Replace NaNs in an array with random values.
27
+
28
+ Args:
29
+ X: Input NumPy array.
30
+
31
+ Returns:
32
+ NumPy array with NaNs replaced.
33
+ """
24
34
  nan_mask = np.isnan(X)
25
35
  X[nan_mask] = np.random.rand(*X[nan_mask].shape)
26
36
  return X
27
37
 
28
38
  # ------------------------- Model Definitions -------------------------
29
39
  class CNNClassifier(nn.Module):
40
+ """Simple 1D CNN classifier for fixed-length inputs."""
41
+
30
42
  def __init__(self, input_size, num_classes):
43
+ """Initialize CNN classifier layers.
44
+
45
+ Args:
46
+ input_size: Length of the 1D input.
47
+ num_classes: Number of output classes.
48
+ """
31
49
  super().__init__()
32
50
  self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
33
51
  self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
@@ -39,11 +57,13 @@ class CNNClassifier(nn.Module):
39
57
  self.fc2 = nn.Linear(64, num_classes)
40
58
 
41
59
  def _forward_conv(self, x):
60
+ """Apply convolutional layers and activation."""
42
61
  x = self.relu(self.conv1(x))
43
62
  x = self.relu(self.conv2(x))
44
63
  return x
45
64
 
46
65
  def forward(self, x):
66
+ """Run the forward pass."""
47
67
  x = x.unsqueeze(1)
48
68
  x = self._forward_conv(x)
49
69
  x = x.view(x.size(0), -1)
@@ -51,7 +71,15 @@ class CNNClassifier(nn.Module):
51
71
  return self.fc2(x)
52
72
 
53
73
  class MLPClassifier(nn.Module):
74
+ """Simple MLP classifier."""
75
+
54
76
  def __init__(self, input_dim, num_classes):
77
+ """Initialize MLP layers.
78
+
79
+ Args:
80
+ input_dim: Input feature dimension.
81
+ num_classes: Number of output classes.
82
+ """
55
83
  super().__init__()
56
84
  self.model = nn.Sequential(
57
85
  nn.Linear(input_dim, 128),
@@ -64,10 +92,20 @@ class MLPClassifier(nn.Module):
64
92
  )
65
93
 
66
94
  def forward(self, x):
95
+ """Run the forward pass."""
67
96
  return self.model(x)
68
97
 
69
98
  class RNNClassifier(nn.Module):
99
+ """LSTM-based classifier for sequential inputs."""
100
+
70
101
  def __init__(self, input_size, hidden_dim, num_classes):
102
+ """Initialize RNN classifier layers.
103
+
104
+ Args:
105
+ input_size: Input feature dimension.
106
+ hidden_dim: Hidden state dimension.
107
+ num_classes: Number of output classes.
108
+ """
71
109
  super().__init__()
72
110
  # Define LSTM layer
73
111
  self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
@@ -75,18 +113,29 @@ class RNNClassifier(nn.Module):
75
113
  self.fc = nn.Linear(hidden_dim, num_classes)
76
114
 
77
115
  def forward(self, x):
116
+ """Run the forward pass."""
78
117
  x = x.unsqueeze(1)
79
118
  _, (h_n, _) = self.lstm(x)
80
119
  return self.fc(h_n.squeeze(0))
81
120
 
82
121
  class AttentionRNNClassifier(nn.Module):
122
+ """LSTM classifier with simple attention."""
123
+
83
124
  def __init__(self, input_size, hidden_dim, num_classes):
125
+ """Initialize attention-based RNN layers.
126
+
127
+ Args:
128
+ input_size: Input feature dimension.
129
+ hidden_dim: Hidden state dimension.
130
+ num_classes: Number of output classes.
131
+ """
84
132
  super().__init__()
85
133
  self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
86
134
  self.attn = nn.Linear(hidden_dim, 1) # Simple attention scores
87
135
  self.fc = nn.Linear(hidden_dim, num_classes)
88
136
 
89
137
  def forward(self, x):
138
+ """Run the forward pass."""
90
139
  x = x.unsqueeze(1) # shape: (batch, 1, seq_len)
91
140
  lstm_out, _ = self.lstm(x) # shape: (batch, 1, hidden_dim)
92
141
  attn_weights = torch.softmax(self.attn(lstm_out), dim=1) # (batch, 1, 1)
@@ -94,7 +143,15 @@ class AttentionRNNClassifier(nn.Module):
94
143
  return self.fc(context)
95
144
 
96
145
  class PositionalEncoding(nn.Module):
146
+ """Positional encoding module for transformer models."""
147
+
97
148
  def __init__(self, d_model, max_len=5000):
149
+ """Initialize positional encoding buffer.
150
+
151
+ Args:
152
+ d_model: Model embedding dimension.
153
+ max_len: Maximum sequence length.
154
+ """
98
155
  super().__init__()
99
156
  pe = torch.zeros(max_len, d_model)
100
157
  position = torch.arange(0, max_len).unsqueeze(1).float()
@@ -104,11 +161,23 @@ class PositionalEncoding(nn.Module):
104
161
  self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
105
162
 
106
163
  def forward(self, x):
164
+ """Add positional encoding to inputs."""
107
165
  x = x + self.pe[:, :x.size(1)].to(x.device)
108
166
  return x
109
167
 
110
168
  class TransformerClassifier(nn.Module):
169
+ """Transformer encoder-based classifier."""
170
+
111
171
  def __init__(self, input_dim, model_dim, num_classes, num_heads=4, num_layers=2):
172
+ """Initialize transformer classifier layers.
173
+
174
+ Args:
175
+ input_dim: Input feature dimension.
176
+ model_dim: Transformer model dimension.
177
+ num_classes: Number of output classes.
178
+ num_heads: Number of attention heads.
179
+ num_layers: Number of encoder layers.
180
+ """
112
181
  super().__init__()
113
182
  self.input_fc = nn.Linear(input_dim, model_dim)
114
183
  self.pos_encoder = PositionalEncoding(model_dim)
@@ -119,6 +188,7 @@ class TransformerClassifier(nn.Module):
119
188
  self.cls_head = nn.Linear(model_dim, num_classes)
120
189
 
121
190
  def forward(self, x):
191
+ """Run the forward pass."""
122
192
  # x: [batch_size, input_dim]
123
193
  x = self.input_fc(x).unsqueeze(1) # -> [batch_size, 1, model_dim]
124
194
  x = self.pos_encoder(x)
@@ -128,6 +198,19 @@ class TransformerClassifier(nn.Module):
128
198
  return self.cls_head(pooled)
129
199
 
130
200
  def train_model(model, loader, optimizer, criterion, device, ref_name="", model_name="", epochs=20, patience=5):
201
+ """Train a model with early stopping.
202
+
203
+ Args:
204
+ model: PyTorch model.
205
+ loader: DataLoader for training data.
206
+ optimizer: Optimizer instance.
207
+ criterion: Loss function.
208
+ device: Torch device.
209
+ ref_name: Reference label for logging.
210
+ model_name: Model label for logging.
211
+ epochs: Maximum epochs.
212
+ patience: Early-stopping patience.
213
+ """
131
214
  model.train()
132
215
  best_loss = float('inf')
133
216
  trigger_times = 0
@@ -154,6 +237,17 @@ def train_model(model, loader, optimizer, criterion, device, ref_name="", model_
154
237
  break
155
238
 
156
239
  def evaluate_model(model, X_tensor, y_encoded, device):
240
+ """Evaluate a trained model and compute metrics.
241
+
242
+ Args:
243
+ model: Trained model.
244
+ X_tensor: Input tensor.
245
+ y_encoded: Encoded labels.
246
+ device: Torch device.
247
+
248
+ Returns:
249
+ Tuple of metrics dict, predicted labels, and probabilities.
250
+ """
157
251
  model.eval()
158
252
  with torch.no_grad():
159
253
  outputs = model(X_tensor.to(device))
@@ -176,6 +270,18 @@ def evaluate_model(model, X_tensor, y_encoded, device):
176
270
  }, preds, probs
177
271
 
178
272
  def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
273
+ """Train a random forest classifier.
274
+
275
+ Args:
276
+ X_tensor: Input tensor.
277
+ y_tensor: Label tensor.
278
+ train_indices: Indices for training.
279
+ test_indices: Indices for testing.
280
+ n_estimators: Number of trees.
281
+
282
+ Returns:
283
+ Tuple of (model, preds, probs).
284
+ """
179
285
  model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
180
286
  model.fit(X_tensor[train_indices].numpy(), y_tensor[train_indices].numpy())
181
287
  probs = model.predict_proba(X_tensor[test_indices].cpu().numpy())[:, 1]
@@ -186,6 +292,25 @@ def train_rf(X_tensor, y_tensor, train_indices, test_indices, n_estimators=500):
186
292
  def run_training_loop(adata, site_config, layer_name=None,
187
293
  mlp=False, cnn=False, rnn=False, arnn=False, transformer=False, rf=False, nb=False, rr_bayes=False,
188
294
  max_epochs=10, max_patience=5, n_estimators=500, training_split=0.5):
295
+ """Train one or more classifier types on AnnData.
296
+
297
+ Args:
298
+ adata: AnnData object containing data and labels.
299
+ site_config: Mapping of reference to site list.
300
+ layer_name: Optional layer to use as input.
301
+ mlp: Whether to train an MLP model.
302
+ cnn: Whether to train a CNN model.
303
+ rnn: Whether to train an RNN model.
304
+ arnn: Whether to train an attention RNN model.
305
+ transformer: Whether to train a transformer model.
306
+ rf: Whether to train a random forest model.
307
+ nb: Whether to train a Naive Bayes model.
308
+ rr_bayes: Whether to train a ridge regression model.
309
+ max_epochs: Maximum training epochs.
310
+ max_patience: Early stopping patience.
311
+ n_estimators: Random forest estimator count.
312
+ training_split: Fraction of data used for training.
313
+ """
189
314
  device = (
190
315
  torch.device('cuda') if torch.cuda.is_available() else
191
316
  torch.device('mps') if torch.backends.mps.is_available() else
@@ -701,6 +826,20 @@ def evaluate_model_by_subgroups(
701
826
  label_col="activity_status",
702
827
  min_samples=10,
703
828
  exclude_training_data=True):
829
+ """Evaluate predictions within categorical subgroups.
830
+
831
+ Args:
832
+ adata: AnnData with prediction columns.
833
+ model_prefix: Prediction column prefix.
834
+ suffix: Prediction column suffix.
835
+ groupby_cols: Columns to group by.
836
+ label_col: Ground-truth label column.
837
+ min_samples: Minimum samples per group.
838
+ exclude_training_data: Whether to exclude training rows.
839
+
840
+ Returns:
841
+ DataFrame of subgroup-level metrics.
842
+ """
704
843
  import pandas as pd
705
844
  from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
706
845
 
@@ -745,6 +884,18 @@ def evaluate_model_by_subgroups(
745
884
  return pd.DataFrame(results)
746
885
 
747
886
  def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col, exclude_training_data=True):
887
+ """Evaluate multiple model prefixes across subgroups.
888
+
889
+ Args:
890
+ adata: AnnData with prediction columns.
891
+ model_prefixes: Iterable of model prefixes.
892
+ groupby_cols: Columns to group by.
893
+ label_col: Ground-truth label column.
894
+ exclude_training_data: Whether to exclude training rows.
895
+
896
+ Returns:
897
+ Concatenated DataFrame of subgroup-level metrics.
898
+ """
748
899
  import pandas as pd
749
900
  all_metrics = []
750
901
  for model_prefix in model_prefixes:
@@ -758,6 +909,20 @@ def evaluate_models_by_subgroup(adata, model_prefixes, groupby_cols, label_col,
758
909
  return final_df
759
910
 
760
911
  def prepare_melted_model_data(adata, outkey='melted_model_df', groupby=['Enhancer_Open', 'Promoter_Open'], label_col='activity_status', model_names = ['cnn', 'mlp', 'rf'], suffix='GpC_site_CpG_site', omit_training=True):
912
+ """Prepare a long-format DataFrame for model performance plots.
913
+
914
+ Args:
915
+ adata: AnnData with prediction columns.
916
+ outkey: Key to store the melted DataFrame in ``adata.uns``.
917
+ groupby: Grouping columns to include.
918
+ label_col: Ground-truth label column.
919
+ model_names: Model prefixes to include.
920
+ suffix: Prediction column suffix.
921
+ omit_training: Whether to exclude training rows.
922
+
923
+ Returns:
924
+ Melted DataFrame of predictions.
925
+ """
761
926
  import pandas as pd
762
927
  import seaborn as sns
763
928
  import matplotlib.pyplot as plt
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # classify_methylated_features
2
4
 
3
5
  def classify_methylated_features(read, model, coordinates, classification_mapping={}):
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # classify_non_methylated_features
2
4
 
3
5
  def classify_non_methylated_features(read, model, coordinates, classification_mapping={}):
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # subset_adata
2
4
 
3
5
  def subset_adata(adata, obs_columns):
@@ -13,6 +15,15 @@ def subset_adata(adata, obs_columns):
13
15
  """
14
16
 
15
17
  def subset_recursive(adata_subset, columns):
18
+ """Recursively subset AnnData by categorical columns.
19
+
20
+ Args:
21
+ adata_subset: AnnData subset to split.
22
+ columns: Remaining columns to split on.
23
+
24
+ Returns:
25
+ Dictionary mapping category tuples to AnnData subsets.
26
+ """
16
27
  if not columns:
17
28
  return {(): adata_subset}
18
29
 
@@ -29,4 +40,4 @@ def subset_adata(adata, obs_columns):
29
40
  # Start the recursive subset process
30
41
  subsets_dict = subset_recursive(adata, obs_columns)
31
42
 
32
- return subsets_dict
43
+ return subsets_dict
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # subset_adata
2
4
 
3
5
  def subset_adata(adata, columns, cat_type='obs'):
@@ -14,6 +16,17 @@ def subset_adata(adata, columns, cat_type='obs'):
14
16
  """
15
17
 
16
18
  def subset_recursive(adata_subset, columns, cat_type, key_prefix=()):
19
+ """Recursively subset AnnData by categorical columns.
20
+
21
+ Args:
22
+ adata_subset: AnnData subset to split.
23
+ columns: Remaining columns to split on.
24
+ cat_type: Whether to use obs or var categories.
25
+ key_prefix: Tuple of previous category keys.
26
+
27
+ Returns:
28
+ Dictionary mapping category tuples to AnnData subsets.
29
+ """
17
30
  # Returns when the bottom of the stack is reached
18
31
  if not columns:
19
32
  # If there's only one column, return the key as a single value, not a tuple
@@ -43,4 +56,4 @@ def subset_adata(adata, columns, cat_type='obs'):
43
56
  # Start the recursive subset process
44
57
  subsets_dict = subset_recursive(adata, columns, cat_type)
45
58
 
46
- return subsets_dict
59
+ return subsets_dict