smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/__init__.py CHANGED
@@ -1,20 +1,13 @@
1
- from .apply_hmm_batched import apply_hmm_batched
2
- from .calculate_distances import calculate_distances
3
1
  from .call_hmm_peaks import call_hmm_peaks
4
2
  from .display_hmm import display_hmm
5
3
  from .hmm_readwrite import load_hmm, save_hmm
6
- from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
7
- from .train_hmm import train_hmm
8
-
4
+ from .nucleosome_hmm_refinement import infer_nucleosomes_in_large_bound, refine_nucleosome_calls
9
5
 
10
6
  __all__ = [
11
- "apply_hmm_batched",
12
- "calculate_distances",
13
7
  "call_hmm_peaks",
14
8
  "display_hmm",
15
9
  "load_hmm",
16
10
  "refine_nucleosome_calls",
17
11
  "infer_nucleosomes_in_large_bound",
18
12
  "save_hmm",
19
- "train_hmm"
20
- ]
13
+ ]
@@ -0,0 +1,121 @@
1
+ def call_hmm_peaks(
2
+ adata,
3
+ feature_configs,
4
+ obs_column='Reference_strand',
5
+ site_types=['GpC_site', 'CpG_site'],
6
+ save_plot=False,
7
+ output_dir=None,
8
+ date_tag=None,
9
+ inplace=False
10
+ ):
11
+ """Call peaks from HMM feature layers and annotate AnnData.
12
+
13
+ Args:
14
+ adata: AnnData containing feature layers.
15
+ feature_configs: Mapping of layer name to peak config.
16
+ obs_column: Obs column for reference categories.
17
+ site_types: Site types to summarize around peaks.
18
+ save_plot: Whether to save peak plots.
19
+ output_dir: Output directory for plots.
20
+ date_tag: Optional tag for plot filenames.
21
+ inplace: Whether to modify AnnData in place.
22
+
23
+ Returns:
24
+ Annotated AnnData with peak masks and summary columns.
25
+ """
26
+ import numpy as np
27
+ import pandas as pd
28
+ import matplotlib.pyplot as plt
29
+ from scipy.signal import find_peaks
30
+
31
+ if not inplace:
32
+ adata = adata.copy()
33
+
34
+ # Ensure obs_column is categorical
35
+ if not isinstance(adata.obs[obs_column].dtype, pd.CategoricalDtype):
36
+ adata.obs[obs_column] = pd.Categorical(adata.obs[obs_column])
37
+
38
+ coordinates = adata.var_names.astype(int).values
39
+ peak_columns = []
40
+
41
+ obs_updates = {}
42
+
43
+ for feature_layer, config in feature_configs.items():
44
+ min_distance = config.get('min_distance', 200)
45
+ peak_width = config.get('peak_width', 200)
46
+ peak_prominence = config.get('peak_prominence', 0.2)
47
+ peak_threshold = config.get('peak_threshold', 0.8)
48
+
49
+ matrix = adata.layers[feature_layer]
50
+ means = np.mean(matrix, axis=0)
51
+ peak_indices, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
52
+ peak_centers = coordinates[peak_indices]
53
+ adata.uns[f'{feature_layer} peak_centers'] = peak_centers.tolist()
54
+
55
+ # Plot
56
+ plt.figure(figsize=(6, 3))
57
+ plt.plot(coordinates, means)
58
+ plt.title(f"{feature_layer} with peak calls")
59
+ plt.xlabel("Genomic position")
60
+ plt.ylabel("Mean intensity")
61
+ for i, center in enumerate(peak_centers):
62
+ start, end = center - peak_width // 2, center + peak_width // 2
63
+ plt.axvspan(start, end, color='purple', alpha=0.2)
64
+ plt.axvline(center, color='red', linestyle='--')
65
+ aligned = [end if i % 2 else start, 'left' if i % 2 else 'right']
66
+ plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
67
+ if save_plot and output_dir:
68
+ filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
69
+ plt.savefig(filename, bbox_inches='tight')
70
+ print(f"Saved plot to {filename}")
71
+ else:
72
+ plt.show()
73
+
74
+ feature_peak_columns = []
75
+ for center in peak_centers:
76
+ start, end = center - peak_width // 2, center + peak_width // 2
77
+ colname = f'{feature_layer}_peak_{center}'
78
+ peak_columns.append(colname)
79
+ feature_peak_columns.append(colname)
80
+
81
+ peak_mask = (coordinates >= start) & (coordinates <= end)
82
+ adata.var[colname] = peak_mask
83
+
84
+ region = matrix[:, peak_mask]
85
+ obs_updates[f'mean_{feature_layer}_around_{center}'] = np.mean(region, axis=1)
86
+ obs_updates[f'sum_{feature_layer}_around_{center}'] = np.sum(region, axis=1)
87
+ obs_updates[f'{feature_layer}_present_at_{center}'] = np.mean(region, axis=1) > peak_threshold
88
+
89
+ for site_type in site_types:
90
+ adata.obs[f'{site_type}_sum_around_{center}'] = 0
91
+ adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
92
+
93
+ for ref in adata.obs[obs_column].cat.categories:
94
+ ref_idx = adata.obs[obs_column] == ref
95
+ mask_key = f"{ref}_{site_type}"
96
+ for site_type in site_types:
97
+ if mask_key not in adata.var:
98
+ continue
99
+ site_mask = adata.var[mask_key].values
100
+ site_coords = coordinates[site_mask]
101
+ region_mask = (site_coords >= start) & (site_coords <= end)
102
+ if not region_mask.any():
103
+ continue
104
+ full_mask = site_mask.copy()
105
+ full_mask[site_mask] = region_mask
106
+ site_region = adata[ref_idx, full_mask].X
107
+ if hasattr(site_region, "A"):
108
+ site_region = site_region.A
109
+ if site_region.shape[1] > 0:
110
+ adata.obs.loc[ref_idx, f'{site_type}_sum_around_{center}'] = np.nansum(site_region, axis=1)
111
+ adata.obs.loc[ref_idx, f'{site_type}_mean_around_{center}'] = np.nanmean(site_region, axis=1)
112
+ else:
113
+ pass
114
+
115
+ adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
116
+ print(f"Annotated {len(peak_centers)} peaks for {feature_layer}")
117
+
118
+ adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
119
+ adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
120
+
121
+ return adata if not inplace else None
@@ -1,106 +1,314 @@
1
+ # FILE: smftools/hmm/call_hmm_peaks.py
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Optional, Sequence, Union
5
+
6
+ from smftools.logging_utils import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
1
11
  def call_hmm_peaks(
2
12
  adata,
3
- feature_configs,
4
- obs_column='Reference_strand',
5
- site_types=['GpC_site', 'CpG_site'],
6
- save_plot=False,
7
- output_dir=None,
8
- date_tag=None,
9
- inplace=False
13
+ feature_configs: Dict[str, Dict[str, Any]],
14
+ ref_column: str = "Reference_strand",
15
+ site_types: Sequence[str] = ("GpC", "CpG"),
16
+ save_plot: bool = False,
17
+ output_dir: Optional[Union[str, "Path"]] = None,
18
+ date_tag: Optional[str] = None,
19
+ inplace: bool = True,
20
+ index_col_suffix: Optional[str] = None,
21
+ alternate_labels: bool = False,
10
22
  ):
23
+ """
24
+ Peak calling over HMM (or other) layers, per reference group and per layer.
25
+ Writes:
26
+ - adata.uns["{layer}_{ref}_peak_centers"] = list of centers
27
+ - adata.var["{layer}_{ref}_peak_{center}"] boolean window masks
28
+ - adata.obs per-read summaries for each peak window:
29
+ mean_{layer}_{ref}_around_{center}
30
+ sum_{layer}_{ref}_around_{center}
31
+ {layer}_{ref}_present_at_{center} (bool)
32
+ and per site-type:
33
+ sum_{layer}_{site}_{ref}_around_{center}
34
+ mean_{layer}_{site}_{ref}_around_{center}
35
+ - adata.var["is_in_any_{layer}_peak_{ref}"]
36
+ - adata.var["is_in_any_peak"] (global)
37
+ """
38
+ import matplotlib.pyplot as plt
11
39
  import numpy as np
12
40
  import pandas as pd
13
- import matplotlib.pyplot as plt
14
41
  from scipy.signal import find_peaks
42
+ from scipy.sparse import issparse
15
43
 
16
44
  if not inplace:
17
45
  adata = adata.copy()
18
46
 
19
- # Ensure obs_column is categorical
20
- if not isinstance(adata.obs[obs_column].dtype, pd.CategoricalDtype):
21
- adata.obs[obs_column] = pd.Categorical(adata.obs[obs_column])
22
-
23
- coordinates = adata.var_names.astype(int).values
24
- peak_columns = []
25
-
26
- obs_updates = {}
27
-
28
- for feature_layer, config in feature_configs.items():
29
- min_distance = config.get('min_distance', 200)
30
- peak_width = config.get('peak_width', 200)
31
- peak_prominence = config.get('peak_prominence', 0.2)
32
- peak_threshold = config.get('peak_threshold', 0.8)
33
-
34
- matrix = adata.layers[feature_layer]
35
- means = np.mean(matrix, axis=0)
36
- peak_indices, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
37
- peak_centers = coordinates[peak_indices]
38
- adata.uns[f'{feature_layer} peak_centers'] = peak_centers.tolist()
39
-
40
- # Plot
41
- plt.figure(figsize=(6, 3))
42
- plt.plot(coordinates, means)
43
- plt.title(f"{feature_layer} with peak calls")
44
- plt.xlabel("Genomic position")
45
- plt.ylabel("Mean intensity")
46
- for i, center in enumerate(peak_centers):
47
- start, end = center - peak_width // 2, center + peak_width // 2
48
- plt.axvspan(start, end, color='purple', alpha=0.2)
49
- plt.axvline(center, color='red', linestyle='--')
50
- aligned = [end if i % 2 else start, 'left' if i % 2 else 'right']
51
- plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
52
- if save_plot and output_dir:
53
- filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
54
- plt.savefig(filename, bbox_inches='tight')
55
- print(f"Saved plot to {filename}")
47
+ if ref_column not in adata.obs:
48
+ raise KeyError(f"obs column '{ref_column}' not found")
49
+
50
+ # Ensure categorical for predictable ref iteration
51
+ if not pd.api.types.is_categorical_dtype(adata.obs[ref_column]):
52
+ adata.obs[ref_column] = adata.obs[ref_column].astype("category")
53
+
54
+ # Optional: drop duplicate obs columns once to avoid Pandas/AnnData view quirks
55
+ if getattr(adata.obs.columns, "duplicated", None) is not None:
56
+ if adata.obs.columns.duplicated().any():
57
+ adata.obs = adata.obs.loc[:, ~adata.obs.columns.duplicated(keep="first")].copy()
58
+
59
+ # Fallback coordinates from var_names
60
+ try:
61
+ base_coordinates = adata.var_names.astype(int).values
62
+ except Exception:
63
+ base_coordinates = np.arange(adata.n_vars, dtype=int)
64
+
65
+ # Output dir
66
+ if output_dir is not None:
67
+ output_dir = Path(output_dir)
68
+ output_dir.mkdir(parents=True, exist_ok=True)
69
+
70
+ # Build search pool = union of declared HMM layers and actual layers; exclude helper suffixes
71
+ declared = list(adata.uns.get("hmm_appended_layers", []) or [])
72
+ search_pool = [
73
+ layer
74
+ for layer in declared
75
+ if not any(s in layer for s in ("_lengths", "_states", "_posterior"))
76
+ ]
77
+
78
+ all_peak_var_cols = []
79
+
80
+ # Iterate per reference
81
+ for ref in adata.obs[ref_column].cat.categories:
82
+ ref_mask = (adata.obs[ref_column] == ref).values
83
+ if not ref_mask.any():
84
+ continue
85
+
86
+ # Per-ref coordinate system
87
+ if index_col_suffix is not None:
88
+ coord_col = f"{ref}_{index_col_suffix}"
89
+ if coord_col not in adata.var:
90
+ raise KeyError(
91
+ f"index_col_suffix='{index_col_suffix}' requested, missing var column '{coord_col}' for ref '{ref}'."
92
+ )
93
+ coord_vals = adata.var[coord_col].values
94
+ try:
95
+ coordinates = coord_vals.astype(int)
96
+ except Exception:
97
+ coordinates = np.asarray(coord_vals, dtype=float)
56
98
  else:
57
- plt.show()
58
-
59
- feature_peak_columns = []
60
- for center in peak_centers:
61
- start, end = center - peak_width // 2, center + peak_width // 2
62
- colname = f'{feature_layer}_peak_{center}'
63
- peak_columns.append(colname)
64
- feature_peak_columns.append(colname)
65
-
66
- peak_mask = (coordinates >= start) & (coordinates <= end)
67
- adata.var[colname] = peak_mask
68
-
69
- region = matrix[:, peak_mask]
70
- obs_updates[f'mean_{feature_layer}_around_{center}'] = np.mean(region, axis=1)
71
- obs_updates[f'sum_{feature_layer}_around_{center}'] = np.sum(region, axis=1)
72
- obs_updates[f'{feature_layer}_present_at_{center}'] = np.mean(region, axis=1) > peak_threshold
73
-
74
- for site_type in site_types:
75
- adata.obs[f'{site_type}_sum_around_{center}'] = 0
76
- adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
77
-
78
- for ref in adata.obs[obs_column].cat.categories:
79
- ref_idx = adata.obs[obs_column] == ref
80
- mask_key = f"{ref}_{site_type}"
81
- for site_type in site_types:
82
- if mask_key not in adata.var:
83
- continue
84
- site_mask = adata.var[mask_key].values
85
- site_coords = coordinates[site_mask]
86
- region_mask = (site_coords >= start) & (site_coords <= end)
87
- if not region_mask.any():
88
- continue
89
- full_mask = site_mask.copy()
90
- full_mask[site_mask] = region_mask
91
- site_region = adata[ref_idx, full_mask].X
92
- if hasattr(site_region, "A"):
93
- site_region = site_region.A
94
- if site_region.shape[1] > 0:
95
- adata.obs.loc[ref_idx, f'{site_type}_sum_around_{center}'] = np.nansum(site_region, axis=1)
96
- adata.obs.loc[ref_idx, f'{site_type}_mean_around_{center}'] = np.nanmean(site_region, axis=1)
99
+ coordinates = base_coordinates
100
+
101
+ if coordinates.shape[0] != adata.n_vars:
102
+ raise ValueError(f"Coordinate length {coordinates.shape[0]} != n_vars {adata.n_vars}")
103
+
104
+ # Feature keys to consider
105
+ for feature_key, config in feature_configs.items():
106
+ # Resolve candidate layers: exact → suffix → direct present
107
+ candidates = [ln for ln in search_pool if ln == feature_key]
108
+ if not candidates:
109
+ candidates = [ln for ln in search_pool if str(ln).endswith(feature_key)]
110
+ if not candidates and feature_key in adata.layers:
111
+ candidates = [feature_key]
112
+
113
+ if not candidates:
114
+ logger.warning(
115
+ "[call_hmm_peaks] No layers found matching '%s' in ref '%s'. Skipping.",
116
+ feature_key,
117
+ ref,
118
+ )
119
+ continue
120
+
121
+ # Hyperparams (sanitized)
122
+ min_distance = max(1, int(config.get("min_distance", 200)))
123
+ peak_width = max(1, int(config.get("peak_width", 200)))
124
+ peak_prom = float(config.get("peak_prominence", 0.2))
125
+ peak_threshold = float(config.get("peak_threshold", 0.8))
126
+ rolling_window = max(1, int(config.get("rolling_window", 1)))
127
+
128
+ for layer_name in candidates:
129
+ if layer_name not in adata.layers:
130
+ logger.warning(
131
+ "[call_hmm_peaks] Layer '%s' not in adata.layers; skipping.",
132
+ layer_name,
133
+ )
134
+ continue
135
+
136
+ # Dense layer data
137
+ L = adata.layers[layer_name]
138
+ L = L.toarray() if issparse(L) else np.asarray(L)
139
+ if L.shape != (adata.n_obs, adata.n_vars):
140
+ logger.warning(
141
+ "[call_hmm_peaks] Layer '%s' has shape %s, expected (%s, %s); skipping.",
142
+ layer_name,
143
+ L.shape,
144
+ adata.n_obs,
145
+ adata.n_vars,
146
+ )
147
+ continue
148
+
149
+ # Ref subset
150
+ matrix = L[ref_mask, :]
151
+ if matrix.size == 0 or matrix.shape[0] == 0:
152
+ continue
153
+
154
+ means = np.nanmean(matrix, axis=0)
155
+ means = np.nan_to_num(means, nan=0.0)
156
+
157
+ if rolling_window > 1:
158
+ kernel = np.ones(rolling_window, dtype=float) / float(rolling_window)
159
+ peak_metric = np.convolve(means, kernel, mode="same")
160
+ else:
161
+ peak_metric = means
162
+
163
+ # Peak detection
164
+ peak_indices, _ = find_peaks(
165
+ peak_metric, prominence=peak_prom, distance=min_distance
166
+ )
167
+ if peak_indices.size == 0:
168
+ logger.info(
169
+ "[call_hmm_peaks] No peaks for layer '%s' in ref '%s'.",
170
+ layer_name,
171
+ ref,
172
+ )
173
+ continue
174
+
175
+ peak_centers = coordinates[peak_indices]
176
+ adata.uns[f"{layer_name}_{ref}_peak_centers"] = peak_centers.tolist()
177
+
178
+ # Plot once per layer/ref
179
+ fig, ax = plt.subplots(figsize=(6, 3))
180
+ ax.plot(coordinates, peak_metric, linewidth=1)
181
+ ax.set_title(f"{layer_name} peaks in {ref}")
182
+ ax.set_xlabel("Coordinate")
183
+ ax.set_ylabel(f"Rolling Mean (win={rolling_window})")
184
+ for i, center in enumerate(peak_centers):
185
+ start = center - peak_width // 2
186
+ end = center + peak_width // 2
187
+ height = peak_metric[peak_indices[i]]
188
+ ax.axvspan(start, end, alpha=0.2)
189
+ ax.axvline(center, linestyle="--", linewidth=0.8)
190
+ x_text, ha = (
191
+ (start, "right") if (not alternate_labels or i % 2 == 0) else (end, "left")
192
+ )
193
+ ax.text(
194
+ x_text, height * 0.8, f"Peak {i}\n{center}", ha=ha, va="bottom", fontsize=8
195
+ )
196
+
197
+ if save_plot and output_dir is not None:
198
+ tag = date_tag or "output"
199
+ safe_ref = str(ref).replace("/", "_")
200
+ safe_layer = str(layer_name).replace("/", "_")
201
+ fname = output_dir / f"{tag}_{safe_layer}_{safe_ref}_peaks.png"
202
+ fig.savefig(fname, bbox_inches="tight", dpi=200)
203
+ logger.info("[call_hmm_peaks] Saved plot to %s", fname)
204
+ plt.close(fig)
205
+ else:
206
+ fig.tight_layout()
207
+ plt.show()
208
+
209
+ # Collect new obs columns; assign once per layer/ref
210
+ new_obs_cols: Dict[str, np.ndarray] = {}
211
+ feature_peak_cols = []
212
+
213
+ for center in np.asarray(peak_centers).tolist():
214
+ start = center - peak_width // 2
215
+ end = center + peak_width // 2
216
+
217
+ # var window mask
218
+ colname = f"{layer_name}_{ref}_peak_{center}"
219
+ feature_peak_cols.append(colname)
220
+ all_peak_var_cols.append(colname)
221
+ peak_mask = (coordinates >= start) & (coordinates <= end)
222
+ adata.var[colname] = peak_mask
223
+
224
+ # feature-layer summaries for reads in this ref
225
+ region = matrix[:, peak_mask] # (n_ref, n_window)
226
+
227
+ mean_col = f"mean_{layer_name}_{ref}_around_{center}"
228
+ sum_col = f"sum_{layer_name}_{ref}_around_{center}"
229
+ present_col = f"{layer_name}_{ref}_present_at_{center}"
230
+
231
+ for nm, default, dt in (
232
+ (mean_col, np.nan, float),
233
+ (sum_col, 0.0, float),
234
+ (present_col, False, bool),
235
+ ):
236
+ if nm not in new_obs_cols:
237
+ new_obs_cols[nm] = np.full(adata.n_obs, default, dtype=dt)
238
+
239
+ if region.shape[1] > 0:
240
+ means_per_read = np.nanmean(region, axis=1)
241
+ sums_per_read = np.nansum(region, axis=1)
97
242
  else:
98
- pass
243
+ means_per_read = np.full(matrix.shape[0], np.nan, dtype=float)
244
+ sums_per_read = np.zeros(matrix.shape[0], dtype=float)
245
+
246
+ new_obs_cols[mean_col][ref_mask] = means_per_read
247
+ new_obs_cols[sum_col][ref_mask] = sums_per_read
248
+ new_obs_cols[present_col][ref_mask] = (
249
+ np.nan_to_num(means_per_read, nan=0.0) > peak_threshold
250
+ )
251
+
252
+ # site-type summaries from adata.X, not an AnnData view
253
+ Xmat = adata.X
254
+ for site_type in site_types:
255
+ mask_key = f"{ref}_{site_type}_site"
256
+ if mask_key not in adata.var:
257
+ continue
258
+
259
+ site_mask = adata.var[mask_key].values.astype(bool)
260
+ if not site_mask.any():
261
+ continue
262
+
263
+ site_coords = coordinates[site_mask]
264
+ site_region_mask = (site_coords >= start) & (site_coords <= end)
265
+ sum_site_col = f"sum_{layer_name}_{site_type}_{ref}_around_{center}"
266
+ mean_site_col = f"mean_{layer_name}_{site_type}_{ref}_around_{center}"
267
+
268
+ if sum_site_col not in new_obs_cols:
269
+ new_obs_cols[sum_site_col] = np.zeros(adata.n_obs, dtype=float)
270
+ if mean_site_col not in new_obs_cols:
271
+ new_obs_cols[mean_site_col] = np.full(adata.n_obs, np.nan, dtype=float)
272
+
273
+ if not site_region_mask.any():
274
+ continue
275
+
276
+ full_mask = np.zeros_like(site_mask, dtype=bool)
277
+ full_mask[site_mask] = site_region_mask
278
+
279
+ if issparse(Xmat):
280
+ site_region = Xmat[ref_mask][:, full_mask]
281
+ site_region = site_region.toarray()
282
+ else:
283
+ Xnp = np.asarray(Xmat)
284
+ site_region = Xnp[np.asarray(ref_mask), :][:, np.asarray(full_mask)]
285
+
286
+ if site_region.shape[1] > 0:
287
+ new_obs_cols[sum_site_col][ref_mask] = np.nansum(site_region, axis=1)
288
+ new_obs_cols[mean_site_col][ref_mask] = np.nanmean(site_region, axis=1)
289
+
290
+ # one-shot assignment to avoid fragmentation
291
+ if new_obs_cols:
292
+ adata.obs = adata.obs.assign(
293
+ **{k: pd.Series(v, index=adata.obs.index) for k, v in new_obs_cols.items()}
294
+ )
295
+
296
+ # per (layer, ref) any-peak
297
+ any_col = f"is_in_any_{layer_name}_peak_{ref}"
298
+ if feature_peak_cols:
299
+ adata.var[any_col] = adata.var[feature_peak_cols].any(axis=1)
300
+ else:
301
+ adata.var[any_col] = False
99
302
 
100
- adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
101
- print(f"Annotated {len(peak_centers)} peaks for {feature_layer}")
303
+ logger.info(
304
+ "[call_hmm_peaks] Annotated %s peaks for layer '%s' in ref '%s'.",
305
+ len(peak_centers),
306
+ layer_name,
307
+ ref,
308
+ )
102
309
 
103
- adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
104
- adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
310
+ # global any-peak across all layers/refs
311
+ if all_peak_var_cols:
312
+ adata.var["is_in_any_peak"] = adata.var[all_peak_var_cols].any(axis=1)
105
313
 
106
- return adata if not inplace else None
314
+ return None if inplace else adata
@@ -1,18 +1,31 @@
1
+ from smftools.logging_utils import get_logger
2
+
3
+ logger = get_logger(__name__)
4
+
5
+
1
6
  def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
7
+ """Log a summary of HMM transition and emission parameters.
8
+
9
+ Args:
10
+ hmm: HMM object with edges and distributions.
11
+ state_labels: Optional labels for states.
12
+ obs_labels: Optional labels for observations.
13
+ """
2
14
  import torch
3
- print("\n**HMM Model Overview**")
4
- print(hmm)
5
15
 
6
- print("\n**Transition Matrix**")
16
+ logger.info("**HMM Model Overview**")
17
+ logger.info("%s", hmm)
18
+
19
+ logger.info("**Transition Matrix**")
7
20
  transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
8
21
  for i, row in enumerate(transition_matrix):
9
22
  label = state_labels[i] if state_labels else f"State {i}"
10
23
  formatted_row = ", ".join(f"{p:.6f}" for p in row)
11
- print(f"{label}: [{formatted_row}]")
24
+ logger.info("%s: [%s]", label, formatted_row)
12
25
 
13
- print("\n**Emission Probabilities**")
26
+ logger.info("**Emission Probabilities**")
14
27
  for i, dist in enumerate(hmm.distributions):
15
28
  label = state_labels[i] if state_labels else f"State {i}"
16
29
  probs = dist.probs.detach().cpu().numpy()
17
30
  formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
18
- print(f"{label}: {formatted_emissions}")
31
+ logger.info("%s: %s", label, formatted_emissions)
@@ -1,16 +1,25 @@
1
- def load_hmm(model_path, device='cpu'):
1
+ def load_hmm(model_path, device="cpu"):
2
2
  """
3
3
  Reads in a pretrained HMM.
4
-
4
+
5
5
  Parameters:
6
6
  model_path (str): Path to a pretrained HMM
7
7
  """
8
8
  import torch
9
+
9
10
  # Load model using PyTorch
10
11
  hmm = torch.load(model_path)
11
- hmm.to(device)
12
+ hmm.to(device)
12
13
  return hmm
13
14
 
15
+
14
16
  def save_hmm(model, model_path):
17
+ """Save a pretrained HMM to disk.
18
+
19
+ Args:
20
+ model: HMM model instance.
21
+ model_path: Output path for the model.
22
+ """
15
23
  import torch
16
- torch.save(model, model_path)
24
+
25
+ torch.save(model, model_path)