SpatialQuery 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SpatialQuery/__init__.py +3 -0
- SpatialQuery/spatial_query.py +1339 -0
- SpatialQuery/spatial_query_multiple_fov.py +1116 -0
- SpatialQuery/utils.py +130 -0
- spatialquery-0.0.1.dist-info/METADATA +56 -0
- spatialquery-0.0.1.dist-info/RECORD +9 -0
- spatialquery-0.0.1.dist-info/WHEEL +5 -0
- spatialquery-0.0.1.dist-info/licenses/LICENSE +21 -0
- spatialquery-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1116 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import scipy.stats as stats
|
|
6
|
+
import statsmodels.stats.multitest as mt
|
|
7
|
+
from anndata import AnnData
|
|
8
|
+
from mlxtend.frequent_patterns import fpgrowth
|
|
9
|
+
from sklearn.preprocessing import MultiLabelBinarizer
|
|
10
|
+
from pandas import DataFrame
|
|
11
|
+
from scipy.stats import hypergeom
|
|
12
|
+
import time
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
from collections import defaultdict
|
|
15
|
+
from sklearn.preprocessing import LabelEncoder
|
|
16
|
+
|
|
17
|
+
from .spatial_query import spatial_query
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class spatial_query_multi:
|
|
21
|
+
def __init__(self,
|
|
22
|
+
adatas: List[AnnData],
|
|
23
|
+
datasets: List[str],
|
|
24
|
+
spatial_key: str,
|
|
25
|
+
label_key: str,
|
|
26
|
+
leaf_size: int,
|
|
27
|
+
max_radius: float = 500,
|
|
28
|
+
n_split: int = 10,
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Initiate models, including setting attributes and building kd-tree for each field of view.
|
|
32
|
+
|
|
33
|
+
Parameter
|
|
34
|
+
---------
|
|
35
|
+
adatas:
|
|
36
|
+
List of adata
|
|
37
|
+
datasets:
|
|
38
|
+
List of dataset names
|
|
39
|
+
spatial_key:
|
|
40
|
+
Spatial coordination name in AnnData.obsm object
|
|
41
|
+
label_key:
|
|
42
|
+
Label name in AnnData.obs object
|
|
43
|
+
leaf_size:
|
|
44
|
+
The largest number of points stored in each leaf node.
|
|
45
|
+
max_radius:
|
|
46
|
+
The upper limit of neighborhood radius.
|
|
47
|
+
"""
|
|
48
|
+
# Each element in self.spatial_queries stores a spatial_query object
|
|
49
|
+
self.spatial_key = spatial_key
|
|
50
|
+
self.label_key = label_key
|
|
51
|
+
self.max_radius = max_radius
|
|
52
|
+
# Modify dataset names by d_0, d_2, ... for duplicates in datasets
|
|
53
|
+
count_dict = {}
|
|
54
|
+
modified_datasets = []
|
|
55
|
+
for dataset in datasets:
|
|
56
|
+
if '_' in dataset:
|
|
57
|
+
print(f"Warning: Misusage of underscore in '{dataset}'. Replacing with hyphen.")
|
|
58
|
+
dataset = dataset.replace('_', '-')
|
|
59
|
+
|
|
60
|
+
if dataset in count_dict:
|
|
61
|
+
count_dict[dataset] += 1
|
|
62
|
+
else:
|
|
63
|
+
count_dict[dataset] = 0
|
|
64
|
+
|
|
65
|
+
mod_dataset = f"{dataset}_{count_dict[dataset]}"
|
|
66
|
+
modified_datasets.append(mod_dataset)
|
|
67
|
+
|
|
68
|
+
self.datasets = modified_datasets
|
|
69
|
+
|
|
70
|
+
self.spatial_queries = [spatial_query(
|
|
71
|
+
adata=adata,
|
|
72
|
+
dataset=self.datasets[i],
|
|
73
|
+
spatial_key=spatial_key,
|
|
74
|
+
label_key=label_key,
|
|
75
|
+
leaf_size=leaf_size,
|
|
76
|
+
max_radius=self.max_radius,
|
|
77
|
+
n_split=n_split,
|
|
78
|
+
) for i, adata in enumerate(adatas)]
|
|
79
|
+
|
|
80
|
+
def find_fp_knn(self,
|
|
81
|
+
ct: str,
|
|
82
|
+
dataset: Union[str, List[str]] = None,
|
|
83
|
+
k: int = 30,
|
|
84
|
+
min_support: float = 0.5,
|
|
85
|
+
max_dist: float = 500
|
|
86
|
+
) -> pd.DataFrame:
|
|
87
|
+
"""
|
|
88
|
+
Find frequent patterns within the KNNs of certain cell type in multiple fields of view.
|
|
89
|
+
|
|
90
|
+
Parameter
|
|
91
|
+
---------
|
|
92
|
+
ct:
|
|
93
|
+
Cell type name.
|
|
94
|
+
dataset:
|
|
95
|
+
Datasets for searching for frequent patterns.
|
|
96
|
+
Use all datasets if dataset=None.
|
|
97
|
+
k:
|
|
98
|
+
Number of nearest neighbors.
|
|
99
|
+
min_support:
|
|
100
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
101
|
+
max_dist:
|
|
102
|
+
The maximum distance at which points are considered neighbors.
|
|
103
|
+
|
|
104
|
+
Return
|
|
105
|
+
------
|
|
106
|
+
Frequent patterns in the neighborhood of certain cell type.
|
|
107
|
+
"""
|
|
108
|
+
# Search transactions for each field of view, find the frequent patterns of integrated transactions
|
|
109
|
+
# start = time.time()
|
|
110
|
+
if_exist_label = [ct in s.labels.unique() for s in self.spatial_queries]
|
|
111
|
+
if not any(if_exist_label):
|
|
112
|
+
raise ValueError(f"Found no {self.label_key} in all datasets!")
|
|
113
|
+
|
|
114
|
+
if dataset is None:
|
|
115
|
+
# Use all datasets if dataset is not provided
|
|
116
|
+
dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
117
|
+
|
|
118
|
+
# Make sure dataset is a list
|
|
119
|
+
if isinstance(dataset, str):
|
|
120
|
+
dataset = [dataset]
|
|
121
|
+
|
|
122
|
+
# test if the input dataset name is valid
|
|
123
|
+
valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
124
|
+
for ds in dataset:
|
|
125
|
+
if ds not in valid_ds_names:
|
|
126
|
+
raise ValueError(f"Invalid input dataset name: {ds}.\n "
|
|
127
|
+
f"Valid dataset names are: {set(valid_ds_names)}")
|
|
128
|
+
|
|
129
|
+
max_dist = min(max_dist, self.max_radius)
|
|
130
|
+
# end = time.time()
|
|
131
|
+
# print(f"time for checking validation of inputs: {end-start} seconds")
|
|
132
|
+
|
|
133
|
+
# start = time.time()
|
|
134
|
+
transactions = []
|
|
135
|
+
for s in self.spatial_queries:
|
|
136
|
+
if s.dataset.split('_')[0] not in dataset:
|
|
137
|
+
continue
|
|
138
|
+
cell_pos = s.spatial_pos
|
|
139
|
+
labels = np.array(s.labels)
|
|
140
|
+
if ct not in np.unique(labels):
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
ct_pos = cell_pos[labels == ct]
|
|
144
|
+
dists, idxs = s.kd_tree.query(ct_pos, k=k + 1, workers=-1)
|
|
145
|
+
mask = dists < max_dist
|
|
146
|
+
for i, idx in enumerate(idxs):
|
|
147
|
+
inds = idx[mask[i]]
|
|
148
|
+
transaction = labels[inds[1:]]
|
|
149
|
+
transactions.append(transaction)
|
|
150
|
+
# end = time.time()
|
|
151
|
+
# print(f"time for building {len(transactions)} transactions: {end-start} seconds.")
|
|
152
|
+
|
|
153
|
+
# start = time.time()
|
|
154
|
+
mlb = MultiLabelBinarizer()
|
|
155
|
+
encoded_data = mlb.fit_transform(transactions)
|
|
156
|
+
df = pd.DataFrame(encoded_data.astype(bool), columns=mlb.classes_)
|
|
157
|
+
# end = time.time()
|
|
158
|
+
# print(f"time for building df for fpgrowth: {end-start} seconds")
|
|
159
|
+
|
|
160
|
+
# start = time.time()
|
|
161
|
+
fp = fpgrowth(df, min_support=min_support, use_colnames=True)
|
|
162
|
+
if len(fp) == 0:
|
|
163
|
+
return pd.DataFrame(columns=['support', 'itemsets'])
|
|
164
|
+
# end = time.time()
|
|
165
|
+
# print(f"time for find fp_growth: {end-start} seconds, {len(fp)} frequent patterns.")
|
|
166
|
+
# start = time.time()
|
|
167
|
+
fp = spatial_query.find_maximal_patterns(fp=fp)
|
|
168
|
+
# end = time.time()
|
|
169
|
+
# print(f"time for identify maximal patterns: {end - start} seconds")
|
|
170
|
+
|
|
171
|
+
fp['itemsets'] = fp['itemsets'].apply(lambda x: list(sorted(x)))
|
|
172
|
+
fp.sort_values(by='support', ascending=False, inplace=True, ignore_index=True)
|
|
173
|
+
|
|
174
|
+
return fp
|
|
175
|
+
|
|
176
|
+
def find_fp_dist(self,
|
|
177
|
+
ct: str,
|
|
178
|
+
dataset: Union[str, List[str]] = None,
|
|
179
|
+
max_dist: float = 100,
|
|
180
|
+
min_size: int = 0,
|
|
181
|
+
min_support: float = 0.5,
|
|
182
|
+
max_ns: int = 100
|
|
183
|
+
):
|
|
184
|
+
"""
|
|
185
|
+
Find frequent patterns within the radius of certain cell type in multiple fields of view.
|
|
186
|
+
|
|
187
|
+
Parameter
|
|
188
|
+
---------
|
|
189
|
+
ct:
|
|
190
|
+
Cell type name.
|
|
191
|
+
dataset:
|
|
192
|
+
Datasets for searching for frequent patterns.
|
|
193
|
+
Use all datasets if dataset=None.
|
|
194
|
+
max_dist:
|
|
195
|
+
Maximum distance for considering a cell as a neighbor.
|
|
196
|
+
min_size:
|
|
197
|
+
Minimum neighborhood size for each point to consider.
|
|
198
|
+
min_support:
|
|
199
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
200
|
+
max_ns:
|
|
201
|
+
Upper limit of neighbors for each point.
|
|
202
|
+
|
|
203
|
+
Return
|
|
204
|
+
------
|
|
205
|
+
Frequent patterns in the neighborhood of certain cell type.
|
|
206
|
+
"""
|
|
207
|
+
# Search transactions for each field of view, find the frequent patterns of integrated transactions
|
|
208
|
+
if_exist_label = [ct in s.labels.unique() for s in self.spatial_queries]
|
|
209
|
+
if not any(if_exist_label):
|
|
210
|
+
raise ValueError(f"Found no {self.label_key} in any datasets!")
|
|
211
|
+
|
|
212
|
+
if dataset is None:
|
|
213
|
+
dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
214
|
+
if isinstance(dataset, str):
|
|
215
|
+
dataset = [dataset]
|
|
216
|
+
|
|
217
|
+
# test if the input dataset name is valid
|
|
218
|
+
valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
219
|
+
for ds in dataset:
|
|
220
|
+
if ds not in valid_ds_names:
|
|
221
|
+
raise ValueError(f"Invalid input dataset name: {ds}.\n "
|
|
222
|
+
f"Valid dataset names are: {set(valid_ds_names)}")
|
|
223
|
+
|
|
224
|
+
max_dist = min(max_dist, self.max_radius)
|
|
225
|
+
# start = time.time()
|
|
226
|
+
transactions = []
|
|
227
|
+
for s in self.spatial_queries:
|
|
228
|
+
if s.dataset.split('_')[0] not in dataset:
|
|
229
|
+
continue
|
|
230
|
+
cell_pos = s.spatial_pos
|
|
231
|
+
labels = np.array(s.labels)
|
|
232
|
+
if ct not in np.unique(labels):
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
cinds = [id for id, l in enumerate(labels) if l == ct]
|
|
236
|
+
ct_pos = cell_pos[cinds]
|
|
237
|
+
|
|
238
|
+
idxs = s.kd_tree.query_ball_point(ct_pos, r=max_dist, return_sorted=False, workers=-1)
|
|
239
|
+
|
|
240
|
+
for i_id, idx in zip(cinds, idxs):
|
|
241
|
+
transaction = [labels[i] for i in idx[:min(max_ns, len(idx))] if i != i_id]
|
|
242
|
+
if len(transaction) > min_size:
|
|
243
|
+
transactions.append(transaction)
|
|
244
|
+
|
|
245
|
+
# end = time.time()
|
|
246
|
+
# print(f"time for building {len(transactions)} transactions: {end-start} seconds.")
|
|
247
|
+
|
|
248
|
+
# start = time.time()
|
|
249
|
+
mlb = MultiLabelBinarizer()
|
|
250
|
+
encoded_data = mlb.fit_transform(transactions)
|
|
251
|
+
df = pd.DataFrame(encoded_data.astype(bool), columns=mlb.classes_)
|
|
252
|
+
# end = time.time()
|
|
253
|
+
# print(f"time for building df for fpgrowth: {end - start} seconds")
|
|
254
|
+
|
|
255
|
+
# start = time.time()
|
|
256
|
+
fp = fpgrowth(df, min_support=min_support, use_colnames=True)
|
|
257
|
+
if len(fp) == 0:
|
|
258
|
+
return pd.DataFrame(columns=['support', 'itemsets'])
|
|
259
|
+
# end = time.time()
|
|
260
|
+
# print(f"time for find fp_growth: {end - start} seconds, {len(fp)} frequent patterns.")
|
|
261
|
+
|
|
262
|
+
# start = time.time()
|
|
263
|
+
fp = spatial_query.find_maximal_patterns(fp=fp)
|
|
264
|
+
# end = time.time()
|
|
265
|
+
# print(f"time for identify maximal patterns: {end - start} seconds")
|
|
266
|
+
|
|
267
|
+
fp['itemsets'] = fp['itemsets'].apply(lambda x: list(sorted(x)))
|
|
268
|
+
fp.sort_values(by='support', ascending=False, inplace=True, ignore_index=True)
|
|
269
|
+
|
|
270
|
+
return fp
|
|
271
|
+
|
|
272
|
+
def motif_enrichment_knn(self,
|
|
273
|
+
ct: str,
|
|
274
|
+
motifs: Union[str, List[str]] = None,
|
|
275
|
+
dataset: Union[str, List[str]] = None,
|
|
276
|
+
k: int = 30,
|
|
277
|
+
min_support: float = 0.5,
|
|
278
|
+
max_dist: float = 500.0,
|
|
279
|
+
) -> pd.DataFrame:
|
|
280
|
+
"""
|
|
281
|
+
Perform motif enrichment analysis using k-nearest neighbors (KNN) in multiple fields of view.
|
|
282
|
+
|
|
283
|
+
Parameter
|
|
284
|
+
---------
|
|
285
|
+
ct:
|
|
286
|
+
The cell type of the center cell.
|
|
287
|
+
motifs:
|
|
288
|
+
Specified motifs to be tested.
|
|
289
|
+
If motifs=None, find the frequent patterns as motifs within
|
|
290
|
+
the neighborhood of center cell type in each fov.
|
|
291
|
+
dataset:
|
|
292
|
+
Datasets for searching for frequent patterns and performing enrichment analysis.
|
|
293
|
+
Use all datasets if dataset=None.
|
|
294
|
+
k:
|
|
295
|
+
Number of nearest neighbors to consider.
|
|
296
|
+
min_support:
|
|
297
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
298
|
+
dis_duplicates:
|
|
299
|
+
Distinguish duplicates in patterns if dis_duplicates=True. This will consider transactions within duplicates
|
|
300
|
+
like (A, A, A, B, C) otherwise only patterns with unique cell types will be considered like (A, B, C).
|
|
301
|
+
max_dist:
|
|
302
|
+
Maximum distance for neighbors (default: 500).
|
|
303
|
+
|
|
304
|
+
Return
|
|
305
|
+
------
|
|
306
|
+
pd.Dataframe containing the cell type name, motifs, number of motifs nearby given cell type,
|
|
307
|
+
number of spots of cell type, number of motifs in single FOV, p value of hypergeometric distribution.
|
|
308
|
+
"""
|
|
309
|
+
if dataset is None:
|
|
310
|
+
dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
311
|
+
if isinstance(dataset, str):
|
|
312
|
+
dataset = [dataset]
|
|
313
|
+
|
|
314
|
+
max_dist = min(max_dist, self.max_radius)
|
|
315
|
+
|
|
316
|
+
out = []
|
|
317
|
+
if_exist_label = [ct in s.labels.unique() for s in self.spatial_queries]
|
|
318
|
+
if not any(if_exist_label):
|
|
319
|
+
raise ValueError(f"Found no {self.label_key} in any datasets!")
|
|
320
|
+
|
|
321
|
+
# Check whether specify motifs. If not, search frequent patterns among specified datasets
|
|
322
|
+
# and use them as interested motifs
|
|
323
|
+
all_labels = pd.concat([s.labels for s in self.spatial_queries])
|
|
324
|
+
labels_unique_all = set(all_labels.unique())
|
|
325
|
+
if motifs is None:
|
|
326
|
+
fp = self.find_fp_knn(ct=ct, k=k, dataset=dataset,
|
|
327
|
+
min_support=min_support, max_dist=max_dist)
|
|
328
|
+
motifs = fp['itemsets'].tolist()
|
|
329
|
+
else:
|
|
330
|
+
if isinstance(motifs, str):
|
|
331
|
+
motifs = [motifs]
|
|
332
|
+
|
|
333
|
+
motifs_exc = [m for m in motifs if m not in labels_unique_all]
|
|
334
|
+
if len(motifs_exc) != 0:
|
|
335
|
+
print(f"Found no {motifs_exc} in {dataset}. Ignoring them.")
|
|
336
|
+
motifs = [m for m in motifs if m not in motifs_exc]
|
|
337
|
+
if len(motifs) == 0:
|
|
338
|
+
raise ValueError(f"All cell types in motifs are missed in {self.label_key}.")
|
|
339
|
+
motifs = [motifs]
|
|
340
|
+
|
|
341
|
+
for motif in motifs:
|
|
342
|
+
n_labels = 0
|
|
343
|
+
n_ct = 0
|
|
344
|
+
n_motif_labels = 0
|
|
345
|
+
n_motif_ct = 0
|
|
346
|
+
|
|
347
|
+
motif = list(motif) if not isinstance(motif, list) else motif
|
|
348
|
+
sort_motif = sorted(motif)
|
|
349
|
+
|
|
350
|
+
# Calculate statistics of each dataset
|
|
351
|
+
for fov, s in enumerate(self.spatial_queries):
|
|
352
|
+
if s.dataset.split('_')[0] not in dataset:
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
cell_pos = s.spatial_pos
|
|
356
|
+
labels = np.array(s.labels)
|
|
357
|
+
labels_unique = np.unique(labels)
|
|
358
|
+
|
|
359
|
+
contain_motif = [m in labels_unique for m in motif]
|
|
360
|
+
if not np.all(contain_motif):
|
|
361
|
+
n_labels += labels.shape[0]
|
|
362
|
+
n_ct += np.sum(labels == ct)
|
|
363
|
+
continue
|
|
364
|
+
else:
|
|
365
|
+
n_labels += labels.shape[0]
|
|
366
|
+
label_encoder = LabelEncoder()
|
|
367
|
+
int_labels = label_encoder.fit_transform(labels)
|
|
368
|
+
int_motifs = label_encoder.transform(np.array(motif))
|
|
369
|
+
|
|
370
|
+
dists, idxs = s.kd_tree.query(cell_pos, k=k + 1, workers=-1)
|
|
371
|
+
num_cells = idxs.shape[0]
|
|
372
|
+
num_types = len(label_encoder.classes_)
|
|
373
|
+
|
|
374
|
+
valid_neighbors = dists[:, 1:] <= max_dist
|
|
375
|
+
filtered_idxs = np.where(valid_neighbors, idxs[:, 1:], -1)
|
|
376
|
+
flat_neighbors = filtered_idxs.flatten()
|
|
377
|
+
valid_neighbors_flat = valid_neighbors.flatten()
|
|
378
|
+
neighbor_labels = np.where(valid_neighbors_flat, int_labels[flat_neighbors], -1)
|
|
379
|
+
valid_mask = neighbor_labels != -1
|
|
380
|
+
|
|
381
|
+
neighbor_matrix = np.zeros((num_cells * k, num_types), dtype=int)
|
|
382
|
+
neighbor_matrix[np.arange(len(neighbor_labels))[valid_mask], neighbor_labels[valid_mask]] = 1
|
|
383
|
+
neighbor_counts = neighbor_matrix.reshape(num_cells, k, num_types).sum(axis=1)
|
|
384
|
+
|
|
385
|
+
n_motif_labels += np.sum(np.all(neighbor_counts[:, int_motifs] > 0, axis=1))
|
|
386
|
+
|
|
387
|
+
if ct in np.unique(labels):
|
|
388
|
+
int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
|
|
389
|
+
mask = int_labels == int_ct
|
|
390
|
+
n_motif_ct += np.sum(np.all(neighbor_counts[mask][:, int_motifs] > 0, axis=1))
|
|
391
|
+
n_ct += np.sum(mask)
|
|
392
|
+
|
|
393
|
+
# for i in range(len(labels)):
|
|
394
|
+
# if spatial_query.has_motif(sort_motif, [labels[idx] for idx in idxs[i][1:]]):
|
|
395
|
+
# n_motif_labels += 1
|
|
396
|
+
# n_labels += len(labels)
|
|
397
|
+
|
|
398
|
+
# if ct not in labels.unique():
|
|
399
|
+
# continue
|
|
400
|
+
# cinds = [i for i, l in enumerate(labels) if l == ct]
|
|
401
|
+
#
|
|
402
|
+
# for i in cinds:
|
|
403
|
+
# inds = [ind for ind, d in enumerate(dists[i]) if d < max_dist]
|
|
404
|
+
# if len(inds) > 1:
|
|
405
|
+
# if spatial_query.has_motif(sort_motif, [labels[idx] for idx in idxs[i][inds[1:]]]):
|
|
406
|
+
# n_motif_ct += 1
|
|
407
|
+
|
|
408
|
+
# n_ct += len(cinds)
|
|
409
|
+
|
|
410
|
+
if ct in motif:
|
|
411
|
+
n_ct = round(n_ct / motif.count(ct))
|
|
412
|
+
|
|
413
|
+
hyge = hypergeom(M=n_labels, n=n_ct, N=n_motif_labels)
|
|
414
|
+
motif_out = {'center': ct, 'motifs': sort_motif, 'n_center_motif': n_motif_ct,
|
|
415
|
+
'n_center': n_ct, 'n_motif': n_motif_labels, 'expectation': hyge.mean(), 'p-values': hyge.sf(n_motif_ct)}
|
|
416
|
+
out.append(motif_out)
|
|
417
|
+
|
|
418
|
+
out_pd = pd.DataFrame(out)
|
|
419
|
+
|
|
420
|
+
if len(out_pd) == 1:
|
|
421
|
+
out_pd['if_significant'] = True if out_pd['p-values'][0] < 0.05 else False
|
|
422
|
+
return out_pd
|
|
423
|
+
else:
|
|
424
|
+
p_values = out_pd['p-values'].tolist()
|
|
425
|
+
if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
|
|
426
|
+
alpha=0.05,
|
|
427
|
+
method='poscorr')
|
|
428
|
+
out_pd['corrected p-values'] = corrected_p_values
|
|
429
|
+
out_pd['if_significant'] = if_rejected
|
|
430
|
+
out_pd = out_pd.sort_values(by='corrected p-values', ignore_index=True)
|
|
431
|
+
return out_pd
|
|
432
|
+
|
|
433
|
+
def motif_enrichment_dist(self,
|
|
434
|
+
ct: str,
|
|
435
|
+
motifs: Union[str, List[str]] = None,
|
|
436
|
+
dataset: Union[str, List[str]] = None,
|
|
437
|
+
max_dist: float = 100,
|
|
438
|
+
min_size: int = 0,
|
|
439
|
+
min_support: float = 0.5,
|
|
440
|
+
max_ns: int = 100) -> DataFrame:
|
|
441
|
+
"""
|
|
442
|
+
Perform motif enrichment analysis within a specified radius-based neighborhood in multiple fields of view.
|
|
443
|
+
|
|
444
|
+
Parameter
|
|
445
|
+
---------
|
|
446
|
+
ct:
|
|
447
|
+
Cell type of the center cell.
|
|
448
|
+
motifs:
|
|
449
|
+
Specified motifs to be tested.
|
|
450
|
+
If motifs=None, find the frequent patterns as motifs within the neighborhood of center cell type.
|
|
451
|
+
dataset:
|
|
452
|
+
Datasets for searching for frequent patterns and performing enrichment analysis.
|
|
453
|
+
Use all datasets if dataset=None.
|
|
454
|
+
max_dist:
|
|
455
|
+
Maximum distance for considering a cell as a neighbor.
|
|
456
|
+
min_size:
|
|
457
|
+
Minimum neighborhood size for each point to consider.
|
|
458
|
+
min_support:
|
|
459
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
460
|
+
dis_duplicates:
|
|
461
|
+
Distinguish duplicates in patterns if dis_duplicates=True. This will consider transactions within duplicates
|
|
462
|
+
like (A, A, A, B, C) otherwise only patterns with unique cell types will be considered like (A, B, C).
|
|
463
|
+
max_ns:
|
|
464
|
+
Maximum number of neighborhood size for each point.
|
|
465
|
+
Returns
|
|
466
|
+
-------
|
|
467
|
+
Tuple containing counts and statistical measures.
|
|
468
|
+
"""
|
|
469
|
+
if dataset is None:
|
|
470
|
+
dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
471
|
+
if isinstance(dataset, str):
|
|
472
|
+
dataset = [dataset]
|
|
473
|
+
|
|
474
|
+
out = []
|
|
475
|
+
if_exist_label = [ct in s.labels.unique() for s in self.spatial_queries]
|
|
476
|
+
if not any(if_exist_label):
|
|
477
|
+
raise ValueError(f"Found no {self.label_key} in any datasets!")
|
|
478
|
+
|
|
479
|
+
max_dist = min(max_dist, self.max_radius)
|
|
480
|
+
|
|
481
|
+
# Check whether specify motifs. If not, search frequent patterns among specified datasets
|
|
482
|
+
# and use them as interested motifs
|
|
483
|
+
if motifs is None:
|
|
484
|
+
fp = self.find_fp_dist(ct=ct, dataset=dataset, max_dist=max_dist, min_size=min_size,
|
|
485
|
+
min_support=min_support, max_ns=max_ns)
|
|
486
|
+
motifs = fp['itemsets'].tolist()
|
|
487
|
+
else:
|
|
488
|
+
if isinstance(motifs, str):
|
|
489
|
+
motifs = [motifs]
|
|
490
|
+
|
|
491
|
+
all_labels = pd.concat([s.labels for s in self.spatial_queries])
|
|
492
|
+
labels_unique_all = set(all_labels.unique())
|
|
493
|
+
motifs_exc = [m for m in motifs if m not in labels_unique_all]
|
|
494
|
+
if len(motifs_exc) != 0:
|
|
495
|
+
print(f"Found no {motifs_exc} in {dataset}! Ignoring them.")
|
|
496
|
+
motifs = [m for m in motifs if m not in motifs_exc]
|
|
497
|
+
if len(motifs) == 0:
|
|
498
|
+
raise ValueError(f"All cell types in motifs are missed in {self.label_key}.")
|
|
499
|
+
motifs = [motifs]
|
|
500
|
+
|
|
501
|
+
for motif in motifs:
|
|
502
|
+
n_labels = 0
|
|
503
|
+
n_ct = 0
|
|
504
|
+
n_motif_labels = 0
|
|
505
|
+
n_motif_ct = 0
|
|
506
|
+
|
|
507
|
+
motif = list(motif) if not isinstance(motif, list) else motif
|
|
508
|
+
sort_motif = sorted(motif)
|
|
509
|
+
|
|
510
|
+
for s in self.spatial_queries:
|
|
511
|
+
if s.dataset.split('_')[0] not in dataset:
|
|
512
|
+
continue
|
|
513
|
+
cell_pos = s.spatial_pos
|
|
514
|
+
labels = np.array(s.labels)
|
|
515
|
+
labels_unique = np.unique(labels)
|
|
516
|
+
|
|
517
|
+
contain_motif = [m in labels_unique for m in motif]
|
|
518
|
+
if not np.all(contain_motif):
|
|
519
|
+
n_labels += labels.shape[0]
|
|
520
|
+
n_ct += np.sum(labels == ct)
|
|
521
|
+
continue
|
|
522
|
+
else:
|
|
523
|
+
n_labels += labels.shape[0]
|
|
524
|
+
_, matching_cells_indices = s._query_pattern(motif)
|
|
525
|
+
if not matching_cells_indices:
|
|
526
|
+
# if matching_cells_indices is empty, it indicates no motif are grouped together within upper limit of radius (500)
|
|
527
|
+
continue
|
|
528
|
+
matching_cells_indices = np.concatenate([t for t in matching_cells_indices.values()])
|
|
529
|
+
matching_cells_indices = np.unique(matching_cells_indices)
|
|
530
|
+
matching_cells_indices.sort()
|
|
531
|
+
# print(f"number of cells skipped: {len(matching_cells_indices)}")
|
|
532
|
+
# print(f"proportion of cells searched: {len(matching_cells_indices) / len(s.spatial_pos)}")
|
|
533
|
+
idxs_in_grids = s.kd_tree.query_ball_point(
|
|
534
|
+
s.spatial_pos[matching_cells_indices],
|
|
535
|
+
r=max_dist,
|
|
536
|
+
return_sorted=False,
|
|
537
|
+
workers=-1
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# using numppy
|
|
541
|
+
label_encoder = LabelEncoder()
|
|
542
|
+
int_labels = label_encoder.fit_transform(labels)
|
|
543
|
+
int_motifs = label_encoder.transform(np.array(motif))
|
|
544
|
+
|
|
545
|
+
num_cells = len(s.spatial_pos)
|
|
546
|
+
num_types = len(label_encoder.classes_)
|
|
547
|
+
# filter center out of neighbors
|
|
548
|
+
idxs_filter = [np.array(ids)[np.array(ids) != i][:min(max_ns, len(ids))] for i, ids in
|
|
549
|
+
zip(matching_cells_indices, idxs_in_grids)]
|
|
550
|
+
|
|
551
|
+
num_matching_cells = len(matching_cells_indices)
|
|
552
|
+
flat_neighbors = np.concatenate(idxs_filter)
|
|
553
|
+
row_indices = np.repeat(np.arange(num_matching_cells), [len(neigh) for neigh in idxs_filter])
|
|
554
|
+
neighbor_labels = int_labels[flat_neighbors]
|
|
555
|
+
|
|
556
|
+
neighbor_matrix = np.zeros((num_matching_cells, num_types), dtype=int)
|
|
557
|
+
np.add.at(neighbor_matrix, (row_indices, neighbor_labels), 1)
|
|
558
|
+
|
|
559
|
+
n_motif_labels += np.sum(np.all(neighbor_matrix[:, int_motifs] > 0, axis=1))
|
|
560
|
+
|
|
561
|
+
if ct in np.unique(labels):
|
|
562
|
+
int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
|
|
563
|
+
mask = int_labels[matching_cells_indices] == int_ct
|
|
564
|
+
n_motif_ct += np.sum(np.all(neighbor_matrix[mask][:, int_motifs] > 0, axis=1))
|
|
565
|
+
n_ct += np.sum(s.labels == ct)
|
|
566
|
+
|
|
567
|
+
# ~10s using C++ codes
|
|
568
|
+
# idxs = idxs.tolist()
|
|
569
|
+
# cinds = [i for i, label in enumerate(labels) if label == ct]
|
|
570
|
+
# n_motif_ct_s, n_motif_labels_s = spatial_module_utils.search_motif_dist(
|
|
571
|
+
# motif, idxs, labels, cinds, max_ns
|
|
572
|
+
# )
|
|
573
|
+
# n_motif_ct += n_motif_ct_s
|
|
574
|
+
# n_motif_labels += n_motif_labels_s
|
|
575
|
+
# n_ct += len(cinds)
|
|
576
|
+
|
|
577
|
+
# original codes, ~ minutes
|
|
578
|
+
# for i in range(len(idxs)):
|
|
579
|
+
# e = min(len(idxs[i]), max_ns)
|
|
580
|
+
# if spatial_query.has_motif(sort_motif, [labels[idx] for idx in idxs[i][:e] if idx != i]):
|
|
581
|
+
# n_motif_labels += 1
|
|
582
|
+
#
|
|
583
|
+
# if ct not in labels.unique():
|
|
584
|
+
# continue
|
|
585
|
+
#
|
|
586
|
+
# cinds = [i for i, label in enumerate(labels) if label == ct]
|
|
587
|
+
#
|
|
588
|
+
# for i in cinds:
|
|
589
|
+
# e = min(len(idxs[i]), max_ns)
|
|
590
|
+
# if spatial_query.has_motif(sort_motif, [labels[idx] for idx in idxs[i][:e] if idx != i]):
|
|
591
|
+
# n_motif_ct += 1
|
|
592
|
+
#
|
|
593
|
+
# n_ct += len(cinds)
|
|
594
|
+
|
|
595
|
+
if ct in motif:
|
|
596
|
+
n_ct = round(n_ct / motif.count(ct))
|
|
597
|
+
hyge = hypergeom(M=n_labels, n=n_ct, N=n_motif_labels)
|
|
598
|
+
motif_out = {'center': ct, 'motifs': sort_motif, 'n_center_motif': n_motif_ct,
|
|
599
|
+
'n_center': n_ct, 'n_motif': n_motif_labels, 'expectation': hyge.mean(), 'p-values': hyge.sf(n_motif_ct)}
|
|
600
|
+
out.append(motif_out)
|
|
601
|
+
|
|
602
|
+
out_pd = pd.DataFrame(out)
|
|
603
|
+
|
|
604
|
+
if len(out_pd) == 1:
|
|
605
|
+
out_pd['if_significant'] = True if out_pd['p-values'][0] < 0.05 else False
|
|
606
|
+
return out_pd
|
|
607
|
+
else:
|
|
608
|
+
p_values = out_pd['p-values'].tolist()
|
|
609
|
+
if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
|
|
610
|
+
alpha=0.05,
|
|
611
|
+
method='poscorr')
|
|
612
|
+
out_pd['corrected p-values'] = corrected_p_values
|
|
613
|
+
out_pd['if_significant'] = if_rejected
|
|
614
|
+
out_pd = out_pd.sort_values(by='corrected p-values', ignore_index=True)
|
|
615
|
+
return out_pd
|
|
616
|
+
|
|
617
|
+
def find_fp_knn_fov(self,
|
|
618
|
+
ct: str,
|
|
619
|
+
dataset_i: str,
|
|
620
|
+
k: int = 30,
|
|
621
|
+
min_support: float = 0.5,
|
|
622
|
+
max_dist: float = 500.0
|
|
623
|
+
) -> pd.DataFrame:
|
|
624
|
+
"""
|
|
625
|
+
Find frequent patterns within the KNNs of specific cell type of interest in single field of view.
|
|
626
|
+
|
|
627
|
+
Parameter
|
|
628
|
+
---------
|
|
629
|
+
ct:
|
|
630
|
+
Cell type name.
|
|
631
|
+
dataset_i:
|
|
632
|
+
Datasets for searching for frequent patterns in dataset_i format.
|
|
633
|
+
k:
|
|
634
|
+
Number of nearest neighbors.
|
|
635
|
+
min_support:
|
|
636
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
637
|
+
max_dist:
|
|
638
|
+
Maximum distance for considering a cell as a neighbor.
|
|
639
|
+
|
|
640
|
+
Return
|
|
641
|
+
------
|
|
642
|
+
Frequent patterns in the neighborhood of certain cell type.
|
|
643
|
+
"""
|
|
644
|
+
if dataset_i not in self.datasets:
|
|
645
|
+
raise ValueError(f"Found no {dataset_i.split('_')[0]} in any datasets.")
|
|
646
|
+
|
|
647
|
+
max_dist = min(max_dist, self.max_radius)
|
|
648
|
+
|
|
649
|
+
sp_object = self.spatial_queries[self.datasets.index(dataset_i)]
|
|
650
|
+
cell_pos = sp_object.spatial_pos
|
|
651
|
+
labels = np.array(sp_object.labels)
|
|
652
|
+
if ct not in np.unique(labels):
|
|
653
|
+
return pd.DataFrame(columns=['support', 'itemsets'])
|
|
654
|
+
|
|
655
|
+
ct_pos = cell_pos[labels == ct]
|
|
656
|
+
|
|
657
|
+
# Identify frequent patterns of cell types, including those subsets of patterns
|
|
658
|
+
# whose support value exceeds min_support. Focus solely on the multiplicity
|
|
659
|
+
# of cell types, rather than their frequency.
|
|
660
|
+
fp, _, _ = sp_object.build_fptree_knn(
|
|
661
|
+
cell_pos=ct_pos,
|
|
662
|
+
k=k,
|
|
663
|
+
min_support=min_support,
|
|
664
|
+
if_max=False,
|
|
665
|
+
max_dist=max_dist,
|
|
666
|
+
)
|
|
667
|
+
return fp
|
|
668
|
+
|
|
669
|
+
def find_fp_dist_fov(self,
|
|
670
|
+
ct: str,
|
|
671
|
+
dataset_i: str,
|
|
672
|
+
max_dist: float = 100,
|
|
673
|
+
min_size: int = 0,
|
|
674
|
+
min_support: float = 0.5,
|
|
675
|
+
max_ns: int = 100
|
|
676
|
+
):
|
|
677
|
+
"""
|
|
678
|
+
Find frequent patterns within the radius-based neighborhood of specific cell type of interest
|
|
679
|
+
in single field of view.
|
|
680
|
+
|
|
681
|
+
Parameter
|
|
682
|
+
---------
|
|
683
|
+
ct:
|
|
684
|
+
Cell type name.
|
|
685
|
+
dataset_i:
|
|
686
|
+
Datasets for searching for frequent patterns in dataset_i format.
|
|
687
|
+
max_dist:
|
|
688
|
+
Maximum distance for considering a cell as a neighbor.
|
|
689
|
+
min_size:
|
|
690
|
+
Minimum neighborhood size for each point to consider.
|
|
691
|
+
min_support:
|
|
692
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
693
|
+
max_ns:
|
|
694
|
+
Maximum number of neighborhood size for each point.
|
|
695
|
+
|
|
696
|
+
Return
|
|
697
|
+
------
|
|
698
|
+
Frequent patterns in the neighborhood of certain cell type.
|
|
699
|
+
"""
|
|
700
|
+
if dataset_i not in self.datasets:
|
|
701
|
+
raise ValueError(f"Found no {dataset_i.split('_')[0]} in any datasets.")
|
|
702
|
+
|
|
703
|
+
max_dist = min(max_dist, self.max_radius)
|
|
704
|
+
|
|
705
|
+
sp_object = self.spatial_queries[self.datasets.index(dataset_i)]
|
|
706
|
+
cell_pos = sp_object.spatial_pos
|
|
707
|
+
labels = sp_object.labels
|
|
708
|
+
if ct not in labels.unique():
|
|
709
|
+
return pd.DataFrame(columns=['support, itemsets'])
|
|
710
|
+
|
|
711
|
+
cinds = [id for id, l in enumerate(labels) if l == ct]
|
|
712
|
+
ct_pos = cell_pos[cinds]
|
|
713
|
+
|
|
714
|
+
fp, _, _ = sp_object.build_fptree_dist(cell_pos=ct_pos,
|
|
715
|
+
max_dist=max_dist,
|
|
716
|
+
min_support=min_support,
|
|
717
|
+
min_size=min_size,
|
|
718
|
+
if_max=False,
|
|
719
|
+
cinds=cinds,
|
|
720
|
+
max_ns=max_ns,
|
|
721
|
+
)
|
|
722
|
+
return fp
|
|
723
|
+
|
|
724
|
+
def differential_analysis_knn(self,
|
|
725
|
+
ct: str,
|
|
726
|
+
datasets: List[str],
|
|
727
|
+
k: int = 30,
|
|
728
|
+
min_support: float = 0.5,
|
|
729
|
+
max_dist: float = 500,
|
|
730
|
+
):
|
|
731
|
+
"""
|
|
732
|
+
Explore the differences in cell types and frequent patterns of cell types in spatial KNN neighborhood of cell
|
|
733
|
+
type of interest. Perform differential analysis of frequent patterns in specified datasets.
|
|
734
|
+
|
|
735
|
+
Parameter
|
|
736
|
+
---------
|
|
737
|
+
ct:
|
|
738
|
+
Cell type of interest as center point.
|
|
739
|
+
datasets:
|
|
740
|
+
Dataset names used to perform differential analysis
|
|
741
|
+
k:
|
|
742
|
+
Number of nearest neighbors.
|
|
743
|
+
min_support:
|
|
744
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
745
|
+
max_dist:
|
|
746
|
+
Maximum distance for considering a cell as a neighbor.
|
|
747
|
+
|
|
748
|
+
Return
|
|
749
|
+
------
|
|
750
|
+
Dataframes with significant enriched patterns in differential analysis
|
|
751
|
+
"""
|
|
752
|
+
if len(datasets) != 2:
|
|
753
|
+
raise ValueError("Require 2 datasets for differential analysis.")
|
|
754
|
+
|
|
755
|
+
max_dist = min(max_dist, self.max_radius)
|
|
756
|
+
|
|
757
|
+
# Check if the two datasets are valid
|
|
758
|
+
valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
759
|
+
for ds in datasets:
|
|
760
|
+
if ds not in valid_ds_names:
|
|
761
|
+
raise ValueError(f"Invalid input dataset name: {ds}.\n"
|
|
762
|
+
f"Valid dataset names are: {set(valid_ds_names)}")
|
|
763
|
+
|
|
764
|
+
flag = 0
|
|
765
|
+
# Identify frequent patterns in each dataset
|
|
766
|
+
for d in datasets:
|
|
767
|
+
fp_d = {}
|
|
768
|
+
dataset_i = [ds for ds in self.datasets if ds.split('_')[0] == d]
|
|
769
|
+
for d_i in dataset_i:
|
|
770
|
+
fp_fov = self.find_fp_knn_fov(ct=ct,
|
|
771
|
+
dataset_i=d_i,
|
|
772
|
+
k=k,
|
|
773
|
+
min_support=min_support,
|
|
774
|
+
max_dist=max_dist)
|
|
775
|
+
if len(fp_fov) > 0:
|
|
776
|
+
fp_d[d_i] = fp_fov
|
|
777
|
+
|
|
778
|
+
if len(fp_d) == 1:
|
|
779
|
+
common_patterns = list(fp_d.values())[0]
|
|
780
|
+
common_patterns = common_patterns.rename(columns={'support': f"support_{list(fp_d.keys())[0]}"})
|
|
781
|
+
else:
|
|
782
|
+
# in comm_fps, duplicates items are not allowed by using set object
|
|
783
|
+
comm_fps = set.intersection(
|
|
784
|
+
*[set(df['itemsets'].apply(lambda x: tuple(sorted(x)))) for df in
|
|
785
|
+
fp_d.values()]) # the items' order in patterns will not affect the returned intersection
|
|
786
|
+
common_patterns = pd.DataFrame({'itemsets': [list(items) for items in comm_fps]})
|
|
787
|
+
for data_name, df in fp_d.items():
|
|
788
|
+
support_dict = {itemset: support for itemset, support in
|
|
789
|
+
df[['itemsets', 'support']].apply(
|
|
790
|
+
lambda row: (tuple(sorted(row['itemsets'])), row['support']), axis=1)}
|
|
791
|
+
# support_dict = {tuple(itemset): support for itemset, support in df[['itemsets', 'support']].apply(
|
|
792
|
+
# lambda row: (tuple(row['itemsets']), row['support']), axis=1)}
|
|
793
|
+
common_patterns[f"support_{data_name}"] = common_patterns['itemsets'].apply(
|
|
794
|
+
lambda x: support_dict.get(tuple(x), None))
|
|
795
|
+
common_patterns['itemsets'] = common_patterns['itemsets'].apply(tuple)
|
|
796
|
+
if flag == 0:
|
|
797
|
+
fp_datasets = common_patterns
|
|
798
|
+
flag = 1
|
|
799
|
+
else:
|
|
800
|
+
fp_datasets = fp_datasets.merge(common_patterns, how='outer', on='itemsets', ).fillna(0)
|
|
801
|
+
|
|
802
|
+
match_ind_datasets = [
|
|
803
|
+
[col for ind, col in enumerate(fp_datasets.columns) if col.startswith(f"support_{dataset}")] for dataset in
|
|
804
|
+
datasets]
|
|
805
|
+
p_values = []
|
|
806
|
+
dataset_higher_ranks = []
|
|
807
|
+
for index, row in fp_datasets.iterrows():
|
|
808
|
+
group1 = pd.to_numeric(row[match_ind_datasets[0]].values)
|
|
809
|
+
group2 = pd.to_numeric(row[match_ind_datasets[1]].values)
|
|
810
|
+
|
|
811
|
+
# Perform the Mann-Whitney U test
|
|
812
|
+
stat, p = stats.mannwhitneyu(group1, group2, alternative='two-sided', method='auto')
|
|
813
|
+
p_values.append(p)
|
|
814
|
+
|
|
815
|
+
# Label the dataset with higher frequency of patterns based on rank median
|
|
816
|
+
support_rank = pd.concat([pd.DataFrame(group1), pd.DataFrame(group2)]).rank() # ascending
|
|
817
|
+
median_rank1 = support_rank[:len(group1)].median()[0]
|
|
818
|
+
median_rank2 = support_rank[len(group1):].median()[0]
|
|
819
|
+
if median_rank1 > median_rank2:
|
|
820
|
+
dataset_higher_ranks.append(datasets[0])
|
|
821
|
+
else:
|
|
822
|
+
dataset_higher_ranks.append(datasets[1])
|
|
823
|
+
|
|
824
|
+
fp_datasets['dataset_higher_frequency'] = dataset_higher_ranks
|
|
825
|
+
# Apply Benjamini-Hochberg correction for multiple testing problems
|
|
826
|
+
if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
|
|
827
|
+
alpha=0.05,
|
|
828
|
+
method='poscorr')
|
|
829
|
+
|
|
830
|
+
# Add the corrected p-values back to the DataFrame (optional)
|
|
831
|
+
fp_datasets['corrected_p_values'] = corrected_p_values
|
|
832
|
+
fp_datasets['if_significant'] = if_rejected
|
|
833
|
+
|
|
834
|
+
# Return the significant patterns in each dataset
|
|
835
|
+
fp_dataset0 = fp_datasets[
|
|
836
|
+
(fp_datasets['dataset_higher_frequency'] == datasets[0]) & (fp_datasets['if_significant'])
|
|
837
|
+
][['itemsets', 'corrected_p_values']]
|
|
838
|
+
fp_dataset1 = fp_datasets[
|
|
839
|
+
(fp_datasets['dataset_higher_frequency'] == datasets[1]) & (fp_datasets['if_significant'])
|
|
840
|
+
][['itemsets', 'corrected_p_values']]
|
|
841
|
+
fp_dataset0 = fp_dataset0.reset_index(drop=True)
|
|
842
|
+
fp_dataset1 = fp_dataset1.reset_index(drop=True)
|
|
843
|
+
fp_dataset0 = fp_dataset0.sort_values(by='corrected_p_values', ascending=True)
|
|
844
|
+
fp_dataset1 = fp_dataset1.sort_values(by='corrected_p_values', ascending=True)
|
|
845
|
+
return {datasets[0]: fp_dataset0, datasets[1]: fp_dataset1}
|
|
846
|
+
|
|
847
|
+
def differential_analysis_dist(self,
|
|
848
|
+
ct: str,
|
|
849
|
+
datasets: List[str],
|
|
850
|
+
max_dist: float = 100,
|
|
851
|
+
min_support: float = 0.5,
|
|
852
|
+
min_size: int = 0,
|
|
853
|
+
max_ns: int = 100,
|
|
854
|
+
):
|
|
855
|
+
"""
|
|
856
|
+
Explore the differences in cell types and frequent patterns of cell types in spatial radius-based neighborhood
|
|
857
|
+
of cell type of interest. Perform differential analysis of frequent patterns in specified datasets.
|
|
858
|
+
|
|
859
|
+
Parameter
|
|
860
|
+
---------
|
|
861
|
+
ct:
|
|
862
|
+
Cell type of interest as center point.
|
|
863
|
+
datasets:
|
|
864
|
+
Dataset names used to perform differential analysis
|
|
865
|
+
max_dist:
|
|
866
|
+
Maximum distance for considering a cell as a neighbor.
|
|
867
|
+
min_support:
|
|
868
|
+
Threshold of frequency to consider a pattern as a frequent pattern.
|
|
869
|
+
min_size:
|
|
870
|
+
Minimum neighborhood size for each point to consider.
|
|
871
|
+
max_ns:
|
|
872
|
+
Upper limit of neighbors for each point.
|
|
873
|
+
|
|
874
|
+
Return
|
|
875
|
+
------
|
|
876
|
+
Dataframes with significant enriched patterns in differential analysis
|
|
877
|
+
"""
|
|
878
|
+
if len(datasets) != 2:
|
|
879
|
+
raise ValueError("Require 2 datasets for differential analysis.")
|
|
880
|
+
|
|
881
|
+
# Check if the two datasets are valid
|
|
882
|
+
valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
883
|
+
for ds in datasets:
|
|
884
|
+
if ds not in valid_ds_names:
|
|
885
|
+
raise ValueError(f"Invalid input dataset name: {ds}.\n"
|
|
886
|
+
f"Valid dataset names are: {set(valid_ds_names)}")
|
|
887
|
+
|
|
888
|
+
max_dist = min(max_dist, self.max_radius)
|
|
889
|
+
|
|
890
|
+
flag = 0
|
|
891
|
+
# Identify frequent patterns in each dataset
|
|
892
|
+
for d in datasets:
|
|
893
|
+
fp_d = {}
|
|
894
|
+
dataset_i = [ds for ds in self.datasets if ds.split('_')[0] == d]
|
|
895
|
+
for d_i in dataset_i:
|
|
896
|
+
fp_fov = self.find_fp_dist_fov(ct=ct,
|
|
897
|
+
dataset_i=d_i,
|
|
898
|
+
max_dist=max_dist,
|
|
899
|
+
min_size=min_size,
|
|
900
|
+
min_support=min_support,
|
|
901
|
+
max_ns=max_ns)
|
|
902
|
+
if len(fp_fov) > 0:
|
|
903
|
+
fp_d[d_i] = fp_fov
|
|
904
|
+
|
|
905
|
+
if len(fp_d) == 1:
|
|
906
|
+
common_patterns = list(fp_d.values())[0]
|
|
907
|
+
common_patterns = common_patterns.rename(columns={'support': f"support_{list(fp_d.keys())[0]}"})
|
|
908
|
+
else:
|
|
909
|
+
comm_fps = set.intersection(*[set(df['itemsets'].apply(lambda x: tuple(sorted(x)))) for df in
|
|
910
|
+
fp_d.values()]) # the items' order in patterns will not affect the returned intersection
|
|
911
|
+
common_patterns = pd.DataFrame({'itemsets': [list(items) for items in comm_fps]})
|
|
912
|
+
for data_name, df in fp_d.items():
|
|
913
|
+
support_dict = {itemset: support for itemset, support in df[['itemsets', 'support']].apply(
|
|
914
|
+
lambda row: (tuple(sorted(row['itemsets'])), row['support']), axis=1)}
|
|
915
|
+
common_patterns[f"support_{data_name}"] = common_patterns['itemsets'].apply(
|
|
916
|
+
lambda x: support_dict.get(tuple(x), None))
|
|
917
|
+
common_patterns['itemsets'] = common_patterns['itemsets'].apply(tuple)
|
|
918
|
+
if flag == 0:
|
|
919
|
+
fp_datasets = common_patterns
|
|
920
|
+
flag = 1
|
|
921
|
+
else:
|
|
922
|
+
fp_datasets = fp_datasets.merge(common_patterns, how='outer', on='itemsets', ).fillna(0)
|
|
923
|
+
|
|
924
|
+
match_ind_datasets = [
|
|
925
|
+
[col for ind, col in enumerate(fp_datasets.columns) if col.startswith(f"support_{dataset}")] for dataset in
|
|
926
|
+
datasets]
|
|
927
|
+
p_values = []
|
|
928
|
+
dataset_higher_ranks = []
|
|
929
|
+
for index, row in fp_datasets.iterrows():
|
|
930
|
+
group1 = pd.to_numeric(row[match_ind_datasets[0]].values)
|
|
931
|
+
group2 = pd.to_numeric(row[match_ind_datasets[1]].values)
|
|
932
|
+
|
|
933
|
+
# Perform the Mann-Whitney U test
|
|
934
|
+
stat, p = stats.mannwhitneyu(group1, group2, alternative='two-sided', method='auto')
|
|
935
|
+
p_values.append(p)
|
|
936
|
+
|
|
937
|
+
# Label the dataset with higher frequency of patterns based on rank sum
|
|
938
|
+
support_rank = pd.concat([pd.DataFrame(group1), pd.DataFrame(group2)]).rank() # ascending
|
|
939
|
+
median_rank1 = support_rank[:len(group1)].median()[0]
|
|
940
|
+
median_rank2 = support_rank[len(group1):].median()[0]
|
|
941
|
+
if median_rank1 > median_rank2:
|
|
942
|
+
dataset_higher_ranks.append(datasets[0])
|
|
943
|
+
else:
|
|
944
|
+
dataset_higher_ranks.append(datasets[1])
|
|
945
|
+
|
|
946
|
+
fp_datasets['dataset_higher_frequency'] = dataset_higher_ranks
|
|
947
|
+
# Apply Benjamini-Hochberg correction for multiple testing problems
|
|
948
|
+
if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
|
|
949
|
+
alpha=0.05,
|
|
950
|
+
method='poscorr')
|
|
951
|
+
|
|
952
|
+
# Add the corrected p-values back to the DataFrame (optional)
|
|
953
|
+
fp_datasets['corrected_p_values'] = corrected_p_values
|
|
954
|
+
fp_datasets['if_significant'] = if_rejected
|
|
955
|
+
|
|
956
|
+
# Return the significant patterns in each dataset
|
|
957
|
+
fp_dataset0 = fp_datasets[
|
|
958
|
+
(fp_datasets['dataset_higher_frequency'] == datasets[0]) & (fp_datasets['if_significant'])
|
|
959
|
+
][['itemsets', 'corrected_p_values']]
|
|
960
|
+
fp_dataset1 = fp_datasets[
|
|
961
|
+
(fp_datasets['dataset_higher_frequency'] == datasets[1]) & (fp_datasets['if_significant'])
|
|
962
|
+
][['itemsets', 'corrected_p_values']]
|
|
963
|
+
|
|
964
|
+
fp_dataset0 = fp_dataset0.sort_values(by='corrected_p_values', ascending=True)
|
|
965
|
+
fp_dataset1 = fp_dataset1.sort_values(by='corrected_p_values', ascending=True)
|
|
966
|
+
fp_dataset0 = fp_dataset0.reset_index(drop=True)
|
|
967
|
+
fp_dataset1 = fp_dataset1.reset_index(drop=True)
|
|
968
|
+
return {datasets[0]: fp_dataset0, datasets[1]: fp_dataset1}
|
|
969
|
+
|
|
970
|
+
def cell_type_distribution(self,
|
|
971
|
+
dataset: Union[str, List[str]] = None,
|
|
972
|
+
data_type: str = 'number',
|
|
973
|
+
):
|
|
974
|
+
"""
|
|
975
|
+
Visualize the distribution of cell types across datasets using a stacked bar plot.
|
|
976
|
+
|
|
977
|
+
Parameter
|
|
978
|
+
---------
|
|
979
|
+
dataset:
|
|
980
|
+
Datasets for searching.
|
|
981
|
+
data_type:
|
|
982
|
+
Plot bar plot by number of cells or by the proportions of datasets in each cell type.
|
|
983
|
+
Default is 'number' otherwise 'proportion' is used.
|
|
984
|
+
Returns
|
|
985
|
+
-------
|
|
986
|
+
Stacked bar plot
|
|
987
|
+
"""
|
|
988
|
+
if data_type not in ['number', 'proportion']:
|
|
989
|
+
raise ValueError("Invalild data_type. It should be one of 'number' or 'proportion'.")
|
|
990
|
+
|
|
991
|
+
if dataset is None:
|
|
992
|
+
dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
993
|
+
if isinstance(dataset, str):
|
|
994
|
+
dataset = [dataset]
|
|
995
|
+
|
|
996
|
+
valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
997
|
+
for ds in dataset:
|
|
998
|
+
if ds not in valid_ds_names:
|
|
999
|
+
raise ValueError(f"Invalid input dataset name: {ds}.\n "
|
|
1000
|
+
f"Valid dataset names are: {set(valid_ds_names)}")
|
|
1001
|
+
|
|
1002
|
+
summary = defaultdict(lambda: defaultdict(int))
|
|
1003
|
+
|
|
1004
|
+
valid_queries = [s for s in self.spatial_queries if s.dataset.split('_')[0] in dataset]
|
|
1005
|
+
cell_types = set([ct for s in valid_queries for ct in s.labels.unique()])
|
|
1006
|
+
for s in valid_queries:
|
|
1007
|
+
for cell_type in cell_types:
|
|
1008
|
+
summary[s.dataset][cell_type] += np.sum(s.labels == cell_type)
|
|
1009
|
+
|
|
1010
|
+
df = pd.DataFrame([(dataset, cell_type, count)
|
|
1011
|
+
for dataset, cell_types in summary.items()
|
|
1012
|
+
for cell_type, count in cell_types.items()],
|
|
1013
|
+
columns=['Dataset', 'Cell Type', 'Count'])
|
|
1014
|
+
|
|
1015
|
+
df['dataset'] = df['Dataset'].str.split('_').str[0]
|
|
1016
|
+
|
|
1017
|
+
summary = df.groupby(['dataset', 'Cell Type'])['Count'].sum().reset_index()
|
|
1018
|
+
plot_data = summary.pivot(index='Cell Type', columns='dataset', values='Count').fillna(0)
|
|
1019
|
+
|
|
1020
|
+
# Sort the cell types by total count (descending)
|
|
1021
|
+
plot_data = plot_data.sort_values(by=plot_data.columns.tolist(), ascending=False, )
|
|
1022
|
+
|
|
1023
|
+
if data_type != 'number':
|
|
1024
|
+
plot_data = plot_data.div(plot_data.sum(axis=1), axis=0)
|
|
1025
|
+
|
|
1026
|
+
# Create the stacked bar plot
|
|
1027
|
+
ax = plot_data.plot(kind='bar', stacked=True,
|
|
1028
|
+
figsize=(plot_data.shape[0], plot_data.shape[0] * 0.6),
|
|
1029
|
+
edgecolor='black')
|
|
1030
|
+
|
|
1031
|
+
# Customize the plot
|
|
1032
|
+
plt.title(f"Distribution of Cell Types Across Datasets", fontsize=16)
|
|
1033
|
+
plt.xlabel('Cell Types', fontsize=12)
|
|
1034
|
+
if data_type == 'number':
|
|
1035
|
+
plt.ylabel('Number of Cells', fontsize=12)
|
|
1036
|
+
else:
|
|
1037
|
+
plt.ylabel('Proportion of Cells', fontsize=12)
|
|
1038
|
+
|
|
1039
|
+
plt.xticks(rotation=90, ha='right', fontsize=10)
|
|
1040
|
+
|
|
1041
|
+
plt.legend(title='Datasets', loc='center left', bbox_to_anchor=(1, 0.5), fontsize=10)
|
|
1042
|
+
|
|
1043
|
+
plt.tight_layout(rect=[0, 0, 0.85, 1])
|
|
1044
|
+
plt.show()
|
|
1045
|
+
|
|
1046
|
+
def cell_type_distribution_fov(self,
|
|
1047
|
+
dataset: str,
|
|
1048
|
+
data_type: str = 'number',
|
|
1049
|
+
):
|
|
1050
|
+
"""
|
|
1051
|
+
Visualize the distribution of cell types across FOVs in the dataset using a stacked bar plot.
|
|
1052
|
+
Parameter
|
|
1053
|
+
---------
|
|
1054
|
+
dataset:
|
|
1055
|
+
Dataset of searching.
|
|
1056
|
+
data_type:
|
|
1057
|
+
Plot bar plot by number of cells or by the proportions of cell types in each FOV.
|
|
1058
|
+
Default is 'number' otherwise 'proportion' is used.
|
|
1059
|
+
Returns
|
|
1060
|
+
-------
|
|
1061
|
+
Stacked bar plot
|
|
1062
|
+
"""
|
|
1063
|
+
if data_type not in ['number', 'proportion']:
|
|
1064
|
+
raise ValueError("Invalild data_type. It should be one of 'number' or 'proportion'.")
|
|
1065
|
+
|
|
1066
|
+
valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
|
|
1067
|
+
if dataset not in valid_ds_names:
|
|
1068
|
+
raise ValueError(f"Invalid input dataset name: {dataset}. \n"
|
|
1069
|
+
f"Valid dataset names are: {set(valid_ds_names)}")
|
|
1070
|
+
valid_queries = [s for s in self.spatial_queries if s.dataset.split('_')[0] == dataset]
|
|
1071
|
+
cell_types = set([ct for s in valid_queries for ct in s.labels.unique()])
|
|
1072
|
+
|
|
1073
|
+
summary = defaultdict(lambda: defaultdict(int))
|
|
1074
|
+
for s in valid_queries:
|
|
1075
|
+
for cell_type in cell_types:
|
|
1076
|
+
summary[s.dataset][cell_type] = np.sum(s.labels == cell_type)
|
|
1077
|
+
|
|
1078
|
+
df = pd.DataFrame([(dataset, cell_type, count)
|
|
1079
|
+
for dataset, cell_types in summary.items()
|
|
1080
|
+
for cell_type, count in cell_types.items()],
|
|
1081
|
+
columns=['Dataset', 'Cell Type', 'Count'])
|
|
1082
|
+
|
|
1083
|
+
df['FOV'] = df['Dataset'].str.split('_').str[1]
|
|
1084
|
+
|
|
1085
|
+
summary = df.groupby(['FOV', 'Cell Type'])['Count'].sum().reset_index()
|
|
1086
|
+
plot_data = summary.pivot(columns='Cell Type', index='FOV', values='Count').fillna(0)
|
|
1087
|
+
|
|
1088
|
+
# Sort the cell types by total count (descending)
|
|
1089
|
+
row_sums = plot_data.sum(axis=1)
|
|
1090
|
+
plot_data_sorted = plot_data.loc[row_sums.sort_values(ascending=False).index]
|
|
1091
|
+
|
|
1092
|
+
if data_type != 'number':
|
|
1093
|
+
plot_data_sorted = plot_data_sorted.div(plot_data_sorted.sum(axis=1), axis=0)
|
|
1094
|
+
|
|
1095
|
+
# Create the stacked bar plot
|
|
1096
|
+
ax = plot_data_sorted.plot(kind='bar', stacked=True,
|
|
1097
|
+
figsize=(plot_data.shape[0] * 0.6, plot_data.shape[0] * 0.3),
|
|
1098
|
+
edgecolor='black')
|
|
1099
|
+
|
|
1100
|
+
# Customize the plot
|
|
1101
|
+
plt.title(f"Distribution of FOVs in {dataset} dataset", fontsize=20)
|
|
1102
|
+
plt.xlabel('FOV', fontsize=12)
|
|
1103
|
+
if data_type == 'number':
|
|
1104
|
+
plt.ylabel('Number of Cells', fontsize=12)
|
|
1105
|
+
else:
|
|
1106
|
+
plt.ylabel('Proportion of Cells', fontsize=12)
|
|
1107
|
+
|
|
1108
|
+
plt.xticks(rotation=90, ha='right', fontsize=10)
|
|
1109
|
+
|
|
1110
|
+
plt.legend(title='Cell Type', bbox_to_anchor=(1, 1.05), loc='center left', fontsize=12)
|
|
1111
|
+
|
|
1112
|
+
plt.tight_layout()
|
|
1113
|
+
plt.show()
|
|
1114
|
+
|
|
1115
|
+
|
|
1116
|
+
|