smftools 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/helpers.py +32 -6
  3. smftools/cli/hmm_adata.py +232 -31
  4. smftools/cli/latent_adata.py +318 -0
  5. smftools/cli/load_adata.py +77 -73
  6. smftools/cli/preprocess_adata.py +178 -53
  7. smftools/cli/spatial_adata.py +149 -101
  8. smftools/cli_entry.py +12 -0
  9. smftools/config/conversion.yaml +11 -1
  10. smftools/config/default.yaml +38 -1
  11. smftools/config/experiment_config.py +53 -1
  12. smftools/constants.py +65 -0
  13. smftools/hmm/HMM.py +88 -0
  14. smftools/informatics/__init__.py +6 -0
  15. smftools/informatics/bam_functions.py +358 -8
  16. smftools/informatics/converted_BAM_to_adata.py +584 -163
  17. smftools/informatics/h5ad_functions.py +115 -2
  18. smftools/informatics/modkit_extract_to_adata.py +1003 -425
  19. smftools/informatics/sequence_encoding.py +72 -0
  20. smftools/logging_utils.py +21 -2
  21. smftools/metadata.py +1 -1
  22. smftools/plotting/__init__.py +9 -0
  23. smftools/plotting/general_plotting.py +2411 -628
  24. smftools/plotting/hmm_plotting.py +85 -7
  25. smftools/preprocessing/__init__.py +1 -0
  26. smftools/preprocessing/append_base_context.py +17 -17
  27. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  28. smftools/preprocessing/calculate_consensus.py +1 -1
  29. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  30. smftools/readwrite.py +53 -17
  31. smftools/schema/anndata_schema_v1.yaml +15 -1
  32. smftools/tools/__init__.py +4 -0
  33. smftools/tools/calculate_leiden.py +57 -0
  34. smftools/tools/calculate_nmf.py +119 -0
  35. smftools/tools/calculate_umap.py +91 -8
  36. smftools/tools/rolling_nn_distance.py +235 -0
  37. smftools/tools/tensor_factorization.py +169 -0
  38. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/METADATA +8 -6
  39. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/RECORD +42 -35
  40. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  41. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  42. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,13 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import math
4
- from typing import Optional, Tuple, Union
4
+ from typing import Optional, Sequence, Tuple, Union
5
5
 
6
6
  import numpy as np
7
+ import pandas as pd
7
8
 
8
9
  from smftools.optional_imports import require
9
10
 
10
11
  plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM plots")
12
+ mpl_colors = require("matplotlib.colors", extra="plotting", purpose="HMM plots")
11
13
  pdf_backend = require(
12
14
  "matplotlib.backends.backend_pdf",
13
15
  extra="plotting",
@@ -32,6 +34,9 @@ def plot_hmm_size_contours(
32
34
  dpi: int = 150,
33
35
  vmin: Optional[float] = None,
34
36
  vmax: Optional[float] = None,
37
+ feature_ranges: Optional[Sequence[Tuple[int, int, str]]] = None,
38
+ zero_color: str = "#f5f1e8",
39
+ nan_color: str = "#E6E6E6",
35
40
  # ---------------- smoothing params ----------------
36
41
  smoothing_sigma: Optional[Union[float, Tuple[float, float]]] = None,
37
42
  normalize_after_smoothing: bool = True,
@@ -40,6 +45,9 @@ def plot_hmm_size_contours(
40
45
  """
41
46
  Create contour/pcolormesh plots of P(length | position) using a length-encoded HMM layer.
42
47
  Optional Gaussian smoothing applied to the 2D probability grid before plotting.
48
+ When feature_ranges is provided, each length row is assigned a base color based
49
+ on the matching (min_len, max_len) range and the probability value modulates
50
+ the color intensity.
43
51
 
44
52
  smoothing_sigma: None or 0 -> no smoothing.
45
53
  float -> same sigma applied to (length_axis, position_axis)
@@ -48,6 +56,51 @@ def plot_hmm_size_contours(
48
56
 
49
57
  Other args are the same as prior function.
50
58
  """
59
+ feature_ranges = tuple(feature_ranges or ())
60
+
61
+ def _resolve_length_color(length: int, fallback: str) -> Tuple[float, float, float, float]:
62
+ for min_len, max_len, color in feature_ranges:
63
+ if min_len <= length <= max_len:
64
+ return mpl_colors.to_rgba(color)
65
+ return mpl_colors.to_rgba(fallback)
66
+
67
+ def _build_length_facecolors(
68
+ Z_values: np.ndarray,
69
+ lengths: np.ndarray,
70
+ fallback_color: str,
71
+ *,
72
+ vmin_local: Optional[float],
73
+ vmax_local: Optional[float],
74
+ ) -> np.ndarray:
75
+ zero_rgba = np.array(mpl_colors.to_rgba(zero_color))
76
+ nan_rgba = np.array(mpl_colors.to_rgba(nan_color))
77
+ base_colors = np.array(
78
+ [_resolve_length_color(int(length), fallback_color) for length in lengths],
79
+ dtype=float,
80
+ )
81
+ base_colors[:, 3] = 1.0
82
+
83
+ scale = np.array(Z_values, copy=True, dtype=float)
84
+ finite_mask = np.isfinite(scale)
85
+ if not finite_mask.any():
86
+ facecolors = np.zeros(scale.shape + (4,), dtype=float)
87
+ facecolors[:] = nan_rgba
88
+ return facecolors.reshape(-1, 4)
89
+
90
+ vmin_use = np.nanmin(scale) if vmin_local is None else vmin_local
91
+ vmax_use = np.nanmax(scale) if vmax_local is None else vmax_local
92
+ denom = vmax_use - vmin_use
93
+ if denom <= 0:
94
+ norm = np.zeros_like(scale)
95
+ else:
96
+ norm = (scale - vmin_use) / denom
97
+ norm = np.clip(norm, 0, 1)
98
+
99
+ row_colors = base_colors[:, None, :]
100
+ facecolors = zero_rgba + norm[..., None] * (row_colors - zero_rgba)
101
+ facecolors[..., 3] = 1.0
102
+ facecolors[~finite_mask] = nan_rgba
103
+ return facecolors.reshape(-1, 4)
51
104
 
52
105
  # --- helper: gaussian smoothing (scipy fallback -> numpy separable conv) ---
53
106
  def _gaussian_1d_kernel(sigma: float, eps: float = 1e-12):
@@ -150,7 +203,8 @@ def plot_hmm_size_contours(
150
203
  figs = []
151
204
 
152
205
  # decide global max length to allocate y axis (cap to avoid huge memory)
153
- observed_max_len = int(np.max(full_layer)) if full_layer.size > 0 else 0
206
+ finite_lengths = full_layer[np.isfinite(full_layer) & (full_layer > 0)]
207
+ observed_max_len = int(np.nanmax(finite_lengths)) if finite_lengths.size > 0 else 0
154
208
  if max_length_cap is None:
155
209
  max_len = observed_max_len
156
210
  else:
@@ -205,10 +259,15 @@ def plot_hmm_size_contours(
205
259
  ax.text(0.5, 0.5, "no data", ha="center", va="center")
206
260
  ax.set_title(f"{sample} / {ref}")
207
261
  continue
262
+ valid_lengths = sub[np.isfinite(sub) & (sub > 0)]
263
+ if valid_lengths.size == 0:
264
+ ax.text(0.5, 0.5, "no data", ha="center", va="center")
265
+ ax.set_title(f"{sample} / {ref}")
266
+ continue
208
267
 
209
268
  # compute counts per length per position
210
269
  n_positions = sub.shape[1]
211
- max_len_local = int(sub.max()) if sub.size > 0 else 0
270
+ max_len_local = int(valid_lengths.max()) if valid_lengths.size > 0 else 0
212
271
  max_len_here = min(max_len, max_len_local)
213
272
 
214
273
  lengths_range = np.arange(1, max_len_here + 1, dtype=int)
@@ -219,7 +278,7 @@ def plot_hmm_size_contours(
219
278
  # fill Z by efficient bincount across columns
220
279
  for j in range(n_positions):
221
280
  col_vals = sub[:, j]
222
- pos_vals = col_vals[col_vals > 0].astype(int)
281
+ pos_vals = col_vals[np.isfinite(col_vals) & (col_vals > 0)].astype(int)
223
282
  if pos_vals.size == 0:
224
283
  continue
225
284
  clipped = np.clip(pos_vals, 1, max_len_here)
@@ -258,9 +317,28 @@ def plot_hmm_size_contours(
258
317
  dy = 1.0
259
318
  y_edges = np.concatenate([y - 0.5, [y[-1] + 0.5]])
260
319
 
261
- pcm = ax.pcolormesh(
262
- x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax
263
- )
320
+ if feature_ranges:
321
+ fallback_color = mpl_colors.to_rgba(plt.get_cmap(cmap)(1.0))
322
+ facecolors = _build_length_facecolors(
323
+ Z_plot,
324
+ lengths_range,
325
+ fallback_color,
326
+ vmin_local=vmin,
327
+ vmax_local=vmax,
328
+ )
329
+ pcm = ax.pcolormesh(
330
+ x_edges,
331
+ y_edges,
332
+ Z_plot,
333
+ shading="auto",
334
+ vmin=vmin,
335
+ vmax=vmax,
336
+ facecolors=facecolors,
337
+ )
338
+ else:
339
+ pcm = ax.pcolormesh(
340
+ x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax
341
+ )
264
342
  ax.set_title(f"{sample} / {ref}")
265
343
  ax.set_ylabel("length")
266
344
  if i_row == rows_on_page - 1:
@@ -5,6 +5,7 @@ from importlib import import_module
5
5
  _LAZY_ATTRS = {
6
6
  "append_base_context": "smftools.preprocessing.append_base_context",
7
7
  "append_binary_layer_by_base_context": "smftools.preprocessing.append_binary_layer_by_base_context",
8
+ "append_mismatch_frequency_sites": "smftools.preprocessing.append_mismatch_frequency_sites",
8
9
  "binarize_adata": "smftools.preprocessing.binarize",
9
10
  "binarize_on_Youden": "smftools.preprocessing.binarize_on_Youden",
10
11
  "calculate_complexity_II": "smftools.preprocessing.calculate_complexity_II",
@@ -133,23 +133,23 @@ def append_base_context(
133
133
  adata.var[f"{ref}_{site_type}_valid_coverage"] = (
134
134
  (adata.var[f"{ref}_{site_type}"]) & (adata.var[f"position_in_{ref}"])
135
135
  )
136
- if native:
137
- adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
138
- :, adata.var[f"{ref}_{site_type}_valid_coverage"]
139
- ].layers["binarized_methylation"]
140
- else:
141
- adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
142
- :, adata.var[f"{ref}_{site_type}_valid_coverage"]
143
- ].X
144
- else:
145
- pass
146
-
147
- if native:
148
- adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].layers[
149
- "binarized_methylation"
150
- ]
151
- else:
152
- adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].X
136
+ # if native:
137
+ # adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
138
+ # :, adata.var[f"{ref}_{site_type}_valid_coverage"]
139
+ # ].layers["binarized_methylation"]
140
+ # else:
141
+ # adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
142
+ # :, adata.var[f"{ref}_{site_type}_valid_coverage"]
143
+ # ].X
144
+ # else:
145
+ # pass
146
+
147
+ # if native:
148
+ # adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].layers[
149
+ # "binarized_methylation"
150
+ # ]
151
+ # else:
152
+ # adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].X
153
153
 
154
154
  # mark as done
155
155
  adata.uns[uns_flag] = True
@@ -0,0 +1,158 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Iterable, Sequence
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from smftools.constants import MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT
9
+ from smftools.logging_utils import get_logger
10
+
11
+ if TYPE_CHECKING:
12
+ import anndata as ad
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def append_mismatch_frequency_sites(
18
+ adata: "ad.AnnData",
19
+ ref_column: str = "Reference_strand",
20
+ mismatch_layer: str = "mismatch_integer_encoding",
21
+ read_span_layer: str = "read_span_mask",
22
+ mismatch_frequency_range: Sequence[float] | None = (0.05, 0.95),
23
+ uns_flag: str = "append_mismatch_frequency_sites_performed",
24
+ force_redo: bool = False,
25
+ bypass: bool = False,
26
+ ) -> None:
27
+ """Append mismatch frequency metadata and variable-site flags per reference.
28
+
29
+ Args:
30
+ adata: AnnData object.
31
+ ref_column: Obs column defining reference categories.
32
+ mismatch_layer: Layer containing mismatch integer encodings.
33
+ read_span_layer: Layer containing read span masks (1=covered, 0=not covered).
34
+ mismatch_frequency_range: Lower/upper bounds (inclusive) for variable site flagging.
35
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
36
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
37
+ bypass: Whether to skip running this step.
38
+ """
39
+ if bypass:
40
+ return
41
+
42
+ already = bool(adata.uns.get(uns_flag, False))
43
+ if already and not force_redo:
44
+ return
45
+
46
+ if mismatch_layer not in adata.layers:
47
+ logger.debug(
48
+ "Mismatch layer '%s' not found; skipping mismatch frequency step.", mismatch_layer
49
+ )
50
+ return
51
+
52
+ mismatch_map = adata.uns.get("mismatch_integer_encoding_map", {})
53
+ if not mismatch_map:
54
+ logger.debug("Mismatch encoding map not found; skipping mismatch frequency step.")
55
+ return
56
+
57
+ n_value = mismatch_map.get("N", MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT["N"])
58
+ pad_value = mismatch_map.get("PAD", MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT["PAD"])
59
+
60
+ base_int_to_label = {
61
+ int(value): str(base)
62
+ for base, value in mismatch_map.items()
63
+ if base not in {"N", "PAD"} and isinstance(value, (int, np.integer))
64
+ }
65
+ if not base_int_to_label:
66
+ logger.debug("Mismatch encoding map missing base labels; skipping mismatch frequency step.")
67
+ return
68
+
69
+ has_span_mask = read_span_layer in adata.layers
70
+ if not has_span_mask:
71
+ logger.debug(
72
+ "Read span mask '%s' not found; mismatch frequencies will be computed over all reads.",
73
+ read_span_layer,
74
+ )
75
+
76
+ references = adata.obs[ref_column].cat.categories
77
+ n_vars = adata.shape[1]
78
+
79
+ if mismatch_frequency_range is None:
80
+ mismatch_frequency_range = (0.0, 1.0)
81
+
82
+ lower_bound, upper_bound = mismatch_frequency_range
83
+
84
+ for ref in references:
85
+ ref_mask = adata.obs[ref_column] == ref
86
+ ref_position_mask = adata.var.get(f"position_in_{ref}")
87
+ if ref_position_mask is None:
88
+ ref_position_mask = pd.Series(np.ones(n_vars, dtype=bool), index=adata.var.index)
89
+ else:
90
+ ref_position_mask = ref_position_mask.astype(bool)
91
+
92
+ frequency_values = np.full(n_vars, np.nan, dtype=float)
93
+ variable_flags = np.zeros(n_vars, dtype=bool)
94
+ mismatch_base_frequencies: list[list[tuple[str, float]]] = [[] for _ in range(n_vars)]
95
+
96
+ if ref_mask.sum() == 0:
97
+ adata.var[f"{ref}_mismatch_frequency"] = pd.Series(
98
+ frequency_values, index=adata.var.index
99
+ )
100
+ adata.var[f"{ref}_variable_sequence_site"] = pd.Series(
101
+ variable_flags, index=adata.var.index
102
+ )
103
+ adata.var[f"{ref}_mismatch_base_frequencies"] = pd.Series(
104
+ mismatch_base_frequencies, index=adata.var.index
105
+ )
106
+ continue
107
+
108
+ mismatch_matrix = np.asarray(adata.layers[mismatch_layer][ref_mask])
109
+ if has_span_mask:
110
+ span_matrix = np.asarray(adata.layers[read_span_layer][ref_mask])
111
+ coverage_mask = span_matrix > 0
112
+ coverage_counts = coverage_mask.sum(axis=0).astype(float)
113
+ else:
114
+ coverage_mask = np.ones_like(mismatch_matrix, dtype=bool)
115
+ coverage_counts = np.full(n_vars, ref_mask.sum(), dtype=float)
116
+
117
+ mismatch_mask = (~np.isin(mismatch_matrix, [n_value, pad_value])) & coverage_mask
118
+ mismatch_counts = mismatch_mask.sum(axis=0)
119
+
120
+ frequency_values = np.divide(
121
+ mismatch_counts,
122
+ coverage_counts,
123
+ out=np.full(n_vars, np.nan, dtype=float),
124
+ where=coverage_counts > 0,
125
+ )
126
+ frequency_values = np.where(ref_position_mask.values, frequency_values, np.nan)
127
+
128
+ variable_flags = (
129
+ (frequency_values >= lower_bound)
130
+ & (frequency_values <= upper_bound)
131
+ & ref_position_mask.values
132
+ )
133
+
134
+ base_counts_by_int: dict[int, np.ndarray] = {}
135
+ for base_int in base_int_to_label:
136
+ base_counts_by_int[base_int] = ((mismatch_matrix == base_int) & coverage_mask).sum(
137
+ axis=0
138
+ )
139
+
140
+ for idx in range(n_vars):
141
+ if not ref_position_mask.iloc[idx] or coverage_counts[idx] == 0:
142
+ continue
143
+ base_freqs: list[tuple[str, float]] = []
144
+ for base_int, base_label in base_int_to_label.items():
145
+ count = base_counts_by_int[base_int][idx]
146
+ if count > 0:
147
+ base_freqs.append((base_label, float(count / coverage_counts[idx])))
148
+ mismatch_base_frequencies[idx] = base_freqs
149
+
150
+ adata.var[f"{ref}_mismatch_frequency"] = pd.Series(frequency_values, index=adata.var.index)
151
+ adata.var[f"{ref}_variable_sequence_site"] = pd.Series(
152
+ variable_flags, index=adata.var.index
153
+ )
154
+ adata.var[f"{ref}_mismatch_base_frequencies"] = pd.Series(
155
+ mismatch_base_frequencies, index=adata.var.index
156
+ )
157
+
158
+ adata.uns[uns_flag] = True
@@ -53,4 +53,4 @@ def calculate_consensus(
53
53
  else:
54
54
  adata.var[f"{reference}_consensus_across_samples"] = consensus_sequence_list
55
55
 
56
- adata.uns[f"{reference}_consensus_sequence"] = consensus_sequence_list
56
+ adata.uns[f"{reference}_consensus_sequence"] = str(consensus_sequence_list)
@@ -20,6 +20,7 @@ def calculate_read_modification_stats(
20
20
  force_redo: bool = False,
21
21
  valid_sites_only: bool = False,
22
22
  valid_site_suffix: str = "_valid_coverage",
23
+ smf_modality: str = "conversion",
23
24
  ) -> None:
24
25
  """Add methylation/deamination statistics for each read.
25
26
 
@@ -80,8 +81,12 @@ def calculate_read_modification_stats(
80
81
  for ref in references:
81
82
  ref_subset = adata[adata.obs[reference_column] == ref]
82
83
  for site_type in site_types:
84
+ site_subset = ref_subset[:, ref_subset.var[f"{ref}_{site_type}{valid_site_suffix}"]]
83
85
  logger.info("Iterating over %s_%s", ref, site_type)
84
- observation_matrix = ref_subset.obsm[f"{ref}_{site_type}{valid_site_suffix}"]
86
+ if smf_modality == "native":
87
+ observation_matrix = site_subset.layers["binarized_methylation"]
88
+ else:
89
+ observation_matrix = site_subset.X
85
90
  total_positions_in_read = np.nansum(~np.isnan(observation_matrix), axis=1)
86
91
  total_positions_in_reference = observation_matrix.shape[1]
87
92
  fraction_valid_positions_in_read_vs_ref = (
smftools/readwrite.py CHANGED
@@ -431,6 +431,8 @@ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir=No
431
431
  "layers_skipped": [],
432
432
  "obsm_converted": [],
433
433
  "obsm_skipped": [],
434
+ "varm_converted": [],
435
+ "varm_skipped": [],
434
436
  "X_replaced_or_converted": None,
435
437
  "errors": [],
436
438
  }
@@ -605,10 +607,16 @@ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir=No
605
607
 
606
608
  def _sanitize_layers_obsm(src_dict, which: str):
607
609
  """
608
- Ensure arrays in layers/obsm are numeric and non-object dtype.
610
+ Ensure arrays in layers/obsm/varm are numeric and non-object dtype.
609
611
  Returns a cleaned dict suitable to pass into AnnData(...)
610
612
  If an entry is not convertible, it is backed up & skipped.
611
613
  """
614
+ report_map = {
615
+ "layers": ("layers_converted", "layers_skipped"),
616
+ "obsm": ("obsm_converted", "obsm_skipped"),
617
+ "varm": ("varm_converted", "varm_skipped"),
618
+ }
619
+ converted_key, skipped_key = report_map[which]
612
620
  cleaned = {}
613
621
  for k, v in src_dict.items():
614
622
  try:
@@ -618,9 +626,7 @@ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir=No
618
626
  arr_f = arr.astype(float)
619
627
  cleaned[k] = arr_f
620
628
  report_key = f"{which}.{k}"
621
- report["layers_converted"].append(
622
- report_key
623
- ) if which == "layers" else report["obsm_converted"].append(report_key)
629
+ report[converted_key].append(report_key)
624
630
  if verbose:
625
631
  print(f" {which}.{k} object array coerced to float.")
626
632
  except Exception:
@@ -628,18 +634,13 @@ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir=No
628
634
  arr_i = arr.astype(int)
629
635
  cleaned[k] = arr_i
630
636
  report_key = f"{which}.{k}"
631
- report["layers_converted"].append(
632
- report_key
633
- ) if which == "layers" else report["obsm_converted"].append(report_key)
637
+ report[converted_key].append(report_key)
634
638
  if verbose:
635
639
  print(f" {which}.{k} object array coerced to int.")
636
640
  except Exception:
637
641
  if backup:
638
642
  _backup(v, f"{which}_{k}_backup")
639
- if which == "layers":
640
- report["layers_skipped"].append(k)
641
- else:
642
- report["obsm_skipped"].append(k)
643
+ report[skipped_key].append(k)
643
644
  if verbose:
644
645
  print(
645
646
  f" SKIPPING {which}.{k} (object dtype not numeric). Backed up: {backup}"
@@ -650,10 +651,7 @@ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir=No
650
651
  except Exception as e:
651
652
  if backup:
652
653
  _backup(v, f"{which}_{k}_backup")
653
- if which == "layers":
654
- report["layers_skipped"].append(k)
655
- else:
656
- report["obsm_skipped"].append(k)
654
+ report[skipped_key].append(k)
657
655
  msg = f" SKIPPING {which}.{k} due to conversion error: {e}"
658
656
  report["errors"].append(msg)
659
657
  if verbose:
@@ -693,6 +691,7 @@ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir=No
693
691
  # ---------- sanitize layers and obsm ----------
694
692
  layers_src = getattr(adata, "layers", {})
695
693
  obsm_src = getattr(adata, "obsm", {})
694
+ varm_src = getattr(adata, "varm", {})
696
695
 
697
696
  try:
698
697
  layers_clean = _sanitize_layers_obsm(layers_src, "layers")
@@ -712,6 +711,15 @@ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir=No
712
711
  print(msg)
713
712
  obsm_clean = {}
714
713
 
714
+ try:
715
+ varm_clean = _sanitize_layers_obsm(varm_src, "varm")
716
+ except Exception as e:
717
+ msg = f"Failed to sanitize varm: {e}"
718
+ report["errors"].append(msg)
719
+ if verbose:
720
+ print(msg)
721
+ varm_clean = {}
722
+
715
723
  # ---------- handle X ----------
716
724
  X_to_use = adata.X
717
725
  try:
@@ -747,7 +755,7 @@ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir=No
747
755
  layers=layers_clean,
748
756
  uns=uns_clean,
749
757
  obsm=obsm_clean,
750
- varm=getattr(adata, "varm", None),
758
+ varm=varm_clean,
751
759
  )
752
760
 
753
761
  # preserve names (as strings)
@@ -872,6 +880,16 @@ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir=No
872
880
  }
873
881
  )
874
882
 
883
+ # varm
884
+ for k, v in adata_copy.varm.items():
885
+ meta_rows.append(
886
+ {
887
+ "kind": "varm",
888
+ "name": k,
889
+ "dtype": str(np.asarray(v).dtype),
890
+ }
891
+ )
892
+
875
893
  # uns
876
894
  for k, v in adata_copy.uns.items():
877
895
  meta_rows.append(
@@ -977,6 +995,7 @@ def safe_read_h5ad(
977
995
  "parsed_uns_json_keys": [],
978
996
  "restored_layers": [],
979
997
  "restored_obsm": [],
998
+ "restored_varm": [],
980
999
  "recategorized_obs": [],
981
1000
  "recategorized_var": [],
982
1001
  "missing_backups": [],
@@ -1215,7 +1234,7 @@ def safe_read_h5ad(
1215
1234
  print(f"[safe_read_h5ad] restored adata.uns['{key}'] from {full}")
1216
1235
 
1217
1236
  # 5) Restore layers and obsm from backups if present
1218
- # expected backup names: layers_<name>_backup.pkl, obsm_<name>_backup.pkl
1237
+ # expected backup names: layers_<name>_backup.pkl, obsm_<name>_backup.pkl, varm_<name>_backup.pkl
1219
1238
  if os.path.isdir(backup_dir):
1220
1239
  for fname in os.listdir(backup_dir):
1221
1240
  if fname.startswith("layers_") and fname.endswith("_backup.pkl"):
@@ -1248,6 +1267,21 @@ def safe_read_h5ad(
1248
1267
  f"Failed to restore obsm['{obsm_name}'] from {full}: {e}"
1249
1268
  )
1250
1269
 
1270
+ if fname.startswith("varm_") and fname.endswith("_backup.pkl"):
1271
+ varm_name = fname[len("varm_") : -len("_backup.pkl")]
1272
+ full = os.path.join(backup_dir, fname)
1273
+ val = _load_pickle_if_exists(full)
1274
+ if val is not None:
1275
+ try:
1276
+ adata.varm[varm_name] = np.asarray(val)
1277
+ report["restored_varm"].append((varm_name, full))
1278
+ if verbose:
1279
+ print(f"[safe_read_h5ad] restored varm['{varm_name}'] from {full}")
1280
+ except Exception as e:
1281
+ report["errors"].append(
1282
+ f"Failed to restore varm['{varm_name}'] from {full}: {e}"
1283
+ )
1284
+
1251
1285
  # 6) If restore_backups True but some expected backups missing, note them
1252
1286
  if restore_backups and os.path.isdir(backup_dir):
1253
1287
  # detect common expected names from obs/var/uns/layers in adata
@@ -1297,6 +1331,8 @@ def safe_read_h5ad(
1297
1331
  print("Restored layers:", report["restored_layers"])
1298
1332
  if report["restored_obsm"]:
1299
1333
  print("Restored obsm:", report["restored_obsm"])
1334
+ if report["restored_varm"]:
1335
+ print("Restored varm:", report["restored_varm"])
1300
1336
  if report["recategorized_obs"] or report["recategorized_var"]:
1301
1337
  print(
1302
1338
  "Recategorized columns (obs/var):",
@@ -60,6 +60,20 @@ stages:
60
60
  notes: "Mapping quality score."
61
61
  requires: []
62
62
  optional_inputs: []
63
+ reference_start:
64
+ dtype: "float"
65
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
66
+ modified_by: []
67
+ notes: "0-based reference start position for the alignment."
68
+ requires: []
69
+ optional_inputs: []
70
+ reference_end:
71
+ dtype: "float"
72
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
73
+ modified_by: []
74
+ notes: "0-based reference end position (exclusive) for the alignment."
75
+ requires: []
76
+ optional_inputs: []
63
77
  read_length_to_reference_length_ratio:
64
78
  dtype: "float"
65
79
  created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
@@ -179,7 +193,7 @@ stages:
179
193
  obs:
180
194
  leiden:
181
195
  dtype: "category"
182
- created_by: "smftools.tools.calculate_umap"
196
+ created_by: "smftools.tools.calculate_leiden"
183
197
  modified_by: []
184
198
  notes: "Leiden cluster assignments."
185
199
  requires: [["obsm.X_umap"]]
@@ -3,6 +3,9 @@ from __future__ import annotations
3
3
  from importlib import import_module
4
4
 
5
5
  _LAZY_ATTRS = {
6
+ "calculate_leiden": "smftools.tools.calculate_leiden",
7
+ "calculate_nmf": "smftools.tools.calculate_nmf",
8
+ "calculate_sequence_cp_decomposition": "smftools.tools.tensor_factorization",
6
9
  "calculate_umap": "smftools.tools.calculate_umap",
7
10
  "cluster_adata_on_methylation": "smftools.tools.cluster_adata_on_methylation",
8
11
  "combine_layers": "smftools.tools.general_tools",
@@ -11,6 +14,7 @@ _LAZY_ATTRS = {
11
14
  "calculate_relative_risk_on_activity": "smftools.tools.position_stats",
12
15
  "compute_positionwise_statistics": "smftools.tools.position_stats",
13
16
  "calculate_row_entropy": "smftools.tools.read_stats",
17
+ "rolling_window_nn_distance": "smftools.tools.rolling_nn_distance",
14
18
  "subset_adata": "smftools.tools.subset_adata",
15
19
  }
16
20
 
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
10
+
11
+ if TYPE_CHECKING:
12
+ import anndata as ad
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def calculate_leiden(
18
+ adata: "ad.AnnData",
19
+ *,
20
+ resolution: float = 0.1,
21
+ key_added: str = "leiden",
22
+ connectivities_key: str = "connectivities",
23
+ ) -> "ad.AnnData":
24
+ """Compute Leiden clusters from a connectivity graph.
25
+
26
+ Args:
27
+ adata: AnnData object with ``obsp[connectivities_key]`` set.
28
+ resolution: Resolution parameter for Leiden clustering.
29
+ key_added: Column name to store cluster assignments in ``adata.obs``.
30
+ connectivities_key: Key in ``adata.obsp`` containing a sparse adjacency matrix.
31
+
32
+ Returns:
33
+ Updated AnnData object with Leiden labels in ``adata.obs``.
34
+ """
35
+ if connectivities_key not in adata.obsp:
36
+ raise KeyError(f"Missing connectivities '{connectivities_key}' in adata.obsp.")
37
+
38
+ igraph = require("igraph", extra="cluster", purpose="Leiden clustering")
39
+ leidenalg = require("leidenalg", extra="cluster", purpose="Leiden clustering")
40
+
41
+ connectivities = adata.obsp[connectivities_key]
42
+ coo = connectivities.tocoo()
43
+ edges = list(zip(coo.row.tolist(), coo.col.tolist()))
44
+ graph = igraph.Graph(n=connectivities.shape[0], edges=edges, directed=False)
45
+ graph.es["weight"] = coo.data.tolist()
46
+
47
+ partition = leidenalg.find_partition(
48
+ graph,
49
+ leidenalg.RBConfigurationVertexPartition,
50
+ weights=graph.es["weight"],
51
+ resolution_parameter=resolution,
52
+ )
53
+
54
+ labels = np.array(partition.membership, dtype=str)
55
+ adata.obs[key_added] = pd.Categorical(labels)
56
+ logger.info("Stored Leiden clusters in adata.obs['%s'].", key_added)
57
+ return adata