gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,334 @@
1
+ """
2
+ Configuration for spatial LD score regression.
3
+ """
4
+
5
+ import logging
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+ from typing import Annotated, Literal
9
+
10
+ import typer
11
+ import yaml
12
+
13
+ from gsMap.config.base import ConfigWithAutoPaths
14
+
15
+ logger = logging.getLogger("gsMap.config")
16
+
17
+ @dataclass
18
+ class SpatialLDSCComputeConfig:
19
+ """Compute configuration for spatial LDSC step."""
20
+ __display_in_quick_mode_cli__ = True
21
+
22
+ use_gpu: Annotated[bool, typer.Option(
23
+ "--use-gpu/--no-gpu",
24
+ help="Use GPU for JAX-accelerated spatial LDSC implementation"
25
+ ), {"__display_in_quick_mode_cli__": True}] = True
26
+
27
+ memmap_tmp_dir: Annotated[Path | None, typer.Option(
28
+ help="Temporary directory for memory-mapped files to improve I/O performance on slow filesystems. "
29
+ "If provided, memory maps will be copied to this directory for faster random access during computation.",
30
+ exists=True,
31
+ file_okay=False,
32
+ dir_okay=True,
33
+ resolve_path=True
34
+ )] = None
35
+
36
+ ldsc_read_workers: Annotated[int, typer.Option(
37
+ help="Number of read workers",
38
+ min=1
39
+ )] = 10
40
+
41
+ ldsc_compute_workers: Annotated[int, typer.Option(
42
+ help="Number of compute workers for LDSC regression",
43
+ min=1
44
+ )] = 10
45
+
46
+ spots_per_chunk_quick_mode: Annotated[int, typer.Option(
47
+ help="Number of spots per chunk in quick mode",
48
+ min=1
49
+ )] = 50
50
+
51
+ @dataclass
52
+ class GWASSumstatsConfig:
53
+ """Configuration for GWAS summary statistics."""
54
+ trait_name: Annotated[str | None, typer.Option(
55
+ help="Name of the trait for GWAS analysis"
56
+ )] = None
57
+
58
+ sumstats_file: Annotated[Path | None, typer.Option(
59
+ help="Path to GWAS summary statistics file",
60
+ exists=True,
61
+ file_okay=True,
62
+ dir_okay=False,
63
+ resolve_path=True
64
+ )] = None
65
+
66
+ sumstats_config_file: Annotated[Path | None, typer.Option(
67
+ help="Path to sumstats config file",
68
+ exists=True,
69
+ file_okay=True,
70
+ dir_okay=False,
71
+ resolve_path=True
72
+ )] = None
73
+
74
+ sumstats_config_dict: dict[str, Path] = field(default_factory=dict)
75
+
76
+ @property
77
+ def trait_name_list(self) -> list[str]:
78
+ """Return the list of trait names to process."""
79
+ return list(self.sumstats_config_dict.keys())
80
+
81
+ def __post_init__(self):
82
+ self._init_sumstats()
83
+
84
+ def _init_sumstats(self):
85
+ """
86
+ Process sumstats input options and populate sumstats_config_dict.
87
+
88
+ Args:
89
+ self: Configuration object with sumstats_file, sumstats_config_file, and trait_name fields.
90
+
91
+ Returns:
92
+ Dict mapping trait names to sumstats file paths.
93
+ """
94
+ sumstats_config_dict = {}
95
+
96
+ if self.sumstats_file is None and self.sumstats_config_file is None:
97
+ raise ValueError("One of sumstats_file and sumstats_config_file must be provided.")
98
+ if self.sumstats_file is not None and self.sumstats_config_file is not None:
99
+ raise ValueError(
100
+ "Only one of sumstats_file and sumstats_config_file must be provided."
101
+ )
102
+ if self.sumstats_file is not None and self.trait_name is None:
103
+ raise ValueError("trait_name must be provided if sumstats_file is provided.")
104
+ if self.sumstats_config_file is not None and self.trait_name is not None:
105
+ raise ValueError(
106
+ "trait_name must not be provided if sumstats_config_file is provided."
107
+ )
108
+ # load the sumstats self file
109
+ if self.sumstats_config_file is not None:
110
+ # get the directory of the config file to resolve relative paths
111
+ config_dir = Path(self.sumstats_config_file).parent
112
+ with open(self.sumstats_config_file) as f:
113
+ config_loaded = yaml.load(f, Loader=yaml.FullLoader)
114
+ for _trait_name, sumstats_file in config_loaded.items():
115
+ s_path = Path(sumstats_file)
116
+ if not s_path.is_absolute():
117
+ s_path = config_dir / s_path
118
+ sumstats_config_dict[_trait_name] = s_path.resolve()
119
+ # load the sumstats file
120
+ elif self.sumstats_file is not None:
121
+ sumstats_config_dict[self.trait_name] = Path(self.sumstats_file)
122
+ else:
123
+ raise ValueError("One of sumstats_file and sumstats_config_file must be provided.")
124
+
125
+ for sumstats_file in sumstats_config_dict.values():
126
+ assert Path(sumstats_file).exists(), f"{sumstats_file} does not exist."
127
+
128
+ self.sumstats_config_dict = sumstats_config_dict
129
+
130
+
131
+ @dataclass
132
+ class SpatialLDSCCoreConfig(GWASSumstatsConfig):
133
+ """Core configuration for spatial LDSC."""
134
+ w_ld_dir: Annotated[Path | None, typer.Option(
135
+ help="Directory containing the weights files (w_ld)",
136
+ exists=True,
137
+ file_okay=False,
138
+ dir_okay=True,
139
+ resolve_path=True
140
+ )] = None
141
+
142
+ additional_baseline_h5ad_path_list: Annotated[list[Path], typer.Option(
143
+ help="List of additional baseline h5ad paths",
144
+ exists=True,
145
+ file_okay=True,
146
+ dir_okay=False,
147
+ resolve_path=True
148
+ )] = field(default_factory=list)
149
+
150
+
151
+ chisq_max: Annotated[int | None, typer.Option(
152
+ help="Maximum chi-square value"
153
+ )] = None
154
+
155
+ cell_indices_range: Annotated[tuple[int, int] | None, typer.Option(
156
+ help="0-based range [start, end) of cell indices to process"
157
+ )] = None
158
+
159
+ sample_filter: Annotated[str | None, typer.Option(
160
+ help="Filter processing to a specific sample"
161
+ )] = None
162
+
163
+ n_blocks: Annotated[int, typer.Option(
164
+ help="Number of jackknife blocks",
165
+ min=1
166
+ )] = 200
167
+
168
+ # spots_per_chunk_quick_mode is inherited from SpatialLDSCComputeConfig
169
+
170
+ snp_gene_weight_adata_path: Annotated[Path | None, typer.Option(
171
+ help="Path to the SNP-gene weight matrix (H5AD format)",
172
+ exists=True,
173
+ file_okay=True,
174
+ dir_okay=False,
175
+ resolve_path=True
176
+ )] = None
177
+
178
+ # use_gpu is inherited from SpatialLDSCComputeConfig
179
+
180
+ marker_score_feather_path: Annotated[Path | None, typer.Option(
181
+ help="Path to marker score feather file",
182
+ exists=True,
183
+ file_okay=True,
184
+ dir_okay=False,
185
+ resolve_path=True
186
+ )] = None
187
+
188
+ marker_score_h5ad_path: Annotated[Path | None, typer.Option(
189
+ help="Path to marker score h5ad file",
190
+ exists=True,
191
+ file_okay=True,
192
+ dir_okay=False,
193
+ resolve_path=True
194
+ )] = None
195
+
196
+ marker_score_format: Annotated[Literal["memmap", "feather", "h5ad"] | None, typer.Option(
197
+ help="Format of marker scores"
198
+ )] = None
199
+
200
+
201
+
202
+ @dataclass
203
+ class SpatialLDSCConfig(SpatialLDSCCoreConfig, SpatialLDSCComputeConfig, ConfigWithAutoPaths):
204
+ """Spatial LDSC Configuration"""
205
+
206
+ def __post_init__(self):
207
+ super().__post_init__()
208
+
209
+
210
+ # Import here to avoid circular imports
211
+ from gsMap.config.utils import configure_jax_platform, get_anndata_shape
212
+
213
+ # Configure JAX platform if use_gpu is enabled
214
+ configure_jax_platform(self.use_gpu)
215
+
216
+ # Auto-detect marker_score_format if not specified
217
+ if self.marker_score_format is None:
218
+ if self.marker_score_feather_path is not None:
219
+ self.marker_score_format = "feather"
220
+ logger.info("Auto-detected marker_score_format as 'feather' based on marker_score_feather_path")
221
+ elif self.marker_score_h5ad_path is not None:
222
+ self.marker_score_format = "h5ad"
223
+ logger.info("Auto-detected marker_score_format as 'h5ad' based on marker_score_h5ad_path")
224
+ else:
225
+ self.marker_score_format = "memmap"
226
+ logger.info("Using default marker_score_format 'memmap'")
227
+
228
+ # Validate cell_indices_range is 0-based
229
+ if self.cell_indices_range is not None:
230
+ # Validate exclusivity between sample_filter and cell_indices_range
231
+
232
+ if self.sample_filter is not None:
233
+ raise ValueError(
234
+ "Only one of sample_filter or cell_indices_range can be provided, not both. "
235
+ "Use sample_filter to filter by sample, or cell_indices_range to process specific cell indices."
236
+ )
237
+
238
+ start, end = self.cell_indices_range
239
+
240
+ # Check that indices are 0-based
241
+ if start < 0:
242
+ raise ValueError(f"cell_indices_range start must be >= 0, got {start}")
243
+ if start == 1:
244
+ logger.warning(
245
+ "cell_indices_range appears to be 1-based (start=1). "
246
+ "Please ensure indices are 0-based. Adjusting start to 0."
247
+ )
248
+ start = 0
249
+
250
+ # Check that start < end
251
+ if start >= end:
252
+ raise ValueError(f"cell_indices_range start ({start}) must be less than end ({end})")
253
+
254
+ # Validate against actual data shape based on marker score format
255
+ if self.marker_score_format == "memmap":
256
+ # For memmap format, check the concatenated latent adata
257
+ adata_path = Path(self.workdir) / self.project_name / "latent2gene" / "concatenated_latent_adata.h5ad"
258
+ shape = get_anndata_shape(str(adata_path))
259
+ if shape is not None:
260
+ n_obs, _ = shape
261
+ if end > n_obs:
262
+ logger.warning(
263
+ f"cell_indices_range end ({end}) exceeds number of observations ({n_obs}). "
264
+ f"Setting end to {n_obs}."
265
+ )
266
+ end = n_obs
267
+ elif self.marker_score_format == "h5ad":
268
+ # For h5ad format, check the provided h5ad path
269
+ adata_path = Path(self.marker_score_h5ad_path)
270
+ assert adata_path.exists(), f"Marker score h5ad not found at {adata_path}."
271
+ shape = get_anndata_shape(str(adata_path))
272
+ if shape is not None:
273
+ n_obs, _ = shape
274
+ if end > n_obs:
275
+ logger.warning(
276
+ f"cell_indices_range end ({end}) exceeds number of observations ({n_obs}). "
277
+ f"Setting end to {n_obs}."
278
+ )
279
+ end = n_obs
280
+ elif self.marker_score_format == "feather":
281
+ # For feather format, validate the path exists
282
+ feather_path = Path(self.marker_score_feather_path)
283
+ assert feather_path.exists(), f"Marker score feather file not found at {feather_path}."
284
+
285
+ # Use pyarrow to get the number of rows and validate end
286
+ try:
287
+ import pyarrow.feather as feather
288
+ # Read metadata without loading full data
289
+ feather_table = feather.read_table(str(feather_path), memory_map=True, columns=[])
290
+ n_obs = feather_table.num_rows
291
+ if end > n_obs:
292
+ logger.warning(
293
+ f"cell_indices_range end ({end}) exceeds number of rows ({n_obs}) in feather file. "
294
+ f"Setting end to {n_obs}."
295
+ )
296
+ end = n_obs
297
+ except ImportError:
298
+ logger.warning(
299
+ "pyarrow not available. Cannot validate cell_indices_range against feather file. "
300
+ "Install pyarrow to enable validation."
301
+ )
302
+ except Exception as e:
303
+ logger.warning(f"Could not read feather file metadata: {e}")
304
+
305
+ # Update cell_indices_range with validated values
306
+ self.cell_indices_range = (start, end)
307
+ logger.info(f"Processing cell_indices_range: [{start}, {end})")
308
+
309
+ if self.snp_gene_weight_adata_path is None:
310
+ raise ValueError("snp_gene_weight_adata_path must be provided.")
311
+
312
+ # Handle w_ld_dir
313
+ if self.w_ld_dir is None:
314
+ w_ld_dir = Path(self.ldscore_save_dir) / "w_ld"
315
+ if w_ld_dir.exists():
316
+ self.w_ld_dir = w_ld_dir
317
+ logger.info(f"Using weights directory generated in the generate_ldscore step: {self.w_ld_dir}")
318
+ else:
319
+ raise ValueError(
320
+ "No w_ld_dir provided and no weights directory found in generate_ldscore output. "
321
+ "Either provide --w-ld-dir or run generate_ldscore first."
322
+ )
323
+ else:
324
+ logger.info(f"Using provided weights directory: {self.w_ld_dir}")
325
+
326
+ self.show_config(SpatialLDSCConfig)
327
+
328
+
329
+ def check_spatial_ldsc_done(config: SpatialLDSCConfig, trait_name: str) -> bool:
330
+ """
331
+ Check if spatial_ldsc step is done for a specific trait.
332
+ """
333
+ result_file = config.get_ldsc_result_file(trait_name)
334
+ return result_file.exists()
gsMap/config/utils.py ADDED
@@ -0,0 +1,286 @@
1
+ """
2
+ Utility functions for gsMap configuration and data validation.
3
+ """
4
+
5
+ import logging
6
+ from collections import OrderedDict
7
+ from pathlib import Path
8
+
9
+ import h5py
10
+ import yaml
11
+
12
+ logger = logging.getLogger("gsMap.config")
13
+
14
+ def configure_jax_platform(use_accelerator: bool = True):
15
+ """Configure JAX platform based on availability of accelerators.
16
+
17
+ Args:
18
+ use_accelerator: If True, try to use GPU then TPU if available,
19
+ otherwise fall back to CPU. If False, force CPU usage.
20
+
21
+ Raises:
22
+ ImportError: If JAX is not installed.
23
+ """
24
+ try:
25
+ import os
26
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
27
+ logging.getLogger("jax._src.xla_bridge").setLevel(logging.ERROR)
28
+
29
+ import jax
30
+ from jax import config as jax_config
31
+
32
+ if not use_accelerator:
33
+ jax_config.update('jax_platform_name', 'cpu')
34
+ logger.info("JAX configured to use CPU for computations (accelerators disabled)")
35
+ return
36
+
37
+ # Priority list for accelerators
38
+ platforms = ['gpu', 'tpu']
39
+ configured = False
40
+
41
+ for platform in platforms:
42
+ try:
43
+ devices = jax.devices(platform)
44
+ if len(devices) > 0:
45
+ jax_config.update('jax_platform_name', platform)
46
+ logger.info(f"JAX configured to use {platform.upper()} for computations.")
47
+ logger.info(f" ({len(devices)} device(s) detected). Device {devices[0]} will be used.")
48
+ configured = True
49
+ break
50
+ except (RuntimeError, ValueError, Exception):
51
+ continue
52
+
53
+ if not configured:
54
+ jax_config.update('jax_platform_name', 'cpu')
55
+ logger.info("No GPU or TPU detected, JAX configured to use CPU for computations")
56
+
57
+ except ImportError:
58
+ raise ImportError(
59
+ "JAX is required but not installed. Please install JAX by running: "
60
+ "pip install jax jaxlib (for CPU) or see JAX documentation for GPU/TPU installation."
61
+ ) from None
62
+
63
+ def get_anndata_shape(h5ad_path: str):
64
+ """Get the shape (n_obs, n_vars) of an AnnData file without loading it."""
65
+ with h5py.File(h5ad_path, 'r') as f:
66
+ # 1. Verify it's a valid AnnData file by checking metadata
67
+ if f.attrs.get('encoding-type') != 'anndata':
68
+ logger.error(f"File '{h5ad_path}' does not appear to be a valid AnnData file.")
69
+ return None
70
+
71
+ # 2. Determine n_obs and n_vars from the primary metadata sources
72
+ if 'obs' not in f or 'var' not in f:
73
+ logger.error("AnnData file is missing 'obs' or 'var' group.")
74
+ return None
75
+
76
+ # Get the name of the index column from attributes
77
+ obs_index_key = f['obs'].attrs.get('_index', None)
78
+ var_index_key = f['var'].attrs.get('_index', None)
79
+
80
+ if not obs_index_key or obs_index_key not in f['obs']:
81
+ logger.error("Could not determine index for 'obs'.")
82
+ return None
83
+ if not var_index_key or var_index_key not in f['var']:
84
+ logger.error("Could not determine index for 'var'.")
85
+ return None
86
+
87
+ # The shape is the length of these index arrays
88
+ obs_obj = f['obs'][obs_index_key]
89
+ if isinstance(obs_obj, h5py.Group):
90
+ obs_obj = obs_obj['categories']
91
+ n_obs = obs_obj.shape[0]
92
+
93
+ var_obj = f['var'][var_index_key]
94
+ if isinstance(var_obj, h5py.Group):
95
+ var_obj = var_obj['categories']
96
+ n_vars = var_obj.shape[0]
97
+
98
+ return n_obs, n_vars
99
+
100
+ def inspect_h5ad_structure(filename):
101
+ """
102
+ Inspect the structure of an h5ad file without loading data.
103
+
104
+ Returns dict with keys present in each slot.
105
+ """
106
+ structure = {}
107
+
108
+ with h5py.File(filename, 'r') as f:
109
+ # Check main slots
110
+ slots = ['obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns', 'layers', 'X', 'raw']
111
+
112
+ for slot in slots:
113
+ if slot in f:
114
+ if slot in ['obsm', 'varm', 'obsp', 'varp', 'layers', 'uns']:
115
+ # These are groups containing multiple keys
116
+ structure[slot] = list(f[slot].keys())
117
+ elif slot in ['obs', 'var']:
118
+ # These are dataframes - get column names
119
+ if 'column-order' in f[slot].attrs:
120
+ structure[slot] = list(f[slot].attrs['column-order'])
121
+ else:
122
+ structure[slot] = list(f[slot].keys())
123
+ else:
124
+ # X, raw - just note they exist
125
+ structure[slot] = True
126
+
127
+ return structure
128
+
129
+ def validate_h5ad_structure(sample_h5ad_dict, required_fields, optional_fields=None):
130
+ """
131
+ Validate h5ad files have required structure.
132
+
133
+ Args:
134
+ sample_h5ad_dict: OrderedDict of {sample_name: h5ad_path}
135
+ required_fields: Dict of {field_name: (slot, field_key, error_msg_template)}
136
+ e.g., {'spatial': ('obsm', 'spatial', 'Spatial key')}
137
+ optional_fields: Dict of {field_name: (slot, field_key)} for fields to warn about
138
+
139
+ Returns:
140
+ None, raises ValueError if required fields are missing
141
+ """
142
+ for sample_name, h5ad_path in sample_h5ad_dict.items():
143
+ if not h5ad_path.exists():
144
+ raise FileNotFoundError(f"H5AD file not found for sample '{sample_name}': {h5ad_path}")
145
+
146
+ # Inspect h5ad structure
147
+ structure = inspect_h5ad_structure(h5ad_path)
148
+
149
+ # Check required fields
150
+ for field_name, (slot, field_key, error_msg) in required_fields.items():
151
+ if field_key is None: # Skip if field not specified
152
+ continue
153
+
154
+ # Special handling for data_layer
155
+ if field_name == 'data_layer' and field_key != 'X':
156
+ if 'layers' not in structure or field_key not in structure.get('layers', []):
157
+ raise ValueError(
158
+ f"Data layer '{field_key}' not found in layers for sample '{sample_name}'. "
159
+ f"Available layers: {structure.get('layers', [])}"
160
+ )
161
+ elif field_name == 'data_layer' and field_key == 'X':
162
+ if 'X' not in structure:
163
+ raise ValueError(f"X matrix not found in h5ad file for sample '{sample_name}'")
164
+ else:
165
+ # Standard validation for obsm, obs, etc.
166
+ if slot not in structure or field_key not in structure.get(slot, []):
167
+ available = structure.get(slot, [])
168
+ raise ValueError(
169
+ f"{error_msg} '{field_key}' not found in {slot} for sample '{sample_name}'. "
170
+ f"Available keys in {slot}: {available}"
171
+ )
172
+
173
+ # Check optional fields (warn only)
174
+ if optional_fields:
175
+ for field_name, (slot, field_key) in optional_fields.items():
176
+ if field_key is None: # Skip if field not specified
177
+ continue
178
+
179
+ if slot not in structure or field_key not in structure.get(slot, []):
180
+ available = structure.get(slot, [])
181
+ logger.warning(
182
+ f"Optional field '{field_key}' not found in {slot} for sample '{sample_name}'. "
183
+ f"Available keys in {slot}: {available}"
184
+ )
185
+
186
+ def process_h5ad_inputs(config, input_options):
187
+ """
188
+ Process h5ad input options and create sample_h5ad_dict.
189
+
190
+ Args:
191
+ config: Configuration object with h5ad input fields
192
+ input_options: Dict mapping option names to (field_name, processing_type)
193
+ e.g., {'h5ad_yaml': ('h5ad_yaml', 'yaml'),
194
+ 'h5ad': ('h5ad', 'list'),
195
+ 'h5ad_list_file': ('h5ad_list_file', 'file')}
196
+
197
+ Returns:
198
+ OrderedDict of {sample_name: h5ad_path}
199
+ """
200
+
201
+ if config.sample_h5ad_dict is not None and len(config.sample_h5ad_dict) > 0:
202
+ logger.info("Using pre-defined sample_h5ad_dict from configuration")
203
+ return OrderedDict(config.sample_h5ad_dict)
204
+
205
+ sample_h5ad_dict = OrderedDict()
206
+
207
+ # Check which options are provided
208
+ options_provided = []
209
+ for option_name, (field_name, _) in input_options.items():
210
+ if hasattr(config, field_name) and getattr(config, field_name):
211
+ options_provided.append(option_name)
212
+
213
+ # Ensure at most one option is provided
214
+ if len(options_provided) > 1:
215
+ raise AssertionError(f"At most one input option can be provided. Got {len(options_provided)}: {', '.join(options_provided)}. " f"Please provide only one of: {', '.join(input_options.keys())}")
216
+
217
+ # Process the provided input option
218
+ for option_name, (field_name, processing_type) in input_options.items():
219
+ field_value = getattr(config, field_name, None)
220
+ if not field_value:
221
+ continue
222
+
223
+ if processing_type == 'yaml':
224
+ logger.info(f"Using {option_name}: {field_value}")
225
+ yaml_file_path = Path(field_value)
226
+ yaml_parent_dir = yaml_file_path.parent
227
+ with open(field_value) as f:
228
+ h5ad_data = yaml.safe_load(f)
229
+ for sample_name, h5ad_path in h5ad_data.items():
230
+ h5ad_path = Path(h5ad_path)
231
+ # Resolve relative paths relative to yaml file location
232
+ if not h5ad_path.is_absolute():
233
+ h5ad_path = yaml_parent_dir / h5ad_path
234
+ sample_h5ad_dict[sample_name] = h5ad_path
235
+
236
+ elif processing_type == 'list':
237
+ logger.info(f"Using {option_name} with {len(field_value)} files")
238
+ for h5ad_path in field_value:
239
+ h5ad_path = Path(h5ad_path)
240
+ sample_name = h5ad_path.stem
241
+ if sample_name in sample_h5ad_dict:
242
+ logger.warning(f"Duplicate sample name: {sample_name}, will be overwritten")
243
+ sample_h5ad_dict[sample_name] = h5ad_path
244
+
245
+ elif processing_type == 'file':
246
+ logger.info(f"Using {option_name}: {field_value}")
247
+ list_file_path = Path(field_value)
248
+ list_file_parent_dir = list_file_path.parent
249
+ with open(field_value) as f:
250
+ for line in f:
251
+ line = line.strip()
252
+ if line: # Skip empty lines
253
+ h5ad_path = Path(line)
254
+ # Resolve relative paths relative to list file location
255
+ if not h5ad_path.is_absolute():
256
+ h5ad_path = list_file_parent_dir / h5ad_path
257
+ sample_name = h5ad_path.stem
258
+ if sample_name in sample_h5ad_dict:
259
+ logger.warning(f"Duplicate sample name: {sample_name}, will be overwritten")
260
+ sample_h5ad_dict[sample_name] = h5ad_path
261
+ break
262
+
263
+ return sample_h5ad_dict
264
+
265
+ def verify_homolog_file_format(config):
266
+ if config.homolog_file is not None:
267
+ logger.info(
268
+ f"User provided homolog file to map gene names to human: {config.homolog_file}"
269
+ )
270
+ # check the format of the homolog file
271
+ with open(config.homolog_file) as f:
272
+ first_line = f.readline().strip()
273
+ _n_col = len(first_line.split())
274
+ if _n_col != 2:
275
+ raise ValueError(
276
+ f"Invalid homolog file format. Expected 2 columns, first column should be other species gene name, second column should be human gene name. "
277
+ f"Got {_n_col} columns in the first line."
278
+ )
279
+ else:
280
+ first_col_name, second_col_name = first_line.split()
281
+ config.species = first_col_name
282
+ logger.info(
283
+ f"Homolog file provided and will map gene name from column1:{first_col_name} to column2:{second_col_name}"
284
+ )
285
+ else:
286
+ logger.info("No homolog file provided. Run in human mode.")
@@ -0,0 +1,3 @@
1
+ from .find_latent_representation import run_find_latent_representation
2
+
3
+ __all__ = ['run_find_latent_representation']