gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,461 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration for latent to gene mapping.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Annotated
|
|
11
|
+
|
|
12
|
+
import typer
|
|
13
|
+
import yaml
|
|
14
|
+
|
|
15
|
+
from gsMap.config.base import ConfigWithAutoPaths
|
|
16
|
+
from gsMap.config.utils import configure_jax_platform, process_h5ad_inputs, validate_h5ad_structure
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger("gsMap.config")
|
|
19
|
+
|
|
20
|
+
class DatasetType(str, Enum):
|
|
21
|
+
SCRNA_SEQ = 'scRNA'
|
|
22
|
+
SPATIAL_2D = 'spatial2D'
|
|
23
|
+
SPATIAL_3D = 'spatial3D'
|
|
24
|
+
|
|
25
|
+
class MarkerScoreCrossSliceStrategy(str, Enum):
|
|
26
|
+
GLOBAL_POOL = 'global_pool'
|
|
27
|
+
PER_SLICE_POOL = 'per_slice_pool'
|
|
28
|
+
HIERARCHICAL_POOL = 'hierarchical_pool'
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class LatentToGeneComputeConfig:
|
|
32
|
+
"""Compute configuration for latent-to-gene step."""
|
|
33
|
+
__display_in_quick_mode_cli__ = True
|
|
34
|
+
|
|
35
|
+
use_gpu: Annotated[bool, typer.Option(
|
|
36
|
+
"--use-gpu/--no-gpu",
|
|
37
|
+
help="Use GPU for JAX computations (requires sufficient GPU memory)"
|
|
38
|
+
)] = True
|
|
39
|
+
|
|
40
|
+
memmap_tmp_dir: Annotated[Path | None, typer.Option(
|
|
41
|
+
help="Temporary directory for memory-mapped files to improve I/O performance on slow filesystems. "
|
|
42
|
+
"If provided, memory maps will be copied to this directory for faster random access during computation.",
|
|
43
|
+
exists=True,
|
|
44
|
+
file_okay=False,
|
|
45
|
+
dir_okay=True,
|
|
46
|
+
resolve_path=True
|
|
47
|
+
)] = None
|
|
48
|
+
|
|
49
|
+
# Batch sizes
|
|
50
|
+
rank_batch_size: int = 500
|
|
51
|
+
|
|
52
|
+
mkscore_batch_size: Annotated[int, typer.Option(
|
|
53
|
+
help="Number of cells per batch for marker score calculation. Reduce this value (e.g., 50) if encountering GPU OOM errors.",
|
|
54
|
+
min=10,
|
|
55
|
+
max=1000
|
|
56
|
+
)] = 500
|
|
57
|
+
|
|
58
|
+
find_homogeneous_batch_size: int = 100
|
|
59
|
+
rank_write_interval: int = 10
|
|
60
|
+
|
|
61
|
+
# Worker configurations
|
|
62
|
+
rank_read_workers: Annotated[int, typer.Option(
|
|
63
|
+
help="Number of parallel reader threads for rank memory map",
|
|
64
|
+
min=1,
|
|
65
|
+
max=50
|
|
66
|
+
)] = 16
|
|
67
|
+
|
|
68
|
+
mkscore_compute_workers: Annotated[int, typer.Option(
|
|
69
|
+
help="Number of parallel compute threads for marker score calculation",
|
|
70
|
+
min=1,
|
|
71
|
+
max=16
|
|
72
|
+
)] = 4
|
|
73
|
+
|
|
74
|
+
mkscore_write_workers: Annotated[int, typer.Option(
|
|
75
|
+
help="Number of parallel writer threads for marker scores",
|
|
76
|
+
min=1,
|
|
77
|
+
max=50
|
|
78
|
+
)] = 4
|
|
79
|
+
|
|
80
|
+
compute_input_queue_size: Annotated[int, typer.Option(
|
|
81
|
+
help="Maximum size of compute input queue (multiplier of mkscore_compute_workers)",
|
|
82
|
+
min=1,
|
|
83
|
+
max=10
|
|
84
|
+
)] = 5
|
|
85
|
+
|
|
86
|
+
writer_queue_size: Annotated[int, typer.Option(
|
|
87
|
+
help="Maximum size of writer input queue",
|
|
88
|
+
min=10,
|
|
89
|
+
max=500
|
|
90
|
+
)] = 100
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class LatentToGeneCoreConfig:
|
|
94
|
+
|
|
95
|
+
dataset_type: Annotated[DatasetType, typer.Option(
|
|
96
|
+
help="Type of dataset: scRNA (uses KNN on latent space), spatial2D (2D spatial), or spatial3D (multi-slice)",
|
|
97
|
+
case_sensitive=False
|
|
98
|
+
)] = 'spatial2D'
|
|
99
|
+
|
|
100
|
+
# --------input h5ad file paths which have the latent representations
|
|
101
|
+
h5ad_path: Annotated[list[Path] | None, typer.Option(
|
|
102
|
+
help="Space-separated list of h5ad file paths. Sample names are derived from file names without suffix.",
|
|
103
|
+
exists=True,
|
|
104
|
+
file_okay=True,
|
|
105
|
+
)] = None
|
|
106
|
+
|
|
107
|
+
h5ad_yaml: Annotated[Path | None, typer.Option(
|
|
108
|
+
help="YAML file with sample names and h5ad paths",
|
|
109
|
+
exists=True,
|
|
110
|
+
file_okay=True,
|
|
111
|
+
dir_okay=False,
|
|
112
|
+
)] = None
|
|
113
|
+
|
|
114
|
+
h5ad_list_file: Annotated[Path | None, typer.Option(
|
|
115
|
+
help="Each row is a h5ad file path, sample name is the file name without suffix",
|
|
116
|
+
exists=True,
|
|
117
|
+
file_okay=True,
|
|
118
|
+
dir_okay=False,
|
|
119
|
+
)] = None
|
|
120
|
+
|
|
121
|
+
sample_h5ad_dict: OrderedDict | None = None
|
|
122
|
+
|
|
123
|
+
# --------input h5ad obs, obsm, layers keys
|
|
124
|
+
|
|
125
|
+
annotation: Annotated[str | None, typer.Option(
|
|
126
|
+
help="Cell type annotation in adata.obs to use. This would constrain finding homogeneous spots within each cell type"
|
|
127
|
+
)] = None
|
|
128
|
+
|
|
129
|
+
data_layer: Annotated[str, typer.Option(
|
|
130
|
+
help="Gene expression raw counts data layer in h5ad layers, e.g., 'count', 'counts'. Other wise use 'X' for adata.X"
|
|
131
|
+
)] = "X"
|
|
132
|
+
|
|
133
|
+
latent_representation_niche: Annotated[str | None, typer.Option(
|
|
134
|
+
help="Key for spatial niche embedding in obsm"
|
|
135
|
+
)] = None
|
|
136
|
+
|
|
137
|
+
latent_representation_cell: Annotated[str, typer.Option(
|
|
138
|
+
help="Key for cell identity embedding in obsm"
|
|
139
|
+
)] = "emb_cell"
|
|
140
|
+
|
|
141
|
+
spatial_key: Annotated[str, typer.Option(
|
|
142
|
+
help="Spatial key in adata.obsm"
|
|
143
|
+
)] = "spatial"
|
|
144
|
+
|
|
145
|
+
# --------parameters for finding homogeneous spots
|
|
146
|
+
|
|
147
|
+
spatial_neighbors: Annotated[int, typer.Option(
|
|
148
|
+
help="k1: Number of spatial neighbors in it's own slice for spatial dataset",
|
|
149
|
+
min=10,
|
|
150
|
+
max=5000
|
|
151
|
+
)] = 301
|
|
152
|
+
|
|
153
|
+
homogeneous_neighbors: Annotated[int, typer.Option(
|
|
154
|
+
help="k3: Number of homogeneous neighbors per cell (for spatial) or KNN neighbors (for scRNA-seq)",
|
|
155
|
+
min=1,
|
|
156
|
+
max=200
|
|
157
|
+
)] = 21
|
|
158
|
+
|
|
159
|
+
cell_embedding_similarity_threshold: Annotated[float, typer.Option(
|
|
160
|
+
help="Minimum similarity threshold for cell embedding.",
|
|
161
|
+
min=0.0,
|
|
162
|
+
max=1.0
|
|
163
|
+
)] = 0.0
|
|
164
|
+
|
|
165
|
+
spatial_domain_similarity_threshold: Annotated[float, typer.Option(
|
|
166
|
+
help="Minimum similarity threshold for spatial domain embedding.",
|
|
167
|
+
min=0.0,
|
|
168
|
+
max=1.0
|
|
169
|
+
)] = 0.6
|
|
170
|
+
|
|
171
|
+
no_expression_fraction: Annotated[bool, typer.Option(
|
|
172
|
+
"--no-expression-fraction",
|
|
173
|
+
help="Skip expression fraction filtering"
|
|
174
|
+
)] = False
|
|
175
|
+
|
|
176
|
+
# --------3D slice-aware neighbor search parameters
|
|
177
|
+
adjacent_slice_spatial_neighbors: Annotated[int, typer.Option(
|
|
178
|
+
help="Number of spatial neighbors to find on each adjacent slice for 3D data",
|
|
179
|
+
min=10,
|
|
180
|
+
max=2000
|
|
181
|
+
)] = 200
|
|
182
|
+
|
|
183
|
+
n_adjacent_slices: Annotated[int, typer.Option(
|
|
184
|
+
help="Number of adjacent slices to search above and below (± n_adjacent_slices) in 3D space for each focal spot. Padding will be applied automatically.",
|
|
185
|
+
min=0,
|
|
186
|
+
max=5
|
|
187
|
+
)] = 1
|
|
188
|
+
|
|
189
|
+
cross_slice_marker_score_strategy: Annotated[MarkerScoreCrossSliceStrategy, typer.Option(
|
|
190
|
+
help="Strategy for computing marker scores across slices in spatial3D datasets. "
|
|
191
|
+
"'global_pool': Select the top K most similar neighbors globally across all slices combined. "
|
|
192
|
+
"'per_slice_pool': Select a fixed number of neighbors (K) from each slice independently, then compute a single weighted average score from all selected neighbors. "
|
|
193
|
+
"'hierarchical_pool': Compute an independent marker score for each slice using its top K neighbors, then take the average of these per-slice scores.",
|
|
194
|
+
case_sensitive=False
|
|
195
|
+
)] = MarkerScoreCrossSliceStrategy.HIERARCHICAL_POOL
|
|
196
|
+
|
|
197
|
+
high_quality_neighbor_filter: Annotated[bool, typer.Option(
|
|
198
|
+
"--high-quality-neighbor-filter/--no-high-quality-filter",
|
|
199
|
+
help="Only find neighbors within high quality cells (requires High_quality column in obs)"
|
|
200
|
+
)] = False
|
|
201
|
+
|
|
202
|
+
fix_cross_slice_homogenous_neighbors: bool = False
|
|
203
|
+
|
|
204
|
+
# Performance options
|
|
205
|
+
# Minimum number of cells per cell type in that dataset to be used for finding homogeneous neighbors
|
|
206
|
+
min_cells_per_type: int | None = None
|
|
207
|
+
|
|
208
|
+
enable_profiling: bool = False
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@dataclass
|
|
212
|
+
class LatentToGeneConfig(LatentToGeneComputeConfig, LatentToGeneCoreConfig, ConfigWithAutoPaths):
|
|
213
|
+
"""Latent to Gene Configuration"""
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def total_homogeneous_neighbor_per_cell(self):
|
|
217
|
+
return self.homogeneous_neighbors * (1 + 2 * self.n_adjacent_slices)
|
|
218
|
+
|
|
219
|
+
def __post_init__(self):
|
|
220
|
+
"""Initialize and validate configuration"""
|
|
221
|
+
super().__post_init__()
|
|
222
|
+
|
|
223
|
+
# Step 1: Configure JAX platform
|
|
224
|
+
configure_jax_platform(self.use_gpu)
|
|
225
|
+
|
|
226
|
+
# Step 2: Process and validate h5ad inputs
|
|
227
|
+
self._resolve_h5ad_inputs()
|
|
228
|
+
|
|
229
|
+
# Step 3: Configure dataset-specific parameters first
|
|
230
|
+
self._configure_dataset_parameters()
|
|
231
|
+
|
|
232
|
+
# Step 4: Set up validation fields and validate structure (after dataset config)
|
|
233
|
+
self._setup_and_validate_fields()
|
|
234
|
+
|
|
235
|
+
self.show_config(LatentToGeneConfig)
|
|
236
|
+
|
|
237
|
+
def _resolve_h5ad_inputs(self):
|
|
238
|
+
"""Resolve h5ad inputs from various sources, prioritizing auto-detection."""
|
|
239
|
+
|
|
240
|
+
# Step 1: Try auto-detection first
|
|
241
|
+
self._auto_detect_h5ad_files()
|
|
242
|
+
|
|
243
|
+
# Step 2: If auto-detection didn't find anything, try explicit inputs
|
|
244
|
+
if not self.sample_h5ad_dict:
|
|
245
|
+
# Define input options
|
|
246
|
+
input_options = {
|
|
247
|
+
'h5ad_yaml': ('h5ad_yaml', 'yaml'),
|
|
248
|
+
'h5ad_path': ('h5ad_path', 'list'),
|
|
249
|
+
'h5ad_list_file': ('h5ad_list_file', 'file'),
|
|
250
|
+
}
|
|
251
|
+
self.sample_h5ad_dict = process_h5ad_inputs(self, input_options)
|
|
252
|
+
|
|
253
|
+
# Step 3: Validate at least one sample exists
|
|
254
|
+
if not self.sample_h5ad_dict or len(self.sample_h5ad_dict) == 0:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"No valid h5ad files found. Please provide one of: h5ad_yaml, h5ad_path, or h5ad_list_file, "
|
|
257
|
+
"or ensure find_latent representation has been run to allow auto-detection."
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
logger.info(f"Loaded and validated {len(self.sample_h5ad_dict)} samples")
|
|
261
|
+
|
|
262
|
+
def _auto_detect_h5ad_files(self):
|
|
263
|
+
"""Auto-detect h5ad files from latent directory"""
|
|
264
|
+
if self.find_latent_metadata_path.exists():
|
|
265
|
+
import yaml
|
|
266
|
+
with open(self.find_latent_metadata_path) as f:
|
|
267
|
+
find_latent_metadata = yaml.safe_load(f)
|
|
268
|
+
self.sample_h5ad_dict = OrderedDict(
|
|
269
|
+
{sample_name: Path(latent_file)
|
|
270
|
+
for sample_name, latent_file in
|
|
271
|
+
find_latent_metadata['outputs']['latent_files'].items()
|
|
272
|
+
})
|
|
273
|
+
# assert all files exist
|
|
274
|
+
for sample_name, latent_file in self.sample_h5ad_dict.items():
|
|
275
|
+
if not latent_file.exists():
|
|
276
|
+
raise FileNotFoundError(f"Latent file not found for sample '{sample_name}': {latent_file}")
|
|
277
|
+
logger.info(
|
|
278
|
+
f"Auto-detected {len(self.sample_h5ad_dict)} samples from find_latent_metadata_path: {self.find_latent_metadata_path}")
|
|
279
|
+
else:
|
|
280
|
+
self.sample_h5ad_dict = OrderedDict()
|
|
281
|
+
latent_dir = self.latent_dir
|
|
282
|
+
logger.info(f"Auto-detecting h5ad files from latent directory: {latent_dir}")
|
|
283
|
+
|
|
284
|
+
# Look for latent files with different naming patterns
|
|
285
|
+
latent_files = list(latent_dir.glob("*_latent_adata.h5ad"))
|
|
286
|
+
if not latent_files:
|
|
287
|
+
latent_files = list(latent_dir.glob("*_add_latent.h5ad"))
|
|
288
|
+
|
|
289
|
+
if not latent_files:
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
# Extract sample names from file names
|
|
293
|
+
for latent_file in latent_files:
|
|
294
|
+
sample_name = self._extract_sample_name(latent_file)
|
|
295
|
+
self.sample_h5ad_dict[sample_name] = latent_file
|
|
296
|
+
|
|
297
|
+
# sort by sample name
|
|
298
|
+
self.sample_h5ad_dict = OrderedDict(sorted(self.sample_h5ad_dict.items()))
|
|
299
|
+
|
|
300
|
+
logger.info(f"Auto-detected {len(self.sample_h5ad_dict)} samples from latent directory")
|
|
301
|
+
|
|
302
|
+
def _extract_sample_name(self, latent_file):
|
|
303
|
+
"""Extract sample name from latent file path"""
|
|
304
|
+
filename = latent_file.stem
|
|
305
|
+
|
|
306
|
+
# Remove known suffixes
|
|
307
|
+
suffixes_to_remove = ["_latent_adata", "_add_latent"]
|
|
308
|
+
for suffix in suffixes_to_remove:
|
|
309
|
+
if filename.endswith(suffix):
|
|
310
|
+
return filename[:-len(suffix)]
|
|
311
|
+
|
|
312
|
+
return filename
|
|
313
|
+
|
|
314
|
+
def _setup_and_validate_fields(self):
|
|
315
|
+
"""Set up required/optional fields and validate h5ad structure"""
|
|
316
|
+
# Define required fields
|
|
317
|
+
required_fields = {
|
|
318
|
+
'latent_representation_cell': ('obsm', self.latent_representation_cell,
|
|
319
|
+
'Latent representation of cell identity'),
|
|
320
|
+
'spatial_key': ('obsm', self.spatial_key, 'Spatial key'),
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
# Add annotation as required if provided
|
|
324
|
+
if self.annotation:
|
|
325
|
+
required_fields['annotation'] = ('obs', self.annotation, 'Annotation')
|
|
326
|
+
|
|
327
|
+
# Add niche representation as required if provided
|
|
328
|
+
if self.latent_representation_niche:
|
|
329
|
+
required_fields['latent_representation_niche'] = (
|
|
330
|
+
'obsm',
|
|
331
|
+
self.latent_representation_niche,
|
|
332
|
+
'Latent representation of spatial niche'
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Add High_quality as required if find_neighbor_within_high_quality is enabled
|
|
336
|
+
if self.high_quality_neighbor_filter:
|
|
337
|
+
required_fields['High_quality'] = ('obs', 'High_quality', 'High quality cell indicator')
|
|
338
|
+
|
|
339
|
+
# Validate h5ad structure
|
|
340
|
+
validate_h5ad_structure(self.sample_h5ad_dict, required_fields)
|
|
341
|
+
|
|
342
|
+
def _configure_dataset_parameters(self):
|
|
343
|
+
"""Configure parameters based on dataset type"""
|
|
344
|
+
self.min_cells_per_type = self.homogeneous_neighbors if self.min_cells_per_type is None else min(self.min_cells_per_type, self.homogeneous_neighbors)
|
|
345
|
+
|
|
346
|
+
if self.dataset_type == DatasetType.SPATIAL_2D:
|
|
347
|
+
self._configure_spatial_2d()
|
|
348
|
+
elif self.dataset_type == DatasetType.SPATIAL_3D:
|
|
349
|
+
self._configure_spatial_3d()
|
|
350
|
+
elif self.dataset_type == DatasetType.SCRNA_SEQ:
|
|
351
|
+
self._configure_scrna_seq()
|
|
352
|
+
|
|
353
|
+
def _configure_spatial_2d(self):
|
|
354
|
+
"""Configure parameters for spatial 2D datasets"""
|
|
355
|
+
# spatial2D can have multiple slices but doesn't search across them
|
|
356
|
+
if self.n_adjacent_slices != 0:
|
|
357
|
+
self.n_adjacent_slices = 0
|
|
358
|
+
self.adjacent_slice_spatial_neighbors = 0
|
|
359
|
+
logger.info(
|
|
360
|
+
"Dataset type is spatial2D. This will only search homogeneous neighbors within each 2D slice (no cross-slice search). Setting adjacent_slices=0.")
|
|
361
|
+
|
|
362
|
+
if self.latent_representation_niche is None:
|
|
363
|
+
logger.warning("latent_representation_niche is not provided. Spatial domain similarity will not be used.")
|
|
364
|
+
|
|
365
|
+
assert self.homogeneous_neighbors <= self.spatial_neighbors, \
|
|
366
|
+
f"homogeneous_neighbors ({self.homogeneous_neighbors}) must be <= spatial_neighbors ({self.spatial_neighbors}) for spatial2D datasets"
|
|
367
|
+
|
|
368
|
+
def _configure_spatial_3d(self):
|
|
369
|
+
"""Configure parameters for spatial 3D datasets"""
|
|
370
|
+
if self.n_adjacent_slices == 0:
|
|
371
|
+
raise ValueError(
|
|
372
|
+
"Dataset type is spatial3D, but adjacent_slices=0. "
|
|
373
|
+
"You must set adjacent_slices to 1 or higher to enable cross-slice search. "
|
|
374
|
+
"If you don't want cross-slice search, use dataset_type='spatial2D' instead."
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
if self.latent_representation_niche is None:
|
|
378
|
+
logger.warning("latent_representation_niche is not provided. Spatial domain similarity will not be used.")
|
|
379
|
+
|
|
380
|
+
assert self.adjacent_slice_spatial_neighbors <= self.spatial_neighbors, \
|
|
381
|
+
f"adjacent_slice_neighbors ({self.adjacent_slice_spatial_neighbors}) must be <= spatial_neighbors ({self.spatial_neighbors})"
|
|
382
|
+
assert self.homogeneous_neighbors <= self.adjacent_slice_spatial_neighbors, \
|
|
383
|
+
f"homogeneous_neighbors ({self.homogeneous_neighbors}) must be <= adjacent_slice_neighbors ({self.adjacent_slice_spatial_neighbors})"
|
|
384
|
+
|
|
385
|
+
n_slices = 1 + self.n_adjacent_slices # only focal + above slices
|
|
386
|
+
assert n_slices <= len(self.sample_h5ad_dict), \
|
|
387
|
+
f"3D Cross slice search requires at least {n_slices} slices (1 focal + {self.n_adjacent_slices} above or {self.n_adjacent_slices} below). " \
|
|
388
|
+
f"Only {len(self.sample_h5ad_dict)} samples provided. Please provide more slices or reduce adjacent_slices."
|
|
389
|
+
|
|
390
|
+
logger.info(
|
|
391
|
+
f"Dataset type is spatial3D, using adjacent_slices={self.n_adjacent_slices} for cross-slice search")
|
|
392
|
+
logger.info("The Z axis order of slices is determined by the h5ad input order. Currently, the order is: ")
|
|
393
|
+
logger.info(f"{' -> '.join(list(self.sample_h5ad_dict.keys()))}")
|
|
394
|
+
|
|
395
|
+
homogeneous_neighbors = self.homogeneous_neighbors
|
|
396
|
+
n_adjacent_slices = self.n_adjacent_slices
|
|
397
|
+
# Check if we should use fix number of homogeneous neighbors per slice
|
|
398
|
+
if self.cross_slice_marker_score_strategy in [
|
|
399
|
+
MarkerScoreCrossSliceStrategy.PER_SLICE_POOL,
|
|
400
|
+
MarkerScoreCrossSliceStrategy.HIERARCHICAL_POOL
|
|
401
|
+
]:
|
|
402
|
+
|
|
403
|
+
self.fix_cross_slice_homogenous_neighbors = True
|
|
404
|
+
logger.info(
|
|
405
|
+
f"Using {self.cross_slice_marker_score_strategy.value} strategy with fixed number of homogeneous neighbors per adjacent slice: {self.homogeneous_neighbors} per slice.")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
elif self.cross_slice_marker_score_strategy == MarkerScoreCrossSliceStrategy.GLOBAL_POOL:
|
|
409
|
+
logger.info(
|
|
410
|
+
"Using global_pool strategy, will select top homogeneous neighbors from all adjacent slices based on similarity scores. Each adjacent slice can contribute variable number of homogeneous neighbors.")
|
|
411
|
+
|
|
412
|
+
logger.info(
|
|
413
|
+
f"Each focal cell will select {homogeneous_neighbors * (1 + 2 * n_adjacent_slices) = } total homogeneous neighbors across {(1 + 2 * n_adjacent_slices) = } slices.")
|
|
414
|
+
|
|
415
|
+
def _configure_scrna_seq(self):
|
|
416
|
+
"""Configure parameters for scRNA-seq datasets"""
|
|
417
|
+
self.n_adjacent_slices = 0
|
|
418
|
+
self.spatial_key = None
|
|
419
|
+
self.latent_representation_niche = None
|
|
420
|
+
|
|
421
|
+
def check_latent2gene_done(config: LatentToGeneConfig) -> bool:
|
|
422
|
+
"""
|
|
423
|
+
Check if latent2gene step is done by verifying validity of metadata and output files.
|
|
424
|
+
"""
|
|
425
|
+
from gsMap.latent2gene.memmap_io import MemMapDense
|
|
426
|
+
|
|
427
|
+
config._resolve_h5ad_inputs()
|
|
428
|
+
|
|
429
|
+
expected_outputs = {
|
|
430
|
+
"concatenated_latent_adata": Path(config.concatenated_latent_adata_path),
|
|
431
|
+
"rank_memmap": Path(config.rank_memmap_path),
|
|
432
|
+
"mean_frac": Path(config.mean_frac_path),
|
|
433
|
+
"marker_scores": Path(config.marker_scores_memmap_path),
|
|
434
|
+
"metadata": Path(config.latent2gene_metadata_path),
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
if not all(p.exists() for p in expected_outputs.values()):
|
|
438
|
+
return False
|
|
439
|
+
|
|
440
|
+
try:
|
|
441
|
+
# Check rank memmap completion
|
|
442
|
+
rank_memmap_complete, _ = MemMapDense.check_complete(expected_outputs["rank_memmap"])
|
|
443
|
+
if not rank_memmap_complete:
|
|
444
|
+
return False
|
|
445
|
+
|
|
446
|
+
# Check marker scores memmap completion
|
|
447
|
+
marker_scores_complete, _ = MemMapDense.check_complete(expected_outputs["marker_scores"])
|
|
448
|
+
if not marker_scores_complete:
|
|
449
|
+
return False
|
|
450
|
+
|
|
451
|
+
# Check metadata
|
|
452
|
+
with open(expected_outputs["metadata"]) as f:
|
|
453
|
+
metadata = yaml.unsafe_load(f)
|
|
454
|
+
|
|
455
|
+
if 'outputs' not in metadata:
|
|
456
|
+
return False
|
|
457
|
+
|
|
458
|
+
return True
|
|
459
|
+
except Exception as e:
|
|
460
|
+
logger.warning(f"Error checking latent2gene results: {e}")
|
|
461
|
+
return False
|