smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. smftools/__init__.py +39 -7
  2. smftools/_settings.py +2 -0
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +34 -6
  7. smftools/cli/hmm_adata.py +239 -33
  8. smftools/cli/latent_adata.py +318 -0
  9. smftools/cli/load_adata.py +167 -131
  10. smftools/cli/preprocess_adata.py +180 -53
  11. smftools/cli/spatial_adata.py +152 -100
  12. smftools/cli_entry.py +38 -1
  13. smftools/config/__init__.py +2 -0
  14. smftools/config/conversion.yaml +11 -1
  15. smftools/config/default.yaml +42 -2
  16. smftools/config/experiment_config.py +59 -1
  17. smftools/constants.py +65 -0
  18. smftools/datasets/__init__.py +2 -0
  19. smftools/hmm/HMM.py +97 -3
  20. smftools/hmm/__init__.py +24 -13
  21. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  22. smftools/hmm/archived/calculate_distances.py +2 -0
  23. smftools/hmm/archived/call_hmm_peaks.py +2 -0
  24. smftools/hmm/archived/train_hmm.py +2 -0
  25. smftools/hmm/call_hmm_peaks.py +5 -2
  26. smftools/hmm/display_hmm.py +4 -1
  27. smftools/hmm/hmm_readwrite.py +7 -2
  28. smftools/hmm/nucleosome_hmm_refinement.py +2 -0
  29. smftools/informatics/__init__.py +59 -34
  30. smftools/informatics/archived/bam_conversion.py +2 -0
  31. smftools/informatics/archived/bam_direct.py +2 -0
  32. smftools/informatics/archived/basecall_pod5s.py +2 -0
  33. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  34. smftools/informatics/archived/conversion_smf.py +2 -0
  35. smftools/informatics/archived/deaminase_smf.py +1 -0
  36. smftools/informatics/archived/direct_smf.py +2 -0
  37. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  38. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  39. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
  40. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  41. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  42. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  43. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  44. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  45. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  46. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  47. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  48. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  49. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  50. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  52. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  53. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  54. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  55. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  56. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  57. smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
  58. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  59. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  60. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  61. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  62. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  63. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  64. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  65. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
  66. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  67. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  68. smftools/informatics/archived/print_bam_query_seq.py +2 -0
  69. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  70. smftools/informatics/archived/subsample_pod5.py +2 -0
  71. smftools/informatics/bam_functions.py +1093 -176
  72. smftools/informatics/basecalling.py +2 -0
  73. smftools/informatics/bed_functions.py +271 -61
  74. smftools/informatics/binarize_converted_base_identities.py +3 -0
  75. smftools/informatics/complement_base_list.py +2 -0
  76. smftools/informatics/converted_BAM_to_adata.py +641 -176
  77. smftools/informatics/fasta_functions.py +94 -10
  78. smftools/informatics/h5ad_functions.py +123 -4
  79. smftools/informatics/modkit_extract_to_adata.py +1019 -431
  80. smftools/informatics/modkit_functions.py +2 -0
  81. smftools/informatics/ohe.py +2 -0
  82. smftools/informatics/pod5_functions.py +3 -2
  83. smftools/informatics/sequence_encoding.py +72 -0
  84. smftools/logging_utils.py +21 -2
  85. smftools/machine_learning/__init__.py +22 -6
  86. smftools/machine_learning/data/__init__.py +2 -0
  87. smftools/machine_learning/data/anndata_data_module.py +18 -4
  88. smftools/machine_learning/data/preprocessing.py +2 -0
  89. smftools/machine_learning/evaluation/__init__.py +2 -0
  90. smftools/machine_learning/evaluation/eval_utils.py +2 -0
  91. smftools/machine_learning/evaluation/evaluators.py +14 -9
  92. smftools/machine_learning/inference/__init__.py +2 -0
  93. smftools/machine_learning/inference/inference_utils.py +2 -0
  94. smftools/machine_learning/inference/lightning_inference.py +6 -1
  95. smftools/machine_learning/inference/sklearn_inference.py +2 -0
  96. smftools/machine_learning/inference/sliding_window_inference.py +2 -0
  97. smftools/machine_learning/models/__init__.py +2 -0
  98. smftools/machine_learning/models/base.py +7 -2
  99. smftools/machine_learning/models/cnn.py +7 -2
  100. smftools/machine_learning/models/lightning_base.py +16 -11
  101. smftools/machine_learning/models/mlp.py +5 -1
  102. smftools/machine_learning/models/positional.py +7 -2
  103. smftools/machine_learning/models/rnn.py +5 -1
  104. smftools/machine_learning/models/sklearn_models.py +14 -9
  105. smftools/machine_learning/models/transformer.py +7 -2
  106. smftools/machine_learning/models/wrappers.py +6 -2
  107. smftools/machine_learning/training/__init__.py +2 -0
  108. smftools/machine_learning/training/train_lightning_model.py +13 -3
  109. smftools/machine_learning/training/train_sklearn_model.py +2 -0
  110. smftools/machine_learning/utils/__init__.py +2 -0
  111. smftools/machine_learning/utils/device.py +5 -1
  112. smftools/machine_learning/utils/grl.py +5 -1
  113. smftools/metadata.py +1 -1
  114. smftools/optional_imports.py +31 -0
  115. smftools/plotting/__init__.py +41 -31
  116. smftools/plotting/autocorrelation_plotting.py +9 -5
  117. smftools/plotting/classifiers.py +16 -4
  118. smftools/plotting/general_plotting.py +2415 -629
  119. smftools/plotting/hmm_plotting.py +97 -9
  120. smftools/plotting/position_stats.py +15 -7
  121. smftools/plotting/qc_plotting.py +6 -1
  122. smftools/preprocessing/__init__.py +36 -37
  123. smftools/preprocessing/append_base_context.py +17 -17
  124. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  125. smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
  126. smftools/preprocessing/archived/calculate_complexity.py +2 -0
  127. smftools/preprocessing/archived/mark_duplicates.py +2 -0
  128. smftools/preprocessing/archived/preprocessing.py +2 -0
  129. smftools/preprocessing/archived/remove_duplicates.py +2 -0
  130. smftools/preprocessing/binary_layers_to_ohe.py +2 -1
  131. smftools/preprocessing/calculate_complexity_II.py +4 -1
  132. smftools/preprocessing/calculate_consensus.py +1 -1
  133. smftools/preprocessing/calculate_pairwise_differences.py +2 -0
  134. smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
  135. smftools/preprocessing/calculate_position_Youden.py +9 -2
  136. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  137. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
  138. smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
  139. smftools/preprocessing/flag_duplicate_reads.py +42 -54
  140. smftools/preprocessing/make_dirs.py +2 -1
  141. smftools/preprocessing/min_non_diagonal.py +2 -0
  142. smftools/preprocessing/recipes.py +2 -0
  143. smftools/readwrite.py +53 -17
  144. smftools/schema/anndata_schema_v1.yaml +15 -1
  145. smftools/tools/__init__.py +30 -18
  146. smftools/tools/archived/apply_hmm.py +2 -0
  147. smftools/tools/archived/classifiers.py +2 -0
  148. smftools/tools/archived/classify_methylated_features.py +2 -0
  149. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  150. smftools/tools/archived/subset_adata_v1.py +2 -0
  151. smftools/tools/archived/subset_adata_v2.py +2 -0
  152. smftools/tools/calculate_leiden.py +57 -0
  153. smftools/tools/calculate_nmf.py +119 -0
  154. smftools/tools/calculate_umap.py +93 -8
  155. smftools/tools/cluster_adata_on_methylation.py +7 -1
  156. smftools/tools/position_stats.py +17 -27
  157. smftools/tools/rolling_nn_distance.py +235 -0
  158. smftools/tools/tensor_factorization.py +169 -0
  159. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
  160. smftools-0.3.1.dist-info/RECORD +189 -0
  161. smftools-0.2.5.dist-info/RECORD +0 -181
  162. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  163. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  164. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -12,6 +12,7 @@ from smftools.constants import (
12
12
  BAM_SUFFIX,
13
13
  BARCODE_BOTH_ENDS,
14
14
  CONVERSIONS,
15
+ LOAD_DIR,
15
16
  MOD_LIST,
16
17
  MOD_MAP,
17
18
  REF_COL,
@@ -664,6 +665,8 @@ class ExperimentConfig:
664
665
  # General I/O
665
666
  input_data_path: Optional[str] = None
666
667
  output_directory: Optional[str] = None
668
+ emit_log_file: Optional[bool] = True
669
+ log_level: Optional[str] = "INFO"
667
670
  fasta: Optional[str] = None
668
671
  bam_suffix: str = BAM_SUFFIX
669
672
  recursive_input_search: bool = True
@@ -736,6 +739,10 @@ class ExperimentConfig:
736
739
  aligner_args: Optional[List[str]] = None
737
740
  make_bigwigs: bool = False
738
741
  make_beds: bool = False
742
+ annotate_secondary_supplementary: bool = True
743
+ samtools_backend: str = "auto"
744
+ bedtools_backend: str = "auto"
745
+ bigwig_backend: str = "auto"
739
746
 
740
747
  # Anndata structure
741
748
  reference_column: Optional[str] = REF_COL
@@ -744,6 +751,9 @@ class ExperimentConfig:
744
751
  # General Plotting
745
752
  sample_name_col_for_plotting: Optional[str] = "Barcode"
746
753
  rows_per_qc_histogram_grid: int = 12
754
+ clustermap_demux_types_to_plot: List[str] = field(
755
+ default_factory=lambda: ["single", "double", "already"]
756
+ )
747
757
 
748
758
  # Preprocessing - Read length and quality filter params
749
759
  read_coord_filter: Optional[Sequence[float]] = field(default_factory=lambda: [None, None])
@@ -813,6 +823,9 @@ class ExperimentConfig:
813
823
  duplicate_detection_site_types: List[str] = field(
814
824
  default_factory=lambda: ["GpC", "CpG", "ambiguous_GpC_CpG"]
815
825
  )
826
+ duplicate_detection_demux_types_to_use: List[str] = field(
827
+ default_factory=lambda: ["single", "double", "already"]
828
+ )
816
829
  duplicate_detection_distance_threshold: float = 0.07
817
830
  hamming_vs_metric_keys: List[str] = field(default_factory=lambda: ["Fraction_C_site_modified"])
818
831
  duplicate_detection_keep_best_metric: str = "read_quality"
@@ -824,6 +837,9 @@ class ExperimentConfig:
824
837
 
825
838
  # Preprocessing - Position QC
826
839
  position_max_nan_threshold: float = 0.1
840
+ mismatch_frequency_range: Sequence[float] = field(default_factory=lambda: [0.05, 0.95])
841
+ mismatch_frequency_layer: str = "mismatch_integer_encoding"
842
+ mismatch_frequency_read_span_layer: str = "read_span_mask"
827
843
 
828
844
  # Spatial Analysis - Clustermap params
829
845
  layer_for_clustermap_plotting: Optional[str] = "nan0_0minus1"
@@ -832,6 +848,14 @@ class ExperimentConfig:
832
848
  clustermap_cmap_cpg: Optional[str] = "coolwarm"
833
849
  clustermap_cmap_a: Optional[str] = "coolwarm"
834
850
  spatial_clustermap_sortby: Optional[str] = "gpc"
851
+ rolling_nn_layer: Optional[str] = "nan0_0minus1"
852
+ rolling_nn_plot_layer: Optional[str] = "nan0_0minus1"
853
+ rolling_nn_window: int = 15
854
+ rolling_nn_step: int = 2
855
+ rolling_nn_min_overlap: int = 10
856
+ rolling_nn_return_fraction: bool = True
857
+ rolling_nn_obsm_key: str = "rolling_nn_dist"
858
+ rolling_nn_site_types: Optional[List[str]] = None
835
859
 
836
860
  # Spatial Analysis - UMAP/Leiden params
837
861
  layer_for_umap_plotting: Optional[str] = "nan_half"
@@ -880,11 +904,15 @@ class ExperimentConfig:
880
904
  accessible_patches: Optional[bool] = True
881
905
  cpg: Optional[bool] = False
882
906
  hmm_feature_sets: Dict[str, Any] = field(default_factory=dict)
907
+ hmm_feature_colormaps: Dict[str, Any] = field(default_factory=dict)
883
908
  hmm_merge_layer_features: Optional[List[Tuple]] = field(default_factory=lambda: [(None, 60)])
884
909
  clustermap_cmap_hmm: Optional[str] = "coolwarm"
885
910
  hmm_clustermap_feature_layers: List[str] = field(
886
911
  default_factory=lambda: ["all_accessible_features"]
887
912
  )
913
+ hmm_clustermap_length_layers: List[str] = field(
914
+ default_factory=lambda: ["all_accessible_features"]
915
+ )
888
916
  hmm_clustermap_sortby: Optional[str] = "hmm"
889
917
  hmm_peak_feature_configs: Dict[str, Any] = field(default_factory=dict)
890
918
 
@@ -903,6 +931,8 @@ class ExperimentConfig:
903
931
  invert_adata: bool = False
904
932
  bypass_append_binary_layer_by_base_context: bool = False
905
933
  force_redo_append_binary_layer_by_base_context: bool = False
934
+ bypass_append_mismatch_frequency_sites: bool = False
935
+ force_redo_append_mismatch_frequency_sites: bool = False
906
936
  bypass_calculate_read_modification_stats: bool = False
907
937
  force_redo_calculate_read_modification_stats: bool = False
908
938
  bypass_filter_reads_on_modification_thresholds: bool = False
@@ -1107,7 +1137,7 @@ class ExperimentConfig:
1107
1137
 
1108
1138
  # Demultiplexing output path
1109
1139
  split_dir = merged.get("split_dir", SPLIT_DIR)
1110
- split_path = output_dir / split_dir
1140
+ split_path = output_dir / LOAD_DIR / split_dir
1111
1141
 
1112
1142
  # final normalization
1113
1143
  if "strands" in merged:
@@ -1194,6 +1224,9 @@ class ExperimentConfig:
1194
1224
  # Final normalization of hmm_feature_sets and canonical local variables
1195
1225
  merged["hmm_feature_sets"] = normalize_hmm_feature_sets(merged.get("hmm_feature_sets", {}))
1196
1226
  hmm_feature_sets = merged.get("hmm_feature_sets", {})
1227
+ hmm_feature_colormaps = merged.get("hmm_feature_colormaps", {})
1228
+ if not isinstance(hmm_feature_colormaps, dict):
1229
+ hmm_feature_colormaps = {}
1197
1230
  hmm_annotation_threshold = merged.get("hmm_annotation_threshold", 0.5)
1198
1231
  hmm_batch_size = int(merged.get("hmm_batch_size", 1024))
1199
1232
  hmm_use_viterbi = bool(merged.get("hmm_use_viterbi", False))
@@ -1208,6 +1241,9 @@ class ExperimentConfig:
1208
1241
  hmm_clustermap_feature_layers = _parse_list(
1209
1242
  merged.get("hmm_clustermap_feature_layers", "all_accessible_features")
1210
1243
  )
1244
+ hmm_clustermap_length_layers = _parse_list(
1245
+ merged.get("hmm_clustermap_length_layers", hmm_clustermap_feature_layers)
1246
+ )
1211
1247
 
1212
1248
  hmm_fit_strategy = str(merged.get("hmm_fit_strategy", "per_group")).strip()
1213
1249
  hmm_shared_scope = _parse_list(merged.get("hmm_shared_scope", ["reference", "methbase"]))
@@ -1228,6 +1264,7 @@ class ExperimentConfig:
1228
1264
 
1229
1265
  # instantiate dataclass
1230
1266
  instance = cls(
1267
+ annotate_secondary_supplementary=merged.get("annotate_secondary_supplementary", True),
1231
1268
  smf_modality=merged.get("smf_modality"),
1232
1269
  input_data_path=input_data_path,
1233
1270
  recursive_input_search=merged.get("recursive_input_search"),
@@ -1254,6 +1291,8 @@ class ExperimentConfig:
1254
1291
  trim=merged.get("trim", TRIM),
1255
1292
  input_already_demuxed=merged.get("input_already_demuxed", False),
1256
1293
  threads=merged.get("threads"),
1294
+ emit_log_file=merged.get("emit_log_file", True),
1295
+ log_level=merged.get("log_level", "INFO"),
1257
1296
  sample_sheet_path=merged.get("sample_sheet_path"),
1258
1297
  sample_sheet_mapping_column=merged.get("sample_sheet_mapping_column"),
1259
1298
  delete_intermediate_bams=merged.get("delete_intermediate_bams", False),
@@ -1264,6 +1303,9 @@ class ExperimentConfig:
1264
1303
  device=merged.get("device", "auto"),
1265
1304
  make_bigwigs=merged.get("make_bigwigs", False),
1266
1305
  make_beds=merged.get("make_beds", False),
1306
+ samtools_backend=merged.get("samtools_backend", "auto"),
1307
+ bedtools_backend=merged.get("bedtools_backend", "auto"),
1308
+ bigwig_backend=merged.get("bigwig_backend", "auto"),
1267
1309
  delete_intermediate_hdfs=merged.get("delete_intermediate_hdfs", True),
1268
1310
  mod_target_bases=merged.get("mod_target_bases", ["GpC", "CpG"]),
1269
1311
  enzyme_target_bases=merged.get("enzyme_target_bases", ["GpC"]),
@@ -1307,6 +1349,9 @@ class ExperimentConfig:
1307
1349
  ),
1308
1350
  reindexing_offsets=merged.get("reindexing_offsets", {None: None}),
1309
1351
  reindexed_var_suffix=merged.get("reindexed_var_suffix", "reindexed"),
1352
+ clustermap_demux_types_to_plot=merged.get(
1353
+ "clustermap_demux_types_to_plot", ["single", "double", "already"]
1354
+ ),
1310
1355
  layer_for_clustermap_plotting=merged.get(
1311
1356
  "layer_for_clustermap_plotting", "nan0_0minus1"
1312
1357
  ),
@@ -1315,6 +1360,14 @@ class ExperimentConfig:
1315
1360
  clustermap_cmap_cpg=merged.get("clustermap_cmap_cpg", "coolwarm"),
1316
1361
  clustermap_cmap_a=merged.get("clustermap_cmap_a", "coolwarm"),
1317
1362
  spatial_clustermap_sortby=merged.get("spatial_clustermap_sortby", "gpc"),
1363
+ rolling_nn_layer=merged.get("rolling_nn_layer", "nan0_0minus1"),
1364
+ rolling_nn_plot_layer=merged.get("rolling_nn_plot_layer", "nan0_0minus1"),
1365
+ rolling_nn_window=merged.get("rolling_nn_window", 15),
1366
+ rolling_nn_step=merged.get("rolling_nn_step", 2),
1367
+ rolling_nn_min_overlap=merged.get("rolling_nn_min_overlap", 10),
1368
+ rolling_nn_return_fraction=merged.get("rolling_nn_return_fraction", True),
1369
+ rolling_nn_obsm_key=merged.get("rolling_nn_obsm_key", "rolling_nn_dist"),
1370
+ rolling_nn_site_types=merged.get("rolling_nn_site_types", None),
1318
1371
  layer_for_umap_plotting=merged.get("layer_for_umap_plotting", "nan_half"),
1319
1372
  umap_layers_to_plot=merged.get(
1320
1373
  "umap_layers_to_plot", ["mapped_length", "Raw_modification_signal"]
@@ -1341,6 +1394,7 @@ class ExperimentConfig:
1341
1394
  hmm_emission_adapt_tol=hmm_emission_adapt_tol,
1342
1395
  hmm_dtype=merged.get("hmm_dtype", "float64"),
1343
1396
  hmm_feature_sets=hmm_feature_sets,
1397
+ hmm_feature_colormaps=hmm_feature_colormaps,
1344
1398
  hmm_annotation_threshold=hmm_annotation_threshold,
1345
1399
  hmm_batch_size=hmm_batch_size,
1346
1400
  hmm_use_viterbi=hmm_use_viterbi,
@@ -1349,6 +1403,7 @@ class ExperimentConfig:
1349
1403
  hmm_merge_layer_features=hmm_merge_layer_features,
1350
1404
  clustermap_cmap_hmm=merged.get("clustermap_cmap_hmm", "coolwarm"),
1351
1405
  hmm_clustermap_feature_layers=hmm_clustermap_feature_layers,
1406
+ hmm_clustermap_length_layers=hmm_clustermap_length_layers,
1352
1407
  hmm_clustermap_sortby=merged.get("hmm_clustermap_sortby", "hmm"),
1353
1408
  hmm_peak_feature_configs=hmm_peak_feature_configs,
1354
1409
  footprints=merged.get("footprints", None),
@@ -1384,6 +1439,9 @@ class ExperimentConfig:
1384
1439
  duplicate_detection_site_types=merged.get(
1385
1440
  "duplicate_detection_site_types", ["GpC", "CpG", "ambiguous_GpC_CpG"]
1386
1441
  ),
1442
+ duplicate_detection_demux_types_to_use=merged.get(
1443
+ "duplicate_detection_demux_types_to_use", ["single", "double", "already"]
1444
+ ),
1387
1445
  duplicate_detection_distance_threshold=merged.get(
1388
1446
  "duplicate_detection_distance_threshold", 0.07
1389
1447
  ),
smftools/constants.py CHANGED
@@ -21,7 +21,30 @@ BAM_SUFFIX: Final[str] = ".bam"
21
21
  BARCODE_BOTH_ENDS: Final[bool] = False
22
22
  REF_COL: Final[str] = "Reference_strand"
23
23
  SAMPLE_COL: Final[str] = "Experiment_name_and_barcode"
24
+ SAMPLE: Final[str] = "Sample"
24
25
  SPLIT_DIR: Final[str] = "demultiplexed_BAMs"
26
+ H5_DIR: Final[str] = "h5ads"
27
+ DEMUX_TYPE: Final[str] = "demux_type"
28
+ BARCODE: Final[str] = "Barcode"
29
+ REFERENCE: Final[str] = "Reference"
30
+ REFERENCE_STRAND: Final[str] = "Reference_strand"
31
+ REFERENCE_DATASET_STRAND: Final[str] = "Reference_dataset_strand"
32
+ STRAND: Final[str] = "Strand"
33
+ DATASET: Final[str] = "Dataset"
34
+ READ_MISMATCH_TREND: Final[str] = "Read_mismatch_trend"
35
+ READ_MAPPING_DIRECTION: Final[str] = "Read_mapping_direction"
36
+ SEQUENCE_INTEGER_ENCODING: Final[str] = "sequence_integer_encoding"
37
+ SEQUENCE_INTEGER_DECODING: Final[str] = "sequence_integer_decoding"
38
+ MISMATCH_INTEGER_ENCODING: Final[str] = "mismatch_integer_encoding"
39
+ BASE_QUALITY_SCORES: Final[str] = "base_quality_scores"
40
+ READ_SPAN_MASK: Final[str] = "read_span_mask"
41
+
42
+ LOAD_DIR: Final[str] = "load_adata_outputs"
43
+ PREPROCESS_DIR: Final[str] = "preprocess_adata_outputs"
44
+ SPATIAL_DIR: Final[str] = "spatial_adata_outputs"
45
+ HMM_DIR: Final[str] = "hmm_adata_outputs"
46
+ LATENT_DIR: Final[str] = "latent_adata_outputs"
47
+ LOGGING_DIR: Final[str] = "logs"
25
48
  TRIM: Final[bool] = False
26
49
 
27
50
  _private_conversions = ["unconverted"]
@@ -35,3 +58,45 @@ MOD_MAP: Final[Mapping[str, str]] = _deep_freeze(_private_mod_map)
35
58
 
36
59
  _private_strands = ("bottom", "top")
37
60
  STRANDS: Final[tuple[str, ...]] = _deep_freeze(_private_strands)
61
+
62
+ MODKIT_EXTRACT_TSV_COLUMN_CHROM: Final[str] = "chrom"
63
+ MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION: Final[str] = "ref_position"
64
+ MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE: Final[str] = "modified_primary_base"
65
+ MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND: Final[str] = "ref_strand"
66
+ MODKIT_EXTRACT_TSV_COLUMN_READ_ID: Final[str] = "read_id"
67
+ MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE: Final[str] = "call_code"
68
+ MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB: Final[str] = "call_prob"
69
+
70
+ MODKIT_EXTRACT_MODIFIED_BASE_A: Final[str] = "A"
71
+ MODKIT_EXTRACT_MODIFIED_BASE_C: Final[str] = "C"
72
+ MODKIT_EXTRACT_REF_STRAND_PLUS: Final[str] = "+"
73
+ MODKIT_EXTRACT_REF_STRAND_MINUS: Final[str] = "-"
74
+
75
+ _private_modkit_extract_call_code_modified = ("a", "h", "m")
76
+ MODKIT_EXTRACT_CALL_CODE_MODIFIED: Final[tuple[str, ...]] = _deep_freeze(
77
+ _private_modkit_extract_call_code_modified
78
+ )
79
+ _private_modkit_extract_call_code_canonical = ("-",)
80
+ MODKIT_EXTRACT_CALL_CODE_CANONICAL: Final[tuple[str, ...]] = _deep_freeze(
81
+ _private_modkit_extract_call_code_canonical
82
+ )
83
+
84
+ MODKIT_EXTRACT_SEQUENCE_BASES: Final[tuple[str, ...]] = _deep_freeze(("A", "C", "G", "T", "N"))
85
+ MODKIT_EXTRACT_SEQUENCE_PADDING_BASE: Final[str] = "PAD"
86
+ _private_modkit_extract_base_to_int: Dict[str, int] = {
87
+ "A": 0,
88
+ "C": 1,
89
+ "G": 2,
90
+ "T": 3,
91
+ "N": 4,
92
+ "PAD": 5,
93
+ }
94
+ MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT: Final[Mapping[str, int]] = _deep_freeze(
95
+ _private_modkit_extract_base_to_int
96
+ )
97
+ _private_modkit_extract_int_to_base: Dict[int, str] = {
98
+ value: key for key, value in _private_modkit_extract_base_to_int.items()
99
+ }
100
+ MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE: Final[Mapping[int, str]] = _deep_freeze(
101
+ _private_modkit_extract_int_to_base
102
+ )
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .datasets import Kissiov_and_McKenna_2025, dCas9_kinetics
2
4
 
3
5
  __all__ = ["dCas9_kinetics", "Kissiov_and_McKenna_2025"]
smftools/hmm/HMM.py CHANGED
@@ -3,14 +3,20 @@ from __future__ import annotations
3
3
  import ast
4
4
  import json
5
5
  from pathlib import Path
6
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
7
7
 
8
8
  import numpy as np
9
- import torch
10
- import torch.nn as nn
11
9
  from scipy.sparse import issparse
12
10
 
13
11
  from smftools.logging_utils import get_logger
12
+ from smftools.optional_imports import require
13
+
14
+ if TYPE_CHECKING:
15
+ import torch as torch_types
16
+ import torch.nn as nn_types
17
+
18
+ torch = require("torch", extra="torch", purpose="HMM modeling")
19
+ nn = torch.nn
14
20
 
15
21
  logger = get_logger(__name__)
16
22
  # =============================================================================
@@ -138,6 +144,83 @@ def _safe_int_coords(var_names) -> Tuple[np.ndarray, bool]:
138
144
  return np.arange(len(var_names), dtype=int), False
139
145
 
140
146
 
147
+ def mask_layers_outside_read_span(
148
+ adata,
149
+ layers: Sequence[str],
150
+ *,
151
+ start_key: str = "reference_start",
152
+ end_key: str = "reference_end",
153
+ use_original_var_names: bool = True,
154
+ ) -> List[str]:
155
+ """Mask layer values outside read reference spans with NaN.
156
+
157
+ This uses integer coordinate comparisons against either ``adata.var["Original_var_names"]``
158
+ (when present) or ``adata.var_names``. Values strictly less than ``start_key`` or greater
159
+ than ``end_key`` are set to NaN for each read.
160
+
161
+ Args:
162
+ adata: AnnData object to modify in-place.
163
+ layers: Layer names to mask.
164
+ start_key: obs column holding reference start positions.
165
+ end_key: obs column holding reference end positions.
166
+ use_original_var_names: Use ``adata.var["Original_var_names"]`` when available.
167
+
168
+ Returns:
169
+ List of layer names that were masked.
170
+ """
171
+ if not layers:
172
+ return []
173
+
174
+ if start_key not in adata.obs or end_key not in adata.obs:
175
+ raise KeyError(f"Missing {start_key!r} or {end_key!r} in adata.obs.")
176
+
177
+ coord_source = adata.var_names
178
+ if use_original_var_names and "Original_var_names" in adata.var:
179
+ orig = np.asarray(adata.var["Original_var_names"])
180
+ if orig.size == adata.n_vars:
181
+ try:
182
+ orig_numeric = np.asarray(orig, dtype=float)
183
+ except (TypeError, ValueError):
184
+ orig_numeric = None
185
+ if orig_numeric is not None and np.isfinite(orig_numeric).any():
186
+ coord_source = orig
187
+
188
+ coords, _ = _safe_int_coords(coord_source)
189
+ if coords.shape[0] != adata.n_vars:
190
+ raise ValueError("Coordinate source length does not match adata.n_vars.")
191
+
192
+ try:
193
+ starts = np.asarray(adata.obs[start_key], dtype=float)
194
+ ends = np.asarray(adata.obs[end_key], dtype=float)
195
+ except (TypeError, ValueError) as exc:
196
+ raise ValueError("Start/end positions must be numeric.") from exc
197
+
198
+ masked = []
199
+ for layer in layers:
200
+ if layer not in adata.layers:
201
+ raise KeyError(f"Layer {layer!r} not found in adata.layers.")
202
+
203
+ arr = np.asarray(adata.layers[layer])
204
+ if not np.issubdtype(arr.dtype, np.floating):
205
+ arr = arr.astype(float, copy=True)
206
+
207
+ for i in range(adata.n_obs):
208
+ start = starts[i]
209
+ end = ends[i]
210
+ if not np.isfinite(start) or not np.isfinite(end):
211
+ continue
212
+ start_i = int(start)
213
+ end_i = int(end)
214
+ row_mask = (coords < start_i) | (coords > end_i)
215
+ if row_mask.any():
216
+ arr[i, row_mask] = np.nan
217
+
218
+ adata.layers[layer] = arr
219
+ masked.append(layer)
220
+
221
+ return masked
222
+
223
+
141
224
  def _logsumexp(x: torch.Tensor, dim: int) -> torch.Tensor:
142
225
  """Compute log-sum-exp in a numerically stable way.
143
226
 
@@ -1058,6 +1141,8 @@ class BaseHMM(nn.Module):
1058
1141
  uns_key: str = "hmm_appended_layers",
1059
1142
  uns_flag: str = "hmm_annotated",
1060
1143
  force_redo: bool = False,
1144
+ mask_to_read_span: bool = True,
1145
+ mask_use_original_var_names: bool = True,
1061
1146
  device: Optional[Union[str, torch.device]] = None,
1062
1147
  **kwargs,
1063
1148
  ):
@@ -1079,6 +1164,8 @@ class BaseHMM(nn.Module):
1079
1164
  uns_key: .uns key to track appended layers.
1080
1165
  uns_flag: .uns flag to mark annotations.
1081
1166
  force_redo: Whether to overwrite existing layers.
1167
+ mask_to_read_span: Whether to mask appended layers outside read spans.
1168
+ mask_use_original_var_names: Use ``adata.var["Original_var_names"]`` when available.
1082
1169
  device: Device specifier.
1083
1170
  **kwargs: Additional parameters for specialized workflows.
1084
1171
 
@@ -1239,6 +1326,13 @@ class BaseHMM(nn.Module):
1239
1326
  np.asarray(adata.layers[nm])
1240
1327
  )
1241
1328
 
1329
+ if mask_to_read_span and appended:
1330
+ mask_layers_outside_read_span(
1331
+ adata,
1332
+ appended,
1333
+ use_original_var_names=mask_use_original_var_names,
1334
+ )
1335
+
1242
1336
  adata.uns[uns_key] = appended
1243
1337
  adata.uns[uns_flag] = True
1244
1338
  return None
smftools/hmm/__init__.py CHANGED
@@ -1,13 +1,24 @@
1
- from .call_hmm_peaks import call_hmm_peaks
2
- from .display_hmm import display_hmm
3
- from .hmm_readwrite import load_hmm, save_hmm
4
- from .nucleosome_hmm_refinement import infer_nucleosomes_in_large_bound, refine_nucleosome_calls
5
-
6
- __all__ = [
7
- "call_hmm_peaks",
8
- "display_hmm",
9
- "load_hmm",
10
- "refine_nucleosome_calls",
11
- "infer_nucleosomes_in_large_bound",
12
- "save_hmm",
13
- ]
1
+ from __future__ import annotations
2
+
3
+ from importlib import import_module
4
+
5
+ _LAZY_ATTRS = {
6
+ "call_hmm_peaks": "smftools.hmm.call_hmm_peaks",
7
+ "display_hmm": "smftools.hmm.display_hmm",
8
+ "load_hmm": "smftools.hmm.hmm_readwrite",
9
+ "save_hmm": "smftools.hmm.hmm_readwrite",
10
+ "infer_nucleosomes_in_large_bound": "smftools.hmm.nucleosome_hmm_refinement",
11
+ "refine_nucleosome_calls": "smftools.hmm.nucleosome_hmm_refinement",
12
+ }
13
+
14
+
15
+ def __getattr__(name: str):
16
+ if name in _LAZY_ATTRS:
17
+ module = import_module(_LAZY_ATTRS[name])
18
+ attr = getattr(module, name)
19
+ globals()[name] = attr
20
+ return attr
21
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
22
+
23
+
24
+ __all__ = list(_LAZY_ATTRS.keys())
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
  import torch
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # calculate_distances
2
4
 
3
5
  def calculate_distances(intervals, threshold=0.9):
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  def call_hmm_peaks(
2
4
  adata,
3
5
  feature_configs,
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  def train_hmm(
2
4
  data,
3
5
  emission_probs=[[0.8, 0.2], [0.2, 0.8]],
@@ -1,9 +1,11 @@
1
- # FILE: smftools/hmm/call_hmm_peaks.py
1
+ from __future__ import annotations
2
2
 
3
+ # FILE: smftools/hmm/call_hmm_peaks.py
3
4
  from pathlib import Path
4
5
  from typing import Any, Dict, Optional, Sequence, Union
5
6
 
6
7
  from smftools.logging_utils import get_logger
8
+ from smftools.optional_imports import require
7
9
 
8
10
  logger = get_logger(__name__)
9
11
 
@@ -35,12 +37,13 @@ def call_hmm_peaks(
35
37
  - adata.var["is_in_any_{layer}_peak_{ref}"]
36
38
  - adata.var["is_in_any_peak"] (global)
37
39
  """
38
- import matplotlib.pyplot as plt
39
40
  import numpy as np
40
41
  import pandas as pd
41
42
  from scipy.signal import find_peaks
42
43
  from scipy.sparse import issparse
43
44
 
45
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM peak plots")
46
+
44
47
  if not inplace:
45
48
  adata = adata.copy()
46
49
 
@@ -1,4 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from smftools.logging_utils import get_logger
4
+ from smftools.optional_imports import require
2
5
 
3
6
  logger = get_logger(__name__)
4
7
 
@@ -11,7 +14,7 @@ def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=[
11
14
  state_labels: Optional labels for states.
12
15
  obs_labels: Optional labels for observations.
13
16
  """
14
- import torch
17
+ torch = require("torch", extra="torch", purpose="HMM display")
15
18
 
16
19
  logger.info("**HMM Model Overview**")
17
20
  logger.info("%s", hmm)
@@ -1,3 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+
1
6
  def load_hmm(model_path, device="cpu"):
2
7
  """
3
8
  Reads in a pretrained HMM.
@@ -5,7 +10,7 @@ def load_hmm(model_path, device="cpu"):
5
10
  Parameters:
6
11
  model_path (str): Path to a pretrained HMM
7
12
  """
8
- import torch
13
+ torch = require("torch", extra="torch", purpose="HMM read/write")
9
14
 
10
15
  # Load model using PyTorch
11
16
  hmm = torch.load(model_path)
@@ -20,6 +25,6 @@ def save_hmm(model, model_path):
20
25
  model: HMM model instance.
21
26
  model_path: Output path for the model.
22
27
  """
23
- import torch
28
+ torch = require("torch", extra="torch", purpose="HMM read/write")
24
29
 
25
30
  torch.save(model, model_path)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from smftools.logging_utils import get_logger
2
4
 
3
5
  logger = get_logger(__name__)
@@ -1,41 +1,66 @@
1
- from .bam_functions import (
2
- align_and_sort_BAM,
3
- bam_qc,
4
- concatenate_fastqs_to_bam,
5
- count_aligned_reads,
6
- demux_and_index_BAM,
7
- extract_base_identities,
8
- extract_read_features_from_bam,
9
- extract_readnames_from_bam,
10
- separate_bam_by_bc,
11
- split_and_index_BAM,
12
- )
13
- from .basecalling import canoncall, modcall
14
- from .bed_functions import (
15
- _bed_to_bigwig,
16
- _plot_bed_histograms,
17
- aligned_BAM_to_bed,
18
- extract_read_lengths_from_bed,
19
- )
20
- from .converted_BAM_to_adata import converted_BAM_to_adata
21
- from .fasta_functions import (
22
- find_conversion_sites,
23
- generate_converted_FASTA,
24
- get_chromosome_lengths,
25
- get_native_references,
26
- index_fasta,
27
- subsample_fasta_from_bed,
28
- )
29
- from .h5ad_functions import add_demux_type_annotation, add_read_length_and_mapping_qc
30
- from .modkit_extract_to_adata import modkit_extract_to_adata
31
- from .modkit_functions import extract_mods, make_modbed, modQC
32
- from .ohe import ohe_batching, ohe_layers_decode, one_hot_decode, one_hot_encode
33
- from .pod5_functions import basecall_pod5s, fast5_to_pod5, subsample_pod5
34
- from .run_multiqc import run_multiqc
1
+ from __future__ import annotations
2
+
3
+ from importlib import import_module
4
+
5
+ _LAZY_ATTRS = {
6
+ "_bed_to_bigwig": "smftools.informatics.bed_functions",
7
+ "_plot_bed_histograms": "smftools.informatics.bed_functions",
8
+ "add_demux_type_annotation": "smftools.informatics.h5ad_functions",
9
+ "add_read_tag_annotations": "smftools.informatics.h5ad_functions",
10
+ "add_read_length_and_mapping_qc": "smftools.informatics.h5ad_functions",
11
+ "align_and_sort_BAM": "smftools.informatics.bam_functions",
12
+ "bam_qc": "smftools.informatics.bam_functions",
13
+ "basecall_pod5s": "smftools.informatics.pod5_functions",
14
+ "canoncall": "smftools.informatics.basecalling",
15
+ "concatenate_fastqs_to_bam": "smftools.informatics.bam_functions",
16
+ "converted_BAM_to_adata": "smftools.informatics.converted_BAM_to_adata",
17
+ "count_aligned_reads": "smftools.informatics.bam_functions",
18
+ "demux_and_index_BAM": "smftools.informatics.bam_functions",
19
+ "extract_base_identities": "smftools.informatics.bam_functions",
20
+ "extract_mods": "smftools.informatics.modkit_functions",
21
+ "extract_read_features_from_bam": "smftools.informatics.bam_functions",
22
+ "extract_read_tags_from_bam": "smftools.informatics.bam_functions",
23
+ "extract_read_lengths_from_bed": "smftools.informatics.bed_functions",
24
+ "extract_readnames_from_bam": "smftools.informatics.bam_functions",
25
+ "fast5_to_pod5": "smftools.informatics.pod5_functions",
26
+ "find_conversion_sites": "smftools.informatics.fasta_functions",
27
+ "generate_converted_FASTA": "smftools.informatics.fasta_functions",
28
+ "get_chromosome_lengths": "smftools.informatics.fasta_functions",
29
+ "get_native_references": "smftools.informatics.fasta_functions",
30
+ "index_fasta": "smftools.informatics.fasta_functions",
31
+ "make_modbed": "smftools.informatics.modkit_functions",
32
+ "modQC": "smftools.informatics.modkit_functions",
33
+ "modcall": "smftools.informatics.basecalling",
34
+ "modkit_extract_to_adata": "smftools.informatics.modkit_extract_to_adata",
35
+ "decode_int_sequence": "smftools.informatics.sequence_encoding",
36
+ "encode_sequence_to_int": "smftools.informatics.sequence_encoding",
37
+ "ohe_batching": "smftools.informatics.ohe",
38
+ "ohe_layers_decode": "smftools.informatics.ohe",
39
+ "one_hot_decode": "smftools.informatics.ohe",
40
+ "one_hot_encode": "smftools.informatics.ohe",
41
+ "run_multiqc": "smftools.informatics.run_multiqc",
42
+ "separate_bam_by_bc": "smftools.informatics.bam_functions",
43
+ "split_and_index_BAM": "smftools.informatics.bam_functions",
44
+ "subsample_fasta_from_bed": "smftools.informatics.fasta_functions",
45
+ "subsample_pod5": "smftools.informatics.pod5_functions",
46
+ "aligned_BAM_to_bed": "smftools.informatics.bed_functions",
47
+ }
48
+
49
+
50
+ def __getattr__(name: str):
51
+ if name in _LAZY_ATTRS:
52
+ module = import_module(_LAZY_ATTRS[name])
53
+ attr = getattr(module, name)
54
+ globals()[name] = attr
55
+ return attr
56
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
57
+
35
58
 
36
59
  __all__ = [
37
60
  "basecall_pod5s",
38
61
  "converted_BAM_to_adata",
62
+ "decode_int_sequence",
63
+ "encode_sequence_to_int",
39
64
  "subsample_fasta_from_bed",
40
65
  "subsample_pod5",
41
66
  "fast5_to_pod5",