SpatialQuery 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SpatialQuery/__init__.py +3 -0
- SpatialQuery/spatial_query.py +1339 -0
- SpatialQuery/spatial_query_multiple_fov.py +1116 -0
- SpatialQuery/utils.py +130 -0
- spatialquery-0.0.1.dist-info/METADATA +56 -0
- spatialquery-0.0.1.dist-info/RECORD +9 -0
- spatialquery-0.0.1.dist-info/WHEEL +5 -0
- spatialquery-0.0.1.dist-info/licenses/LICENSE +21 -0
- spatialquery-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1339 @@
|
|
|
1
|
+
from collections import Counter
|
|
2
|
+
from itertools import combinations
|
|
3
|
+
from typing import List, Union
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import statsmodels.stats.multitest as mt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import seaborn as sns
|
|
10
|
+
from anndata import AnnData
|
|
11
|
+
from mlxtend.frequent_patterns import fpgrowth
|
|
12
|
+
from sklearn.preprocessing import MultiLabelBinarizer
|
|
13
|
+
from pandas import DataFrame
|
|
14
|
+
from scipy.spatial import KDTree
|
|
15
|
+
from scipy.stats import hypergeom
|
|
16
|
+
from sklearn.preprocessing import LabelEncoder
|
|
17
|
+
import time
|
|
18
|
+
|
|
19
|
+
class spatial_query:
|
|
20
|
+
def __init__(self,
|
|
21
|
+
adata: AnnData,
|
|
22
|
+
dataset: str = 'ST',
|
|
23
|
+
spatial_key: str = 'X_spatial',
|
|
24
|
+
label_key: str = 'predicted_label',
|
|
25
|
+
leaf_size: int = 10,
|
|
26
|
+
max_radius: float = 500,
|
|
27
|
+
n_split: int = 10
|
|
28
|
+
):
|
|
29
|
+
if spatial_key not in adata.obsm.keys() or label_key not in adata.obs.keys():
|
|
30
|
+
raise ValueError(f"The Anndata object must contain {spatial_key} in obsm and {label_key} in obs.")
|
|
31
|
+
# Store spatial position and cell type label
|
|
32
|
+
self.spatial_key = spatial_key
|
|
33
|
+
self.spatial_pos = np.array(adata.obsm[self.spatial_key])
|
|
34
|
+
self.dataset = dataset
|
|
35
|
+
self.label_key = label_key
|
|
36
|
+
self.max_radius = max_radius
|
|
37
|
+
self.labels = adata.obs[self.label_key]
|
|
38
|
+
self.labels = self.labels.astype('category')
|
|
39
|
+
self.kd_tree = KDTree(self.spatial_pos, leafsize=leaf_size)
|
|
40
|
+
self.overlap_radius = max_radius # the upper limit of radius in case missing cells with large radius of query
|
|
41
|
+
self.n_split = n_split
|
|
42
|
+
self.grid_cell_types, self.grid_indices = self._initialize_grids()
|
|
43
|
+
|
|
44
|
+
def _initialize_grids(self):
|
|
45
|
+
xmax, ymax = np.max(self.spatial_pos, axis=0)
|
|
46
|
+
xmin, ymin = np.min(self.spatial_pos, axis=0)
|
|
47
|
+
x_step = (xmax - xmin) / self.n_split # separate x axis into self.n_split parts
|
|
48
|
+
y_step = (ymax - ymin) / self.n_split # separate y axis into self.n_split parts
|
|
49
|
+
|
|
50
|
+
grid_cell_types = {}
|
|
51
|
+
grid_indices = {}
|
|
52
|
+
|
|
53
|
+
for i in range(self.n_split):
|
|
54
|
+
for j in range(self.n_split):
|
|
55
|
+
x_start = xmin + i * x_step - (self.overlap_radius if i > 0 else 0)
|
|
56
|
+
x_end = xmin + (i + 1) * x_step + (self.overlap_radius if i < (self.n_split - 1) else 0)
|
|
57
|
+
y_start = ymin + j * y_step - (self.overlap_radius if j > 0 else 0)
|
|
58
|
+
y_end = ymin + (j + 1) * y_step + (self.overlap_radius if j < (self.n_split - 1) else 0)
|
|
59
|
+
|
|
60
|
+
cell_mask = (self.spatial_pos[:, 0] >= x_start) & (self.spatial_pos[:, 0] <= x_end) & \
|
|
61
|
+
(self.spatial_pos[:, 1] >= y_start) & (self.spatial_pos[:, 1] <= y_end)
|
|
62
|
+
|
|
63
|
+
grid_indices[(i, j)] = np.where(cell_mask)[0]
|
|
64
|
+
grid_cell_types[(i, j)] = set(self.labels[cell_mask])
|
|
65
|
+
|
|
66
|
+
return grid_cell_types, grid_indices
|
|
67
|
+
|
|
68
|
+
def _query_pattern(self, pattern):
|
|
69
|
+
matching_grids = []
|
|
70
|
+
matching_cells_indices = {}
|
|
71
|
+
for grid, cell_types in self.grid_cell_types.items():
|
|
72
|
+
if all(cell_type in cell_types for cell_type in pattern):
|
|
73
|
+
matching_grids.append(grid)
|
|
74
|
+
indices = self.grid_indices[grid]
|
|
75
|
+
matching_cells_indices[grid] = indices
|
|
76
|
+
return matching_grids, matching_cells_indices
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def has_motif(neighbors: List[str], labels: List[str]) -> bool:
|
|
80
|
+
"""
|
|
81
|
+
Determines whether all elements in 'neighbors' are present in 'labels'.
|
|
82
|
+
If all elements are present, returns True. Otherwise, returns False.
|
|
83
|
+
|
|
84
|
+
Parameter
|
|
85
|
+
---------
|
|
86
|
+
neighbors:
|
|
87
|
+
List of elements to check.
|
|
88
|
+
labels:
|
|
89
|
+
List in which to check for elements from 'neighbors'.
|
|
90
|
+
|
|
91
|
+
Return
|
|
92
|
+
------
|
|
93
|
+
True if all elements of 'neighbors' are in 'labels', False otherwise.
|
|
94
|
+
"""
|
|
95
|
+
# Set elements in neighbors and labels to be unique.
|
|
96
|
+
# neighbors = set(neighbors)
|
|
97
|
+
# labels = set(labels)
|
|
98
|
+
freq_neighbors = Counter(neighbors)
|
|
99
|
+
freq_labels = Counter(labels)
|
|
100
|
+
for element, count in freq_neighbors.items():
|
|
101
|
+
if freq_labels[element] < count:
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
return True
|
|
105
|
+
# if len(neighbors) <= len(labels):
|
|
106
|
+
# for n in neighbors:
|
|
107
|
+
# if n in labels:
|
|
108
|
+
# pass
|
|
109
|
+
# else:
|
|
110
|
+
# return False
|
|
111
|
+
# return True
|
|
112
|
+
# return False
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def _distinguish_duplicates(transaction: List[str]):
|
|
116
|
+
"""
|
|
117
|
+
Append suffix to items of transaction to distinguish the duplicate items.
|
|
118
|
+
"""
|
|
119
|
+
counter = dict(Counter(transaction))
|
|
120
|
+
trans_suf = [f"{item}_{i}" for item, value in counter.items() for i in range(value)]
|
|
121
|
+
# trans_suf = [f"{item}_{value}" for item, value in counter.items()]
|
|
122
|
+
# count_dict = defaultdict(int)
|
|
123
|
+
# for i, item in enumerate(transaction):
|
|
124
|
+
# # Increment the count for the item, or initialize it if it's new
|
|
125
|
+
# count_dict[item] += 1
|
|
126
|
+
# # Update the item with its count as suffix
|
|
127
|
+
# transaction[i] = f"{item}_{count_dict[item]}"
|
|
128
|
+
# return transaction
|
|
129
|
+
return trans_suf
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def _remove_suffix(fp: pd.DataFrame):
|
|
133
|
+
"""
|
|
134
|
+
Remove the suffix of frequent patterns.
|
|
135
|
+
"""
|
|
136
|
+
trans = [list(tran) for tran in fp['itemsets'].values]
|
|
137
|
+
fp_no_suffix = [[item.split('_')[0] for item in tran] for tran in trans]
|
|
138
|
+
# Create a DataFrame
|
|
139
|
+
fp['itemsets'] = fp_no_suffix
|
|
140
|
+
return fp
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def find_maximal_patterns(fp: pd.DataFrame) -> pd.DataFrame:
|
|
144
|
+
"""
|
|
145
|
+
Find the maximal frequent patterns
|
|
146
|
+
|
|
147
|
+
Parameter
|
|
148
|
+
---------
|
|
149
|
+
fp: Frequent patterns dataframe with support values and itemsets.
|
|
150
|
+
|
|
151
|
+
Return
|
|
152
|
+
------
|
|
153
|
+
Maximal frequent patterns with support and itemsets.
|
|
154
|
+
"""
|
|
155
|
+
# Convert itemsets to frozensets for set operations
|
|
156
|
+
itemsets = fp['itemsets'].apply(frozenset)
|
|
157
|
+
|
|
158
|
+
# Find all subsets of each itemset
|
|
159
|
+
subsets = set()
|
|
160
|
+
for itemset in itemsets:
|
|
161
|
+
for r in range(1, len(itemset)):
|
|
162
|
+
subsets.update(frozenset(s) for s in combinations(itemset, r))
|
|
163
|
+
|
|
164
|
+
# Identify maximal patterns (itemsets that are not subsets of any other)
|
|
165
|
+
maximal_patterns = [itemset for itemset in itemsets if itemset not in subsets]
|
|
166
|
+
# maximal_patterns_ = [list(p) for p in maximal_patterns]
|
|
167
|
+
|
|
168
|
+
# Filter the original DataFrame to keep only the maximal patterns
|
|
169
|
+
return fp[fp['itemsets'].isin(maximal_patterns)].reset_index(drop=True)
|
|
170
|
+
|
|
171
|
+
def find_fp_knn(self,
|
|
172
|
+
ct: str,
|
|
173
|
+
k: int = 30,
|
|
174
|
+
min_support: float = 0.5,
|
|
175
|
+
) -> pd.DataFrame:
|
|
176
|
+
"""
|
|
177
|
+
Find frequent patterns within the KNNs of certain cell type.
|
|
178
|
+
|
|
179
|
+
Parameter
|
|
180
|
+
---------
|
|
181
|
+
ct:
|
|
182
|
+
Cell type name.
|
|
183
|
+
k:
|
|
184
|
+
Number of nearest neighbors.
|
|
185
|
+
min_support:
|
|
186
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
187
|
+
|
|
188
|
+
Return
|
|
189
|
+
------
|
|
190
|
+
Frequent patterns in the neighborhood of certain cell type.
|
|
191
|
+
"""
|
|
192
|
+
if ct not in self.labels.unique():
|
|
193
|
+
raise ValueError(f"Found no {ct} in {self.label_key}!")
|
|
194
|
+
|
|
195
|
+
cinds = [id for id, l in enumerate(self.labels) if l == ct]
|
|
196
|
+
ct_pos = self.spatial_pos[cinds]
|
|
197
|
+
|
|
198
|
+
fp, _, _ = self.build_fptree_knn(cell_pos=ct_pos, k=k,
|
|
199
|
+
min_support=min_support,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
return fp
|
|
203
|
+
|
|
204
|
+
def find_fp_dist(self,
|
|
205
|
+
ct: str,
|
|
206
|
+
max_dist: float = 100,
|
|
207
|
+
min_size: int = 0,
|
|
208
|
+
min_support: float = 0.5,
|
|
209
|
+
):
|
|
210
|
+
"""
|
|
211
|
+
Find frequent patterns within the radius of certain cell type.
|
|
212
|
+
|
|
213
|
+
Parameter
|
|
214
|
+
---------
|
|
215
|
+
ct:
|
|
216
|
+
Cell type name.
|
|
217
|
+
max_dist:
|
|
218
|
+
Maximum distance for considering a cell as a neighbor.
|
|
219
|
+
min_size:
|
|
220
|
+
Minimum neighborhood size for each point to consider.
|
|
221
|
+
min_support:
|
|
222
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
223
|
+
|
|
224
|
+
Return
|
|
225
|
+
------
|
|
226
|
+
Frequent patterns in the neighborhood of certain cell type.
|
|
227
|
+
"""
|
|
228
|
+
if ct not in self.labels.unique():
|
|
229
|
+
raise ValueError(f"Found no {ct} in {self.label_key}!")
|
|
230
|
+
|
|
231
|
+
cinds = [id for id, l in enumerate(self.labels) if l == ct]
|
|
232
|
+
ct_pos = self.spatial_pos[cinds]
|
|
233
|
+
max_dist = min(max_dist, self.max_radius)
|
|
234
|
+
|
|
235
|
+
fp, _, _ = self.build_fptree_dist(cell_pos=ct_pos,
|
|
236
|
+
max_dist=max_dist,
|
|
237
|
+
min_size=min_size,
|
|
238
|
+
min_support=min_support,
|
|
239
|
+
cinds=cinds)
|
|
240
|
+
|
|
241
|
+
return fp
|
|
242
|
+
|
|
243
|
+
def motif_enrichment_knn(self,
|
|
244
|
+
ct: str,
|
|
245
|
+
motifs: Union[str, List[str]] = None,
|
|
246
|
+
k: int = 30,
|
|
247
|
+
min_support: float = 0.5,
|
|
248
|
+
max_dist: float = 200,
|
|
249
|
+
return_cellID: bool = False
|
|
250
|
+
) -> pd.DataFrame:
|
|
251
|
+
"""
|
|
252
|
+
Perform motif enrichment analysis using k-nearest neighbors (KNN).
|
|
253
|
+
|
|
254
|
+
Parameter
|
|
255
|
+
---------
|
|
256
|
+
ct:
|
|
257
|
+
The cell type of the center cell.
|
|
258
|
+
motifs:
|
|
259
|
+
Specified motifs to be tested.
|
|
260
|
+
If motifs=None, find the frequent patterns as motifs within the neighborhood of center cell type.
|
|
261
|
+
k:
|
|
262
|
+
Number of nearest neighbors to consider.
|
|
263
|
+
min_support:
|
|
264
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
265
|
+
max_dist:
|
|
266
|
+
Maximum distance for neighbors (default: 200).
|
|
267
|
+
return_cellID:
|
|
268
|
+
Indicate whether return cell IDs for each frequent pattern within the neighborhood of grid points.
|
|
269
|
+
By defaults do not return cell ID.
|
|
270
|
+
|
|
271
|
+
Return
|
|
272
|
+
------
|
|
273
|
+
pd.Dataframe containing the cell type name, motifs, number of motifs nearby given cell type,
|
|
274
|
+
number of spots of cell type, number of motifs in single FOV, p value of hypergeometric distribution.
|
|
275
|
+
"""
|
|
276
|
+
if ct not in self.labels.unique():
|
|
277
|
+
raise ValueError(f"Found no {ct} in {self.label_key}!")
|
|
278
|
+
|
|
279
|
+
max_dist = min(max_dist, self.max_radius)
|
|
280
|
+
|
|
281
|
+
dists, idxs = self.kd_tree.query(self.spatial_pos,
|
|
282
|
+
k=k + 1, workers=-1
|
|
283
|
+
) # use k+1 to find the knn except for the points themselves
|
|
284
|
+
cinds = [i for i, l in enumerate(self.labels) if l == ct]
|
|
285
|
+
|
|
286
|
+
out = []
|
|
287
|
+
if motifs is None:
|
|
288
|
+
fp = self.find_fp_knn(ct=ct, k=k,
|
|
289
|
+
min_support=min_support,
|
|
290
|
+
)
|
|
291
|
+
motifs = fp['itemsets']
|
|
292
|
+
else:
|
|
293
|
+
if isinstance(motifs, str):
|
|
294
|
+
motifs = [motifs]
|
|
295
|
+
|
|
296
|
+
labels_unique = self.labels.unique()
|
|
297
|
+
motifs_exc = [m for m in motifs if m not in labels_unique]
|
|
298
|
+
if len(motifs_exc) != 0:
|
|
299
|
+
print(f"Found no {motifs_exc} in {self.label_key}. Ignoring them.")
|
|
300
|
+
motifs = [m for m in motifs if m not in motifs_exc]
|
|
301
|
+
motifs = [motifs]
|
|
302
|
+
|
|
303
|
+
if len(motifs) == 0:
|
|
304
|
+
raise ValueError("No frequent patterns were found. Please lower min_support value.")
|
|
305
|
+
|
|
306
|
+
label_encoder = LabelEncoder()
|
|
307
|
+
int_labels = label_encoder.fit_transform(self.labels)
|
|
308
|
+
int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
|
|
309
|
+
|
|
310
|
+
num_cells = idxs.shape[0]
|
|
311
|
+
num_types = len(label_encoder.classes_)
|
|
312
|
+
|
|
313
|
+
valid_neighbors = dists[:, 1:] <= max_dist
|
|
314
|
+
filtered_idxs = np.where(valid_neighbors, idxs[:, 1:], -1)
|
|
315
|
+
flat_neighbors = filtered_idxs.flatten()
|
|
316
|
+
valid_neighbors_flat = valid_neighbors.flatten()
|
|
317
|
+
|
|
318
|
+
neighbor_labels = np.where(valid_neighbors_flat, int_labels[flat_neighbors], -1)
|
|
319
|
+
valid_mask = neighbor_labels != -1
|
|
320
|
+
|
|
321
|
+
neighbor_matrix = np.zeros((num_cells * k, num_types), dtype=int)
|
|
322
|
+
neighbor_matrix[np.arange(len(neighbor_labels))[valid_mask], neighbor_labels[valid_mask]] = 1
|
|
323
|
+
|
|
324
|
+
neighbor_counts = neighbor_matrix.reshape(num_cells, k, num_types).sum(axis=1)
|
|
325
|
+
|
|
326
|
+
mask = int_labels == int_ct
|
|
327
|
+
|
|
328
|
+
for motif in motifs:
|
|
329
|
+
motif = list(motif) if not isinstance(motif, list) else motif
|
|
330
|
+
sort_motif = sorted(motif)
|
|
331
|
+
|
|
332
|
+
int_motifs = label_encoder.transform(np.array(motif))
|
|
333
|
+
|
|
334
|
+
n_motif_ct = np.sum(np.all(neighbor_counts[mask][:, int_motifs] > 0, axis=1))
|
|
335
|
+
n_motif_labels = np.sum(np.all(neighbor_counts[:, int_motifs] > 0, axis=1))
|
|
336
|
+
|
|
337
|
+
n_ct = len(cinds)
|
|
338
|
+
if ct in motif:
|
|
339
|
+
n_ct = round(n_ct / motif.count(ct))
|
|
340
|
+
|
|
341
|
+
hyge = hypergeom(M=len(self.labels), n=n_ct, N=n_motif_labels)
|
|
342
|
+
# M is number of total, N is number of drawn without replacement, n is number of success in total
|
|
343
|
+
motif_out = {'center': ct, 'motifs': sort_motif, 'n_center_motif': n_motif_ct,
|
|
344
|
+
'n_center': n_ct, 'n_motif': n_motif_labels, 'expectation': hyge.mean(), 'p-values': hyge.sf(n_motif_ct)}
|
|
345
|
+
|
|
346
|
+
if return_cellID:
|
|
347
|
+
inds = np.where(np.all(neighbor_counts[mask][:, int_motifs] > 0, axis=1))[0]
|
|
348
|
+
cind_with_motif = [cinds[i] for i in inds]
|
|
349
|
+
motif_mask = np.isin(np.array(self.labels), motif)
|
|
350
|
+
neighbors = np.concatenate(idxs[cind_with_motif])
|
|
351
|
+
exclude_self_mask = ~np.isin(neighbors, cind_with_motif)
|
|
352
|
+
valid_neighbors = neighbors[motif_mask[neighbors] & exclude_self_mask]
|
|
353
|
+
id_motif_celltype = set(valid_neighbors)
|
|
354
|
+
motif_out['cell_id'] = np.array(list(id_motif_celltype))
|
|
355
|
+
|
|
356
|
+
out.append(motif_out)
|
|
357
|
+
|
|
358
|
+
out_pd = pd.DataFrame(out)
|
|
359
|
+
|
|
360
|
+
if len(out_pd) == 1:
|
|
361
|
+
out_pd['if_significant'] = True if out_pd['p-values'][0] < 0.05 else False
|
|
362
|
+
return out_pd
|
|
363
|
+
else:
|
|
364
|
+
p_values = out_pd['p-values'].tolist()
|
|
365
|
+
if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
|
|
366
|
+
alpha=0.05,
|
|
367
|
+
method='poscorr')
|
|
368
|
+
out_pd['corrected p-values'] = corrected_p_values
|
|
369
|
+
out_pd['if_significant'] = if_rejected
|
|
370
|
+
out_pd = out_pd.sort_values(by='corrected p-values', ignore_index=True)
|
|
371
|
+
return out_pd
|
|
372
|
+
|
|
373
|
+
def motif_enrichment_dist(self,
|
|
374
|
+
ct: str,
|
|
375
|
+
motifs: Union[str, List[str]] = None,
|
|
376
|
+
max_dist: float = 100,
|
|
377
|
+
min_size: int = 0,
|
|
378
|
+
min_support: float = 0.5,
|
|
379
|
+
max_ns: int = 100,
|
|
380
|
+
return_cellID: bool = False,
|
|
381
|
+
) -> DataFrame:
|
|
382
|
+
"""
|
|
383
|
+
Perform motif enrichment analysis within a specified radius-based neighborhood.
|
|
384
|
+
|
|
385
|
+
Parameter
|
|
386
|
+
---------
|
|
387
|
+
ct:
|
|
388
|
+
Cell type as the center cells.
|
|
389
|
+
motifs:
|
|
390
|
+
Specified motifs to be tested.
|
|
391
|
+
If motifs=None, find the frequent patterns as motifs within the neighborhood of center cell type.
|
|
392
|
+
max_dist:
|
|
393
|
+
Maximum distance for considering a cell as a neighbor.
|
|
394
|
+
min_size:
|
|
395
|
+
Minimum neighborhood size for each point to consider.
|
|
396
|
+
min_support:
|
|
397
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
398
|
+
max_ns:
|
|
399
|
+
Maximum number of neighborhood size for each point.
|
|
400
|
+
return_cellID:
|
|
401
|
+
Indicate whether return cell IDs for each motif within the neighborhood of central cell type.
|
|
402
|
+
By defaults do not return cell ID.
|
|
403
|
+
Returns
|
|
404
|
+
-------
|
|
405
|
+
Tuple containing counts and statistical measures.
|
|
406
|
+
"""
|
|
407
|
+
if ct not in self.labels.unique():
|
|
408
|
+
raise ValueError(f"Found no {ct} in {self.label_key}!")
|
|
409
|
+
|
|
410
|
+
out = []
|
|
411
|
+
max_dist = min(max_dist, self.max_radius)
|
|
412
|
+
if motifs is None:
|
|
413
|
+
fp = self.find_fp_dist(ct=ct,
|
|
414
|
+
max_dist=max_dist, min_size=min_size,
|
|
415
|
+
min_support=min_support)
|
|
416
|
+
motifs = fp['itemsets']
|
|
417
|
+
else:
|
|
418
|
+
if isinstance(motifs, str):
|
|
419
|
+
motifs = [motifs]
|
|
420
|
+
|
|
421
|
+
labels_unique = self.labels.unique()
|
|
422
|
+
motifs_exc = [m for m in motifs if m not in labels_unique]
|
|
423
|
+
if len(motifs_exc) != 0:
|
|
424
|
+
print(f"Found no {motifs_exc} in {self.label_key}. Ignoring them.")
|
|
425
|
+
motifs = [m for m in motifs if m not in motifs_exc]
|
|
426
|
+
motifs = [motifs]
|
|
427
|
+
|
|
428
|
+
label_encoder = LabelEncoder()
|
|
429
|
+
int_labels = label_encoder.fit_transform(np.array(self.labels))
|
|
430
|
+
int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
|
|
431
|
+
|
|
432
|
+
cinds = np.where(self.labels == ct)[0]
|
|
433
|
+
|
|
434
|
+
num_cells = len(self.spatial_pos)
|
|
435
|
+
num_types = len(label_encoder.classes_)
|
|
436
|
+
|
|
437
|
+
if return_cellID:
|
|
438
|
+
idxs_all = self.kd_tree.query_ball_point(
|
|
439
|
+
self.spatial_pos,
|
|
440
|
+
r=max_dist,
|
|
441
|
+
return_sorted=False,
|
|
442
|
+
workers=-1,
|
|
443
|
+
)
|
|
444
|
+
idxs_all_filter = [np.array(ids)[np.array(ids) != i] for i, ids in enumerate(idxs_all)]
|
|
445
|
+
flat_neighbors_all = np.concatenate(idxs_all_filter)
|
|
446
|
+
row_indices_all = np.repeat(np.arange(num_cells), [len(neigh) for neigh in idxs_all_filter])
|
|
447
|
+
neighbor_labels_all = int_labels[flat_neighbors_all]
|
|
448
|
+
mask_all = int_labels == int_ct
|
|
449
|
+
|
|
450
|
+
for motif in motifs:
|
|
451
|
+
motif = list(motif) if not isinstance(motif, list) else motif
|
|
452
|
+
sort_motif = sorted(motif)
|
|
453
|
+
|
|
454
|
+
_, matching_cells_indices = self._query_pattern(motif)
|
|
455
|
+
if not matching_cells_indices:
|
|
456
|
+
# if matching_cells_indices is empty, it indicates no motif are grouped together within upper limit of radius (500)
|
|
457
|
+
continue
|
|
458
|
+
matching_cells_indices = np.concatenate([t for t in matching_cells_indices.values()])
|
|
459
|
+
matching_cells_indices = np.unique(matching_cells_indices)
|
|
460
|
+
print(f"number of cells skipped: {len(matching_cells_indices)}")
|
|
461
|
+
print(f"proportion of cells searched: {len(matching_cells_indices) / len(self.spatial_pos)}")
|
|
462
|
+
idxs_in_grids = self.kd_tree.query_ball_point(
|
|
463
|
+
self.spatial_pos[matching_cells_indices],
|
|
464
|
+
r=max_dist,
|
|
465
|
+
return_sorted=True,
|
|
466
|
+
workers=-1
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# using numpy
|
|
470
|
+
int_motifs = label_encoder.transform(np.array(motif))
|
|
471
|
+
|
|
472
|
+
# filter center out of neighbors
|
|
473
|
+
idxs_filter = [np.array(ids)[np.array(ids) != i][:min(max_ns, len(ids))] for i, ids in zip(matching_cells_indices, idxs_in_grids)]
|
|
474
|
+
|
|
475
|
+
flat_neighbors = np.concatenate(idxs_filter)
|
|
476
|
+
row_indices = np.repeat(np.arange(len(matching_cells_indices)), [len(neigh) for neigh in idxs_filter])
|
|
477
|
+
neighbor_labels = int_labels[flat_neighbors]
|
|
478
|
+
|
|
479
|
+
neighbor_matrix = np.zeros((len(matching_cells_indices), num_types), dtype=int)
|
|
480
|
+
np.add.at(neighbor_matrix, (row_indices, neighbor_labels), 1)
|
|
481
|
+
|
|
482
|
+
mask = int_labels[matching_cells_indices] == int_ct
|
|
483
|
+
n_motif_ct = np.sum(np.all(neighbor_matrix[mask][:, int_motifs] > 0, axis=1))
|
|
484
|
+
n_motif_labels = np.sum(np.all(neighbor_matrix[:, int_motifs] > 0, axis=1))
|
|
485
|
+
|
|
486
|
+
n_ct = len(cinds)
|
|
487
|
+
if ct in motif:
|
|
488
|
+
n_ct = round(n_ct / motif.count(ct))
|
|
489
|
+
|
|
490
|
+
hyge = hypergeom(M=len(self.labels), n=n_ct, N=n_motif_labels)
|
|
491
|
+
motif_out = {'center': ct, 'motifs': sort_motif, 'n_center_motif': n_motif_ct,
|
|
492
|
+
'n_center': n_ct, 'n_motif': n_motif_labels, 'expectation': hyge.mean(), 'p-values': hyge.sf(n_motif_ct)}
|
|
493
|
+
|
|
494
|
+
if return_cellID:
|
|
495
|
+
neighbor_matrix_all = np.zeros((num_cells, num_types), dtype=int)
|
|
496
|
+
np.add.at(neighbor_matrix_all, (row_indices_all, neighbor_labels_all), 1)
|
|
497
|
+
inds_all = np.where(np.all(neighbor_matrix_all[mask_all][:, int_motifs] > 0, axis=1))[0]
|
|
498
|
+
cind_with_motif = [cinds[i] for i in inds_all]
|
|
499
|
+
motif_mask = np.isin(np.array(self.labels), motif)
|
|
500
|
+
all_neighbors = np.concatenate(idxs_all[cind_with_motif])
|
|
501
|
+
exclude_self_mask = ~np.isin(all_neighbors, cind_with_motif)
|
|
502
|
+
valid_neighbors = all_neighbors[motif_mask[all_neighbors] & exclude_self_mask]
|
|
503
|
+
id_motif_celltype = set(valid_neighbors)
|
|
504
|
+
motif_out['cell_id'] = np.array(list(id_motif_celltype))
|
|
505
|
+
|
|
506
|
+
out.append(motif_out)
|
|
507
|
+
|
|
508
|
+
out_pd = pd.DataFrame(out)
|
|
509
|
+
|
|
510
|
+
if len(out_pd) == 1:
|
|
511
|
+
out_pd['if_significant'] = True if out_pd['p-values'][0] < 0.05 else False
|
|
512
|
+
return out_pd
|
|
513
|
+
else:
|
|
514
|
+
p_values = out_pd['p-values'].tolist()
|
|
515
|
+
if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
|
|
516
|
+
alpha=0.05,
|
|
517
|
+
method='poscorr')
|
|
518
|
+
out_pd['corrected p-values'] = corrected_p_values
|
|
519
|
+
out_pd['if_significant'] = if_rejected
|
|
520
|
+
out_pd = out_pd.sort_values(by='corrected p-values', ignore_index=True)
|
|
521
|
+
return out_pd
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def build_fptree_dist(self,
|
|
525
|
+
cell_pos: np.ndarray = None,
|
|
526
|
+
max_dist: float = 100,
|
|
527
|
+
min_support: float = 0.5,
|
|
528
|
+
if_max: bool = True,
|
|
529
|
+
min_size: int = 0,
|
|
530
|
+
cinds: List[int] = None,
|
|
531
|
+
max_ns: int = 100) -> tuple:
|
|
532
|
+
"""
|
|
533
|
+
Build a frequency pattern tree based on the distance of cell types.
|
|
534
|
+
|
|
535
|
+
Parameter
|
|
536
|
+
---------
|
|
537
|
+
cell_pos:
|
|
538
|
+
Spatial coordinates of input points.
|
|
539
|
+
If cell_pos is None, use all spots in fov to compute frequent patterns.
|
|
540
|
+
max_dist:
|
|
541
|
+
Maximum distance to consider a cell as a neighbor.
|
|
542
|
+
min_support:
|
|
543
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
544
|
+
if_max:
|
|
545
|
+
By default return the maximum set of frequent patterns without the subsets. If if_max=False, return all
|
|
546
|
+
patterns whose support values are greater than min_support.
|
|
547
|
+
min_size:
|
|
548
|
+
Minimum neighborhood size for each point to consider.
|
|
549
|
+
max_ns:
|
|
550
|
+
Maximum number of neighborhood size for each point.
|
|
551
|
+
|
|
552
|
+
Return
|
|
553
|
+
------
|
|
554
|
+
A tuple containing the FPs, the transactions table and the nerghbors index.
|
|
555
|
+
"""
|
|
556
|
+
if cell_pos is None:
|
|
557
|
+
cell_pos = self.spatial_pos
|
|
558
|
+
|
|
559
|
+
max_dist = min(max_dist, self.max_radius)
|
|
560
|
+
|
|
561
|
+
# start = time.time()
|
|
562
|
+
idxs = self.kd_tree.query_ball_point(cell_pos, r=max_dist, return_sorted=False, workers=-1)
|
|
563
|
+
if cinds is None:
|
|
564
|
+
cinds = list(range(len(idxs)))
|
|
565
|
+
# end = time.time()
|
|
566
|
+
# print("query: {end-start} seconds")
|
|
567
|
+
|
|
568
|
+
# Prepare data for FP-Tree construction
|
|
569
|
+
# start = time.time()
|
|
570
|
+
transactions = []
|
|
571
|
+
valid_idxs = []
|
|
572
|
+
labels = np.array(self.labels)
|
|
573
|
+
for i_idx, idx in zip(cinds, idxs):
|
|
574
|
+
if not idx:
|
|
575
|
+
continue
|
|
576
|
+
idx_array = np.array(idx)
|
|
577
|
+
valid_mask = idx_array != i_idx
|
|
578
|
+
valid_indices = idx_array[valid_mask][:max_ns]
|
|
579
|
+
|
|
580
|
+
transaction = labels[valid_indices]
|
|
581
|
+
if len(transaction) > min_size:
|
|
582
|
+
transactions.append(transaction.tolist())
|
|
583
|
+
valid_idxs.append(valid_indices)
|
|
584
|
+
# end = time.time()
|
|
585
|
+
# print(f"build transactions: {end-start} seconds")
|
|
586
|
+
# Convert transactions to a DataFrame suitable for fpgrowth
|
|
587
|
+
# start = time.time()
|
|
588
|
+
mlb = MultiLabelBinarizer()
|
|
589
|
+
encoded_data = mlb.fit_transform(transactions)
|
|
590
|
+
df = pd.DataFrame(encoded_data.astype(bool), columns=mlb.classes_)
|
|
591
|
+
|
|
592
|
+
# Construct FP-Tree using fpgrowth
|
|
593
|
+
fp_tree = fpgrowth(df, min_support=min_support, use_colnames=True)
|
|
594
|
+
# end = time.time()
|
|
595
|
+
# print(f"fp_growth: {end-start} seconds")
|
|
596
|
+
if if_max:
|
|
597
|
+
# start = time.time()
|
|
598
|
+
fp_tree = self.find_maximal_patterns(fp=fp_tree)
|
|
599
|
+
# end = time.time()
|
|
600
|
+
# print(f"find_maximal_patterns: {end-start} seconds")
|
|
601
|
+
|
|
602
|
+
# Remove suffix of items if treating duplicates as different items
|
|
603
|
+
# if dis_duplicates:
|
|
604
|
+
# fp_tree = self._remove_suffix(fp_tree)
|
|
605
|
+
|
|
606
|
+
if len(fp_tree) == 0:
|
|
607
|
+
return pd.DataFrame(columns=['support', 'itemsets']), df, valid_idxs
|
|
608
|
+
else:
|
|
609
|
+
fp_tree['itemsets'] = fp_tree['itemsets'].apply(lambda x: tuple(sorted(x)))
|
|
610
|
+
fp_tree = fp_tree.drop_duplicates().reset_index(drop=True)
|
|
611
|
+
fp_tree['itemsets'] = fp_tree['itemsets'].apply(lambda x: list(x))
|
|
612
|
+
fp_tree = fp_tree.sort_values(by='support', ignore_index=True, ascending=False)
|
|
613
|
+
return fp_tree, df, valid_idxs
|
|
614
|
+
|
|
615
|
+
def build_fptree_knn(self,
|
|
616
|
+
cell_pos: np.ndarray = None,
|
|
617
|
+
k: int = 30,
|
|
618
|
+
min_support: float = 0.5,
|
|
619
|
+
max_dist: float = 200,
|
|
620
|
+
if_max: bool = True
|
|
621
|
+
) -> tuple:
|
|
622
|
+
"""
|
|
623
|
+
Build a frequency pattern tree based on knn
|
|
624
|
+
|
|
625
|
+
Parameter
|
|
626
|
+
---------
|
|
627
|
+
cell_pos:
|
|
628
|
+
Spatial coordinates of input points.
|
|
629
|
+
If cell_pos is None, use all spots in fov to compute frequent patterns.
|
|
630
|
+
k:
|
|
631
|
+
Number of neighborhood size for each point.
|
|
632
|
+
min_support:
|
|
633
|
+
Threshold of frequency to consider a pattern as a frequent pattern
|
|
634
|
+
max_dist:
|
|
635
|
+
The maximum distance at which points are considered neighbors.
|
|
636
|
+
if_max:
|
|
637
|
+
By default return the maximum set of frequent patterns without the subsets. If if_max=False, return all
|
|
638
|
+
patterns whose support values are greater than min_support.
|
|
639
|
+
|
|
640
|
+
Return
|
|
641
|
+
------
|
|
642
|
+
A tuple containing the FPs, the transactions table, and the neighbors index.
|
|
643
|
+
"""
|
|
644
|
+
if cell_pos is None:
|
|
645
|
+
cell_pos = self.spatial_pos
|
|
646
|
+
|
|
647
|
+
max_dist = min(max_dist, self.max_radius)
|
|
648
|
+
# start = time.time()
|
|
649
|
+
dists, idxs = self.kd_tree.query(cell_pos, k=k + 1, workers=-1)
|
|
650
|
+
# end = time.time()
|
|
651
|
+
# print(f"knn query: {end-start} seconds")
|
|
652
|
+
|
|
653
|
+
# Prepare data for FP-Tree construction
|
|
654
|
+
# start = time.time()
|
|
655
|
+
idxs = np.array(idxs)
|
|
656
|
+
dists = np.array(dists)
|
|
657
|
+
labels = np.array(self.labels)
|
|
658
|
+
transactions = []
|
|
659
|
+
mask = dists < max_dist
|
|
660
|
+
for i, idx in enumerate(idxs):
|
|
661
|
+
inds = idx[mask[i]]
|
|
662
|
+
if len(inds) == 0:
|
|
663
|
+
continue
|
|
664
|
+
transaction = labels[inds[1:]]
|
|
665
|
+
# if dis_duplicates:
|
|
666
|
+
# transaction = distinguish_duplicates_numpy(transaction)
|
|
667
|
+
transactions.append(transaction) # 将 NumPy 数组转换回列表
|
|
668
|
+
|
|
669
|
+
# end = time.time()
|
|
670
|
+
# print(f"build transactions: {end-start} seconds")
|
|
671
|
+
|
|
672
|
+
# transactions = []
|
|
673
|
+
# for i, idx in enumerate(idxs):
|
|
674
|
+
# inds = [id for j, id in enumerate(idx) if
|
|
675
|
+
# dists[i][j] < max_dist] # only contain the KNN whose distance is less than max_dist
|
|
676
|
+
# transaction = [self.labels[i] for i in inds[1:] if self.labels[i]]
|
|
677
|
+
# # if dis_duplicates:
|
|
678
|
+
# # transaction = self._distinguish_duplicates(transaction)
|
|
679
|
+
# transactions.append(transaction)
|
|
680
|
+
|
|
681
|
+
# Convert transactions to a DataFrame suitable for fpgrowth
|
|
682
|
+
# start = time.time()
|
|
683
|
+
mlb = MultiLabelBinarizer()
|
|
684
|
+
encoded_data = mlb.fit_transform(transactions)
|
|
685
|
+
df = pd.DataFrame(encoded_data.astype(bool), columns=mlb.classes_)
|
|
686
|
+
|
|
687
|
+
# Construct FP-Tree using fpgrowth
|
|
688
|
+
fp_tree = fpgrowth(df, min_support=min_support, use_colnames=True)
|
|
689
|
+
# end = time.time()
|
|
690
|
+
# print(f"fp-growth: {end-start} seconds")
|
|
691
|
+
|
|
692
|
+
if if_max:
|
|
693
|
+
# start = time.time()
|
|
694
|
+
fp_tree = self.find_maximal_patterns(fp_tree)
|
|
695
|
+
# end = time.time()
|
|
696
|
+
# print(f"find_maximal_patterns: {end-start} seconds")
|
|
697
|
+
|
|
698
|
+
# if dis_duplicates:
|
|
699
|
+
# fp_tree = self._remove_suffix(fp_tree)
|
|
700
|
+
if len(fp_tree) == 0:
|
|
701
|
+
return pd.DataFrame(columns=['support', 'itemsets']), df, idxs
|
|
702
|
+
else:
|
|
703
|
+
fp_tree['itemsets'] = fp_tree['itemsets'].apply(lambda x: tuple(sorted(x)))
|
|
704
|
+
fp_tree = fp_tree.drop_duplicates().reset_index(drop=True)
|
|
705
|
+
fp_tree['itemsets'] = fp_tree['itemsets'].apply(lambda x: list(x))
|
|
706
|
+
fp_tree = fp_tree.sort_values(by='support', ignore_index=True, ascending=False)
|
|
707
|
+
return fp_tree, df, idxs
|
|
708
|
+
|
|
709
|
+
def find_patterns_grid(self,
|
|
710
|
+
max_dist: float = 100,
|
|
711
|
+
min_size: int = 0,
|
|
712
|
+
min_support: float = 0.5,
|
|
713
|
+
if_display: bool = True,
|
|
714
|
+
fig_size: tuple = (10, 5),
|
|
715
|
+
return_cellID: bool = False,
|
|
716
|
+
return_grid: bool = False,
|
|
717
|
+
):
|
|
718
|
+
"""
|
|
719
|
+
Create a grid and use it to find surrounding patterns in spatial data.
|
|
720
|
+
|
|
721
|
+
Parameter
|
|
722
|
+
---------
|
|
723
|
+
max_dist:
|
|
724
|
+
Maximum distance to consider a cell as a neighbor.
|
|
725
|
+
min_support:
|
|
726
|
+
Threshold of frequency to consider a pattern as a frequent pattern
|
|
727
|
+
min_size:
|
|
728
|
+
Additional parameters for pattern finding.
|
|
729
|
+
if_display:
|
|
730
|
+
Display the grid points with nearby frequent patterns if if_display=True.
|
|
731
|
+
fig_size:
|
|
732
|
+
Tuple of figure size.
|
|
733
|
+
return_cellID:
|
|
734
|
+
Indicate whether return cell IDs for each frequent pattern within the neighborhood of grid points.
|
|
735
|
+
By defaults do not return cell ID.
|
|
736
|
+
return_grid:
|
|
737
|
+
Indicate whether return the grid points. By default, do not return grid points.
|
|
738
|
+
If true, will return a tuple (fp_tree, grid)
|
|
739
|
+
|
|
740
|
+
Return
|
|
741
|
+
------
|
|
742
|
+
fp_tree:
|
|
743
|
+
Frequent patterns
|
|
744
|
+
"""
|
|
745
|
+
|
|
746
|
+
max_dist = min(max_dist, self.max_radius)
|
|
747
|
+
xmax, ymax = np.max(self.spatial_pos, axis=0)
|
|
748
|
+
xmin, ymin = np.min(self.spatial_pos, axis=0)
|
|
749
|
+
x_grid = np.arange(xmin - max_dist, xmax + max_dist, max_dist)
|
|
750
|
+
y_grid = np.arange(ymin - max_dist, ymax + max_dist, max_dist)
|
|
751
|
+
grid = np.array(np.meshgrid(x_grid, y_grid)).T.reshape(-1, 2)
|
|
752
|
+
|
|
753
|
+
fp, trans_df, idxs = self.build_fptree_dist(
|
|
754
|
+
cell_pos=grid,
|
|
755
|
+
max_dist=max_dist,
|
|
756
|
+
min_size=min_size,
|
|
757
|
+
min_support=min_support,
|
|
758
|
+
if_max=True,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# For each frequent pattern/motif, locate the cell IDs in the neighborhood of the above grid points
|
|
762
|
+
# as well as labelled with cell types in motif.
|
|
763
|
+
# if dis_duplicates:
|
|
764
|
+
# normalized_columns = [col.split('_')[0] for col in trans_df.columns]
|
|
765
|
+
# trans_df.columns = normalized_columns
|
|
766
|
+
# sparse_trans_df = csr_matrix(trans_df, dtype=int)
|
|
767
|
+
# trans_df_aggregated = pd.DataFrame.sparse.from_spmatrix(sparse_trans_df, columns=normalized_columns)
|
|
768
|
+
# trans_df_aggregated = trans_df_aggregated.groupby(trans_df_aggregated.columns, axis=1).sum()
|
|
769
|
+
id_neighbor_motifs = []
|
|
770
|
+
if if_display or return_cellID:
|
|
771
|
+
for motif in fp['itemsets']:
|
|
772
|
+
motif = list(motif)
|
|
773
|
+
fp_spots_index = set()
|
|
774
|
+
# if dis_duplicates:
|
|
775
|
+
# ct_counts_in_motif = pd.Series(motif).value_counts().to_dict()
|
|
776
|
+
# required_counts = pd.Series(ct_counts_in_motif, index=trans_df_aggregated.columns).fillna(0)
|
|
777
|
+
# ids = trans_df_aggregated[trans_df_aggregated >= required_counts].dropna().index
|
|
778
|
+
# else:
|
|
779
|
+
# ids = trans_df[trans_df[motif].all(axis=1)].index.to_list()
|
|
780
|
+
ids = trans_df[trans_df[motif].all(axis=1)].index.to_list()
|
|
781
|
+
if isinstance(idxs, list):
|
|
782
|
+
# ids = ids.index[ids == True].to_list()
|
|
783
|
+
fp_spots_index.update([i for id in ids for i in idxs[id] if self.labels[i] in motif])
|
|
784
|
+
else:
|
|
785
|
+
ids = idxs[ids]
|
|
786
|
+
fp_spots_index.update([i for id in ids for i in id if self.labels[i] in motif])
|
|
787
|
+
id_neighbor_motifs.append(fp_spots_index)
|
|
788
|
+
if return_cellID:
|
|
789
|
+
fp['cell_id'] = id_neighbor_motifs
|
|
790
|
+
|
|
791
|
+
if if_display:
|
|
792
|
+
fp_cts = sorted(set(t for items in fp['itemsets'] for t in list(items)))
|
|
793
|
+
n_colors = len(fp_cts)
|
|
794
|
+
colors = sns.color_palette('hsv', n_colors)
|
|
795
|
+
color_map = {ct: col for ct, col in zip(fp_cts, colors)}
|
|
796
|
+
|
|
797
|
+
fp_spots_index = set()
|
|
798
|
+
for cell_id in id_neighbor_motifs:
|
|
799
|
+
fp_spots_index.update(cell_id)
|
|
800
|
+
|
|
801
|
+
fp_spot_pos = self.spatial_pos[list(fp_spots_index), :]
|
|
802
|
+
fp_spot_label = self.labels[list(fp_spots_index)]
|
|
803
|
+
fig, ax = plt.subplots(figsize=fig_size)
|
|
804
|
+
# Plotting the grid lines
|
|
805
|
+
for x in x_grid:
|
|
806
|
+
ax.axvline(x, color='lightgray', linestyle='--', lw=0.5)
|
|
807
|
+
|
|
808
|
+
for y in y_grid:
|
|
809
|
+
ax.axhline(y, color='lightgray', linestyle='--', lw=0.5)
|
|
810
|
+
|
|
811
|
+
for ct in fp_cts:
|
|
812
|
+
ct_ind = fp_spot_label == ct
|
|
813
|
+
ax.scatter(fp_spot_pos[ct_ind, 0], fp_spot_pos[ct_ind, 1],
|
|
814
|
+
label=ct, color=color_map[ct], s=1)
|
|
815
|
+
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
|
|
816
|
+
plt.xlabel('Spatial X')
|
|
817
|
+
plt.ylabel('Spatial Y')
|
|
818
|
+
plt.title('Spatial distribution of frequent patterns')
|
|
819
|
+
ax.set_xticklabels([])
|
|
820
|
+
ax.set_yticklabels([])
|
|
821
|
+
ax.set_xticks([])
|
|
822
|
+
ax.set_yticks([])
|
|
823
|
+
plt.tight_layout(rect=[0, 0, 1.1, 1])
|
|
824
|
+
plt.show()
|
|
825
|
+
|
|
826
|
+
if return_grid:
|
|
827
|
+
return fp.sort_values(by='support', ignore_index=True, ascending=False), grid
|
|
828
|
+
else:
|
|
829
|
+
return fp.sort_values(by='support', ignore_index=True, ascending=False)
|
|
830
|
+
|
|
831
|
+
def find_patterns_rand(self,
|
|
832
|
+
max_dist: float = 100,
|
|
833
|
+
n_points: int = 1000,
|
|
834
|
+
min_support: float = 0.5,
|
|
835
|
+
min_size: int = 0,
|
|
836
|
+
if_display: bool = True,
|
|
837
|
+
fig_size: tuple = (10, 5),
|
|
838
|
+
return_cellID: bool = False,
|
|
839
|
+
seed: int = 2023) -> DataFrame:
|
|
840
|
+
"""
|
|
841
|
+
Randomly generate points and use them to find surrounding patterns in spatial data.
|
|
842
|
+
|
|
843
|
+
Parameter
|
|
844
|
+
---------
|
|
845
|
+
if_knn:
|
|
846
|
+
Use k-nearest neighbors or points within max_dist distance as neighborhood.
|
|
847
|
+
k:
|
|
848
|
+
Number of nearest neighbors. If if_knn=True, parameter k is used.
|
|
849
|
+
max_dist:
|
|
850
|
+
Maximum distance to consider a cell as a neighbor. If if_knn=False, parameter max_dist is used.
|
|
851
|
+
n_points:
|
|
852
|
+
Number of random points to generate.
|
|
853
|
+
min_support:
|
|
854
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
855
|
+
min_size:
|
|
856
|
+
Additional parameters for pattern finding.
|
|
857
|
+
if_display:
|
|
858
|
+
Display the grid points with nearby frequent patterns if if_display=True.
|
|
859
|
+
fig_size:
|
|
860
|
+
Tuple of figure size.
|
|
861
|
+
return_cellID:
|
|
862
|
+
Indicate whether return cell IDs for each frequent pattern within the neighborhood of grid points.
|
|
863
|
+
By defaults do not return cell ID.
|
|
864
|
+
seed:
|
|
865
|
+
Set random seed for reproducible.
|
|
866
|
+
|
|
867
|
+
Return
|
|
868
|
+
------
|
|
869
|
+
Results from the pattern finding function.
|
|
870
|
+
"""
|
|
871
|
+
|
|
872
|
+
max_dist = min(max_dist, self.max_radius)
|
|
873
|
+
xmax, ymax = np.max(self.spatial_pos, axis=0)
|
|
874
|
+
xmin, ymin = np.min(self.spatial_pos, axis=0)
|
|
875
|
+
np.random.seed(seed)
|
|
876
|
+
pos = np.column_stack((np.random.rand(n_points) * (xmax - xmin) + xmin,
|
|
877
|
+
np.random.rand(n_points) * (ymax - ymin) + ymin))
|
|
878
|
+
|
|
879
|
+
fp, trans_df, idxs = self.build_fptree_dist(
|
|
880
|
+
cell_pos=pos,
|
|
881
|
+
max_dist=max_dist,
|
|
882
|
+
min_size=min_size,
|
|
883
|
+
min_support=min_support,
|
|
884
|
+
if_max=True,
|
|
885
|
+
)
|
|
886
|
+
# if dis_duplicates:
|
|
887
|
+
# normalized_columns = [col.split('_')[0] for col in trans_df.columns]
|
|
888
|
+
# trans_df.columns = normalized_columns
|
|
889
|
+
# sparse_trans_df = csr_matrix(trans_df, dtype=int)
|
|
890
|
+
# trans_df_aggregated = pd.DataFrame.sparse.from_spmatrix(sparse_trans_df, columns=normalized_columns)
|
|
891
|
+
# trans_df_aggregated = trans_df_aggregated.groupby(trans_df_aggregated.columns, axis=1).sum()
|
|
892
|
+
|
|
893
|
+
id_neighbor_motifs = []
|
|
894
|
+
if if_display or return_cellID:
|
|
895
|
+
for motif in fp['itemsets']:
|
|
896
|
+
motif = list(motif)
|
|
897
|
+
fp_spots_index = set()
|
|
898
|
+
# if dis_duplicates:
|
|
899
|
+
# ct_counts_in_motif = pd.Series(motif).value_counts().to_dict()
|
|
900
|
+
# required_counts = pd.Series(ct_counts_in_motif, index=trans_df_aggregated.columns).fillna(0)
|
|
901
|
+
# ids = trans_df_aggregated[trans_df_aggregated >= required_counts].dropna().index
|
|
902
|
+
# else:
|
|
903
|
+
# ids = trans_df[trans_df[motif].all(axis=1)].index.to_list()
|
|
904
|
+
ids = trans_df[trans_df[motif].all(axis=1)].index.to_list()
|
|
905
|
+
if isinstance(idxs, list):
|
|
906
|
+
# ids = ids.index[ids == True].to_list()
|
|
907
|
+
fp_spots_index.update([i for id in ids for i in idxs[id] if self.labels[i] in motif])
|
|
908
|
+
else:
|
|
909
|
+
ids = idxs[ids]
|
|
910
|
+
fp_spots_index.update([i for id in ids for i in id if self.labels[i] in motif])
|
|
911
|
+
id_neighbor_motifs.append(fp_spots_index)
|
|
912
|
+
if return_cellID:
|
|
913
|
+
fp['cell_id'] = id_neighbor_motifs
|
|
914
|
+
|
|
915
|
+
if if_display:
|
|
916
|
+
fp_cts = sorted(set(t for items in fp['itemsets'] for t in list(items)))
|
|
917
|
+
n_colors = len(fp_cts)
|
|
918
|
+
colors = sns.color_palette('hsv', n_colors)
|
|
919
|
+
color_map = {ct: col for ct, col in zip(fp_cts, colors)}
|
|
920
|
+
|
|
921
|
+
fp_spots_index = set()
|
|
922
|
+
for cell_id in id_neighbor_motifs:
|
|
923
|
+
fp_spots_index.update(cell_id)
|
|
924
|
+
|
|
925
|
+
fp_spot_pos = self.spatial_pos[list(fp_spots_index), :]
|
|
926
|
+
fp_spot_label = self.labels[list(fp_spots_index)]
|
|
927
|
+
fig, ax = plt.subplots(figsize=fig_size)
|
|
928
|
+
for ct in fp_cts:
|
|
929
|
+
ct_ind = fp_spot_label == ct
|
|
930
|
+
ax.scatter(fp_spot_pos[ct_ind, 0], fp_spot_pos[ct_ind, 1],
|
|
931
|
+
label=ct, color=color_map[ct], s=1)
|
|
932
|
+
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
|
|
933
|
+
plt.xlabel('Spatial X')
|
|
934
|
+
plt.ylabel('Spatial Y')
|
|
935
|
+
plt.title('Spatial distribution of frequent patterns')
|
|
936
|
+
ax.set_xticklabels([])
|
|
937
|
+
ax.set_yticklabels([])
|
|
938
|
+
ax.set_xticks([])
|
|
939
|
+
ax.set_yticks([])
|
|
940
|
+
plt.tight_layout(rect=[0, 0, 1.1, 1])
|
|
941
|
+
plt.show()
|
|
942
|
+
|
|
943
|
+
return fp.sort_values(by='support', ignore_index=True, ascending=False)
|
|
944
|
+
|
|
945
|
+
def plot_fov(self,
|
|
946
|
+
min_cells_label: int = 50,
|
|
947
|
+
title: str = 'Spatial distribution of cell types',
|
|
948
|
+
fig_size: tuple = (10, 5)):
|
|
949
|
+
"""
|
|
950
|
+
Plot the cell type distribution of single fov.
|
|
951
|
+
|
|
952
|
+
Parameter
|
|
953
|
+
--------
|
|
954
|
+
min_cells_label:
|
|
955
|
+
Minimum number of points in each cell type to display.
|
|
956
|
+
title:
|
|
957
|
+
Figure title.
|
|
958
|
+
fig_size:
|
|
959
|
+
Figure size paramters.
|
|
960
|
+
|
|
961
|
+
Return
|
|
962
|
+
------
|
|
963
|
+
A figure.
|
|
964
|
+
"""
|
|
965
|
+
# Ensure that 'spatial' and label_key are present in the Anndata object
|
|
966
|
+
|
|
967
|
+
cell_type_counts = self.labels.value_counts()
|
|
968
|
+
n_colors = sum(cell_type_counts >= min_cells_label)
|
|
969
|
+
colors = sns.color_palette('hsv', n_colors)
|
|
970
|
+
|
|
971
|
+
color_counter = 0
|
|
972
|
+
fig, ax = plt.subplots(figsize=fig_size)
|
|
973
|
+
|
|
974
|
+
# Iterate over each cell type
|
|
975
|
+
for cell_type in sorted(self.labels.unique()):
|
|
976
|
+
# Filter data for each cell type
|
|
977
|
+
index = self.labels == cell_type
|
|
978
|
+
index = np.where(index)[0]
|
|
979
|
+
# data = self.labels[self.labels == cell_type].index
|
|
980
|
+
# Check if the cell type count is above the threshold
|
|
981
|
+
if cell_type_counts[cell_type] >= min_cells_label:
|
|
982
|
+
ax.scatter(self.spatial_pos[index, 0], self.spatial_pos[index, 1],
|
|
983
|
+
label=cell_type, color=colors[color_counter], s=1)
|
|
984
|
+
color_counter += 1
|
|
985
|
+
else:
|
|
986
|
+
ax.scatter(self.spatial_pos[index, 0], self.spatial_pos[index, 1],
|
|
987
|
+
color='grey', s=1)
|
|
988
|
+
|
|
989
|
+
handles, labels = ax.get_legend_handles_labels()
|
|
990
|
+
|
|
991
|
+
# Modify labels to include count values
|
|
992
|
+
new_labels = [f'{label} ({cell_type_counts[label]})' for label in labels]
|
|
993
|
+
|
|
994
|
+
# Create new legend
|
|
995
|
+
ax.legend(handles, new_labels, loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
|
|
996
|
+
# ax.legend(handles, new_labels, loc='lower center', bbox_to_anchor=(1, 0.5), markerscale=4)
|
|
997
|
+
|
|
998
|
+
# ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
|
|
999
|
+
|
|
1000
|
+
plt.xlabel('Spatial X')
|
|
1001
|
+
plt.ylabel('Spatial Y')
|
|
1002
|
+
plt.title(title)
|
|
1003
|
+
ax.set_xticklabels([])
|
|
1004
|
+
ax.set_yticklabels([])
|
|
1005
|
+
ax.set_xticks([])
|
|
1006
|
+
ax.set_yticks([])
|
|
1007
|
+
|
|
1008
|
+
# Adjust layout to prevent clipping of ylabel and accommodate the legend
|
|
1009
|
+
plt.tight_layout(rect=[0, 0, 1.1, 1])
|
|
1010
|
+
|
|
1011
|
+
plt.show()
|
|
1012
|
+
|
|
1013
|
+
def plot_motif_grid(self,
|
|
1014
|
+
motif: Union[str, List[str]],
|
|
1015
|
+
fp: pd.DataFrame,
|
|
1016
|
+
fig_size: tuple = (10, 5),
|
|
1017
|
+
max_dist: float = 100,
|
|
1018
|
+
):
|
|
1019
|
+
"""
|
|
1020
|
+
Display the distribution of each motif around grid points. To make sure the input
|
|
1021
|
+
motif can be found in the results obtained by find_patterns_grid, use the same arguments
|
|
1022
|
+
as those in find_pattern_grid method.
|
|
1023
|
+
|
|
1024
|
+
Parameter
|
|
1025
|
+
---------
|
|
1026
|
+
motif:
|
|
1027
|
+
Motif (names of cell types) to be colored
|
|
1028
|
+
fp:
|
|
1029
|
+
Frequent patterns identified by find_patterns_grid.
|
|
1030
|
+
max_dist:
|
|
1031
|
+
Spacing distance for building grid. Make sure using the same value as that in find_patterns_grid.
|
|
1032
|
+
fig_size:
|
|
1033
|
+
Figure size.
|
|
1034
|
+
"""
|
|
1035
|
+
if isinstance(motif, str):
|
|
1036
|
+
motif = [motif]
|
|
1037
|
+
|
|
1038
|
+
max_dist = min(max_dist, self.max_radius)
|
|
1039
|
+
|
|
1040
|
+
labels_unique = self.labels.unique()
|
|
1041
|
+
motif_exc = [m for m in motif if m not in labels_unique]
|
|
1042
|
+
if len(motif_exc) != 0:
|
|
1043
|
+
print(f"Found no {motif_exc} in {self.label_key}. Ignoring them.")
|
|
1044
|
+
motif = [m for m in motif if m not in motif_exc]
|
|
1045
|
+
|
|
1046
|
+
# Build mesh
|
|
1047
|
+
xmax, ymax = np.max(self.spatial_pos, axis=0)
|
|
1048
|
+
xmin, ymin = np.min(self.spatial_pos, axis=0)
|
|
1049
|
+
x_grid = np.arange(xmin - max_dist, xmax + max_dist, max_dist)
|
|
1050
|
+
y_grid = np.arange(ymin - max_dist, ymax + max_dist, max_dist)
|
|
1051
|
+
grid = np.array(np.meshgrid(x_grid, y_grid)).T.reshape(-1, 2)
|
|
1052
|
+
|
|
1053
|
+
# self.build_fptree_dist returns valid_idxs () instead of all the idxs,
|
|
1054
|
+
# so recalculate the idxs directly using self.kd_tree.query_ball_point
|
|
1055
|
+
idxs = self.kd_tree.query_ball_point(grid, r=max_dist, return_sorted=False, workers=-1)
|
|
1056
|
+
|
|
1057
|
+
# Locate the index of grid points acting as centers with motif nearby
|
|
1058
|
+
id_center = []
|
|
1059
|
+
for i, idx in enumerate(idxs):
|
|
1060
|
+
ns = [self.labels[id] for id in idx]
|
|
1061
|
+
if self.has_motif(neighbors=motif, labels=ns):
|
|
1062
|
+
id_center.append(i)
|
|
1063
|
+
|
|
1064
|
+
# Locate the index of cell types contained in motif in the
|
|
1065
|
+
# neighborhood of above grid points with motif nearby
|
|
1066
|
+
id_motif_celltype = fp[fp['itemsets'].apply(
|
|
1067
|
+
lambda p: set(p)) == set(motif)]
|
|
1068
|
+
id_motif_celltype = id_motif_celltype['cell_id'].iloc[0]
|
|
1069
|
+
|
|
1070
|
+
# Plot above spots and center grid points
|
|
1071
|
+
# Set color map as in find_patterns_grid
|
|
1072
|
+
fp_cts = sorted(set(t for items in fp['itemsets'] for t in list(items)))
|
|
1073
|
+
n_colors = len(fp_cts)
|
|
1074
|
+
colors = sns.color_palette('hsv', n_colors)
|
|
1075
|
+
color_map = {ct: col for ct, col in zip(fp_cts, colors)}
|
|
1076
|
+
|
|
1077
|
+
motif_spot_pos = self.spatial_pos[list(id_motif_celltype), :]
|
|
1078
|
+
motif_spot_label = self.labels[list(id_motif_celltype)]
|
|
1079
|
+
fig, ax = plt.subplots(figsize=fig_size)
|
|
1080
|
+
# Plotting the grid lines
|
|
1081
|
+
for x in x_grid:
|
|
1082
|
+
ax.axvline(x, color='lightgray', linestyle='--', lw=0.5)
|
|
1083
|
+
|
|
1084
|
+
for y in y_grid:
|
|
1085
|
+
ax.axhline(y, color='lightgray', linestyle='--', lw=0.5)
|
|
1086
|
+
ax.scatter(grid[id_center, 0], grid[id_center, 1], label='Grid Points',
|
|
1087
|
+
edgecolors='red', facecolors='none', s=8)
|
|
1088
|
+
|
|
1089
|
+
# Plotting other spots as background
|
|
1090
|
+
bg_index = [i for i, _ in enumerate(self.labels) if
|
|
1091
|
+
i not in id_motif_celltype] # the other spots are colored as background
|
|
1092
|
+
# bg_adata = self.adata[bg_index, :]
|
|
1093
|
+
bg_pos = self.spatial_pos[bg_index, :]
|
|
1094
|
+
ax.scatter(bg_pos[:, 0],
|
|
1095
|
+
bg_pos[:, 1],
|
|
1096
|
+
color='darkgrey', s=1)
|
|
1097
|
+
|
|
1098
|
+
motif_unique = list(set(motif))
|
|
1099
|
+
for ct in motif_unique:
|
|
1100
|
+
ct_ind = motif_spot_label == ct
|
|
1101
|
+
ax.scatter(motif_spot_pos[ct_ind, 0],
|
|
1102
|
+
motif_spot_pos[ct_ind, 1],
|
|
1103
|
+
label=ct, color=color_map[ct], s=1)
|
|
1104
|
+
|
|
1105
|
+
ax.set_xlim([xmin - max_dist, xmax + max_dist])
|
|
1106
|
+
ax.set_ylim([ymin - max_dist, ymax + max_dist])
|
|
1107
|
+
ax.legend(title='motif', loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
|
|
1108
|
+
# ax.legend(title='motif', loc='lower center', bbox_to_anchor=(0, 0.), markerscale=4)
|
|
1109
|
+
plt.xlabel('Spatial X')
|
|
1110
|
+
plt.ylabel('Spatial Y')
|
|
1111
|
+
plt.title('Spatial distribution of frequent patterns')
|
|
1112
|
+
ax.set_xticklabels([])
|
|
1113
|
+
ax.set_yticklabels([])
|
|
1114
|
+
ax.set_xticks([])
|
|
1115
|
+
ax.set_yticks([])
|
|
1116
|
+
plt.tight_layout(rect=[0, 0, 1.1, 1])
|
|
1117
|
+
plt.show()
|
|
1118
|
+
|
|
1119
|
+
def plot_motif_rand(self,
|
|
1120
|
+
motif: Union[str, List[str]],
|
|
1121
|
+
fp: pd.DataFrame,
|
|
1122
|
+
max_dist: float = 100,
|
|
1123
|
+
n_points: int = 1000,
|
|
1124
|
+
fig_size: tuple = (10, 5),
|
|
1125
|
+
seed: int = 2023,
|
|
1126
|
+
):
|
|
1127
|
+
"""
|
|
1128
|
+
Display the random sampled points with motif in radius-based neighborhood,
|
|
1129
|
+
and cell types of motif in the neighborhood of these grid points. To make sure the input
|
|
1130
|
+
motif can be found in the results obtained by find_patterns_grid, use the same arguments
|
|
1131
|
+
as those in find_pattern_grid method.
|
|
1132
|
+
|
|
1133
|
+
Parameter
|
|
1134
|
+
---------
|
|
1135
|
+
motif:
|
|
1136
|
+
Motif (names of cell types) to be colored
|
|
1137
|
+
fp:
|
|
1138
|
+
Frequent patterns identified by find_patterns_grid.
|
|
1139
|
+
max_dist:
|
|
1140
|
+
Spacing distance for building grid. Make sure using the same value as that in find_patterns_grid.
|
|
1141
|
+
n_points:
|
|
1142
|
+
Number of random points to generate.
|
|
1143
|
+
fig_size:
|
|
1144
|
+
Figure size.
|
|
1145
|
+
seed:
|
|
1146
|
+
Set random seed for reproducible.
|
|
1147
|
+
"""
|
|
1148
|
+
if isinstance(motif, str):
|
|
1149
|
+
motif = [motif]
|
|
1150
|
+
|
|
1151
|
+
max_dist = min(max_dist, self.max_radius)
|
|
1152
|
+
|
|
1153
|
+
labels_unique = self.labels.unique()
|
|
1154
|
+
motif_exc = [m for m in motif if m not in labels_unique]
|
|
1155
|
+
if len(motif_exc) != 0:
|
|
1156
|
+
print(f"Found no {motif_exc} in {self.label_key}. Ignoring them.")
|
|
1157
|
+
motif = [m for m in motif if m not in motif_exc]
|
|
1158
|
+
|
|
1159
|
+
# Random sample points
|
|
1160
|
+
xmax, ymax = np.max(self.spatial_pos, axis=0)
|
|
1161
|
+
xmin, ymin = np.min(self.spatial_pos, axis=0)
|
|
1162
|
+
np.random.seed(seed)
|
|
1163
|
+
pos = np.column_stack((np.random.rand(n_points) * (xmax - xmin) + xmin,
|
|
1164
|
+
np.random.rand(n_points) * (ymax - ymin) + ymin))
|
|
1165
|
+
|
|
1166
|
+
idxs = self.kd_tree.query_ball_point(pos, r=max_dist, return_sorted=False, workers=-1)
|
|
1167
|
+
|
|
1168
|
+
# Locate the index of grid points acting as centers with motif nearby
|
|
1169
|
+
id_center = []
|
|
1170
|
+
for i, idx in enumerate(idxs):
|
|
1171
|
+
ns = [self.labels[id] for id in idx]
|
|
1172
|
+
if self.has_motif(neighbors=motif, labels=ns):
|
|
1173
|
+
id_center.append(i)
|
|
1174
|
+
|
|
1175
|
+
# Locate the index of cell types contained in motif in the
|
|
1176
|
+
# neighborhood of above random points with motif nearby
|
|
1177
|
+
id_motif_celltype = fp[fp['itemsets'].apply(
|
|
1178
|
+
lambda p: set(p)) == set(motif)]
|
|
1179
|
+
id_motif_celltype = id_motif_celltype['cell_id'].iloc[0]
|
|
1180
|
+
|
|
1181
|
+
# Plot above spots and center grid points
|
|
1182
|
+
# Set color map as in find_patterns_grid
|
|
1183
|
+
fp_cts = sorted(set(t for items in fp['itemsets'] for t in list(items)))
|
|
1184
|
+
n_colors = len(fp_cts)
|
|
1185
|
+
colors = sns.color_palette('hsv', n_colors)
|
|
1186
|
+
color_map = {ct: col for ct, col in zip(fp_cts, colors)}
|
|
1187
|
+
|
|
1188
|
+
motif_spot_pos = self.spatial_pos[list(id_motif_celltype), :]
|
|
1189
|
+
motif_spot_label = self.labels[list(id_motif_celltype)]
|
|
1190
|
+
fig, ax = plt.subplots(figsize=fig_size)
|
|
1191
|
+
ax.scatter(pos[id_center, 0], pos[id_center, 1], label='Random Sampling Points',
|
|
1192
|
+
edgecolors='red', facecolors='none', s=8)
|
|
1193
|
+
|
|
1194
|
+
# Plotting other spots as background
|
|
1195
|
+
bg_index = [i for i, _ in enumerate(self.labels) if
|
|
1196
|
+
i not in id_motif_celltype] # the other spots are colored as background
|
|
1197
|
+
bg_adata = self.spatial_pos[bg_index, :]
|
|
1198
|
+
ax.scatter(bg_adata[:, 0],
|
|
1199
|
+
bg_adata[:, 1],
|
|
1200
|
+
color='darkgrey', s=1)
|
|
1201
|
+
motif_unique = list(set(motif))
|
|
1202
|
+
for ct in motif_unique:
|
|
1203
|
+
ct_ind = motif_spot_label == ct
|
|
1204
|
+
ax.scatter(motif_spot_pos[ct_ind, 0],
|
|
1205
|
+
motif_spot_pos[ct_ind, 1],
|
|
1206
|
+
label=ct, color=color_map[ct], s=1)
|
|
1207
|
+
|
|
1208
|
+
ax.set_xlim([xmin - max_dist, xmax + max_dist])
|
|
1209
|
+
ax.set_ylim([ymin - max_dist, ymax + max_dist])
|
|
1210
|
+
ax.legend(title='motif', loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
|
|
1211
|
+
plt.xlabel('Spatial X')
|
|
1212
|
+
plt.ylabel('Spatial Y')
|
|
1213
|
+
plt.title('Spatial distribution of frequent patterns')
|
|
1214
|
+
ax.set_xticklabels([])
|
|
1215
|
+
ax.set_yticklabels([])
|
|
1216
|
+
ax.set_xticks([])
|
|
1217
|
+
ax.set_yticks([])
|
|
1218
|
+
plt.tight_layout(rect=[0, 0, 1.1, 1])
|
|
1219
|
+
plt.show()
|
|
1220
|
+
|
|
1221
|
+
def plot_motif_celltype(self,
|
|
1222
|
+
ct: str,
|
|
1223
|
+
motif: Union[str, List[str]],
|
|
1224
|
+
max_dist: float = 100,
|
|
1225
|
+
fig_size: tuple = (10, 5)
|
|
1226
|
+
):
|
|
1227
|
+
"""
|
|
1228
|
+
Display the distribution of interested motifs in the radius-based neighborhood of certain cell type.
|
|
1229
|
+
This function is mainly used to visualize the results of motif_enrichment_dist. Make sure the input parameters
|
|
1230
|
+
are consistent with those of motif_enrichment_dist.
|
|
1231
|
+
|
|
1232
|
+
Parameter
|
|
1233
|
+
---------
|
|
1234
|
+
ct:
|
|
1235
|
+
Cell type as the center cells.
|
|
1236
|
+
motif:
|
|
1237
|
+
Motif (names of cell types) to be colored.
|
|
1238
|
+
max_dist:
|
|
1239
|
+
Spacing distance for building grid. Make sure using the same value as that in find_patterns_grid.
|
|
1240
|
+
fig_size:
|
|
1241
|
+
Figure size.
|
|
1242
|
+
"""
|
|
1243
|
+
if isinstance(motif, str):
|
|
1244
|
+
motif = [motif]
|
|
1245
|
+
|
|
1246
|
+
max_dist = min(max_dist, self.max_radius)
|
|
1247
|
+
|
|
1248
|
+
motif_exc = [m for m in motif if m not in self.labels.unique()]
|
|
1249
|
+
if len(motif_exc) != 0:
|
|
1250
|
+
print(f"Found no {motif_exc} in {self.label_key}. Ignoring them.")
|
|
1251
|
+
motif = [m for m in motif if m not in motif_exc]
|
|
1252
|
+
|
|
1253
|
+
if ct not in self.labels.unique():
|
|
1254
|
+
raise ValueError(f"Found no {ct} in {self.label_key}!")
|
|
1255
|
+
|
|
1256
|
+
cinds = [i for i, label in enumerate(self.labels) if label == ct] # id of center cell type
|
|
1257
|
+
# ct_pos = self.spatial_pos[cinds]
|
|
1258
|
+
idxs = self.kd_tree.query_ball_point(self.spatial_pos, r=max_dist, return_sorted=False, workers=-1)
|
|
1259
|
+
|
|
1260
|
+
# find the index of cell type spots whose neighborhoods contain given motif
|
|
1261
|
+
# cind_with_motif = []
|
|
1262
|
+
# sort_motif = sorted(motif)
|
|
1263
|
+
label_encoder = LabelEncoder()
|
|
1264
|
+
int_labels = label_encoder.fit_transform(np.array(self.labels))
|
|
1265
|
+
int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
|
|
1266
|
+
int_motifs = label_encoder.transform(np.array(motif))
|
|
1267
|
+
|
|
1268
|
+
num_cells = len(idxs)
|
|
1269
|
+
num_types = len(label_encoder.classes_)
|
|
1270
|
+
idxs_filter = [np.array(ids)[np.array(ids) != i] for i, ids in enumerate(idxs)]
|
|
1271
|
+
|
|
1272
|
+
flat_neighbors = np.concatenate(idxs_filter)
|
|
1273
|
+
row_indices = np.repeat(np.arange(num_cells), [len(neigh) for neigh in idxs_filter])
|
|
1274
|
+
neighbor_labels = int_labels[flat_neighbors]
|
|
1275
|
+
|
|
1276
|
+
neighbor_matrix = np.zeros((num_cells, num_types), dtype=int)
|
|
1277
|
+
np.add.at(neighbor_matrix, (row_indices, neighbor_labels), 1)
|
|
1278
|
+
|
|
1279
|
+
mask = int_labels == int_ct
|
|
1280
|
+
inds = np.where(np.all(neighbor_matrix[mask][:, int_motifs] > 0, axis=1))[0]
|
|
1281
|
+
cind_with_motif = [cinds[i] for i in inds]
|
|
1282
|
+
|
|
1283
|
+
# for id in cinds:
|
|
1284
|
+
#
|
|
1285
|
+
# if self.has_motif(sort_motif, [self.labels[idx] for idx in idxs[id] if idx != id]):
|
|
1286
|
+
# cind_with_motif.append(id)
|
|
1287
|
+
|
|
1288
|
+
# Locate the index of motifs in the neighborhood of center cell type.
|
|
1289
|
+
|
|
1290
|
+
motif_mask = np.isin(np.array(self.labels), motif)
|
|
1291
|
+
all_neighbors = np.concatenate(idxs[cind_with_motif])
|
|
1292
|
+
exclude_self_mask = ~np.isin(all_neighbors, cind_with_motif)
|
|
1293
|
+
valid_neighbors = all_neighbors[motif_mask[all_neighbors] & exclude_self_mask]
|
|
1294
|
+
id_motif_celltype = set(valid_neighbors)
|
|
1295
|
+
|
|
1296
|
+
# id_motif_celltype = set()
|
|
1297
|
+
# for id in cind_with_motif:
|
|
1298
|
+
# id_neighbor = [i for i in idxs[id] if self.labels[i] in motif and i != id]
|
|
1299
|
+
# id_motif_celltype.update(id_neighbor)
|
|
1300
|
+
|
|
1301
|
+
# Plot figures
|
|
1302
|
+
motif_unique = set(motif)
|
|
1303
|
+
n_colors = len(motif_unique)
|
|
1304
|
+
colors = sns.color_palette('hsv', n_colors)
|
|
1305
|
+
color_map = {ct: col for ct, col in zip(motif_unique, colors)}
|
|
1306
|
+
motif_spot_pos = self.spatial_pos[list(id_motif_celltype), :]
|
|
1307
|
+
motif_spot_label = self.labels[list(id_motif_celltype)]
|
|
1308
|
+
fig, ax = plt.subplots(figsize=fig_size)
|
|
1309
|
+
# Plotting other spots as background
|
|
1310
|
+
labels_length = len(self.labels)
|
|
1311
|
+
id_motif_celltype_set = set(id_motif_celltype)
|
|
1312
|
+
cind_with_motif_set = set(cind_with_motif)
|
|
1313
|
+
bg_index = [i for i in range(labels_length) if i not in id_motif_celltype_set and i not in cind_with_motif_set]
|
|
1314
|
+
bg_adata = self.spatial_pos[bg_index, :]
|
|
1315
|
+
ax.scatter(bg_adata[:, 0],
|
|
1316
|
+
bg_adata[:, 1],
|
|
1317
|
+
color='darkgrey', s=1)
|
|
1318
|
+
# Plot center the cell type whose neighborhood contains motif
|
|
1319
|
+
ax.scatter(self.spatial_pos[cind_with_motif, 0],
|
|
1320
|
+
self.spatial_pos[cind_with_motif, 1],
|
|
1321
|
+
label=ct, edgecolors='red', facecolors='none', s=3,
|
|
1322
|
+
)
|
|
1323
|
+
for ct_m in motif_unique:
|
|
1324
|
+
ct_ind = motif_spot_label == ct_m
|
|
1325
|
+
ax.scatter(motif_spot_pos[ct_ind, 0],
|
|
1326
|
+
motif_spot_pos[ct_ind, 1],
|
|
1327
|
+
label=ct_m, color=color_map[ct_m], s=1)
|
|
1328
|
+
|
|
1329
|
+
ax.legend(title='motif', loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
|
|
1330
|
+
# ax.legend(title='motif', loc='lower center', bbox_to_anchor=(1, 0.5), markerscale=4)
|
|
1331
|
+
plt.xlabel('Spatial X')
|
|
1332
|
+
plt.ylabel('Spatial Y')
|
|
1333
|
+
plt.title(f"Spatial distribution of motif around {ct}")
|
|
1334
|
+
ax.set_xticklabels([])
|
|
1335
|
+
ax.set_yticklabels([])
|
|
1336
|
+
ax.set_xticks([])
|
|
1337
|
+
ax.set_yticks([])
|
|
1338
|
+
plt.tight_layout(rect=[0, 0, 1.1, 1])
|
|
1339
|
+
plt.show()
|