smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +49 -7
  4. smftools/cli/hmm_adata.py +250 -32
  5. smftools/cli/latent_adata.py +773 -0
  6. smftools/cli/load_adata.py +78 -74
  7. smftools/cli/preprocess_adata.py +122 -58
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +74 -112
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +52 -4
  12. smftools/config/conversion.yaml +1 -1
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +85 -12
  15. smftools/config/experiment_config.py +146 -1
  16. smftools/constants.py +69 -0
  17. smftools/hmm/HMM.py +88 -0
  18. smftools/hmm/call_hmm_peaks.py +1 -1
  19. smftools/informatics/__init__.py +6 -0
  20. smftools/informatics/bam_functions.py +358 -8
  21. smftools/informatics/binarize_converted_base_identities.py +2 -89
  22. smftools/informatics/converted_BAM_to_adata.py +636 -175
  23. smftools/informatics/h5ad_functions.py +198 -2
  24. smftools/informatics/modkit_extract_to_adata.py +1007 -425
  25. smftools/informatics/sequence_encoding.py +72 -0
  26. smftools/logging_utils.py +21 -2
  27. smftools/metadata.py +1 -1
  28. smftools/plotting/__init__.py +26 -3
  29. smftools/plotting/autocorrelation_plotting.py +22 -4
  30. smftools/plotting/chimeric_plotting.py +1893 -0
  31. smftools/plotting/classifiers.py +28 -14
  32. smftools/plotting/general_plotting.py +62 -1583
  33. smftools/plotting/hmm_plotting.py +1670 -8
  34. smftools/plotting/latent_plotting.py +804 -0
  35. smftools/plotting/plotting_utils.py +243 -0
  36. smftools/plotting/position_stats.py +16 -8
  37. smftools/plotting/preprocess_plotting.py +281 -0
  38. smftools/plotting/qc_plotting.py +8 -3
  39. smftools/plotting/spatial_plotting.py +1134 -0
  40. smftools/plotting/variant_plotting.py +1231 -0
  41. smftools/preprocessing/__init__.py +4 -0
  42. smftools/preprocessing/append_base_context.py +18 -18
  43. smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
  44. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  45. smftools/preprocessing/append_variant_call_layer.py +480 -0
  46. smftools/preprocessing/calculate_consensus.py +1 -1
  47. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  48. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  49. smftools/preprocessing/invert_adata.py +1 -0
  50. smftools/readwrite.py +159 -99
  51. smftools/schema/anndata_schema_v1.yaml +15 -1
  52. smftools/tools/__init__.py +10 -0
  53. smftools/tools/calculate_knn.py +121 -0
  54. smftools/tools/calculate_leiden.py +57 -0
  55. smftools/tools/calculate_nmf.py +130 -0
  56. smftools/tools/calculate_pca.py +180 -0
  57. smftools/tools/calculate_umap.py +79 -80
  58. smftools/tools/position_stats.py +4 -4
  59. smftools/tools/rolling_nn_distance.py +872 -0
  60. smftools/tools/sequence_alignment.py +140 -0
  61. smftools/tools/tensor_factorization.py +217 -0
  62. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
  63. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
  64. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  65. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  66. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -60,6 +60,20 @@ stages:
60
60
  notes: "Mapping quality score."
61
61
  requires: []
62
62
  optional_inputs: []
63
+ reference_start:
64
+ dtype: "float"
65
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
66
+ modified_by: []
67
+ notes: "0-based reference start position for the alignment."
68
+ requires: []
69
+ optional_inputs: []
70
+ reference_end:
71
+ dtype: "float"
72
+ created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
73
+ modified_by: []
74
+ notes: "0-based reference end position (exclusive) for the alignment."
75
+ requires: []
76
+ optional_inputs: []
63
77
  read_length_to_reference_length_ratio:
64
78
  dtype: "float"
65
79
  created_by: "smftools.informatics.h5ad_functions.add_read_length_and_mapping_qc"
@@ -179,7 +193,7 @@ stages:
179
193
  obs:
180
194
  leiden:
181
195
  dtype: "category"
182
- created_by: "smftools.tools.calculate_umap"
196
+ created_by: "smftools.tools.calculate_leiden"
183
197
  modified_by: []
184
198
  notes: "Leiden cluster assignments."
185
199
  requires: [["obsm.X_umap"]]
@@ -3,6 +3,11 @@ from __future__ import annotations
3
3
  from importlib import import_module
4
4
 
5
5
  _LAZY_ATTRS = {
6
+ "calculate_leiden": "smftools.tools.calculate_leiden",
7
+ "calculate_nmf": "smftools.tools.calculate_nmf",
8
+ "calculate_sequence_cp_decomposition": "smftools.tools.tensor_factorization",
9
+ "calculate_pca": "smftools.tools.calculate_pca",
10
+ "calculate_knn": "smftools.tools.calculate_knn",
6
11
  "calculate_umap": "smftools.tools.calculate_umap",
7
12
  "cluster_adata_on_methylation": "smftools.tools.cluster_adata_on_methylation",
8
13
  "combine_layers": "smftools.tools.general_tools",
@@ -11,6 +16,11 @@ _LAZY_ATTRS = {
11
16
  "calculate_relative_risk_on_activity": "smftools.tools.position_stats",
12
17
  "compute_positionwise_statistics": "smftools.tools.position_stats",
13
18
  "calculate_row_entropy": "smftools.tools.read_stats",
19
+ "align_sequences_with_mismatches": "smftools.tools.sequence_alignment",
20
+ "rolling_window_nn_distance": "smftools.tools.rolling_nn_distance",
21
+ "annotate_zero_hamming_segments": "smftools.tools.rolling_nn_distance",
22
+ "assign_per_read_segments_layer": "smftools.tools.rolling_nn_distance",
23
+ "select_top_segments_per_read": "smftools.tools.rolling_nn_distance",
14
24
  "subset_adata": "smftools.tools.subset_adata",
15
25
  }
16
26
 
@@ -0,0 +1,121 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from smftools.logging_utils import get_logger
6
+ from smftools.optional_imports import require
7
+
8
+ if TYPE_CHECKING:
9
+ import anndata as ad
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ def calculate_knn(
15
+ adata: "ad.AnnData",
16
+ obsm: str = "X_pca",
17
+ knn_neighbors: int = 100,
18
+ overwrite: bool = True,
19
+ threads: int = 8,
20
+ random_state: int | None = 0,
21
+ symmetrize: bool = True,
22
+ ) -> "ad.AnnData":
23
+ """Compute a KNN distance graph on an embedding in `adata.obsm[obsm]`.
24
+
25
+ Stores:
26
+ - adata.obsp[f"knn_distances_{obsm}"] : CSR sparse matrix of distances
27
+ - adata.uns[f"knn_distances_{obsm}"]["params"] : metadata
28
+
29
+ Args:
30
+ adata: AnnData object to update.
31
+ obsm: Key in `adata.obsm` to use as the embedding.
32
+ knn_neighbors: Target number of neighbors (will be clipped to n_obs-1).
33
+ overwrite: If False and graph exists, do nothing.
34
+ threads: Parallel jobs for pynndescent.
35
+ random_state: Seed for pynndescent.
36
+ symmetrize: If True, make distance graph symmetric via min(A, A.T).
37
+
38
+ Returns:
39
+ Updated AnnData.
40
+ """
41
+ import numpy as np
42
+ import scipy.sparse as sp
43
+
44
+ if obsm not in adata.obsm:
45
+ raise KeyError(f"`{obsm}` not found in adata.obsm. Available: {list(adata.obsm.keys())}")
46
+
47
+ out_key = f"knn_distances_{obsm}"
48
+ if not overwrite and out_key in adata.obsp:
49
+ logger.info("KNN graph %r already exists and overwrite=False; skipping.", out_key)
50
+ return adata
51
+
52
+ data = adata.obsm[obsm]
53
+
54
+ if sp.issparse(data):
55
+ # Convert to float32 for pynndescent/numba friendliness if needed
56
+ data = data.astype(np.float32)
57
+ logger.info(
58
+ "Sparse embedding detected (%s). Proceeding without NaN check.", type(data).__name__
59
+ )
60
+ else:
61
+ data = np.asarray(data)
62
+ if np.isnan(data).any():
63
+ logger.warning("NaNs detected in %s; filling NaNs with 0.5 before KNN.", obsm)
64
+ data = np.nan_to_num(data, nan=0.5)
65
+ data = data.astype(np.float32, copy=False)
66
+
67
+ pynndescent = require("pynndescent", extra="umap", purpose="KNN graph computation")
68
+
69
+ n_obs = data.shape[0]
70
+ if n_obs < 2:
71
+ raise ValueError(f"Need at least 2 observations for KNN; got n_obs={n_obs}")
72
+
73
+ n_neighbors = min(int(knn_neighbors), n_obs - 1)
74
+ if n_neighbors < 1:
75
+ raise ValueError(f"Computed n_neighbors={n_neighbors}; check knn_neighbors and n_obs.")
76
+
77
+ logger.info(
78
+ "Running pynndescent KNN (obsm=%s, n_neighbors=%d, metric=euclidean, n_jobs=%d)",
79
+ obsm,
80
+ n_neighbors,
81
+ threads,
82
+ )
83
+
84
+ nn_index = pynndescent.NNDescent(
85
+ data,
86
+ n_neighbors=n_neighbors,
87
+ metric="euclidean",
88
+ random_state=random_state,
89
+ n_jobs=threads,
90
+ )
91
+ knn_indices, knn_dists = nn_index.neighbor_graph # shapes: (n_obs, n_neighbors)
92
+
93
+ rows = np.repeat(np.arange(n_obs, dtype=np.int64), n_neighbors)
94
+ cols = knn_indices.reshape(-1).astype(np.int64, copy=False)
95
+ vals = knn_dists.reshape(-1).astype(np.float32, copy=False)
96
+
97
+ distances = sp.coo_matrix((vals, (rows, cols)), shape=(n_obs, n_obs)).tocsr()
98
+
99
+ # Optional: ensure diagonal is 0 and (optionally) symmetrize
100
+ distances.setdiag(0.0)
101
+ distances.eliminate_zeros()
102
+
103
+ if symmetrize:
104
+ # Keep the smaller directed distance for each undirected edge
105
+ distances = distances.minimum(distances.T)
106
+
107
+ adata.obsp[out_key] = distances
108
+ adata.uns[out_key] = {
109
+ "params": {
110
+ "obsm": obsm,
111
+ "n_neighbors_requested": int(knn_neighbors),
112
+ "n_neighbors_used": int(n_neighbors),
113
+ "method": "pynndescent",
114
+ "metric": "euclidean",
115
+ "random_state": random_state,
116
+ "n_jobs": int(threads),
117
+ "symmetrize": bool(symmetrize),
118
+ }
119
+ }
120
+
121
+ return adata
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
10
+
11
+ if TYPE_CHECKING:
12
+ import anndata as ad
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def calculate_leiden(
18
+ adata: "ad.AnnData",
19
+ *,
20
+ resolution: float = 0.1,
21
+ key_added: str = "leiden",
22
+ connectivities_key: str = "connectivities",
23
+ ) -> "ad.AnnData":
24
+ """Compute Leiden clusters from a connectivity graph.
25
+
26
+ Args:
27
+ adata: AnnData object with ``obsp[connectivities_key]`` set.
28
+ resolution: Resolution parameter for Leiden clustering.
29
+ key_added: Column name to store cluster assignments in ``adata.obs``.
30
+ connectivities_key: Key in ``adata.obsp`` containing a sparse adjacency matrix.
31
+
32
+ Returns:
33
+ Updated AnnData object with Leiden labels in ``adata.obs``.
34
+ """
35
+ if connectivities_key not in adata.obsp:
36
+ raise KeyError(f"Missing connectivities '{connectivities_key}' in adata.obsp.")
37
+
38
+ igraph = require("igraph", extra="cluster", purpose="Leiden clustering")
39
+ leidenalg = require("leidenalg", extra="cluster", purpose="Leiden clustering")
40
+
41
+ connectivities = adata.obsp[connectivities_key]
42
+ coo = connectivities.tocoo()
43
+ edges = list(zip(coo.row.tolist(), coo.col.tolist()))
44
+ graph = igraph.Graph(n=connectivities.shape[0], edges=edges, directed=False)
45
+ graph.es["weight"] = coo.data.tolist()
46
+
47
+ partition = leidenalg.find_partition(
48
+ graph,
49
+ leidenalg.RBConfigurationVertexPartition,
50
+ weights=graph.es["weight"],
51
+ resolution_parameter=resolution,
52
+ )
53
+
54
+ labels = np.array(partition.membership, dtype=str)
55
+ adata.obs[key_added] = pd.Categorical(labels)
56
+ logger.info("Stored Leiden clusters in adata.obs['%s'].", key_added)
57
+ return adata
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Sequence
4
+
5
+ import numpy as np
6
+
7
+ from smftools.logging_utils import get_logger
8
+ from smftools.optional_imports import require
9
+
10
+ if TYPE_CHECKING:
11
+ import anndata as ad
12
+ import numpy as np
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def calculate_nmf(
18
+ adata: "ad.AnnData",
19
+ layer: str | None = "nan_half",
20
+ var_mask: "np.ndarray | Sequence[bool] | None" = None,
21
+ n_components: int = 2,
22
+ max_iter: int = 200,
23
+ random_state: int = 0,
24
+ overwrite: bool = True,
25
+ embedding_key: str = "X_nmf",
26
+ components_key: str = "H_nmf",
27
+ uns_key: str = "nmf",
28
+ suffix: str | None = None,
29
+ ) -> "ad.AnnData":
30
+ """Compute a low-dimensional NMF embedding.
31
+
32
+ Args:
33
+ adata: AnnData object to update.
34
+ layer: Layer name to use for NMF (``None`` uses ``adata.X``).
35
+ var_mask: Optional boolean mask to subset features.
36
+ n_components: Number of NMF components to compute.
37
+ max_iter: Maximum number of NMF iterations.
38
+ random_state: Random seed for the NMF initializer.
39
+ overwrite: Whether to recompute if the embedding already exists.
40
+ embedding_key: Key for the embedding in ``adata.obsm``.
41
+ components_key: Key for the components matrix in ``adata.varm``.
42
+ uns_key: Key for metadata stored in ``adata.uns``.
43
+
44
+ Returns:
45
+ anndata.AnnData: Updated AnnData object.
46
+ """
47
+ from scipy.sparse import issparse
48
+
49
+ require("sklearn", extra="ml-base", purpose="NMF calculation")
50
+ from sklearn.decomposition import NMF
51
+
52
+ if suffix:
53
+ embedding_key = f"{embedding_key}_{suffix}"
54
+ components_key = f"{components_key}_{suffix}"
55
+ uns_key = f"{uns_key}_{suffix}"
56
+
57
+ has_embedding = embedding_key in adata.obsm
58
+ has_components = components_key in adata.varm
59
+ if has_embedding and has_components and not overwrite:
60
+ logger.info("NMF embedding and components already present; skipping recomputation.")
61
+ return adata
62
+ if has_embedding and not has_components and not overwrite:
63
+ logger.info("NMF embedding present without components; recomputing to store components.")
64
+
65
+ subset_mask = None
66
+ if var_mask is not None:
67
+ subset_mask = np.asarray(var_mask, dtype=bool)
68
+ if subset_mask.ndim != 1 or subset_mask.shape[0] != adata.n_vars:
69
+ raise ValueError(
70
+ "var_mask must be a 1D boolean array with length matching adata.n_vars."
71
+ )
72
+ adata_subset = adata[:, subset_mask].copy()
73
+ logger.info(
74
+ "Subsetting adata: retained %s features based on filters %s",
75
+ adata_subset.shape[1],
76
+ "var_mask",
77
+ )
78
+ else:
79
+ adata_subset = adata.copy()
80
+ logger.info("No var_mask provided. Using all features.")
81
+
82
+ data = adata_subset.layers[layer] if layer else adata_subset.X
83
+ if issparse(data):
84
+ data = data.copy()
85
+ if data.data.size and np.isnan(data.data).any():
86
+ logger.warning("NaNs detected in sparse data, filling with 0.5 before NMF.")
87
+ data.data = np.nan_to_num(data.data, nan=0.5)
88
+ if data.data.size and (data.data < 0).any():
89
+ logger.warning("Negative values detected in sparse data, clipping to 0 for NMF.")
90
+ data.data[data.data < 0] = 0
91
+ else:
92
+ if np.isnan(data).any():
93
+ logger.warning("NaNs detected, filling with 0.5 before NMF.")
94
+ data = np.nan_to_num(data, nan=0.5)
95
+ if (data < 0).any():
96
+ logger.warning("Negative values detected, clipping to 0 for NMF.")
97
+ data = np.clip(data, a_min=0, a_max=None)
98
+
99
+ model = NMF(
100
+ n_components=n_components,
101
+ init="nndsvda",
102
+ max_iter=max_iter,
103
+ random_state=random_state,
104
+ )
105
+ embedding = model.fit_transform(data)
106
+ components = model.components_.T
107
+
108
+ if subset_mask is not None:
109
+ components_matrix = np.zeros((adata.shape[1], components.shape[1]))
110
+ components_matrix[subset_mask, :] = components
111
+ else:
112
+ components_matrix = components
113
+
114
+ adata.obsm[embedding_key] = embedding
115
+ adata.varm[components_key] = components_matrix
116
+ adata.uns[uns_key] = {
117
+ "n_components": n_components,
118
+ "max_iter": max_iter,
119
+ "random_state": random_state,
120
+ "layer": layer,
121
+ "var_mask_provided": var_mask is not None,
122
+ "components_key": components_key,
123
+ }
124
+
125
+ logger.info(
126
+ "Stored: adata.obsm['%s'] and adata.varm['%s']",
127
+ embedding_key,
128
+ components_key,
129
+ )
130
+ return adata
@@ -0,0 +1,180 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Sequence
4
+
5
+ from smftools.logging_utils import get_logger
6
+ from smftools.optional_imports import require
7
+
8
+ if TYPE_CHECKING:
9
+ import anndata as ad
10
+ import numpy as np
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def calculate_pca(
16
+ adata: "ad.AnnData",
17
+ layer: str | None = "nan_half",
18
+ var_mask: "np.ndarray | Sequence[bool] | None" = None,
19
+ n_pcs: int = 15,
20
+ overwrite: bool = True,
21
+ output_suffix: str | None = None,
22
+ fill_nan: float | None = 0.5,
23
+ ) -> "ad.AnnData":
24
+ """Compute PCA and store scores in `.obsm` and loadings in `.varm`."""
25
+
26
+ import numpy as np
27
+ import scipy.sparse as sp
28
+
29
+ obsm_output = f"X_pca_{output_suffix}" if output_suffix else "X_pca"
30
+ varm_output = f"PCs_{output_suffix}" if output_suffix else "PCs"
31
+
32
+ if not overwrite and obsm_output in adata.obsm and varm_output in adata.varm:
33
+ logger.info(
34
+ "PCA outputs already exist and overwrite=False; skipping (%s, %s).",
35
+ obsm_output,
36
+ varm_output,
37
+ )
38
+ return adata
39
+
40
+ # --- Build feature subset mask (over vars) ---
41
+ if var_mask is not None:
42
+ subset_mask = np.asarray(var_mask, dtype=bool)
43
+ if subset_mask.ndim != 1 or subset_mask.shape[0] != adata.n_vars:
44
+ raise ValueError(
45
+ "var_mask must be a 1D boolean array with length matching adata.n_vars."
46
+ )
47
+ n_vars_used = int(subset_mask.sum())
48
+ if n_vars_used == 0:
49
+ raise ValueError("var_mask retained 0 features.")
50
+ logger.info(
51
+ "Subsetting vars: retained %d / %d features from var_mask",
52
+ n_vars_used,
53
+ adata.n_vars,
54
+ )
55
+ else:
56
+ subset_mask = slice(None)
57
+ n_vars_used = adata.n_vars
58
+ logger.info("No var_mask provided; using all %d features.", adata.n_vars)
59
+
60
+ # --- Pull matrix view ---
61
+ if layer is None:
62
+ matrix = adata.X
63
+ layer_used = None
64
+ else:
65
+ if layer not in adata.layers:
66
+ raise KeyError(
67
+ f"Layer {layer!r} not found in adata.layers. Available: {list(adata.layers.keys())}"
68
+ )
69
+ matrix = adata.layers[layer]
70
+ layer_used = layer
71
+
72
+ matrix = matrix[:, subset_mask] # slice view (sparse OK)
73
+
74
+ n_obs = matrix.shape[0]
75
+ if n_obs < 2:
76
+ raise ValueError(f"PCA requires at least 2 observations; got n_obs={n_obs}")
77
+ if n_vars_used < 1:
78
+ raise ValueError("PCA requires at least 1 feature.")
79
+
80
+ n_pcs_requested = int(n_pcs)
81
+ n_pcs_used = min(n_pcs_requested, n_obs, n_vars_used)
82
+ if n_pcs_used < 1:
83
+ raise ValueError(f"n_pcs_used became {n_pcs_used}; check inputs.")
84
+
85
+ # --- NaN handling (dense only; sparse usually won’t store NaNs) ---
86
+ if not sp.issparse(matrix):
87
+ X = np.asarray(matrix, dtype=np.float32)
88
+ if fill_nan is not None and np.isnan(X).any():
89
+ logger.warning("NaNs detected; filling NaNs with %s before PCA.", fill_nan)
90
+ X = np.nan_to_num(X, nan=float(fill_nan))
91
+ else:
92
+ X = matrix # keep sparse
93
+
94
+ # --- PCA ---
95
+ # Prefer sklearn's randomized PCA for speed on big matrices.
96
+ used_sklearn = False
97
+ try:
98
+ sklearn = require("sklearn", extra="ml", purpose="PCA computation")
99
+ from sklearn.decomposition import PCA, TruncatedSVD
100
+
101
+ if sp.issparse(X):
102
+ # TruncatedSVD works on sparse without centering; good approximation.
103
+ # If you *need* centered PCA on sparse, you'd need different machinery.
104
+ logger.info("Running TruncatedSVD (sparse) with n_components=%d", n_pcs_used)
105
+ model = TruncatedSVD(n_components=n_pcs_used, random_state=0)
106
+ scores = model.fit_transform(X) # (n_obs, n_pcs)
107
+ loadings = model.components_.T # (n_vars_used, n_pcs)
108
+ mean = None
109
+ explained_variance_ratio = getattr(model, "explained_variance_ratio_", None)
110
+ else:
111
+ logger.info(
112
+ "Running sklearn PCA with n_components=%d (svd_solver=randomized)", n_pcs_used
113
+ )
114
+ model = PCA(n_components=n_pcs_used, svd_solver="randomized", random_state=0)
115
+ scores = model.fit_transform(X) # (n_obs, n_pcs)
116
+ loadings = model.components_.T # (n_vars_used, n_pcs)
117
+ mean = model.mean_
118
+ explained_variance_ratio = model.explained_variance_ratio_
119
+
120
+ used_sklearn = True
121
+
122
+ except Exception as e:
123
+ # Fallback to your manual SVD (dense only)
124
+ if sp.issparse(X):
125
+ raise RuntimeError(
126
+ "Sparse input PCA fallback is not implemented without sklearn. "
127
+ "Install scikit-learn (extra 'ml') or densify upstream."
128
+ ) from e
129
+
130
+ import scipy.linalg as spla
131
+
132
+ logger.warning(
133
+ "sklearn PCA unavailable; falling back to full SVD (can be slow). Reason: %s", e
134
+ )
135
+ Xd = np.asarray(X, dtype=np.float64)
136
+ mean = Xd.mean(axis=0)
137
+ centered = Xd - mean
138
+ u, s, vt = spla.svd(centered, full_matrices=False)
139
+ u = u[:, :n_pcs_used]
140
+ s = s[:n_pcs_used]
141
+ vt = vt[:n_pcs_used]
142
+ scores = u * s
143
+ loadings = vt.T
144
+ explained_variance_ratio = None
145
+
146
+ # --- Store scores (obsm) ---
147
+ adata.obsm[obsm_output] = scores
148
+
149
+ # --- Store loadings (varm) with original var dimension ---
150
+ pc_matrix = np.zeros((adata.n_vars, n_pcs_used), dtype=np.float32)
151
+ if isinstance(subset_mask, slice):
152
+ pc_matrix[:, :] = loadings
153
+ else:
154
+ pc_matrix[subset_mask, :] = loadings.astype(np.float32, copy=False)
155
+
156
+ adata.varm[varm_output] = pc_matrix
157
+
158
+ # --- Metadata ---
159
+ adata.uns[obsm_output] = {
160
+ "params": {
161
+ "layer": layer_used,
162
+ "var_mask_provided": var_mask is not None,
163
+ "n_pcs_requested": n_pcs_requested,
164
+ "n_pcs_used": int(n_pcs_used),
165
+ "used_sklearn": used_sklearn,
166
+ "fill_nan": fill_nan,
167
+ "note_sparse": bool(sp.issparse(matrix)),
168
+ },
169
+ "explained_variance_ratio": explained_variance_ratio,
170
+ "mean": mean.tolist() if (mean is not None and isinstance(mean, np.ndarray)) else None,
171
+ }
172
+
173
+ logger.info(
174
+ "Stored PCA: adata.obsm[%s] (%s) and adata.varm[%s] (%s)",
175
+ obsm_output,
176
+ scores.shape,
177
+ varm_output,
178
+ pc_matrix.shape,
179
+ )
180
+ return adata