SpatialQuery 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1116 @@
1
+ from typing import List, Union
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import scipy.stats as stats
6
+ import statsmodels.stats.multitest as mt
7
+ from anndata import AnnData
8
+ from mlxtend.frequent_patterns import fpgrowth
9
+ from sklearn.preprocessing import MultiLabelBinarizer
10
+ from pandas import DataFrame
11
+ from scipy.stats import hypergeom
12
+ import time
13
+ import matplotlib.pyplot as plt
14
+ from collections import defaultdict
15
+ from sklearn.preprocessing import LabelEncoder
16
+
17
+ from .spatial_query import spatial_query
18
+
19
+
20
+ class spatial_query_multi:
21
+ def __init__(self,
22
+ adatas: List[AnnData],
23
+ datasets: List[str],
24
+ spatial_key: str,
25
+ label_key: str,
26
+ leaf_size: int,
27
+ max_radius: float = 500,
28
+ n_split: int = 10,
29
+ ):
30
+ """
31
+ Initiate models, including setting attributes and building kd-tree for each field of view.
32
+
33
+ Parameter
34
+ ---------
35
+ adatas:
36
+ List of adata
37
+ datasets:
38
+ List of dataset names
39
+ spatial_key:
40
+ Spatial coordination name in AnnData.obsm object
41
+ label_key:
42
+ Label name in AnnData.obs object
43
+ leaf_size:
44
+ The largest number of points stored in each leaf node.
45
+ max_radius:
46
+ The upper limit of neighborhood radius.
47
+ """
48
+ # Each element in self.spatial_queries stores a spatial_query object
49
+ self.spatial_key = spatial_key
50
+ self.label_key = label_key
51
+ self.max_radius = max_radius
52
+ # Modify dataset names by d_0, d_2, ... for duplicates in datasets
53
+ count_dict = {}
54
+ modified_datasets = []
55
+ for dataset in datasets:
56
+ if '_' in dataset:
57
+ print(f"Warning: Misusage of underscore in '{dataset}'. Replacing with hyphen.")
58
+ dataset = dataset.replace('_', '-')
59
+
60
+ if dataset in count_dict:
61
+ count_dict[dataset] += 1
62
+ else:
63
+ count_dict[dataset] = 0
64
+
65
+ mod_dataset = f"{dataset}_{count_dict[dataset]}"
66
+ modified_datasets.append(mod_dataset)
67
+
68
+ self.datasets = modified_datasets
69
+
70
+ self.spatial_queries = [spatial_query(
71
+ adata=adata,
72
+ dataset=self.datasets[i],
73
+ spatial_key=spatial_key,
74
+ label_key=label_key,
75
+ leaf_size=leaf_size,
76
+ max_radius=self.max_radius,
77
+ n_split=n_split,
78
+ ) for i, adata in enumerate(adatas)]
79
+
80
+ def find_fp_knn(self,
81
+ ct: str,
82
+ dataset: Union[str, List[str]] = None,
83
+ k: int = 30,
84
+ min_support: float = 0.5,
85
+ max_dist: float = 500
86
+ ) -> pd.DataFrame:
87
+ """
88
+ Find frequent patterns within the KNNs of certain cell type in multiple fields of view.
89
+
90
+ Parameter
91
+ ---------
92
+ ct:
93
+ Cell type name.
94
+ dataset:
95
+ Datasets for searching for frequent patterns.
96
+ Use all datasets if dataset=None.
97
+ k:
98
+ Number of nearest neighbors.
99
+ min_support:
100
+ Threshold of frequency to consider a pattern as a frequent pattern.
101
+ max_dist:
102
+ The maximum distance at which points are considered neighbors.
103
+
104
+ Return
105
+ ------
106
+ Frequent patterns in the neighborhood of certain cell type.
107
+ """
108
+ # Search transactions for each field of view, find the frequent patterns of integrated transactions
109
+ # start = time.time()
110
+ if_exist_label = [ct in s.labels.unique() for s in self.spatial_queries]
111
+ if not any(if_exist_label):
112
+ raise ValueError(f"Found no {self.label_key} in all datasets!")
113
+
114
+ if dataset is None:
115
+ # Use all datasets if dataset is not provided
116
+ dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
117
+
118
+ # Make sure dataset is a list
119
+ if isinstance(dataset, str):
120
+ dataset = [dataset]
121
+
122
+ # test if the input dataset name is valid
123
+ valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
124
+ for ds in dataset:
125
+ if ds not in valid_ds_names:
126
+ raise ValueError(f"Invalid input dataset name: {ds}.\n "
127
+ f"Valid dataset names are: {set(valid_ds_names)}")
128
+
129
+ max_dist = min(max_dist, self.max_radius)
130
+ # end = time.time()
131
+ # print(f"time for checking validation of inputs: {end-start} seconds")
132
+
133
+ # start = time.time()
134
+ transactions = []
135
+ for s in self.spatial_queries:
136
+ if s.dataset.split('_')[0] not in dataset:
137
+ continue
138
+ cell_pos = s.spatial_pos
139
+ labels = np.array(s.labels)
140
+ if ct not in np.unique(labels):
141
+ continue
142
+
143
+ ct_pos = cell_pos[labels == ct]
144
+ dists, idxs = s.kd_tree.query(ct_pos, k=k + 1, workers=-1)
145
+ mask = dists < max_dist
146
+ for i, idx in enumerate(idxs):
147
+ inds = idx[mask[i]]
148
+ transaction = labels[inds[1:]]
149
+ transactions.append(transaction)
150
+ # end = time.time()
151
+ # print(f"time for building {len(transactions)} transactions: {end-start} seconds.")
152
+
153
+ # start = time.time()
154
+ mlb = MultiLabelBinarizer()
155
+ encoded_data = mlb.fit_transform(transactions)
156
+ df = pd.DataFrame(encoded_data.astype(bool), columns=mlb.classes_)
157
+ # end = time.time()
158
+ # print(f"time for building df for fpgrowth: {end-start} seconds")
159
+
160
+ # start = time.time()
161
+ fp = fpgrowth(df, min_support=min_support, use_colnames=True)
162
+ if len(fp) == 0:
163
+ return pd.DataFrame(columns=['support', 'itemsets'])
164
+ # end = time.time()
165
+ # print(f"time for find fp_growth: {end-start} seconds, {len(fp)} frequent patterns.")
166
+ # start = time.time()
167
+ fp = spatial_query.find_maximal_patterns(fp=fp)
168
+ # end = time.time()
169
+ # print(f"time for identify maximal patterns: {end - start} seconds")
170
+
171
+ fp['itemsets'] = fp['itemsets'].apply(lambda x: list(sorted(x)))
172
+ fp.sort_values(by='support', ascending=False, inplace=True, ignore_index=True)
173
+
174
+ return fp
175
+
176
+ def find_fp_dist(self,
177
+ ct: str,
178
+ dataset: Union[str, List[str]] = None,
179
+ max_dist: float = 100,
180
+ min_size: int = 0,
181
+ min_support: float = 0.5,
182
+ max_ns: int = 100
183
+ ):
184
+ """
185
+ Find frequent patterns within the radius of certain cell type in multiple fields of view.
186
+
187
+ Parameter
188
+ ---------
189
+ ct:
190
+ Cell type name.
191
+ dataset:
192
+ Datasets for searching for frequent patterns.
193
+ Use all datasets if dataset=None.
194
+ max_dist:
195
+ Maximum distance for considering a cell as a neighbor.
196
+ min_size:
197
+ Minimum neighborhood size for each point to consider.
198
+ min_support:
199
+ Threshold of frequency to consider a pattern as a frequent pattern.
200
+ max_ns:
201
+ Upper limit of neighbors for each point.
202
+
203
+ Return
204
+ ------
205
+ Frequent patterns in the neighborhood of certain cell type.
206
+ """
207
+ # Search transactions for each field of view, find the frequent patterns of integrated transactions
208
+ if_exist_label = [ct in s.labels.unique() for s in self.spatial_queries]
209
+ if not any(if_exist_label):
210
+ raise ValueError(f"Found no {self.label_key} in any datasets!")
211
+
212
+ if dataset is None:
213
+ dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
214
+ if isinstance(dataset, str):
215
+ dataset = [dataset]
216
+
217
+ # test if the input dataset name is valid
218
+ valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
219
+ for ds in dataset:
220
+ if ds not in valid_ds_names:
221
+ raise ValueError(f"Invalid input dataset name: {ds}.\n "
222
+ f"Valid dataset names are: {set(valid_ds_names)}")
223
+
224
+ max_dist = min(max_dist, self.max_radius)
225
+ # start = time.time()
226
+ transactions = []
227
+ for s in self.spatial_queries:
228
+ if s.dataset.split('_')[0] not in dataset:
229
+ continue
230
+ cell_pos = s.spatial_pos
231
+ labels = np.array(s.labels)
232
+ if ct not in np.unique(labels):
233
+ continue
234
+
235
+ cinds = [id for id, l in enumerate(labels) if l == ct]
236
+ ct_pos = cell_pos[cinds]
237
+
238
+ idxs = s.kd_tree.query_ball_point(ct_pos, r=max_dist, return_sorted=False, workers=-1)
239
+
240
+ for i_id, idx in zip(cinds, idxs):
241
+ transaction = [labels[i] for i in idx[:min(max_ns, len(idx))] if i != i_id]
242
+ if len(transaction) > min_size:
243
+ transactions.append(transaction)
244
+
245
+ # end = time.time()
246
+ # print(f"time for building {len(transactions)} transactions: {end-start} seconds.")
247
+
248
+ # start = time.time()
249
+ mlb = MultiLabelBinarizer()
250
+ encoded_data = mlb.fit_transform(transactions)
251
+ df = pd.DataFrame(encoded_data.astype(bool), columns=mlb.classes_)
252
+ # end = time.time()
253
+ # print(f"time for building df for fpgrowth: {end - start} seconds")
254
+
255
+ # start = time.time()
256
+ fp = fpgrowth(df, min_support=min_support, use_colnames=True)
257
+ if len(fp) == 0:
258
+ return pd.DataFrame(columns=['support', 'itemsets'])
259
+ # end = time.time()
260
+ # print(f"time for find fp_growth: {end - start} seconds, {len(fp)} frequent patterns.")
261
+
262
+ # start = time.time()
263
+ fp = spatial_query.find_maximal_patterns(fp=fp)
264
+ # end = time.time()
265
+ # print(f"time for identify maximal patterns: {end - start} seconds")
266
+
267
+ fp['itemsets'] = fp['itemsets'].apply(lambda x: list(sorted(x)))
268
+ fp.sort_values(by='support', ascending=False, inplace=True, ignore_index=True)
269
+
270
+ return fp
271
+
272
+ def motif_enrichment_knn(self,
273
+ ct: str,
274
+ motifs: Union[str, List[str]] = None,
275
+ dataset: Union[str, List[str]] = None,
276
+ k: int = 30,
277
+ min_support: float = 0.5,
278
+ max_dist: float = 500.0,
279
+ ) -> pd.DataFrame:
280
+ """
281
+ Perform motif enrichment analysis using k-nearest neighbors (KNN) in multiple fields of view.
282
+
283
+ Parameter
284
+ ---------
285
+ ct:
286
+ The cell type of the center cell.
287
+ motifs:
288
+ Specified motifs to be tested.
289
+ If motifs=None, find the frequent patterns as motifs within
290
+ the neighborhood of center cell type in each fov.
291
+ dataset:
292
+ Datasets for searching for frequent patterns and performing enrichment analysis.
293
+ Use all datasets if dataset=None.
294
+ k:
295
+ Number of nearest neighbors to consider.
296
+ min_support:
297
+ Threshold of frequency to consider a pattern as a frequent pattern.
298
+ dis_duplicates:
299
+ Distinguish duplicates in patterns if dis_duplicates=True. This will consider transactions within duplicates
300
+ like (A, A, A, B, C) otherwise only patterns with unique cell types will be considered like (A, B, C).
301
+ max_dist:
302
+ Maximum distance for neighbors (default: 500).
303
+
304
+ Return
305
+ ------
306
+ pd.Dataframe containing the cell type name, motifs, number of motifs nearby given cell type,
307
+ number of spots of cell type, number of motifs in single FOV, p value of hypergeometric distribution.
308
+ """
309
+ if dataset is None:
310
+ dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
311
+ if isinstance(dataset, str):
312
+ dataset = [dataset]
313
+
314
+ max_dist = min(max_dist, self.max_radius)
315
+
316
+ out = []
317
+ if_exist_label = [ct in s.labels.unique() for s in self.spatial_queries]
318
+ if not any(if_exist_label):
319
+ raise ValueError(f"Found no {self.label_key} in any datasets!")
320
+
321
+ # Check whether specify motifs. If not, search frequent patterns among specified datasets
322
+ # and use them as interested motifs
323
+ all_labels = pd.concat([s.labels for s in self.spatial_queries])
324
+ labels_unique_all = set(all_labels.unique())
325
+ if motifs is None:
326
+ fp = self.find_fp_knn(ct=ct, k=k, dataset=dataset,
327
+ min_support=min_support, max_dist=max_dist)
328
+ motifs = fp['itemsets'].tolist()
329
+ else:
330
+ if isinstance(motifs, str):
331
+ motifs = [motifs]
332
+
333
+ motifs_exc = [m for m in motifs if m not in labels_unique_all]
334
+ if len(motifs_exc) != 0:
335
+ print(f"Found no {motifs_exc} in {dataset}. Ignoring them.")
336
+ motifs = [m for m in motifs if m not in motifs_exc]
337
+ if len(motifs) == 0:
338
+ raise ValueError(f"All cell types in motifs are missed in {self.label_key}.")
339
+ motifs = [motifs]
340
+
341
+ for motif in motifs:
342
+ n_labels = 0
343
+ n_ct = 0
344
+ n_motif_labels = 0
345
+ n_motif_ct = 0
346
+
347
+ motif = list(motif) if not isinstance(motif, list) else motif
348
+ sort_motif = sorted(motif)
349
+
350
+ # Calculate statistics of each dataset
351
+ for fov, s in enumerate(self.spatial_queries):
352
+ if s.dataset.split('_')[0] not in dataset:
353
+ continue
354
+
355
+ cell_pos = s.spatial_pos
356
+ labels = np.array(s.labels)
357
+ labels_unique = np.unique(labels)
358
+
359
+ contain_motif = [m in labels_unique for m in motif]
360
+ if not np.all(contain_motif):
361
+ n_labels += labels.shape[0]
362
+ n_ct += np.sum(labels == ct)
363
+ continue
364
+ else:
365
+ n_labels += labels.shape[0]
366
+ label_encoder = LabelEncoder()
367
+ int_labels = label_encoder.fit_transform(labels)
368
+ int_motifs = label_encoder.transform(np.array(motif))
369
+
370
+ dists, idxs = s.kd_tree.query(cell_pos, k=k + 1, workers=-1)
371
+ num_cells = idxs.shape[0]
372
+ num_types = len(label_encoder.classes_)
373
+
374
+ valid_neighbors = dists[:, 1:] <= max_dist
375
+ filtered_idxs = np.where(valid_neighbors, idxs[:, 1:], -1)
376
+ flat_neighbors = filtered_idxs.flatten()
377
+ valid_neighbors_flat = valid_neighbors.flatten()
378
+ neighbor_labels = np.where(valid_neighbors_flat, int_labels[flat_neighbors], -1)
379
+ valid_mask = neighbor_labels != -1
380
+
381
+ neighbor_matrix = np.zeros((num_cells * k, num_types), dtype=int)
382
+ neighbor_matrix[np.arange(len(neighbor_labels))[valid_mask], neighbor_labels[valid_mask]] = 1
383
+ neighbor_counts = neighbor_matrix.reshape(num_cells, k, num_types).sum(axis=1)
384
+
385
+ n_motif_labels += np.sum(np.all(neighbor_counts[:, int_motifs] > 0, axis=1))
386
+
387
+ if ct in np.unique(labels):
388
+ int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
389
+ mask = int_labels == int_ct
390
+ n_motif_ct += np.sum(np.all(neighbor_counts[mask][:, int_motifs] > 0, axis=1))
391
+ n_ct += np.sum(mask)
392
+
393
+ # for i in range(len(labels)):
394
+ # if spatial_query.has_motif(sort_motif, [labels[idx] for idx in idxs[i][1:]]):
395
+ # n_motif_labels += 1
396
+ # n_labels += len(labels)
397
+
398
+ # if ct not in labels.unique():
399
+ # continue
400
+ # cinds = [i for i, l in enumerate(labels) if l == ct]
401
+ #
402
+ # for i in cinds:
403
+ # inds = [ind for ind, d in enumerate(dists[i]) if d < max_dist]
404
+ # if len(inds) > 1:
405
+ # if spatial_query.has_motif(sort_motif, [labels[idx] for idx in idxs[i][inds[1:]]]):
406
+ # n_motif_ct += 1
407
+
408
+ # n_ct += len(cinds)
409
+
410
+ if ct in motif:
411
+ n_ct = round(n_ct / motif.count(ct))
412
+
413
+ hyge = hypergeom(M=n_labels, n=n_ct, N=n_motif_labels)
414
+ motif_out = {'center': ct, 'motifs': sort_motif, 'n_center_motif': n_motif_ct,
415
+ 'n_center': n_ct, 'n_motif': n_motif_labels, 'expectation': hyge.mean(), 'p-values': hyge.sf(n_motif_ct)}
416
+ out.append(motif_out)
417
+
418
+ out_pd = pd.DataFrame(out)
419
+
420
+ if len(out_pd) == 1:
421
+ out_pd['if_significant'] = True if out_pd['p-values'][0] < 0.05 else False
422
+ return out_pd
423
+ else:
424
+ p_values = out_pd['p-values'].tolist()
425
+ if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
426
+ alpha=0.05,
427
+ method='poscorr')
428
+ out_pd['corrected p-values'] = corrected_p_values
429
+ out_pd['if_significant'] = if_rejected
430
+ out_pd = out_pd.sort_values(by='corrected p-values', ignore_index=True)
431
+ return out_pd
432
+
433
+ def motif_enrichment_dist(self,
434
+ ct: str,
435
+ motifs: Union[str, List[str]] = None,
436
+ dataset: Union[str, List[str]] = None,
437
+ max_dist: float = 100,
438
+ min_size: int = 0,
439
+ min_support: float = 0.5,
440
+ max_ns: int = 100) -> DataFrame:
441
+ """
442
+ Perform motif enrichment analysis within a specified radius-based neighborhood in multiple fields of view.
443
+
444
+ Parameter
445
+ ---------
446
+ ct:
447
+ Cell type of the center cell.
448
+ motifs:
449
+ Specified motifs to be tested.
450
+ If motifs=None, find the frequent patterns as motifs within the neighborhood of center cell type.
451
+ dataset:
452
+ Datasets for searching for frequent patterns and performing enrichment analysis.
453
+ Use all datasets if dataset=None.
454
+ max_dist:
455
+ Maximum distance for considering a cell as a neighbor.
456
+ min_size:
457
+ Minimum neighborhood size for each point to consider.
458
+ min_support:
459
+ Threshold of frequency to consider a pattern as a frequent pattern.
460
+ dis_duplicates:
461
+ Distinguish duplicates in patterns if dis_duplicates=True. This will consider transactions within duplicates
462
+ like (A, A, A, B, C) otherwise only patterns with unique cell types will be considered like (A, B, C).
463
+ max_ns:
464
+ Maximum number of neighborhood size for each point.
465
+ Returns
466
+ -------
467
+ Tuple containing counts and statistical measures.
468
+ """
469
+ if dataset is None:
470
+ dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
471
+ if isinstance(dataset, str):
472
+ dataset = [dataset]
473
+
474
+ out = []
475
+ if_exist_label = [ct in s.labels.unique() for s in self.spatial_queries]
476
+ if not any(if_exist_label):
477
+ raise ValueError(f"Found no {self.label_key} in any datasets!")
478
+
479
+ max_dist = min(max_dist, self.max_radius)
480
+
481
+ # Check whether specify motifs. If not, search frequent patterns among specified datasets
482
+ # and use them as interested motifs
483
+ if motifs is None:
484
+ fp = self.find_fp_dist(ct=ct, dataset=dataset, max_dist=max_dist, min_size=min_size,
485
+ min_support=min_support, max_ns=max_ns)
486
+ motifs = fp['itemsets'].tolist()
487
+ else:
488
+ if isinstance(motifs, str):
489
+ motifs = [motifs]
490
+
491
+ all_labels = pd.concat([s.labels for s in self.spatial_queries])
492
+ labels_unique_all = set(all_labels.unique())
493
+ motifs_exc = [m for m in motifs if m not in labels_unique_all]
494
+ if len(motifs_exc) != 0:
495
+ print(f"Found no {motifs_exc} in {dataset}! Ignoring them.")
496
+ motifs = [m for m in motifs if m not in motifs_exc]
497
+ if len(motifs) == 0:
498
+ raise ValueError(f"All cell types in motifs are missed in {self.label_key}.")
499
+ motifs = [motifs]
500
+
501
+ for motif in motifs:
502
+ n_labels = 0
503
+ n_ct = 0
504
+ n_motif_labels = 0
505
+ n_motif_ct = 0
506
+
507
+ motif = list(motif) if not isinstance(motif, list) else motif
508
+ sort_motif = sorted(motif)
509
+
510
+ for s in self.spatial_queries:
511
+ if s.dataset.split('_')[0] not in dataset:
512
+ continue
513
+ cell_pos = s.spatial_pos
514
+ labels = np.array(s.labels)
515
+ labels_unique = np.unique(labels)
516
+
517
+ contain_motif = [m in labels_unique for m in motif]
518
+ if not np.all(contain_motif):
519
+ n_labels += labels.shape[0]
520
+ n_ct += np.sum(labels == ct)
521
+ continue
522
+ else:
523
+ n_labels += labels.shape[0]
524
+ _, matching_cells_indices = s._query_pattern(motif)
525
+ if not matching_cells_indices:
526
+ # if matching_cells_indices is empty, it indicates no motif are grouped together within upper limit of radius (500)
527
+ continue
528
+ matching_cells_indices = np.concatenate([t for t in matching_cells_indices.values()])
529
+ matching_cells_indices = np.unique(matching_cells_indices)
530
+ matching_cells_indices.sort()
531
+ # print(f"number of cells skipped: {len(matching_cells_indices)}")
532
+ # print(f"proportion of cells searched: {len(matching_cells_indices) / len(s.spatial_pos)}")
533
+ idxs_in_grids = s.kd_tree.query_ball_point(
534
+ s.spatial_pos[matching_cells_indices],
535
+ r=max_dist,
536
+ return_sorted=False,
537
+ workers=-1
538
+ )
539
+
540
+ # using numppy
541
+ label_encoder = LabelEncoder()
542
+ int_labels = label_encoder.fit_transform(labels)
543
+ int_motifs = label_encoder.transform(np.array(motif))
544
+
545
+ num_cells = len(s.spatial_pos)
546
+ num_types = len(label_encoder.classes_)
547
+ # filter center out of neighbors
548
+ idxs_filter = [np.array(ids)[np.array(ids) != i][:min(max_ns, len(ids))] for i, ids in
549
+ zip(matching_cells_indices, idxs_in_grids)]
550
+
551
+ num_matching_cells = len(matching_cells_indices)
552
+ flat_neighbors = np.concatenate(idxs_filter)
553
+ row_indices = np.repeat(np.arange(num_matching_cells), [len(neigh) for neigh in idxs_filter])
554
+ neighbor_labels = int_labels[flat_neighbors]
555
+
556
+ neighbor_matrix = np.zeros((num_matching_cells, num_types), dtype=int)
557
+ np.add.at(neighbor_matrix, (row_indices, neighbor_labels), 1)
558
+
559
+ n_motif_labels += np.sum(np.all(neighbor_matrix[:, int_motifs] > 0, axis=1))
560
+
561
+ if ct in np.unique(labels):
562
+ int_ct = label_encoder.transform(np.array(ct, dtype=object, ndmin=1))
563
+ mask = int_labels[matching_cells_indices] == int_ct
564
+ n_motif_ct += np.sum(np.all(neighbor_matrix[mask][:, int_motifs] > 0, axis=1))
565
+ n_ct += np.sum(s.labels == ct)
566
+
567
+ # ~10s using C++ codes
568
+ # idxs = idxs.tolist()
569
+ # cinds = [i for i, label in enumerate(labels) if label == ct]
570
+ # n_motif_ct_s, n_motif_labels_s = spatial_module_utils.search_motif_dist(
571
+ # motif, idxs, labels, cinds, max_ns
572
+ # )
573
+ # n_motif_ct += n_motif_ct_s
574
+ # n_motif_labels += n_motif_labels_s
575
+ # n_ct += len(cinds)
576
+
577
+ # original codes, ~ minutes
578
+ # for i in range(len(idxs)):
579
+ # e = min(len(idxs[i]), max_ns)
580
+ # if spatial_query.has_motif(sort_motif, [labels[idx] for idx in idxs[i][:e] if idx != i]):
581
+ # n_motif_labels += 1
582
+ #
583
+ # if ct not in labels.unique():
584
+ # continue
585
+ #
586
+ # cinds = [i for i, label in enumerate(labels) if label == ct]
587
+ #
588
+ # for i in cinds:
589
+ # e = min(len(idxs[i]), max_ns)
590
+ # if spatial_query.has_motif(sort_motif, [labels[idx] for idx in idxs[i][:e] if idx != i]):
591
+ # n_motif_ct += 1
592
+ #
593
+ # n_ct += len(cinds)
594
+
595
+ if ct in motif:
596
+ n_ct = round(n_ct / motif.count(ct))
597
+ hyge = hypergeom(M=n_labels, n=n_ct, N=n_motif_labels)
598
+ motif_out = {'center': ct, 'motifs': sort_motif, 'n_center_motif': n_motif_ct,
599
+ 'n_center': n_ct, 'n_motif': n_motif_labels, 'expectation': hyge.mean(), 'p-values': hyge.sf(n_motif_ct)}
600
+ out.append(motif_out)
601
+
602
+ out_pd = pd.DataFrame(out)
603
+
604
+ if len(out_pd) == 1:
605
+ out_pd['if_significant'] = True if out_pd['p-values'][0] < 0.05 else False
606
+ return out_pd
607
+ else:
608
+ p_values = out_pd['p-values'].tolist()
609
+ if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
610
+ alpha=0.05,
611
+ method='poscorr')
612
+ out_pd['corrected p-values'] = corrected_p_values
613
+ out_pd['if_significant'] = if_rejected
614
+ out_pd = out_pd.sort_values(by='corrected p-values', ignore_index=True)
615
+ return out_pd
616
+
617
+ def find_fp_knn_fov(self,
618
+ ct: str,
619
+ dataset_i: str,
620
+ k: int = 30,
621
+ min_support: float = 0.5,
622
+ max_dist: float = 500.0
623
+ ) -> pd.DataFrame:
624
+ """
625
+ Find frequent patterns within the KNNs of specific cell type of interest in single field of view.
626
+
627
+ Parameter
628
+ ---------
629
+ ct:
630
+ Cell type name.
631
+ dataset_i:
632
+ Datasets for searching for frequent patterns in dataset_i format.
633
+ k:
634
+ Number of nearest neighbors.
635
+ min_support:
636
+ Threshold of frequency to consider a pattern as a frequent pattern.
637
+ max_dist:
638
+ Maximum distance for considering a cell as a neighbor.
639
+
640
+ Return
641
+ ------
642
+ Frequent patterns in the neighborhood of certain cell type.
643
+ """
644
+ if dataset_i not in self.datasets:
645
+ raise ValueError(f"Found no {dataset_i.split('_')[0]} in any datasets.")
646
+
647
+ max_dist = min(max_dist, self.max_radius)
648
+
649
+ sp_object = self.spatial_queries[self.datasets.index(dataset_i)]
650
+ cell_pos = sp_object.spatial_pos
651
+ labels = np.array(sp_object.labels)
652
+ if ct not in np.unique(labels):
653
+ return pd.DataFrame(columns=['support', 'itemsets'])
654
+
655
+ ct_pos = cell_pos[labels == ct]
656
+
657
+ # Identify frequent patterns of cell types, including those subsets of patterns
658
+ # whose support value exceeds min_support. Focus solely on the multiplicity
659
+ # of cell types, rather than their frequency.
660
+ fp, _, _ = sp_object.build_fptree_knn(
661
+ cell_pos=ct_pos,
662
+ k=k,
663
+ min_support=min_support,
664
+ if_max=False,
665
+ max_dist=max_dist,
666
+ )
667
+ return fp
668
+
669
+ def find_fp_dist_fov(self,
670
+ ct: str,
671
+ dataset_i: str,
672
+ max_dist: float = 100,
673
+ min_size: int = 0,
674
+ min_support: float = 0.5,
675
+ max_ns: int = 100
676
+ ):
677
+ """
678
+ Find frequent patterns within the radius-based neighborhood of specific cell type of interest
679
+ in single field of view.
680
+
681
+ Parameter
682
+ ---------
683
+ ct:
684
+ Cell type name.
685
+ dataset_i:
686
+ Datasets for searching for frequent patterns in dataset_i format.
687
+ max_dist:
688
+ Maximum distance for considering a cell as a neighbor.
689
+ min_size:
690
+ Minimum neighborhood size for each point to consider.
691
+ min_support:
692
+ Threshold of frequency to consider a pattern as a frequent pattern.
693
+ max_ns:
694
+ Maximum number of neighborhood size for each point.
695
+
696
+ Return
697
+ ------
698
+ Frequent patterns in the neighborhood of certain cell type.
699
+ """
700
+ if dataset_i not in self.datasets:
701
+ raise ValueError(f"Found no {dataset_i.split('_')[0]} in any datasets.")
702
+
703
+ max_dist = min(max_dist, self.max_radius)
704
+
705
+ sp_object = self.spatial_queries[self.datasets.index(dataset_i)]
706
+ cell_pos = sp_object.spatial_pos
707
+ labels = sp_object.labels
708
+ if ct not in labels.unique():
709
+ return pd.DataFrame(columns=['support, itemsets'])
710
+
711
+ cinds = [id for id, l in enumerate(labels) if l == ct]
712
+ ct_pos = cell_pos[cinds]
713
+
714
+ fp, _, _ = sp_object.build_fptree_dist(cell_pos=ct_pos,
715
+ max_dist=max_dist,
716
+ min_support=min_support,
717
+ min_size=min_size,
718
+ if_max=False,
719
+ cinds=cinds,
720
+ max_ns=max_ns,
721
+ )
722
+ return fp
723
+
724
+ def differential_analysis_knn(self,
725
+ ct: str,
726
+ datasets: List[str],
727
+ k: int = 30,
728
+ min_support: float = 0.5,
729
+ max_dist: float = 500,
730
+ ):
731
+ """
732
+ Explore the differences in cell types and frequent patterns of cell types in spatial KNN neighborhood of cell
733
+ type of interest. Perform differential analysis of frequent patterns in specified datasets.
734
+
735
+ Parameter
736
+ ---------
737
+ ct:
738
+ Cell type of interest as center point.
739
+ datasets:
740
+ Dataset names used to perform differential analysis
741
+ k:
742
+ Number of nearest neighbors.
743
+ min_support:
744
+ Threshold of frequency to consider a pattern as a frequent pattern.
745
+ max_dist:
746
+ Maximum distance for considering a cell as a neighbor.
747
+
748
+ Return
749
+ ------
750
+ Dataframes with significant enriched patterns in differential analysis
751
+ """
752
+ if len(datasets) != 2:
753
+ raise ValueError("Require 2 datasets for differential analysis.")
754
+
755
+ max_dist = min(max_dist, self.max_radius)
756
+
757
+ # Check if the two datasets are valid
758
+ valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
759
+ for ds in datasets:
760
+ if ds not in valid_ds_names:
761
+ raise ValueError(f"Invalid input dataset name: {ds}.\n"
762
+ f"Valid dataset names are: {set(valid_ds_names)}")
763
+
764
+ flag = 0
765
+ # Identify frequent patterns in each dataset
766
+ for d in datasets:
767
+ fp_d = {}
768
+ dataset_i = [ds for ds in self.datasets if ds.split('_')[0] == d]
769
+ for d_i in dataset_i:
770
+ fp_fov = self.find_fp_knn_fov(ct=ct,
771
+ dataset_i=d_i,
772
+ k=k,
773
+ min_support=min_support,
774
+ max_dist=max_dist)
775
+ if len(fp_fov) > 0:
776
+ fp_d[d_i] = fp_fov
777
+
778
+ if len(fp_d) == 1:
779
+ common_patterns = list(fp_d.values())[0]
780
+ common_patterns = common_patterns.rename(columns={'support': f"support_{list(fp_d.keys())[0]}"})
781
+ else:
782
+ # in comm_fps, duplicates items are not allowed by using set object
783
+ comm_fps = set.intersection(
784
+ *[set(df['itemsets'].apply(lambda x: tuple(sorted(x)))) for df in
785
+ fp_d.values()]) # the items' order in patterns will not affect the returned intersection
786
+ common_patterns = pd.DataFrame({'itemsets': [list(items) for items in comm_fps]})
787
+ for data_name, df in fp_d.items():
788
+ support_dict = {itemset: support for itemset, support in
789
+ df[['itemsets', 'support']].apply(
790
+ lambda row: (tuple(sorted(row['itemsets'])), row['support']), axis=1)}
791
+ # support_dict = {tuple(itemset): support for itemset, support in df[['itemsets', 'support']].apply(
792
+ # lambda row: (tuple(row['itemsets']), row['support']), axis=1)}
793
+ common_patterns[f"support_{data_name}"] = common_patterns['itemsets'].apply(
794
+ lambda x: support_dict.get(tuple(x), None))
795
+ common_patterns['itemsets'] = common_patterns['itemsets'].apply(tuple)
796
+ if flag == 0:
797
+ fp_datasets = common_patterns
798
+ flag = 1
799
+ else:
800
+ fp_datasets = fp_datasets.merge(common_patterns, how='outer', on='itemsets', ).fillna(0)
801
+
802
+ match_ind_datasets = [
803
+ [col for ind, col in enumerate(fp_datasets.columns) if col.startswith(f"support_{dataset}")] for dataset in
804
+ datasets]
805
+ p_values = []
806
+ dataset_higher_ranks = []
807
+ for index, row in fp_datasets.iterrows():
808
+ group1 = pd.to_numeric(row[match_ind_datasets[0]].values)
809
+ group2 = pd.to_numeric(row[match_ind_datasets[1]].values)
810
+
811
+ # Perform the Mann-Whitney U test
812
+ stat, p = stats.mannwhitneyu(group1, group2, alternative='two-sided', method='auto')
813
+ p_values.append(p)
814
+
815
+ # Label the dataset with higher frequency of patterns based on rank median
816
+ support_rank = pd.concat([pd.DataFrame(group1), pd.DataFrame(group2)]).rank() # ascending
817
+ median_rank1 = support_rank[:len(group1)].median()[0]
818
+ median_rank2 = support_rank[len(group1):].median()[0]
819
+ if median_rank1 > median_rank2:
820
+ dataset_higher_ranks.append(datasets[0])
821
+ else:
822
+ dataset_higher_ranks.append(datasets[1])
823
+
824
+ fp_datasets['dataset_higher_frequency'] = dataset_higher_ranks
825
+ # Apply Benjamini-Hochberg correction for multiple testing problems
826
+ if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
827
+ alpha=0.05,
828
+ method='poscorr')
829
+
830
+ # Add the corrected p-values back to the DataFrame (optional)
831
+ fp_datasets['corrected_p_values'] = corrected_p_values
832
+ fp_datasets['if_significant'] = if_rejected
833
+
834
+ # Return the significant patterns in each dataset
835
+ fp_dataset0 = fp_datasets[
836
+ (fp_datasets['dataset_higher_frequency'] == datasets[0]) & (fp_datasets['if_significant'])
837
+ ][['itemsets', 'corrected_p_values']]
838
+ fp_dataset1 = fp_datasets[
839
+ (fp_datasets['dataset_higher_frequency'] == datasets[1]) & (fp_datasets['if_significant'])
840
+ ][['itemsets', 'corrected_p_values']]
841
+ fp_dataset0 = fp_dataset0.reset_index(drop=True)
842
+ fp_dataset1 = fp_dataset1.reset_index(drop=True)
843
+ fp_dataset0 = fp_dataset0.sort_values(by='corrected_p_values', ascending=True)
844
+ fp_dataset1 = fp_dataset1.sort_values(by='corrected_p_values', ascending=True)
845
+ return {datasets[0]: fp_dataset0, datasets[1]: fp_dataset1}
846
+
847
+ def differential_analysis_dist(self,
848
+ ct: str,
849
+ datasets: List[str],
850
+ max_dist: float = 100,
851
+ min_support: float = 0.5,
852
+ min_size: int = 0,
853
+ max_ns: int = 100,
854
+ ):
855
+ """
856
+ Explore the differences in cell types and frequent patterns of cell types in spatial radius-based neighborhood
857
+ of cell type of interest. Perform differential analysis of frequent patterns in specified datasets.
858
+
859
+ Parameter
860
+ ---------
861
+ ct:
862
+ Cell type of interest as center point.
863
+ datasets:
864
+ Dataset names used to perform differential analysis
865
+ max_dist:
866
+ Maximum distance for considering a cell as a neighbor.
867
+ min_support:
868
+ Threshold of frequency to consider a pattern as a frequent pattern.
869
+ min_size:
870
+ Minimum neighborhood size for each point to consider.
871
+ max_ns:
872
+ Upper limit of neighbors for each point.
873
+
874
+ Return
875
+ ------
876
+ Dataframes with significant enriched patterns in differential analysis
877
+ """
878
+ if len(datasets) != 2:
879
+ raise ValueError("Require 2 datasets for differential analysis.")
880
+
881
+ # Check if the two datasets are valid
882
+ valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
883
+ for ds in datasets:
884
+ if ds not in valid_ds_names:
885
+ raise ValueError(f"Invalid input dataset name: {ds}.\n"
886
+ f"Valid dataset names are: {set(valid_ds_names)}")
887
+
888
+ max_dist = min(max_dist, self.max_radius)
889
+
890
+ flag = 0
891
+ # Identify frequent patterns in each dataset
892
+ for d in datasets:
893
+ fp_d = {}
894
+ dataset_i = [ds for ds in self.datasets if ds.split('_')[0] == d]
895
+ for d_i in dataset_i:
896
+ fp_fov = self.find_fp_dist_fov(ct=ct,
897
+ dataset_i=d_i,
898
+ max_dist=max_dist,
899
+ min_size=min_size,
900
+ min_support=min_support,
901
+ max_ns=max_ns)
902
+ if len(fp_fov) > 0:
903
+ fp_d[d_i] = fp_fov
904
+
905
+ if len(fp_d) == 1:
906
+ common_patterns = list(fp_d.values())[0]
907
+ common_patterns = common_patterns.rename(columns={'support': f"support_{list(fp_d.keys())[0]}"})
908
+ else:
909
+ comm_fps = set.intersection(*[set(df['itemsets'].apply(lambda x: tuple(sorted(x)))) for df in
910
+ fp_d.values()]) # the items' order in patterns will not affect the returned intersection
911
+ common_patterns = pd.DataFrame({'itemsets': [list(items) for items in comm_fps]})
912
+ for data_name, df in fp_d.items():
913
+ support_dict = {itemset: support for itemset, support in df[['itemsets', 'support']].apply(
914
+ lambda row: (tuple(sorted(row['itemsets'])), row['support']), axis=1)}
915
+ common_patterns[f"support_{data_name}"] = common_patterns['itemsets'].apply(
916
+ lambda x: support_dict.get(tuple(x), None))
917
+ common_patterns['itemsets'] = common_patterns['itemsets'].apply(tuple)
918
+ if flag == 0:
919
+ fp_datasets = common_patterns
920
+ flag = 1
921
+ else:
922
+ fp_datasets = fp_datasets.merge(common_patterns, how='outer', on='itemsets', ).fillna(0)
923
+
924
+ match_ind_datasets = [
925
+ [col for ind, col in enumerate(fp_datasets.columns) if col.startswith(f"support_{dataset}")] for dataset in
926
+ datasets]
927
+ p_values = []
928
+ dataset_higher_ranks = []
929
+ for index, row in fp_datasets.iterrows():
930
+ group1 = pd.to_numeric(row[match_ind_datasets[0]].values)
931
+ group2 = pd.to_numeric(row[match_ind_datasets[1]].values)
932
+
933
+ # Perform the Mann-Whitney U test
934
+ stat, p = stats.mannwhitneyu(group1, group2, alternative='two-sided', method='auto')
935
+ p_values.append(p)
936
+
937
+ # Label the dataset with higher frequency of patterns based on rank sum
938
+ support_rank = pd.concat([pd.DataFrame(group1), pd.DataFrame(group2)]).rank() # ascending
939
+ median_rank1 = support_rank[:len(group1)].median()[0]
940
+ median_rank2 = support_rank[len(group1):].median()[0]
941
+ if median_rank1 > median_rank2:
942
+ dataset_higher_ranks.append(datasets[0])
943
+ else:
944
+ dataset_higher_ranks.append(datasets[1])
945
+
946
+ fp_datasets['dataset_higher_frequency'] = dataset_higher_ranks
947
+ # Apply Benjamini-Hochberg correction for multiple testing problems
948
+ if_rejected, corrected_p_values = mt.fdrcorrection(p_values,
949
+ alpha=0.05,
950
+ method='poscorr')
951
+
952
+ # Add the corrected p-values back to the DataFrame (optional)
953
+ fp_datasets['corrected_p_values'] = corrected_p_values
954
+ fp_datasets['if_significant'] = if_rejected
955
+
956
+ # Return the significant patterns in each dataset
957
+ fp_dataset0 = fp_datasets[
958
+ (fp_datasets['dataset_higher_frequency'] == datasets[0]) & (fp_datasets['if_significant'])
959
+ ][['itemsets', 'corrected_p_values']]
960
+ fp_dataset1 = fp_datasets[
961
+ (fp_datasets['dataset_higher_frequency'] == datasets[1]) & (fp_datasets['if_significant'])
962
+ ][['itemsets', 'corrected_p_values']]
963
+
964
+ fp_dataset0 = fp_dataset0.sort_values(by='corrected_p_values', ascending=True)
965
+ fp_dataset1 = fp_dataset1.sort_values(by='corrected_p_values', ascending=True)
966
+ fp_dataset0 = fp_dataset0.reset_index(drop=True)
967
+ fp_dataset1 = fp_dataset1.reset_index(drop=True)
968
+ return {datasets[0]: fp_dataset0, datasets[1]: fp_dataset1}
969
+
970
+ def cell_type_distribution(self,
971
+ dataset: Union[str, List[str]] = None,
972
+ data_type: str = 'number',
973
+ ):
974
+ """
975
+ Visualize the distribution of cell types across datasets using a stacked bar plot.
976
+
977
+ Parameter
978
+ ---------
979
+ dataset:
980
+ Datasets for searching.
981
+ data_type:
982
+ Plot bar plot by number of cells or by the proportions of datasets in each cell type.
983
+ Default is 'number' otherwise 'proportion' is used.
984
+ Returns
985
+ -------
986
+ Stacked bar plot
987
+ """
988
+ if data_type not in ['number', 'proportion']:
989
+ raise ValueError("Invalild data_type. It should be one of 'number' or 'proportion'.")
990
+
991
+ if dataset is None:
992
+ dataset = [s.dataset.split('_')[0] for s in self.spatial_queries]
993
+ if isinstance(dataset, str):
994
+ dataset = [dataset]
995
+
996
+ valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
997
+ for ds in dataset:
998
+ if ds not in valid_ds_names:
999
+ raise ValueError(f"Invalid input dataset name: {ds}.\n "
1000
+ f"Valid dataset names are: {set(valid_ds_names)}")
1001
+
1002
+ summary = defaultdict(lambda: defaultdict(int))
1003
+
1004
+ valid_queries = [s for s in self.spatial_queries if s.dataset.split('_')[0] in dataset]
1005
+ cell_types = set([ct for s in valid_queries for ct in s.labels.unique()])
1006
+ for s in valid_queries:
1007
+ for cell_type in cell_types:
1008
+ summary[s.dataset][cell_type] += np.sum(s.labels == cell_type)
1009
+
1010
+ df = pd.DataFrame([(dataset, cell_type, count)
1011
+ for dataset, cell_types in summary.items()
1012
+ for cell_type, count in cell_types.items()],
1013
+ columns=['Dataset', 'Cell Type', 'Count'])
1014
+
1015
+ df['dataset'] = df['Dataset'].str.split('_').str[0]
1016
+
1017
+ summary = df.groupby(['dataset', 'Cell Type'])['Count'].sum().reset_index()
1018
+ plot_data = summary.pivot(index='Cell Type', columns='dataset', values='Count').fillna(0)
1019
+
1020
+ # Sort the cell types by total count (descending)
1021
+ plot_data = plot_data.sort_values(by=plot_data.columns.tolist(), ascending=False, )
1022
+
1023
+ if data_type != 'number':
1024
+ plot_data = plot_data.div(plot_data.sum(axis=1), axis=0)
1025
+
1026
+ # Create the stacked bar plot
1027
+ ax = plot_data.plot(kind='bar', stacked=True,
1028
+ figsize=(plot_data.shape[0], plot_data.shape[0] * 0.6),
1029
+ edgecolor='black')
1030
+
1031
+ # Customize the plot
1032
+ plt.title(f"Distribution of Cell Types Across Datasets", fontsize=16)
1033
+ plt.xlabel('Cell Types', fontsize=12)
1034
+ if data_type == 'number':
1035
+ plt.ylabel('Number of Cells', fontsize=12)
1036
+ else:
1037
+ plt.ylabel('Proportion of Cells', fontsize=12)
1038
+
1039
+ plt.xticks(rotation=90, ha='right', fontsize=10)
1040
+
1041
+ plt.legend(title='Datasets', loc='center left', bbox_to_anchor=(1, 0.5), fontsize=10)
1042
+
1043
+ plt.tight_layout(rect=[0, 0, 0.85, 1])
1044
+ plt.show()
1045
+
1046
+ def cell_type_distribution_fov(self,
1047
+ dataset: str,
1048
+ data_type: str = 'number',
1049
+ ):
1050
+ """
1051
+ Visualize the distribution of cell types across FOVs in the dataset using a stacked bar plot.
1052
+ Parameter
1053
+ ---------
1054
+ dataset:
1055
+ Dataset of searching.
1056
+ data_type:
1057
+ Plot bar plot by number of cells or by the proportions of cell types in each FOV.
1058
+ Default is 'number' otherwise 'proportion' is used.
1059
+ Returns
1060
+ -------
1061
+ Stacked bar plot
1062
+ """
1063
+ if data_type not in ['number', 'proportion']:
1064
+ raise ValueError("Invalild data_type. It should be one of 'number' or 'proportion'.")
1065
+
1066
+ valid_ds_names = [s.dataset.split('_')[0] for s in self.spatial_queries]
1067
+ if dataset not in valid_ds_names:
1068
+ raise ValueError(f"Invalid input dataset name: {dataset}. \n"
1069
+ f"Valid dataset names are: {set(valid_ds_names)}")
1070
+ valid_queries = [s for s in self.spatial_queries if s.dataset.split('_')[0] == dataset]
1071
+ cell_types = set([ct for s in valid_queries for ct in s.labels.unique()])
1072
+
1073
+ summary = defaultdict(lambda: defaultdict(int))
1074
+ for s in valid_queries:
1075
+ for cell_type in cell_types:
1076
+ summary[s.dataset][cell_type] = np.sum(s.labels == cell_type)
1077
+
1078
+ df = pd.DataFrame([(dataset, cell_type, count)
1079
+ for dataset, cell_types in summary.items()
1080
+ for cell_type, count in cell_types.items()],
1081
+ columns=['Dataset', 'Cell Type', 'Count'])
1082
+
1083
+ df['FOV'] = df['Dataset'].str.split('_').str[1]
1084
+
1085
+ summary = df.groupby(['FOV', 'Cell Type'])['Count'].sum().reset_index()
1086
+ plot_data = summary.pivot(columns='Cell Type', index='FOV', values='Count').fillna(0)
1087
+
1088
+ # Sort the cell types by total count (descending)
1089
+ row_sums = plot_data.sum(axis=1)
1090
+ plot_data_sorted = plot_data.loc[row_sums.sort_values(ascending=False).index]
1091
+
1092
+ if data_type != 'number':
1093
+ plot_data_sorted = plot_data_sorted.div(plot_data_sorted.sum(axis=1), axis=0)
1094
+
1095
+ # Create the stacked bar plot
1096
+ ax = plot_data_sorted.plot(kind='bar', stacked=True,
1097
+ figsize=(plot_data.shape[0] * 0.6, plot_data.shape[0] * 0.3),
1098
+ edgecolor='black')
1099
+
1100
+ # Customize the plot
1101
+ plt.title(f"Distribution of FOVs in {dataset} dataset", fontsize=20)
1102
+ plt.xlabel('FOV', fontsize=12)
1103
+ if data_type == 'number':
1104
+ plt.ylabel('Number of Cells', fontsize=12)
1105
+ else:
1106
+ plt.ylabel('Proportion of Cells', fontsize=12)
1107
+
1108
+ plt.xticks(rotation=90, ha='right', fontsize=10)
1109
+
1110
+ plt.legend(title='Cell Type', bbox_to_anchor=(1, 1.05), loc='center left', fontsize=12)
1111
+
1112
+ plt.tight_layout()
1113
+ plt.show()
1114
+
1115
+
1116
+