gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration for spatial LD score regression.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Annotated, Literal
|
|
9
|
+
|
|
10
|
+
import typer
|
|
11
|
+
import yaml
|
|
12
|
+
|
|
13
|
+
from gsMap.config.base import ConfigWithAutoPaths
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger("gsMap.config")
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SpatialLDSCComputeConfig:
|
|
19
|
+
"""Compute configuration for spatial LDSC step."""
|
|
20
|
+
__display_in_quick_mode_cli__ = True
|
|
21
|
+
|
|
22
|
+
use_gpu: Annotated[bool, typer.Option(
|
|
23
|
+
"--use-gpu/--no-gpu",
|
|
24
|
+
help="Use GPU for JAX-accelerated spatial LDSC implementation"
|
|
25
|
+
), {"__display_in_quick_mode_cli__": True}] = True
|
|
26
|
+
|
|
27
|
+
memmap_tmp_dir: Annotated[Path | None, typer.Option(
|
|
28
|
+
help="Temporary directory for memory-mapped files to improve I/O performance on slow filesystems. "
|
|
29
|
+
"If provided, memory maps will be copied to this directory for faster random access during computation.",
|
|
30
|
+
exists=True,
|
|
31
|
+
file_okay=False,
|
|
32
|
+
dir_okay=True,
|
|
33
|
+
resolve_path=True
|
|
34
|
+
)] = None
|
|
35
|
+
|
|
36
|
+
ldsc_read_workers: Annotated[int, typer.Option(
|
|
37
|
+
help="Number of read workers",
|
|
38
|
+
min=1
|
|
39
|
+
)] = 10
|
|
40
|
+
|
|
41
|
+
ldsc_compute_workers: Annotated[int, typer.Option(
|
|
42
|
+
help="Number of compute workers for LDSC regression",
|
|
43
|
+
min=1
|
|
44
|
+
)] = 10
|
|
45
|
+
|
|
46
|
+
spots_per_chunk_quick_mode: Annotated[int, typer.Option(
|
|
47
|
+
help="Number of spots per chunk in quick mode",
|
|
48
|
+
min=1
|
|
49
|
+
)] = 50
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class GWASSumstatsConfig:
|
|
53
|
+
"""Configuration for GWAS summary statistics."""
|
|
54
|
+
trait_name: Annotated[str | None, typer.Option(
|
|
55
|
+
help="Name of the trait for GWAS analysis"
|
|
56
|
+
)] = None
|
|
57
|
+
|
|
58
|
+
sumstats_file: Annotated[Path | None, typer.Option(
|
|
59
|
+
help="Path to GWAS summary statistics file",
|
|
60
|
+
exists=True,
|
|
61
|
+
file_okay=True,
|
|
62
|
+
dir_okay=False,
|
|
63
|
+
resolve_path=True
|
|
64
|
+
)] = None
|
|
65
|
+
|
|
66
|
+
sumstats_config_file: Annotated[Path | None, typer.Option(
|
|
67
|
+
help="Path to sumstats config file",
|
|
68
|
+
exists=True,
|
|
69
|
+
file_okay=True,
|
|
70
|
+
dir_okay=False,
|
|
71
|
+
resolve_path=True
|
|
72
|
+
)] = None
|
|
73
|
+
|
|
74
|
+
sumstats_config_dict: dict[str, Path] = field(default_factory=dict)
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def trait_name_list(self) -> list[str]:
|
|
78
|
+
"""Return the list of trait names to process."""
|
|
79
|
+
return list(self.sumstats_config_dict.keys())
|
|
80
|
+
|
|
81
|
+
def __post_init__(self):
|
|
82
|
+
self._init_sumstats()
|
|
83
|
+
|
|
84
|
+
def _init_sumstats(self):
|
|
85
|
+
"""
|
|
86
|
+
Process sumstats input options and populate sumstats_config_dict.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
self: Configuration object with sumstats_file, sumstats_config_file, and trait_name fields.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Dict mapping trait names to sumstats file paths.
|
|
93
|
+
"""
|
|
94
|
+
sumstats_config_dict = {}
|
|
95
|
+
|
|
96
|
+
if self.sumstats_file is None and self.sumstats_config_file is None:
|
|
97
|
+
raise ValueError("One of sumstats_file and sumstats_config_file must be provided.")
|
|
98
|
+
if self.sumstats_file is not None and self.sumstats_config_file is not None:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
"Only one of sumstats_file and sumstats_config_file must be provided."
|
|
101
|
+
)
|
|
102
|
+
if self.sumstats_file is not None and self.trait_name is None:
|
|
103
|
+
raise ValueError("trait_name must be provided if sumstats_file is provided.")
|
|
104
|
+
if self.sumstats_config_file is not None and self.trait_name is not None:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
"trait_name must not be provided if sumstats_config_file is provided."
|
|
107
|
+
)
|
|
108
|
+
# load the sumstats self file
|
|
109
|
+
if self.sumstats_config_file is not None:
|
|
110
|
+
# get the directory of the config file to resolve relative paths
|
|
111
|
+
config_dir = Path(self.sumstats_config_file).parent
|
|
112
|
+
with open(self.sumstats_config_file) as f:
|
|
113
|
+
config_loaded = yaml.load(f, Loader=yaml.FullLoader)
|
|
114
|
+
for _trait_name, sumstats_file in config_loaded.items():
|
|
115
|
+
s_path = Path(sumstats_file)
|
|
116
|
+
if not s_path.is_absolute():
|
|
117
|
+
s_path = config_dir / s_path
|
|
118
|
+
sumstats_config_dict[_trait_name] = s_path.resolve()
|
|
119
|
+
# load the sumstats file
|
|
120
|
+
elif self.sumstats_file is not None:
|
|
121
|
+
sumstats_config_dict[self.trait_name] = Path(self.sumstats_file)
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError("One of sumstats_file and sumstats_config_file must be provided.")
|
|
124
|
+
|
|
125
|
+
for sumstats_file in sumstats_config_dict.values():
|
|
126
|
+
assert Path(sumstats_file).exists(), f"{sumstats_file} does not exist."
|
|
127
|
+
|
|
128
|
+
self.sumstats_config_dict = sumstats_config_dict
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclass
|
|
132
|
+
class SpatialLDSCCoreConfig(GWASSumstatsConfig):
|
|
133
|
+
"""Core configuration for spatial LDSC."""
|
|
134
|
+
w_ld_dir: Annotated[Path | None, typer.Option(
|
|
135
|
+
help="Directory containing the weights files (w_ld)",
|
|
136
|
+
exists=True,
|
|
137
|
+
file_okay=False,
|
|
138
|
+
dir_okay=True,
|
|
139
|
+
resolve_path=True
|
|
140
|
+
)] = None
|
|
141
|
+
|
|
142
|
+
additional_baseline_h5ad_path_list: Annotated[list[Path], typer.Option(
|
|
143
|
+
help="List of additional baseline h5ad paths",
|
|
144
|
+
exists=True,
|
|
145
|
+
file_okay=True,
|
|
146
|
+
dir_okay=False,
|
|
147
|
+
resolve_path=True
|
|
148
|
+
)] = field(default_factory=list)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
chisq_max: Annotated[int | None, typer.Option(
|
|
152
|
+
help="Maximum chi-square value"
|
|
153
|
+
)] = None
|
|
154
|
+
|
|
155
|
+
cell_indices_range: Annotated[tuple[int, int] | None, typer.Option(
|
|
156
|
+
help="0-based range [start, end) of cell indices to process"
|
|
157
|
+
)] = None
|
|
158
|
+
|
|
159
|
+
sample_filter: Annotated[str | None, typer.Option(
|
|
160
|
+
help="Filter processing to a specific sample"
|
|
161
|
+
)] = None
|
|
162
|
+
|
|
163
|
+
n_blocks: Annotated[int, typer.Option(
|
|
164
|
+
help="Number of jackknife blocks",
|
|
165
|
+
min=1
|
|
166
|
+
)] = 200
|
|
167
|
+
|
|
168
|
+
# spots_per_chunk_quick_mode is inherited from SpatialLDSCComputeConfig
|
|
169
|
+
|
|
170
|
+
snp_gene_weight_adata_path: Annotated[Path | None, typer.Option(
|
|
171
|
+
help="Path to the SNP-gene weight matrix (H5AD format)",
|
|
172
|
+
exists=True,
|
|
173
|
+
file_okay=True,
|
|
174
|
+
dir_okay=False,
|
|
175
|
+
resolve_path=True
|
|
176
|
+
)] = None
|
|
177
|
+
|
|
178
|
+
# use_gpu is inherited from SpatialLDSCComputeConfig
|
|
179
|
+
|
|
180
|
+
marker_score_feather_path: Annotated[Path | None, typer.Option(
|
|
181
|
+
help="Path to marker score feather file",
|
|
182
|
+
exists=True,
|
|
183
|
+
file_okay=True,
|
|
184
|
+
dir_okay=False,
|
|
185
|
+
resolve_path=True
|
|
186
|
+
)] = None
|
|
187
|
+
|
|
188
|
+
marker_score_h5ad_path: Annotated[Path | None, typer.Option(
|
|
189
|
+
help="Path to marker score h5ad file",
|
|
190
|
+
exists=True,
|
|
191
|
+
file_okay=True,
|
|
192
|
+
dir_okay=False,
|
|
193
|
+
resolve_path=True
|
|
194
|
+
)] = None
|
|
195
|
+
|
|
196
|
+
marker_score_format: Annotated[Literal["memmap", "feather", "h5ad"] | None, typer.Option(
|
|
197
|
+
help="Format of marker scores"
|
|
198
|
+
)] = None
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@dataclass
|
|
203
|
+
class SpatialLDSCConfig(SpatialLDSCCoreConfig, SpatialLDSCComputeConfig, ConfigWithAutoPaths):
|
|
204
|
+
"""Spatial LDSC Configuration"""
|
|
205
|
+
|
|
206
|
+
def __post_init__(self):
|
|
207
|
+
super().__post_init__()
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
# Import here to avoid circular imports
|
|
211
|
+
from gsMap.config.utils import configure_jax_platform, get_anndata_shape
|
|
212
|
+
|
|
213
|
+
# Configure JAX platform if use_gpu is enabled
|
|
214
|
+
configure_jax_platform(self.use_gpu)
|
|
215
|
+
|
|
216
|
+
# Auto-detect marker_score_format if not specified
|
|
217
|
+
if self.marker_score_format is None:
|
|
218
|
+
if self.marker_score_feather_path is not None:
|
|
219
|
+
self.marker_score_format = "feather"
|
|
220
|
+
logger.info("Auto-detected marker_score_format as 'feather' based on marker_score_feather_path")
|
|
221
|
+
elif self.marker_score_h5ad_path is not None:
|
|
222
|
+
self.marker_score_format = "h5ad"
|
|
223
|
+
logger.info("Auto-detected marker_score_format as 'h5ad' based on marker_score_h5ad_path")
|
|
224
|
+
else:
|
|
225
|
+
self.marker_score_format = "memmap"
|
|
226
|
+
logger.info("Using default marker_score_format 'memmap'")
|
|
227
|
+
|
|
228
|
+
# Validate cell_indices_range is 0-based
|
|
229
|
+
if self.cell_indices_range is not None:
|
|
230
|
+
# Validate exclusivity between sample_filter and cell_indices_range
|
|
231
|
+
|
|
232
|
+
if self.sample_filter is not None:
|
|
233
|
+
raise ValueError(
|
|
234
|
+
"Only one of sample_filter or cell_indices_range can be provided, not both. "
|
|
235
|
+
"Use sample_filter to filter by sample, or cell_indices_range to process specific cell indices."
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
start, end = self.cell_indices_range
|
|
239
|
+
|
|
240
|
+
# Check that indices are 0-based
|
|
241
|
+
if start < 0:
|
|
242
|
+
raise ValueError(f"cell_indices_range start must be >= 0, got {start}")
|
|
243
|
+
if start == 1:
|
|
244
|
+
logger.warning(
|
|
245
|
+
"cell_indices_range appears to be 1-based (start=1). "
|
|
246
|
+
"Please ensure indices are 0-based. Adjusting start to 0."
|
|
247
|
+
)
|
|
248
|
+
start = 0
|
|
249
|
+
|
|
250
|
+
# Check that start < end
|
|
251
|
+
if start >= end:
|
|
252
|
+
raise ValueError(f"cell_indices_range start ({start}) must be less than end ({end})")
|
|
253
|
+
|
|
254
|
+
# Validate against actual data shape based on marker score format
|
|
255
|
+
if self.marker_score_format == "memmap":
|
|
256
|
+
# For memmap format, check the concatenated latent adata
|
|
257
|
+
adata_path = Path(self.workdir) / self.project_name / "latent2gene" / "concatenated_latent_adata.h5ad"
|
|
258
|
+
shape = get_anndata_shape(str(adata_path))
|
|
259
|
+
if shape is not None:
|
|
260
|
+
n_obs, _ = shape
|
|
261
|
+
if end > n_obs:
|
|
262
|
+
logger.warning(
|
|
263
|
+
f"cell_indices_range end ({end}) exceeds number of observations ({n_obs}). "
|
|
264
|
+
f"Setting end to {n_obs}."
|
|
265
|
+
)
|
|
266
|
+
end = n_obs
|
|
267
|
+
elif self.marker_score_format == "h5ad":
|
|
268
|
+
# For h5ad format, check the provided h5ad path
|
|
269
|
+
adata_path = Path(self.marker_score_h5ad_path)
|
|
270
|
+
assert adata_path.exists(), f"Marker score h5ad not found at {adata_path}."
|
|
271
|
+
shape = get_anndata_shape(str(adata_path))
|
|
272
|
+
if shape is not None:
|
|
273
|
+
n_obs, _ = shape
|
|
274
|
+
if end > n_obs:
|
|
275
|
+
logger.warning(
|
|
276
|
+
f"cell_indices_range end ({end}) exceeds number of observations ({n_obs}). "
|
|
277
|
+
f"Setting end to {n_obs}."
|
|
278
|
+
)
|
|
279
|
+
end = n_obs
|
|
280
|
+
elif self.marker_score_format == "feather":
|
|
281
|
+
# For feather format, validate the path exists
|
|
282
|
+
feather_path = Path(self.marker_score_feather_path)
|
|
283
|
+
assert feather_path.exists(), f"Marker score feather file not found at {feather_path}."
|
|
284
|
+
|
|
285
|
+
# Use pyarrow to get the number of rows and validate end
|
|
286
|
+
try:
|
|
287
|
+
import pyarrow.feather as feather
|
|
288
|
+
# Read metadata without loading full data
|
|
289
|
+
feather_table = feather.read_table(str(feather_path), memory_map=True, columns=[])
|
|
290
|
+
n_obs = feather_table.num_rows
|
|
291
|
+
if end > n_obs:
|
|
292
|
+
logger.warning(
|
|
293
|
+
f"cell_indices_range end ({end}) exceeds number of rows ({n_obs}) in feather file. "
|
|
294
|
+
f"Setting end to {n_obs}."
|
|
295
|
+
)
|
|
296
|
+
end = n_obs
|
|
297
|
+
except ImportError:
|
|
298
|
+
logger.warning(
|
|
299
|
+
"pyarrow not available. Cannot validate cell_indices_range against feather file. "
|
|
300
|
+
"Install pyarrow to enable validation."
|
|
301
|
+
)
|
|
302
|
+
except Exception as e:
|
|
303
|
+
logger.warning(f"Could not read feather file metadata: {e}")
|
|
304
|
+
|
|
305
|
+
# Update cell_indices_range with validated values
|
|
306
|
+
self.cell_indices_range = (start, end)
|
|
307
|
+
logger.info(f"Processing cell_indices_range: [{start}, {end})")
|
|
308
|
+
|
|
309
|
+
if self.snp_gene_weight_adata_path is None:
|
|
310
|
+
raise ValueError("snp_gene_weight_adata_path must be provided.")
|
|
311
|
+
|
|
312
|
+
# Handle w_ld_dir
|
|
313
|
+
if self.w_ld_dir is None:
|
|
314
|
+
w_ld_dir = Path(self.ldscore_save_dir) / "w_ld"
|
|
315
|
+
if w_ld_dir.exists():
|
|
316
|
+
self.w_ld_dir = w_ld_dir
|
|
317
|
+
logger.info(f"Using weights directory generated in the generate_ldscore step: {self.w_ld_dir}")
|
|
318
|
+
else:
|
|
319
|
+
raise ValueError(
|
|
320
|
+
"No w_ld_dir provided and no weights directory found in generate_ldscore output. "
|
|
321
|
+
"Either provide --w-ld-dir or run generate_ldscore first."
|
|
322
|
+
)
|
|
323
|
+
else:
|
|
324
|
+
logger.info(f"Using provided weights directory: {self.w_ld_dir}")
|
|
325
|
+
|
|
326
|
+
self.show_config(SpatialLDSCConfig)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def check_spatial_ldsc_done(config: SpatialLDSCConfig, trait_name: str) -> bool:
|
|
330
|
+
"""
|
|
331
|
+
Check if spatial_ldsc step is done for a specific trait.
|
|
332
|
+
"""
|
|
333
|
+
result_file = config.get_ldsc_result_file(trait_name)
|
|
334
|
+
return result_file.exists()
|
gsMap/config/utils.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for gsMap configuration and data validation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import h5py
|
|
10
|
+
import yaml
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger("gsMap.config")
|
|
13
|
+
|
|
14
|
+
def configure_jax_platform(use_accelerator: bool = True):
|
|
15
|
+
"""Configure JAX platform based on availability of accelerators.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
use_accelerator: If True, try to use GPU then TPU if available,
|
|
19
|
+
otherwise fall back to CPU. If False, force CPU usage.
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
ImportError: If JAX is not installed.
|
|
23
|
+
"""
|
|
24
|
+
try:
|
|
25
|
+
import os
|
|
26
|
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
27
|
+
logging.getLogger("jax._src.xla_bridge").setLevel(logging.ERROR)
|
|
28
|
+
|
|
29
|
+
import jax
|
|
30
|
+
from jax import config as jax_config
|
|
31
|
+
|
|
32
|
+
if not use_accelerator:
|
|
33
|
+
jax_config.update('jax_platform_name', 'cpu')
|
|
34
|
+
logger.info("JAX configured to use CPU for computations (accelerators disabled)")
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
# Priority list for accelerators
|
|
38
|
+
platforms = ['gpu', 'tpu']
|
|
39
|
+
configured = False
|
|
40
|
+
|
|
41
|
+
for platform in platforms:
|
|
42
|
+
try:
|
|
43
|
+
devices = jax.devices(platform)
|
|
44
|
+
if len(devices) > 0:
|
|
45
|
+
jax_config.update('jax_platform_name', platform)
|
|
46
|
+
logger.info(f"JAX configured to use {platform.upper()} for computations.")
|
|
47
|
+
logger.info(f" ({len(devices)} device(s) detected). Device {devices[0]} will be used.")
|
|
48
|
+
configured = True
|
|
49
|
+
break
|
|
50
|
+
except (RuntimeError, ValueError, Exception):
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
if not configured:
|
|
54
|
+
jax_config.update('jax_platform_name', 'cpu')
|
|
55
|
+
logger.info("No GPU or TPU detected, JAX configured to use CPU for computations")
|
|
56
|
+
|
|
57
|
+
except ImportError:
|
|
58
|
+
raise ImportError(
|
|
59
|
+
"JAX is required but not installed. Please install JAX by running: "
|
|
60
|
+
"pip install jax jaxlib (for CPU) or see JAX documentation for GPU/TPU installation."
|
|
61
|
+
) from None
|
|
62
|
+
|
|
63
|
+
def get_anndata_shape(h5ad_path: str):
|
|
64
|
+
"""Get the shape (n_obs, n_vars) of an AnnData file without loading it."""
|
|
65
|
+
with h5py.File(h5ad_path, 'r') as f:
|
|
66
|
+
# 1. Verify it's a valid AnnData file by checking metadata
|
|
67
|
+
if f.attrs.get('encoding-type') != 'anndata':
|
|
68
|
+
logger.error(f"File '{h5ad_path}' does not appear to be a valid AnnData file.")
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
# 2. Determine n_obs and n_vars from the primary metadata sources
|
|
72
|
+
if 'obs' not in f or 'var' not in f:
|
|
73
|
+
logger.error("AnnData file is missing 'obs' or 'var' group.")
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
# Get the name of the index column from attributes
|
|
77
|
+
obs_index_key = f['obs'].attrs.get('_index', None)
|
|
78
|
+
var_index_key = f['var'].attrs.get('_index', None)
|
|
79
|
+
|
|
80
|
+
if not obs_index_key or obs_index_key not in f['obs']:
|
|
81
|
+
logger.error("Could not determine index for 'obs'.")
|
|
82
|
+
return None
|
|
83
|
+
if not var_index_key or var_index_key not in f['var']:
|
|
84
|
+
logger.error("Could not determine index for 'var'.")
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
# The shape is the length of these index arrays
|
|
88
|
+
obs_obj = f['obs'][obs_index_key]
|
|
89
|
+
if isinstance(obs_obj, h5py.Group):
|
|
90
|
+
obs_obj = obs_obj['categories']
|
|
91
|
+
n_obs = obs_obj.shape[0]
|
|
92
|
+
|
|
93
|
+
var_obj = f['var'][var_index_key]
|
|
94
|
+
if isinstance(var_obj, h5py.Group):
|
|
95
|
+
var_obj = var_obj['categories']
|
|
96
|
+
n_vars = var_obj.shape[0]
|
|
97
|
+
|
|
98
|
+
return n_obs, n_vars
|
|
99
|
+
|
|
100
|
+
def inspect_h5ad_structure(filename):
|
|
101
|
+
"""
|
|
102
|
+
Inspect the structure of an h5ad file without loading data.
|
|
103
|
+
|
|
104
|
+
Returns dict with keys present in each slot.
|
|
105
|
+
"""
|
|
106
|
+
structure = {}
|
|
107
|
+
|
|
108
|
+
with h5py.File(filename, 'r') as f:
|
|
109
|
+
# Check main slots
|
|
110
|
+
slots = ['obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns', 'layers', 'X', 'raw']
|
|
111
|
+
|
|
112
|
+
for slot in slots:
|
|
113
|
+
if slot in f:
|
|
114
|
+
if slot in ['obsm', 'varm', 'obsp', 'varp', 'layers', 'uns']:
|
|
115
|
+
# These are groups containing multiple keys
|
|
116
|
+
structure[slot] = list(f[slot].keys())
|
|
117
|
+
elif slot in ['obs', 'var']:
|
|
118
|
+
# These are dataframes - get column names
|
|
119
|
+
if 'column-order' in f[slot].attrs:
|
|
120
|
+
structure[slot] = list(f[slot].attrs['column-order'])
|
|
121
|
+
else:
|
|
122
|
+
structure[slot] = list(f[slot].keys())
|
|
123
|
+
else:
|
|
124
|
+
# X, raw - just note they exist
|
|
125
|
+
structure[slot] = True
|
|
126
|
+
|
|
127
|
+
return structure
|
|
128
|
+
|
|
129
|
+
def validate_h5ad_structure(sample_h5ad_dict, required_fields, optional_fields=None):
|
|
130
|
+
"""
|
|
131
|
+
Validate h5ad files have required structure.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
sample_h5ad_dict: OrderedDict of {sample_name: h5ad_path}
|
|
135
|
+
required_fields: Dict of {field_name: (slot, field_key, error_msg_template)}
|
|
136
|
+
e.g., {'spatial': ('obsm', 'spatial', 'Spatial key')}
|
|
137
|
+
optional_fields: Dict of {field_name: (slot, field_key)} for fields to warn about
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
None, raises ValueError if required fields are missing
|
|
141
|
+
"""
|
|
142
|
+
for sample_name, h5ad_path in sample_h5ad_dict.items():
|
|
143
|
+
if not h5ad_path.exists():
|
|
144
|
+
raise FileNotFoundError(f"H5AD file not found for sample '{sample_name}': {h5ad_path}")
|
|
145
|
+
|
|
146
|
+
# Inspect h5ad structure
|
|
147
|
+
structure = inspect_h5ad_structure(h5ad_path)
|
|
148
|
+
|
|
149
|
+
# Check required fields
|
|
150
|
+
for field_name, (slot, field_key, error_msg) in required_fields.items():
|
|
151
|
+
if field_key is None: # Skip if field not specified
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
# Special handling for data_layer
|
|
155
|
+
if field_name == 'data_layer' and field_key != 'X':
|
|
156
|
+
if 'layers' not in structure or field_key not in structure.get('layers', []):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Data layer '{field_key}' not found in layers for sample '{sample_name}'. "
|
|
159
|
+
f"Available layers: {structure.get('layers', [])}"
|
|
160
|
+
)
|
|
161
|
+
elif field_name == 'data_layer' and field_key == 'X':
|
|
162
|
+
if 'X' not in structure:
|
|
163
|
+
raise ValueError(f"X matrix not found in h5ad file for sample '{sample_name}'")
|
|
164
|
+
else:
|
|
165
|
+
# Standard validation for obsm, obs, etc.
|
|
166
|
+
if slot not in structure or field_key not in structure.get(slot, []):
|
|
167
|
+
available = structure.get(slot, [])
|
|
168
|
+
raise ValueError(
|
|
169
|
+
f"{error_msg} '{field_key}' not found in {slot} for sample '{sample_name}'. "
|
|
170
|
+
f"Available keys in {slot}: {available}"
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Check optional fields (warn only)
|
|
174
|
+
if optional_fields:
|
|
175
|
+
for field_name, (slot, field_key) in optional_fields.items():
|
|
176
|
+
if field_key is None: # Skip if field not specified
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
if slot not in structure or field_key not in structure.get(slot, []):
|
|
180
|
+
available = structure.get(slot, [])
|
|
181
|
+
logger.warning(
|
|
182
|
+
f"Optional field '{field_key}' not found in {slot} for sample '{sample_name}'. "
|
|
183
|
+
f"Available keys in {slot}: {available}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def process_h5ad_inputs(config, input_options):
|
|
187
|
+
"""
|
|
188
|
+
Process h5ad input options and create sample_h5ad_dict.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
config: Configuration object with h5ad input fields
|
|
192
|
+
input_options: Dict mapping option names to (field_name, processing_type)
|
|
193
|
+
e.g., {'h5ad_yaml': ('h5ad_yaml', 'yaml'),
|
|
194
|
+
'h5ad': ('h5ad', 'list'),
|
|
195
|
+
'h5ad_list_file': ('h5ad_list_file', 'file')}
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
OrderedDict of {sample_name: h5ad_path}
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
if config.sample_h5ad_dict is not None and len(config.sample_h5ad_dict) > 0:
|
|
202
|
+
logger.info("Using pre-defined sample_h5ad_dict from configuration")
|
|
203
|
+
return OrderedDict(config.sample_h5ad_dict)
|
|
204
|
+
|
|
205
|
+
sample_h5ad_dict = OrderedDict()
|
|
206
|
+
|
|
207
|
+
# Check which options are provided
|
|
208
|
+
options_provided = []
|
|
209
|
+
for option_name, (field_name, _) in input_options.items():
|
|
210
|
+
if hasattr(config, field_name) and getattr(config, field_name):
|
|
211
|
+
options_provided.append(option_name)
|
|
212
|
+
|
|
213
|
+
# Ensure at most one option is provided
|
|
214
|
+
if len(options_provided) > 1:
|
|
215
|
+
raise AssertionError(f"At most one input option can be provided. Got {len(options_provided)}: {', '.join(options_provided)}. " f"Please provide only one of: {', '.join(input_options.keys())}")
|
|
216
|
+
|
|
217
|
+
# Process the provided input option
|
|
218
|
+
for option_name, (field_name, processing_type) in input_options.items():
|
|
219
|
+
field_value = getattr(config, field_name, None)
|
|
220
|
+
if not field_value:
|
|
221
|
+
continue
|
|
222
|
+
|
|
223
|
+
if processing_type == 'yaml':
|
|
224
|
+
logger.info(f"Using {option_name}: {field_value}")
|
|
225
|
+
yaml_file_path = Path(field_value)
|
|
226
|
+
yaml_parent_dir = yaml_file_path.parent
|
|
227
|
+
with open(field_value) as f:
|
|
228
|
+
h5ad_data = yaml.safe_load(f)
|
|
229
|
+
for sample_name, h5ad_path in h5ad_data.items():
|
|
230
|
+
h5ad_path = Path(h5ad_path)
|
|
231
|
+
# Resolve relative paths relative to yaml file location
|
|
232
|
+
if not h5ad_path.is_absolute():
|
|
233
|
+
h5ad_path = yaml_parent_dir / h5ad_path
|
|
234
|
+
sample_h5ad_dict[sample_name] = h5ad_path
|
|
235
|
+
|
|
236
|
+
elif processing_type == 'list':
|
|
237
|
+
logger.info(f"Using {option_name} with {len(field_value)} files")
|
|
238
|
+
for h5ad_path in field_value:
|
|
239
|
+
h5ad_path = Path(h5ad_path)
|
|
240
|
+
sample_name = h5ad_path.stem
|
|
241
|
+
if sample_name in sample_h5ad_dict:
|
|
242
|
+
logger.warning(f"Duplicate sample name: {sample_name}, will be overwritten")
|
|
243
|
+
sample_h5ad_dict[sample_name] = h5ad_path
|
|
244
|
+
|
|
245
|
+
elif processing_type == 'file':
|
|
246
|
+
logger.info(f"Using {option_name}: {field_value}")
|
|
247
|
+
list_file_path = Path(field_value)
|
|
248
|
+
list_file_parent_dir = list_file_path.parent
|
|
249
|
+
with open(field_value) as f:
|
|
250
|
+
for line in f:
|
|
251
|
+
line = line.strip()
|
|
252
|
+
if line: # Skip empty lines
|
|
253
|
+
h5ad_path = Path(line)
|
|
254
|
+
# Resolve relative paths relative to list file location
|
|
255
|
+
if not h5ad_path.is_absolute():
|
|
256
|
+
h5ad_path = list_file_parent_dir / h5ad_path
|
|
257
|
+
sample_name = h5ad_path.stem
|
|
258
|
+
if sample_name in sample_h5ad_dict:
|
|
259
|
+
logger.warning(f"Duplicate sample name: {sample_name}, will be overwritten")
|
|
260
|
+
sample_h5ad_dict[sample_name] = h5ad_path
|
|
261
|
+
break
|
|
262
|
+
|
|
263
|
+
return sample_h5ad_dict
|
|
264
|
+
|
|
265
|
+
def verify_homolog_file_format(config):
|
|
266
|
+
if config.homolog_file is not None:
|
|
267
|
+
logger.info(
|
|
268
|
+
f"User provided homolog file to map gene names to human: {config.homolog_file}"
|
|
269
|
+
)
|
|
270
|
+
# check the format of the homolog file
|
|
271
|
+
with open(config.homolog_file) as f:
|
|
272
|
+
first_line = f.readline().strip()
|
|
273
|
+
_n_col = len(first_line.split())
|
|
274
|
+
if _n_col != 2:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"Invalid homolog file format. Expected 2 columns, first column should be other species gene name, second column should be human gene name. "
|
|
277
|
+
f"Got {_n_col} columns in the first line."
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
first_col_name, second_col_name = first_line.split()
|
|
281
|
+
config.species = first_col_name
|
|
282
|
+
logger.info(
|
|
283
|
+
f"Homolog file provided and will map gene name from column1:{first_col_name} to column2:{second_col_name}"
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
logger.info("No homolog file provided. Run in human mode.")
|