SpatialQuery 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1339 @@
1
+ from collections import Counter
2
+ from itertools import combinations
3
+ from typing import List, Union
4
+
5
+ import matplotlib.pyplot as plt
6
+ import statsmodels.stats.multitest as mt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import seaborn as sns
10
+ from anndata import AnnData
11
+ from mlxtend.frequent_patterns import fpgrowth
12
+ from sklearn.preprocessing import MultiLabelBinarizer
13
+ from pandas import DataFrame
14
+ from scipy.spatial import KDTree
15
+ from scipy.stats import hypergeom
16
+ from sklearn.preprocessing import LabelEncoder
17
+ import time
18
+
19
+ class spatial_query:
20
+ def __init__(self,
21
+ adata: AnnData,
22
+ dataset: str = 'ST',
23
+ spatial_key: str = 'X_spatial',
24
+ label_key: str = 'predicted_label',
25
+ leaf_size: int = 10,
26
+ max_radius: float = 500,
27
+ n_split: int = 10
28
+ ):
29
+ if spatial_key not in adata.obsm.keys() or label_key not in adata.obs.keys():
30
+ raise ValueError(f"The Anndata object must contain {spatial_key} in obsm and {label_key} in obs.")
31
+ # Store spatial position and cell type label
32
+ self.spatial_key = spatial_key
33
+ self.spatial_pos = np.array(adata.obsm[self.spatial_key])
34
+ self.dataset = dataset
35
+ self.label_key = label_key
36
+ self.max_radius = max_radius
37
+ self.labels = adata.obs[self.label_key]
38
+ self.labels = self.labels.astype('category')
39
+ self.kd_tree = KDTree(self.spatial_pos, leafsize=leaf_size)
40
+ self.overlap_radius = max_radius # the upper limit of radius in case missing cells with large radius of query
41
+ self.n_split = n_split
42
+ self.grid_cell_types, self.grid_indices = self._initialize_grids()
43
+
44
+ def _initialize_grids(self):
45
+ xmax, ymax = np.max(self.spatial_pos, axis=0)
46
+ xmin, ymin = np.min(self.spatial_pos, axis=0)
47
+ x_step = (xmax - xmin) / self.n_split # separate x axis into self.n_split parts
48
+ y_step = (ymax - ymin) / self.n_split # separate y axis into self.n_split parts
49
+
50
+ grid_cell_types = {}
51
+ grid_indices = {}
52
+
53
+ for i in range(self.n_split):
54
+ for j in range(self.n_split):
55
+ x_start = xmin + i * x_step - (self.overlap_radius if i > 0 else 0)
56
+ x_end = xmin + (i + 1) * x_step + (self.overlap_radius if i < (self.n_split - 1) else 0)
57
+ y_start = ymin + j * y_step - (self.overlap_radius if j > 0 else 0)
58
+ y_end = ymin + (j + 1) * y_step + (self.overlap_radius if j < (self.n_split - 1) else 0)
59
+
60
+ cell_mask = (self.spatial_pos[:, 0] >= x_start) & (self.spatial_pos[:, 0] <= x_end) & \
61
+ (self.spatial_pos[:, 1] >= y_start) & (self.spatial_pos[:, 1] <= y_end)
62
+
63
+ grid_indices[(i, j)] = np.where(cell_mask)[0]
64
+ grid_cell_types[(i, j)] = set(self.labels[cell_mask])
65
+
66
+ return grid_cell_types, grid_indices
67
+
68
+ def _query_pattern(self, pattern):
69
+ matching_grids = []
70
+ matching_cells_indices = {}
71
+ for grid, cell_types in self.grid_cell_types.items():
72
+ if all(cell_type in cell_types for cell_type in pattern):
73
+ matching_grids.append(grid)
74
+ indices = self.grid_indices[grid]
75
+ matching_cells_indices[grid] = indices
76
+ return matching_grids, matching_cells_indices
77
+
78
+ @staticmethod
79
+ def has_motif(neighbors: List[str], labels: List[str]) -> bool:
80
+ """
81
+ Determines whether all elements in 'neighbors' are present in 'labels'.
82
+ If all elements are present, returns True. Otherwise, returns False.
83
+
84
+ Parameter
85
+ ---------
86
+ neighbors:
87
+ List of elements to check.
88
+ labels:
89
+ List in which to check for elements from 'neighbors'.
90
+
91
+ Return
92
+ ------
93
+ True if all elements of 'neighbors' are in 'labels', False otherwise.
94
+ """
95
+ # Set elements in neighbors and labels to be unique.
96
+ # neighbors = set(neighbors)
97
+ # labels = set(labels)
98
+ freq_neighbors = Counter(neighbors)
99
+ freq_labels = Counter(labels)
100
+ for element, count in freq_neighbors.items():
101
+ if freq_labels[element] < count:
102
+ return False
103
+
104
+ return True
105
+ # if len(neighbors) <= len(labels):
106
+ # for n in neighbors:
107
+ # if n in labels:
108
+ # pass
109
+ # else:
110
+ # return False
111
+ # return True
112
+ # return False
113
+
114
+ @staticmethod
115
+ def _distinguish_duplicates(transaction: List[str]):
116
+ """
117
+ Append suffix to items of transaction to distinguish the duplicate items.
118
+ """
119
+ counter = dict(Counter(transaction))
120
+ trans_suf = [f"{item}_{i}" for item, value in counter.items() for i in range(value)]
121
+ # trans_suf = [f"{item}_{value}" for item, value in counter.items()]
122
+ # count_dict = defaultdict(int)
123
+ # for i, item in enumerate(transaction):
124
+ # # Increment the count for the item, or initialize it if it's new
125
+ # count_dict[item] += 1
126
+ # # Update the item with its count as suffix
127
+ # transaction[i] = f"{item}_{count_dict[item]}"
128
+ # return transaction
129
+ return trans_suf
130
+
131
+ @staticmethod
132
+ def _remove_suffix(fp: pd.DataFrame):
133
+ """
134
+ Remove the suffix of frequent patterns.
135
+ """
136
+ trans = [list(tran) for tran in fp['itemsets'].values]
137
+ fp_no_suffix = [[item.split('_')[0] for item in tran] for tran in trans]
138
+ # Create a DataFrame
139
+ fp['itemsets'] = fp_no_suffix
140
+ return fp
141
+
142
+ @staticmethod
143
+ def find_maximal_patterns(fp: pd.DataFrame) -> pd.DataFrame:
144
+ """
145
+ Find the maximal frequent patterns
146
+
147
+ Parameter
148
+ ---------
149
+ fp: Frequent patterns dataframe with support values and itemsets.
150
+
151
+ Return
152
+ ------
153
+ Maximal frequent patterns with support and itemsets.
154
+ """
155
+ # Convert itemsets to frozensets for set operations
156
+ itemsets = fp['itemsets'].apply(frozenset)
157
+
158
+ # Find all subsets of each itemset
159
+ subsets = set()
160
+ for itemset in itemsets:
161
+ for r in range(1, len(itemset)):
162
+ subsets.update(frozenset(s) for s in combinations(itemset, r))
163
+
164
+ # Identify maximal patterns (itemsets that are not subsets of any other)
165
+ maximal_patterns = [itemset for itemset in itemsets if itemset not in subsets]
166
+ # maximal_patterns_ = [list(p) for p in maximal_patterns]
167
+
168
+ # Filter the original DataFrame to keep only the maximal patterns
169
+ return fp[fp['itemsets'].isin(maximal_patterns)].reset_index(drop=True)
170
+
171
+ def find_fp_knn(self,
172
+ ct: str,
173
+ k: int = 30,
174
+ min_support: float = 0.5,
175
+ ) -> pd.DataFrame:
176
+ """
177
+ Find frequent patterns within the KNNs of certain cell type.
178
+
179
+ Parameter
180
+ ---------
181
+ ct:
182
+ Cell type name.
183
+ k:
184
+ Number of nearest neighbors.
185
+ min_support:
186
+ Threshold of frequency to consider a pattern as a frequent pattern.
187
+
188
+ Return
189
+ ------
190
+ Frequent patterns in the neighborhood of certain cell type.
191
+ """
192
+ if ct not in self.labels.unique():
193
+ raise ValueError(f"Found no {ct} in {self.label_key}!")
194
+
195
+ cinds = [id for id, l in enumerate(self.labels) if l == ct]
196
+ ct_pos = self.spatial_pos[cinds]
197
+
198
+ fp, _, _ = self.build_fptree_knn(cell_pos=ct_pos, k=k,
199
+ min_support=min_support,
200
+ )
201
+
202
+ return fp
203
+
204
+ def find_fp_dist(self,
205
+ ct: str,
206
+ max_dist: float = 100,
207
+ min_size: int = 0,
208
+ min_support: float = 0.5,
209
+ ):
210
+ """
211
+ Find frequent patterns within the radius of certain cell type.
212
+
213
+ Parameter
214
+ ---------
215
+ ct:
216
+ Cell type name.
217
+ max_dist:
218
+ Maximum distance for considering a cell as a neighbor.
219
+ min_size:
220
+ Minimum neighborhood size for each point to consider.
221
+ min_support:
222
+ Threshold of frequency to consider a pattern as a frequent pattern.
223
+
224
+ Return
225
+ ------
226
+ Frequent patterns in the neighborhood of certain cell type.
227
+ """
228
+ if ct not in self.labels.unique():
229
+ raise ValueError(f"Found no {ct} in {self.label_key}!")
230
+
231
+ cinds = [id for id, l in enumerate(self.labels) if l == ct]
232
+ ct_pos = self.spatial_pos[cinds]
233
+ max_dist = min(max_dist, self.max_radius)
234
+
235
+ fp, _, _ = self.build_fptree_dist(cell_pos=ct_pos,
236
+ max_dist=max_dist,
237
+ min_size=min_size,
238
+ min_support=min_support,
239
+ cinds=cinds)
240
+
241
+ return fp
242
+
243
+ def motif_enrichment_knn(self,
244
+ ct: str,
245
+ motifs: Union[str, List[str]] = None,
246
+ k: int = 30,
247
+ min_support: float = 0.5,
248
+ max_dist: float = 200,
249
+ return_cellID: bool = False
250
+ ) -> pd.DataFrame:
251
+ """
252
+ Perform motif enrichment analysis using k-nearest neighbors (KNN).
253
+
254
+ Parameter
255
+ ---------
256
+ ct:
257
+ The cell type of the center cell.
258
+ motifs:
259
+ Specified motifs to be tested.
260
+ If motifs=None, find the frequent patterns as motifs within the neighborhood of center cell type.
261
+ k:
262
+ Number of nearest neighbors to consider.
263
+ min_support:
264
+ Threshold of frequency to consider a pattern as a frequent pattern.
265
+ max_dist:
266
+ Maximum distance for neighbors (default: 200).
267
+ return_cellID:
268
+ Indicate whether return cell IDs for each frequent pattern within the neighborhood of grid points.
269
+ By defaults do not return cell ID.
270
+
271
+ Return
272
+ ------
273
+ pd.Dataframe containing the cell type name, motifs, number of motifs nearby given cell type,
274
+ number of spots of cell type, number of motifs in single FOV, p value of hypergeometric distribution.
275
+ """
276
+ if ct not in self.labels.unique():
277
+ raise ValueError(f"Found no {ct} in {self.label_key}!")
278
+
279
+ max_dist = min(max_dist, self.max_radius)
280
+
281
+ dists, idxs = self.kd_tree.query(self.spatial_pos,
282
+ k=k + 1, workers=-1
283
+ ) # use k+1 to find the knn except for the points themselves
284
+ cinds = [i for i, l in enumerate(self.labels) if l == ct]
285
+
286
+ out = []
287
+ if motifs is None:
288
+ fp = self.find_fp_knn(ct=ct, k=k,
289
+ min_support=min_support,
290
+ )
291
+ motifs = fp['itemsets']
292
+ else:
293
+ if isinstance(motifs, str):
294
+ motifs = [motifs]
295
+
296
+ labels_unique = self.labels.unique()
297
+ motifs_exc = [m for m in motifs if m not in labels_unique]
298
+ if len(motifs_exc) != 0:
299
+ print(f"Found no {motifs_exc} in {self.label_key}. Ignoring them.")
300
+ motifs = [m for m in motifs if m not in motifs_exc]
301
+ motifs = [motifs]
302
+
303
+ if len(motifs) == 0:
304
+ raise ValueError("No frequent patterns were found. Please lower min_support value.")
305
+
306
+ label_encoder = LabelEncoder()
307
+ int_labels = label_encoder.fit_transform(self.labels)
308
+ int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
309
+
310
+ num_cells = idxs.shape[0]
311
+ num_types = len(label_encoder.classes_)
312
+
313
+ valid_neighbors = dists[:, 1:] <= max_dist
314
+ filtered_idxs = np.where(valid_neighbors, idxs[:, 1:], -1)
315
+ flat_neighbors = filtered_idxs.flatten()
316
+ valid_neighbors_flat = valid_neighbors.flatten()
317
+
318
+ neighbor_labels = np.where(valid_neighbors_flat, int_labels[flat_neighbors], -1)
319
+ valid_mask = neighbor_labels != -1
320
+
321
+ neighbor_matrix = np.zeros((num_cells * k, num_types), dtype=int)
322
+ neighbor_matrix[np.arange(len(neighbor_labels))[valid_mask], neighbor_labels[valid_mask]] = 1
323
+
324
+ neighbor_counts = neighbor_matrix.reshape(num_cells, k, num_types).sum(axis=1)
325
+
326
+ mask = int_labels == int_ct
327
+
328
+ for motif in motifs:
329
+ motif = list(motif) if not isinstance(motif, list) else motif
330
+ sort_motif = sorted(motif)
331
+
332
+ int_motifs = label_encoder.transform(np.array(motif))
333
+
334
+ n_motif_ct = np.sum(np.all(neighbor_counts[mask][:, int_motifs] > 0, axis=1))
335
+ n_motif_labels = np.sum(np.all(neighbor_counts[:, int_motifs] > 0, axis=1))
336
+
337
+ n_ct = len(cinds)
338
+ if ct in motif:
339
+ n_ct = round(n_ct / motif.count(ct))
340
+
341
+ hyge = hypergeom(M=len(self.labels), n=n_ct, N=n_motif_labels)
342
+ # M is number of total, N is number of drawn without replacement, n is number of success in total
343
+ motif_out = {'center': ct, 'motifs': sort_motif, 'n_center_motif': n_motif_ct,
344
+ 'n_center': n_ct, 'n_motif': n_motif_labels, 'expectation': hyge.mean(), 'p-values': hyge.sf(n_motif_ct)}
345
+
346
+ if return_cellID:
347
+ inds = np.where(np.all(neighbor_counts[mask][:, int_motifs] > 0, axis=1))[0]
348
+ cind_with_motif = [cinds[i] for i in inds]
349
+ motif_mask = np.isin(np.array(self.labels), motif)
350
+ neighbors = np.concatenate(idxs[cind_with_motif])
351
+ exclude_self_mask = ~np.isin(neighbors, cind_with_motif)
352
+ valid_neighbors = neighbors[motif_mask[neighbors] & exclude_self_mask]
353
+ id_motif_celltype = set(valid_neighbors)
354
+ motif_out['cell_id'] = np.array(list(id_motif_celltype))
355
+
356
+ out.append(motif_out)
357
+
358
+ out_pd = pd.DataFrame(out)
359
+
360
+ if len(out_pd) == 1:
361
+ out_pd['if_significant'] = True if out_pd['p-values'][0] < 0.05 else False
362
+ return out_pd
363
+ else:
364
+ p_values = out_pd['p-values'].tolist()
365
+ if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
366
+ alpha=0.05,
367
+ method='poscorr')
368
+ out_pd['corrected p-values'] = corrected_p_values
369
+ out_pd['if_significant'] = if_rejected
370
+ out_pd = out_pd.sort_values(by='corrected p-values', ignore_index=True)
371
+ return out_pd
372
+
373
+ def motif_enrichment_dist(self,
374
+ ct: str,
375
+ motifs: Union[str, List[str]] = None,
376
+ max_dist: float = 100,
377
+ min_size: int = 0,
378
+ min_support: float = 0.5,
379
+ max_ns: int = 100,
380
+ return_cellID: bool = False,
381
+ ) -> DataFrame:
382
+ """
383
+ Perform motif enrichment analysis within a specified radius-based neighborhood.
384
+
385
+ Parameter
386
+ ---------
387
+ ct:
388
+ Cell type as the center cells.
389
+ motifs:
390
+ Specified motifs to be tested.
391
+ If motifs=None, find the frequent patterns as motifs within the neighborhood of center cell type.
392
+ max_dist:
393
+ Maximum distance for considering a cell as a neighbor.
394
+ min_size:
395
+ Minimum neighborhood size for each point to consider.
396
+ min_support:
397
+ Threshold of frequency to consider a pattern as a frequent pattern.
398
+ max_ns:
399
+ Maximum number of neighborhood size for each point.
400
+ return_cellID:
401
+ Indicate whether return cell IDs for each motif within the neighborhood of central cell type.
402
+ By defaults do not return cell ID.
403
+ Returns
404
+ -------
405
+ Tuple containing counts and statistical measures.
406
+ """
407
+ if ct not in self.labels.unique():
408
+ raise ValueError(f"Found no {ct} in {self.label_key}!")
409
+
410
+ out = []
411
+ max_dist = min(max_dist, self.max_radius)
412
+ if motifs is None:
413
+ fp = self.find_fp_dist(ct=ct,
414
+ max_dist=max_dist, min_size=min_size,
415
+ min_support=min_support)
416
+ motifs = fp['itemsets']
417
+ else:
418
+ if isinstance(motifs, str):
419
+ motifs = [motifs]
420
+
421
+ labels_unique = self.labels.unique()
422
+ motifs_exc = [m for m in motifs if m not in labels_unique]
423
+ if len(motifs_exc) != 0:
424
+ print(f"Found no {motifs_exc} in {self.label_key}. Ignoring them.")
425
+ motifs = [m for m in motifs if m not in motifs_exc]
426
+ motifs = [motifs]
427
+
428
+ label_encoder = LabelEncoder()
429
+ int_labels = label_encoder.fit_transform(np.array(self.labels))
430
+ int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
431
+
432
+ cinds = np.where(self.labels == ct)[0]
433
+
434
+ num_cells = len(self.spatial_pos)
435
+ num_types = len(label_encoder.classes_)
436
+
437
+ if return_cellID:
438
+ idxs_all = self.kd_tree.query_ball_point(
439
+ self.spatial_pos,
440
+ r=max_dist,
441
+ return_sorted=False,
442
+ workers=-1,
443
+ )
444
+ idxs_all_filter = [np.array(ids)[np.array(ids) != i] for i, ids in enumerate(idxs_all)]
445
+ flat_neighbors_all = np.concatenate(idxs_all_filter)
446
+ row_indices_all = np.repeat(np.arange(num_cells), [len(neigh) for neigh in idxs_all_filter])
447
+ neighbor_labels_all = int_labels[flat_neighbors_all]
448
+ mask_all = int_labels == int_ct
449
+
450
+ for motif in motifs:
451
+ motif = list(motif) if not isinstance(motif, list) else motif
452
+ sort_motif = sorted(motif)
453
+
454
+ _, matching_cells_indices = self._query_pattern(motif)
455
+ if not matching_cells_indices:
456
+ # if matching_cells_indices is empty, it indicates no motif are grouped together within upper limit of radius (500)
457
+ continue
458
+ matching_cells_indices = np.concatenate([t for t in matching_cells_indices.values()])
459
+ matching_cells_indices = np.unique(matching_cells_indices)
460
+ print(f"number of cells skipped: {len(matching_cells_indices)}")
461
+ print(f"proportion of cells searched: {len(matching_cells_indices) / len(self.spatial_pos)}")
462
+ idxs_in_grids = self.kd_tree.query_ball_point(
463
+ self.spatial_pos[matching_cells_indices],
464
+ r=max_dist,
465
+ return_sorted=True,
466
+ workers=-1
467
+ )
468
+
469
+ # using numpy
470
+ int_motifs = label_encoder.transform(np.array(motif))
471
+
472
+ # filter center out of neighbors
473
+ idxs_filter = [np.array(ids)[np.array(ids) != i][:min(max_ns, len(ids))] for i, ids in zip(matching_cells_indices, idxs_in_grids)]
474
+
475
+ flat_neighbors = np.concatenate(idxs_filter)
476
+ row_indices = np.repeat(np.arange(len(matching_cells_indices)), [len(neigh) for neigh in idxs_filter])
477
+ neighbor_labels = int_labels[flat_neighbors]
478
+
479
+ neighbor_matrix = np.zeros((len(matching_cells_indices), num_types), dtype=int)
480
+ np.add.at(neighbor_matrix, (row_indices, neighbor_labels), 1)
481
+
482
+ mask = int_labels[matching_cells_indices] == int_ct
483
+ n_motif_ct = np.sum(np.all(neighbor_matrix[mask][:, int_motifs] > 0, axis=1))
484
+ n_motif_labels = np.sum(np.all(neighbor_matrix[:, int_motifs] > 0, axis=1))
485
+
486
+ n_ct = len(cinds)
487
+ if ct in motif:
488
+ n_ct = round(n_ct / motif.count(ct))
489
+
490
+ hyge = hypergeom(M=len(self.labels), n=n_ct, N=n_motif_labels)
491
+ motif_out = {'center': ct, 'motifs': sort_motif, 'n_center_motif': n_motif_ct,
492
+ 'n_center': n_ct, 'n_motif': n_motif_labels, 'expectation': hyge.mean(), 'p-values': hyge.sf(n_motif_ct)}
493
+
494
+ if return_cellID:
495
+ neighbor_matrix_all = np.zeros((num_cells, num_types), dtype=int)
496
+ np.add.at(neighbor_matrix_all, (row_indices_all, neighbor_labels_all), 1)
497
+ inds_all = np.where(np.all(neighbor_matrix_all[mask_all][:, int_motifs] > 0, axis=1))[0]
498
+ cind_with_motif = [cinds[i] for i in inds_all]
499
+ motif_mask = np.isin(np.array(self.labels), motif)
500
+ all_neighbors = np.concatenate(idxs_all[cind_with_motif])
501
+ exclude_self_mask = ~np.isin(all_neighbors, cind_with_motif)
502
+ valid_neighbors = all_neighbors[motif_mask[all_neighbors] & exclude_self_mask]
503
+ id_motif_celltype = set(valid_neighbors)
504
+ motif_out['cell_id'] = np.array(list(id_motif_celltype))
505
+
506
+ out.append(motif_out)
507
+
508
+ out_pd = pd.DataFrame(out)
509
+
510
+ if len(out_pd) == 1:
511
+ out_pd['if_significant'] = True if out_pd['p-values'][0] < 0.05 else False
512
+ return out_pd
513
+ else:
514
+ p_values = out_pd['p-values'].tolist()
515
+ if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
516
+ alpha=0.05,
517
+ method='poscorr')
518
+ out_pd['corrected p-values'] = corrected_p_values
519
+ out_pd['if_significant'] = if_rejected
520
+ out_pd = out_pd.sort_values(by='corrected p-values', ignore_index=True)
521
+ return out_pd
522
+
523
+
524
+ def build_fptree_dist(self,
525
+ cell_pos: np.ndarray = None,
526
+ max_dist: float = 100,
527
+ min_support: float = 0.5,
528
+ if_max: bool = True,
529
+ min_size: int = 0,
530
+ cinds: List[int] = None,
531
+ max_ns: int = 100) -> tuple:
532
+ """
533
+ Build a frequency pattern tree based on the distance of cell types.
534
+
535
+ Parameter
536
+ ---------
537
+ cell_pos:
538
+ Spatial coordinates of input points.
539
+ If cell_pos is None, use all spots in fov to compute frequent patterns.
540
+ max_dist:
541
+ Maximum distance to consider a cell as a neighbor.
542
+ min_support:
543
+ Threshold of frequency to consider a pattern as a frequent pattern.
544
+ if_max:
545
+ By default return the maximum set of frequent patterns without the subsets. If if_max=False, return all
546
+ patterns whose support values are greater than min_support.
547
+ min_size:
548
+ Minimum neighborhood size for each point to consider.
549
+ max_ns:
550
+ Maximum number of neighborhood size for each point.
551
+
552
+ Return
553
+ ------
554
+ A tuple containing the FPs, the transactions table and the nerghbors index.
555
+ """
556
+ if cell_pos is None:
557
+ cell_pos = self.spatial_pos
558
+
559
+ max_dist = min(max_dist, self.max_radius)
560
+
561
+ # start = time.time()
562
+ idxs = self.kd_tree.query_ball_point(cell_pos, r=max_dist, return_sorted=False, workers=-1)
563
+ if cinds is None:
564
+ cinds = list(range(len(idxs)))
565
+ # end = time.time()
566
+ # print("query: {end-start} seconds")
567
+
568
+ # Prepare data for FP-Tree construction
569
+ # start = time.time()
570
+ transactions = []
571
+ valid_idxs = []
572
+ labels = np.array(self.labels)
573
+ for i_idx, idx in zip(cinds, idxs):
574
+ if not idx:
575
+ continue
576
+ idx_array = np.array(idx)
577
+ valid_mask = idx_array != i_idx
578
+ valid_indices = idx_array[valid_mask][:max_ns]
579
+
580
+ transaction = labels[valid_indices]
581
+ if len(transaction) > min_size:
582
+ transactions.append(transaction.tolist())
583
+ valid_idxs.append(valid_indices)
584
+ # end = time.time()
585
+ # print(f"build transactions: {end-start} seconds")
586
+ # Convert transactions to a DataFrame suitable for fpgrowth
587
+ # start = time.time()
588
+ mlb = MultiLabelBinarizer()
589
+ encoded_data = mlb.fit_transform(transactions)
590
+ df = pd.DataFrame(encoded_data.astype(bool), columns=mlb.classes_)
591
+
592
+ # Construct FP-Tree using fpgrowth
593
+ fp_tree = fpgrowth(df, min_support=min_support, use_colnames=True)
594
+ # end = time.time()
595
+ # print(f"fp_growth: {end-start} seconds")
596
+ if if_max:
597
+ # start = time.time()
598
+ fp_tree = self.find_maximal_patterns(fp=fp_tree)
599
+ # end = time.time()
600
+ # print(f"find_maximal_patterns: {end-start} seconds")
601
+
602
+ # Remove suffix of items if treating duplicates as different items
603
+ # if dis_duplicates:
604
+ # fp_tree = self._remove_suffix(fp_tree)
605
+
606
+ if len(fp_tree) == 0:
607
+ return pd.DataFrame(columns=['support', 'itemsets']), df, valid_idxs
608
+ else:
609
+ fp_tree['itemsets'] = fp_tree['itemsets'].apply(lambda x: tuple(sorted(x)))
610
+ fp_tree = fp_tree.drop_duplicates().reset_index(drop=True)
611
+ fp_tree['itemsets'] = fp_tree['itemsets'].apply(lambda x: list(x))
612
+ fp_tree = fp_tree.sort_values(by='support', ignore_index=True, ascending=False)
613
+ return fp_tree, df, valid_idxs
614
+
615
+ def build_fptree_knn(self,
616
+ cell_pos: np.ndarray = None,
617
+ k: int = 30,
618
+ min_support: float = 0.5,
619
+ max_dist: float = 200,
620
+ if_max: bool = True
621
+ ) -> tuple:
622
+ """
623
+ Build a frequency pattern tree based on knn
624
+
625
+ Parameter
626
+ ---------
627
+ cell_pos:
628
+ Spatial coordinates of input points.
629
+ If cell_pos is None, use all spots in fov to compute frequent patterns.
630
+ k:
631
+ Number of neighborhood size for each point.
632
+ min_support:
633
+ Threshold of frequency to consider a pattern as a frequent pattern
634
+ max_dist:
635
+ The maximum distance at which points are considered neighbors.
636
+ if_max:
637
+ By default return the maximum set of frequent patterns without the subsets. If if_max=False, return all
638
+ patterns whose support values are greater than min_support.
639
+
640
+ Return
641
+ ------
642
+ A tuple containing the FPs, the transactions table, and the neighbors index.
643
+ """
644
+ if cell_pos is None:
645
+ cell_pos = self.spatial_pos
646
+
647
+ max_dist = min(max_dist, self.max_radius)
648
+ # start = time.time()
649
+ dists, idxs = self.kd_tree.query(cell_pos, k=k + 1, workers=-1)
650
+ # end = time.time()
651
+ # print(f"knn query: {end-start} seconds")
652
+
653
+ # Prepare data for FP-Tree construction
654
+ # start = time.time()
655
+ idxs = np.array(idxs)
656
+ dists = np.array(dists)
657
+ labels = np.array(self.labels)
658
+ transactions = []
659
+ mask = dists < max_dist
660
+ for i, idx in enumerate(idxs):
661
+ inds = idx[mask[i]]
662
+ if len(inds) == 0:
663
+ continue
664
+ transaction = labels[inds[1:]]
665
+ # if dis_duplicates:
666
+ # transaction = distinguish_duplicates_numpy(transaction)
667
+ transactions.append(transaction) # 将 NumPy 数组转换回列表
668
+
669
+ # end = time.time()
670
+ # print(f"build transactions: {end-start} seconds")
671
+
672
+ # transactions = []
673
+ # for i, idx in enumerate(idxs):
674
+ # inds = [id for j, id in enumerate(idx) if
675
+ # dists[i][j] < max_dist] # only contain the KNN whose distance is less than max_dist
676
+ # transaction = [self.labels[i] for i in inds[1:] if self.labels[i]]
677
+ # # if dis_duplicates:
678
+ # # transaction = self._distinguish_duplicates(transaction)
679
+ # transactions.append(transaction)
680
+
681
+ # Convert transactions to a DataFrame suitable for fpgrowth
682
+ # start = time.time()
683
+ mlb = MultiLabelBinarizer()
684
+ encoded_data = mlb.fit_transform(transactions)
685
+ df = pd.DataFrame(encoded_data.astype(bool), columns=mlb.classes_)
686
+
687
+ # Construct FP-Tree using fpgrowth
688
+ fp_tree = fpgrowth(df, min_support=min_support, use_colnames=True)
689
+ # end = time.time()
690
+ # print(f"fp-growth: {end-start} seconds")
691
+
692
+ if if_max:
693
+ # start = time.time()
694
+ fp_tree = self.find_maximal_patterns(fp_tree)
695
+ # end = time.time()
696
+ # print(f"find_maximal_patterns: {end-start} seconds")
697
+
698
+ # if dis_duplicates:
699
+ # fp_tree = self._remove_suffix(fp_tree)
700
+ if len(fp_tree) == 0:
701
+ return pd.DataFrame(columns=['support', 'itemsets']), df, idxs
702
+ else:
703
+ fp_tree['itemsets'] = fp_tree['itemsets'].apply(lambda x: tuple(sorted(x)))
704
+ fp_tree = fp_tree.drop_duplicates().reset_index(drop=True)
705
+ fp_tree['itemsets'] = fp_tree['itemsets'].apply(lambda x: list(x))
706
+ fp_tree = fp_tree.sort_values(by='support', ignore_index=True, ascending=False)
707
+ return fp_tree, df, idxs
708
+
709
+ def find_patterns_grid(self,
710
+ max_dist: float = 100,
711
+ min_size: int = 0,
712
+ min_support: float = 0.5,
713
+ if_display: bool = True,
714
+ fig_size: tuple = (10, 5),
715
+ return_cellID: bool = False,
716
+ return_grid: bool = False,
717
+ ):
718
+ """
719
+ Create a grid and use it to find surrounding patterns in spatial data.
720
+
721
+ Parameter
722
+ ---------
723
+ max_dist:
724
+ Maximum distance to consider a cell as a neighbor.
725
+ min_support:
726
+ Threshold of frequency to consider a pattern as a frequent pattern
727
+ min_size:
728
+ Additional parameters for pattern finding.
729
+ if_display:
730
+ Display the grid points with nearby frequent patterns if if_display=True.
731
+ fig_size:
732
+ Tuple of figure size.
733
+ return_cellID:
734
+ Indicate whether return cell IDs for each frequent pattern within the neighborhood of grid points.
735
+ By defaults do not return cell ID.
736
+ return_grid:
737
+ Indicate whether return the grid points. By default, do not return grid points.
738
+ If true, will return a tuple (fp_tree, grid)
739
+
740
+ Return
741
+ ------
742
+ fp_tree:
743
+ Frequent patterns
744
+ """
745
+
746
+ max_dist = min(max_dist, self.max_radius)
747
+ xmax, ymax = np.max(self.spatial_pos, axis=0)
748
+ xmin, ymin = np.min(self.spatial_pos, axis=0)
749
+ x_grid = np.arange(xmin - max_dist, xmax + max_dist, max_dist)
750
+ y_grid = np.arange(ymin - max_dist, ymax + max_dist, max_dist)
751
+ grid = np.array(np.meshgrid(x_grid, y_grid)).T.reshape(-1, 2)
752
+
753
+ fp, trans_df, idxs = self.build_fptree_dist(
754
+ cell_pos=grid,
755
+ max_dist=max_dist,
756
+ min_size=min_size,
757
+ min_support=min_support,
758
+ if_max=True,
759
+ )
760
+
761
+ # For each frequent pattern/motif, locate the cell IDs in the neighborhood of the above grid points
762
+ # as well as labelled with cell types in motif.
763
+ # if dis_duplicates:
764
+ # normalized_columns = [col.split('_')[0] for col in trans_df.columns]
765
+ # trans_df.columns = normalized_columns
766
+ # sparse_trans_df = csr_matrix(trans_df, dtype=int)
767
+ # trans_df_aggregated = pd.DataFrame.sparse.from_spmatrix(sparse_trans_df, columns=normalized_columns)
768
+ # trans_df_aggregated = trans_df_aggregated.groupby(trans_df_aggregated.columns, axis=1).sum()
769
+ id_neighbor_motifs = []
770
+ if if_display or return_cellID:
771
+ for motif in fp['itemsets']:
772
+ motif = list(motif)
773
+ fp_spots_index = set()
774
+ # if dis_duplicates:
775
+ # ct_counts_in_motif = pd.Series(motif).value_counts().to_dict()
776
+ # required_counts = pd.Series(ct_counts_in_motif, index=trans_df_aggregated.columns).fillna(0)
777
+ # ids = trans_df_aggregated[trans_df_aggregated >= required_counts].dropna().index
778
+ # else:
779
+ # ids = trans_df[trans_df[motif].all(axis=1)].index.to_list()
780
+ ids = trans_df[trans_df[motif].all(axis=1)].index.to_list()
781
+ if isinstance(idxs, list):
782
+ # ids = ids.index[ids == True].to_list()
783
+ fp_spots_index.update([i for id in ids for i in idxs[id] if self.labels[i] in motif])
784
+ else:
785
+ ids = idxs[ids]
786
+ fp_spots_index.update([i for id in ids for i in id if self.labels[i] in motif])
787
+ id_neighbor_motifs.append(fp_spots_index)
788
+ if return_cellID:
789
+ fp['cell_id'] = id_neighbor_motifs
790
+
791
+ if if_display:
792
+ fp_cts = sorted(set(t for items in fp['itemsets'] for t in list(items)))
793
+ n_colors = len(fp_cts)
794
+ colors = sns.color_palette('hsv', n_colors)
795
+ color_map = {ct: col for ct, col in zip(fp_cts, colors)}
796
+
797
+ fp_spots_index = set()
798
+ for cell_id in id_neighbor_motifs:
799
+ fp_spots_index.update(cell_id)
800
+
801
+ fp_spot_pos = self.spatial_pos[list(fp_spots_index), :]
802
+ fp_spot_label = self.labels[list(fp_spots_index)]
803
+ fig, ax = plt.subplots(figsize=fig_size)
804
+ # Plotting the grid lines
805
+ for x in x_grid:
806
+ ax.axvline(x, color='lightgray', linestyle='--', lw=0.5)
807
+
808
+ for y in y_grid:
809
+ ax.axhline(y, color='lightgray', linestyle='--', lw=0.5)
810
+
811
+ for ct in fp_cts:
812
+ ct_ind = fp_spot_label == ct
813
+ ax.scatter(fp_spot_pos[ct_ind, 0], fp_spot_pos[ct_ind, 1],
814
+ label=ct, color=color_map[ct], s=1)
815
+ ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
816
+ plt.xlabel('Spatial X')
817
+ plt.ylabel('Spatial Y')
818
+ plt.title('Spatial distribution of frequent patterns')
819
+ ax.set_xticklabels([])
820
+ ax.set_yticklabels([])
821
+ ax.set_xticks([])
822
+ ax.set_yticks([])
823
+ plt.tight_layout(rect=[0, 0, 1.1, 1])
824
+ plt.show()
825
+
826
+ if return_grid:
827
+ return fp.sort_values(by='support', ignore_index=True, ascending=False), grid
828
+ else:
829
+ return fp.sort_values(by='support', ignore_index=True, ascending=False)
830
+
831
+ def find_patterns_rand(self,
832
+ max_dist: float = 100,
833
+ n_points: int = 1000,
834
+ min_support: float = 0.5,
835
+ min_size: int = 0,
836
+ if_display: bool = True,
837
+ fig_size: tuple = (10, 5),
838
+ return_cellID: bool = False,
839
+ seed: int = 2023) -> DataFrame:
840
+ """
841
+ Randomly generate points and use them to find surrounding patterns in spatial data.
842
+
843
+ Parameter
844
+ ---------
845
+ if_knn:
846
+ Use k-nearest neighbors or points within max_dist distance as neighborhood.
847
+ k:
848
+ Number of nearest neighbors. If if_knn=True, parameter k is used.
849
+ max_dist:
850
+ Maximum distance to consider a cell as a neighbor. If if_knn=False, parameter max_dist is used.
851
+ n_points:
852
+ Number of random points to generate.
853
+ min_support:
854
+ Threshold of frequency to consider a pattern as a frequent pattern.
855
+ min_size:
856
+ Additional parameters for pattern finding.
857
+ if_display:
858
+ Display the grid points with nearby frequent patterns if if_display=True.
859
+ fig_size:
860
+ Tuple of figure size.
861
+ return_cellID:
862
+ Indicate whether return cell IDs for each frequent pattern within the neighborhood of grid points.
863
+ By defaults do not return cell ID.
864
+ seed:
865
+ Set random seed for reproducible.
866
+
867
+ Return
868
+ ------
869
+ Results from the pattern finding function.
870
+ """
871
+
872
+ max_dist = min(max_dist, self.max_radius)
873
+ xmax, ymax = np.max(self.spatial_pos, axis=0)
874
+ xmin, ymin = np.min(self.spatial_pos, axis=0)
875
+ np.random.seed(seed)
876
+ pos = np.column_stack((np.random.rand(n_points) * (xmax - xmin) + xmin,
877
+ np.random.rand(n_points) * (ymax - ymin) + ymin))
878
+
879
+ fp, trans_df, idxs = self.build_fptree_dist(
880
+ cell_pos=pos,
881
+ max_dist=max_dist,
882
+ min_size=min_size,
883
+ min_support=min_support,
884
+ if_max=True,
885
+ )
886
+ # if dis_duplicates:
887
+ # normalized_columns = [col.split('_')[0] for col in trans_df.columns]
888
+ # trans_df.columns = normalized_columns
889
+ # sparse_trans_df = csr_matrix(trans_df, dtype=int)
890
+ # trans_df_aggregated = pd.DataFrame.sparse.from_spmatrix(sparse_trans_df, columns=normalized_columns)
891
+ # trans_df_aggregated = trans_df_aggregated.groupby(trans_df_aggregated.columns, axis=1).sum()
892
+
893
+ id_neighbor_motifs = []
894
+ if if_display or return_cellID:
895
+ for motif in fp['itemsets']:
896
+ motif = list(motif)
897
+ fp_spots_index = set()
898
+ # if dis_duplicates:
899
+ # ct_counts_in_motif = pd.Series(motif).value_counts().to_dict()
900
+ # required_counts = pd.Series(ct_counts_in_motif, index=trans_df_aggregated.columns).fillna(0)
901
+ # ids = trans_df_aggregated[trans_df_aggregated >= required_counts].dropna().index
902
+ # else:
903
+ # ids = trans_df[trans_df[motif].all(axis=1)].index.to_list()
904
+ ids = trans_df[trans_df[motif].all(axis=1)].index.to_list()
905
+ if isinstance(idxs, list):
906
+ # ids = ids.index[ids == True].to_list()
907
+ fp_spots_index.update([i for id in ids for i in idxs[id] if self.labels[i] in motif])
908
+ else:
909
+ ids = idxs[ids]
910
+ fp_spots_index.update([i for id in ids for i in id if self.labels[i] in motif])
911
+ id_neighbor_motifs.append(fp_spots_index)
912
+ if return_cellID:
913
+ fp['cell_id'] = id_neighbor_motifs
914
+
915
+ if if_display:
916
+ fp_cts = sorted(set(t for items in fp['itemsets'] for t in list(items)))
917
+ n_colors = len(fp_cts)
918
+ colors = sns.color_palette('hsv', n_colors)
919
+ color_map = {ct: col for ct, col in zip(fp_cts, colors)}
920
+
921
+ fp_spots_index = set()
922
+ for cell_id in id_neighbor_motifs:
923
+ fp_spots_index.update(cell_id)
924
+
925
+ fp_spot_pos = self.spatial_pos[list(fp_spots_index), :]
926
+ fp_spot_label = self.labels[list(fp_spots_index)]
927
+ fig, ax = plt.subplots(figsize=fig_size)
928
+ for ct in fp_cts:
929
+ ct_ind = fp_spot_label == ct
930
+ ax.scatter(fp_spot_pos[ct_ind, 0], fp_spot_pos[ct_ind, 1],
931
+ label=ct, color=color_map[ct], s=1)
932
+ ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
933
+ plt.xlabel('Spatial X')
934
+ plt.ylabel('Spatial Y')
935
+ plt.title('Spatial distribution of frequent patterns')
936
+ ax.set_xticklabels([])
937
+ ax.set_yticklabels([])
938
+ ax.set_xticks([])
939
+ ax.set_yticks([])
940
+ plt.tight_layout(rect=[0, 0, 1.1, 1])
941
+ plt.show()
942
+
943
+ return fp.sort_values(by='support', ignore_index=True, ascending=False)
944
+
945
+ def plot_fov(self,
946
+ min_cells_label: int = 50,
947
+ title: str = 'Spatial distribution of cell types',
948
+ fig_size: tuple = (10, 5)):
949
+ """
950
+ Plot the cell type distribution of single fov.
951
+
952
+ Parameter
953
+ --------
954
+ min_cells_label:
955
+ Minimum number of points in each cell type to display.
956
+ title:
957
+ Figure title.
958
+ fig_size:
959
+ Figure size paramters.
960
+
961
+ Return
962
+ ------
963
+ A figure.
964
+ """
965
+ # Ensure that 'spatial' and label_key are present in the Anndata object
966
+
967
+ cell_type_counts = self.labels.value_counts()
968
+ n_colors = sum(cell_type_counts >= min_cells_label)
969
+ colors = sns.color_palette('hsv', n_colors)
970
+
971
+ color_counter = 0
972
+ fig, ax = plt.subplots(figsize=fig_size)
973
+
974
+ # Iterate over each cell type
975
+ for cell_type in sorted(self.labels.unique()):
976
+ # Filter data for each cell type
977
+ index = self.labels == cell_type
978
+ index = np.where(index)[0]
979
+ # data = self.labels[self.labels == cell_type].index
980
+ # Check if the cell type count is above the threshold
981
+ if cell_type_counts[cell_type] >= min_cells_label:
982
+ ax.scatter(self.spatial_pos[index, 0], self.spatial_pos[index, 1],
983
+ label=cell_type, color=colors[color_counter], s=1)
984
+ color_counter += 1
985
+ else:
986
+ ax.scatter(self.spatial_pos[index, 0], self.spatial_pos[index, 1],
987
+ color='grey', s=1)
988
+
989
+ handles, labels = ax.get_legend_handles_labels()
990
+
991
+ # Modify labels to include count values
992
+ new_labels = [f'{label} ({cell_type_counts[label]})' for label in labels]
993
+
994
+ # Create new legend
995
+ ax.legend(handles, new_labels, loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
996
+ # ax.legend(handles, new_labels, loc='lower center', bbox_to_anchor=(1, 0.5), markerscale=4)
997
+
998
+ # ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
999
+
1000
+ plt.xlabel('Spatial X')
1001
+ plt.ylabel('Spatial Y')
1002
+ plt.title(title)
1003
+ ax.set_xticklabels([])
1004
+ ax.set_yticklabels([])
1005
+ ax.set_xticks([])
1006
+ ax.set_yticks([])
1007
+
1008
+ # Adjust layout to prevent clipping of ylabel and accommodate the legend
1009
+ plt.tight_layout(rect=[0, 0, 1.1, 1])
1010
+
1011
+ plt.show()
1012
+
1013
+ def plot_motif_grid(self,
1014
+ motif: Union[str, List[str]],
1015
+ fp: pd.DataFrame,
1016
+ fig_size: tuple = (10, 5),
1017
+ max_dist: float = 100,
1018
+ ):
1019
+ """
1020
+ Display the distribution of each motif around grid points. To make sure the input
1021
+ motif can be found in the results obtained by find_patterns_grid, use the same arguments
1022
+ as those in find_pattern_grid method.
1023
+
1024
+ Parameter
1025
+ ---------
1026
+ motif:
1027
+ Motif (names of cell types) to be colored
1028
+ fp:
1029
+ Frequent patterns identified by find_patterns_grid.
1030
+ max_dist:
1031
+ Spacing distance for building grid. Make sure using the same value as that in find_patterns_grid.
1032
+ fig_size:
1033
+ Figure size.
1034
+ """
1035
+ if isinstance(motif, str):
1036
+ motif = [motif]
1037
+
1038
+ max_dist = min(max_dist, self.max_radius)
1039
+
1040
+ labels_unique = self.labels.unique()
1041
+ motif_exc = [m for m in motif if m not in labels_unique]
1042
+ if len(motif_exc) != 0:
1043
+ print(f"Found no {motif_exc} in {self.label_key}. Ignoring them.")
1044
+ motif = [m for m in motif if m not in motif_exc]
1045
+
1046
+ # Build mesh
1047
+ xmax, ymax = np.max(self.spatial_pos, axis=0)
1048
+ xmin, ymin = np.min(self.spatial_pos, axis=0)
1049
+ x_grid = np.arange(xmin - max_dist, xmax + max_dist, max_dist)
1050
+ y_grid = np.arange(ymin - max_dist, ymax + max_dist, max_dist)
1051
+ grid = np.array(np.meshgrid(x_grid, y_grid)).T.reshape(-1, 2)
1052
+
1053
+ # self.build_fptree_dist returns valid_idxs () instead of all the idxs,
1054
+ # so recalculate the idxs directly using self.kd_tree.query_ball_point
1055
+ idxs = self.kd_tree.query_ball_point(grid, r=max_dist, return_sorted=False, workers=-1)
1056
+
1057
+ # Locate the index of grid points acting as centers with motif nearby
1058
+ id_center = []
1059
+ for i, idx in enumerate(idxs):
1060
+ ns = [self.labels[id] for id in idx]
1061
+ if self.has_motif(neighbors=motif, labels=ns):
1062
+ id_center.append(i)
1063
+
1064
+ # Locate the index of cell types contained in motif in the
1065
+ # neighborhood of above grid points with motif nearby
1066
+ id_motif_celltype = fp[fp['itemsets'].apply(
1067
+ lambda p: set(p)) == set(motif)]
1068
+ id_motif_celltype = id_motif_celltype['cell_id'].iloc[0]
1069
+
1070
+ # Plot above spots and center grid points
1071
+ # Set color map as in find_patterns_grid
1072
+ fp_cts = sorted(set(t for items in fp['itemsets'] for t in list(items)))
1073
+ n_colors = len(fp_cts)
1074
+ colors = sns.color_palette('hsv', n_colors)
1075
+ color_map = {ct: col for ct, col in zip(fp_cts, colors)}
1076
+
1077
+ motif_spot_pos = self.spatial_pos[list(id_motif_celltype), :]
1078
+ motif_spot_label = self.labels[list(id_motif_celltype)]
1079
+ fig, ax = plt.subplots(figsize=fig_size)
1080
+ # Plotting the grid lines
1081
+ for x in x_grid:
1082
+ ax.axvline(x, color='lightgray', linestyle='--', lw=0.5)
1083
+
1084
+ for y in y_grid:
1085
+ ax.axhline(y, color='lightgray', linestyle='--', lw=0.5)
1086
+ ax.scatter(grid[id_center, 0], grid[id_center, 1], label='Grid Points',
1087
+ edgecolors='red', facecolors='none', s=8)
1088
+
1089
+ # Plotting other spots as background
1090
+ bg_index = [i for i, _ in enumerate(self.labels) if
1091
+ i not in id_motif_celltype] # the other spots are colored as background
1092
+ # bg_adata = self.adata[bg_index, :]
1093
+ bg_pos = self.spatial_pos[bg_index, :]
1094
+ ax.scatter(bg_pos[:, 0],
1095
+ bg_pos[:, 1],
1096
+ color='darkgrey', s=1)
1097
+
1098
+ motif_unique = list(set(motif))
1099
+ for ct in motif_unique:
1100
+ ct_ind = motif_spot_label == ct
1101
+ ax.scatter(motif_spot_pos[ct_ind, 0],
1102
+ motif_spot_pos[ct_ind, 1],
1103
+ label=ct, color=color_map[ct], s=1)
1104
+
1105
+ ax.set_xlim([xmin - max_dist, xmax + max_dist])
1106
+ ax.set_ylim([ymin - max_dist, ymax + max_dist])
1107
+ ax.legend(title='motif', loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
1108
+ # ax.legend(title='motif', loc='lower center', bbox_to_anchor=(0, 0.), markerscale=4)
1109
+ plt.xlabel('Spatial X')
1110
+ plt.ylabel('Spatial Y')
1111
+ plt.title('Spatial distribution of frequent patterns')
1112
+ ax.set_xticklabels([])
1113
+ ax.set_yticklabels([])
1114
+ ax.set_xticks([])
1115
+ ax.set_yticks([])
1116
+ plt.tight_layout(rect=[0, 0, 1.1, 1])
1117
+ plt.show()
1118
+
1119
+ def plot_motif_rand(self,
1120
+ motif: Union[str, List[str]],
1121
+ fp: pd.DataFrame,
1122
+ max_dist: float = 100,
1123
+ n_points: int = 1000,
1124
+ fig_size: tuple = (10, 5),
1125
+ seed: int = 2023,
1126
+ ):
1127
+ """
1128
+ Display the random sampled points with motif in radius-based neighborhood,
1129
+ and cell types of motif in the neighborhood of these grid points. To make sure the input
1130
+ motif can be found in the results obtained by find_patterns_grid, use the same arguments
1131
+ as those in find_pattern_grid method.
1132
+
1133
+ Parameter
1134
+ ---------
1135
+ motif:
1136
+ Motif (names of cell types) to be colored
1137
+ fp:
1138
+ Frequent patterns identified by find_patterns_grid.
1139
+ max_dist:
1140
+ Spacing distance for building grid. Make sure using the same value as that in find_patterns_grid.
1141
+ n_points:
1142
+ Number of random points to generate.
1143
+ fig_size:
1144
+ Figure size.
1145
+ seed:
1146
+ Set random seed for reproducible.
1147
+ """
1148
+ if isinstance(motif, str):
1149
+ motif = [motif]
1150
+
1151
+ max_dist = min(max_dist, self.max_radius)
1152
+
1153
+ labels_unique = self.labels.unique()
1154
+ motif_exc = [m for m in motif if m not in labels_unique]
1155
+ if len(motif_exc) != 0:
1156
+ print(f"Found no {motif_exc} in {self.label_key}. Ignoring them.")
1157
+ motif = [m for m in motif if m not in motif_exc]
1158
+
1159
+ # Random sample points
1160
+ xmax, ymax = np.max(self.spatial_pos, axis=0)
1161
+ xmin, ymin = np.min(self.spatial_pos, axis=0)
1162
+ np.random.seed(seed)
1163
+ pos = np.column_stack((np.random.rand(n_points) * (xmax - xmin) + xmin,
1164
+ np.random.rand(n_points) * (ymax - ymin) + ymin))
1165
+
1166
+ idxs = self.kd_tree.query_ball_point(pos, r=max_dist, return_sorted=False, workers=-1)
1167
+
1168
+ # Locate the index of grid points acting as centers with motif nearby
1169
+ id_center = []
1170
+ for i, idx in enumerate(idxs):
1171
+ ns = [self.labels[id] for id in idx]
1172
+ if self.has_motif(neighbors=motif, labels=ns):
1173
+ id_center.append(i)
1174
+
1175
+ # Locate the index of cell types contained in motif in the
1176
+ # neighborhood of above random points with motif nearby
1177
+ id_motif_celltype = fp[fp['itemsets'].apply(
1178
+ lambda p: set(p)) == set(motif)]
1179
+ id_motif_celltype = id_motif_celltype['cell_id'].iloc[0]
1180
+
1181
+ # Plot above spots and center grid points
1182
+ # Set color map as in find_patterns_grid
1183
+ fp_cts = sorted(set(t for items in fp['itemsets'] for t in list(items)))
1184
+ n_colors = len(fp_cts)
1185
+ colors = sns.color_palette('hsv', n_colors)
1186
+ color_map = {ct: col for ct, col in zip(fp_cts, colors)}
1187
+
1188
+ motif_spot_pos = self.spatial_pos[list(id_motif_celltype), :]
1189
+ motif_spot_label = self.labels[list(id_motif_celltype)]
1190
+ fig, ax = plt.subplots(figsize=fig_size)
1191
+ ax.scatter(pos[id_center, 0], pos[id_center, 1], label='Random Sampling Points',
1192
+ edgecolors='red', facecolors='none', s=8)
1193
+
1194
+ # Plotting other spots as background
1195
+ bg_index = [i for i, _ in enumerate(self.labels) if
1196
+ i not in id_motif_celltype] # the other spots are colored as background
1197
+ bg_adata = self.spatial_pos[bg_index, :]
1198
+ ax.scatter(bg_adata[:, 0],
1199
+ bg_adata[:, 1],
1200
+ color='darkgrey', s=1)
1201
+ motif_unique = list(set(motif))
1202
+ for ct in motif_unique:
1203
+ ct_ind = motif_spot_label == ct
1204
+ ax.scatter(motif_spot_pos[ct_ind, 0],
1205
+ motif_spot_pos[ct_ind, 1],
1206
+ label=ct, color=color_map[ct], s=1)
1207
+
1208
+ ax.set_xlim([xmin - max_dist, xmax + max_dist])
1209
+ ax.set_ylim([ymin - max_dist, ymax + max_dist])
1210
+ ax.legend(title='motif', loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
1211
+ plt.xlabel('Spatial X')
1212
+ plt.ylabel('Spatial Y')
1213
+ plt.title('Spatial distribution of frequent patterns')
1214
+ ax.set_xticklabels([])
1215
+ ax.set_yticklabels([])
1216
+ ax.set_xticks([])
1217
+ ax.set_yticks([])
1218
+ plt.tight_layout(rect=[0, 0, 1.1, 1])
1219
+ plt.show()
1220
+
1221
+ def plot_motif_celltype(self,
1222
+ ct: str,
1223
+ motif: Union[str, List[str]],
1224
+ max_dist: float = 100,
1225
+ fig_size: tuple = (10, 5)
1226
+ ):
1227
+ """
1228
+ Display the distribution of interested motifs in the radius-based neighborhood of certain cell type.
1229
+ This function is mainly used to visualize the results of motif_enrichment_dist. Make sure the input parameters
1230
+ are consistent with those of motif_enrichment_dist.
1231
+
1232
+ Parameter
1233
+ ---------
1234
+ ct:
1235
+ Cell type as the center cells.
1236
+ motif:
1237
+ Motif (names of cell types) to be colored.
1238
+ max_dist:
1239
+ Spacing distance for building grid. Make sure using the same value as that in find_patterns_grid.
1240
+ fig_size:
1241
+ Figure size.
1242
+ """
1243
+ if isinstance(motif, str):
1244
+ motif = [motif]
1245
+
1246
+ max_dist = min(max_dist, self.max_radius)
1247
+
1248
+ motif_exc = [m for m in motif if m not in self.labels.unique()]
1249
+ if len(motif_exc) != 0:
1250
+ print(f"Found no {motif_exc} in {self.label_key}. Ignoring them.")
1251
+ motif = [m for m in motif if m not in motif_exc]
1252
+
1253
+ if ct not in self.labels.unique():
1254
+ raise ValueError(f"Found no {ct} in {self.label_key}!")
1255
+
1256
+ cinds = [i for i, label in enumerate(self.labels) if label == ct] # id of center cell type
1257
+ # ct_pos = self.spatial_pos[cinds]
1258
+ idxs = self.kd_tree.query_ball_point(self.spatial_pos, r=max_dist, return_sorted=False, workers=-1)
1259
+
1260
+ # find the index of cell type spots whose neighborhoods contain given motif
1261
+ # cind_with_motif = []
1262
+ # sort_motif = sorted(motif)
1263
+ label_encoder = LabelEncoder()
1264
+ int_labels = label_encoder.fit_transform(np.array(self.labels))
1265
+ int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
1266
+ int_motifs = label_encoder.transform(np.array(motif))
1267
+
1268
+ num_cells = len(idxs)
1269
+ num_types = len(label_encoder.classes_)
1270
+ idxs_filter = [np.array(ids)[np.array(ids) != i] for i, ids in enumerate(idxs)]
1271
+
1272
+ flat_neighbors = np.concatenate(idxs_filter)
1273
+ row_indices = np.repeat(np.arange(num_cells), [len(neigh) for neigh in idxs_filter])
1274
+ neighbor_labels = int_labels[flat_neighbors]
1275
+
1276
+ neighbor_matrix = np.zeros((num_cells, num_types), dtype=int)
1277
+ np.add.at(neighbor_matrix, (row_indices, neighbor_labels), 1)
1278
+
1279
+ mask = int_labels == int_ct
1280
+ inds = np.where(np.all(neighbor_matrix[mask][:, int_motifs] > 0, axis=1))[0]
1281
+ cind_with_motif = [cinds[i] for i in inds]
1282
+
1283
+ # for id in cinds:
1284
+ #
1285
+ # if self.has_motif(sort_motif, [self.labels[idx] for idx in idxs[id] if idx != id]):
1286
+ # cind_with_motif.append(id)
1287
+
1288
+ # Locate the index of motifs in the neighborhood of center cell type.
1289
+
1290
+ motif_mask = np.isin(np.array(self.labels), motif)
1291
+ all_neighbors = np.concatenate(idxs[cind_with_motif])
1292
+ exclude_self_mask = ~np.isin(all_neighbors, cind_with_motif)
1293
+ valid_neighbors = all_neighbors[motif_mask[all_neighbors] & exclude_self_mask]
1294
+ id_motif_celltype = set(valid_neighbors)
1295
+
1296
+ # id_motif_celltype = set()
1297
+ # for id in cind_with_motif:
1298
+ # id_neighbor = [i for i in idxs[id] if self.labels[i] in motif and i != id]
1299
+ # id_motif_celltype.update(id_neighbor)
1300
+
1301
+ # Plot figures
1302
+ motif_unique = set(motif)
1303
+ n_colors = len(motif_unique)
1304
+ colors = sns.color_palette('hsv', n_colors)
1305
+ color_map = {ct: col for ct, col in zip(motif_unique, colors)}
1306
+ motif_spot_pos = self.spatial_pos[list(id_motif_celltype), :]
1307
+ motif_spot_label = self.labels[list(id_motif_celltype)]
1308
+ fig, ax = plt.subplots(figsize=fig_size)
1309
+ # Plotting other spots as background
1310
+ labels_length = len(self.labels)
1311
+ id_motif_celltype_set = set(id_motif_celltype)
1312
+ cind_with_motif_set = set(cind_with_motif)
1313
+ bg_index = [i for i in range(labels_length) if i not in id_motif_celltype_set and i not in cind_with_motif_set]
1314
+ bg_adata = self.spatial_pos[bg_index, :]
1315
+ ax.scatter(bg_adata[:, 0],
1316
+ bg_adata[:, 1],
1317
+ color='darkgrey', s=1)
1318
+ # Plot center the cell type whose neighborhood contains motif
1319
+ ax.scatter(self.spatial_pos[cind_with_motif, 0],
1320
+ self.spatial_pos[cind_with_motif, 1],
1321
+ label=ct, edgecolors='red', facecolors='none', s=3,
1322
+ )
1323
+ for ct_m in motif_unique:
1324
+ ct_ind = motif_spot_label == ct_m
1325
+ ax.scatter(motif_spot_pos[ct_ind, 0],
1326
+ motif_spot_pos[ct_ind, 1],
1327
+ label=ct_m, color=color_map[ct_m], s=1)
1328
+
1329
+ ax.legend(title='motif', loc='center left', bbox_to_anchor=(1, 0.5), markerscale=4)
1330
+ # ax.legend(title='motif', loc='lower center', bbox_to_anchor=(1, 0.5), markerscale=4)
1331
+ plt.xlabel('Spatial X')
1332
+ plt.ylabel('Spatial Y')
1333
+ plt.title(f"Spatial distribution of motif around {ct}")
1334
+ ax.set_xticklabels([])
1335
+ ax.set_yticklabels([])
1336
+ ax.set_xticks([])
1337
+ ax.set_yticks([])
1338
+ plt.tight_layout(rect=[0, 0, 1.1, 1])
1339
+ plt.show()