gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,781 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Connectivity matrix building for homogeneous spot identification
|
|
3
|
+
Implements the spatial → anchor → homogeneous neighbor finding algorithm
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from functools import partial
|
|
8
|
+
|
|
9
|
+
import anndata as ad
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
import numpy as np
|
|
13
|
+
import scanpy as sc
|
|
14
|
+
from jax import jit
|
|
15
|
+
from rich.progress import (
|
|
16
|
+
BarColumn,
|
|
17
|
+
MofNCompleteColumn,
|
|
18
|
+
Progress,
|
|
19
|
+
SpinnerColumn,
|
|
20
|
+
TaskProgressColumn,
|
|
21
|
+
TextColumn,
|
|
22
|
+
TimeElapsedColumn,
|
|
23
|
+
TimeRemainingColumn,
|
|
24
|
+
track,
|
|
25
|
+
)
|
|
26
|
+
from scipy.spatial import cKDTree
|
|
27
|
+
|
|
28
|
+
from gsMap.config import DatasetType, LatentToGeneConfig
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
# Configure JAX
|
|
33
|
+
jax.config.update("jax_enable_x64", False) # Use float32 for speed
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def find_spatial_neighbors_with_slices(
|
|
37
|
+
coords: np.ndarray,
|
|
38
|
+
slice_ids: np.ndarray | None = None,
|
|
39
|
+
query_cell_mask: np.ndarray | None = None,
|
|
40
|
+
high_quality_cell_mask: np.ndarray = None,
|
|
41
|
+
k_central: int = 101,
|
|
42
|
+
k_adjacent: int = 50,
|
|
43
|
+
n_adjacent_slices: int = 1
|
|
44
|
+
) -> tuple[np.ndarray, dict[int, np.ndarray]]:
|
|
45
|
+
"""
|
|
46
|
+
Find spatial neighbors with slice-aware search for 3D data.
|
|
47
|
+
|
|
48
|
+
For 2D data (slice_ids is None), performs standard KNN.
|
|
49
|
+
For 3D data, implements slice-aware neighbor search:
|
|
50
|
+
- Finds k_central neighbors on the same slice
|
|
51
|
+
- Finds k_adjacent neighbors on each of n_adjacent_slices above and below
|
|
52
|
+
|
|
53
|
+
KDTrees are built using only high quality cells, and neighbors are searched
|
|
54
|
+
within high quality cells only.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
coords: Spatial coordinates (n_cells, 2) - only x,y coordinates
|
|
58
|
+
slice_ids: Slice/z-coordinate indices (n_cells,) - sequential integers
|
|
59
|
+
query_cell_mask: Boolean mask for cells to find neighbors for
|
|
60
|
+
high_quality_cell_mask: Boolean mask for high quality cells (defines the neighbor search pool)
|
|
61
|
+
k_central: Number of neighbors to find on the central slice
|
|
62
|
+
k_adjacent: Number of neighbors to find on each adjacent slice
|
|
63
|
+
n_adjacent_slices: Number of slices to search above and below
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Tuple of:
|
|
67
|
+
- spatial_neighbors: Array of neighbor indices (n_query_cells, k_central + 2*n_adjacent_slices*k_adjacent)
|
|
68
|
+
- neighbor_pool_per_slice: Dict mapping slice_id to local indices of neighbor pool cells on that slice
|
|
69
|
+
"""
|
|
70
|
+
n_cells = len(coords)
|
|
71
|
+
if query_cell_mask is None:
|
|
72
|
+
query_cell_mask = np.ones(n_cells, dtype=bool)
|
|
73
|
+
|
|
74
|
+
if high_quality_cell_mask is None:
|
|
75
|
+
high_quality_cell_mask = np.ones(n_cells, dtype=bool)
|
|
76
|
+
|
|
77
|
+
logger.debug(f"Finding neighbors: {high_quality_cell_mask.sum()}/{n_cells} cells are high quality")
|
|
78
|
+
|
|
79
|
+
query_cell_indices = np.where(query_cell_mask)[0]
|
|
80
|
+
n_query_cells = len(query_cell_indices)
|
|
81
|
+
|
|
82
|
+
# Neighbor pool: high quality cells that will be used to build KDTrees
|
|
83
|
+
# (intersection of query_cell_mask and high_quality_cell_mask)
|
|
84
|
+
neighbor_pool_mask = query_cell_mask & high_quality_cell_mask
|
|
85
|
+
neighbor_pool_indices = np.where(neighbor_pool_mask)[0]
|
|
86
|
+
|
|
87
|
+
# If no slice_ids provided, perform standard 2D KNN
|
|
88
|
+
if slice_ids is None:
|
|
89
|
+
logger.info(f"No slice IDs provided, performing standard 2D KNN with k={k_central}")
|
|
90
|
+
# Build tree with neighbor pool cells only
|
|
91
|
+
kdtree = cKDTree(coords[neighbor_pool_mask])
|
|
92
|
+
# Query for all query cells (including non-HQ ones)
|
|
93
|
+
_, neighbor_local_indices = kdtree.query(
|
|
94
|
+
coords[query_cell_mask],
|
|
95
|
+
k=min(k_central, len(neighbor_pool_indices)),
|
|
96
|
+
workers=-1 # Use all available cores
|
|
97
|
+
)
|
|
98
|
+
# Convert local indices to global indices
|
|
99
|
+
spatial_neighbors = neighbor_pool_indices[neighbor_local_indices]
|
|
100
|
+
return spatial_neighbors, {}
|
|
101
|
+
|
|
102
|
+
# Slice-aware neighbor search with fixed-size arrays
|
|
103
|
+
logger.info(f"Performing slice-aware neighbor search: k_central={k_central}, "
|
|
104
|
+
f"k_adjacent={k_adjacent}, n_adjacent_slices={n_adjacent_slices}")
|
|
105
|
+
|
|
106
|
+
query_cell_slice_ids = slice_ids[query_cell_mask]
|
|
107
|
+
query_cell_coords = coords[query_cell_mask]
|
|
108
|
+
|
|
109
|
+
# Pre-allocate output with fixed size, initialized with -1 (invalid)
|
|
110
|
+
total_neighbors_per_cell = k_central + 2 * n_adjacent_slices * k_adjacent
|
|
111
|
+
# Always use int32 for spatial neighbors since they contain global cell indices
|
|
112
|
+
# which can easily exceed int16 range in large spatial datasets
|
|
113
|
+
spatial_neighbors = np.full((n_query_cells, total_neighbors_per_cell), -1, dtype=np.int32)
|
|
114
|
+
|
|
115
|
+
# Get unique slices and create mapping for all query cells
|
|
116
|
+
unique_slice_ids = np.unique(query_cell_slice_ids)
|
|
117
|
+
slice_to_query_cell_indices = {s: np.where(query_cell_slice_ids == s)[0] for s in unique_slice_ids}
|
|
118
|
+
|
|
119
|
+
# Create mapping for neighbor pool cells per slice (for building KDTrees)
|
|
120
|
+
neighbor_pool_per_slice = {}
|
|
121
|
+
|
|
122
|
+
# Get slice IDs and coordinates for neighbor pool cells
|
|
123
|
+
neighbor_pool_slice_ids = slice_ids[neighbor_pool_mask]
|
|
124
|
+
neighbor_pool_coords = coords[neighbor_pool_mask]
|
|
125
|
+
|
|
126
|
+
for slice_id in unique_slice_ids:
|
|
127
|
+
# Find neighbor pool cells on this slice (in the neighbor_pool array)
|
|
128
|
+
slice_neighbor_pool_mask = neighbor_pool_slice_ids == slice_id
|
|
129
|
+
slice_neighbor_pool_local_indices = np.where(slice_neighbor_pool_mask)[0]
|
|
130
|
+
neighbor_pool_per_slice[slice_id] = slice_neighbor_pool_local_indices
|
|
131
|
+
logger.debug(f"Slice {slice_id}: {len(slice_neighbor_pool_local_indices)} high quality cells")
|
|
132
|
+
|
|
133
|
+
# Pre-compute KDTree for each slice using neighbor pool cells only
|
|
134
|
+
# Note: KDTree can handle cases where k > number of points in tree
|
|
135
|
+
# It will return valid neighbors first, then fill remaining slots with invalid indices
|
|
136
|
+
slice_kdtrees = {}
|
|
137
|
+
for slice_id in unique_slice_ids:
|
|
138
|
+
# Use neighbor pool cells to build tree
|
|
139
|
+
slice_neighbor_pool_local = neighbor_pool_per_slice.get(slice_id, np.array([]))
|
|
140
|
+
if len(slice_neighbor_pool_local) > 0: # Build tree as long as there's at least 1 hq cell
|
|
141
|
+
# Get coordinates of neighbor pool cells on this slice
|
|
142
|
+
slice_neighbor_pool_coords = neighbor_pool_coords[slice_neighbor_pool_local]
|
|
143
|
+
kdtree = cKDTree(slice_neighbor_pool_coords)
|
|
144
|
+
# Store tree with global indices of neighbor pool cells
|
|
145
|
+
slice_kdtrees[slice_id] = (kdtree, neighbor_pool_indices[slice_neighbor_pool_local])
|
|
146
|
+
|
|
147
|
+
if len(slice_neighbor_pool_local) < max(k_central, k_adjacent):
|
|
148
|
+
logger.warning(f"Slice {slice_id} has only {len(slice_neighbor_pool_local)} high quality cells, "
|
|
149
|
+
f"which is less than required k={max(k_central, k_adjacent)}. "
|
|
150
|
+
# f"Some neighbors will be invalid (-1)."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Build a mapping of slice_id -> list of adjacent slice_ids to search
|
|
154
|
+
# This ensures all slices search the same total number of adjacent slices
|
|
155
|
+
slice_adjacent_mapping = {}
|
|
156
|
+
min_slice_id = min(unique_slice_ids)
|
|
157
|
+
max_slice_id = max(unique_slice_ids)
|
|
158
|
+
|
|
159
|
+
for slice_id in unique_slice_ids:
|
|
160
|
+
adjacent_slices = []
|
|
161
|
+
|
|
162
|
+
# Collect slices in both directions up to n_adjacent_slices
|
|
163
|
+
for offset in range(1, n_adjacent_slices + 1):
|
|
164
|
+
neg_slice = slice_id - offset
|
|
165
|
+
pos_slice = slice_id + offset
|
|
166
|
+
if neg_slice in slice_kdtrees:
|
|
167
|
+
adjacent_slices.append((neg_slice, offset, -1)) # (slice_id, offset, direction)
|
|
168
|
+
if pos_slice in slice_kdtrees:
|
|
169
|
+
adjacent_slices.append((pos_slice, offset, 1))
|
|
170
|
+
|
|
171
|
+
# If we don't have enough adjacent slices (2*n_adjacent_slices total), compensate
|
|
172
|
+
target_count = 2 * n_adjacent_slices
|
|
173
|
+
if len(adjacent_slices) < target_count:
|
|
174
|
+
# Search deeper in available directions
|
|
175
|
+
extra_offset = n_adjacent_slices + 1
|
|
176
|
+
while len(adjacent_slices) < target_count and extra_offset <= max(max_slice_id - min_slice_id, 10):
|
|
177
|
+
neg_slice = slice_id - extra_offset
|
|
178
|
+
pos_slice = slice_id + extra_offset
|
|
179
|
+
|
|
180
|
+
if neg_slice >= min_slice_id and neg_slice in slice_kdtrees:
|
|
181
|
+
if not any(s[0] == neg_slice for s in adjacent_slices):
|
|
182
|
+
adjacent_slices.append((neg_slice, extra_offset, -1))
|
|
183
|
+
if len(adjacent_slices) >= target_count:
|
|
184
|
+
break
|
|
185
|
+
|
|
186
|
+
if pos_slice <= max_slice_id and pos_slice in slice_kdtrees:
|
|
187
|
+
if not any(s[0] == pos_slice for s in adjacent_slices):
|
|
188
|
+
adjacent_slices.append((pos_slice, extra_offset, 1))
|
|
189
|
+
if len(adjacent_slices) >= target_count:
|
|
190
|
+
break
|
|
191
|
+
|
|
192
|
+
extra_offset += 1
|
|
193
|
+
|
|
194
|
+
# Sort adjacent slices: first by offset (closer slices first), then by direction
|
|
195
|
+
# This maintains consistency: [offset=1,dir=-1], [offset=1,dir=1], [offset=2,dir=-1], [offset=2,dir=1], ...
|
|
196
|
+
adjacent_slices.sort(key=lambda x: (x[1], -x[2])) # Sort by offset, then direction (-1 before 1)
|
|
197
|
+
|
|
198
|
+
slice_adjacent_mapping[slice_id] = adjacent_slices
|
|
199
|
+
|
|
200
|
+
if len(adjacent_slices) < target_count:
|
|
201
|
+
logger.warning(f"Slice {slice_id} only has {len(adjacent_slices)} adjacent slices available "
|
|
202
|
+
f"(target: {target_count}). Some neighbor slots will remain empty.")
|
|
203
|
+
|
|
204
|
+
# Log the adjacent slice mapping for verification
|
|
205
|
+
for slice_id in sorted(slice_adjacent_mapping.keys()):
|
|
206
|
+
adj_slice_ids = [s[0] for s in slice_adjacent_mapping[slice_id]]
|
|
207
|
+
logger.debug(f"Slice {slice_id} will search in adjacent slices: {adj_slice_ids}")
|
|
208
|
+
|
|
209
|
+
# Batch process all query cells (including edge cases with few neighbor pool cells)
|
|
210
|
+
for slice_id in unique_slice_ids:
|
|
211
|
+
if slice_id in slice_kdtrees:
|
|
212
|
+
# Get all query cells on this slice
|
|
213
|
+
query_cells_on_slice = slice_to_query_cell_indices[slice_id]
|
|
214
|
+
query_coords = query_cell_coords[query_cells_on_slice]
|
|
215
|
+
|
|
216
|
+
# Central slice neighbors (fixed k_central)
|
|
217
|
+
central_kdtree, central_neighbor_pool_global_indices = slice_kdtrees[slice_id]
|
|
218
|
+
_, central_neighbor_local_indices = central_kdtree.query(query_coords, k=k_central, workers=-1)
|
|
219
|
+
|
|
220
|
+
# Handle invalid indices returned by KDTree when k > number of points in tree
|
|
221
|
+
# KDTree returns index >= tree_size for invalid slots
|
|
222
|
+
n_central_pool_cells = len(central_neighbor_pool_global_indices)
|
|
223
|
+
# Initialize with -1 (invalid)
|
|
224
|
+
central_neighbor_global_indices = np.full_like(central_neighbor_local_indices, -1, dtype=np.int32)
|
|
225
|
+
# Create mask for valid indices (< tree size)
|
|
226
|
+
central_valid_mask = central_neighbor_local_indices < n_central_pool_cells
|
|
227
|
+
# Map valid local indices to global indices
|
|
228
|
+
central_neighbor_global_indices[central_valid_mask] = central_neighbor_pool_global_indices[
|
|
229
|
+
central_neighbor_local_indices[central_valid_mask]
|
|
230
|
+
]
|
|
231
|
+
spatial_neighbors[query_cells_on_slice, :k_central] = central_neighbor_global_indices
|
|
232
|
+
|
|
233
|
+
# Adjacent slices - use the pre-computed mapping
|
|
234
|
+
# Get the list of adjacent slices to search for this slice_id
|
|
235
|
+
adjacent_slices_to_search = slice_adjacent_mapping.get(slice_id, [])
|
|
236
|
+
|
|
237
|
+
neighbor_column_offset = k_central
|
|
238
|
+
|
|
239
|
+
# We need to fill exactly 2*n_adjacent_slices slots in the output array
|
|
240
|
+
# The array structure is: [offset=1,dir=-1], [offset=1,dir=+1], [offset=2,dir=-1], [offset=2,dir=+1], ...
|
|
241
|
+
for slot_idx in range(2 * n_adjacent_slices):
|
|
242
|
+
if slot_idx < len(adjacent_slices_to_search):
|
|
243
|
+
# We have a slice to search for this slot
|
|
244
|
+
adjacent_slice_id, _, _ = adjacent_slices_to_search[slot_idx]
|
|
245
|
+
|
|
246
|
+
adjacent_kdtree, adjacent_neighbor_pool_global_indices = slice_kdtrees[adjacent_slice_id]
|
|
247
|
+
_, adjacent_neighbor_local_indices = adjacent_kdtree.query(query_coords, k=k_adjacent, workers=-1)
|
|
248
|
+
|
|
249
|
+
# Handle invalid indices for adjacent slices
|
|
250
|
+
n_adjacent_pool_cells = len(adjacent_neighbor_pool_global_indices)
|
|
251
|
+
adjacent_neighbor_global_indices = np.full_like(adjacent_neighbor_local_indices, -1, dtype=np.int32)
|
|
252
|
+
adjacent_valid_mask = adjacent_neighbor_local_indices < n_adjacent_pool_cells
|
|
253
|
+
adjacent_neighbor_global_indices[adjacent_valid_mask] = adjacent_neighbor_pool_global_indices[
|
|
254
|
+
adjacent_neighbor_local_indices[adjacent_valid_mask]
|
|
255
|
+
]
|
|
256
|
+
spatial_neighbors[query_cells_on_slice, neighbor_column_offset:neighbor_column_offset+k_adjacent] = adjacent_neighbor_global_indices
|
|
257
|
+
# else: slot remains as -1 (already initialized) if no slice available
|
|
258
|
+
|
|
259
|
+
neighbor_column_offset += k_adjacent
|
|
260
|
+
|
|
261
|
+
# Note: Slices with 0 neighbor pool cells will not be in slice_kdtrees
|
|
262
|
+
# Query cells on such slices will have all neighbors remain as -1 (already initialized)
|
|
263
|
+
|
|
264
|
+
return spatial_neighbors, neighbor_pool_per_slice
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@partial(jit, static_argnums=(5, 6))
|
|
268
|
+
def _find_homogeneous_spot_dual_embedding_batch_jit(
|
|
269
|
+
emb_niche_batch_norm: jnp.ndarray, # (batch_size, d1)
|
|
270
|
+
emb_indv_batch_norm: jnp.ndarray, # (batch_size, d2)
|
|
271
|
+
spatial_neighbors: jnp.ndarray, # (batch_size, k1)
|
|
272
|
+
all_emb_niche_norm: jnp.ndarray, # (n_all, d1)
|
|
273
|
+
all_emb_indv_norm: jnp.ndarray, # (n_all, d2)
|
|
274
|
+
homogeneous_neighbors: int,
|
|
275
|
+
cell_embedding_similarity_threshold: float = 0.0,
|
|
276
|
+
spatial_domain_similarity_threshold: float = 0.5
|
|
277
|
+
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
278
|
+
"""
|
|
279
|
+
Finds homogeneous neighbors using a soft priority (Bonus) system.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
homogeneous_neighbors_array: Indices of selected neighbors (batch, homogeneous_neighbors)
|
|
283
|
+
selected_cell_sims: Cell similarity scores for selected neighbors (batch, homogeneous_neighbors)
|
|
284
|
+
selected_niche_sims: Niche/Spatial similarity scores for selected neighbors (batch, homogeneous_neighbors)
|
|
285
|
+
"""
|
|
286
|
+
batch_size = emb_indv_batch_norm.shape[0]
|
|
287
|
+
|
|
288
|
+
# Step 1: Extract spatial neighbors' embeddings
|
|
289
|
+
# Handle padding: map -1 to 0 temporarily for extraction
|
|
290
|
+
safe_neighbors = jnp.where(spatial_neighbors >= 0, spatial_neighbors, 0)
|
|
291
|
+
spatial_emb_indv_norm = all_emb_indv_norm[safe_neighbors]
|
|
292
|
+
spatial_emb_niche_norm = all_emb_niche_norm[safe_neighbors]
|
|
293
|
+
|
|
294
|
+
# Step 2: Compute Similarities
|
|
295
|
+
# 2a. Cell Similarity
|
|
296
|
+
raw_cell_sims = jnp.einsum('bd,bkd->bk', emb_indv_batch_norm, spatial_emb_indv_norm)
|
|
297
|
+
# Apply cell threshold: values below threshold become 0.0
|
|
298
|
+
cell_sims = jnp.where(raw_cell_sims >= cell_embedding_similarity_threshold, raw_cell_sims, 0.0)
|
|
299
|
+
|
|
300
|
+
# 2b. Niche (Spatial) Similarity
|
|
301
|
+
niche_sims = jnp.einsum('bd,bkd->bk', emb_niche_batch_norm, spatial_emb_niche_norm)
|
|
302
|
+
|
|
303
|
+
# Step 3: Compute Ranking Score with Bonus
|
|
304
|
+
# Logic:
|
|
305
|
+
# If niche_sim >= threshold: Score = cell_sim + 2.0
|
|
306
|
+
# Else: Score = cell_sim
|
|
307
|
+
# This prioritizes niche matches first, then falls back to highest cell sim.
|
|
308
|
+
BONUS = 2.0
|
|
309
|
+
pass_spatial_mask = niche_sims >= spatial_domain_similarity_threshold
|
|
310
|
+
ranking_score = cell_sims + jnp.where(pass_spatial_mask, BONUS, 0.0)
|
|
311
|
+
|
|
312
|
+
# Step 4: Mask Invalid Neighbors in Ranking
|
|
313
|
+
# Force padding indices (original -1s) to the bottom of the list
|
|
314
|
+
valid_neighbor_mask = spatial_neighbors >= 0
|
|
315
|
+
ranking_score = jnp.where(valid_neighbor_mask, ranking_score, -100.0)
|
|
316
|
+
|
|
317
|
+
# Step 5: Sort and Select
|
|
318
|
+
# Get indices of top scores
|
|
319
|
+
top_homo_idx = jnp.argsort(-ranking_score, axis=1)[:, :homogeneous_neighbors]
|
|
320
|
+
batch_idx = jnp.arange(batch_size)[:, None]
|
|
321
|
+
|
|
322
|
+
# Gather the results
|
|
323
|
+
homogeneous_neighbors_array = spatial_neighbors[batch_idx, top_homo_idx]
|
|
324
|
+
selected_cell_sims = cell_sims[batch_idx, top_homo_idx]
|
|
325
|
+
selected_niche_sims = niche_sims[batch_idx, top_homo_idx]
|
|
326
|
+
|
|
327
|
+
# Step 6: Final Masking
|
|
328
|
+
# If the selected neighbor is invalid (which happens if k1 < homogeneous_neighbors
|
|
329
|
+
# or all valid neighbors were exhausted), ensure returned similarities are 0.0.
|
|
330
|
+
final_valid_mask = homogeneous_neighbors_array >= 0
|
|
331
|
+
|
|
332
|
+
selected_cell_sims = jnp.where(final_valid_mask, selected_cell_sims, 0.0)
|
|
333
|
+
selected_niche_sims = jnp.where(final_valid_mask, selected_niche_sims, 0.0)
|
|
334
|
+
|
|
335
|
+
return homogeneous_neighbors_array, selected_cell_sims, selected_niche_sims
|
|
336
|
+
def _find_homogeneous_3d_memory_efficient(
|
|
337
|
+
emb_niche_masked_jax: jnp.ndarray,
|
|
338
|
+
emb_indv_masked_jax: jnp.ndarray,
|
|
339
|
+
spatial_neighbors: np.ndarray,
|
|
340
|
+
all_emb_niche_norm_jax: jnp.ndarray,
|
|
341
|
+
all_emb_indv_norm_jax: jnp.ndarray,
|
|
342
|
+
num_homogeneous_per_slice: int,
|
|
343
|
+
k_central: int,
|
|
344
|
+
k_adjacent: int,
|
|
345
|
+
n_adjacent_slices: int,
|
|
346
|
+
cell_embedding_similarity_threshold: float,
|
|
347
|
+
spatial_domain_similarity_threshold: float,
|
|
348
|
+
find_homogeneous_batch_size: int
|
|
349
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
350
|
+
"""
|
|
351
|
+
Memory-efficient version of 3D homogeneous neighbor finding.
|
|
352
|
+
Processes slices separately to avoid large memory allocations.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
emb_niche_masked_jax: Niche embeddings for masked cells (JAX array, float16, always exists, may be dummy ones)
|
|
356
|
+
emb_indv_masked_jax: Individual embeddings for masked cells (JAX array, float16)
|
|
357
|
+
spatial_neighbors: Spatial neighbors array with structure [central | adj1 | adj2 | ...] (numpy array)
|
|
358
|
+
all_emb_niche_norm_jax: All normalized niche embeddings (JAX array, float16, always exists, may be dummy ones)
|
|
359
|
+
all_emb_indv_norm_jax: All normalized individual embeddings (JAX array, float16)
|
|
360
|
+
num_homogeneous_per_slice: Number of neighbors to select per slice
|
|
361
|
+
k_central: Number of neighbors in central slice
|
|
362
|
+
k_adjacent: Number of neighbors per adjacent slice
|
|
363
|
+
n_adjacent_slices: Number of adjacent slices above and below
|
|
364
|
+
cell_embedding_similarity_threshold: Minimum similarity threshold for cell embedding
|
|
365
|
+
spatial_domain_similarity_threshold: Minimum similarity threshold for spatial domain embedding
|
|
366
|
+
find_homogeneous_batch_size: Batch size for processing
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
homogeneous_neighbors: Selected neighbors
|
|
370
|
+
homogeneous_cell_sims: Cell similarity scores for selected neighbors
|
|
371
|
+
homogeneous_niche_sims: Niche/Spatial similarity scores for selected neighbors
|
|
372
|
+
"""
|
|
373
|
+
n_masked = emb_indv_masked_jax.shape[0]
|
|
374
|
+
1 + 2 * n_adjacent_slices
|
|
375
|
+
|
|
376
|
+
homogeneous_neighbors_all_slices = []
|
|
377
|
+
homogeneous_cell_sims_all_slices = []
|
|
378
|
+
homogeneous_niche_sims_all_slices = []
|
|
379
|
+
|
|
380
|
+
# Process all slices (central + adjacent) in a single loop
|
|
381
|
+
total_slices = 1 + 2 * n_adjacent_slices
|
|
382
|
+
|
|
383
|
+
# Create overall slice progress tracking
|
|
384
|
+
with Progress(
|
|
385
|
+
SpinnerColumn(),
|
|
386
|
+
TextColumn("[bold blue]{task.description}"),
|
|
387
|
+
BarColumn(),
|
|
388
|
+
MofNCompleteColumn(),
|
|
389
|
+
TaskProgressColumn(),
|
|
390
|
+
TimeRemainingColumn(),
|
|
391
|
+
TimeElapsedColumn(),
|
|
392
|
+
refresh_per_second=1
|
|
393
|
+
) as slice_progress:
|
|
394
|
+
# Overall slice progress task
|
|
395
|
+
slice_task = slice_progress.add_task(
|
|
396
|
+
"Finding homogeneous neighbors (3D cross-slice)...",
|
|
397
|
+
total=total_slices
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
for slice_num in range(total_slices):
|
|
401
|
+
# Determine slice name and parameters
|
|
402
|
+
if slice_num == 0:
|
|
403
|
+
slice_name = "central slice"
|
|
404
|
+
slice_start = 0
|
|
405
|
+
slice_end = k_central
|
|
406
|
+
else:
|
|
407
|
+
# Adjacent slices
|
|
408
|
+
adj_idx = slice_num - 1
|
|
409
|
+
if adj_idx < n_adjacent_slices:
|
|
410
|
+
slice_name = f"adjacent slice -{n_adjacent_slices - adj_idx}"
|
|
411
|
+
else:
|
|
412
|
+
slice_name = f"adjacent slice +{adj_idx - n_adjacent_slices + 1}"
|
|
413
|
+
|
|
414
|
+
slice_start = k_central + adj_idx * k_adjacent
|
|
415
|
+
slice_end = slice_start + k_adjacent
|
|
416
|
+
|
|
417
|
+
# Convert slice neighbors to JAX array once per slice
|
|
418
|
+
spatial_neighbors_slice = jnp.asarray(spatial_neighbors[:, slice_start:slice_end])
|
|
419
|
+
|
|
420
|
+
homogeneous_neighbors_slice_list = []
|
|
421
|
+
homogeneous_cell_sims_slice_list = []
|
|
422
|
+
homogeneous_niche_sims_slice_list = []
|
|
423
|
+
|
|
424
|
+
# Process batches for this slice using simple transient track
|
|
425
|
+
for batch_start in track(range(0, n_masked, find_homogeneous_batch_size),
|
|
426
|
+
description=f"Finding homogeneous neighbors ({slice_name})",
|
|
427
|
+
transient=True):
|
|
428
|
+
batch_end = min(batch_start + find_homogeneous_batch_size, n_masked)
|
|
429
|
+
batch_indices = slice(batch_start, batch_end)
|
|
430
|
+
|
|
431
|
+
# Get batch data (emb_niche always exists now, may be dummy ones)
|
|
432
|
+
emb_niche_batch_norm = emb_niche_masked_jax[batch_indices]
|
|
433
|
+
emb_indv_batch_norm = emb_indv_masked_jax[batch_indices]
|
|
434
|
+
|
|
435
|
+
# Extract batch of neighbors for this slice
|
|
436
|
+
spatial_neighbors_slice_batch = spatial_neighbors_slice[batch_indices, :]
|
|
437
|
+
|
|
438
|
+
# Process with 2D function
|
|
439
|
+
homo_neighbors_batch, cell_sims_batch, niche_sims_batch = _find_homogeneous_spot_dual_embedding_batch_jit(
|
|
440
|
+
emb_niche_batch_norm=emb_niche_batch_norm,
|
|
441
|
+
emb_indv_batch_norm=emb_indv_batch_norm,
|
|
442
|
+
spatial_neighbors=spatial_neighbors_slice_batch,
|
|
443
|
+
all_emb_niche_norm=all_emb_niche_norm_jax,
|
|
444
|
+
all_emb_indv_norm=all_emb_indv_norm_jax,
|
|
445
|
+
homogeneous_neighbors=num_homogeneous_per_slice,
|
|
446
|
+
cell_embedding_similarity_threshold=cell_embedding_similarity_threshold,
|
|
447
|
+
spatial_domain_similarity_threshold=spatial_domain_similarity_threshold
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
homogeneous_neighbors_slice_list.append(np.array(homo_neighbors_batch))
|
|
451
|
+
homogeneous_cell_sims_slice_list.append(np.array(cell_sims_batch))
|
|
452
|
+
homogeneous_niche_sims_slice_list.append(np.array(niche_sims_batch))
|
|
453
|
+
|
|
454
|
+
# Concatenate this slice's results
|
|
455
|
+
homogeneous_neighbors_slice = np.vstack(homogeneous_neighbors_slice_list)
|
|
456
|
+
homogeneous_cell_sims_slice = np.vstack(homogeneous_cell_sims_slice_list)
|
|
457
|
+
homogeneous_niche_sims_slice = np.vstack(homogeneous_niche_sims_slice_list)
|
|
458
|
+
homogeneous_neighbors_all_slices.append(homogeneous_neighbors_slice)
|
|
459
|
+
homogeneous_cell_sims_all_slices.append(homogeneous_cell_sims_slice)
|
|
460
|
+
homogeneous_niche_sims_all_slices.append(homogeneous_niche_sims_slice)
|
|
461
|
+
|
|
462
|
+
# Update slice progress
|
|
463
|
+
slice_progress.update(slice_task, advance=1)
|
|
464
|
+
|
|
465
|
+
# Concatenate all slices along axis 1
|
|
466
|
+
homogeneous_neighbors = np.concatenate(homogeneous_neighbors_all_slices, axis=1)
|
|
467
|
+
homogeneous_cell_sims = np.concatenate(homogeneous_cell_sims_all_slices, axis=1)
|
|
468
|
+
homogeneous_niche_sims = np.concatenate(homogeneous_niche_sims_all_slices, axis=1)
|
|
469
|
+
|
|
470
|
+
return homogeneous_neighbors, homogeneous_cell_sims, homogeneous_niche_sims
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def build_scrna_connectivity(
|
|
474
|
+
emb_cell: np.ndarray,
|
|
475
|
+
cell_mask: np.ndarray | None = None,
|
|
476
|
+
n_neighbors: int = 21,
|
|
477
|
+
metric: str = 'euclidean'
|
|
478
|
+
) -> tuple[np.ndarray, np.ndarray, None]:
|
|
479
|
+
"""
|
|
480
|
+
Build connectivity for scRNA-seq data using KNN on cell embeddings.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
emb_cell: Cell embeddings (n_cells, d)
|
|
484
|
+
cell_mask: Boolean mask for cells to process
|
|
485
|
+
n_neighbors: Number of nearest neighbors
|
|
486
|
+
metric: Distance metric for KNN
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
neighbor_indices: (n_masked, n_neighbors) array of neighbor indices
|
|
490
|
+
neighbor_weights: (n_masked, n_neighbors) array of weights from KNN graph
|
|
491
|
+
niche_sims: None (scRNA-seq has no niche embedding)
|
|
492
|
+
"""
|
|
493
|
+
n_cells = len(emb_cell)
|
|
494
|
+
if cell_mask is None:
|
|
495
|
+
cell_mask = np.ones(n_cells, dtype=bool)
|
|
496
|
+
|
|
497
|
+
cell_indices = np.where(cell_mask)[0]
|
|
498
|
+
n_masked = len(cell_indices)
|
|
499
|
+
|
|
500
|
+
logger.info(f"Building scRNA-seq connectivity using KNN with k={n_neighbors}")
|
|
501
|
+
|
|
502
|
+
# Create temporary AnnData for using scanpy's neighbors function
|
|
503
|
+
adata_temp = ad.AnnData(X=emb_cell[cell_mask])
|
|
504
|
+
adata_temp.obsm['X_emb'] = emb_cell[cell_mask]
|
|
505
|
+
|
|
506
|
+
# Compute neighbors using scanpy
|
|
507
|
+
sc.pp.neighbors(
|
|
508
|
+
adata_temp,
|
|
509
|
+
n_neighbors=n_neighbors,
|
|
510
|
+
use_rep='X_emb',
|
|
511
|
+
metric=metric,
|
|
512
|
+
method='umap'
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
# Extract connectivity matrix
|
|
516
|
+
connectivities = adata_temp.obsp['connectivities'].tocsr()
|
|
517
|
+
|
|
518
|
+
# Convert to dense format for consistency with spatial methods
|
|
519
|
+
# Check the actual max cell index value to determine dtype
|
|
520
|
+
max_cell_idx = cell_indices.max() if len(cell_indices) > 0 else 0
|
|
521
|
+
idx_dtype = np.int16 if max_cell_idx < 32768 else np.int32
|
|
522
|
+
neighbor_indices = np.zeros((n_masked, n_neighbors), dtype=idx_dtype)
|
|
523
|
+
neighbor_weights = np.zeros((n_masked, n_neighbors), dtype=np.float16)
|
|
524
|
+
|
|
525
|
+
for i in range(n_masked):
|
|
526
|
+
row = connectivities.getrow(i)
|
|
527
|
+
neighbors = row.indices
|
|
528
|
+
weights = row.data
|
|
529
|
+
|
|
530
|
+
# Sort by weight (descending)
|
|
531
|
+
sorted_idx = np.argsort(-weights)[:n_neighbors]
|
|
532
|
+
|
|
533
|
+
if len(sorted_idx) < n_neighbors:
|
|
534
|
+
# Pad with self-index if needed
|
|
535
|
+
n_found = len(sorted_idx)
|
|
536
|
+
neighbor_indices[i, :n_found] = cell_indices[neighbors[sorted_idx]]
|
|
537
|
+
neighbor_indices[i, n_found:] = cell_indices[i]
|
|
538
|
+
neighbor_weights[i, :n_found] = weights[sorted_idx]
|
|
539
|
+
neighbor_weights[i, n_found:] = 0.0
|
|
540
|
+
else:
|
|
541
|
+
neighbor_indices[i] = cell_indices[neighbors[sorted_idx]]
|
|
542
|
+
neighbor_weights[i] = weights[sorted_idx]
|
|
543
|
+
|
|
544
|
+
# Normalize weights to sum to 1
|
|
545
|
+
weight_sums = neighbor_weights.sum(axis=1, keepdims=True)
|
|
546
|
+
weight_sums = np.where(weight_sums > 0, weight_sums, 1.0)
|
|
547
|
+
neighbor_weights = neighbor_weights / weight_sums
|
|
548
|
+
|
|
549
|
+
logger.info(f"scRNA-seq connectivity built: {n_masked} cells × {n_neighbors} neighbors")
|
|
550
|
+
|
|
551
|
+
return neighbor_indices, neighbor_weights, None
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
class ConnectivityMatrixBuilder:
|
|
555
|
+
"""Build connectivity matrix using JAX-accelerated computation with GPU memory optimization"""
|
|
556
|
+
|
|
557
|
+
def __init__(self, config: LatentToGeneConfig):
|
|
558
|
+
"""
|
|
559
|
+
Initialize with configuration
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
config: LatentToGeneConfig object
|
|
563
|
+
"""
|
|
564
|
+
self.config = config
|
|
565
|
+
# Use configured batch size for GPU processing
|
|
566
|
+
self.find_homogeneous_batch_size = config.find_homogeneous_batch_size
|
|
567
|
+
self.dataset_type = config.dataset_type
|
|
568
|
+
|
|
569
|
+
def build_connectivity_matrix(
|
|
570
|
+
self,
|
|
571
|
+
coords: np.ndarray | None = None,
|
|
572
|
+
emb_niche: np.ndarray = None,
|
|
573
|
+
emb_indv: np.ndarray | None = None,
|
|
574
|
+
cell_mask: np.ndarray | None = None,
|
|
575
|
+
high_quality_mask: np.ndarray = None,
|
|
576
|
+
slice_ids: np.ndarray | None = None,
|
|
577
|
+
k_central: int | None = None,
|
|
578
|
+
k_adjacent: int | None = None,
|
|
579
|
+
n_adjacent_slices: int | None = None
|
|
580
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
|
|
581
|
+
"""
|
|
582
|
+
Build connectivity matrix for a group of cells based on dataset type.
|
|
583
|
+
|
|
584
|
+
For scRNA-seq: Uses KNN on cell embeddings (emb_indv)
|
|
585
|
+
For spatial2D: Uses spatial anchors and homogeneous neighbors
|
|
586
|
+
For spatial3D: Uses slice-aware spatial anchors and homogeneous neighbors
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
coords: Spatial coordinates (n_cells, 2) - required for spatial datasets
|
|
590
|
+
emb_niche: Niche embeddings (n_cells, d1) - always provided (may be dummy ones for scRNA-seq)
|
|
591
|
+
emb_indv: Cell identity embeddings (n_cells, d2) - required for all datasets
|
|
592
|
+
cell_mask: Boolean mask for cells to process
|
|
593
|
+
high_quality_mask: Boolean mask for high quality cells (used for neighbor search in spatial data)
|
|
594
|
+
slice_ids: Optional slice/z-coordinate indices (n_cells,) for spatial3D
|
|
595
|
+
k_central: Number of neighbors on central slice (defaults to config settings)
|
|
596
|
+
k_adjacent: Number of neighbors on adjacent slices for spatial3D
|
|
597
|
+
n_adjacent_slices: Number of slices to search above/below for spatial3D
|
|
598
|
+
|
|
599
|
+
Returns:
|
|
600
|
+
Tuple of (neighbor_indices, cell_sims, niche_sims) arrays.
|
|
601
|
+
niche_sims is None for scRNA-seq datasets.
|
|
602
|
+
"""
|
|
603
|
+
# Check dataset type and call appropriate method
|
|
604
|
+
if self.dataset_type == DatasetType.SCRNA_SEQ:
|
|
605
|
+
logger.info("Building connectivity for scRNA-seq dataset")
|
|
606
|
+
if emb_indv is None:
|
|
607
|
+
raise ValueError("emb_indv (cell embeddings) required for scRNA-seq dataset")
|
|
608
|
+
|
|
609
|
+
return build_scrna_connectivity(
|
|
610
|
+
emb_cell=emb_indv,
|
|
611
|
+
cell_mask=cell_mask,
|
|
612
|
+
n_neighbors=self.config.homogeneous_neighbors,
|
|
613
|
+
metric='euclidean'
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
elif self.dataset_type in ['spatial2D', 'spatial3D']:
|
|
617
|
+
logger.info(f"Building connectivity for {self.dataset_type} dataset")
|
|
618
|
+
|
|
619
|
+
# Validate required inputs for spatial datasets
|
|
620
|
+
if coords is None or emb_indv is None:
|
|
621
|
+
raise ValueError("coords and emb_indv required for spatial datasets")
|
|
622
|
+
|
|
623
|
+
# Use config defaults if not provided
|
|
624
|
+
if k_central is None:
|
|
625
|
+
k_central = self.config.spatial_neighbors
|
|
626
|
+
if k_adjacent is None:
|
|
627
|
+
k_adjacent = self.config.adjacent_slice_spatial_neighbors
|
|
628
|
+
if n_adjacent_slices is None:
|
|
629
|
+
if self.dataset_type == DatasetType.SPATIAL_2D:
|
|
630
|
+
n_adjacent_slices = 0
|
|
631
|
+
else: # spatial3D
|
|
632
|
+
n_adjacent_slices = self.config.n_adjacent_slices
|
|
633
|
+
|
|
634
|
+
# For spatial2D, ensure no cross-slice search (but can still have slice_ids)
|
|
635
|
+
if self.dataset_type == DatasetType.SPATIAL_2D:
|
|
636
|
+
n_adjacent_slices = 0 # No cross-slice search, but keep slice_ids if provided
|
|
637
|
+
|
|
638
|
+
return self._build_spatial_connectivity(
|
|
639
|
+
coords=coords,
|
|
640
|
+
emb_niche=emb_niche,
|
|
641
|
+
emb_indv=emb_indv,
|
|
642
|
+
cell_mask=cell_mask,
|
|
643
|
+
high_quality_mask=high_quality_mask,
|
|
644
|
+
slice_ids=slice_ids,
|
|
645
|
+
k_central=k_central,
|
|
646
|
+
k_adjacent=k_adjacent,
|
|
647
|
+
n_adjacent_slices=n_adjacent_slices
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
else:
|
|
651
|
+
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
|
|
652
|
+
|
|
653
|
+
def _build_spatial_connectivity(
|
|
654
|
+
self,
|
|
655
|
+
coords: np.ndarray,
|
|
656
|
+
emb_indv: np.ndarray,
|
|
657
|
+
emb_niche: np.ndarray = None,
|
|
658
|
+
cell_mask: np.ndarray | None = None,
|
|
659
|
+
high_quality_mask: np.ndarray = None,
|
|
660
|
+
slice_ids: np.ndarray | None = None,
|
|
661
|
+
k_central: int = 101,
|
|
662
|
+
k_adjacent: int = 50,
|
|
663
|
+
n_adjacent_slices: int = 1
|
|
664
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
665
|
+
"""
|
|
666
|
+
Internal method for building spatial connectivity matrix.
|
|
667
|
+
|
|
668
|
+
Args:
|
|
669
|
+
coords: Spatial coordinates (n_cells, 2) - only x,y coordinates
|
|
670
|
+
emb_niche: Niche embeddings (n_cells, d1) - always provided (may be dummy ones)
|
|
671
|
+
emb_indv: Cell identity embeddings (n_cells, d2)
|
|
672
|
+
cell_mask: Boolean mask for cells to process
|
|
673
|
+
high_quality_mask: Boolean mask for high quality cells (used for neighbor search)
|
|
674
|
+
slice_ids: Optional slice/z-coordinate indices (n_cells,) for 3D data
|
|
675
|
+
k_central: Number of neighbors on central slice
|
|
676
|
+
k_adjacent: Number of neighbors on adjacent slices for 3D data
|
|
677
|
+
n_adjacent_slices: Number of slices to search above/below for 3D data
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
Tuple of (neighbor_indices, cell_sims, niche_sims) arrays
|
|
681
|
+
"""
|
|
682
|
+
|
|
683
|
+
n_cells = len(coords)
|
|
684
|
+
if cell_mask is None:
|
|
685
|
+
cell_mask = np.ones(n_cells, dtype=bool)
|
|
686
|
+
|
|
687
|
+
cell_indices = np.where(cell_mask)[0]
|
|
688
|
+
n_masked = len(cell_indices)
|
|
689
|
+
|
|
690
|
+
# Step 1: Find spatial neighbors (slice-aware if slice_ids provided)
|
|
691
|
+
spatial_neighbors, neighbor_pool_per_slice = find_spatial_neighbors_with_slices(
|
|
692
|
+
coords=coords,
|
|
693
|
+
slice_ids=slice_ids,
|
|
694
|
+
query_cell_mask=cell_mask,
|
|
695
|
+
high_quality_cell_mask=high_quality_mask,
|
|
696
|
+
k_central=k_central,
|
|
697
|
+
k_adjacent=k_adjacent,
|
|
698
|
+
n_adjacent_slices=n_adjacent_slices
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
# Log statistics about neighbor pool cells per slice if available
|
|
702
|
+
if neighbor_pool_per_slice:
|
|
703
|
+
for slice_id, neighbor_pool_indices in neighbor_pool_per_slice.items():
|
|
704
|
+
logger.debug(f"Slice {slice_id}: {len(neighbor_pool_indices)} high quality cells available for neighbor search")
|
|
705
|
+
|
|
706
|
+
# Step 2 & 3: Find anchors and homogeneous neighbors in batches
|
|
707
|
+
logger.info(f"Finding anchors and homogeneous neighbors (batch size: {self.find_homogeneous_batch_size})...")
|
|
708
|
+
|
|
709
|
+
# Convert embeddings to JAX arrays once (shared for both paths)
|
|
710
|
+
# Note: float16 provides sufficient precision for normalized embeddings
|
|
711
|
+
# emb_niche is guaranteed to exist now (may be dummy ones)
|
|
712
|
+
all_emb_niche_norm_jax = jnp.array(emb_niche, dtype=jnp.float16)
|
|
713
|
+
all_emb_indv_norm_jax = jnp.array(emb_indv, dtype=jnp.float16)
|
|
714
|
+
|
|
715
|
+
# Get masked embeddings
|
|
716
|
+
masked_cell_indices = np.where(cell_mask)[0]
|
|
717
|
+
emb_niche_masked_jax = all_emb_niche_norm_jax[masked_cell_indices]
|
|
718
|
+
emb_indv_masked_jax = all_emb_indv_norm_jax[masked_cell_indices]
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
if self.config.fix_cross_slice_homogenous_neighbors:
|
|
722
|
+
logger.info(f"Using 3D constrained selection (ensuring {self.config.homogeneous_neighbors} neighbors per slice)")
|
|
723
|
+
|
|
724
|
+
# Use memory-efficient version that processes slices separately
|
|
725
|
+
homogeneous_neighbors, homogeneous_cell_sims, homogeneous_niche_sims = _find_homogeneous_3d_memory_efficient(
|
|
726
|
+
emb_niche_masked_jax=emb_niche_masked_jax,
|
|
727
|
+
emb_indv_masked_jax=emb_indv_masked_jax,
|
|
728
|
+
spatial_neighbors=spatial_neighbors,
|
|
729
|
+
all_emb_niche_norm_jax=all_emb_niche_norm_jax,
|
|
730
|
+
all_emb_indv_norm_jax=all_emb_indv_norm_jax,
|
|
731
|
+
num_homogeneous_per_slice=self.config.homogeneous_neighbors,
|
|
732
|
+
k_central=k_central,
|
|
733
|
+
k_adjacent=k_adjacent,
|
|
734
|
+
n_adjacent_slices=n_adjacent_slices,
|
|
735
|
+
cell_embedding_similarity_threshold=self.config.cell_embedding_similarity_threshold,
|
|
736
|
+
spatial_domain_similarity_threshold=self.config.spatial_domain_similarity_threshold,
|
|
737
|
+
find_homogeneous_batch_size=self.find_homogeneous_batch_size
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
else:
|
|
741
|
+
# Convert spatial_neighbors to JAX array for regular processing
|
|
742
|
+
spatial_neighbors_jax = jnp.array(spatial_neighbors, dtype=jnp.int32)
|
|
743
|
+
homogeneous_neighbors_list = []
|
|
744
|
+
homogeneous_cell_sims_list = []
|
|
745
|
+
homogeneous_niche_sims_list = []
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
# Use the standard function (2D or 3D without fix_cross_slice_homogenous_neighbors)
|
|
749
|
+
for batch_start in track(range(0, n_masked, self.find_homogeneous_batch_size), description="Finding homogeneous neighbors", transient=True):
|
|
750
|
+
batch_end = min(batch_start + self.find_homogeneous_batch_size, n_masked)
|
|
751
|
+
batch_indices = slice(batch_start, batch_end)
|
|
752
|
+
|
|
753
|
+
# Get batch data directly from JAX arrays (no GPU movement)
|
|
754
|
+
emb_niche_batch_norm = emb_niche_masked_jax[batch_indices]
|
|
755
|
+
emb_indv_batch_norm = emb_indv_masked_jax[batch_indices]
|
|
756
|
+
spatial_neighbors_batch = spatial_neighbors_jax[batch_indices]
|
|
757
|
+
|
|
758
|
+
# Process batch with single JIT-compiled function
|
|
759
|
+
homo_neighbors_batch, cell_sims_batch, niche_sims_batch = _find_homogeneous_spot_dual_embedding_batch_jit(
|
|
760
|
+
emb_niche_batch_norm=emb_niche_batch_norm,
|
|
761
|
+
emb_indv_batch_norm=emb_indv_batch_norm,
|
|
762
|
+
spatial_neighbors=spatial_neighbors_batch,
|
|
763
|
+
all_emb_niche_norm=all_emb_niche_norm_jax,
|
|
764
|
+
all_emb_indv_norm=all_emb_indv_norm_jax,
|
|
765
|
+
homogeneous_neighbors=self.config.total_homogeneous_neighbor_per_cell,
|
|
766
|
+
cell_embedding_similarity_threshold=self.config.cell_embedding_similarity_threshold,
|
|
767
|
+
spatial_domain_similarity_threshold=self.config.spatial_domain_similarity_threshold
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
# Convert back to numpy and append
|
|
771
|
+
homogeneous_neighbors_list.append(np.array(homo_neighbors_batch))
|
|
772
|
+
homogeneous_cell_sims_list.append(np.array(cell_sims_batch))
|
|
773
|
+
homogeneous_niche_sims_list.append(np.array(niche_sims_batch))
|
|
774
|
+
|
|
775
|
+
# Regular batched processing
|
|
776
|
+
homogeneous_neighbors = np.vstack(homogeneous_neighbors_list)
|
|
777
|
+
homogeneous_cell_sims = np.vstack(homogeneous_cell_sims_list)
|
|
778
|
+
homogeneous_niche_sims = np.vstack(homogeneous_niche_sims_list)
|
|
779
|
+
|
|
780
|
+
# Return dense format: (n_masked, num_homogeneous) arrays
|
|
781
|
+
return homogeneous_neighbors, homogeneous_cell_sims, homogeneous_niche_sims
|