smftools 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. smftools/__init__.py +2 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/__init__.py +0 -0
  4. smftools/cli/archived/cli_flows.py +94 -0
  5. smftools/cli/helpers.py +48 -0
  6. smftools/cli/hmm_adata.py +361 -0
  7. smftools/cli/load_adata.py +637 -0
  8. smftools/cli/preprocess_adata.py +455 -0
  9. smftools/cli/spatial_adata.py +697 -0
  10. smftools/cli_entry.py +434 -0
  11. smftools/config/conversion.yaml +18 -6
  12. smftools/config/deaminase.yaml +18 -11
  13. smftools/config/default.yaml +151 -36
  14. smftools/config/direct.yaml +28 -1
  15. smftools/config/discover_input_files.py +115 -0
  16. smftools/config/experiment_config.py +225 -27
  17. smftools/hmm/HMM.py +12 -1
  18. smftools/hmm/__init__.py +0 -6
  19. smftools/hmm/archived/call_hmm_peaks.py +106 -0
  20. smftools/hmm/call_hmm_peaks.py +318 -90
  21. smftools/informatics/__init__.py +13 -7
  22. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  23. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  24. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  25. smftools/informatics/{helpers → archived/helpers/archived}/aligned_BAM_to_bed.py +6 -4
  26. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  27. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  28. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  29. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  30. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  31. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +1 -1
  32. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  33. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  34. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  35. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  36. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  38. smftools/informatics/{helpers → archived/helpers/archived}/plot_bed_histograms.py +0 -19
  39. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +6 -5
  40. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +7 -7
  41. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  42. smftools/informatics/bam_functions.py +811 -0
  43. smftools/informatics/basecalling.py +67 -0
  44. smftools/informatics/bed_functions.py +366 -0
  45. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +42 -30
  46. smftools/informatics/fasta_functions.py +255 -0
  47. smftools/informatics/h5ad_functions.py +197 -0
  48. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +142 -59
  49. smftools/informatics/modkit_functions.py +129 -0
  50. smftools/informatics/ohe.py +160 -0
  51. smftools/informatics/pod5_functions.py +224 -0
  52. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  53. smftools/plotting/autocorrelation_plotting.py +1 -3
  54. smftools/plotting/general_plotting.py +1084 -363
  55. smftools/plotting/position_stats.py +3 -3
  56. smftools/preprocessing/__init__.py +4 -4
  57. smftools/preprocessing/append_base_context.py +35 -26
  58. smftools/preprocessing/append_binary_layer_by_base_context.py +6 -6
  59. smftools/preprocessing/binarize.py +17 -0
  60. smftools/preprocessing/binarize_on_Youden.py +11 -9
  61. smftools/preprocessing/calculate_complexity_II.py +1 -1
  62. smftools/preprocessing/calculate_coverage.py +16 -13
  63. smftools/preprocessing/calculate_position_Youden.py +42 -26
  64. smftools/preprocessing/calculate_read_modification_stats.py +2 -2
  65. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +1 -1
  66. smftools/preprocessing/filter_reads_on_modification_thresholds.py +20 -20
  67. smftools/preprocessing/flag_duplicate_reads.py +2 -2
  68. smftools/preprocessing/invert_adata.py +1 -1
  69. smftools/preprocessing/load_sample_sheet.py +1 -1
  70. smftools/preprocessing/reindex_references_adata.py +37 -0
  71. smftools/readwrite.py +360 -140
  72. {smftools-0.2.1.dist-info → smftools-0.2.4.dist-info}/METADATA +26 -19
  73. smftools-0.2.4.dist-info/RECORD +176 -0
  74. smftools-0.2.4.dist-info/entry_points.txt +2 -0
  75. smftools/cli.py +0 -184
  76. smftools/informatics/fast5_to_pod5.py +0 -24
  77. smftools/informatics/helpers/__init__.py +0 -73
  78. smftools/informatics/helpers/align_and_sort_BAM.py +0 -86
  79. smftools/informatics/helpers/bam_qc.py +0 -66
  80. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  81. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -378
  82. smftools/informatics/helpers/discover_input_files.py +0 -100
  83. smftools/informatics/helpers/index_fasta.py +0 -12
  84. smftools/informatics/helpers/make_dirs.py +0 -21
  85. smftools/informatics/readwrite.py +0 -106
  86. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  87. smftools/load_adata.py +0 -1346
  88. smftools-0.2.1.dist-info/RECORD +0 -161
  89. smftools-0.2.1.dist-info/entry_points.txt +0 -2
  90. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  91. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  92. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  93. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  94. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  95. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  96. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +0 -0
  97. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  98. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  99. /smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +0 -0
  100. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  101. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  102. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  103. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  104. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  105. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  106. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  107. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  108. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  109. /smftools/informatics/{helpers/binarize_converted_base_identities.py → binarize_converted_base_identities.py} +0 -0
  110. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  111. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archives/add_read_length_and_mapping_qc.py} +0 -0
  112. /smftools/preprocessing/{calculate_complexity.py → archives/calculate_complexity.py} +0 -0
  113. {smftools-0.2.1.dist-info → smftools-0.2.4.dist-info}/WHEEL +0 -0
  114. {smftools-0.2.1.dist-info → smftools-0.2.4.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ import warnings
6
6
  from dataclasses import dataclass, field, asdict
7
7
  from pathlib import Path
8
8
  from typing import Any, Dict, List, Optional, Tuple, Union, IO, Sequence
9
+ from .discover_input_files import discover_input_files
9
10
 
10
11
  # Optional dependency for YAML handling
11
12
  try:
@@ -213,7 +214,7 @@ def resolve_aligner_args(
213
214
  return list(default_by_aligner.get(key_align, []))
214
215
 
215
216
 
216
- # HMM default params and hepler functions
217
+ # HMM default params and helper functions
217
218
  def normalize_hmm_feature_sets(raw: Any) -> Dict[str, dict]:
218
219
  """
219
220
  Normalize user-provided `hmm_feature_sets` into canonical structure:
@@ -274,6 +275,58 @@ def normalize_hmm_feature_sets(raw: Any) -> Dict[str, dict]:
274
275
  canonical[grp] = {"features": feats, "state": state}
275
276
  return canonical
276
277
 
278
+ def normalize_peak_feature_configs(raw: Any) -> Dict[str, dict]:
279
+ """
280
+ Normalize user-provided `hmm_peak_feature_configs` into:
281
+ {
282
+ layer_name: {
283
+ "min_distance": int,
284
+ "peak_width": int,
285
+ "peak_prominence": float,
286
+ "peak_threshold": float,
287
+ "rolling_window": int,
288
+ },
289
+ ...
290
+ }
291
+
292
+ Accepts dict, JSON/string, None. Returns {} for empty input.
293
+ """
294
+ if raw is None:
295
+ return {}
296
+
297
+ parsed = raw
298
+ if isinstance(raw, str):
299
+ parsed = _try_json_or_literal(raw)
300
+ if not isinstance(parsed, dict):
301
+ return {}
302
+
303
+ defaults = {
304
+ "min_distance": 200,
305
+ "peak_width": 200,
306
+ "peak_prominence": 0.2,
307
+ "peak_threshold": 0.8,
308
+ "rolling_window": 1,
309
+ }
310
+
311
+ out: Dict[str, dict] = {}
312
+ for layer, conf in parsed.items():
313
+ if conf is None:
314
+ conf = {}
315
+ if not isinstance(conf, dict):
316
+ # allow shorthand like 300 -> interpreted as peak_width
317
+ conf = {"peak_width": conf}
318
+
319
+ full = defaults.copy()
320
+ full.update(conf)
321
+ out[str(layer)] = {
322
+ "min_distance": int(full["min_distance"]),
323
+ "peak_width": int(full["peak_width"]),
324
+ "peak_prominence": float(full["peak_prominence"]),
325
+ "peak_threshold": float(full["peak_threshold"]),
326
+ "rolling_window": int(full["rolling_window"]),
327
+ }
328
+ return out
329
+
277
330
 
278
331
  # -------------------------
279
332
  # LoadExperimentConfig
@@ -593,7 +646,10 @@ class ExperimentConfig:
593
646
  fasta: Optional[str] = None
594
647
  bam_suffix: str = ".bam"
595
648
  recursive_input_search: bool = True
649
+ input_type: Optional[str] = None
650
+ input_files: Optional[List[Path]] = None
596
651
  split_dir: str = "demultiplexed_BAMs"
652
+ split_path: Optional[str] = None
597
653
  strands: List[str] = field(default_factory=lambda: ["bottom", "top"])
598
654
  conversions: List[str] = field(default_factory=lambda: ["unconverted"])
599
655
  fasta_regions_of_interest: Optional[str] = None
@@ -601,11 +657,16 @@ class ExperimentConfig:
601
657
  sample_sheet_mapping_column: Optional[str] = 'Barcode'
602
658
  experiment_name: Optional[str] = None
603
659
  input_already_demuxed: bool = False
660
+ summary_file: Optional[Path] = None
604
661
 
605
662
  # FASTQ input specific
606
663
  fastq_barcode_map: Optional[Dict[str, str]] = None
607
664
  fastq_auto_pairing: bool = True
608
665
 
666
+ # Remove intermediate file options
667
+ delete_intermediate_bams: bool = False
668
+ delete_intermediate_tsvs: bool = True
669
+
609
670
  # Conversion/Deamination file handling
610
671
  delete_intermediate_hdfs: bool = True
611
672
 
@@ -638,13 +699,16 @@ class ExperimentConfig:
638
699
  m5C_threshold: float = 0.7
639
700
  hm5C_threshold: float = 0.7
640
701
  thresholds: List[float] = field(default_factory=list)
641
- mod_list: List[str] = field(default_factory=lambda: ["5mC_5hmC", "6mA"])
702
+ mod_list: List[str] = field(default_factory=lambda: ["5mC_5hmC", "6mA"]) # Dorado modified basecalling codes
703
+ mod_map: Dict[str, str] = field(default_factory=lambda: {'6mA': '6mA', '5mC_5hmC': '5mC'}) # Map from dorado modified basecalling codes to codes used in modkit_extract_to_adata function
642
704
 
643
705
  # Alignment params
644
706
  mapping_threshold: float = 0.01 # Min threshold for fraction of reads in a sample mapping to a reference in order to include the reference in the anndata
645
- aligner: str = "minimap2"
707
+ align_from_bam: bool = False # Whether minimap2 should align from a bam file as input. If False, aligns from FASTQ
708
+ aligner: str = "dorado"
646
709
  aligner_args: Optional[List[str]] = None
647
710
  make_bigwigs: bool = False
711
+ make_beds: bool = False
648
712
 
649
713
  # Anndata structure
650
714
  reference_column: Optional[str] = 'Reference_strand'
@@ -656,23 +720,40 @@ class ExperimentConfig:
656
720
 
657
721
  # Preprocessing - Read length and quality filter params
658
722
  read_coord_filter: Optional[Sequence[float]] = field(default_factory=lambda: [None, None])
659
- read_len_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [200, None])
660
- read_len_to_ref_ratio_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [0.4, 1.1])
661
- read_quality_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [20, None])
723
+ read_len_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [100, None])
724
+ read_len_to_ref_ratio_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [0.4, 1.5])
725
+ read_quality_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [15, None])
662
726
  read_mapping_quality_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [None, None])
663
727
 
728
+ # Preprocessing - Optional reindexing params
729
+ reindexing_offsets: Dict[str, int] = field(default_factory=dict)
730
+ reindexed_var_suffix: Optional[str] = "reindexed"
731
+
732
+ # Preprocessing - Direct mod detection binarization params
733
+ fit_position_methylation_thresholds: Optional[bool] = False # Whether to use Youden J-stat to determine position by positions thresholds for modification binarization.
734
+ binarize_on_fixed_methlyation_threshold: Optional[float] = 0.7 # The threshold used to binarize the anndata using a fixed value if fitting parameter above is False.
735
+ positive_control_sample_methylation_fitting: Optional[str] = None # A positive control Sample_name to use for fully modified template data
736
+ negative_control_sample_methylation_fitting: Optional[str] = None # A negative control Sample_name to use for fully unmodified template data
737
+ infer_on_percentile_sample_methylation_fitting: Optional[int] = 10 # If a positive/negative control are not provided and fitting the data is requested, use the indicated percentile windows from the top and bottom of the dataset.
738
+ inference_variable_sample_methylation_fitting: Optional[str] = "Raw_modification_signal" # The obs column value used for the percentile metric above.
739
+ fit_j_threshold: Optional[float] = 0.5 # The J-statistic threhold to use for determining which positions pass qc for mod detection thresholding
740
+ output_binary_layer_name: Optional[str] = "binarized_methylation"
741
+
664
742
  # Preprocessing - Read modification filter params
665
743
  read_mod_filtering_gpc_thresholds: List[float] = field(default_factory=lambda: [0.025, 0.975])
666
744
  read_mod_filtering_cpg_thresholds: List[float] = field(default_factory=lambda: [0.00, 1])
667
- read_mod_filtering_any_c_thresholds: List[float] = field(default_factory=lambda: [0.025, 0.975])
745
+ read_mod_filtering_c_thresholds: List[float] = field(default_factory=lambda: [0.025, 0.975])
668
746
  read_mod_filtering_a_thresholds: List[float] = field(default_factory=lambda: [0.025, 0.975])
669
747
  read_mod_filtering_use_other_c_as_background: bool = True
670
748
  min_valid_fraction_positions_in_read_vs_ref: float = 0.2
671
749
 
750
+ # Preprocessing - plotting params
751
+ obs_to_plot_pp_qc: List[str] = field(default_factory=lambda: ['read_length', 'mapped_length','read_quality', 'mapping_quality','mapped_length_to_reference_length_ratio', 'mapped_length_to_read_length_ratio', 'Raw_modification_signal'])
752
+
672
753
  # Preprocessing - Duplicate detection params
673
754
  duplicate_detection_site_types: List[str] = field(default_factory=lambda: ['GpC', 'CpG', 'ambiguous_GpC_CpG'])
674
755
  duplicate_detection_distance_threshold: float = 0.07
675
- hamming_vs_metric_keys: List[str] = field(default_factory=lambda: ['Fraction_any_C_site_modified'])
756
+ hamming_vs_metric_keys: List[str] = field(default_factory=lambda: ['Fraction_C_site_modified'])
676
757
  duplicate_detection_keep_best_metric: str ='read_quality'
677
758
  duplicate_detection_window_size_for_hamming_neighbors: int = 50
678
759
  duplicate_detection_min_overlapping_positions: int = 20
@@ -680,22 +761,28 @@ class ExperimentConfig:
680
761
  duplicate_detection_hierarchical_linkage: str = "average"
681
762
  duplicate_detection_do_pca: bool = False
682
763
 
683
- # Preprocessing - Complexity analysis params
764
+ # Preprocessing - Position QC
765
+ position_max_nan_threshold: float = 0.1
684
766
 
685
- # Basic Analysis - Clustermap params
767
+ # Spatial Analysis - Clustermap params
686
768
  layer_for_clustermap_plotting: Optional[str] = 'nan0_0minus1'
769
+ clustermap_cmap_c: Optional[str] = 'coolwarm'
770
+ clustermap_cmap_gpc: Optional[str] = 'coolwarm'
771
+ clustermap_cmap_cpg: Optional[str] = 'coolwarm'
772
+ clustermap_cmap_a: Optional[str] = 'coolwarm'
773
+ spatial_clustermap_sortby: Optional[str] = 'gpc'
687
774
 
688
- # Basic Analysis - UMAP/Leiden params
775
+ # Spatial Analysis - UMAP/Leiden params
689
776
  layer_for_umap_plotting: Optional[str] = 'nan_half'
690
777
  umap_layers_to_plot: List[str] = field(default_factory=lambda: ["mapped_length", "Raw_modification_signal"])
691
778
 
692
- # Basic Analysis - Spatial Autocorrelation params
779
+ # Spatial Analysis - Spatial Autocorrelation params
693
780
  rows_per_qc_autocorr_grid: int = 12
694
781
  autocorr_rolling_window_size: int = 25
695
782
  autocorr_max_lag: int = 800
696
- autocorr_site_types: List[str] = field(default_factory=lambda: ['GpC', 'CpG', 'any_C'])
783
+ autocorr_site_types: List[str] = field(default_factory=lambda: ['GpC', 'CpG', 'C'])
697
784
 
698
- # Basic Analysis - Correlation Matrix params
785
+ # Spatial Analysis - Correlation Matrix params
699
786
  correlation_matrix_types: List[str] = field(default_factory=lambda: ["pearson", "binary_covariance"])
700
787
  correlation_matrix_cmaps: List[str] = field(default_factory=lambda: ["seismic", "viridis"])
701
788
  correlation_matrix_site_types: List[str] = field(default_factory=lambda: ["GpC_site"])
@@ -717,6 +804,13 @@ class ExperimentConfig:
717
804
  cpg: Optional[bool] = False
718
805
  hmm_feature_sets: Dict[str, Any] = field(default_factory=dict)
719
806
  hmm_merge_layer_features: Optional[List[Tuple]] = field(default_factory=lambda: [(None, 80)])
807
+ clustermap_cmap_hmm: Optional[str] = 'coolwarm'
808
+ hmm_clustermap_feature_layers: List[str] = field(default_factory=lambda: ["all_accessible_features"])
809
+ hmm_clustermap_sortby: Optional[str] = 'hmm'
810
+ hmm_peak_feature_configs: Dict[str, Any] = field(default_factory=dict)
811
+
812
+ # Pipeline control flow - load adata
813
+ force_redo_load_adata: bool = False
720
814
 
721
815
  # Pipeline control flow - preprocessing and QC
722
816
  force_redo_preprocessing: bool = False
@@ -739,8 +833,8 @@ class ExperimentConfig:
739
833
  bypass_complexity_analysis: bool = False
740
834
  force_redo_complexity_analysis: bool = False
741
835
 
742
- # Pipeline control flow - Basic Analyses
743
- force_redo_basic_analyses: bool = False
836
+ # Pipeline control flow - Spatial Analyses
837
+ force_redo_spatial_analyses: bool = False
744
838
  bypass_basic_clustermaps: bool = False
745
839
  force_redo_basic_clustermaps: bool = False
746
840
  bypass_basic_umap: bool = False
@@ -860,6 +954,70 @@ class ExperimentConfig:
860
954
  if merged.get("experiment_name") is None and date_str:
861
955
  merged["experiment_name"] = f"{date_str}_SMF_experiment"
862
956
 
957
+ # Input file types and path handling
958
+ input_data_path = Path(merged['input_data_path'])
959
+
960
+ # Detect the input filetype
961
+ if input_data_path.is_file():
962
+ suffix = input_data_path.suffix.lower()
963
+ suffixes = [s.lower() for s in input_data_path.suffixes] # handles multi-part extensions
964
+
965
+ # recognize multi-suffix cases like .fastq.gz or .fq.gz
966
+ if any(s in ['.pod5', '.p5'] for s in suffixes):
967
+ input_type = "pod5"
968
+ input_files = [Path(input_data_path)]
969
+ elif any(s in ['.fast5', '.f5'] for s in suffixes):
970
+ input_type = "fast5"
971
+ input_files = [Path(input_data_path)]
972
+ elif any(s in ['.fastq', '.fq'] for s in suffixes):
973
+ input_type = "fastq"
974
+ input_files = [Path(input_data_path)]
975
+ elif any(s in ['.bam'] for s in suffixes):
976
+ input_type = "bam"
977
+ input_files = [Path(input_data_path)]
978
+ elif any(s in ['.h5ad', ".h5"] for s in suffixes):
979
+ input_type = "h5ad"
980
+ input_files = [Path(input_data_path)]
981
+ else:
982
+ print("Error detecting input file type")
983
+
984
+ elif input_data_path.is_dir():
985
+ found = discover_input_files(input_data_path, bam_suffix=merged["bam_suffix"], recursive=merged["recursive_input_search"])
986
+
987
+ if found["input_is_pod5"]:
988
+ input_type = "pod5"
989
+ input_files = found["pod5_paths"]
990
+ elif found["input_is_fast5"]:
991
+ input_type = "fast5"
992
+ input_files = found["fast5_paths"]
993
+ elif found["input_is_fastq"]:
994
+ input_type = "fastq"
995
+ input_files = found["fastq_paths"]
996
+ elif found["input_is_bam"]:
997
+ input_type = "bam"
998
+ input_files = found["bam_paths"]
999
+ elif found["input_is_h5ad"]:
1000
+ input_type = "h5ad"
1001
+ input_files = found["h5ad_paths"]
1002
+
1003
+ print(
1004
+ f"Found {found['all_files_searched']} files; "
1005
+ f"fastq={len(found['fastq_paths'])}, "
1006
+ f"bam={len(found['bam_paths'])}, "
1007
+ f"pod5={len(found['pod5_paths'])}, "
1008
+ f"fast5={len(found['fast5_paths'])}, "
1009
+ f"h5ad={len(found['h5ad_paths'])}"
1010
+ )
1011
+
1012
+ # summary file output path
1013
+ output_dir = Path(merged['output_directory'])
1014
+ summary_file_basename = merged["experiment_name"] + '_output_summary.csv'
1015
+ summary_file = output_dir / summary_file_basename
1016
+
1017
+ # Demultiplexing output path
1018
+ split_dir = merged.get("split_dir", "demultiplexed_BAMs")
1019
+ split_path = output_dir / split_dir
1020
+
863
1021
  # final normalization
864
1022
  if "strands" in merged:
865
1023
  merged["strands"] = _parse_list(merged["strands"])
@@ -900,6 +1058,9 @@ class ExperimentConfig:
900
1058
  if "mod_list" in merged:
901
1059
  merged["mod_list"] = _parse_list(merged.get("mod_list"))
902
1060
 
1061
+ # Preprocessing args
1062
+ obs_to_plot_pp_qc = _parse_list(merged.get("obs_to_plot_pp_qc", None))
1063
+
903
1064
  # HMM feature set handling
904
1065
  if "hmm_feature_sets" in merged:
905
1066
  merged["hmm_feature_sets"] = normalize_hmm_feature_sets(merged["hmm_feature_sets"])
@@ -935,14 +1096,23 @@ class ExperimentConfig:
935
1096
  hmm_methbases = ['C']
936
1097
  hmm_methbases = list(hmm_methbases)
937
1098
  hmm_merge_layer_features = _parse_list(merged.get("hmm_merge_layer_features", None))
1099
+ hmm_clustermap_feature_layers = _parse_list(merged.get("hmm_clustermap_feature_layers", "all_accessible_features"))
938
1100
 
1101
+ # HMM peak feature configs (for call_hmm_peaks)
1102
+ merged["hmm_peak_feature_configs"] = normalize_peak_feature_configs(
1103
+ merged.get("hmm_peak_feature_configs", {})
1104
+ )
1105
+ hmm_peak_feature_configs = merged.get("hmm_peak_feature_configs", {})
939
1106
 
940
1107
  # instantiate dataclass
941
1108
  instance = cls(
942
1109
  smf_modality = merged.get("smf_modality"),
943
- input_data_path = merged.get("input_data_path"),
1110
+ input_data_path = input_data_path,
944
1111
  recursive_input_search = merged.get("recursive_input_search"),
945
- output_directory = merged.get("output_directory"),
1112
+ input_type = input_type,
1113
+ input_files = input_files,
1114
+ output_directory = output_dir,
1115
+ summary_file = summary_file,
946
1116
  fasta = merged.get("fasta"),
947
1117
  sequencer = merged.get("sequencer"),
948
1118
  model_dir = merged.get("model_dir"),
@@ -950,7 +1120,8 @@ class ExperimentConfig:
950
1120
  fastq_barcode_map = merged.get("fastq_barcode_map"),
951
1121
  fastq_auto_pairing = merged.get("fastq_auto_pairing"),
952
1122
  bam_suffix = merged.get("bam_suffix", ".bam"),
953
- split_dir = merged.get("split_dir", "demultiplexed_BAMs"),
1123
+ split_dir = split_dir,
1124
+ split_path = split_path,
954
1125
  strands = merged.get("strands", ["bottom","top"]),
955
1126
  conversions = merged.get("conversions", ["unconverted"]),
956
1127
  fasta_regions_of_interest = merged.get("fasta_regions_of_interest"),
@@ -963,14 +1134,18 @@ class ExperimentConfig:
963
1134
  threads = merged.get("threads"),
964
1135
  sample_sheet_path = merged.get("sample_sheet_path"),
965
1136
  sample_sheet_mapping_column = merged.get("sample_sheet_mapping_column"),
1137
+ delete_intermediate_bams = merged.get("delete_intermediate_bams", False),
1138
+ delete_intermediate_tsvs = merged.get("delete_intermediate_tsvs", True),
1139
+ align_from_bam = merged.get("align_from_bam", False),
966
1140
  aligner = merged.get("aligner", "minimap2"),
967
1141
  aligner_args = merged.get("aligner_args", None),
968
1142
  device = merged.get("device", "auto"),
969
1143
  make_bigwigs = merged.get("make_bigwigs", False),
1144
+ make_beds = merged.get("make_beds", False),
970
1145
  delete_intermediate_hdfs = merged.get("delete_intermediate_hdfs", True),
971
1146
  mod_target_bases = merged.get("mod_target_bases", ["GpC","CpG"]),
972
1147
  enzyme_target_bases = merged.get("enzyme_target_bases", ["GpC"]),
973
- conversion_types = merged.get("conversion_types", ["5mC"]),
1148
+ conversion_types = merged.get("conversions", ["unconverted"]) + merged.get("conversion_types", ["5mC"]),
974
1149
  filter_threshold = merged.get("filter_threshold", 0.8),
975
1150
  m6A_threshold = merged.get("m6A_threshold", 0.7),
976
1151
  m5C_threshold = merged.get("m5C_threshold", 0.7),
@@ -983,14 +1158,30 @@ class ExperimentConfig:
983
1158
  reference_column = merged.get("reference_column", 'Reference_strand'),
984
1159
  sample_column = merged.get("sample_column", 'Barcode'),
985
1160
  sample_name_col_for_plotting = merged.get("sample_name_col_for_plotting", 'Barcode'),
1161
+ obs_to_plot_pp_qc = obs_to_plot_pp_qc,
1162
+ fit_position_methylation_thresholds = merged.get("fit_position_methylation_thresholds", False),
1163
+ binarize_on_fixed_methlyation_threshold = merged.get("binarize_on_fixed_methlyation_threshold", 0.7),
1164
+ positive_control_sample_methylation_fitting = merged.get("positive_control_sample_methylation_fitting", None),
1165
+ negative_control_sample_methylation_fitting = merged.get("negative_control_sample_methylation_fitting", None),
1166
+ infer_on_percentile_sample_methylation_fitting = merged.get("infer_on_percentile_sample_methylation_fitting", 10),
1167
+ inference_variable_sample_methylation_fitting = merged.get("inference_variable_sample_methylation_fitting", "Raw_modification_signal"),
1168
+ fit_j_threshold = merged.get("fit_j_threshold", 0.5),
1169
+ output_binary_layer_name = merged.get("output_binary_layer_name", "binarized_methylation"),
1170
+ reindexing_offsets = merged.get("reindexing_offsets", {None: None}),
1171
+ reindexed_var_suffix = merged.get("reindexed_var_suffix", "reindexed"),
986
1172
  layer_for_clustermap_plotting = merged.get("layer_for_clustermap_plotting", 'nan0_0minus1'),
1173
+ clustermap_cmap_c = merged.get("clustermap_cmap_c", 'coolwarm'),
1174
+ clustermap_cmap_gpc = merged.get("clustermap_cmap_gpc", 'coolwarm'),
1175
+ clustermap_cmap_cpg = merged.get("clustermap_cmap_cpg", 'coolwarm'),
1176
+ clustermap_cmap_a = merged.get("clustermap_cmap_a", 'coolwarm'),
1177
+ spatial_clustermap_sortby = merged.get("spatial_clustermap_sortby", 'gpc'),
987
1178
  layer_for_umap_plotting = merged.get("layer_for_umap_plotting", 'nan_half'),
988
1179
  umap_layers_to_plot = merged.get("umap_layers_to_plot",["mapped_length", 'Raw_modification_signal']),
989
1180
  rows_per_qc_histogram_grid = merged.get("rows_per_qc_histogram_grid", 12),
990
1181
  rows_per_qc_autocorr_grid = merged.get("rows_per_qc_autocorr_grid", 12),
991
1182
  autocorr_rolling_window_size = merged.get("autocorr_rolling_window_size", 25),
992
1183
  autocorr_max_lag = merged.get("autocorr_max_lag", 800),
993
- autocorr_site_types = merged.get("autocorr_site_types", ['GpC', 'CpG', 'any_C']),
1184
+ autocorr_site_types = merged.get("autocorr_site_types", ['GpC', 'CpG', 'C']),
994
1185
  hmm_n_states = merged.get("hmm_n_states", 2),
995
1186
  hmm_init_emission_probs = merged.get("hmm_init_emission_probs",[[0.8, 0.2], [0.2, 0.8]]),
996
1187
  hmm_init_transition_probs = merged.get("hmm_init_transition_probs",[[0.9, 0.1], [0.1, 0.9]]),
@@ -1004,17 +1195,21 @@ class ExperimentConfig:
1004
1195
  hmm_methbases = hmm_methbases,
1005
1196
  hmm_device = hmm_device,
1006
1197
  hmm_merge_layer_features = hmm_merge_layer_features,
1198
+ clustermap_cmap_hmm = merged.get("clustermap_cmap_hmm", 'coolwarm'),
1199
+ hmm_clustermap_feature_layers = hmm_clustermap_feature_layers,
1200
+ hmm_clustermap_sortby = merged.get("hmm_clustermap_sortby", 'hmm'),
1201
+ hmm_peak_feature_configs = hmm_peak_feature_configs,
1007
1202
  footprints = merged.get("footprints", None),
1008
1203
  accessible_patches = merged.get("accessible_patches", None),
1009
1204
  cpg = merged.get("cpg", None),
1010
1205
  read_coord_filter = merged.get("read_coord_filter", [None, None]),
1011
- read_len_filter_thresholds = merged.get("read_len_filter_thresholds", [200, None]),
1012
- read_len_to_ref_ratio_filter_thresholds = merged.get("read_len_to_ref_ratio_filter_thresholds", [0.4, 1.1]),
1013
- read_quality_filter_thresholds = merged.get("read_quality_filter_thresholds", [20, None]),
1206
+ read_len_filter_thresholds = merged.get("read_len_filter_thresholds", [100, None]),
1207
+ read_len_to_ref_ratio_filter_thresholds = merged.get("read_len_to_ref_ratio_filter_thresholds", [0.3, None]),
1208
+ read_quality_filter_thresholds = merged.get("read_quality_filter_thresholds", [15, None]),
1014
1209
  read_mapping_quality_filter_thresholds = merged.get("read_mapping_quality_filter_thresholds", [None, None]),
1015
1210
  read_mod_filtering_gpc_thresholds = merged.get("read_mod_filtering_gpc_thresholds", [0.025, 0.975]),
1016
1211
  read_mod_filtering_cpg_thresholds = merged.get("read_mod_filtering_cpg_thresholds", [0.0, 1.0]),
1017
- read_mod_filtering_any_c_thresholds = merged.get("read_mod_filtering_any_c_thresholds", [0.025, 0.975]),
1212
+ read_mod_filtering_c_thresholds = merged.get("read_mod_filtering_c_thresholds", [0.025, 0.975]),
1018
1213
  read_mod_filtering_a_thresholds = merged.get("read_mod_filtering_a_thresholds", [0.025, 0.975]),
1019
1214
  read_mod_filtering_use_other_c_as_background = merged.get("read_mod_filtering_use_other_c_as_background", True),
1020
1215
  min_valid_fraction_positions_in_read_vs_ref = merged.get("min_valid_fraction_positions_in_read_vs_ref", 0.2),
@@ -1026,10 +1221,12 @@ class ExperimentConfig:
1026
1221
  duplicate_detection_do_hierarchical = merged.get("duplicate_detection_do_hierarchical", True),
1027
1222
  duplicate_detection_hierarchical_linkage = merged.get("duplicate_detection_hierarchical_linkage", "average"),
1028
1223
  duplicate_detection_do_pca = merged.get("duplicate_detection_do_pca", False),
1224
+ position_max_nan_threshold = merged.get("position_max_nan_threshold", 0.1),
1029
1225
  correlation_matrix_types = merged.get("correlation_matrix_types", ["pearson", "binary_covariance"]),
1030
1226
  correlation_matrix_cmaps = merged.get("correlation_matrix_cmaps", ["seismic", "viridis"]),
1031
1227
  correlation_matrix_site_types = merged.get("correlation_matrix_site_types", ["GpC_site"]),
1032
- hamming_vs_metric_keys = merged.get("hamming_vs_metric_keys", ['Fraction_any_C_site_modified']),
1228
+ hamming_vs_metric_keys = merged.get("hamming_vs_metric_keys", ['Fraction_C_site_modified']),
1229
+ force_redo_load_adata = merged.get("force_redo_load_adata", False),
1033
1230
  force_redo_preprocessing = merged.get("force_redo_preprocessing", False),
1034
1231
  force_reload_sample_sheet = merged.get("force_reload_sample_sheet", True),
1035
1232
  bypass_add_read_length_and_mapping_qc = merged.get("bypass_add_read_length_and_mapping_qc", False),
@@ -1049,7 +1246,7 @@ class ExperimentConfig:
1049
1246
  force_redo_flag_duplicate_reads = merged.get("force_redo_flag_duplicate_reads", False),
1050
1247
  bypass_complexity_analysis = merged.get("bypass_complexity_analysis", False),
1051
1248
  force_redo_complexity_analysis = merged.get("force_redo_complexity_analysis", False),
1052
- force_redo_basic_analyses = merged.get("force_redo_basic_analyses", False),
1249
+ force_redo_spatial_analyses = merged.get("force_redo_spatial_analyses", False),
1053
1250
  bypass_basic_clustermaps = merged.get("bypass_basic_clustermaps", False),
1054
1251
  force_redo_basic_clustermaps = merged.get("force_redo_basic_clustermaps", False),
1055
1252
  bypass_basic_umap = merged.get("bypass_basic_umap", False),
@@ -1101,6 +1298,7 @@ class ExperimentConfig:
1101
1298
  # -------------------------
1102
1299
  # validation & serialization
1103
1300
  # -------------------------
1301
+ @staticmethod
1104
1302
  def _validate_hmm_features_structure(hfs: dict) -> List[str]:
1105
1303
  errs = []
1106
1304
  if not isinstance(hfs, dict):
smftools/hmm/HMM.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import List, Optional, Tuple, Union, Any, Dict
2
+ from typing import List, Optional, Tuple, Union, Any, Dict, Sequence
3
3
  import ast
4
4
  import json
5
5
 
@@ -772,6 +772,8 @@ class HMM(nn.Module):
772
772
  verbose: bool = True,
773
773
  uns_key: str = "hmm_appended_layers",
774
774
  config: Optional[Union[dict, "ExperimentConfig"]] = None, # NEW: config/dict accepted
775
+ uns_flag: str = "hmm_annotated",
776
+ force_redo: bool = False
775
777
  ):
776
778
  """
777
779
  Annotate an AnnData with HMM-derived features (in adata.obs and adata.layers).
@@ -793,6 +795,12 @@ class HMM(nn.Module):
793
795
  import torch as _torch
794
796
  from tqdm import trange, tqdm as _tqdm
795
797
 
798
+ # Only run if not already performed
799
+ already = bool(adata.uns.get(uns_flag, False))
800
+ if (already and not force_redo):
801
+ # QC already performed; nothing to do
802
+ return None if in_place else adata
803
+
796
804
  # small helpers
797
805
  def _try_json_or_literal(s):
798
806
  if s is None:
@@ -1298,6 +1306,9 @@ class HMM(nn.Module):
1298
1306
  new_list = existing + [l for l in appended_layers if l not in existing]
1299
1307
  adata.uns[uns_key] = new_list
1300
1308
 
1309
+ # Mark that the annotation has been completed
1310
+ adata.uns[uns_flag] = True
1311
+
1301
1312
  return None if in_place else adata
1302
1313
 
1303
1314
  def merge_intervals_in_layer(
smftools/hmm/__init__.py CHANGED
@@ -1,20 +1,14 @@
1
- from .apply_hmm_batched import apply_hmm_batched
2
- from .calculate_distances import calculate_distances
3
1
  from .call_hmm_peaks import call_hmm_peaks
4
2
  from .display_hmm import display_hmm
5
3
  from .hmm_readwrite import load_hmm, save_hmm
6
4
  from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
7
- from .train_hmm import train_hmm
8
5
 
9
6
 
10
7
  __all__ = [
11
- "apply_hmm_batched",
12
- "calculate_distances",
13
8
  "call_hmm_peaks",
14
9
  "display_hmm",
15
10
  "load_hmm",
16
11
  "refine_nucleosome_calls",
17
12
  "infer_nucleosomes_in_large_bound",
18
13
  "save_hmm",
19
- "train_hmm"
20
14
  ]
@@ -0,0 +1,106 @@
1
+ def call_hmm_peaks(
2
+ adata,
3
+ feature_configs,
4
+ obs_column='Reference_strand',
5
+ site_types=['GpC_site', 'CpG_site'],
6
+ save_plot=False,
7
+ output_dir=None,
8
+ date_tag=None,
9
+ inplace=False
10
+ ):
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
14
+ from scipy.signal import find_peaks
15
+
16
+ if not inplace:
17
+ adata = adata.copy()
18
+
19
+ # Ensure obs_column is categorical
20
+ if not isinstance(adata.obs[obs_column].dtype, pd.CategoricalDtype):
21
+ adata.obs[obs_column] = pd.Categorical(adata.obs[obs_column])
22
+
23
+ coordinates = adata.var_names.astype(int).values
24
+ peak_columns = []
25
+
26
+ obs_updates = {}
27
+
28
+ for feature_layer, config in feature_configs.items():
29
+ min_distance = config.get('min_distance', 200)
30
+ peak_width = config.get('peak_width', 200)
31
+ peak_prominence = config.get('peak_prominence', 0.2)
32
+ peak_threshold = config.get('peak_threshold', 0.8)
33
+
34
+ matrix = adata.layers[feature_layer]
35
+ means = np.mean(matrix, axis=0)
36
+ peak_indices, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
37
+ peak_centers = coordinates[peak_indices]
38
+ adata.uns[f'{feature_layer} peak_centers'] = peak_centers.tolist()
39
+
40
+ # Plot
41
+ plt.figure(figsize=(6, 3))
42
+ plt.plot(coordinates, means)
43
+ plt.title(f"{feature_layer} with peak calls")
44
+ plt.xlabel("Genomic position")
45
+ plt.ylabel("Mean intensity")
46
+ for i, center in enumerate(peak_centers):
47
+ start, end = center - peak_width // 2, center + peak_width // 2
48
+ plt.axvspan(start, end, color='purple', alpha=0.2)
49
+ plt.axvline(center, color='red', linestyle='--')
50
+ aligned = [end if i % 2 else start, 'left' if i % 2 else 'right']
51
+ plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
52
+ if save_plot and output_dir:
53
+ filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
54
+ plt.savefig(filename, bbox_inches='tight')
55
+ print(f"Saved plot to {filename}")
56
+ else:
57
+ plt.show()
58
+
59
+ feature_peak_columns = []
60
+ for center in peak_centers:
61
+ start, end = center - peak_width // 2, center + peak_width // 2
62
+ colname = f'{feature_layer}_peak_{center}'
63
+ peak_columns.append(colname)
64
+ feature_peak_columns.append(colname)
65
+
66
+ peak_mask = (coordinates >= start) & (coordinates <= end)
67
+ adata.var[colname] = peak_mask
68
+
69
+ region = matrix[:, peak_mask]
70
+ obs_updates[f'mean_{feature_layer}_around_{center}'] = np.mean(region, axis=1)
71
+ obs_updates[f'sum_{feature_layer}_around_{center}'] = np.sum(region, axis=1)
72
+ obs_updates[f'{feature_layer}_present_at_{center}'] = np.mean(region, axis=1) > peak_threshold
73
+
74
+ for site_type in site_types:
75
+ adata.obs[f'{site_type}_sum_around_{center}'] = 0
76
+ adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
77
+
78
+ for ref in adata.obs[obs_column].cat.categories:
79
+ ref_idx = adata.obs[obs_column] == ref
80
+ mask_key = f"{ref}_{site_type}"
81
+ for site_type in site_types:
82
+ if mask_key not in adata.var:
83
+ continue
84
+ site_mask = adata.var[mask_key].values
85
+ site_coords = coordinates[site_mask]
86
+ region_mask = (site_coords >= start) & (site_coords <= end)
87
+ if not region_mask.any():
88
+ continue
89
+ full_mask = site_mask.copy()
90
+ full_mask[site_mask] = region_mask
91
+ site_region = adata[ref_idx, full_mask].X
92
+ if hasattr(site_region, "A"):
93
+ site_region = site_region.A
94
+ if site_region.shape[1] > 0:
95
+ adata.obs.loc[ref_idx, f'{site_type}_sum_around_{center}'] = np.nansum(site_region, axis=1)
96
+ adata.obs.loc[ref_idx, f'{site_type}_mean_around_{center}'] = np.nanmean(site_region, axis=1)
97
+ else:
98
+ pass
99
+
100
+ adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
101
+ print(f"Annotated {len(peak_centers)} peaks for {feature_layer}")
102
+
103
+ adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
104
+ adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
105
+
106
+ return adata if not inplace else None