smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +49 -7
  4. smftools/cli/hmm_adata.py +250 -32
  5. smftools/cli/latent_adata.py +773 -0
  6. smftools/cli/load_adata.py +78 -74
  7. smftools/cli/preprocess_adata.py +122 -58
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +74 -112
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +52 -4
  12. smftools/config/conversion.yaml +1 -1
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +85 -12
  15. smftools/config/experiment_config.py +146 -1
  16. smftools/constants.py +69 -0
  17. smftools/hmm/HMM.py +88 -0
  18. smftools/hmm/call_hmm_peaks.py +1 -1
  19. smftools/informatics/__init__.py +6 -0
  20. smftools/informatics/bam_functions.py +358 -8
  21. smftools/informatics/binarize_converted_base_identities.py +2 -89
  22. smftools/informatics/converted_BAM_to_adata.py +636 -175
  23. smftools/informatics/h5ad_functions.py +198 -2
  24. smftools/informatics/modkit_extract_to_adata.py +1007 -425
  25. smftools/informatics/sequence_encoding.py +72 -0
  26. smftools/logging_utils.py +21 -2
  27. smftools/metadata.py +1 -1
  28. smftools/plotting/__init__.py +26 -3
  29. smftools/plotting/autocorrelation_plotting.py +22 -4
  30. smftools/plotting/chimeric_plotting.py +1893 -0
  31. smftools/plotting/classifiers.py +28 -14
  32. smftools/plotting/general_plotting.py +62 -1583
  33. smftools/plotting/hmm_plotting.py +1670 -8
  34. smftools/plotting/latent_plotting.py +804 -0
  35. smftools/plotting/plotting_utils.py +243 -0
  36. smftools/plotting/position_stats.py +16 -8
  37. smftools/plotting/preprocess_plotting.py +281 -0
  38. smftools/plotting/qc_plotting.py +8 -3
  39. smftools/plotting/spatial_plotting.py +1134 -0
  40. smftools/plotting/variant_plotting.py +1231 -0
  41. smftools/preprocessing/__init__.py +4 -0
  42. smftools/preprocessing/append_base_context.py +18 -18
  43. smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
  44. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  45. smftools/preprocessing/append_variant_call_layer.py +480 -0
  46. smftools/preprocessing/calculate_consensus.py +1 -1
  47. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  48. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  49. smftools/preprocessing/invert_adata.py +1 -0
  50. smftools/readwrite.py +159 -99
  51. smftools/schema/anndata_schema_v1.yaml +15 -1
  52. smftools/tools/__init__.py +10 -0
  53. smftools/tools/calculate_knn.py +121 -0
  54. smftools/tools/calculate_leiden.py +57 -0
  55. smftools/tools/calculate_nmf.py +130 -0
  56. smftools/tools/calculate_pca.py +180 -0
  57. smftools/tools/calculate_umap.py +79 -80
  58. smftools/tools/position_stats.py +4 -4
  59. smftools/tools/rolling_nn_distance.py +872 -0
  60. smftools/tools/sequence_alignment.py +140 -0
  61. smftools/tools/tensor_factorization.py +217 -0
  62. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
  63. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
  64. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  65. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  66. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,773 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Optional, Sequence, Tuple
6
+
7
+ import anndata as ad
8
+
9
+ from smftools.constants import LATENT_DIR, LOGGING_DIR, REFERENCE_STRAND, SEQUENCE_INTEGER_ENCODING
10
+ from smftools.logging_utils import get_logger, setup_logging
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def _build_mod_sites_var_filter_mask(
16
+ adata: ad.AnnData,
17
+ references: Sequence[str],
18
+ cfg,
19
+ smf_modality: str,
20
+ deaminase: bool,
21
+ ) -> "np.ndarray":
22
+ """Build a boolean var mask for mod sites across references."""
23
+ import numpy as np
24
+
25
+ mod_target_bases = _expand_mod_target_bases(cfg.mod_target_bases)
26
+ ref_masks = []
27
+ for ref in references:
28
+ if deaminase and smf_modality != "direct":
29
+ mod_site_cols = [f"{ref}_C_site"]
30
+ else:
31
+ mod_site_cols = [f"{ref}_{base}_site" for base in mod_target_bases]
32
+
33
+ position_col = f"position_in_{ref}"
34
+ required_cols = mod_site_cols + [position_col]
35
+ missing = [col for col in required_cols if col not in adata.var.columns]
36
+ if missing:
37
+ raise KeyError(f"var_filters not found in adata.var: {missing}")
38
+
39
+ mod_masks = [np.asarray(adata.var[col].values, dtype=bool) for col in mod_site_cols]
40
+ mod_mask = mod_masks[0] if len(mod_masks) == 1 else np.logical_or.reduce(mod_masks)
41
+ position_mask = np.asarray(adata.var[position_col].values, dtype=bool)
42
+ ref_masks.append(np.logical_and(mod_mask, position_mask))
43
+
44
+ if not ref_masks:
45
+ return np.ones(adata.n_vars, dtype=bool)
46
+
47
+ return np.logical_and.reduce(ref_masks)
48
+
49
+
50
+ def _build_shared_valid_non_mod_sites_mask(
51
+ adata: ad.AnnData,
52
+ references: Sequence[str],
53
+ cfg,
54
+ smf_modality: str,
55
+ deaminase: bool,
56
+ ) -> "np.ndarray":
57
+ """Build a boolean var mask for shared valid positions without mod sites."""
58
+ import numpy as np
59
+
60
+ shared_position_mask = _build_reference_position_mask(adata, references)
61
+ if len(references) == 0:
62
+ return shared_position_mask
63
+
64
+ mod_target_bases = _expand_mod_target_bases(cfg.mod_target_bases)
65
+ ref_mod_masks = []
66
+ for ref in references:
67
+ if deaminase and smf_modality != "direct":
68
+ mod_site_cols = [f"{ref}_C_site"]
69
+ else:
70
+ mod_site_cols = [f"{ref}_{base}_site" for base in mod_target_bases]
71
+
72
+ required_cols = mod_site_cols
73
+ missing = [col for col in required_cols if col not in adata.var.columns]
74
+ if missing:
75
+ raise KeyError(f"var_filters not found in adata.var: {missing}")
76
+
77
+ mod_masks = [np.asarray(adata.var[col].values, dtype=bool) for col in mod_site_cols]
78
+ ref_mod_masks.append(
79
+ mod_masks[0] if len(mod_masks) == 1 else np.logical_or.reduce(mod_masks)
80
+ )
81
+
82
+ any_mod_mask = (
83
+ np.logical_or.reduce(ref_mod_masks) if ref_mod_masks else np.zeros(adata.n_vars, dtype=bool)
84
+ )
85
+ return np.logical_and(shared_position_mask, np.logical_not(any_mod_mask))
86
+
87
+
88
+ def _expand_mod_target_bases(mod_target_bases: Sequence[str]) -> list[str]:
89
+ """Ensure ambiguous GpC/CpG sites are included when requested."""
90
+ bases = list(mod_target_bases)
91
+ if any(base in {"GpC", "CpG"} for base in bases) and "ambiguous_GpC_CpG" not in bases:
92
+ bases.append("ambiguous_GpC_CpG")
93
+ return bases
94
+
95
+
96
+ def _build_reference_position_mask(
97
+ adata: ad.AnnData,
98
+ references: Sequence[str],
99
+ ) -> "np.ndarray":
100
+ """Build a boolean var mask for positions valid across references."""
101
+ import numpy as np
102
+
103
+ ref_masks = []
104
+ for ref in references:
105
+ position_col = f"position_in_{ref}"
106
+ if position_col not in adata.var.columns:
107
+ raise KeyError(f"var_filters not found in adata.var: {position_col}")
108
+ position_mask = np.asarray(adata.var[position_col].values, dtype=bool)
109
+ ref_masks.append(position_mask)
110
+
111
+ if not ref_masks:
112
+ return np.ones(adata.n_vars, dtype=bool)
113
+
114
+ return np.logical_and.reduce(ref_masks)
115
+
116
+
117
+ def latent_adata(
118
+ config_path: str,
119
+ ) -> Tuple[Optional[ad.AnnData], Optional[Path]]:
120
+ """
121
+ CLI-facing wrapper for representation learning.
122
+
123
+ Called by: `smftools latent <config_path>`
124
+
125
+ Responsibilities:
126
+ - Determine which AnnData stages exist (pp, pp_dedup, spatial, hmm).
127
+ - Call `latent_adata_core(...)` when actual work is needed.
128
+
129
+ Returns
130
+ -------
131
+ latent_adata : AnnData | None
132
+ AnnData with latent analyses, or None if we skipped because a later-stage
133
+ AnnData already exists.
134
+ latent_adata_path : Path | None
135
+ Path to the “current” latent AnnData.
136
+ """
137
+ from ..readwrite import add_or_update_column_in_csv, safe_read_h5ad
138
+ from .helpers import get_adata_paths, load_experiment_config
139
+
140
+ # 1) Ensure config + basic paths via load_adata
141
+ cfg = load_experiment_config(config_path)
142
+
143
+ paths = get_adata_paths(cfg)
144
+
145
+ pp_path = paths.pp
146
+ pp_dedup_path = paths.pp_dedup
147
+ spatial_path = paths.spatial
148
+ chimeric_path = paths.chimeric
149
+ variant_path = paths.variant
150
+ hmm_path = paths.hmm
151
+ latent_path = paths.latent
152
+
153
+ # Stage-skipping logic for latent
154
+ if not getattr(cfg, "force_redo_latent_analyses", False):
155
+ # If latent exists, we consider latent analyses already done.
156
+ if latent_path.exists():
157
+ logger.info(f"Latent AnnData found: {latent_path}\nSkipping smftools latent")
158
+ return None, latent_path
159
+
160
+ # Helper to load from disk, reusing loaded_adata if it matches
161
+ def _load(path: Path):
162
+ adata, _ = safe_read_h5ad(path)
163
+ return adata
164
+
165
+ # 3) Decide which AnnData to use as the *starting point* for latent analyses
166
+ if hmm_path.exists():
167
+ start_adata = _load(hmm_path)
168
+ source_path = hmm_path
169
+ elif latent_path.exists():
170
+ start_adata = _load(latent_path)
171
+ source_path = latent_path
172
+ elif spatial_path.exists():
173
+ start_adata = _load(spatial_path)
174
+ source_path = spatial_path
175
+ elif chimeric_path.exists():
176
+ start_adata = _load(chimeric_path)
177
+ source_path = chimeric_path
178
+ elif variant_path.exists():
179
+ start_adata = _load(variant_path)
180
+ source_path = variant_path
181
+ elif pp_dedup_path.exists():
182
+ start_adata = _load(pp_dedup_path)
183
+ source_path = pp_dedup_path
184
+ elif pp_path.exists():
185
+ start_adata = _load(pp_path)
186
+ source_path = pp_path
187
+ else:
188
+ logger.warning(
189
+ "No suitable AnnData found for latent analyses (need at least preprocessed)."
190
+ )
191
+ return None, None
192
+
193
+ # 4) Run the latent core
194
+ adata_latent, latent_path = latent_adata_core(
195
+ adata=start_adata,
196
+ cfg=cfg,
197
+ paths=paths,
198
+ source_adata_path=source_path,
199
+ config_path=config_path,
200
+ )
201
+
202
+ return adata_latent, latent_path
203
+
204
+
205
+ def latent_adata_core(
206
+ adata: ad.AnnData,
207
+ cfg,
208
+ paths: AdataPaths,
209
+ source_adata_path: Optional[Path] = None,
210
+ config_path: Optional[str] = None,
211
+ ) -> Tuple[ad.AnnData, Path]:
212
+ """
213
+ Core spatial analysis pipeline.
214
+
215
+ Assumes:
216
+ - `adata` is (typically) the preprocessed, duplicate-removed AnnData.
217
+ - `cfg` is the ExperimentConfig.
218
+
219
+ Does:
220
+ - Optional sample sheet load.
221
+ - Optional inversion & reindexing.
222
+ - PCA/KNN/UMAP/Leiden/NMP/PARAFAC
223
+ - Save latent AnnData to `latent_adata_path`.
224
+
225
+ Returns
226
+ -------
227
+ adata : AnnData
228
+ analyzed AnnData (same object, modified in-place).
229
+ adata_path : Path
230
+ Path where AnnData was written.
231
+ """
232
+ import os
233
+ import warnings
234
+ from datetime import datetime
235
+ from pathlib import Path
236
+
237
+ import numpy as np
238
+ import pandas as pd
239
+
240
+ from ..metadata import record_smftools_metadata
241
+ from ..plotting import (
242
+ plot_cp_sequence_components,
243
+ plot_embedding_grid,
244
+ plot_nmf_components,
245
+ plot_pca_components,
246
+ plot_pca_explained_variance,
247
+ plot_pca_grid,
248
+ plot_umap_grid,
249
+ )
250
+ from ..preprocessing import (
251
+ invert_adata,
252
+ load_sample_sheet,
253
+ reindex_references_adata,
254
+ )
255
+ from ..readwrite import make_dirs
256
+ from ..tools import (
257
+ calculate_knn,
258
+ calculate_leiden,
259
+ calculate_nmf,
260
+ calculate_pca,
261
+ calculate_sequence_cp_decomposition,
262
+ calculate_umap,
263
+ )
264
+ from .helpers import write_gz_h5ad
265
+
266
+ # -----------------------------
267
+ # General setup
268
+ # -----------------------------
269
+ date_str = datetime.today().strftime("%y%m%d")
270
+ now = datetime.now()
271
+ time_str = now.strftime("%H%M%S")
272
+ log_level = getattr(logging, cfg.log_level.upper(), logging.INFO)
273
+
274
+ latent_adata_path = paths.latent
275
+
276
+ output_directory = Path(cfg.output_directory)
277
+ latent_directory = output_directory / LATENT_DIR
278
+ logging_directory = latent_directory / LOGGING_DIR
279
+
280
+ make_dirs([output_directory, latent_directory])
281
+
282
+ if cfg.emit_log_file:
283
+ log_file = logging_directory / f"{date_str}_{time_str}_log.log"
284
+ make_dirs([logging_directory])
285
+ else:
286
+ log_file = None
287
+
288
+ setup_logging(level=log_level, log_file=log_file, reconfigure=log_file is not None)
289
+
290
+ smf_modality = cfg.smf_modality
291
+ if smf_modality == "conversion":
292
+ deaminase = False
293
+ else:
294
+ deaminase = True
295
+
296
+ # -----------------------------
297
+ # Optional sample sheet metadata
298
+ # -----------------------------
299
+ if getattr(cfg, "sample_sheet_path", None):
300
+ load_sample_sheet(
301
+ adata,
302
+ cfg.sample_sheet_path,
303
+ mapping_key_column=cfg.sample_sheet_mapping_column,
304
+ as_category=True,
305
+ force_reload=cfg.force_reload_sample_sheet,
306
+ )
307
+
308
+ # -----------------------------
309
+ # Optional inversion along positions axis
310
+ # -----------------------------
311
+ if getattr(cfg, "invert_adata", False):
312
+ adata = invert_adata(adata)
313
+
314
+ # -----------------------------
315
+ # Optional reindexing by reference
316
+ # -----------------------------
317
+ reindex_references_adata(
318
+ adata,
319
+ reference_col=cfg.reference_column,
320
+ offsets=cfg.reindexing_offsets,
321
+ new_col=cfg.reindexed_var_suffix,
322
+ )
323
+
324
+ if adata.uns.get("reindex_references_adata_performed", False):
325
+ reindex_suffix = cfg.reindexed_var_suffix
326
+ else:
327
+ reindex_suffix = None
328
+
329
+ references = adata.obs[cfg.reference_column].cat.categories
330
+
331
+ latent_dir_dedup = latent_directory / "deduplicated"
332
+
333
+ # ============================================================
334
+ # 2) PCA/UMAP/NMF at valid modified base site binary encodings shared across references
335
+ # ============================================================
336
+ SUBSET = "shared_valid_mod_sites_binary_mod_arrays"
337
+
338
+ pca_dir = latent_dir_dedup / f"01_pca_{SUBSET}"
339
+ umap_dir = latent_dir_dedup / f"01_umap_{SUBSET}"
340
+ nmf_dir = latent_dir_dedup / f"01_nmf_{SUBSET}"
341
+
342
+ mod_site_layers = []
343
+ for mod_base in cfg.mod_target_bases:
344
+ mod_site_layers += [f"Modified_{mod_base}_site_count", f"Fraction_{mod_base}_site_modified"]
345
+
346
+ plotting_layers = [cfg.sample_name_col_for_plotting, REFERENCE_STRAND] + mod_site_layers
347
+ plotting_layers += cfg.umap_layers_to_plot
348
+
349
+ mod_sites_mask = _build_mod_sites_var_filter_mask(
350
+ adata=adata,
351
+ references=references,
352
+ cfg=cfg,
353
+ smf_modality=smf_modality,
354
+ deaminase=deaminase,
355
+ )
356
+ non_mod_sites_mask = _build_shared_valid_non_mod_sites_mask(
357
+ adata=adata,
358
+ references=references,
359
+ cfg=cfg,
360
+ smf_modality=smf_modality,
361
+ deaminase=deaminase,
362
+ )
363
+
364
+ # PCA calculation
365
+ adata = calculate_pca(
366
+ adata,
367
+ layer=cfg.layer_for_umap_plotting,
368
+ var_mask=mod_sites_mask,
369
+ n_pcs=10,
370
+ output_suffix=SUBSET,
371
+ )
372
+
373
+ # KNN calculation
374
+ adata = calculate_knn(
375
+ adata,
376
+ obsm=f"X_pca_{SUBSET}",
377
+ knn_neighbors=15,
378
+ )
379
+
380
+ # UMAP Calculation
381
+ adata = calculate_umap(
382
+ adata,
383
+ obsm=f"X_pca_{SUBSET}",
384
+ output_suffix=SUBSET,
385
+ )
386
+
387
+ # Leiden clustering
388
+ calculate_leiden(adata, resolution=0.1, connectivities_key=f"connectivities_X_pca_{SUBSET}")
389
+
390
+ # NMF Calculation
391
+ adata = calculate_nmf(
392
+ adata,
393
+ layer=cfg.layer_for_umap_plotting,
394
+ var_mask=mod_sites_mask,
395
+ n_components=2,
396
+ suffix=SUBSET,
397
+ )
398
+
399
+ # PCA
400
+ if pca_dir.is_dir() and not getattr(cfg, "force_redo_spatial_analyses", False):
401
+ logger.debug(f"{pca_dir} already exists. Skipping PCA calculation and plotting.")
402
+ else:
403
+ make_dirs([pca_dir])
404
+ plot_pca_grid(adata, subset=SUBSET, color=plotting_layers, output_dir=pca_dir)
405
+ plot_pca_explained_variance(adata, subset=SUBSET, output_dir=pca_dir)
406
+ plot_pca_components(adata, output_dir=pca_dir, suffix=SUBSET)
407
+
408
+ # UMAP
409
+ if umap_dir.is_dir() and not getattr(cfg, "force_redo_spatial_analyses", False):
410
+ logger.debug(f"{umap_dir} already exists. Skipping UMAP plotting.")
411
+ else:
412
+ make_dirs([umap_dir])
413
+ plot_umap_grid(adata, subset=SUBSET, color=plotting_layers, output_dir=umap_dir)
414
+
415
+ # NMF
416
+ if nmf_dir.is_dir() and not getattr(cfg, "force_redo_spatial_analyses", False):
417
+ logger.debug(f"{nmf_dir} already exists. Skipping NMF plotting.")
418
+ else:
419
+ make_dirs([nmf_dir])
420
+
421
+ plot_embedding_grid(adata, basis=f"nmf_{SUBSET}", color=plotting_layers, output_dir=nmf_dir)
422
+ plot_nmf_components(adata, output_dir=nmf_dir, suffix=SUBSET)
423
+
424
+ # ============================================================
425
+ # 3) PCA/UMAP/NMF at valid base site integer encodings shared across references
426
+ # ============================================================
427
+ SUBSET = "shared_valid_ref_sites_integer_sequence_encodings"
428
+
429
+ pca_dir = latent_dir_dedup / f"02_pca_{SUBSET}"
430
+ umap_dir = latent_dir_dedup / f"02_umap_{SUBSET}"
431
+ nmf_dir = latent_dir_dedup / f"02_nmf_{SUBSET}"
432
+
433
+ valid_sites = _build_reference_position_mask(adata, references)
434
+
435
+ # PCA calculation
436
+ adata = calculate_pca(
437
+ adata,
438
+ layer=SEQUENCE_INTEGER_ENCODING,
439
+ var_mask=valid_sites,
440
+ n_pcs=10,
441
+ output_suffix=SUBSET,
442
+ )
443
+
444
+ # KNN calculation
445
+ adata = calculate_knn(
446
+ adata,
447
+ obsm=f"X_pca_{SUBSET}",
448
+ knn_neighbors=15,
449
+ )
450
+
451
+ # UMAP Calculation
452
+ adata = calculate_umap(
453
+ adata,
454
+ obsm=f"X_pca_{SUBSET}",
455
+ output_suffix=SUBSET,
456
+ )
457
+
458
+ # Leiden clustering
459
+ calculate_leiden(adata, resolution=0.1, connectivities_key=f"connectivities_X_pca_{SUBSET}")
460
+
461
+ # NMF Calculation
462
+ adata = calculate_nmf(
463
+ adata,
464
+ layer=SEQUENCE_INTEGER_ENCODING,
465
+ var_mask=valid_sites,
466
+ n_components=2,
467
+ suffix=SUBSET,
468
+ )
469
+
470
+ # PCA
471
+ if pca_dir.is_dir() and not getattr(cfg, "force_redo_latent_analyses", False):
472
+ logger.debug(f"{pca_dir} already exists. Skipping PCA calculation and plotting.")
473
+ else:
474
+ make_dirs([pca_dir])
475
+ plot_pca_grid(adata, subset=SUBSET, color=plotting_layers, output_dir=pca_dir)
476
+ plot_pca_explained_variance(adata, subset=SUBSET, output_dir=pca_dir)
477
+ plot_pca_components(adata, output_dir=pca_dir, suffix=SUBSET)
478
+
479
+ # UMAP
480
+ if umap_dir.is_dir() and not getattr(cfg, "force_redo_latent_analyses", False):
481
+ logger.debug(f"{umap_dir} already exists. Skipping UMAP plotting.")
482
+ else:
483
+ make_dirs([umap_dir])
484
+ plot_umap_grid(adata, subset=SUBSET, color=plotting_layers, output_dir=umap_dir)
485
+
486
+ # NMF
487
+ if nmf_dir.is_dir() and not getattr(cfg, "force_redo_latent_analyses", False):
488
+ logger.debug(f"{nmf_dir} already exists. Skipping NMF plotting.")
489
+ else:
490
+ make_dirs([nmf_dir])
491
+
492
+ plot_embedding_grid(adata, basis=f"nmf_{SUBSET}", color=plotting_layers, output_dir=nmf_dir)
493
+ plot_nmf_components(adata, output_dir=nmf_dir, suffix=SUBSET)
494
+
495
+ # ============================================================
496
+ # 3) CP PARAFAC factorization of shared mod site OHE sequences with mask layer
497
+ # ============================================================
498
+ SUBSET = "shared_valid_mod_sites_ohe_sequence_N_masked"
499
+
500
+ cp_sequence_dir = latent_dir_dedup / f"03_cp_{SUBSET}"
501
+
502
+ # Calculate CP tensor factorization
503
+ if SEQUENCE_INTEGER_ENCODING not in adata.layers:
504
+ logger.warning(
505
+ "Layer %s not found; skipping sequence integer encoding CP.",
506
+ SEQUENCE_INTEGER_ENCODING,
507
+ )
508
+ else:
509
+ adata = calculate_sequence_cp_decomposition(
510
+ adata,
511
+ layer=SEQUENCE_INTEGER_ENCODING,
512
+ var_mask=mod_sites_mask,
513
+ var_mask_name="shared_reference_and_mod_site_positions",
514
+ rank=2,
515
+ embedding_key=f"X_cp_{SUBSET}",
516
+ components_key=f"H_cp_{SUBSET}",
517
+ uns_key=f"cp_{SUBSET}",
518
+ non_negative=False,
519
+ )
520
+
521
+ # CP decomposition using sequence integer encoding (no var filters)
522
+ if cp_sequence_dir.is_dir() and not getattr(cfg, "force_redo_latent_analyses", False):
523
+ logger.debug(f"{cp_sequence_dir} already exists. Skipping sequence CP plotting.")
524
+ else:
525
+ make_dirs([cp_sequence_dir])
526
+ plot_embedding_grid(
527
+ adata,
528
+ basis=f"cp_{SUBSET}",
529
+ color=plotting_layers,
530
+ output_dir=cp_sequence_dir,
531
+ )
532
+ plot_cp_sequence_components(
533
+ adata,
534
+ output_dir=cp_sequence_dir,
535
+ components_key=f"H_cp_{SUBSET}",
536
+ uns_key=f"cp_{SUBSET}",
537
+ )
538
+
539
+ # ============================================================
540
+ # 4) Non-negative CP PARAFAC factorization of shared mod site OHE sequences with mask layer
541
+ # ============================================================
542
+ SUBSET = "shared_valid_mod_sites_ohe_sequence_N_masked_non_negative"
543
+
544
+ cp_sequence_dir = latent_dir_dedup / f"04_cp_{SUBSET}"
545
+
546
+ # Calculate CP tensor factorization
547
+ if SEQUENCE_INTEGER_ENCODING not in adata.layers:
548
+ logger.warning(
549
+ "Layer %s not found; skipping sequence integer encoding CP.",
550
+ SEQUENCE_INTEGER_ENCODING,
551
+ )
552
+ else:
553
+ adata = calculate_sequence_cp_decomposition(
554
+ adata,
555
+ layer=SEQUENCE_INTEGER_ENCODING,
556
+ var_mask=mod_sites_mask,
557
+ var_mask_name="shared_reference_mod_site_positions",
558
+ rank=2,
559
+ embedding_key=f"X_cp_{SUBSET}",
560
+ components_key=f"H_cp_{SUBSET}",
561
+ uns_key=f"cp_{SUBSET}",
562
+ non_negative=True,
563
+ )
564
+
565
+ # CP decomposition using sequence integer encoding (no var filters)
566
+ if cp_sequence_dir.is_dir() and not getattr(cfg, "force_redo_latent_analyses", False):
567
+ logger.debug(f"{cp_sequence_dir} already exists. Skipping sequence CP plotting.")
568
+ else:
569
+ make_dirs([cp_sequence_dir])
570
+ plot_embedding_grid(
571
+ adata,
572
+ basis=f"cp_{SUBSET}",
573
+ color=plotting_layers,
574
+ output_dir=cp_sequence_dir,
575
+ )
576
+ plot_cp_sequence_components(
577
+ adata,
578
+ output_dir=cp_sequence_dir,
579
+ components_key=f"H_cp_{SUBSET}",
580
+ uns_key=f"cp_{SUBSET}",
581
+ )
582
+ # ============================================================
583
+ # 6) CP PARAFAC factorization of non mod-site OHE sequences with mask layer
584
+ # ============================================================
585
+ SUBSET = "non_mod_site_ohe_sequence_N_masked"
586
+
587
+ cp_sequence_dir = latent_dir_dedup / f"05_cp_{SUBSET}"
588
+
589
+ # Calculate CP tensor factorization
590
+ if SEQUENCE_INTEGER_ENCODING not in adata.layers:
591
+ logger.warning(
592
+ "Layer %s not found; skipping sequence integer encoding CP.",
593
+ SEQUENCE_INTEGER_ENCODING,
594
+ )
595
+ else:
596
+ adata = calculate_sequence_cp_decomposition(
597
+ adata,
598
+ layer=SEQUENCE_INTEGER_ENCODING,
599
+ var_mask=non_mod_sites_mask,
600
+ var_mask_name="non_mod_site_reference_positions",
601
+ rank=2,
602
+ embedding_key=f"X_cp_{SUBSET}",
603
+ components_key=f"H_cp_{SUBSET}",
604
+ uns_key=f"cp_{SUBSET}",
605
+ non_negative=False,
606
+ )
607
+
608
+ # CP decomposition using sequence integer encoding (no var filters)
609
+ if cp_sequence_dir.is_dir() and not getattr(cfg, "force_redo_latent_analyses", False):
610
+ logger.debug(f"{cp_sequence_dir} already exists. Skipping sequence CP plotting.")
611
+ else:
612
+ make_dirs([cp_sequence_dir])
613
+ plot_embedding_grid(
614
+ adata,
615
+ basis=f"cp_{SUBSET}",
616
+ color=plotting_layers,
617
+ output_dir=cp_sequence_dir,
618
+ )
619
+ plot_cp_sequence_components(
620
+ adata,
621
+ output_dir=cp_sequence_dir,
622
+ components_key=f"H_cp_{SUBSET}",
623
+ uns_key=f"cp_{SUBSET}",
624
+ )
625
+
626
+ # ============================================================
627
+ # 7) Non-negative CP PARAFAC factorization of full OHE sequences with mask layer
628
+ # ============================================================
629
+ SUBSET = "non_mod_site_ohe_sequence_N_masked_non_negative"
630
+
631
+ cp_sequence_dir = latent_dir_dedup / f"06_cp_{SUBSET}"
632
+
633
+ # Calculate CP tensor factorization
634
+ if SEQUENCE_INTEGER_ENCODING not in adata.layers:
635
+ logger.warning(
636
+ "Layer %s not found; skipping sequence integer encoding CP.",
637
+ SEQUENCE_INTEGER_ENCODING,
638
+ )
639
+ else:
640
+ adata = calculate_sequence_cp_decomposition(
641
+ adata,
642
+ layer=SEQUENCE_INTEGER_ENCODING,
643
+ var_mask=non_mod_sites_mask,
644
+ var_mask_name="non_mod_site_reference_positions",
645
+ rank=2,
646
+ embedding_key=f"X_cp_{SUBSET}",
647
+ components_key=f"H_cp_{SUBSET}",
648
+ uns_key=f"cp_{SUBSET}",
649
+ non_negative=True,
650
+ )
651
+
652
+ # CP decomposition using sequence integer encoding (no var filters)
653
+ if cp_sequence_dir.is_dir() and not getattr(cfg, "force_redo_latent_analyses", False):
654
+ logger.debug(f"{cp_sequence_dir} already exists. Skipping sequence CP plotting.")
655
+ else:
656
+ make_dirs([cp_sequence_dir])
657
+ plot_embedding_grid(
658
+ adata,
659
+ basis=f"cp_{SUBSET}",
660
+ color=plotting_layers,
661
+ output_dir=cp_sequence_dir,
662
+ )
663
+ plot_cp_sequence_components(
664
+ adata,
665
+ output_dir=cp_sequence_dir,
666
+ components_key=f"H_cp_{SUBSET}",
667
+ uns_key=f"cp_{SUBSET}",
668
+ )
669
+
670
+ # ============================================================
671
+ # 8) CP PARAFAC factorization of full OHE sequences with mask layer
672
+ # ============================================================
673
+ SUBSET = "full_ohe_sequence_N_masked"
674
+
675
+ cp_sequence_dir = latent_dir_dedup / f"07_cp_{SUBSET}"
676
+
677
+ # Calculate CP tensor factorization
678
+ if SEQUENCE_INTEGER_ENCODING not in adata.layers:
679
+ logger.warning(
680
+ "Layer %s not found; skipping sequence integer encoding CP.",
681
+ SEQUENCE_INTEGER_ENCODING,
682
+ )
683
+ else:
684
+ adata = calculate_sequence_cp_decomposition(
685
+ adata,
686
+ layer=SEQUENCE_INTEGER_ENCODING,
687
+ var_mask=_build_reference_position_mask(adata, references),
688
+ var_mask_name="shared_reference_positions",
689
+ rank=2,
690
+ embedding_key=f"X_cp_{SUBSET}",
691
+ components_key=f"H_cp_{SUBSET}",
692
+ uns_key=f"cp_{SUBSET}",
693
+ non_negative=False,
694
+ )
695
+
696
+ # CP decomposition using sequence integer encoding (no var filters)
697
+ if cp_sequence_dir.is_dir() and not getattr(cfg, "force_redo_latent_analyses", False):
698
+ logger.debug(f"{cp_sequence_dir} already exists. Skipping sequence CP plotting.")
699
+ else:
700
+ make_dirs([cp_sequence_dir])
701
+ plot_embedding_grid(
702
+ adata,
703
+ basis=f"cp_{SUBSET}",
704
+ color=plotting_layers,
705
+ output_dir=cp_sequence_dir,
706
+ )
707
+ plot_cp_sequence_components(
708
+ adata,
709
+ output_dir=cp_sequence_dir,
710
+ components_key=f"H_cp_{SUBSET}",
711
+ uns_key=f"cp_{SUBSET}",
712
+ )
713
+
714
+ # ============================================================
715
+ # 9) Non-negative CP PARAFAC factorization of full OHE sequences with mask layer
716
+ # ============================================================
717
+ SUBSET = "full_ohe_sequence_N_masked_non_negative"
718
+
719
+ cp_sequence_dir = latent_dir_dedup / f"08_cp_{SUBSET}"
720
+
721
+ # Calculate CP tensor factorization
722
+ if SEQUENCE_INTEGER_ENCODING not in adata.layers:
723
+ logger.warning(
724
+ "Layer %s not found; skipping sequence integer encoding CP.",
725
+ SEQUENCE_INTEGER_ENCODING,
726
+ )
727
+ else:
728
+ adata = calculate_sequence_cp_decomposition(
729
+ adata,
730
+ layer=SEQUENCE_INTEGER_ENCODING,
731
+ var_mask=_build_reference_position_mask(adata, references),
732
+ var_mask_name="shared_reference_positions",
733
+ rank=2,
734
+ embedding_key=f"X_cp_{SUBSET}",
735
+ components_key=f"H_cp_{SUBSET}",
736
+ uns_key=f"cp_{SUBSET}",
737
+ non_negative=True,
738
+ )
739
+
740
+ # CP decomposition using sequence integer encoding (no var filters)
741
+ if cp_sequence_dir.is_dir() and not getattr(cfg, "force_redo_latent_analyses", False):
742
+ logger.debug(f"{cp_sequence_dir} already exists. Skipping sequence CP plotting.")
743
+ else:
744
+ make_dirs([cp_sequence_dir])
745
+ plot_embedding_grid(
746
+ adata,
747
+ basis=f"cp_{SUBSET}",
748
+ color=plotting_layers,
749
+ output_dir=cp_sequence_dir,
750
+ )
751
+ plot_cp_sequence_components(
752
+ adata,
753
+ output_dir=cp_sequence_dir,
754
+ components_key=f"H_cp_{SUBSET}",
755
+ uns_key=f"cp_{SUBSET}",
756
+ )
757
+
758
+ # ============================================================
759
+ # 10) Save latent AnnData
760
+ # ============================================================
761
+ if not latent_adata_path.exists():
762
+ logger.info("Saving latent analyzed AnnData")
763
+ record_smftools_metadata(
764
+ adata,
765
+ step_name="latent",
766
+ cfg=cfg,
767
+ config_path=config_path,
768
+ input_paths=[source_adata_path] if source_adata_path else None,
769
+ output_path=latent_adata_path,
770
+ )
771
+ write_gz_h5ad(adata, latent_adata_path)
772
+
773
+ return adata, latent_adata_path