smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +49 -7
  4. smftools/cli/hmm_adata.py +250 -32
  5. smftools/cli/latent_adata.py +773 -0
  6. smftools/cli/load_adata.py +78 -74
  7. smftools/cli/preprocess_adata.py +122 -58
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +74 -112
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +52 -4
  12. smftools/config/conversion.yaml +1 -1
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +85 -12
  15. smftools/config/experiment_config.py +146 -1
  16. smftools/constants.py +69 -0
  17. smftools/hmm/HMM.py +88 -0
  18. smftools/hmm/call_hmm_peaks.py +1 -1
  19. smftools/informatics/__init__.py +6 -0
  20. smftools/informatics/bam_functions.py +358 -8
  21. smftools/informatics/binarize_converted_base_identities.py +2 -89
  22. smftools/informatics/converted_BAM_to_adata.py +636 -175
  23. smftools/informatics/h5ad_functions.py +198 -2
  24. smftools/informatics/modkit_extract_to_adata.py +1007 -425
  25. smftools/informatics/sequence_encoding.py +72 -0
  26. smftools/logging_utils.py +21 -2
  27. smftools/metadata.py +1 -1
  28. smftools/plotting/__init__.py +26 -3
  29. smftools/plotting/autocorrelation_plotting.py +22 -4
  30. smftools/plotting/chimeric_plotting.py +1893 -0
  31. smftools/plotting/classifiers.py +28 -14
  32. smftools/plotting/general_plotting.py +62 -1583
  33. smftools/plotting/hmm_plotting.py +1670 -8
  34. smftools/plotting/latent_plotting.py +804 -0
  35. smftools/plotting/plotting_utils.py +243 -0
  36. smftools/plotting/position_stats.py +16 -8
  37. smftools/plotting/preprocess_plotting.py +281 -0
  38. smftools/plotting/qc_plotting.py +8 -3
  39. smftools/plotting/spatial_plotting.py +1134 -0
  40. smftools/plotting/variant_plotting.py +1231 -0
  41. smftools/preprocessing/__init__.py +4 -0
  42. smftools/preprocessing/append_base_context.py +18 -18
  43. smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
  44. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  45. smftools/preprocessing/append_variant_call_layer.py +480 -0
  46. smftools/preprocessing/calculate_consensus.py +1 -1
  47. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  48. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  49. smftools/preprocessing/invert_adata.py +1 -0
  50. smftools/readwrite.py +159 -99
  51. smftools/schema/anndata_schema_v1.yaml +15 -1
  52. smftools/tools/__init__.py +10 -0
  53. smftools/tools/calculate_knn.py +121 -0
  54. smftools/tools/calculate_leiden.py +57 -0
  55. smftools/tools/calculate_nmf.py +130 -0
  56. smftools/tools/calculate_pca.py +180 -0
  57. smftools/tools/calculate_umap.py +79 -80
  58. smftools/tools/position_stats.py +4 -4
  59. smftools/tools/rolling_nn_distance.py +872 -0
  60. smftools/tools/sequence_alignment.py +140 -0
  61. smftools/tools/tensor_factorization.py +217 -0
  62. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
  63. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
  64. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  65. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  66. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
smftools/cli/helpers.py CHANGED
@@ -5,6 +5,17 @@ from pathlib import Path
5
5
 
6
6
  import anndata as ad
7
7
 
8
+ from smftools.constants import (
9
+ CHIMERIC_DIR,
10
+ H5_DIR,
11
+ HMM_DIR,
12
+ LATENT_DIR,
13
+ LOAD_DIR,
14
+ PREPROCESS_DIR,
15
+ SPATIAL_DIR,
16
+ VARIANT_DIR,
17
+ )
18
+
8
19
  from ..metadata import write_runtime_schema_yaml
9
20
  from ..readwrite import safe_write_h5ad
10
21
 
@@ -16,28 +27,40 @@ class AdataPaths:
16
27
  pp_dedup: Path
17
28
  spatial: Path
18
29
  hmm: Path
30
+ latent: Path
31
+ variant: Path
32
+ chimeric: Path
19
33
 
20
34
 
21
35
  def get_adata_paths(cfg) -> AdataPaths:
22
36
  """
23
37
  Central helper: given cfg, compute all standard AnnData paths.
24
38
  """
25
- h5_dir = Path(cfg.output_directory) / "h5ads"
26
-
27
- raw = h5_dir / f"{cfg.experiment_name}.h5ad.gz"
39
+ output_directory = Path(cfg.output_directory)
28
40
 
29
- pp = h5_dir / f"{cfg.experiment_name}_preprocessed.h5ad.gz"
41
+ # Raw and Preprocessed adata file pathes will have set names.
42
+ raw = output_directory / LOAD_DIR / H5_DIR / f"{cfg.experiment_name}.h5ad.gz"
43
+ pp = output_directory / PREPROCESS_DIR / H5_DIR / f"{cfg.experiment_name}_preprocessed.h5ad.gz"
30
44
 
31
45
  if cfg.smf_modality == "direct":
32
46
  # direct SMF: duplicate-removed path is just preprocessed path
33
47
  pp_dedup = pp
34
48
  else:
35
- pp_dedup = h5_dir / f"{cfg.experiment_name}_preprocessed_duplicates_removed.h5ad.gz"
49
+ pp_dedup = (
50
+ output_directory
51
+ / PREPROCESS_DIR
52
+ / H5_DIR
53
+ / f"{cfg.experiment_name}_preprocessed_duplicates_removed.h5ad.gz"
54
+ )
36
55
 
37
56
  pp_dedup_base = pp_dedup.name.removesuffix(".h5ad.gz")
38
57
 
39
- spatial = h5_dir / f"{pp_dedup_base}_spatial.h5ad.gz"
40
- hmm = h5_dir / f"{pp_dedup_base}_spatial_hmm.h5ad.gz"
58
+ # All of the following just append a new suffix to the preprocessesed_deduplicated base name
59
+ spatial = output_directory / SPATIAL_DIR / H5_DIR / f"{pp_dedup_base}_spatial.h5ad.gz"
60
+ hmm = output_directory / HMM_DIR / H5_DIR / f"{pp_dedup_base}_hmm.h5ad.gz"
61
+ latent = output_directory / LATENT_DIR / H5_DIR / f"{pp_dedup_base}_latent.h5ad.gz"
62
+ variant = output_directory / VARIANT_DIR / H5_DIR / f"{pp_dedup_base}_variant.h5ad.gz"
63
+ chimeric = output_directory / CHIMERIC_DIR / H5_DIR / f"{pp_dedup_base}_chimeric.h5ad.gz"
41
64
 
42
65
  return AdataPaths(
43
66
  raw=raw,
@@ -45,7 +68,26 @@ def get_adata_paths(cfg) -> AdataPaths:
45
68
  pp_dedup=pp_dedup,
46
69
  spatial=spatial,
47
70
  hmm=hmm,
71
+ latent=latent,
72
+ variant=variant,
73
+ chimeric=chimeric,
74
+ )
75
+
76
+
77
+ def load_experiment_config(config_path: str):
78
+ """Load ExperimentConfig without invoking any pipeline stages."""
79
+ from datetime import datetime
80
+ from importlib import resources
81
+
82
+ from ..config import ExperimentConfig, LoadExperimentConfig
83
+
84
+ date_str = datetime.today().strftime("%y%m%d")
85
+ loader = LoadExperimentConfig(config_path)
86
+ defaults_dir = resources.files("smftools").joinpath("config")
87
+ cfg, _ = ExperimentConfig.from_var_dict(
88
+ loader.var_dict, date_str=date_str, defaults_dir=defaults_dir
48
89
  )
90
+ return cfg
49
91
 
50
92
 
51
93
  def write_gz_h5ad(adata: ad.AnnData, path: Path) -> Path:
smftools/cli/hmm_adata.py CHANGED
@@ -1,13 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import copy
4
+ import logging
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
7
  from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
7
8
 
8
9
  import numpy as np
9
10
 
10
- from smftools.logging_utils import get_logger
11
+ from smftools.constants import HMM_DIR, LOGGING_DIR
12
+ from smftools.logging_utils import get_logger, setup_logging
11
13
  from smftools.optional_imports import require
12
14
 
13
15
  # FIX: import _to_dense_np to avoid NameError
@@ -16,15 +18,135 @@ from ..hmm.HMM import _safe_int_coords, _to_dense_np, create_hmm, normalize_hmm_
16
18
  logger = get_logger(__name__)
17
19
 
18
20
  if TYPE_CHECKING:
19
- import torch as torch_types
21
+ import torch
20
22
 
21
23
  torch = require("torch", extra="torch", purpose="HMM CLI")
24
+ mpl = require("matplotlib", extra="plotting", purpose="HMM plotting")
25
+ mpl_colors = require("matplotlib.colors", extra="plotting", purpose="HMM plotting")
22
26
 
23
27
  # =============================================================================
24
28
  # Helpers: extracting training arrays
25
29
  # =============================================================================
26
30
 
27
31
 
32
+ def _strip_hmm_layer_prefix(layer: str) -> str:
33
+ """Strip methbase prefixes and length suffixes from an HMM layer name.
34
+
35
+ Args:
36
+ layer: Full layer name (e.g., "GpC_small_accessible_patch_lengths").
37
+
38
+ Returns:
39
+ The base layer name without methbase prefixes or length suffixes.
40
+ """
41
+ base = layer
42
+ for prefix in ("Combined_", "GpC_", "CpG_", "C_", "A_"):
43
+ if base.startswith(prefix):
44
+ base = base[len(prefix) :]
45
+ break
46
+ if base.endswith("_lengths"):
47
+ base = base[: -len("_lengths")]
48
+ if base.endswith("_merged"):
49
+ base = base[: -len("_merged")]
50
+ return base
51
+
52
+
53
+ def _resolve_feature_colormap(layer: str, cfg, default_cmap: str) -> Any:
54
+ """Resolve a colormap for a given HMM layer.
55
+
56
+ Args:
57
+ layer: Full layer name.
58
+ cfg: Experiment config.
59
+ default_cmap: Fallback colormap name.
60
+
61
+ Returns:
62
+ A matplotlib colormap or colormap name.
63
+ """
64
+ feature_maps = getattr(cfg, "hmm_feature_colormaps", {}) or {}
65
+ if not isinstance(feature_maps, dict):
66
+ feature_maps = {}
67
+
68
+ base = _strip_hmm_layer_prefix(layer)
69
+ value = feature_maps.get(layer, feature_maps.get(base))
70
+ if value is None:
71
+ return default_cmap
72
+
73
+ if isinstance(value, (list, tuple)):
74
+ return mpl_colors.ListedColormap(list(value))
75
+
76
+ if isinstance(value, str):
77
+ try:
78
+ mpl.colormaps.get_cmap(value)
79
+ return value
80
+ except Exception:
81
+ return mpl_colors.LinearSegmentedColormap.from_list(
82
+ f"hmm_{base}_cmap", ["#ffffff", value]
83
+ )
84
+
85
+ return default_cmap
86
+
87
+
88
+ def _resolve_feature_color(layer: str, cfg, fallback_cmap: str, idx: int, total: int) -> Any:
89
+ """Resolve a line color for a given HMM layer."""
90
+ feature_maps = getattr(cfg, "hmm_feature_colormaps", {}) or {}
91
+ if not isinstance(feature_maps, dict):
92
+ feature_maps = {}
93
+
94
+ base = _strip_hmm_layer_prefix(layer)
95
+ value = feature_maps.get(layer, feature_maps.get(base))
96
+ if isinstance(value, str):
97
+ try:
98
+ mpl.colormaps.get_cmap(value)
99
+ except Exception:
100
+ return value
101
+ return mpl.colormaps.get_cmap(value)(0.75)
102
+ if isinstance(value, (list, tuple)) and value:
103
+ return value[-1]
104
+
105
+ cmap_obj = mpl.colormaps.get_cmap(fallback_cmap)
106
+ if total <= 1:
107
+ return cmap_obj(0.5)
108
+ return cmap_obj(idx / (total - 1))
109
+
110
+
111
+ def _resolve_length_feature_ranges(
112
+ layer: str, cfg, default_cmap: str
113
+ ) -> List[Tuple[int, int, Any]]:
114
+ """Resolve length-based feature ranges to colors for size contour overlays."""
115
+ base = _strip_hmm_layer_prefix(layer)
116
+ feature_sets = getattr(cfg, "hmm_feature_sets", {}) or {}
117
+ if not isinstance(feature_sets, dict):
118
+ return []
119
+
120
+ feature_key = None
121
+ if "accessible" in base:
122
+ feature_key = "accessible"
123
+ elif "footprint" in base:
124
+ feature_key = "footprint"
125
+
126
+ if feature_key is None:
127
+ return []
128
+
129
+ features = feature_sets.get(feature_key, {}).get("features", {})
130
+ if not isinstance(features, dict):
131
+ return []
132
+
133
+ ranges: List[Tuple[int, int, Any]] = []
134
+ for feature_name, bounds in features.items():
135
+ if not isinstance(bounds, (list, tuple)) or len(bounds) != 2:
136
+ continue
137
+ min_len, max_len = bounds
138
+ if max_len is None or (isinstance(max_len, (float, int)) and np.isinf(max_len)):
139
+ max_len = int(1e9)
140
+ try:
141
+ min_len_int = int(min_len)
142
+ max_len_int = int(max_len)
143
+ except (TypeError, ValueError):
144
+ continue
145
+ color = _resolve_feature_color(feature_name, cfg, default_cmap, 0, 1)
146
+ ranges.append((min_len_int, max_len_int, color))
147
+ return ranges
148
+
149
+
28
150
  def _get_training_matrix(
29
151
  subset, cols_mask: np.ndarray, smf_modality: Optional[str], cfg
30
152
  ) -> Tuple[np.ndarray, Optional[str]]:
@@ -445,34 +567,37 @@ def hmm_adata(config_path: str):
445
567
  - Call hmm_adata_core(cfg, adata, paths)
446
568
  """
447
569
  from ..readwrite import safe_read_h5ad
448
- from .helpers import get_adata_paths
449
- from .load_adata import load_adata
450
- from .preprocess_adata import preprocess_adata
451
- from .spatial_adata import spatial_adata
570
+ from .helpers import get_adata_paths, load_experiment_config
452
571
 
453
572
  # 1) load cfg / stage paths
454
- _, _, cfg = load_adata(config_path)
455
- paths = get_adata_paths(cfg)
573
+ cfg = load_experiment_config(config_path)
456
574
 
457
- # 2) make sure upstream stages are run (they have their own skipping logic)
458
- preprocess_adata(config_path)
459
- spatial_ad, spatial_path = spatial_adata(config_path)
575
+ paths = get_adata_paths(cfg)
460
576
 
461
- # 3) choose starting AnnData
577
+ # 2) choose starting AnnData
462
578
  # Prefer:
463
579
  # - existing HMM h5ad if not forcing redo
464
580
  # - in-memory spatial_ad from wrapper call
465
581
  # - saved spatial / pp_dedup / pp / raw on disk
466
582
  if paths.hmm.exists() and not (cfg.force_redo_hmm_fit or cfg.force_redo_hmm_apply):
467
- adata, _ = safe_read_h5ad(paths.hmm)
468
- return adata, paths.hmm
583
+ logger.debug(f"Skipping hmm. HMM AnnData found: {paths.hmm}")
584
+ return None
469
585
 
470
- if spatial_ad is not None:
471
- adata = spatial_ad
472
- source_path = spatial_path
586
+ if paths.hmm.exists():
587
+ adata, _ = safe_read_h5ad(paths.hmm)
588
+ source_path = paths.hmm
589
+ elif paths.latent.exists():
590
+ adata, _ = safe_read_h5ad(paths.latent)
591
+ source_path = paths.latent
473
592
  elif paths.spatial.exists():
474
593
  adata, _ = safe_read_h5ad(paths.spatial)
475
594
  source_path = paths.spatial
595
+ elif paths.chimeric.exists():
596
+ adata, _ = safe_read_h5ad(paths.chimeric)
597
+ source_path = paths.chimeric
598
+ elif paths.variant.exists():
599
+ adata, _ = safe_read_h5ad(paths.variant)
600
+ source_path = paths.variant
476
601
  elif paths.pp_dedup.exists():
477
602
  adata, _ = safe_read_h5ad(paths.pp_dedup)
478
603
  source_path = paths.pp_dedup
@@ -516,11 +641,14 @@ def hmm_adata_core(
516
641
  Does NOT decide which h5ad to start from – that is the wrapper's job.
517
642
  """
518
643
 
644
+ from datetime import datetime
645
+
519
646
  import numpy as np
520
647
 
521
648
  from ..hmm import call_hmm_peaks
522
649
  from ..metadata import record_smftools_metadata
523
650
  from ..plotting import (
651
+ combined_hmm_length_clustermap,
524
652
  combined_hmm_raw_clustermap,
525
653
  plot_hmm_layers_rolling_by_sample_ref,
526
654
  plot_hmm_size_contours,
@@ -528,18 +656,33 @@ def hmm_adata_core(
528
656
  from ..readwrite import make_dirs
529
657
  from .helpers import write_gz_h5ad
530
658
 
659
+ date_str = datetime.today().strftime("%y%m%d")
660
+ now = datetime.now()
661
+ time_str = now.strftime("%H%M%S")
662
+
663
+ log_level = getattr(logging, cfg.log_level.upper(), logging.INFO)
664
+
531
665
  smf_modality = cfg.smf_modality
532
666
  deaminase = smf_modality == "deaminase"
533
667
 
534
668
  output_directory = Path(cfg.output_directory)
535
- make_dirs([output_directory])
669
+ hmm_directory = output_directory / HMM_DIR
670
+ logging_directory = hmm_directory / LOGGING_DIR
671
+
672
+ make_dirs([output_directory, hmm_directory])
673
+
674
+ if cfg.emit_log_file:
675
+ log_file = logging_directory / f"{date_str}_{time_str}_log.log"
676
+ make_dirs([logging_directory])
677
+ else:
678
+ log_file = None
536
679
 
537
- pp_dir = output_directory / "preprocessed" / "deduplicated"
680
+ setup_logging(level=log_level, log_file=log_file, reconfigure=log_file is not None)
538
681
 
539
682
  # ---------------------------- HMM annotate stage ----------------------------
540
683
  if not (cfg.bypass_hmm_fit and cfg.bypass_hmm_apply):
541
- hmm_models_dir = pp_dir / "10_hmm_models"
542
- make_dirs([pp_dir, hmm_models_dir])
684
+ hmm_models_dir = hmm_directory / "10_hmm_models"
685
+ make_dirs([hmm_directory, hmm_models_dir])
543
686
 
544
687
  # Standard bookkeeping
545
688
  uns_key = "hmm_appended_layers"
@@ -743,6 +886,8 @@ def hmm_adata_core(
743
886
  uns_key=uns_key,
744
887
  uns_flag="hmm_annotated_combined",
745
888
  force_redo=force_apply,
889
+ mask_to_read_span=True,
890
+ mask_use_original_var_names=True,
746
891
  )
747
892
 
748
893
  for core_layer, dist in (
@@ -855,11 +1000,11 @@ def hmm_adata_core(
855
1000
  logger.info(f"HMM appended layers: {hmm_layers}")
856
1001
 
857
1002
  # ---------------------------- HMM peak calling stage ----------------------------
858
- hmm_dir = pp_dir / "11_hmm_peak_calling"
1003
+ hmm_dir = hmm_directory / "11_hmm_peak_calling"
859
1004
  if hmm_dir.is_dir():
860
1005
  pass
861
1006
  else:
862
- make_dirs([pp_dir, hmm_dir])
1007
+ make_dirs([hmm_directory, hmm_dir])
863
1008
 
864
1009
  call_hmm_peaks(
865
1010
  adata,
@@ -888,8 +1033,8 @@ def hmm_adata_core(
888
1033
 
889
1034
  ############################################### HMM based feature plotting ###############################################
890
1035
 
891
- hmm_dir = pp_dir / "12_hmm_clustermaps"
892
- make_dirs([pp_dir, hmm_dir])
1036
+ hmm_dir = hmm_directory / "12_hmm_clustermaps"
1037
+ make_dirs([hmm_directory, hmm_dir])
893
1038
 
894
1039
  layers: list[str] = []
895
1040
 
@@ -914,6 +1059,7 @@ def hmm_adata_core(
914
1059
  pass
915
1060
  else:
916
1061
  make_dirs([hmm_cluster_save_dir])
1062
+ hmm_cmap = _resolve_feature_colormap(layer, cfg, cfg.clustermap_cmap_hmm)
917
1063
 
918
1064
  combined_hmm_raw_clustermap(
919
1065
  adata,
@@ -924,7 +1070,7 @@ def hmm_adata_core(
924
1070
  layer_cpg=cfg.layer_for_clustermap_plotting,
925
1071
  layer_c=cfg.layer_for_clustermap_plotting,
926
1072
  layer_a=cfg.layer_for_clustermap_plotting,
927
- cmap_hmm=cfg.clustermap_cmap_hmm,
1073
+ cmap_hmm=hmm_cmap,
928
1074
  cmap_gpc=cfg.clustermap_cmap_gpc,
929
1075
  cmap_cpg=cfg.clustermap_cmap_cpg,
930
1076
  cmap_c=cfg.clustermap_cmap_c,
@@ -935,7 +1081,7 @@ def hmm_adata_core(
935
1081
  0
936
1082
  ],
937
1083
  min_position_valid_fraction=1 - cfg.position_max_nan_threshold,
938
- demux_types=("double", "already"),
1084
+ demux_types=cfg.clustermap_demux_types_to_plot,
939
1085
  save_path=hmm_cluster_save_dir,
940
1086
  normalize_hmm=False,
941
1087
  sort_by=cfg.hmm_clustermap_sortby, # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
@@ -943,14 +1089,78 @@ def hmm_adata_core(
943
1089
  deaminase=deaminase,
944
1090
  min_signal=0,
945
1091
  index_col_suffix=cfg.reindexed_var_suffix,
1092
+ overlay_variant_calls=getattr(cfg, "overlay_variant_calls", False),
1093
+ variant_overlay_seq1_color=getattr(cfg, "variant_overlay_seq1_color", "white"),
1094
+ variant_overlay_seq2_color=getattr(cfg, "variant_overlay_seq2_color", "black"),
1095
+ variant_overlay_marker_size=getattr(cfg, "variant_overlay_marker_size", 4.0),
1096
+ )
1097
+
1098
+ hmm_length_dir = hmm_directory / "12b_hmm_length_clustermaps"
1099
+ make_dirs([hmm_directory, hmm_length_dir])
1100
+
1101
+ length_layers: list[str] = []
1102
+ length_layer_roots = list(
1103
+ getattr(cfg, "hmm_clustermap_length_layers", cfg.hmm_clustermap_feature_layers)
1104
+ )
1105
+
1106
+ for base in cfg.hmm_methbases:
1107
+ length_layers.extend([f"{base}_{layer}_lengths" for layer in length_layer_roots])
1108
+
1109
+ if getattr(cfg, "hmm_run_multichannel", True) and len(cfg.hmm_methbases) >= 2:
1110
+ length_layers.extend([f"Combined_{layer}_lengths" for layer in length_layer_roots])
1111
+
1112
+ if cfg.cpg:
1113
+ length_layers.extend(["CpG_cpg_patch_lengths"])
1114
+
1115
+ for layer in length_layers:
1116
+ hmm_cluster_save_dir = hmm_length_dir / layer
1117
+ if hmm_cluster_save_dir.is_dir():
1118
+ pass
1119
+ else:
1120
+ make_dirs([hmm_cluster_save_dir])
1121
+ length_cmap = _resolve_feature_colormap(layer, cfg, "Greens")
1122
+ length_feature_ranges = _resolve_length_feature_ranges(layer, cfg, "Greens")
1123
+
1124
+ combined_hmm_length_clustermap(
1125
+ adata,
1126
+ sample_col=cfg.sample_name_col_for_plotting,
1127
+ reference_col=cfg.reference_column,
1128
+ length_layer=layer,
1129
+ layer_gpc=cfg.layer_for_clustermap_plotting,
1130
+ layer_cpg=cfg.layer_for_clustermap_plotting,
1131
+ layer_c=cfg.layer_for_clustermap_plotting,
1132
+ layer_a=cfg.layer_for_clustermap_plotting,
1133
+ cmap_lengths=length_cmap,
1134
+ cmap_gpc=cfg.clustermap_cmap_gpc,
1135
+ cmap_cpg=cfg.clustermap_cmap_cpg,
1136
+ cmap_c=cfg.clustermap_cmap_c,
1137
+ cmap_a=cfg.clustermap_cmap_a,
1138
+ min_quality=cfg.read_quality_filter_thresholds[0],
1139
+ min_length=cfg.read_len_filter_thresholds[0],
1140
+ min_mapped_length_to_reference_length_ratio=cfg.read_len_to_ref_ratio_filter_thresholds[
1141
+ 0
1142
+ ],
1143
+ min_position_valid_fraction=1 - cfg.position_max_nan_threshold,
1144
+ demux_types=cfg.clustermap_demux_types_to_plot,
1145
+ save_path=hmm_cluster_save_dir,
1146
+ sort_by=cfg.hmm_clustermap_sortby,
1147
+ bins=None,
1148
+ deaminase=deaminase,
1149
+ min_signal=0,
1150
+ index_col_suffix=cfg.reindexed_var_suffix,
1151
+ length_feature_ranges=length_feature_ranges,
1152
+ overlay_variant_calls=getattr(cfg, "overlay_variant_calls", False),
1153
+ variant_overlay_seq1_color=getattr(cfg, "variant_overlay_seq1_color", "white"),
1154
+ variant_overlay_seq2_color=getattr(cfg, "variant_overlay_seq2_color", "black"),
1155
+ variant_overlay_marker_size=getattr(cfg, "variant_overlay_marker_size", 4.0),
946
1156
  )
947
1157
 
948
- hmm_dir = pp_dir / "13_hmm_bulk_traces"
1158
+ hmm_dir = hmm_directory / "13_hmm_bulk_traces"
949
1159
 
950
1160
  if hmm_dir.is_dir():
951
1161
  logger.debug(f"{hmm_dir} already exists.")
952
1162
  else:
953
- make_dirs([pp_dir, hmm_dir])
1163
+ make_dirs([hmm_directory, hmm_dir])
954
1164
  from ..plotting import plot_hmm_layers_rolling_by_sample_ref
955
1165
 
956
1166
  bulk_hmm_layers = [
@@ -958,6 +1168,10 @@ def hmm_adata_core(
958
1168
  for layer in hmm_layers
959
1169
  if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
960
1170
  ]
1171
+ layer_colors = {
1172
+ layer: _resolve_feature_color(layer, cfg, "tab20", idx, len(bulk_hmm_layers))
1173
+ for idx, layer in enumerate(bulk_hmm_layers)
1174
+ }
961
1175
  saved = plot_hmm_layers_rolling_by_sample_ref(
962
1176
  adata,
963
1177
  layers=bulk_hmm_layers,
@@ -969,14 +1183,15 @@ def hmm_adata_core(
969
1183
  output_dir=hmm_dir,
970
1184
  save=True,
971
1185
  show_raw=False,
1186
+ layer_colors=layer_colors,
972
1187
  )
973
1188
 
974
- hmm_dir = pp_dir / "14_hmm_fragment_distributions"
1189
+ hmm_dir = hmm_directory / "14_hmm_fragment_distributions"
975
1190
 
976
1191
  if hmm_dir.is_dir():
977
1192
  logger.debug(f"{hmm_dir} already exists.")
978
1193
  else:
979
- make_dirs([pp_dir, hmm_dir])
1194
+ make_dirs([hmm_directory, hmm_dir])
980
1195
  from ..plotting import plot_hmm_size_contours
981
1196
 
982
1197
  if smf_modality == "deaminase":
@@ -1001,6 +1216,8 @@ def hmm_adata_core(
1001
1216
  for layer, max in fragments:
1002
1217
  save_path = hmm_dir / layer
1003
1218
  make_dirs([save_path])
1219
+ layer_cmap = _resolve_feature_colormap(layer, cfg, "Greens")
1220
+ feature_ranges = _resolve_length_feature_ranges(layer, cfg, "Greens")
1004
1221
 
1005
1222
  figs = plot_hmm_size_contours(
1006
1223
  adata,
@@ -1016,8 +1233,9 @@ def hmm_adata_core(
1016
1233
  dpi=200,
1017
1234
  smoothing_sigma=(10, 10),
1018
1235
  normalize_after_smoothing=True,
1019
- cmap="Greens",
1236
+ cmap=layer_cmap,
1020
1237
  log_scale_z=True,
1238
+ feature_ranges=tuple(feature_ranges),
1021
1239
  )
1022
1240
  ########################################################################################################################
1023
1241