sclab 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sclab/__init__.py CHANGED
@@ -6,4 +6,4 @@ __all__ = [
6
6
  "SCLabDashboard",
7
7
  ]
8
8
 
9
- __version__ = "0.3.2"
9
+ __version__ = "0.3.4"
@@ -63,7 +63,7 @@ class Integration(ProcessorStepBase):
63
63
 
64
64
  def function(
65
65
  self,
66
- use_rep: str,
66
+ use_rep: str | None,
67
67
  group_by: str,
68
68
  flavor: str,
69
69
  reference_batch: str | None,
@@ -71,6 +71,9 @@ class Integration(ProcessorStepBase):
71
71
  ):
72
72
  adata = self.parent.dataset.adata
73
73
 
74
+ if use_rep is None:
75
+ use_rep = "X"
76
+
74
77
  key_added = f"{use_rep}_{flavor}"
75
78
  kvargs = {
76
79
  "adata": adata,
@@ -106,25 +106,41 @@ class Preprocess(ProcessorStepBase):
106
106
  )
107
107
  pbar.update(10)
108
108
 
109
- sc.pp.highly_variable_genes(
110
- adata,
111
- layer=f"{layer}_log1p",
112
- flavor="seurat",
113
- batch_key=group_by,
114
- )
115
- hvg_seurat = adata.var["highly_variable"]
116
- sc.pp.highly_variable_genes(
117
- adata,
118
- layer=layer,
119
- flavor="seurat_v3_paper",
120
- batch_key=group_by,
121
- n_top_genes=hvg_seurat.sum(),
122
- )
123
- hvg_seurat_v3 = adata.var["highly_variable"]
109
+ if group_by is not None:
110
+ adata.var["highly_variable"] = False
111
+ for name, idx in adata.obs.groupby(group_by, observed=True).groups.items():
112
+ hvg_seurat = sc.pp.highly_variable_genes(
113
+ adata[idx],
114
+ layer=f"{layer}_log1p",
115
+ flavor="seurat",
116
+ inplace=False,
117
+ )["highly_variable"]
118
+
119
+ hvg_seurat_v3 = sc.pp.highly_variable_genes(
120
+ adata[idx],
121
+ layer=layer,
122
+ flavor="seurat_v3_paper",
123
+ n_top_genes=hvg_seurat.sum(),
124
+ inplace=False,
125
+ )["highly_variable"]
126
+
127
+ adata.var[f"highly_variable_{name}"] = hvg_seurat | hvg_seurat_v3
128
+ adata.var["highly_variable"] |= adata.var[f"highly_variable_{name}"]
129
+
130
+ else:
131
+ sc.pp.highly_variable_genes(adata, layer=f"{layer}_log1p", flavor="seurat")
132
+ hvg_seurat = adata.var["highly_variable"]
133
+
134
+ sc.pp.highly_variable_genes(
135
+ adata,
136
+ layer=layer,
137
+ flavor="seurat_v3_paper",
138
+ n_top_genes=hvg_seurat.sum(),
139
+ )
140
+ hvg_seurat_v3 = adata.var["highly_variable"]
141
+
142
+ adata.var["highly_variable"] = hvg_seurat | hvg_seurat_v3
124
143
 
125
- adata.var["highly_variable"] = hvg_seurat | hvg_seurat_v3
126
- adata.var["highly_variable_seurat"] = hvg_seurat
127
- adata.var["highly_variable_seurat_v3"] = hvg_seurat_v3
128
144
  pbar.update(10)
129
145
  pbar.update(10)
130
146
 
@@ -2,8 +2,11 @@ from ._cca_integrate import cca_integrate, cca_integrate_pair
2
2
  from ._filter_obs import filter_obs
3
3
  from ._harmony_integrate import harmony_integrate
4
4
  from ._normalize_weighted import normalize_weighted
5
+ from ._pca import pca
6
+ from ._preprocess import preprocess
7
+ from ._qc import qc
5
8
  from ._subset import subset_obs, subset_var
6
- from ._transfer_metadata import transfer_metadata
9
+ from ._transfer_metadata import propagate_metadata, transfer_metadata
7
10
  from ._transform import pool_neighbors
8
11
 
9
12
  __all__ = [
@@ -12,7 +15,11 @@ __all__ = [
12
15
  "filter_obs",
13
16
  "harmony_integrate",
14
17
  "normalize_weighted",
18
+ "pca",
15
19
  "pool_neighbors",
20
+ "preprocess",
21
+ "propagate_metadata",
22
+ "qc",
16
23
  "subset_obs",
17
24
  "subset_var",
18
25
  "transfer_metadata",
sclab/preprocess/_cca.py CHANGED
@@ -1,24 +1,31 @@
1
1
  import logging
2
+ import os
2
3
  from typing import Literal
3
4
 
4
5
  import numpy as np
6
+ from joblib import Parallel, delayed
5
7
  from numpy import matrix
6
8
  from numpy.typing import NDArray
7
9
  from scipy.linalg import svd
8
10
  from scipy.sparse import csc_matrix, csr_matrix, issparse
11
+ from scipy.sparse import vstack as sparse_vstack
9
12
  from scipy.sparse.linalg import svds
10
13
  from sklearn.utils.extmath import randomized_svd
11
14
 
12
15
  logger = logging.getLogger(__name__)
13
16
 
14
17
 
18
+ N_CPUS = os.cpu_count()
19
+
20
+
15
21
  def cca(
16
22
  X: NDArray | csr_matrix | csc_matrix,
17
23
  Y: NDArray | csr_matrix | csc_matrix,
18
24
  n_components=None,
19
- svd_solver: Literal["full", "partial", "randomized"] = "partial",
25
+ svd_solver: Literal["full", "partial", "randomized"] = "randomized",
20
26
  normalize: bool = False,
21
27
  random_state=42,
28
+ n_jobs: int = N_CPUS,
22
29
  ) -> tuple[NDArray, NDArray, NDArray]:
23
30
  """
24
31
  CCA-style integration for two single-cell matrices with unequal numbers of cells.
@@ -50,7 +57,7 @@ def cca(
50
57
  k = n_components or min(n1, n2)
51
58
 
52
59
  if issparse(X):
53
- C = _cross_covariance_sparse(X, Y)
60
+ C = _cross_covariance_sparse(X, Y, n_jobs=n_jobs)
54
61
  else:
55
62
  C = _cross_covariance_dense(X, Y)
56
63
 
@@ -103,7 +110,7 @@ def _svd_decomposition(
103
110
  return Uc, s, Vct
104
111
 
105
112
 
106
- def _cross_covariance_sparse(X: csr_matrix, Y: csr_matrix) -> NDArray:
113
+ def _cross_covariance_sparse(X: csr_matrix, Y: csr_matrix, n_jobs=N_CPUS) -> NDArray:
107
114
  _, p1 = X.shape
108
115
  _, p2 = Y.shape
109
116
  if p1 != p2:
@@ -118,7 +125,7 @@ def _cross_covariance_sparse(X: csr_matrix, Y: csr_matrix) -> NDArray:
118
125
  mux: matrix = X.mean(axis=0)
119
126
  muy: matrix = Y.mean(axis=0)
120
127
 
121
- XYt: csr_matrix = X.dot(Y.T)
128
+ XYt: csr_matrix = _spmm_parallel(X, Y.T, n_jobs=n_jobs)
122
129
  Xmuyt: matrix = X.dot(muy.T)
123
130
  muxYt: matrix = Y.dot(mux.T).T
124
131
  muxmuyt: float = (mux @ muy.T)[0, 0]
@@ -152,3 +159,18 @@ def _dense_scale(A: NDArray) -> NDArray:
152
159
  A = np.asarray(A)
153
160
  eps = np.finfo(A.dtype).eps
154
161
  return A / (A.std(axis=0, ddof=1, keepdims=True) + eps)
162
+
163
+
164
+ def _spmm_chunk(A_csr, X, start, stop):
165
+ return A_csr[start:stop, :] @ X
166
+
167
+
168
+ def _spmm_parallel(A_csr: csr_matrix, X_csc: csc_matrix, n_jobs=N_CPUS):
169
+ n = A_csr.shape[0]
170
+
171
+ bounds = np.linspace(0, n, n_jobs + 1, dtype=int)
172
+ Ys = Parallel(n_jobs=n_jobs, prefer="processes")(
173
+ delayed(_spmm_chunk)(A_csr, X_csc, bounds[i], bounds[i + 1])
174
+ for i in range(n_jobs)
175
+ )
176
+ return sparse_vstack(Ys) # result is sparse if X is sparse, dense otherwise
@@ -13,8 +13,8 @@ def cca_integrate(
13
13
  reference_batch: str | list[str] | None = None,
14
14
  mask_var: str | None = None,
15
15
  n_components: int = 30,
16
- svd_solver: str = "partial",
17
- normalize: bool = False,
16
+ svd_solver: str = "randomized",
17
+ normalize: bool = True,
18
18
  random_state: int | None = None,
19
19
  ):
20
20
  n_groups = adata.obs[key].nunique()
@@ -46,8 +46,8 @@ def cca_integrate_pair(
46
46
  adjusted_basis: str | None = None,
47
47
  mask_var: str | None = None,
48
48
  n_components: int = 30,
49
- svd_solver: str = "partial",
50
- normalize: bool = False,
49
+ svd_solver: str = "randomized",
50
+ normalize: bool = True,
51
51
  random_state: int | None = None,
52
52
  ):
53
53
  if basis is None:
@@ -9,6 +9,7 @@ def normalize_weighted(
9
9
  adata: AnnData,
10
10
  target_scale: float | None = None,
11
11
  batch_key: str | None = None,
12
+ q: float = 0.99,
12
13
  ) -> None:
13
14
  if batch_key is not None:
14
15
  for _, idx in adata.obs.groupby(batch_key, observed=True).groups.items():
@@ -22,6 +23,8 @@ def normalize_weighted(
22
23
 
23
24
  return
24
25
 
26
+ target_scale = None
27
+
25
28
  X: csr_matrix
26
29
  Y: csr_matrix
27
30
  Z: csr_matrix
@@ -38,6 +41,7 @@ def normalize_weighted(
38
41
  Y.eliminate_zeros()
39
42
  Y.data = -Y.data * np.log(Y.data)
40
43
  entropy = Y.sum(axis=0)
44
+ entropy[:, entropy.A1 < np.quantile(entropy.A1, q)] *= 0.0
41
45
 
42
46
  Z = X.multiply(entropy)
43
47
  Z = Z.tocsr()
@@ -48,7 +52,7 @@ def normalize_weighted(
48
52
  "ignore", category=RuntimeWarning, message="divide by zero"
49
53
  )
50
54
  scale = Z.sum(axis=1)
51
- Z = Z.multiply(1 / scale)
55
+ Z = X.multiply(1 / scale)
52
56
  Z = Z.tocsr()
53
57
 
54
58
  if target_scale is None:
@@ -0,0 +1,51 @@
1
+ from anndata import AnnData
2
+
3
+
4
+ def pca(
5
+ adata: AnnData,
6
+ layer: str | None = None,
7
+ n_comps: int = 30,
8
+ mask_var: str | None = None,
9
+ batch_key: str | None = None,
10
+ reference_batch: str | None = None,
11
+ zero_center: bool = False,
12
+ ):
13
+ import scanpy as sc
14
+
15
+ pca_kwargs = dict(
16
+ n_comps=n_comps,
17
+ layer=layer,
18
+ mask_var=mask_var,
19
+ svd_solver="arpack",
20
+ )
21
+
22
+ if reference_batch:
23
+ obs_mask = adata.obs[batch_key] == reference_batch
24
+ adata_ref = adata[obs_mask].copy()
25
+ if mask_var == "highly_variable":
26
+ sc.pp.highly_variable_genes(
27
+ adata_ref, layer=f"{layer if layer else 'X'}_log1p", flavor="seurat"
28
+ )
29
+ hvg_seurat = adata_ref.var["highly_variable"]
30
+ sc.pp.highly_variable_genes(
31
+ adata_ref,
32
+ layer=layer,
33
+ flavor="seurat_v3_paper",
34
+ n_top_genes=hvg_seurat.sum(),
35
+ )
36
+ hvg_seurat_v3 = adata_ref.var["highly_variable"]
37
+ adata_ref.var["highly_variable"] = hvg_seurat | hvg_seurat_v3
38
+
39
+ sc.pp.pca(adata_ref, **pca_kwargs)
40
+ uns_pca = adata_ref.uns["pca"]
41
+ uns_pca["reference_batch"] = reference_batch
42
+ PCs = adata_ref.varm["PCs"]
43
+ adata.obsm["X_pca"] = adata.X.dot(PCs)
44
+ adata.uns["pca"] = uns_pca
45
+ adata.varm["PCs"] = PCs
46
+ else:
47
+ sc.pp.pca(adata, **pca_kwargs)
48
+ adata.obsm["X_pca"] = adata.X.dot(adata.varm["PCs"])
49
+
50
+ if zero_center:
51
+ adata.obsm["X_pca"] -= adata.obsm["X_pca"].mean(axis=0, keepdims=True)
@@ -0,0 +1,155 @@
1
+ import warnings
2
+ from typing import Literal
3
+
4
+ import numpy as np
5
+ from anndata import AnnData, ImplicitModificationWarning
6
+ from tqdm.auto import tqdm
7
+
8
+
9
+ def preprocess(
10
+ adata: AnnData,
11
+ counts_layer: str = "counts",
12
+ group_by: str | None = None,
13
+ min_cells: int = 5,
14
+ min_genes: int = 5,
15
+ compute_hvg: bool = True,
16
+ regress_total_counts: bool = False,
17
+ regress_n_genes: bool = False,
18
+ normalization_method: Literal["library", "weighted", "none"] = "library",
19
+ target_scale: float = 1e4,
20
+ weighted_norm_quantile: float = 0.9,
21
+ log1p: bool = True,
22
+ scale: bool = True,
23
+ ):
24
+ import scanpy as sc
25
+
26
+ from ._normalize_weighted import normalize_weighted
27
+
28
+ with tqdm(total=100, bar_format="{percentage:3.0f}%|{bar}|") as pbar:
29
+ if counts_layer not in adata.layers:
30
+ adata.layers[counts_layer] = adata.X.copy()
31
+
32
+ if f"{counts_layer}_log1p" not in adata.layers:
33
+ adata.layers[f"{counts_layer}_log1p"] = sc.pp.log1p(
34
+ adata.layers[counts_layer].copy()
35
+ )
36
+ pbar.update(10)
37
+
38
+ adata.X = adata.layers[counts_layer].copy()
39
+ sc.pp.calculate_qc_metrics(
40
+ adata,
41
+ percent_top=None,
42
+ log1p=False,
43
+ inplace=True,
44
+ )
45
+ sc.pp.filter_cells(adata, min_genes=min_genes)
46
+ sc.pp.filter_genes(adata, min_cells=min_cells)
47
+ pbar.update(10)
48
+
49
+ sc.pp.calculate_qc_metrics(
50
+ adata,
51
+ percent_top=None,
52
+ log1p=False,
53
+ inplace=True,
54
+ )
55
+ pbar.update(10)
56
+
57
+ if compute_hvg:
58
+ if group_by is not None:
59
+ adata.var["highly_variable"] = False
60
+ for name, idx in adata.obs.groupby(
61
+ group_by, observed=True
62
+ ).groups.items():
63
+ hvg_seurat = sc.pp.highly_variable_genes(
64
+ adata[idx],
65
+ layer=f"{counts_layer}_log1p",
66
+ flavor="seurat",
67
+ inplace=False,
68
+ )["highly_variable"]
69
+
70
+ hvg_seurat_v3 = sc.pp.highly_variable_genes(
71
+ adata[idx],
72
+ layer=counts_layer,
73
+ flavor="seurat_v3_paper",
74
+ n_top_genes=hvg_seurat.sum(),
75
+ inplace=False,
76
+ )["highly_variable"]
77
+
78
+ adata.var[f"highly_variable_{name}"] = hvg_seurat | hvg_seurat_v3
79
+ adata.var["highly_variable"] |= adata.var[f"highly_variable_{name}"]
80
+
81
+ else:
82
+ sc.pp.highly_variable_genes(
83
+ adata, layer=f"{counts_layer}_log1p", flavor="seurat"
84
+ )
85
+ hvg_seurat = adata.var["highly_variable"]
86
+
87
+ sc.pp.highly_variable_genes(
88
+ adata,
89
+ layer=counts_layer,
90
+ flavor="seurat_v3_paper",
91
+ n_top_genes=hvg_seurat.sum(),
92
+ )
93
+ hvg_seurat_v3 = adata.var["highly_variable"]
94
+
95
+ adata.var["highly_variable"] = hvg_seurat | hvg_seurat_v3
96
+
97
+ pbar.update(10)
98
+ pbar.update(10)
99
+
100
+ new_layer = counts_layer
101
+ if normalization_method == "library":
102
+ new_layer += "_normt"
103
+ sc.pp.normalize_total(adata, target_sum=target_scale)
104
+ elif normalization_method == "weighted":
105
+ new_layer += "_normw"
106
+ normalize_weighted(
107
+ adata,
108
+ target_scale=target_scale,
109
+ batch_key=group_by,
110
+ q=weighted_norm_quantile,
111
+ )
112
+
113
+ pbar.update(10)
114
+ pbar.update(10)
115
+
116
+ if log1p:
117
+ new_layer += "_log1p"
118
+ adata.uns.pop("log1p", None)
119
+ sc.pp.log1p(adata)
120
+ pbar.update(10)
121
+
122
+ vars_to_regress = []
123
+ if regress_n_genes:
124
+ vars_to_regress.append("n_genes_by_counts")
125
+
126
+ if regress_total_counts and log1p:
127
+ adata.obs["log1p_total_counts"] = np.log1p(adata.obs["total_counts"])
128
+ vars_to_regress.append("log1p_total_counts")
129
+ elif regress_total_counts:
130
+ vars_to_regress.append("total_counts")
131
+
132
+ if vars_to_regress:
133
+ new_layer += "_regr"
134
+ sc.pp.regress_out(adata, keys=vars_to_regress, n_jobs=1)
135
+ pbar.update(10)
136
+
137
+ if scale:
138
+ new_layer += "_scale"
139
+ if group_by is not None:
140
+ for _, idx in adata.obs.groupby(group_by, observed=True).groups.items():
141
+ with warnings.catch_warnings():
142
+ warnings.filterwarnings(
143
+ "ignore",
144
+ category=ImplicitModificationWarning,
145
+ message="Modifying `X` on a view results in data being overridden",
146
+ )
147
+ adata[idx].X = sc.pp.scale(adata[idx].X, zero_center=False)
148
+ else:
149
+ sc.pp.scale(adata, zero_center=False)
150
+
151
+ adata.layers[new_layer] = adata.X.copy()
152
+
153
+ pbar.update(10)
154
+
155
+ adata.X = adata.X.astype(np.float32)
@@ -0,0 +1,38 @@
1
+ import numpy as np
2
+ from anndata import AnnData
3
+
4
+
5
+ def qc(
6
+ adata: AnnData,
7
+ counts_layer: str = "counts",
8
+ min_counts: int = 50,
9
+ min_genes: int = 5,
10
+ min_cells: int = 5,
11
+ max_rank: int = 0,
12
+ ):
13
+ import scanpy as sc
14
+
15
+ if counts_layer not in adata.layers:
16
+ adata.layers[counts_layer] = adata.X.copy()
17
+
18
+ adata.layers["qc_tmp_current_X"] = adata.X
19
+ adata.X = adata.layers[counts_layer].copy()
20
+ rowsums = np.asarray(adata.X.sum(axis=1)).squeeze()
21
+
22
+ obs_idx = adata.obs_names[rowsums >= min_counts]
23
+ adata._inplace_subset_obs(obs_idx)
24
+
25
+ sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)
26
+
27
+ sc.pp.filter_cells(adata, min_genes=min_genes)
28
+ sc.pp.filter_genes(adata, min_cells=min_cells)
29
+ sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)
30
+ adata.obs["barcode_rank"] = adata.obs["total_counts"].rank(ascending=False)
31
+
32
+ # Restore original X
33
+ adata.X = adata.layers.pop("qc_tmp_current_X")
34
+
35
+ if max_rank > 0:
36
+ series = adata.obs["barcode_rank"]
37
+ index = series.loc[series < max_rank].index
38
+ adata._inplace_subset_obs(index)
@@ -0,0 +1,116 @@
1
+ import numpy as np
2
+ from anndata import AnnData
3
+ from numpy.typing import NDArray
4
+
5
+
6
+ def rpca(
7
+ adata: AnnData,
8
+ key: str,
9
+ *,
10
+ basis: str = "X",
11
+ adjusted_basis: str | None = None,
12
+ reference_batch: str | list[str] | None = None,
13
+ mask_var: str | None = None,
14
+ n_components: int = 30,
15
+ min_variance_ratio: float = 0.0005,
16
+ svd_solver: str = "arpack",
17
+ normalize: bool = True,
18
+ ):
19
+ if basis is None:
20
+ basis = "X"
21
+
22
+ if adjusted_basis is None:
23
+ adjusted_basis = basis + "_rpca"
24
+
25
+ if mask_var is not None:
26
+ mask = adata.var[mask_var].values
27
+ else:
28
+ mask = np.ones(adata.n_vars, dtype=bool)
29
+
30
+ X = _get_basis(adata[:, mask], basis)
31
+ uns = {}
32
+
33
+ groups = adata.obs.groupby(key, observed=True).groups
34
+ if reference_batch is None:
35
+ reference_batch = list(groups.keys())
36
+ elif isinstance(reference_batch, str):
37
+ reference_batch = [reference_batch]
38
+
39
+ for gr, idx in groups.items():
40
+ if gr not in reference_batch:
41
+ continue
42
+
43
+ ref_basis_key = f"{adjusted_basis}_{gr}"
44
+ ref_PCs_key = f"{adjusted_basis}_{gr}_PCs"
45
+
46
+ X_reference = _get_basis(adata[idx, mask], basis)
47
+ proj_result = pca_projection(
48
+ X,
49
+ X_reference,
50
+ n_components=n_components,
51
+ min_variance_ratio=min_variance_ratio,
52
+ svd_solver=svd_solver,
53
+ normalize=normalize,
54
+ )
55
+ res_ncomps = proj_result[0].shape[1]
56
+ components = np.zeros((res_ncomps, adata.n_vars))
57
+ components[:, mask] = proj_result[1]
58
+
59
+ adata.obsm[ref_basis_key] = proj_result[0]
60
+ adata.varm[ref_PCs_key] = components.T
61
+
62
+ uns[gr] = {
63
+ "n_components": res_ncomps,
64
+ "explained_variance_ratio": proj_result[2],
65
+ "explained_variance": proj_result[3],
66
+ }
67
+
68
+ adata.uns[adjusted_basis] = uns
69
+
70
+
71
+ def pca_projection(
72
+ X: NDArray,
73
+ X_reference: NDArray,
74
+ n_components: int = 30,
75
+ min_variance_ratio: float = 0.0005,
76
+ svd_solver: str = "arpack",
77
+ normalize: bool = False,
78
+ ) -> tuple[NDArray, NDArray, NDArray, NDArray]:
79
+ import scanpy as sc
80
+
81
+ pca_kwargs = dict(
82
+ n_comps=n_components,
83
+ svd_solver=svd_solver,
84
+ return_info=True,
85
+ )
86
+
87
+ pca_result = sc.pp.pca(X_reference, **pca_kwargs)
88
+ _, components, explained_variance_ratio, explained_variance = pca_result
89
+
90
+ components_mask = explained_variance_ratio > min_variance_ratio
91
+ components = components[components_mask]
92
+ explained_variance_ratio = explained_variance_ratio[components_mask]
93
+ explained_variance = explained_variance[components_mask]
94
+
95
+ X_pca = X.dot(components.T)
96
+
97
+ if normalize:
98
+ X_pca = X_pca / np.linalg.norm(X_pca, axis=1, keepdims=True)
99
+
100
+ return X_pca, components, explained_variance_ratio, explained_variance
101
+
102
+
103
+ def _get_basis(adata: AnnData, basis: str):
104
+ if basis == "X":
105
+ X = adata.X
106
+
107
+ elif basis in adata.layers:
108
+ X = adata.layers[basis]
109
+
110
+ elif basis in adata.obsm:
111
+ X = adata.obsm[basis]
112
+
113
+ else:
114
+ raise ValueError(f"Unknown basis {basis}")
115
+
116
+ return X