gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,781 @@
1
+ """
2
+ Connectivity matrix building for homogeneous spot identification
3
+ Implements the spatial → anchor → homogeneous neighbor finding algorithm
4
+ """
5
+
6
+ import logging
7
+ from functools import partial
8
+
9
+ import anndata as ad
10
+ import jax
11
+ import jax.numpy as jnp
12
+ import numpy as np
13
+ import scanpy as sc
14
+ from jax import jit
15
+ from rich.progress import (
16
+ BarColumn,
17
+ MofNCompleteColumn,
18
+ Progress,
19
+ SpinnerColumn,
20
+ TaskProgressColumn,
21
+ TextColumn,
22
+ TimeElapsedColumn,
23
+ TimeRemainingColumn,
24
+ track,
25
+ )
26
+ from scipy.spatial import cKDTree
27
+
28
+ from gsMap.config import DatasetType, LatentToGeneConfig
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Configure JAX
33
+ jax.config.update("jax_enable_x64", False) # Use float32 for speed
34
+
35
+
36
+ def find_spatial_neighbors_with_slices(
37
+ coords: np.ndarray,
38
+ slice_ids: np.ndarray | None = None,
39
+ query_cell_mask: np.ndarray | None = None,
40
+ high_quality_cell_mask: np.ndarray = None,
41
+ k_central: int = 101,
42
+ k_adjacent: int = 50,
43
+ n_adjacent_slices: int = 1
44
+ ) -> tuple[np.ndarray, dict[int, np.ndarray]]:
45
+ """
46
+ Find spatial neighbors with slice-aware search for 3D data.
47
+
48
+ For 2D data (slice_ids is None), performs standard KNN.
49
+ For 3D data, implements slice-aware neighbor search:
50
+ - Finds k_central neighbors on the same slice
51
+ - Finds k_adjacent neighbors on each of n_adjacent_slices above and below
52
+
53
+ KDTrees are built using only high quality cells, and neighbors are searched
54
+ within high quality cells only.
55
+
56
+ Args:
57
+ coords: Spatial coordinates (n_cells, 2) - only x,y coordinates
58
+ slice_ids: Slice/z-coordinate indices (n_cells,) - sequential integers
59
+ query_cell_mask: Boolean mask for cells to find neighbors for
60
+ high_quality_cell_mask: Boolean mask for high quality cells (defines the neighbor search pool)
61
+ k_central: Number of neighbors to find on the central slice
62
+ k_adjacent: Number of neighbors to find on each adjacent slice
63
+ n_adjacent_slices: Number of slices to search above and below
64
+
65
+ Returns:
66
+ Tuple of:
67
+ - spatial_neighbors: Array of neighbor indices (n_query_cells, k_central + 2*n_adjacent_slices*k_adjacent)
68
+ - neighbor_pool_per_slice: Dict mapping slice_id to local indices of neighbor pool cells on that slice
69
+ """
70
+ n_cells = len(coords)
71
+ if query_cell_mask is None:
72
+ query_cell_mask = np.ones(n_cells, dtype=bool)
73
+
74
+ if high_quality_cell_mask is None:
75
+ high_quality_cell_mask = np.ones(n_cells, dtype=bool)
76
+
77
+ logger.debug(f"Finding neighbors: {high_quality_cell_mask.sum()}/{n_cells} cells are high quality")
78
+
79
+ query_cell_indices = np.where(query_cell_mask)[0]
80
+ n_query_cells = len(query_cell_indices)
81
+
82
+ # Neighbor pool: high quality cells that will be used to build KDTrees
83
+ # (intersection of query_cell_mask and high_quality_cell_mask)
84
+ neighbor_pool_mask = query_cell_mask & high_quality_cell_mask
85
+ neighbor_pool_indices = np.where(neighbor_pool_mask)[0]
86
+
87
+ # If no slice_ids provided, perform standard 2D KNN
88
+ if slice_ids is None:
89
+ logger.info(f"No slice IDs provided, performing standard 2D KNN with k={k_central}")
90
+ # Build tree with neighbor pool cells only
91
+ kdtree = cKDTree(coords[neighbor_pool_mask])
92
+ # Query for all query cells (including non-HQ ones)
93
+ _, neighbor_local_indices = kdtree.query(
94
+ coords[query_cell_mask],
95
+ k=min(k_central, len(neighbor_pool_indices)),
96
+ workers=-1 # Use all available cores
97
+ )
98
+ # Convert local indices to global indices
99
+ spatial_neighbors = neighbor_pool_indices[neighbor_local_indices]
100
+ return spatial_neighbors, {}
101
+
102
+ # Slice-aware neighbor search with fixed-size arrays
103
+ logger.info(f"Performing slice-aware neighbor search: k_central={k_central}, "
104
+ f"k_adjacent={k_adjacent}, n_adjacent_slices={n_adjacent_slices}")
105
+
106
+ query_cell_slice_ids = slice_ids[query_cell_mask]
107
+ query_cell_coords = coords[query_cell_mask]
108
+
109
+ # Pre-allocate output with fixed size, initialized with -1 (invalid)
110
+ total_neighbors_per_cell = k_central + 2 * n_adjacent_slices * k_adjacent
111
+ # Always use int32 for spatial neighbors since they contain global cell indices
112
+ # which can easily exceed int16 range in large spatial datasets
113
+ spatial_neighbors = np.full((n_query_cells, total_neighbors_per_cell), -1, dtype=np.int32)
114
+
115
+ # Get unique slices and create mapping for all query cells
116
+ unique_slice_ids = np.unique(query_cell_slice_ids)
117
+ slice_to_query_cell_indices = {s: np.where(query_cell_slice_ids == s)[0] for s in unique_slice_ids}
118
+
119
+ # Create mapping for neighbor pool cells per slice (for building KDTrees)
120
+ neighbor_pool_per_slice = {}
121
+
122
+ # Get slice IDs and coordinates for neighbor pool cells
123
+ neighbor_pool_slice_ids = slice_ids[neighbor_pool_mask]
124
+ neighbor_pool_coords = coords[neighbor_pool_mask]
125
+
126
+ for slice_id in unique_slice_ids:
127
+ # Find neighbor pool cells on this slice (in the neighbor_pool array)
128
+ slice_neighbor_pool_mask = neighbor_pool_slice_ids == slice_id
129
+ slice_neighbor_pool_local_indices = np.where(slice_neighbor_pool_mask)[0]
130
+ neighbor_pool_per_slice[slice_id] = slice_neighbor_pool_local_indices
131
+ logger.debug(f"Slice {slice_id}: {len(slice_neighbor_pool_local_indices)} high quality cells")
132
+
133
+ # Pre-compute KDTree for each slice using neighbor pool cells only
134
+ # Note: KDTree can handle cases where k > number of points in tree
135
+ # It will return valid neighbors first, then fill remaining slots with invalid indices
136
+ slice_kdtrees = {}
137
+ for slice_id in unique_slice_ids:
138
+ # Use neighbor pool cells to build tree
139
+ slice_neighbor_pool_local = neighbor_pool_per_slice.get(slice_id, np.array([]))
140
+ if len(slice_neighbor_pool_local) > 0: # Build tree as long as there's at least 1 hq cell
141
+ # Get coordinates of neighbor pool cells on this slice
142
+ slice_neighbor_pool_coords = neighbor_pool_coords[slice_neighbor_pool_local]
143
+ kdtree = cKDTree(slice_neighbor_pool_coords)
144
+ # Store tree with global indices of neighbor pool cells
145
+ slice_kdtrees[slice_id] = (kdtree, neighbor_pool_indices[slice_neighbor_pool_local])
146
+
147
+ if len(slice_neighbor_pool_local) < max(k_central, k_adjacent):
148
+ logger.warning(f"Slice {slice_id} has only {len(slice_neighbor_pool_local)} high quality cells, "
149
+ f"which is less than required k={max(k_central, k_adjacent)}. "
150
+ # f"Some neighbors will be invalid (-1)."
151
+ )
152
+
153
+ # Build a mapping of slice_id -> list of adjacent slice_ids to search
154
+ # This ensures all slices search the same total number of adjacent slices
155
+ slice_adjacent_mapping = {}
156
+ min_slice_id = min(unique_slice_ids)
157
+ max_slice_id = max(unique_slice_ids)
158
+
159
+ for slice_id in unique_slice_ids:
160
+ adjacent_slices = []
161
+
162
+ # Collect slices in both directions up to n_adjacent_slices
163
+ for offset in range(1, n_adjacent_slices + 1):
164
+ neg_slice = slice_id - offset
165
+ pos_slice = slice_id + offset
166
+ if neg_slice in slice_kdtrees:
167
+ adjacent_slices.append((neg_slice, offset, -1)) # (slice_id, offset, direction)
168
+ if pos_slice in slice_kdtrees:
169
+ adjacent_slices.append((pos_slice, offset, 1))
170
+
171
+ # If we don't have enough adjacent slices (2*n_adjacent_slices total), compensate
172
+ target_count = 2 * n_adjacent_slices
173
+ if len(adjacent_slices) < target_count:
174
+ # Search deeper in available directions
175
+ extra_offset = n_adjacent_slices + 1
176
+ while len(adjacent_slices) < target_count and extra_offset <= max(max_slice_id - min_slice_id, 10):
177
+ neg_slice = slice_id - extra_offset
178
+ pos_slice = slice_id + extra_offset
179
+
180
+ if neg_slice >= min_slice_id and neg_slice in slice_kdtrees:
181
+ if not any(s[0] == neg_slice for s in adjacent_slices):
182
+ adjacent_slices.append((neg_slice, extra_offset, -1))
183
+ if len(adjacent_slices) >= target_count:
184
+ break
185
+
186
+ if pos_slice <= max_slice_id and pos_slice in slice_kdtrees:
187
+ if not any(s[0] == pos_slice for s in adjacent_slices):
188
+ adjacent_slices.append((pos_slice, extra_offset, 1))
189
+ if len(adjacent_slices) >= target_count:
190
+ break
191
+
192
+ extra_offset += 1
193
+
194
+ # Sort adjacent slices: first by offset (closer slices first), then by direction
195
+ # This maintains consistency: [offset=1,dir=-1], [offset=1,dir=1], [offset=2,dir=-1], [offset=2,dir=1], ...
196
+ adjacent_slices.sort(key=lambda x: (x[1], -x[2])) # Sort by offset, then direction (-1 before 1)
197
+
198
+ slice_adjacent_mapping[slice_id] = adjacent_slices
199
+
200
+ if len(adjacent_slices) < target_count:
201
+ logger.warning(f"Slice {slice_id} only has {len(adjacent_slices)} adjacent slices available "
202
+ f"(target: {target_count}). Some neighbor slots will remain empty.")
203
+
204
+ # Log the adjacent slice mapping for verification
205
+ for slice_id in sorted(slice_adjacent_mapping.keys()):
206
+ adj_slice_ids = [s[0] for s in slice_adjacent_mapping[slice_id]]
207
+ logger.debug(f"Slice {slice_id} will search in adjacent slices: {adj_slice_ids}")
208
+
209
+ # Batch process all query cells (including edge cases with few neighbor pool cells)
210
+ for slice_id in unique_slice_ids:
211
+ if slice_id in slice_kdtrees:
212
+ # Get all query cells on this slice
213
+ query_cells_on_slice = slice_to_query_cell_indices[slice_id]
214
+ query_coords = query_cell_coords[query_cells_on_slice]
215
+
216
+ # Central slice neighbors (fixed k_central)
217
+ central_kdtree, central_neighbor_pool_global_indices = slice_kdtrees[slice_id]
218
+ _, central_neighbor_local_indices = central_kdtree.query(query_coords, k=k_central, workers=-1)
219
+
220
+ # Handle invalid indices returned by KDTree when k > number of points in tree
221
+ # KDTree returns index >= tree_size for invalid slots
222
+ n_central_pool_cells = len(central_neighbor_pool_global_indices)
223
+ # Initialize with -1 (invalid)
224
+ central_neighbor_global_indices = np.full_like(central_neighbor_local_indices, -1, dtype=np.int32)
225
+ # Create mask for valid indices (< tree size)
226
+ central_valid_mask = central_neighbor_local_indices < n_central_pool_cells
227
+ # Map valid local indices to global indices
228
+ central_neighbor_global_indices[central_valid_mask] = central_neighbor_pool_global_indices[
229
+ central_neighbor_local_indices[central_valid_mask]
230
+ ]
231
+ spatial_neighbors[query_cells_on_slice, :k_central] = central_neighbor_global_indices
232
+
233
+ # Adjacent slices - use the pre-computed mapping
234
+ # Get the list of adjacent slices to search for this slice_id
235
+ adjacent_slices_to_search = slice_adjacent_mapping.get(slice_id, [])
236
+
237
+ neighbor_column_offset = k_central
238
+
239
+ # We need to fill exactly 2*n_adjacent_slices slots in the output array
240
+ # The array structure is: [offset=1,dir=-1], [offset=1,dir=+1], [offset=2,dir=-1], [offset=2,dir=+1], ...
241
+ for slot_idx in range(2 * n_adjacent_slices):
242
+ if slot_idx < len(adjacent_slices_to_search):
243
+ # We have a slice to search for this slot
244
+ adjacent_slice_id, _, _ = adjacent_slices_to_search[slot_idx]
245
+
246
+ adjacent_kdtree, adjacent_neighbor_pool_global_indices = slice_kdtrees[adjacent_slice_id]
247
+ _, adjacent_neighbor_local_indices = adjacent_kdtree.query(query_coords, k=k_adjacent, workers=-1)
248
+
249
+ # Handle invalid indices for adjacent slices
250
+ n_adjacent_pool_cells = len(adjacent_neighbor_pool_global_indices)
251
+ adjacent_neighbor_global_indices = np.full_like(adjacent_neighbor_local_indices, -1, dtype=np.int32)
252
+ adjacent_valid_mask = adjacent_neighbor_local_indices < n_adjacent_pool_cells
253
+ adjacent_neighbor_global_indices[adjacent_valid_mask] = adjacent_neighbor_pool_global_indices[
254
+ adjacent_neighbor_local_indices[adjacent_valid_mask]
255
+ ]
256
+ spatial_neighbors[query_cells_on_slice, neighbor_column_offset:neighbor_column_offset+k_adjacent] = adjacent_neighbor_global_indices
257
+ # else: slot remains as -1 (already initialized) if no slice available
258
+
259
+ neighbor_column_offset += k_adjacent
260
+
261
+ # Note: Slices with 0 neighbor pool cells will not be in slice_kdtrees
262
+ # Query cells on such slices will have all neighbors remain as -1 (already initialized)
263
+
264
+ return spatial_neighbors, neighbor_pool_per_slice
265
+
266
+
267
+ @partial(jit, static_argnums=(5, 6))
268
+ def _find_homogeneous_spot_dual_embedding_batch_jit(
269
+ emb_niche_batch_norm: jnp.ndarray, # (batch_size, d1)
270
+ emb_indv_batch_norm: jnp.ndarray, # (batch_size, d2)
271
+ spatial_neighbors: jnp.ndarray, # (batch_size, k1)
272
+ all_emb_niche_norm: jnp.ndarray, # (n_all, d1)
273
+ all_emb_indv_norm: jnp.ndarray, # (n_all, d2)
274
+ homogeneous_neighbors: int,
275
+ cell_embedding_similarity_threshold: float = 0.0,
276
+ spatial_domain_similarity_threshold: float = 0.5
277
+ ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
278
+ """
279
+ Finds homogeneous neighbors using a soft priority (Bonus) system.
280
+
281
+ Returns:
282
+ homogeneous_neighbors_array: Indices of selected neighbors (batch, homogeneous_neighbors)
283
+ selected_cell_sims: Cell similarity scores for selected neighbors (batch, homogeneous_neighbors)
284
+ selected_niche_sims: Niche/Spatial similarity scores for selected neighbors (batch, homogeneous_neighbors)
285
+ """
286
+ batch_size = emb_indv_batch_norm.shape[0]
287
+
288
+ # Step 1: Extract spatial neighbors' embeddings
289
+ # Handle padding: map -1 to 0 temporarily for extraction
290
+ safe_neighbors = jnp.where(spatial_neighbors >= 0, spatial_neighbors, 0)
291
+ spatial_emb_indv_norm = all_emb_indv_norm[safe_neighbors]
292
+ spatial_emb_niche_norm = all_emb_niche_norm[safe_neighbors]
293
+
294
+ # Step 2: Compute Similarities
295
+ # 2a. Cell Similarity
296
+ raw_cell_sims = jnp.einsum('bd,bkd->bk', emb_indv_batch_norm, spatial_emb_indv_norm)
297
+ # Apply cell threshold: values below threshold become 0.0
298
+ cell_sims = jnp.where(raw_cell_sims >= cell_embedding_similarity_threshold, raw_cell_sims, 0.0)
299
+
300
+ # 2b. Niche (Spatial) Similarity
301
+ niche_sims = jnp.einsum('bd,bkd->bk', emb_niche_batch_norm, spatial_emb_niche_norm)
302
+
303
+ # Step 3: Compute Ranking Score with Bonus
304
+ # Logic:
305
+ # If niche_sim >= threshold: Score = cell_sim + 2.0
306
+ # Else: Score = cell_sim
307
+ # This prioritizes niche matches first, then falls back to highest cell sim.
308
+ BONUS = 2.0
309
+ pass_spatial_mask = niche_sims >= spatial_domain_similarity_threshold
310
+ ranking_score = cell_sims + jnp.where(pass_spatial_mask, BONUS, 0.0)
311
+
312
+ # Step 4: Mask Invalid Neighbors in Ranking
313
+ # Force padding indices (original -1s) to the bottom of the list
314
+ valid_neighbor_mask = spatial_neighbors >= 0
315
+ ranking_score = jnp.where(valid_neighbor_mask, ranking_score, -100.0)
316
+
317
+ # Step 5: Sort and Select
318
+ # Get indices of top scores
319
+ top_homo_idx = jnp.argsort(-ranking_score, axis=1)[:, :homogeneous_neighbors]
320
+ batch_idx = jnp.arange(batch_size)[:, None]
321
+
322
+ # Gather the results
323
+ homogeneous_neighbors_array = spatial_neighbors[batch_idx, top_homo_idx]
324
+ selected_cell_sims = cell_sims[batch_idx, top_homo_idx]
325
+ selected_niche_sims = niche_sims[batch_idx, top_homo_idx]
326
+
327
+ # Step 6: Final Masking
328
+ # If the selected neighbor is invalid (which happens if k1 < homogeneous_neighbors
329
+ # or all valid neighbors were exhausted), ensure returned similarities are 0.0.
330
+ final_valid_mask = homogeneous_neighbors_array >= 0
331
+
332
+ selected_cell_sims = jnp.where(final_valid_mask, selected_cell_sims, 0.0)
333
+ selected_niche_sims = jnp.where(final_valid_mask, selected_niche_sims, 0.0)
334
+
335
+ return homogeneous_neighbors_array, selected_cell_sims, selected_niche_sims
336
+ def _find_homogeneous_3d_memory_efficient(
337
+ emb_niche_masked_jax: jnp.ndarray,
338
+ emb_indv_masked_jax: jnp.ndarray,
339
+ spatial_neighbors: np.ndarray,
340
+ all_emb_niche_norm_jax: jnp.ndarray,
341
+ all_emb_indv_norm_jax: jnp.ndarray,
342
+ num_homogeneous_per_slice: int,
343
+ k_central: int,
344
+ k_adjacent: int,
345
+ n_adjacent_slices: int,
346
+ cell_embedding_similarity_threshold: float,
347
+ spatial_domain_similarity_threshold: float,
348
+ find_homogeneous_batch_size: int
349
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
350
+ """
351
+ Memory-efficient version of 3D homogeneous neighbor finding.
352
+ Processes slices separately to avoid large memory allocations.
353
+
354
+ Args:
355
+ emb_niche_masked_jax: Niche embeddings for masked cells (JAX array, float16, always exists, may be dummy ones)
356
+ emb_indv_masked_jax: Individual embeddings for masked cells (JAX array, float16)
357
+ spatial_neighbors: Spatial neighbors array with structure [central | adj1 | adj2 | ...] (numpy array)
358
+ all_emb_niche_norm_jax: All normalized niche embeddings (JAX array, float16, always exists, may be dummy ones)
359
+ all_emb_indv_norm_jax: All normalized individual embeddings (JAX array, float16)
360
+ num_homogeneous_per_slice: Number of neighbors to select per slice
361
+ k_central: Number of neighbors in central slice
362
+ k_adjacent: Number of neighbors per adjacent slice
363
+ n_adjacent_slices: Number of adjacent slices above and below
364
+ cell_embedding_similarity_threshold: Minimum similarity threshold for cell embedding
365
+ spatial_domain_similarity_threshold: Minimum similarity threshold for spatial domain embedding
366
+ find_homogeneous_batch_size: Batch size for processing
367
+
368
+ Returns:
369
+ homogeneous_neighbors: Selected neighbors
370
+ homogeneous_cell_sims: Cell similarity scores for selected neighbors
371
+ homogeneous_niche_sims: Niche/Spatial similarity scores for selected neighbors
372
+ """
373
+ n_masked = emb_indv_masked_jax.shape[0]
374
+ 1 + 2 * n_adjacent_slices
375
+
376
+ homogeneous_neighbors_all_slices = []
377
+ homogeneous_cell_sims_all_slices = []
378
+ homogeneous_niche_sims_all_slices = []
379
+
380
+ # Process all slices (central + adjacent) in a single loop
381
+ total_slices = 1 + 2 * n_adjacent_slices
382
+
383
+ # Create overall slice progress tracking
384
+ with Progress(
385
+ SpinnerColumn(),
386
+ TextColumn("[bold blue]{task.description}"),
387
+ BarColumn(),
388
+ MofNCompleteColumn(),
389
+ TaskProgressColumn(),
390
+ TimeRemainingColumn(),
391
+ TimeElapsedColumn(),
392
+ refresh_per_second=1
393
+ ) as slice_progress:
394
+ # Overall slice progress task
395
+ slice_task = slice_progress.add_task(
396
+ "Finding homogeneous neighbors (3D cross-slice)...",
397
+ total=total_slices
398
+ )
399
+
400
+ for slice_num in range(total_slices):
401
+ # Determine slice name and parameters
402
+ if slice_num == 0:
403
+ slice_name = "central slice"
404
+ slice_start = 0
405
+ slice_end = k_central
406
+ else:
407
+ # Adjacent slices
408
+ adj_idx = slice_num - 1
409
+ if adj_idx < n_adjacent_slices:
410
+ slice_name = f"adjacent slice -{n_adjacent_slices - adj_idx}"
411
+ else:
412
+ slice_name = f"adjacent slice +{adj_idx - n_adjacent_slices + 1}"
413
+
414
+ slice_start = k_central + adj_idx * k_adjacent
415
+ slice_end = slice_start + k_adjacent
416
+
417
+ # Convert slice neighbors to JAX array once per slice
418
+ spatial_neighbors_slice = jnp.asarray(spatial_neighbors[:, slice_start:slice_end])
419
+
420
+ homogeneous_neighbors_slice_list = []
421
+ homogeneous_cell_sims_slice_list = []
422
+ homogeneous_niche_sims_slice_list = []
423
+
424
+ # Process batches for this slice using simple transient track
425
+ for batch_start in track(range(0, n_masked, find_homogeneous_batch_size),
426
+ description=f"Finding homogeneous neighbors ({slice_name})",
427
+ transient=True):
428
+ batch_end = min(batch_start + find_homogeneous_batch_size, n_masked)
429
+ batch_indices = slice(batch_start, batch_end)
430
+
431
+ # Get batch data (emb_niche always exists now, may be dummy ones)
432
+ emb_niche_batch_norm = emb_niche_masked_jax[batch_indices]
433
+ emb_indv_batch_norm = emb_indv_masked_jax[batch_indices]
434
+
435
+ # Extract batch of neighbors for this slice
436
+ spatial_neighbors_slice_batch = spatial_neighbors_slice[batch_indices, :]
437
+
438
+ # Process with 2D function
439
+ homo_neighbors_batch, cell_sims_batch, niche_sims_batch = _find_homogeneous_spot_dual_embedding_batch_jit(
440
+ emb_niche_batch_norm=emb_niche_batch_norm,
441
+ emb_indv_batch_norm=emb_indv_batch_norm,
442
+ spatial_neighbors=spatial_neighbors_slice_batch,
443
+ all_emb_niche_norm=all_emb_niche_norm_jax,
444
+ all_emb_indv_norm=all_emb_indv_norm_jax,
445
+ homogeneous_neighbors=num_homogeneous_per_slice,
446
+ cell_embedding_similarity_threshold=cell_embedding_similarity_threshold,
447
+ spatial_domain_similarity_threshold=spatial_domain_similarity_threshold
448
+ )
449
+
450
+ homogeneous_neighbors_slice_list.append(np.array(homo_neighbors_batch))
451
+ homogeneous_cell_sims_slice_list.append(np.array(cell_sims_batch))
452
+ homogeneous_niche_sims_slice_list.append(np.array(niche_sims_batch))
453
+
454
+ # Concatenate this slice's results
455
+ homogeneous_neighbors_slice = np.vstack(homogeneous_neighbors_slice_list)
456
+ homogeneous_cell_sims_slice = np.vstack(homogeneous_cell_sims_slice_list)
457
+ homogeneous_niche_sims_slice = np.vstack(homogeneous_niche_sims_slice_list)
458
+ homogeneous_neighbors_all_slices.append(homogeneous_neighbors_slice)
459
+ homogeneous_cell_sims_all_slices.append(homogeneous_cell_sims_slice)
460
+ homogeneous_niche_sims_all_slices.append(homogeneous_niche_sims_slice)
461
+
462
+ # Update slice progress
463
+ slice_progress.update(slice_task, advance=1)
464
+
465
+ # Concatenate all slices along axis 1
466
+ homogeneous_neighbors = np.concatenate(homogeneous_neighbors_all_slices, axis=1)
467
+ homogeneous_cell_sims = np.concatenate(homogeneous_cell_sims_all_slices, axis=1)
468
+ homogeneous_niche_sims = np.concatenate(homogeneous_niche_sims_all_slices, axis=1)
469
+
470
+ return homogeneous_neighbors, homogeneous_cell_sims, homogeneous_niche_sims
471
+
472
+
473
+ def build_scrna_connectivity(
474
+ emb_cell: np.ndarray,
475
+ cell_mask: np.ndarray | None = None,
476
+ n_neighbors: int = 21,
477
+ metric: str = 'euclidean'
478
+ ) -> tuple[np.ndarray, np.ndarray, None]:
479
+ """
480
+ Build connectivity for scRNA-seq data using KNN on cell embeddings.
481
+
482
+ Args:
483
+ emb_cell: Cell embeddings (n_cells, d)
484
+ cell_mask: Boolean mask for cells to process
485
+ n_neighbors: Number of nearest neighbors
486
+ metric: Distance metric for KNN
487
+
488
+ Returns:
489
+ neighbor_indices: (n_masked, n_neighbors) array of neighbor indices
490
+ neighbor_weights: (n_masked, n_neighbors) array of weights from KNN graph
491
+ niche_sims: None (scRNA-seq has no niche embedding)
492
+ """
493
+ n_cells = len(emb_cell)
494
+ if cell_mask is None:
495
+ cell_mask = np.ones(n_cells, dtype=bool)
496
+
497
+ cell_indices = np.where(cell_mask)[0]
498
+ n_masked = len(cell_indices)
499
+
500
+ logger.info(f"Building scRNA-seq connectivity using KNN with k={n_neighbors}")
501
+
502
+ # Create temporary AnnData for using scanpy's neighbors function
503
+ adata_temp = ad.AnnData(X=emb_cell[cell_mask])
504
+ adata_temp.obsm['X_emb'] = emb_cell[cell_mask]
505
+
506
+ # Compute neighbors using scanpy
507
+ sc.pp.neighbors(
508
+ adata_temp,
509
+ n_neighbors=n_neighbors,
510
+ use_rep='X_emb',
511
+ metric=metric,
512
+ method='umap'
513
+ )
514
+
515
+ # Extract connectivity matrix
516
+ connectivities = adata_temp.obsp['connectivities'].tocsr()
517
+
518
+ # Convert to dense format for consistency with spatial methods
519
+ # Check the actual max cell index value to determine dtype
520
+ max_cell_idx = cell_indices.max() if len(cell_indices) > 0 else 0
521
+ idx_dtype = np.int16 if max_cell_idx < 32768 else np.int32
522
+ neighbor_indices = np.zeros((n_masked, n_neighbors), dtype=idx_dtype)
523
+ neighbor_weights = np.zeros((n_masked, n_neighbors), dtype=np.float16)
524
+
525
+ for i in range(n_masked):
526
+ row = connectivities.getrow(i)
527
+ neighbors = row.indices
528
+ weights = row.data
529
+
530
+ # Sort by weight (descending)
531
+ sorted_idx = np.argsort(-weights)[:n_neighbors]
532
+
533
+ if len(sorted_idx) < n_neighbors:
534
+ # Pad with self-index if needed
535
+ n_found = len(sorted_idx)
536
+ neighbor_indices[i, :n_found] = cell_indices[neighbors[sorted_idx]]
537
+ neighbor_indices[i, n_found:] = cell_indices[i]
538
+ neighbor_weights[i, :n_found] = weights[sorted_idx]
539
+ neighbor_weights[i, n_found:] = 0.0
540
+ else:
541
+ neighbor_indices[i] = cell_indices[neighbors[sorted_idx]]
542
+ neighbor_weights[i] = weights[sorted_idx]
543
+
544
+ # Normalize weights to sum to 1
545
+ weight_sums = neighbor_weights.sum(axis=1, keepdims=True)
546
+ weight_sums = np.where(weight_sums > 0, weight_sums, 1.0)
547
+ neighbor_weights = neighbor_weights / weight_sums
548
+
549
+ logger.info(f"scRNA-seq connectivity built: {n_masked} cells × {n_neighbors} neighbors")
550
+
551
+ return neighbor_indices, neighbor_weights, None
552
+
553
+
554
+ class ConnectivityMatrixBuilder:
555
+ """Build connectivity matrix using JAX-accelerated computation with GPU memory optimization"""
556
+
557
+ def __init__(self, config: LatentToGeneConfig):
558
+ """
559
+ Initialize with configuration
560
+
561
+ Args:
562
+ config: LatentToGeneConfig object
563
+ """
564
+ self.config = config
565
+ # Use configured batch size for GPU processing
566
+ self.find_homogeneous_batch_size = config.find_homogeneous_batch_size
567
+ self.dataset_type = config.dataset_type
568
+
569
+ def build_connectivity_matrix(
570
+ self,
571
+ coords: np.ndarray | None = None,
572
+ emb_niche: np.ndarray = None,
573
+ emb_indv: np.ndarray | None = None,
574
+ cell_mask: np.ndarray | None = None,
575
+ high_quality_mask: np.ndarray = None,
576
+ slice_ids: np.ndarray | None = None,
577
+ k_central: int | None = None,
578
+ k_adjacent: int | None = None,
579
+ n_adjacent_slices: int | None = None
580
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
581
+ """
582
+ Build connectivity matrix for a group of cells based on dataset type.
583
+
584
+ For scRNA-seq: Uses KNN on cell embeddings (emb_indv)
585
+ For spatial2D: Uses spatial anchors and homogeneous neighbors
586
+ For spatial3D: Uses slice-aware spatial anchors and homogeneous neighbors
587
+
588
+ Args:
589
+ coords: Spatial coordinates (n_cells, 2) - required for spatial datasets
590
+ emb_niche: Niche embeddings (n_cells, d1) - always provided (may be dummy ones for scRNA-seq)
591
+ emb_indv: Cell identity embeddings (n_cells, d2) - required for all datasets
592
+ cell_mask: Boolean mask for cells to process
593
+ high_quality_mask: Boolean mask for high quality cells (used for neighbor search in spatial data)
594
+ slice_ids: Optional slice/z-coordinate indices (n_cells,) for spatial3D
595
+ k_central: Number of neighbors on central slice (defaults to config settings)
596
+ k_adjacent: Number of neighbors on adjacent slices for spatial3D
597
+ n_adjacent_slices: Number of slices to search above/below for spatial3D
598
+
599
+ Returns:
600
+ Tuple of (neighbor_indices, cell_sims, niche_sims) arrays.
601
+ niche_sims is None for scRNA-seq datasets.
602
+ """
603
+ # Check dataset type and call appropriate method
604
+ if self.dataset_type == DatasetType.SCRNA_SEQ:
605
+ logger.info("Building connectivity for scRNA-seq dataset")
606
+ if emb_indv is None:
607
+ raise ValueError("emb_indv (cell embeddings) required for scRNA-seq dataset")
608
+
609
+ return build_scrna_connectivity(
610
+ emb_cell=emb_indv,
611
+ cell_mask=cell_mask,
612
+ n_neighbors=self.config.homogeneous_neighbors,
613
+ metric='euclidean'
614
+ )
615
+
616
+ elif self.dataset_type in ['spatial2D', 'spatial3D']:
617
+ logger.info(f"Building connectivity for {self.dataset_type} dataset")
618
+
619
+ # Validate required inputs for spatial datasets
620
+ if coords is None or emb_indv is None:
621
+ raise ValueError("coords and emb_indv required for spatial datasets")
622
+
623
+ # Use config defaults if not provided
624
+ if k_central is None:
625
+ k_central = self.config.spatial_neighbors
626
+ if k_adjacent is None:
627
+ k_adjacent = self.config.adjacent_slice_spatial_neighbors
628
+ if n_adjacent_slices is None:
629
+ if self.dataset_type == DatasetType.SPATIAL_2D:
630
+ n_adjacent_slices = 0
631
+ else: # spatial3D
632
+ n_adjacent_slices = self.config.n_adjacent_slices
633
+
634
+ # For spatial2D, ensure no cross-slice search (but can still have slice_ids)
635
+ if self.dataset_type == DatasetType.SPATIAL_2D:
636
+ n_adjacent_slices = 0 # No cross-slice search, but keep slice_ids if provided
637
+
638
+ return self._build_spatial_connectivity(
639
+ coords=coords,
640
+ emb_niche=emb_niche,
641
+ emb_indv=emb_indv,
642
+ cell_mask=cell_mask,
643
+ high_quality_mask=high_quality_mask,
644
+ slice_ids=slice_ids,
645
+ k_central=k_central,
646
+ k_adjacent=k_adjacent,
647
+ n_adjacent_slices=n_adjacent_slices
648
+ )
649
+
650
+ else:
651
+ raise ValueError(f"Unknown dataset type: {self.dataset_type}")
652
+
653
+ def _build_spatial_connectivity(
654
+ self,
655
+ coords: np.ndarray,
656
+ emb_indv: np.ndarray,
657
+ emb_niche: np.ndarray = None,
658
+ cell_mask: np.ndarray | None = None,
659
+ high_quality_mask: np.ndarray = None,
660
+ slice_ids: np.ndarray | None = None,
661
+ k_central: int = 101,
662
+ k_adjacent: int = 50,
663
+ n_adjacent_slices: int = 1
664
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
665
+ """
666
+ Internal method for building spatial connectivity matrix.
667
+
668
+ Args:
669
+ coords: Spatial coordinates (n_cells, 2) - only x,y coordinates
670
+ emb_niche: Niche embeddings (n_cells, d1) - always provided (may be dummy ones)
671
+ emb_indv: Cell identity embeddings (n_cells, d2)
672
+ cell_mask: Boolean mask for cells to process
673
+ high_quality_mask: Boolean mask for high quality cells (used for neighbor search)
674
+ slice_ids: Optional slice/z-coordinate indices (n_cells,) for 3D data
675
+ k_central: Number of neighbors on central slice
676
+ k_adjacent: Number of neighbors on adjacent slices for 3D data
677
+ n_adjacent_slices: Number of slices to search above/below for 3D data
678
+
679
+ Returns:
680
+ Tuple of (neighbor_indices, cell_sims, niche_sims) arrays
681
+ """
682
+
683
+ n_cells = len(coords)
684
+ if cell_mask is None:
685
+ cell_mask = np.ones(n_cells, dtype=bool)
686
+
687
+ cell_indices = np.where(cell_mask)[0]
688
+ n_masked = len(cell_indices)
689
+
690
+ # Step 1: Find spatial neighbors (slice-aware if slice_ids provided)
691
+ spatial_neighbors, neighbor_pool_per_slice = find_spatial_neighbors_with_slices(
692
+ coords=coords,
693
+ slice_ids=slice_ids,
694
+ query_cell_mask=cell_mask,
695
+ high_quality_cell_mask=high_quality_mask,
696
+ k_central=k_central,
697
+ k_adjacent=k_adjacent,
698
+ n_adjacent_slices=n_adjacent_slices
699
+ )
700
+
701
+ # Log statistics about neighbor pool cells per slice if available
702
+ if neighbor_pool_per_slice:
703
+ for slice_id, neighbor_pool_indices in neighbor_pool_per_slice.items():
704
+ logger.debug(f"Slice {slice_id}: {len(neighbor_pool_indices)} high quality cells available for neighbor search")
705
+
706
+ # Step 2 & 3: Find anchors and homogeneous neighbors in batches
707
+ logger.info(f"Finding anchors and homogeneous neighbors (batch size: {self.find_homogeneous_batch_size})...")
708
+
709
+ # Convert embeddings to JAX arrays once (shared for both paths)
710
+ # Note: float16 provides sufficient precision for normalized embeddings
711
+ # emb_niche is guaranteed to exist now (may be dummy ones)
712
+ all_emb_niche_norm_jax = jnp.array(emb_niche, dtype=jnp.float16)
713
+ all_emb_indv_norm_jax = jnp.array(emb_indv, dtype=jnp.float16)
714
+
715
+ # Get masked embeddings
716
+ masked_cell_indices = np.where(cell_mask)[0]
717
+ emb_niche_masked_jax = all_emb_niche_norm_jax[masked_cell_indices]
718
+ emb_indv_masked_jax = all_emb_indv_norm_jax[masked_cell_indices]
719
+
720
+
721
+ if self.config.fix_cross_slice_homogenous_neighbors:
722
+ logger.info(f"Using 3D constrained selection (ensuring {self.config.homogeneous_neighbors} neighbors per slice)")
723
+
724
+ # Use memory-efficient version that processes slices separately
725
+ homogeneous_neighbors, homogeneous_cell_sims, homogeneous_niche_sims = _find_homogeneous_3d_memory_efficient(
726
+ emb_niche_masked_jax=emb_niche_masked_jax,
727
+ emb_indv_masked_jax=emb_indv_masked_jax,
728
+ spatial_neighbors=spatial_neighbors,
729
+ all_emb_niche_norm_jax=all_emb_niche_norm_jax,
730
+ all_emb_indv_norm_jax=all_emb_indv_norm_jax,
731
+ num_homogeneous_per_slice=self.config.homogeneous_neighbors,
732
+ k_central=k_central,
733
+ k_adjacent=k_adjacent,
734
+ n_adjacent_slices=n_adjacent_slices,
735
+ cell_embedding_similarity_threshold=self.config.cell_embedding_similarity_threshold,
736
+ spatial_domain_similarity_threshold=self.config.spatial_domain_similarity_threshold,
737
+ find_homogeneous_batch_size=self.find_homogeneous_batch_size
738
+ )
739
+
740
+ else:
741
+ # Convert spatial_neighbors to JAX array for regular processing
742
+ spatial_neighbors_jax = jnp.array(spatial_neighbors, dtype=jnp.int32)
743
+ homogeneous_neighbors_list = []
744
+ homogeneous_cell_sims_list = []
745
+ homogeneous_niche_sims_list = []
746
+
747
+
748
+ # Use the standard function (2D or 3D without fix_cross_slice_homogenous_neighbors)
749
+ for batch_start in track(range(0, n_masked, self.find_homogeneous_batch_size), description="Finding homogeneous neighbors", transient=True):
750
+ batch_end = min(batch_start + self.find_homogeneous_batch_size, n_masked)
751
+ batch_indices = slice(batch_start, batch_end)
752
+
753
+ # Get batch data directly from JAX arrays (no GPU movement)
754
+ emb_niche_batch_norm = emb_niche_masked_jax[batch_indices]
755
+ emb_indv_batch_norm = emb_indv_masked_jax[batch_indices]
756
+ spatial_neighbors_batch = spatial_neighbors_jax[batch_indices]
757
+
758
+ # Process batch with single JIT-compiled function
759
+ homo_neighbors_batch, cell_sims_batch, niche_sims_batch = _find_homogeneous_spot_dual_embedding_batch_jit(
760
+ emb_niche_batch_norm=emb_niche_batch_norm,
761
+ emb_indv_batch_norm=emb_indv_batch_norm,
762
+ spatial_neighbors=spatial_neighbors_batch,
763
+ all_emb_niche_norm=all_emb_niche_norm_jax,
764
+ all_emb_indv_norm=all_emb_indv_norm_jax,
765
+ homogeneous_neighbors=self.config.total_homogeneous_neighbor_per_cell,
766
+ cell_embedding_similarity_threshold=self.config.cell_embedding_similarity_threshold,
767
+ spatial_domain_similarity_threshold=self.config.spatial_domain_similarity_threshold
768
+ )
769
+
770
+ # Convert back to numpy and append
771
+ homogeneous_neighbors_list.append(np.array(homo_neighbors_batch))
772
+ homogeneous_cell_sims_list.append(np.array(cell_sims_batch))
773
+ homogeneous_niche_sims_list.append(np.array(niche_sims_batch))
774
+
775
+ # Regular batched processing
776
+ homogeneous_neighbors = np.vstack(homogeneous_neighbors_list)
777
+ homogeneous_cell_sims = np.vstack(homogeneous_cell_sims_list)
778
+ homogeneous_niche_sims = np.vstack(homogeneous_niche_sims_list)
779
+
780
+ # Return dense format: (n_masked, num_homogeneous) arrays
781
+ return homogeneous_neighbors, homogeneous_cell_sims, homogeneous_niche_sims