gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Decorators for CLI and resource tracking.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import inspect
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
import subprocess
|
|
11
|
+
import sys
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
from dataclasses import fields
|
|
15
|
+
from functools import wraps
|
|
16
|
+
from typing import Annotated, Any, get_args, get_origin
|
|
17
|
+
|
|
18
|
+
import psutil
|
|
19
|
+
import pyfiglet
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger("gsMap")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def process_cpu_time(proc: psutil.Process):
|
|
25
|
+
"""Calculate total CPU time for a process."""
|
|
26
|
+
cpu_times = proc.cpu_times()
|
|
27
|
+
total = cpu_times.user + cpu_times.system
|
|
28
|
+
return total
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def track_resource_usage(func):
|
|
32
|
+
"""
|
|
33
|
+
Decorator to track resource usage during function execution.
|
|
34
|
+
Logs memory usage, CPU time, and wall clock time at the end of the function.
|
|
35
|
+
"""
|
|
36
|
+
@wraps(func)
|
|
37
|
+
def wrapper(*args, **kwargs):
|
|
38
|
+
# Get the current process
|
|
39
|
+
process = psutil.Process(os.getpid())
|
|
40
|
+
|
|
41
|
+
# Initialize tracking variables
|
|
42
|
+
peak_memory = 0
|
|
43
|
+
cpu_percent_samples = []
|
|
44
|
+
stop_thread = False
|
|
45
|
+
|
|
46
|
+
# Function to monitor resource usage
|
|
47
|
+
def resource_monitor():
|
|
48
|
+
nonlocal peak_memory, cpu_percent_samples
|
|
49
|
+
while not stop_thread:
|
|
50
|
+
try:
|
|
51
|
+
# Get current memory usage in MB
|
|
52
|
+
current_memory = process.memory_info().rss / (1024 * 1024)
|
|
53
|
+
peak_memory = max(peak_memory, current_memory)
|
|
54
|
+
|
|
55
|
+
# Get CPU usage percentage
|
|
56
|
+
cpu_percent = process.cpu_percent(interval=None)
|
|
57
|
+
if cpu_percent > 0: # Skip initial zero readings
|
|
58
|
+
cpu_percent_samples.append(cpu_percent)
|
|
59
|
+
|
|
60
|
+
time.sleep(0.5)
|
|
61
|
+
except Exception:
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
# Start resource monitoring in a separate thread
|
|
65
|
+
monitor_thread = threading.Thread(target=resource_monitor)
|
|
66
|
+
monitor_thread.daemon = True
|
|
67
|
+
monitor_thread.start()
|
|
68
|
+
|
|
69
|
+
# Get start times
|
|
70
|
+
start_wall_time = time.time()
|
|
71
|
+
start_cpu_time = process_cpu_time(process)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
# Run the actual function
|
|
75
|
+
result = func(*args, **kwargs)
|
|
76
|
+
return result
|
|
77
|
+
finally:
|
|
78
|
+
# Stop the monitoring thread
|
|
79
|
+
stop_thread = True
|
|
80
|
+
monitor_thread.join(timeout=1.0)
|
|
81
|
+
|
|
82
|
+
# Calculate elapsed times
|
|
83
|
+
end_wall_time = time.time()
|
|
84
|
+
end_cpu_time = process_cpu_time(process)
|
|
85
|
+
|
|
86
|
+
wall_time = end_wall_time - start_wall_time
|
|
87
|
+
cpu_time = end_cpu_time - start_cpu_time
|
|
88
|
+
|
|
89
|
+
# Calculate average CPU percentage
|
|
90
|
+
avg_cpu_percent = (
|
|
91
|
+
sum(cpu_percent_samples) / len(cpu_percent_samples) if cpu_percent_samples else 0
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Adjust for macOS if needed
|
|
95
|
+
if sys.platform == "darwin":
|
|
96
|
+
from gsMap.utils import macos_timebase_factor
|
|
97
|
+
factor = macos_timebase_factor()
|
|
98
|
+
cpu_time *= factor
|
|
99
|
+
avg_cpu_percent *= factor
|
|
100
|
+
|
|
101
|
+
# Format memory for display
|
|
102
|
+
if peak_memory < 1024:
|
|
103
|
+
memory_str = f"{peak_memory:.2f} MB"
|
|
104
|
+
else:
|
|
105
|
+
memory_str = f"{peak_memory / 1024:.2f} GB"
|
|
106
|
+
|
|
107
|
+
# Format times for display
|
|
108
|
+
if wall_time < 60:
|
|
109
|
+
wall_time_str = f"{wall_time:.2f} seconds"
|
|
110
|
+
elif wall_time < 3600:
|
|
111
|
+
wall_time_str = f"{wall_time / 60:.2f} minutes"
|
|
112
|
+
else:
|
|
113
|
+
wall_time_str = f"{wall_time / 3600:.2f} hours"
|
|
114
|
+
|
|
115
|
+
if cpu_time < 60:
|
|
116
|
+
cpu_time_str = f"{cpu_time:.2f} seconds"
|
|
117
|
+
elif cpu_time < 3600:
|
|
118
|
+
cpu_time_str = f"{cpu_time / 60:.2f} minutes"
|
|
119
|
+
else:
|
|
120
|
+
cpu_time_str = f"{cpu_time / 3600:.2f} hours"
|
|
121
|
+
|
|
122
|
+
# Log the resource usage
|
|
123
|
+
logger.info("Resource usage summary:")
|
|
124
|
+
logger.info(f" • Wall clock time: {wall_time_str}")
|
|
125
|
+
logger.info(f" • CPU time: {cpu_time_str}")
|
|
126
|
+
logger.info(f" • Average CPU utilization: {avg_cpu_percent:.1f}%")
|
|
127
|
+
logger.info(f" • Peak memory usage: {memory_str}")
|
|
128
|
+
|
|
129
|
+
return wrapper
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def show_banner(command_name: str, version: str = "1.73.5"):
|
|
133
|
+
"""Display gsMap banner and version information."""
|
|
134
|
+
command_name = command_name.replace("_", " ")
|
|
135
|
+
gsMap_main_logo = pyfiglet.figlet_format(
|
|
136
|
+
"gsMap",
|
|
137
|
+
font="doom",
|
|
138
|
+
width=80,
|
|
139
|
+
justify="center",
|
|
140
|
+
).rstrip()
|
|
141
|
+
print(gsMap_main_logo, flush=True)
|
|
142
|
+
version_number = "Version: " + version
|
|
143
|
+
print(version_number.center(80), flush=True)
|
|
144
|
+
print("=" * 80, flush=True)
|
|
145
|
+
logger.info(f"Running {command_name}...")
|
|
146
|
+
|
|
147
|
+
# Record start time for the log message
|
|
148
|
+
start_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
149
|
+
logger.info(f"Started at: {start_time}")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def dataclass_typer(func):
|
|
153
|
+
"""
|
|
154
|
+
Decorator to convert a function that takes a dataclass config
|
|
155
|
+
into a Typer command with individual CLI options.
|
|
156
|
+
"""
|
|
157
|
+
sig = inspect.signature(func)
|
|
158
|
+
|
|
159
|
+
# Get the dataclass type from the function signature
|
|
160
|
+
config_param = list(sig.parameters.values())[0]
|
|
161
|
+
config_class = config_param.annotation
|
|
162
|
+
|
|
163
|
+
@wraps(func)
|
|
164
|
+
@track_resource_usage # Add resource tracking
|
|
165
|
+
def wrapper(**kwargs):
|
|
166
|
+
# Show banner
|
|
167
|
+
try:
|
|
168
|
+
from gsMap import __version__
|
|
169
|
+
version = __version__
|
|
170
|
+
except ImportError:
|
|
171
|
+
version = "development"
|
|
172
|
+
show_banner(func.__name__, version)
|
|
173
|
+
|
|
174
|
+
# Create the config instance
|
|
175
|
+
config = config_class(**kwargs)
|
|
176
|
+
result = func(config)
|
|
177
|
+
|
|
178
|
+
# Record end time
|
|
179
|
+
end_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
180
|
+
logger.info(f"Finished at: {end_time}")
|
|
181
|
+
|
|
182
|
+
return result
|
|
183
|
+
|
|
184
|
+
# Build new parameters from dataclass fields
|
|
185
|
+
from dataclasses import MISSING
|
|
186
|
+
params = []
|
|
187
|
+
|
|
188
|
+
core_only = getattr(config_class, "__core_only__", False)
|
|
189
|
+
|
|
190
|
+
def is_core_field(field_name: str, cls: type[Any]) -> bool:
|
|
191
|
+
"""Check if a field originates from a 'Core' config class."""
|
|
192
|
+
for base in cls.__mro__:
|
|
193
|
+
if field_name in getattr(base, "__annotations__", {}):
|
|
194
|
+
# If ANY class in the inheritance chain for this field explicitly disables quick mode, honor it
|
|
195
|
+
if getattr(base, "__display_in_quick_mode_cli__", True) is False:
|
|
196
|
+
return False
|
|
197
|
+
return True
|
|
198
|
+
|
|
199
|
+
for field in fields(config_class):
|
|
200
|
+
# Only include fields with Annotated type hints in the CLI
|
|
201
|
+
# This allows internal fields to be excluded from CLI parameters
|
|
202
|
+
|
|
203
|
+
# Check if the field type is Annotated
|
|
204
|
+
if get_origin(field.type) != Annotated:
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
# Get Annotated metadata
|
|
208
|
+
annotated_args = get_args(field.type)
|
|
209
|
+
|
|
210
|
+
# Check for explicit display override in Annotated metadata
|
|
211
|
+
# e.g., Annotated[int, typer.Option(...), {"__display_in_quick_mode_cli__": True}]
|
|
212
|
+
display_override = None
|
|
213
|
+
for arg in annotated_args:
|
|
214
|
+
if isinstance(arg, dict) and "__display_in_quick_mode_cli__" in arg:
|
|
215
|
+
display_override = arg["__display_in_quick_mode_cli__"]
|
|
216
|
+
break
|
|
217
|
+
|
|
218
|
+
if core_only:
|
|
219
|
+
# If field explicitly says True, we show it even if class says False
|
|
220
|
+
if display_override is True:
|
|
221
|
+
pass
|
|
222
|
+
# If field explicitly says False, we hide it
|
|
223
|
+
elif display_override is False:
|
|
224
|
+
continue
|
|
225
|
+
# Otherwise, fall back to class-level logic
|
|
226
|
+
elif not is_core_field(field.name, config_class):
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
# Get the actual type and typer.Option from Annotated
|
|
230
|
+
# Annotated[type, typer.Option(...)] -> type is at args[0]
|
|
231
|
+
get_args(field.type)[0]
|
|
232
|
+
|
|
233
|
+
# Determine the default value
|
|
234
|
+
if field.default is not MISSING:
|
|
235
|
+
# Field has a default value, use it as the parameter default
|
|
236
|
+
default_value = field.default
|
|
237
|
+
elif field.default_factory is not MISSING:
|
|
238
|
+
# Field has a default factory, call it to get the default value
|
|
239
|
+
default_value = field.default_factory()
|
|
240
|
+
else:
|
|
241
|
+
# No default, parameter is required
|
|
242
|
+
default_value = inspect.Parameter.empty
|
|
243
|
+
|
|
244
|
+
# Create the parameter
|
|
245
|
+
if default_value is not inspect.Parameter.empty:
|
|
246
|
+
param = inspect.Parameter(
|
|
247
|
+
field.name,
|
|
248
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
249
|
+
annotation=field.type, # Keep the full Annotated type
|
|
250
|
+
default=default_value
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
param = inspect.Parameter(
|
|
254
|
+
field.name,
|
|
255
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
256
|
+
annotation=field.type # Keep the full Annotated type
|
|
257
|
+
)
|
|
258
|
+
params.append(param)
|
|
259
|
+
|
|
260
|
+
# Update the wrapper's signature
|
|
261
|
+
wrapper.__signature__ = inspect.Signature(params)
|
|
262
|
+
|
|
263
|
+
# Preserve the original function's docstring
|
|
264
|
+
wrapper.__doc__ = func.__doc__
|
|
265
|
+
|
|
266
|
+
return wrapper
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
@functools.cache
|
|
270
|
+
def macos_timebase_factor():
|
|
271
|
+
"""
|
|
272
|
+
On MacOS, `psutil.Process.cpu_times()` is not accurate, check activity monitor instead.
|
|
273
|
+
see: https://github.com/giampaolo/psutil/issues/2411#issuecomment-2274682289
|
|
274
|
+
"""
|
|
275
|
+
default_factor = 1
|
|
276
|
+
ioreg_output_lines = []
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
result = subprocess.run(
|
|
280
|
+
["ioreg", "-p", "IODeviceTree", "-c", "IOPlatformDevice"],
|
|
281
|
+
capture_output=True,
|
|
282
|
+
text=True,
|
|
283
|
+
check=True,
|
|
284
|
+
)
|
|
285
|
+
ioreg_output_lines = result.stdout.splitlines()
|
|
286
|
+
except subprocess.CalledProcessError as e:
|
|
287
|
+
print(f"Command failed: {e}")
|
|
288
|
+
return default_factor
|
|
289
|
+
|
|
290
|
+
if not ioreg_output_lines:
|
|
291
|
+
return default_factor
|
|
292
|
+
|
|
293
|
+
for line in ioreg_output_lines:
|
|
294
|
+
if "timebase-frequency" in line:
|
|
295
|
+
match = re.search(r"<([0-9a-fA-F]+)>", line)
|
|
296
|
+
if not match:
|
|
297
|
+
return default_factor
|
|
298
|
+
byte_data = bytes.fromhex(match.group(1))
|
|
299
|
+
timebase_freq = int.from_bytes(byte_data, byteorder="little")
|
|
300
|
+
# Typically, it should be 1000/24.
|
|
301
|
+
return pow(10, 9) / timebase_freq
|
|
302
|
+
return default_factor
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration for finding latent representations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Annotated
|
|
10
|
+
|
|
11
|
+
import typer
|
|
12
|
+
import yaml
|
|
13
|
+
|
|
14
|
+
from gsMap.config.base import ConfigWithAutoPaths
|
|
15
|
+
from gsMap.config.utils import (
|
|
16
|
+
process_h5ad_inputs,
|
|
17
|
+
validate_h5ad_structure,
|
|
18
|
+
verify_homolog_file_format,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger("gsMap.config")
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class FindLatentModelConfig:
|
|
25
|
+
__display_in_quick_mode_cli__ = False
|
|
26
|
+
|
|
27
|
+
feat_cell: Annotated[int, typer.Option(
|
|
28
|
+
help="Number of top variable features to retain",
|
|
29
|
+
min=100,
|
|
30
|
+
max=10000
|
|
31
|
+
)] = 2000
|
|
32
|
+
|
|
33
|
+
# Feature extraction parameters
|
|
34
|
+
n_neighbors: Annotated[int, typer.Option(
|
|
35
|
+
help="Number of neighbors for LGCN",
|
|
36
|
+
min=1,
|
|
37
|
+
max=50
|
|
38
|
+
)] = 10
|
|
39
|
+
|
|
40
|
+
K: Annotated[int, typer.Option(
|
|
41
|
+
help="Graph convolution depth for LGCN",
|
|
42
|
+
min=1,
|
|
43
|
+
max=10
|
|
44
|
+
)] = 3
|
|
45
|
+
|
|
46
|
+
# Model parameters
|
|
47
|
+
hidden_size: Annotated[int, typer.Option(
|
|
48
|
+
help="Units in the first hidden layer",
|
|
49
|
+
min=32,
|
|
50
|
+
max=512
|
|
51
|
+
)] = 128
|
|
52
|
+
|
|
53
|
+
embedding_size: Annotated[int, typer.Option(
|
|
54
|
+
help="Size of the latent embedding layer",
|
|
55
|
+
min=8,
|
|
56
|
+
max=128
|
|
57
|
+
)] = 32
|
|
58
|
+
|
|
59
|
+
# Transformer parameters
|
|
60
|
+
use_tf: Annotated[bool, typer.Option(
|
|
61
|
+
"--use-tf",
|
|
62
|
+
help="Enable transformer module"
|
|
63
|
+
)] = False
|
|
64
|
+
|
|
65
|
+
module_dim: Annotated[int, typer.Option(
|
|
66
|
+
help="Dimensionality of transformer modules",
|
|
67
|
+
min=10,
|
|
68
|
+
max=100
|
|
69
|
+
)] = 30
|
|
70
|
+
|
|
71
|
+
hidden_gmf: Annotated[int, typer.Option(
|
|
72
|
+
help="Hidden units for global mean feature extractor",
|
|
73
|
+
min=32,
|
|
74
|
+
max=512
|
|
75
|
+
)] = 128
|
|
76
|
+
|
|
77
|
+
n_modules: Annotated[int, typer.Option(
|
|
78
|
+
help="Number of transformer modules",
|
|
79
|
+
min=4,
|
|
80
|
+
max=64
|
|
81
|
+
)] = 16
|
|
82
|
+
|
|
83
|
+
nhead: Annotated[int, typer.Option(
|
|
84
|
+
help="Number of attention heads in transformer",
|
|
85
|
+
min=1,
|
|
86
|
+
max=16
|
|
87
|
+
)] = 4
|
|
88
|
+
|
|
89
|
+
n_enc_layer: Annotated[int, typer.Option(
|
|
90
|
+
help="Number of transformer encoder layers",
|
|
91
|
+
min=1,
|
|
92
|
+
max=8
|
|
93
|
+
)] = 2
|
|
94
|
+
|
|
95
|
+
# Training parameters
|
|
96
|
+
distribution: Annotated[str, typer.Option(
|
|
97
|
+
help="Distribution type for loss calculation",
|
|
98
|
+
case_sensitive=False
|
|
99
|
+
)] = "nb"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
batch_size: Annotated[int, typer.Option(
|
|
103
|
+
help="Batch size for training",
|
|
104
|
+
min=32,
|
|
105
|
+
max=4096
|
|
106
|
+
)] = 1024
|
|
107
|
+
|
|
108
|
+
itermax: Annotated[int, typer.Option(
|
|
109
|
+
help="Maximum number of training iterations",
|
|
110
|
+
min=10,
|
|
111
|
+
max=1000
|
|
112
|
+
)] = 100
|
|
113
|
+
|
|
114
|
+
patience: Annotated[int, typer.Option(
|
|
115
|
+
help="Early stopping patience",
|
|
116
|
+
min=1,
|
|
117
|
+
max=50
|
|
118
|
+
)] = 10
|
|
119
|
+
|
|
120
|
+
two_stage: Annotated[bool, typer.Option(
|
|
121
|
+
"--two-stage/--single-stage",
|
|
122
|
+
help="Tune the cell embeddings based on the provided annotation"
|
|
123
|
+
),{"__display_in_quick_mode_cli__": True}] = False
|
|
124
|
+
|
|
125
|
+
do_sampling: Annotated[bool, typer.Option(
|
|
126
|
+
"--do-sampling/--no-sampling",
|
|
127
|
+
help="Down-sampling cells in training"
|
|
128
|
+
)] = True
|
|
129
|
+
|
|
130
|
+
n_cell_training: Annotated[int, typer.Option(
|
|
131
|
+
help="Number of cells used for training",
|
|
132
|
+
min=1000,
|
|
133
|
+
max=1000000
|
|
134
|
+
), {"__display_in_quick_mode_cli__": True}] = 100000
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class FindLatentCoreConfig:
|
|
140
|
+
h5ad_path: Annotated[list[Path] | None, typer.Option(
|
|
141
|
+
help="Space-separated list of h5ad file paths. Sample names are derived from file names without suffix.",
|
|
142
|
+
exists=True,
|
|
143
|
+
file_okay=True,
|
|
144
|
+
)] = None
|
|
145
|
+
|
|
146
|
+
h5ad_yaml: Annotated[Path | None, typer.Option(
|
|
147
|
+
help="YAML file with sample names and h5ad paths",
|
|
148
|
+
exists=True,
|
|
149
|
+
file_okay=True,
|
|
150
|
+
dir_okay=False,
|
|
151
|
+
)] = None
|
|
152
|
+
|
|
153
|
+
h5ad_list_file: Annotated[Path | None, typer.Option(
|
|
154
|
+
help="Each row is a h5ad file path, sample name is the file name without suffix",
|
|
155
|
+
exists=True,
|
|
156
|
+
file_okay=True,
|
|
157
|
+
dir_okay=False,
|
|
158
|
+
)] = None
|
|
159
|
+
|
|
160
|
+
sample_h5ad_dict: OrderedDict | None = None
|
|
161
|
+
|
|
162
|
+
data_layer: Annotated[str, typer.Option(
|
|
163
|
+
help="Gene expression raw counts data layer in h5ad layers, e.g., 'count', 'counts'. Other wise use 'X' for adata.X"
|
|
164
|
+
)] = "X"
|
|
165
|
+
|
|
166
|
+
spatial_key: Annotated[str, typer.Option(
|
|
167
|
+
help="Spatial key in adata.obsm storing spatial coordinates"
|
|
168
|
+
)] = "spatial"
|
|
169
|
+
|
|
170
|
+
annotation: Annotated[str | None, typer.Option(
|
|
171
|
+
help="Annotation of cell type in adata.obs to use"
|
|
172
|
+
)] = None
|
|
173
|
+
|
|
174
|
+
homolog_file: Annotated[Path | None, typer.Option(
|
|
175
|
+
help="Path to homologous gene conversion file",
|
|
176
|
+
exists=True,
|
|
177
|
+
file_okay=True,
|
|
178
|
+
dir_okay=False
|
|
179
|
+
)] = None
|
|
180
|
+
|
|
181
|
+
species: str | None = None
|
|
182
|
+
|
|
183
|
+
latent_representation_niche: Annotated[str, typer.Option(
|
|
184
|
+
help="Key for spatial niche embedding in obsm"
|
|
185
|
+
)] = "emb_niche"
|
|
186
|
+
|
|
187
|
+
latent_representation_cell: Annotated[str, typer.Option(
|
|
188
|
+
help="Key for cell identity embedding in obsm"
|
|
189
|
+
)] = "emb_cell"
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
high_quality_cell_qc: Annotated[bool, typer.Option(
|
|
194
|
+
"--high-quality-cell-qc/--no-high-quality-cell-qc",
|
|
195
|
+
help="Enable/disable high quality cell QC based on module scores. If enabled, it will compute DEG and module scores."
|
|
196
|
+
)] = True
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@dataclass
|
|
200
|
+
class FindLatentRepresentationsConfig(FindLatentModelConfig, FindLatentCoreConfig, ConfigWithAutoPaths):
|
|
201
|
+
"""Find Latent Configuration"""
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def __post_init__(self):
|
|
205
|
+
super().__post_init__()
|
|
206
|
+
|
|
207
|
+
# Define input options
|
|
208
|
+
input_options = {
|
|
209
|
+
'h5ad_yaml': ('h5ad_yaml', 'yaml'),
|
|
210
|
+
'h5ad_path': ('h5ad_path', 'list'),
|
|
211
|
+
'h5ad_list_file': ('h5ad_list_file', 'file'),
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
# Process h5ad inputs
|
|
215
|
+
self.sample_h5ad_dict = process_h5ad_inputs(self, input_options)
|
|
216
|
+
|
|
217
|
+
if not self.sample_h5ad_dict:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
"At least one of h5ad_yaml, h5ad_path, h5ad_list_file, or spe_file_list must be provided"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Define required and optional fields for validation
|
|
223
|
+
required_fields = {
|
|
224
|
+
'data_layer': ('layers', self.data_layer, 'Data layer'),
|
|
225
|
+
'spatial_key': ('obsm', self.spatial_key, 'Spatial key'),
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
# Add annotation as required if provided
|
|
229
|
+
if self.annotation:
|
|
230
|
+
required_fields['annotation'] = ('obs', self.annotation, 'Annotation')
|
|
231
|
+
|
|
232
|
+
# Validate h5ad structure
|
|
233
|
+
validate_h5ad_structure(self.sample_h5ad_dict, required_fields)
|
|
234
|
+
|
|
235
|
+
# Log final sample count
|
|
236
|
+
logger.info(f"Loaded and validated {len(self.sample_h5ad_dict)} samples")
|
|
237
|
+
|
|
238
|
+
# Check if at least one sample is provided
|
|
239
|
+
if len(self.sample_h5ad_dict) == 0:
|
|
240
|
+
raise ValueError("No valid samples found in the provided input")
|
|
241
|
+
|
|
242
|
+
# Verify homolog file format if provided
|
|
243
|
+
verify_homolog_file_format(self)
|
|
244
|
+
|
|
245
|
+
self.show_config(FindLatentRepresentationsConfig)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def check_find_latent_done(config: FindLatentRepresentationsConfig) -> bool:
|
|
249
|
+
"""
|
|
250
|
+
Check if find_latent step is done by verifying validity of metadata and output files.
|
|
251
|
+
"""
|
|
252
|
+
metadata_path = config.find_latent_metadata_path
|
|
253
|
+
if not metadata_path.exists():
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
try:
|
|
257
|
+
with open(metadata_path) as f:
|
|
258
|
+
metadata = yaml.safe_load(f)
|
|
259
|
+
|
|
260
|
+
if 'outputs' not in metadata or 'latent_files' not in metadata['outputs']:
|
|
261
|
+
return False
|
|
262
|
+
|
|
263
|
+
latent_files = metadata['outputs']['latent_files']
|
|
264
|
+
if not latent_files:
|
|
265
|
+
return False
|
|
266
|
+
|
|
267
|
+
# Verify all listed files exist
|
|
268
|
+
all_exist = True
|
|
269
|
+
for sample, path_str in latent_files.items():
|
|
270
|
+
if not Path(path_str).exists():
|
|
271
|
+
all_exist = False
|
|
272
|
+
break
|
|
273
|
+
return all_exist
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.warning(f"Error checking find_latent metadata: {e}")
|
|
276
|
+
return False
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration for formatting GWAS summary statistics.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Annotated, Literal
|
|
7
|
+
|
|
8
|
+
import typer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class FormatSumstatsConfig:
|
|
13
|
+
"""Configuration for formatting GWAS summary statistics."""
|
|
14
|
+
|
|
15
|
+
sumstats: Annotated[str, typer.Option(help="Path to gwas summary data")]
|
|
16
|
+
out: Annotated[str, typer.Option(help="Path to save the formatted gwas data")]
|
|
17
|
+
|
|
18
|
+
# Arguments for specify column name
|
|
19
|
+
snp: Annotated[str | None, typer.Option(help="Name of snp column")] = None
|
|
20
|
+
a1: Annotated[str | None, typer.Option(help="Name of effect allele column")] = None
|
|
21
|
+
a2: Annotated[str | None, typer.Option(help="Name of none-effect allele column")] = None
|
|
22
|
+
info: Annotated[str | None, typer.Option(help="Name of info column")] = None
|
|
23
|
+
beta: Annotated[str | None, typer.Option(help="Name of gwas beta column.")] = None
|
|
24
|
+
se: Annotated[str | None, typer.Option(help="Name of gwas standar error of beta column")] = None
|
|
25
|
+
p: Annotated[str | None, typer.Option(help="Name of p-value column")] = None
|
|
26
|
+
frq: Annotated[str | None, typer.Option(help="Name of A1 ferquency column")] = None
|
|
27
|
+
n: Annotated[str | None, typer.Option(help="Name of sample size column")] = None
|
|
28
|
+
z: Annotated[str | None, typer.Option(help="Name of gwas Z-statistics column")] = None
|
|
29
|
+
OR: Annotated[str | None, typer.Option(help="Name of gwas OR column")] = None
|
|
30
|
+
se_OR: Annotated[str | None, typer.Option(help="Name of standar error of OR column")] = None
|
|
31
|
+
|
|
32
|
+
# Arguments for convert SNP (chr, pos) to rsid
|
|
33
|
+
chr: Annotated[str, typer.Option(help="Name of SNP chromosome column")] = "Chr"
|
|
34
|
+
pos: Annotated[str, typer.Option(help="Name of SNP positions column")] = "Pos"
|
|
35
|
+
dbsnp: Annotated[str | None, typer.Option(help="Path to reference dnsnp file")] = None
|
|
36
|
+
chunksize: Annotated[int, typer.Option(help="Chunk size for loading dbsnp file")] = 1000000
|
|
37
|
+
|
|
38
|
+
# Arguments for output format and quality
|
|
39
|
+
format: Annotated[Literal["gsMap", "COJO"], typer.Option(help="Format of output data", case_sensitive=False)] = "gsMap"
|
|
40
|
+
info_min: Annotated[float, typer.Option(help="Minimum INFO score.")] = 0.9
|
|
41
|
+
maf_min: Annotated[float, typer.Option(help="Minimum MAF.")] = 0.01
|
|
42
|
+
keep_chr_pos: Annotated[bool, typer.Option(help="Keep SNP chromosome and position columns in the output data")] = False
|
|
43
|
+
|
|
44
|
+
def __post_init__(self):
|
|
45
|
+
# Handle n being potentially a number passed as a string from CLI
|
|
46
|
+
if isinstance(self.n, str):
|
|
47
|
+
try:
|
|
48
|
+
if "." in self.n:
|
|
49
|
+
self.n = float(self.n)
|
|
50
|
+
else:
|
|
51
|
+
self.n = int(self.n)
|
|
52
|
+
except ValueError:
|
|
53
|
+
# Leave as string if it's a column name
|
|
54
|
+
pass
|