sclab 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sclab might be problematic. Click here for more details.

Files changed (50) hide show
  1. sclab/__init__.py +1 -1
  2. sclab/dataset/_dataset.py +1 -1
  3. sclab/examples/processor_steps/__init__.py +2 -0
  4. sclab/examples/processor_steps/_doublet_detection.py +68 -0
  5. sclab/examples/processor_steps/_integration.py +37 -4
  6. sclab/examples/processor_steps/_neighbors.py +24 -4
  7. sclab/examples/processor_steps/_pca.py +5 -5
  8. sclab/examples/processor_steps/_preprocess.py +14 -1
  9. sclab/examples/processor_steps/_qc.py +22 -6
  10. sclab/gui/__init__.py +0 -0
  11. sclab/gui/components/__init__.py +5 -0
  12. sclab/gui/components/_guided_pseudotime.py +482 -0
  13. sclab/methods/__init__.py +25 -1
  14. sclab/preprocess/__init__.py +18 -0
  15. sclab/preprocess/_cca.py +154 -0
  16. sclab/preprocess/_cca_integrate.py +77 -0
  17. sclab/preprocess/_filter_obs.py +42 -0
  18. sclab/preprocess/_harmony.py +421 -0
  19. sclab/preprocess/_harmony_integrate.py +50 -0
  20. sclab/preprocess/_normalize_weighted.py +61 -0
  21. sclab/preprocess/_subset.py +208 -0
  22. sclab/preprocess/_transfer_metadata.py +137 -0
  23. sclab/preprocess/_transform.py +82 -0
  24. sclab/preprocess/_utils.py +96 -0
  25. sclab/tools/__init__.py +0 -0
  26. sclab/tools/cellflow/__init__.py +0 -0
  27. sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
  28. sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
  29. sclab/tools/cellflow/pseudotime/__init__.py +0 -0
  30. sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
  31. sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
  32. sclab/tools/cellflow/utils/__init__.py +0 -0
  33. sclab/tools/cellflow/utils/density_nd.py +136 -0
  34. sclab/tools/cellflow/utils/interpolate.py +334 -0
  35. sclab/tools/cellflow/utils/smoothen.py +124 -0
  36. sclab/tools/cellflow/utils/times.py +55 -0
  37. sclab/tools/differential_expression/__init__.py +5 -0
  38. sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
  39. sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
  40. sclab/tools/doublet_detection/__init__.py +5 -0
  41. sclab/tools/doublet_detection/_scrublet.py +64 -0
  42. sclab/tools/labeling/__init__.py +6 -0
  43. sclab/tools/labeling/sctype.py +233 -0
  44. sclab/utils/__init__.py +5 -0
  45. sclab/utils/_write_excel.py +510 -0
  46. {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/METADATA +6 -2
  47. sclab-0.3.0.dist-info/RECORD +81 -0
  48. sclab-0.2.5.dist-info/RECORD +0 -45
  49. {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/WHEEL +0 -0
  50. {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,50 @@
1
+ """Use harmony to integrate cells from different experiments.
2
+
3
+ Note: code adapted from scanpy to use a custom version of harmonypy
4
+
5
+ Harmony:
6
+ Korsunsky, I., Millard, N., Fan, J. et al. Fast, sensitive and accurate integration of single-cell data with Harmony.
7
+ Nat Methods 16, 1289-1296 (2019). https://doi.org/10.1038/s41592-019-0619-0
8
+
9
+ Scanpy:
10
+ Wolf, F., Angerer, P. & Theis, F. SCANPY: large-scale single-cell gene expression data analysis.
11
+ Genome Biol 19, 15 (2018). https://doi.org/10.1186/s13059-017-1382-0
12
+
13
+ Scverse:
14
+ Virshup, I., Bredikhin, D., Heumos, L. et al. The scverse project provides a computational ecosystem for single-cell omics data analysis.
15
+ Nat Biotechnol 41, 604-606 (2023). https://doi.org/10.1038/s41587-023-01733-8
16
+ """
17
+
18
+ from collections.abc import Sequence
19
+
20
+ from anndata import AnnData
21
+ import numpy as np
22
+
23
+ from ._harmony import run_harmony
24
+
25
+
26
+ def harmony_integrate(
27
+ adata: AnnData,
28
+ key: str | Sequence[str],
29
+ *,
30
+ basis: str = "X_pca",
31
+ adjusted_basis: str = "X_pca_harmony",
32
+ reference_batch: str | list[str] | None = None,
33
+ **kwargs,
34
+ ):
35
+ """Use harmonypy :cite:p:`Korsunsky2019` to integrate different experiments."""
36
+
37
+ if isinstance(reference_batch, str):
38
+ reference_batch = [reference_batch]
39
+
40
+ if reference_batch is not None:
41
+ reference_values = np.zeros(adata.n_obs, dtype=bool)
42
+ for batch in reference_batch:
43
+ reference_values |= adata.obs[key].values == batch
44
+ kwargs["reference_values"] = reference_values
45
+
46
+ X = adata.obsm[basis].astype(np.float64)
47
+
48
+ harmony_out = run_harmony(X, adata.obs, key, **kwargs)
49
+
50
+ adata.obsm[adjusted_basis] = harmony_out.Z_corr.T
@@ -0,0 +1,61 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+ from anndata import AnnData, ImplicitModificationWarning
5
+ from scipy.sparse import csr_matrix, issparse
6
+
7
+
8
+ def normalize_weighted(
9
+ adata: AnnData,
10
+ target_scale: float | None = None,
11
+ batch_key: str | None = None,
12
+ ) -> None:
13
+ if batch_key is not None:
14
+ for _, idx in adata.obs.groupby(batch_key, observed=True).groups.items():
15
+ with warnings.catch_warnings():
16
+ warnings.filterwarnings(
17
+ "ignore",
18
+ category=ImplicitModificationWarning,
19
+ message="Modifying `X` on a view results in data being overridden",
20
+ )
21
+ normalize_weighted(adata[idx], target_scale, None)
22
+
23
+ return
24
+
25
+ X: csr_matrix
26
+ Y: csr_matrix
27
+ Z: csr_matrix
28
+
29
+ X = adata.X
30
+ assert issparse(X)
31
+
32
+ with warnings.catch_warnings():
33
+ warnings.filterwarnings(
34
+ "ignore", category=RuntimeWarning, message="divide by zero"
35
+ )
36
+ Y = X.multiply(1 / X.sum(axis=0))
37
+ Y = Y.tocsr()
38
+ Y.eliminate_zeros()
39
+ Y.data = -Y.data * np.log(Y.data)
40
+ entropy = Y.sum(axis=0)
41
+
42
+ Z = X.multiply(entropy)
43
+ Z = Z.tocsr()
44
+ Z.eliminate_zeros()
45
+
46
+ with warnings.catch_warnings():
47
+ warnings.filterwarnings(
48
+ "ignore", category=RuntimeWarning, message="divide by zero"
49
+ )
50
+ scale = Z.sum(axis=1)
51
+ Z = Z.multiply(1 / scale)
52
+ Z = Z.tocsr()
53
+
54
+ if target_scale is None:
55
+ target_scale = np.median(scale.A1[scale.A1 > 0])
56
+
57
+ Z = Z * target_scale
58
+
59
+ adata.X = Z
60
+
61
+ return
@@ -0,0 +1,208 @@
1
+ from typing import Sequence
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from anndata import AnnData
6
+
7
+
8
+ def subset_obs(
9
+ adata: AnnData,
10
+ subset: pd.Index | Sequence[str | int | bool] | str,
11
+ ) -> None:
12
+ """Subset observations (rows) in an AnnData object.
13
+
14
+ This function modifies the AnnData object in-place by selecting a subset of observations
15
+ based on the provided subset parameter. The subsetting can be done using observation
16
+ names, integer indices, a boolean mask, a query string, or a pandas Index.
17
+
18
+ Parameters
19
+ ----------
20
+ adata : AnnData
21
+ The annotated data matrix to subset. Will be modified in-place.
22
+ subset : pd.Index | Sequence[str | int | bool] | str
23
+ The subset specification. Can be one of:
24
+ * A pandas Index containing observation names
25
+ * A sequence of observation names (strings)
26
+ * A sequence of integer indices
27
+ * A boolean mask of length `adata.n_obs`
28
+ * A query string to match observations by their metadata columns
29
+
30
+ Examples
31
+ --------
32
+ >>> # Create an example AnnData object
33
+ >>> import anndata
34
+ >>> import pandas as pd
35
+ >>> import numpy as np
36
+ >>>
37
+ >>> obs = pd.DataFrame(
38
+ ... index=['A', 'B', 'C'],
39
+ ... data={'cell_type': ['type1', 'type2', 'type2']})
40
+ >>> adata_ = anndata.AnnData(obs=obs)
41
+ >>>
42
+ >>> # Subset using pandas Index
43
+ >>> adata = adata_.copy()
44
+ >>> subset_obs(adata, pd.Index(['B', 'C']))
45
+ >>> adata.obs_names.tolist()
46
+ ['B', 'C']
47
+ >>>
48
+ >>> # Subset using observation names
49
+ >>> adata = adata_.copy()
50
+ >>> subset_obs(adata, ['A', 'B'])
51
+ >>> adata.obs_names.tolist()
52
+ ['A', 'B']
53
+ >>>
54
+ >>> # Subset using integer indices
55
+ >>> adata = adata_.copy()
56
+ >>> subset_obs(adata, [0, 1])
57
+ >>> adata.obs_names.tolist()
58
+ ['A', 'B']
59
+ >>>
60
+ >>> # Subset using boolean mask
61
+ >>> adata = adata_.copy()
62
+ >>> subset_obs(adata, [True, False, True])
63
+ >>> adata.obs_names.tolist()
64
+ ['A', 'C']
65
+ >>>
66
+ >>> # Subset using query string
67
+ >>> adata = adata_.copy()
68
+ >>> subset_obs(adata, 'cell_type == "type2"')
69
+ >>> adata.obs_names.tolist()
70
+ ['B', 'C']
71
+
72
+ Notes
73
+ -----
74
+ - The function modifies the AnnData object in-place
75
+ - When using a boolean mask, its length must match the number of observations
76
+ - When using integer indices, they must be valid indices for the observations
77
+ - Invalid observation names or indices will raise KeyError or IndexError respectively
78
+ - The order of observations in the output will match the order in the subset parameter
79
+ """
80
+ if isinstance(subset, str):
81
+ subset = adata.obs.query(subset).index
82
+
83
+ if not isinstance(subset, pd.Index):
84
+ subset = np.asarray(subset)
85
+
86
+ # Handle boolean mask
87
+ if subset.dtype.kind == "b":
88
+ if len(subset) != adata.n_obs:
89
+ raise IndexError(
90
+ f"Boolean mask length ({len(subset)}) does not match number of "
91
+ f"observations ({adata.n_obs})"
92
+ )
93
+ subset = adata.obs_names[subset]
94
+
95
+ # Handle integer indices
96
+ elif subset.dtype.kind in "iu":
97
+ if np.any(subset < 0) or np.any(subset >= adata.n_obs):
98
+ raise IndexError(f"Integer indices must be between 0 and {adata.n_obs - 1}")
99
+ subset = adata.obs_names[subset]
100
+
101
+ if adata.n_obs == subset.size and (subset == adata.obs_names).all():
102
+ # No need to subset, avoid making a copy. Useful for large AnnData objects
103
+ return
104
+
105
+ adata._inplace_subset_obs(subset)
106
+
107
+
108
+ def subset_var(
109
+ adata: AnnData,
110
+ subset: pd.Index | Sequence[str | int | bool] | str,
111
+ ) -> None:
112
+ """Subset variables (columns) in an AnnData object.
113
+
114
+ This function modifies the AnnData object in-place by selecting a subset of variables
115
+ based on the provided subset parameter. The subsetting can be done using variable
116
+ names, integer indices, a boolean mask, a query string, or a pandas Index.
117
+
118
+ Parameters
119
+ ----------
120
+ adata : AnnData
121
+ The annotated data matrix to subset. Will be modified in-place.
122
+ subset : pd.Index | Sequence[str | int | bool] | str
123
+ The subset specification. Can be one of:
124
+ * A pandas Index containing variable names
125
+ * A sequence of variable names (strings)
126
+ * A sequence of integer indices
127
+ * A boolean mask of length `adata.n_vars`
128
+ * A query string to match variables by their metadata columns
129
+
130
+ Examples
131
+ --------
132
+ >>> # Create an example AnnData object
133
+ >>> import anndata
134
+ >>> import pandas as pd
135
+ >>> import numpy as np
136
+ >>>
137
+ >>> var = pd.DataFrame(
138
+ ... index=['gene1', 'gene2', 'gene3'],
139
+ ... data={'gene_type': ['type1', 'type2', 'type1']})
140
+ >>> adata_ = anndata.AnnData(var=var)
141
+ >>>
142
+ >>> # Subset using pandas Index
143
+ >>> adata = adata_.copy()
144
+ >>> subset_var(adata, pd.Index(['gene2', 'gene3']))
145
+ >>> adata.var_names.tolist()
146
+ ['gene2', 'gene3']
147
+ >>>
148
+ >>> # Subset using variable names
149
+ >>> adata = adata_.copy()
150
+ >>> subset_var(adata, ['gene1', 'gene2'])
151
+ >>> adata.var_names.tolist()
152
+ ['gene1', 'gene2']
153
+ >>>
154
+ >>> # Subset using integer indices
155
+ >>> adata = adata_.copy()
156
+ >>> subset_var(adata, [0, 1])
157
+ >>> adata.var_names.tolist()
158
+ ['gene1', 'gene2']
159
+ >>>
160
+ >>> # Subset using boolean mask
161
+ >>> adata = adata_.copy()
162
+ >>> subset_var(adata, [True, False, True])
163
+ >>> adata.var_names.tolist()
164
+ ['gene1', 'gene3']
165
+ >>>
166
+ >>> # Subset using query string
167
+ >>> adata = adata_.copy()
168
+ >>> subset_var(adata, 'gene_type == "type1"')
169
+ >>> adata.var_names.tolist()
170
+ ['gene1', 'gene3']
171
+
172
+ Notes
173
+ -----
174
+ - The function modifies the AnnData object in-place
175
+ - When using a boolean mask, its length must match the number of variables
176
+ - When using integer indices, they must be valid indices for the variables
177
+ - Invalid variable names or indices will raise KeyError or IndexError respectively
178
+ - The order of variables in the output will match the order in the subset parameter
179
+ """
180
+
181
+ if isinstance(subset, str):
182
+ subset = adata.var.query(subset).index
183
+
184
+ if not isinstance(subset, pd.Index):
185
+ subset = np.asarray(subset)
186
+
187
+ # Handle boolean mask
188
+ if subset.dtype.kind == "b":
189
+ if len(subset) != adata.n_vars:
190
+ raise IndexError(
191
+ f"Boolean mask length ({len(subset)}) does not match number of "
192
+ f"variables ({adata.n_vars})"
193
+ )
194
+ subset = adata.var_names[subset]
195
+
196
+ # Handle integer indices
197
+ elif subset.dtype.kind in "iu":
198
+ if np.any(subset < 0) or np.any(subset >= adata.n_vars):
199
+ raise IndexError(
200
+ f"Integer indices must be between 0 and {adata.n_vars - 1}"
201
+ )
202
+ subset = adata.var_names[subset]
203
+
204
+ if adata.n_vars == subset.size and (subset == adata.var_names).all():
205
+ # No need to subset, avoid making a copy. Useful for large AnnData objects
206
+ return
207
+
208
+ adata._inplace_subset_var(subset)
@@ -0,0 +1,137 @@
1
+ from collections import Counter
2
+ from functools import partial
3
+ from typing import Literal
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from anndata import AnnData
8
+ from numpy.typing import NDArray
9
+ from pandas.api.types import is_bool_dtype, is_numeric_dtype
10
+ from scipy.sparse import csr_matrix
11
+ from scipy.special import gamma
12
+ from tqdm.auto import tqdm
13
+
14
+
15
+ def transfer_metadata(
16
+ adata: AnnData,
17
+ group_key: str,
18
+ source_group: str,
19
+ column: str,
20
+ periodic: bool = False,
21
+ vmin: float = 0,
22
+ vmax: float = 1,
23
+ min_neighs: int = 5,
24
+ weight_by: Literal["connectivity", "distance", "constant"] = "connectivity",
25
+ ):
26
+ D: csr_matrix = adata.obsp["distances"]
27
+ C: csr_matrix = adata.obsp["connectivities"]
28
+ D = D.tocsr()
29
+
30
+ match weight_by:
31
+ case "connectivity":
32
+ W = C.tocsr()
33
+ case "distance":
34
+ W = D.tocsr()
35
+ W.data = 1.0 / W.data
36
+ case "constant":
37
+ W = D.tocsr()
38
+ W.data[:] = 1.0
39
+ case _:
40
+ raise ValueError(f"Unsupported weight_by {weight_by}")
41
+
42
+ meta_values: pd.Series
43
+ new_values: pd.Series
44
+
45
+ series = adata.obs[column]
46
+ if isinstance(series.dtype, pd.CategoricalDtype) or is_bool_dtype(series.dtype):
47
+ assign_value_fn = _assign_categorical
48
+ new_column = f"transferred_{column}"
49
+ new_column_err = f"transferred_{column}_proportion"
50
+ elif is_numeric_dtype(series.dtype) and periodic:
51
+ assign_value_fn = partial(_assign_numerical_periodic, vmin=vmin, vmax=vmax)
52
+ new_column = f"transferred_{column}"
53
+ new_column_err = f"transferred_{column}_error"
54
+ elif is_numeric_dtype(series.dtype):
55
+ assign_value_fn = _assign_numerical
56
+ new_column = f"transferred_{column}"
57
+ new_column_err = f"transferred_{column}_error"
58
+ else:
59
+ raise ValueError(f"Unsupported dtype {series.dtype} for column {column}")
60
+
61
+ meta_values = series.copy()
62
+ meta_values[adata.obs[group_key] != source_group] = np.nan
63
+ new_values = pd.Series(index=series.index, dtype=series.dtype, name=new_column)
64
+ new_values_err = pd.Series(index=series.index, dtype=float, name=new_column_err)
65
+
66
+ for i, (d, w) in tqdm(enumerate(zip(D, W)), total=D.shape[0]):
67
+ if not pd.isna(meta_values.iloc[i]):
68
+ continue
69
+
70
+ d = d.tocoo()
71
+ w = w.toarray().ravel()
72
+ neighs = d.coords[1]
73
+
74
+ values: pd.Series = meta_values.iloc[neighs]
75
+ msk = pd.notna(values)
76
+ if msk.sum() < min_neighs:
77
+ continue
78
+
79
+ values = values.loc[msk]
80
+ weights = w[neighs][msk]
81
+
82
+ if np.allclose(weights, 0):
83
+ continue
84
+
85
+ assigned_value, assigned_value_err = assign_value_fn(values, weights)
86
+ new_values.iloc[i] = assigned_value
87
+ new_values_err.iloc[i] = assigned_value_err
88
+
89
+ adata.obs[new_column] = new_values.copy()
90
+ adata.obs[new_column_err] = new_values_err.copy()
91
+
92
+
93
+ def _assign_categorical(values: pd.Series, weights: NDArray):
94
+ # weighted majority and proportion of votes
95
+ tally = Counter()
96
+ for v, w in zip(values, weights):
97
+ tally[v] += w
98
+
99
+ winner, shares = tally.most_common()[0]
100
+ return winner, shares / weights.sum()
101
+
102
+
103
+ def _assign_numerical(values: pd.Series, weights: NDArray):
104
+ # weighted mean and standard error
105
+ sum_w: float = weights.sum()
106
+ sum2_w: float = weights.sum() ** 2
107
+ sum_w2: float = (weights**2).sum()
108
+ n_eff: float = sum2_w / sum_w2
109
+
110
+ mean_x: float = (values * weights).sum() / sum_w
111
+ var_x: float = ((values - mean_x) ** 2 * weights).sum() * sum_w / (sum2_w - sum_w2)
112
+ err_x: float = np.sqrt(var_x / n_eff)
113
+
114
+ return mean_x, err_x
115
+
116
+
117
+ def _assign_numerical_periodic(
118
+ values: pd.Series, weights: NDArray, vmin: float, vmax: float
119
+ ):
120
+ vspan = vmax - vmin
121
+
122
+ values = values - vmin
123
+ offset = np.median(values)
124
+ values = values - offset + vspan / 2
125
+ values = values % vspan
126
+ assigned_value, assigned_value_err = _assign_numerical(values, weights)
127
+ assigned_value = assigned_value + offset - vspan / 2
128
+ assigned_value = assigned_value % vspan
129
+ assigned_value = assigned_value + vmin
130
+
131
+ return assigned_value, assigned_value_err
132
+
133
+
134
+ def _c4(n: float):
135
+ # correct for bias
136
+ nm1 = n - 1
137
+ return np.sqrt(2 / nm1) * gamma(n / 2) / gamma(nm1 / 2)
@@ -0,0 +1,82 @@
1
+ from typing import Optional
2
+
3
+ from anndata import AnnData
4
+ from numpy import ndarray
5
+ from scipy.sparse import csr_matrix
6
+
7
+ from ._utils import get_neighbors_adjacency_matrix
8
+
9
+
10
+ def pool_neighbors(
11
+ adata: AnnData,
12
+ *,
13
+ layer: Optional[str] = None,
14
+ n_neighbors: Optional[int] = None,
15
+ neighbors_key: Optional[str] = None,
16
+ weighted: bool = False,
17
+ directed: bool = True,
18
+ key_added: Optional[str] = None,
19
+ copy: bool = False,
20
+ ) -> csr_matrix | ndarray | None:
21
+ """
22
+ Given an adjacency matrix, pool cell features using a weighted sum of feature counts
23
+ from neighboring cells. The weights are the normalized connectivities from the
24
+ adjacency matrix.
25
+
26
+ Parameters
27
+ ----------
28
+ adata : AnnData
29
+ Annotated data matrix.
30
+ layer : str, optional
31
+ Layer in AnnData object to use for pooling. Defaults to None.
32
+ n_neighbors : int, optional
33
+ Number of neighbors to consider. Defaults to None.
34
+ neighbors_key : str, optional
35
+ Key in AnnData object to use for neighbors. Defaults to None.
36
+ weighted : bool, optional
37
+ Whether to weight neighbors by their connectivities in the adjacency matrix.
38
+ Defaults to False.
39
+ directed : bool, optional
40
+ Whether to use directed or undirected neighbors. Defaults to True.
41
+ key_added : str, optional
42
+ Key to use in AnnData object for the pooled features. Defaults to None.
43
+ copy : bool, optional
44
+ Whether to return a copy of the pooled features instead of modifying the
45
+ original AnnData object. Defaults to False.
46
+
47
+ Returns
48
+ -------
49
+ csr_matrix | ndarray | None
50
+ The pooled features if copy is True, otherwise None.
51
+ """
52
+ if layer is None or layer == "X":
53
+ X = adata.X
54
+ else:
55
+ X = adata.layers[layer]
56
+
57
+ adjacency = get_neighbors_adjacency_matrix(
58
+ adata,
59
+ key=neighbors_key,
60
+ n_neighbors=n_neighbors,
61
+ weighted=weighted,
62
+ directed=directed,
63
+ )
64
+
65
+ W = adjacency.tolil()
66
+ W.setdiag(1)
67
+
68
+ W = W / W.sum(axis=1)
69
+
70
+ pooled = W.dot(X)
71
+
72
+ if copy:
73
+ return pooled
74
+
75
+ if key_added is not None:
76
+ adata.layers[key_added] = pooled
77
+ return
78
+
79
+ if layer is None or layer == "X":
80
+ adata.X = pooled
81
+ else:
82
+ adata.layers[layer] = pooled
@@ -0,0 +1,96 @@
1
+ from typing import Literal, Optional
2
+
3
+ import numpy as np
4
+ from anndata import AnnData
5
+ from scanpy import Neighbors
6
+ from scipy.sparse import coo_matrix, csr_matrix
7
+
8
+
9
+ def get_neighbors_adjacency_matrix(
10
+ adata: AnnData,
11
+ *,
12
+ key: Optional[str] = "neighbors",
13
+ n_neighbors: Optional[int] = None,
14
+ weighted: bool = False,
15
+ directed: bool = True,
16
+ ) -> csr_matrix:
17
+ # get the current neighbors
18
+ neigh = Neighbors(adata, neighbors_key=key)
19
+ params = adata.uns[key]["params"]
20
+
21
+ if n_neighbors is None:
22
+ n_neighbors = neigh.n_neighbors
23
+
24
+ if n_neighbors < neigh.n_neighbors and not weighted:
25
+ distances = _filter_knn_matrix(
26
+ neigh.distances, n_neighbors=n_neighbors, mode="distances"
27
+ )
28
+
29
+ elif n_neighbors != neigh.n_neighbors:
30
+ neigh.compute_neighbors(**{**params, "n_neighbors": n_neighbors})
31
+ distances = neigh.distances
32
+
33
+ else:
34
+ distances = neigh.distances
35
+
36
+ adjacency = distances.copy()
37
+ adjacency.data = np.ones_like(adjacency.data)
38
+
39
+ if not directed:
40
+ # make the adjacency matrix symmetric
41
+ adjacency = _symmetrize_sparse_matrix(adjacency)
42
+
43
+ if weighted:
44
+ # use the connectivities to assign weights
45
+ adjacency = adjacency.multiply(neigh.connectivities)
46
+
47
+ return adjacency
48
+
49
+
50
+ def _filter_knn_matrix(
51
+ matrix: csr_matrix, *, n_neighbors: int, mode: Literal["distances", "weights"]
52
+ ) -> csr_matrix:
53
+ assert mode in ["distances", "weights"]
54
+ nrows, _ = matrix.shape
55
+
56
+ # Initialize arrays for new sparse matrix with pre-allocated size
57
+ indptr = np.arange(0, (n_neighbors - 1) * (nrows + 1), n_neighbors - 1)
58
+ data = np.zeros(nrows * (n_neighbors - 1), dtype=float)
59
+ indices = np.zeros(nrows * (n_neighbors - 1), dtype=int)
60
+
61
+ # Process each row to keep top n_neighbors-1 connections
62
+ for i in range(nrows):
63
+ start, end = matrix.indptr[i : i + 2]
64
+ idxs = matrix.indices[start:end]
65
+ vals = matrix.data[start:end]
66
+
67
+ # Sort by values and keep top n_neighbors-1
68
+ if mode == "weights":
69
+ # Sort in descending order (keep largest weights)
70
+ o = np.argsort(-vals)[: n_neighbors - 1]
71
+ else:
72
+ # Sort in ascending order (keep smallest distances)
73
+ o = np.argsort(vals)[: n_neighbors - 1]
74
+
75
+ # Maintain original order within top neighbors
76
+ oo = np.argsort(idxs[o])
77
+ start, end = indptr[i : i + 2]
78
+ indices[start:end] = idxs[o][oo]
79
+ data[start:end] = vals[o][oo]
80
+
81
+ return csr_matrix((data, indices, indptr))
82
+
83
+
84
+ def _symmetrize_sparse_matrix(matrix: csr_matrix) -> csr_matrix:
85
+ A = matrix.tocoo()
86
+
87
+ # Make matrix symmetric by duplicating entries in both directions
88
+ coords = np.array([[*A.row, *A.col], [*A.col, *A.row]])
89
+ data = np.array([*A.data, *A.data])
90
+
91
+ # Remove duplicate entries that might occur in symmetrization
92
+ idxs = np.unique(coords, axis=1, return_index=True)[1]
93
+ coords, data = coords[:, idxs], data[idxs]
94
+ A = coo_matrix((data, coords), shape=matrix.shape)
95
+
96
+ return A.tocsr()
File without changes
File without changes
File without changes