gsMap 1.71.2__py3-none-any.whl → 1.73.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN/adjacency_matrix.py +25 -27
- gsMap/GNN/model.py +9 -7
- gsMap/GNN/train.py +8 -11
- gsMap/__init__.py +3 -3
- gsMap/__main__.py +3 -2
- gsMap/cauchy_combination_test.py +78 -75
- gsMap/config.py +948 -322
- gsMap/create_slice_mean.py +168 -0
- gsMap/diagnosis.py +179 -101
- gsMap/find_latent_representation.py +29 -27
- gsMap/format_sumstats.py +239 -201
- gsMap/generate_ldscore.py +334 -222
- gsMap/latent_to_gene.py +128 -68
- gsMap/main.py +23 -14
- gsMap/report.py +39 -25
- gsMap/run_all_mode.py +87 -46
- gsMap/setup.py +1 -1
- gsMap/spatial_ldsc_multiple_sumstats.py +154 -80
- gsMap/utils/generate_r2_matrix.py +100 -346
- gsMap/utils/jackknife.py +84 -80
- gsMap/utils/manhattan_plot.py +180 -207
- gsMap/utils/regression_read.py +83 -176
- gsMap/visualize.py +82 -64
- gsmap-1.73.0.dist-info/METADATA +169 -0
- gsmap-1.73.0.dist-info/RECORD +31 -0
- {gsmap-1.71.2.dist-info → gsmap-1.73.0.dist-info}/WHEEL +1 -1
- {gsmap-1.71.2.dist-info → gsmap-1.73.0.dist-info/licenses}/LICENSE +6 -6
- gsMap/utils/make_annotations.py +0 -518
- gsmap-1.71.2.dist-info/METADATA +0 -105
- gsmap-1.71.2.dist-info/RECORD +0 -31
- {gsmap-1.71.2.dist-info → gsmap-1.73.0.dist-info}/entry_points.txt +0 -0
gsMap/utils/regression_read.py
CHANGED
@@ -1,55 +1,55 @@
|
|
1
|
+
import os
|
2
|
+
|
1
3
|
import numpy as np
|
2
4
|
import pandas as pd
|
3
|
-
import os
|
4
5
|
|
5
6
|
|
6
7
|
# Fun for reading gwas data
|
7
8
|
def _read_sumstats(fh, alleles=False, dropna=False):
|
8
|
-
|
9
|
+
"""
|
9
10
|
Parse gwas summary statistics.
|
10
|
-
|
11
|
-
print(
|
11
|
+
"""
|
12
|
+
print(f"Reading summary statistics from {fh} ...")
|
12
13
|
sumstats = ps_sumstats(fh, alleles=alleles, dropna=dropna)
|
13
|
-
print(
|
14
|
+
print(f"Read summary statistics for {len(sumstats)} SNPs.")
|
14
15
|
|
15
16
|
m = len(sumstats)
|
16
|
-
sumstats = sumstats.drop_duplicates(subset=
|
17
|
+
sumstats = sumstats.drop_duplicates(subset="SNP")
|
17
18
|
if m > len(sumstats):
|
18
|
-
print(
|
19
|
+
print(f"Dropped {m - len(sumstats)} SNPs with duplicated rs numbers.")
|
19
20
|
|
20
21
|
return sumstats
|
21
22
|
|
22
23
|
|
23
24
|
def ps_sumstats(fh, alleles=False, dropna=True):
|
24
|
-
|
25
|
+
"""
|
25
26
|
Parses .sumstats files. See docs/file_formats_sumstats.txt.
|
26
|
-
|
27
|
-
|
28
|
-
dtype_dict = {'SNP': str, 'Z': float, 'N': float, 'A1': str, 'A2': str}
|
27
|
+
"""
|
28
|
+
dtype_dict = {"SNP": str, "Z": float, "N": float, "A1": str, "A2": str}
|
29
29
|
compression = get_compression(fh)
|
30
|
-
usecols = [
|
30
|
+
usecols = ["SNP", "Z", "N"]
|
31
31
|
if alleles:
|
32
|
-
usecols += [
|
32
|
+
usecols += ["A1", "A2"]
|
33
33
|
|
34
34
|
try:
|
35
35
|
x = read_csv(fh, usecols=usecols, dtype=dtype_dict, compression=compression)
|
36
36
|
except (AttributeError, ValueError) as e:
|
37
|
-
raise ValueError(
|
37
|
+
raise ValueError("Improperly formatted sumstats file: " + str(e.args)) from e
|
38
38
|
|
39
39
|
if dropna:
|
40
|
-
x = x.dropna(how=
|
40
|
+
x = x.dropna(how="any")
|
41
41
|
|
42
42
|
return x
|
43
43
|
|
44
44
|
|
45
45
|
def get_compression(fh):
|
46
|
-
|
46
|
+
"""
|
47
47
|
Determin the format of compression used with read_csv?
|
48
|
-
|
49
|
-
if fh.endswith(
|
50
|
-
compression =
|
51
|
-
elif fh.endswith(
|
52
|
-
compression =
|
48
|
+
"""
|
49
|
+
if fh.endswith("gz"):
|
50
|
+
compression = "gzip"
|
51
|
+
elif fh.endswith("bz2"):
|
52
|
+
compression = "bz2"
|
53
53
|
else:
|
54
54
|
compression = None
|
55
55
|
# -
|
@@ -57,120 +57,96 @@ def get_compression(fh):
|
|
57
57
|
|
58
58
|
|
59
59
|
def read_csv(fh, **kwargs):
|
60
|
-
|
60
|
+
"""
|
61
61
|
Read the csv data
|
62
|
-
|
63
|
-
return pd.read_csv(fh, sep=
|
62
|
+
"""
|
63
|
+
return pd.read_csv(fh, sep=r"\s+", na_values=".", **kwargs)
|
64
64
|
|
65
65
|
|
66
|
-
# Fun for reading loading LD scores
|
66
|
+
# Fun for reading loading LD scores
|
67
67
|
def which_compression(fh):
|
68
|
-
|
68
|
+
"""
|
69
69
|
Given a file prefix, figure out what sort of compression to use.
|
70
|
-
|
71
|
-
if os.access(fh +
|
72
|
-
suffix =
|
73
|
-
compression =
|
74
|
-
elif os.access(fh +
|
75
|
-
suffix =
|
76
|
-
compression =
|
77
|
-
elif os.access(fh +
|
78
|
-
suffix =
|
79
|
-
compression =
|
80
|
-
elif os.access(fh +
|
81
|
-
suffix =
|
82
|
-
compression =
|
70
|
+
"""
|
71
|
+
if os.access(fh + ".bz2", 4):
|
72
|
+
suffix = ".bz2"
|
73
|
+
compression = "bz2"
|
74
|
+
elif os.access(fh + ".gz", 4):
|
75
|
+
suffix = ".gz"
|
76
|
+
compression = "gzip"
|
77
|
+
elif os.access(fh + ".parquet", 4):
|
78
|
+
suffix = ".parquet"
|
79
|
+
compression = "parquet"
|
80
|
+
elif os.access(fh + ".feather", 4):
|
81
|
+
suffix = ".feather"
|
82
|
+
compression = "feather"
|
83
83
|
elif os.access(fh, 4):
|
84
|
-
suffix =
|
84
|
+
suffix = ""
|
85
85
|
compression = None
|
86
86
|
else:
|
87
|
-
raise
|
87
|
+
raise OSError(f"Could not open {fh}[./gz/bz2/parquet/feather]")
|
88
88
|
# -
|
89
89
|
return suffix, compression
|
90
90
|
|
91
91
|
|
92
|
-
def _read_ref_ld(ld_file):
|
93
|
-
suffix = '.l2.ldscore'
|
94
|
-
file = ld_file
|
95
|
-
first_fh = f'{file}1{suffix}'
|
96
|
-
s, compression = which_compression(first_fh)
|
97
|
-
#
|
98
|
-
ldscore_array = []
|
99
|
-
print(f'Reading ld score annotations from {file}[1-22]{suffix}.{compression}')
|
100
|
-
|
101
|
-
for chr in range(1, 23):
|
102
|
-
file_chr = f'{file}{chr}{suffix}{s}'
|
103
|
-
#
|
104
|
-
if compression == 'parquet':
|
105
|
-
x = pd.read_parquet(file_chr)
|
106
|
-
elif compression == 'feather':
|
107
|
-
x = pd.read_feather(file_chr)
|
108
|
-
else:
|
109
|
-
x = pd.read_csv(file_chr, compression=compression, sep='\t')
|
110
|
-
|
111
|
-
x = x.sort_values(by=['CHR', 'BP']) # SEs will be wrong unless sorted
|
112
|
-
|
113
|
-
columns_to_drop = ['MAF', 'CM', 'Gene', 'TSS', 'CHR', 'BP']
|
114
|
-
columns_to_drop = [col for col in columns_to_drop if col in x.columns]
|
115
|
-
x = x.drop(columns_to_drop, axis=1)
|
116
|
-
|
117
|
-
ldscore_array.append(x)
|
118
|
-
#
|
119
|
-
ref_ld = pd.concat(ldscore_array, axis=0)
|
120
|
-
return ref_ld
|
121
|
-
|
122
|
-
|
123
92
|
def _read_ref_ld_v2(ld_file):
|
124
|
-
suffix =
|
93
|
+
suffix = ".l2.ldscore"
|
125
94
|
file = ld_file
|
126
|
-
first_fh = f
|
95
|
+
first_fh = f"{file}1{suffix}"
|
127
96
|
s, compression = which_compression(first_fh)
|
128
|
-
print(f
|
97
|
+
print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
|
129
98
|
ref_ld = pd.concat(
|
130
|
-
[pd.read_feather(f
|
99
|
+
[pd.read_feather(f"{file}{chr}{suffix}{s}") for chr in range(1, 23)], axis=0
|
131
100
|
)
|
132
101
|
# set first column as index
|
133
|
-
ref_ld.rename(columns={
|
134
|
-
ref_ld.set_index(
|
102
|
+
ref_ld.rename(columns={"index": "SNP"}, inplace=True)
|
103
|
+
ref_ld.set_index("SNP", inplace=True)
|
135
104
|
return ref_ld
|
136
105
|
|
106
|
+
|
137
107
|
def _read_M_v2(ld_file, n_annot, not_M_5_50):
|
138
|
-
suffix =
|
108
|
+
suffix = ".l2.M"
|
139
109
|
if not not_M_5_50:
|
140
|
-
suffix +=
|
141
|
-
M_annot= np.array(
|
110
|
+
suffix += "_5_50"
|
111
|
+
M_annot = np.array(
|
142
112
|
[
|
143
|
-
np.loadtxt(
|
144
|
-
|
145
|
-
|
113
|
+
np.loadtxt(
|
114
|
+
f"{ld_file}{chr}{suffix}",
|
115
|
+
)
|
116
|
+
for chr in range(1, 23)
|
117
|
+
]
|
146
118
|
)
|
147
119
|
assert M_annot.shape == (22, n_annot)
|
148
120
|
return M_annot.sum(axis=0).reshape((1, n_annot))
|
149
|
-
|
121
|
+
|
122
|
+
|
123
|
+
# Fun for reading M annotations
|
150
124
|
def _read_M(ld_file, n_annot, not_M_5_50):
|
151
|
-
|
125
|
+
"""
|
152
126
|
Read M (--M, --M-file, etc).
|
153
|
-
|
127
|
+
"""
|
154
128
|
M_annot = M(ld_file, common=(not not_M_5_50))
|
155
129
|
|
156
130
|
try:
|
157
131
|
M_annot = np.array(M_annot).reshape((1, n_annot))
|
158
132
|
except ValueError as e:
|
159
|
-
raise ValueError(
|
133
|
+
raise ValueError(
|
134
|
+
"# terms in --M must match # of LD Scores in --ref-ld.\n" + str(e.args)
|
135
|
+
) from e
|
160
136
|
return M_annot
|
161
137
|
|
162
138
|
|
163
139
|
def M(fh, common=False):
|
164
|
-
|
140
|
+
"""
|
165
141
|
Parses .l{N}.M files, split across num chromosomes.
|
166
|
-
|
167
|
-
suffix =
|
142
|
+
"""
|
143
|
+
suffix = ".l2.M"
|
168
144
|
if common:
|
169
|
-
suffix +=
|
145
|
+
suffix += "_5_50"
|
170
146
|
# -
|
171
147
|
M_array = []
|
172
148
|
for i in range(1, 23):
|
173
|
-
M_current = pd.read_csv(f
|
149
|
+
M_current = pd.read_csv(f"{fh}{i}" + suffix, header=None)
|
174
150
|
M_array.append(M_current)
|
175
151
|
|
176
152
|
M_array = pd.concat(M_array, axis=1).sum(axis=1)
|
@@ -178,117 +154,48 @@ def M(fh, common=False):
|
|
178
154
|
return np.array(M_array).reshape((1, len(M_array)))
|
179
155
|
|
180
156
|
|
181
|
-
def _check_variance(M_annot, ref_ld):
|
182
|
-
'''
|
183
|
-
Remove zero-variance LD Scores.
|
184
|
-
'''
|
185
|
-
ii = ref_ld.iloc[:, 1:].var() == 0 # NB there is a SNP column here
|
186
|
-
if ii.all():
|
187
|
-
raise ValueError('All LD Scores have zero variance.')
|
188
|
-
else:
|
189
|
-
print('Removing partitioned LD Scores with zero variance.')
|
190
|
-
ii_snp = np.array([True] + list(~ii))
|
191
|
-
ii_m = np.array(~ii)
|
192
|
-
ref_ld = ref_ld.iloc[:, ii_snp]
|
193
|
-
M_annot = M_annot[:, ii_m]
|
194
|
-
# -
|
195
|
-
return M_annot, ref_ld, ii
|
196
157
|
def _check_variance_v2(M_annot, ref_ld):
|
197
158
|
ii = ref_ld.var() == 0
|
198
159
|
if ii.all():
|
199
|
-
raise ValueError(
|
160
|
+
raise ValueError("All LD Scores have zero variance.")
|
200
161
|
elif not ii.any():
|
201
|
-
print(
|
162
|
+
print("No partitioned LD Scores have zero variance.")
|
202
163
|
else:
|
203
|
-
ii_snp= ii_m = np.array(~ii)
|
204
|
-
print(f
|
164
|
+
ii_snp = ii_m = np.array(~ii)
|
165
|
+
print(f"Removing {sum(ii)} partitioned LD Scores with zero variance.")
|
205
166
|
ref_ld = ref_ld.iloc[:, ii_snp]
|
206
167
|
M_annot = M_annot[:, ii_m]
|
207
168
|
return M_annot, ref_ld
|
208
169
|
|
209
170
|
|
210
|
-
# Fun for reading regression weights
|
211
|
-
def which_compression(fh):
|
212
|
-
'''
|
213
|
-
Given a file prefix, figure out what sort of compression to use.
|
214
|
-
'''
|
215
|
-
if os.access(fh + '.bz2', 4):
|
216
|
-
suffix = '.bz2'
|
217
|
-
compression = 'bz2'
|
218
|
-
elif os.access(fh + '.gz', 4):
|
219
|
-
suffix = '.gz'
|
220
|
-
compression = 'gzip'
|
221
|
-
elif os.access(fh + '.parquet', 4):
|
222
|
-
suffix = '.parquet'
|
223
|
-
compression = 'parquet'
|
224
|
-
elif os.access(fh + '.feather', 4):
|
225
|
-
suffix = '.feather'
|
226
|
-
compression = 'feather'
|
227
|
-
elif os.access(fh, 4):
|
228
|
-
suffix = ''
|
229
|
-
compression = None
|
230
|
-
else:
|
231
|
-
raise IOError('Could not open {F}[./gz/bz2/parquet/feather]'.format(F=fh))
|
232
|
-
# -
|
233
|
-
return suffix, compression
|
234
|
-
|
235
|
-
|
236
171
|
def _read_w_ld(w_file):
|
237
|
-
suffix =
|
172
|
+
suffix = ".l2.ldscore"
|
238
173
|
file = w_file
|
239
|
-
first_fh = f
|
174
|
+
first_fh = f"{file}1{suffix}"
|
240
175
|
s, compression = which_compression(first_fh)
|
241
176
|
#
|
242
177
|
w_array = []
|
243
|
-
print(f
|
178
|
+
print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
|
244
179
|
|
245
180
|
for chr in range(1, 23):
|
246
|
-
file_chr = f
|
181
|
+
file_chr = f"{file}{chr}{suffix}{s}"
|
247
182
|
#
|
248
|
-
if compression ==
|
183
|
+
if compression == "parquet":
|
249
184
|
x = pd.read_parquet(file_chr)
|
250
|
-
elif compression ==
|
185
|
+
elif compression == "feather":
|
251
186
|
x = pd.read_feather(file_chr)
|
252
187
|
else:
|
253
|
-
x = pd.read_csv(file_chr, compression=compression, sep=
|
188
|
+
x = pd.read_csv(file_chr, compression=compression, sep="\t")
|
254
189
|
|
255
|
-
x = x.sort_values(by=[
|
190
|
+
x = x.sort_values(by=["CHR", "BP"])
|
256
191
|
|
257
|
-
columns_to_drop = [
|
192
|
+
columns_to_drop = ["MAF", "CM", "Gene", "TSS", "CHR", "BP"]
|
258
193
|
columns_to_drop = [col for col in columns_to_drop if col in x.columns]
|
259
194
|
x = x.drop(columns_to_drop, axis=1)
|
260
195
|
|
261
196
|
w_array.append(x)
|
262
197
|
#
|
263
198
|
w_ld = pd.concat(w_array, axis=0)
|
264
|
-
w_ld.columns = [
|
199
|
+
w_ld.columns = ["SNP", "LD_weights"]
|
265
200
|
|
266
201
|
return w_ld
|
267
|
-
|
268
|
-
|
269
|
-
# Fun for merging
|
270
|
-
def _merge_and_log(ld, sumstats, noun):
|
271
|
-
'''
|
272
|
-
Wrap smart merge with log messages about # of SNPs.
|
273
|
-
'''
|
274
|
-
sumstats = smart_merge(ld, sumstats)
|
275
|
-
msg = 'After merging with {F}, {N} SNPs remain.'
|
276
|
-
if len(sumstats) == 0:
|
277
|
-
raise ValueError(msg.format(N=len(sumstats), F=noun))
|
278
|
-
else:
|
279
|
-
print(msg.format(N=len(sumstats), F=noun))
|
280
|
-
# -
|
281
|
-
return sumstats
|
282
|
-
|
283
|
-
|
284
|
-
def smart_merge(x, y):
|
285
|
-
'''
|
286
|
-
Check if SNP columns are equal. If so, save time by using concat instead of merge.
|
287
|
-
'''
|
288
|
-
if len(x) == len(y) and (x.index == y.index).all() and (x.SNP == y.SNP).all():
|
289
|
-
x = x.reset_index(drop=True)
|
290
|
-
y = y.reset_index(drop=True).drop('SNP', 1)
|
291
|
-
out = pd.concat([x, y], axis=1)
|
292
|
-
else:
|
293
|
-
out = pd.merge(x, y, how='inner', on='SNP')
|
294
|
-
return out
|
gsMap/visualize.py
CHANGED
@@ -11,35 +11,41 @@ from gsMap.config import VisualizeConfig
|
|
11
11
|
|
12
12
|
|
13
13
|
def load_ldsc(ldsc_input_file):
|
14
|
-
ldsc = pd.read_csv(
|
15
|
-
|
16
|
-
|
17
|
-
|
14
|
+
ldsc = pd.read_csv(
|
15
|
+
ldsc_input_file,
|
16
|
+
compression="gzip",
|
17
|
+
dtype={"spot": str, "p": float},
|
18
|
+
index_col="spot",
|
19
|
+
usecols=["spot", "p"],
|
20
|
+
)
|
21
|
+
ldsc["logp"] = -np.log10(ldsc.p)
|
18
22
|
return ldsc
|
19
23
|
|
20
24
|
|
21
25
|
# %%
|
22
26
|
def load_st_coord(adata, feature_series: pd.Series, annotation):
|
23
27
|
spot_name = adata.obs_names.to_list()
|
24
|
-
assert
|
28
|
+
assert "spatial" in adata.obsm.keys(), "spatial coordinates are not found in adata.obsm"
|
25
29
|
|
26
30
|
# to DataFrame
|
27
|
-
space_coord = adata.obsm[
|
31
|
+
space_coord = adata.obsm["spatial"]
|
28
32
|
if isinstance(space_coord, np.ndarray):
|
29
|
-
space_coord = pd.DataFrame(space_coord, columns=[
|
33
|
+
space_coord = pd.DataFrame(space_coord, columns=["sx", "sy"], index=spot_name)
|
30
34
|
else:
|
31
|
-
space_coord = pd.DataFrame(space_coord.values, columns=[
|
35
|
+
space_coord = pd.DataFrame(space_coord.values, columns=["sx", "sy"], index=spot_name)
|
32
36
|
|
33
37
|
space_coord = space_coord[space_coord.index.isin(feature_series.index)]
|
34
38
|
space_coord_concat = pd.concat([space_coord.loc[feature_series.index], feature_series], axis=1)
|
35
39
|
space_coord_concat.head()
|
36
40
|
if annotation is not None:
|
37
|
-
annotation = pd.Series(
|
41
|
+
annotation = pd.Series(
|
42
|
+
adata.obs[annotation].values, index=adata.obs_names, name="annotation"
|
43
|
+
)
|
38
44
|
space_coord_concat = pd.concat([space_coord_concat, annotation], axis=1)
|
39
45
|
return space_coord_concat
|
40
46
|
|
41
47
|
|
42
|
-
def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH
|
48
|
+
def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH=1000):
|
43
49
|
tree = KDTree(coordinates)
|
44
50
|
distances, _ = tree.query(coordinates, k=2)
|
45
51
|
avg_min_distance = np.mean(distances[:, 1])
|
@@ -55,34 +61,42 @@ def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH = 1000):
|
|
55
61
|
return (pixel_width, pixel_height), point_size
|
56
62
|
|
57
63
|
|
58
|
-
def draw_scatter(
|
59
|
-
|
64
|
+
def draw_scatter(
|
65
|
+
space_coord_concat,
|
66
|
+
title=None,
|
67
|
+
fig_style: Literal["dark", "light"] = "light",
|
68
|
+
point_size: int = None,
|
69
|
+
width=800,
|
70
|
+
height=600,
|
71
|
+
annotation=None,
|
72
|
+
color_by="logp",
|
73
|
+
):
|
60
74
|
# Set theme based on fig_style
|
61
|
-
if fig_style ==
|
75
|
+
if fig_style == "dark":
|
62
76
|
px.defaults.template = "plotly_dark"
|
63
77
|
else:
|
64
78
|
px.defaults.template = "plotly_white"
|
65
79
|
|
66
80
|
custom_color_scale = [
|
67
|
-
(1,
|
68
|
-
(7 / 8,
|
69
|
-
(6 / 8,
|
70
|
-
(5 / 8,
|
71
|
-
(4 / 8,
|
72
|
-
(3 / 8,
|
73
|
-
(2 / 8,
|
74
|
-
(1 / 8,
|
75
|
-
(0,
|
81
|
+
(1, "#d73027"), # Red
|
82
|
+
(7 / 8, "#f46d43"), # Red-Orange
|
83
|
+
(6 / 8, "#fdae61"), # Orange
|
84
|
+
(5 / 8, "#fee090"), # Light Orange
|
85
|
+
(4 / 8, "#e0f3f8"), # Light Blue
|
86
|
+
(3 / 8, "#abd9e9"), # Sky Blue
|
87
|
+
(2 / 8, "#74add1"), # Medium Blue
|
88
|
+
(1 / 8, "#4575b4"), # Dark Blue
|
89
|
+
(0, "#313695"), # Deep Blue
|
76
90
|
]
|
77
91
|
custom_color_scale.reverse()
|
78
92
|
|
79
93
|
# Create the scatter plot
|
80
94
|
fig = px.scatter(
|
81
95
|
space_coord_concat,
|
82
|
-
x=
|
83
|
-
y=
|
96
|
+
x="sx",
|
97
|
+
y="sy",
|
84
98
|
color=color_by,
|
85
|
-
symbol=
|
99
|
+
symbol="annotation" if annotation is not None else None,
|
86
100
|
title=title,
|
87
101
|
color_continuous_scale=custom_color_scale,
|
88
102
|
range_color=[0, max(space_coord_concat[color_by])],
|
@@ -90,7 +104,7 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
|
|
90
104
|
|
91
105
|
# Update marker size if specified
|
92
106
|
if point_size is not None:
|
93
|
-
fig.update_traces(marker=dict(size=point_size, symbol=
|
107
|
+
fig.update_traces(marker=dict(size=point_size, symbol="circle"))
|
94
108
|
|
95
109
|
# Update layout for figure size
|
96
110
|
fig.update_layout(
|
@@ -108,45 +122,45 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
|
|
108
122
|
x=1.0,
|
109
123
|
font=dict(
|
110
124
|
size=10,
|
111
|
-
)
|
125
|
+
),
|
112
126
|
)
|
113
127
|
)
|
114
128
|
|
115
129
|
# Update colorbar to be at the bottom and horizontal
|
116
130
|
fig.update_layout(
|
117
131
|
coloraxis_colorbar=dict(
|
118
|
-
orientation=
|
132
|
+
orientation="h", # Make the colorbar horizontal
|
119
133
|
x=0.5, # Center the colorbar horizontally
|
120
134
|
y=-0.0, # Position below the plot
|
121
|
-
xanchor=
|
122
|
-
yanchor=
|
135
|
+
xanchor="center", # Anchor the colorbar at the center
|
136
|
+
yanchor="top", # Anchor the colorbar at the top to keep it just below the plot
|
123
137
|
len=0.75, # Length of the colorbar relative to the plot width
|
124
138
|
title=dict(
|
125
|
-
text=
|
126
|
-
side=
|
127
|
-
)
|
139
|
+
text="-log10(p)" if color_by == "logp" else color_by, # Colorbar title
|
140
|
+
side="top", # Place the title at the top of the colorbar
|
141
|
+
),
|
128
142
|
)
|
129
143
|
)
|
130
144
|
# Remove gridlines, axis labels, and ticks
|
131
145
|
fig.update_xaxes(
|
132
|
-
showgrid=False,
|
133
|
-
zeroline=False,
|
146
|
+
showgrid=False, # Hide x-axis gridlines
|
147
|
+
zeroline=False, # Hide x-axis zero line
|
134
148
|
showticklabels=False, # Hide x-axis tick labels
|
135
|
-
title=None,
|
136
|
-
scaleanchor=
|
149
|
+
title=None, # Remove x-axis title
|
150
|
+
scaleanchor="y", # Link the x-axis scale to the y-axis scale
|
137
151
|
)
|
138
152
|
|
139
153
|
fig.update_yaxes(
|
140
|
-
showgrid=False,
|
141
|
-
zeroline=False,
|
154
|
+
showgrid=False, # Hide y-axis gridlines
|
155
|
+
zeroline=False, # Hide y-axis zero line
|
142
156
|
showticklabels=False, # Hide y-axis tick labels
|
143
|
-
title=None
|
157
|
+
title=None, # Remove y-axis title
|
144
158
|
)
|
145
159
|
|
146
160
|
# Adjust margins to ensure no clipping and equal axis ratio
|
147
161
|
fig.update_layout(
|
148
162
|
margin=dict(l=0, r=0, t=20, b=10), # Adjust margins to prevent clipping
|
149
|
-
height=width # Ensure the figure height matches the width for equal axis ratio
|
163
|
+
height=width, # Ensure the figure height matches the width for equal axis ratio
|
150
164
|
)
|
151
165
|
|
152
166
|
# Adjust the title location and font size
|
@@ -154,46 +168,50 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
|
|
154
168
|
title=dict(
|
155
169
|
y=0.98,
|
156
170
|
x=0.5, # Center the title horizontally
|
157
|
-
xanchor=
|
158
|
-
yanchor=
|
171
|
+
xanchor="center", # Anchor the title at the center
|
172
|
+
yanchor="top", # Anchor the title at the top
|
159
173
|
font=dict(
|
160
174
|
size=20 # Increase the title font size
|
161
|
-
)
|
162
|
-
)
|
175
|
+
),
|
176
|
+
)
|
177
|
+
)
|
163
178
|
|
164
179
|
return fig
|
165
180
|
|
166
181
|
|
167
|
-
|
168
182
|
def run_Visualize(config: VisualizeConfig):
|
169
|
-
print(f
|
170
|
-
ldsc = load_ldsc(
|
183
|
+
print(f"------Loading LDSC results of {config.ldsc_save_dir}...")
|
184
|
+
ldsc = load_ldsc(
|
185
|
+
ldsc_input_file=Path(config.ldsc_save_dir)
|
186
|
+
/ f"{config.sample_name}_{config.trait_name}.csv.gz"
|
187
|
+
)
|
171
188
|
|
172
|
-
print(f
|
173
|
-
adata = sc.read_h5ad(f
|
189
|
+
print(f"------Loading ST data of {config.sample_name}...")
|
190
|
+
adata = sc.read_h5ad(f"{config.hdf5_with_latent_path}")
|
174
191
|
|
175
192
|
space_coord_concat = load_st_coord(adata, ldsc, annotation=config.annotation)
|
176
|
-
fig = draw_scatter(
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
193
|
+
fig = draw_scatter(
|
194
|
+
space_coord_concat,
|
195
|
+
title=config.fig_title,
|
196
|
+
fig_style=config.fig_style,
|
197
|
+
point_size=config.point_size,
|
198
|
+
width=config.fig_width,
|
199
|
+
height=config.fig_height,
|
200
|
+
annotation=config.annotation,
|
201
|
+
)
|
185
202
|
|
186
203
|
# Visualization
|
187
204
|
output_dir = Path(config.output_dir)
|
188
205
|
output_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
|
189
|
-
output_file_html = output_dir / f
|
190
|
-
output_file_pdf = output_dir / f
|
191
|
-
output_file_csv = output_dir / f
|
206
|
+
output_file_html = output_dir / f"{config.sample_name}_{config.trait_name}.html"
|
207
|
+
output_file_pdf = output_dir / f"{config.sample_name}_{config.trait_name}.pdf"
|
208
|
+
output_file_csv = output_dir / f"{config.sample_name}_{config.trait_name}.csv"
|
192
209
|
|
193
210
|
fig.write_html(str(output_file_html))
|
194
211
|
fig.write_image(str(output_file_pdf))
|
195
212
|
space_coord_concat.to_csv(str(output_file_csv))
|
196
213
|
|
197
214
|
print(
|
198
|
-
f
|
199
|
-
|
215
|
+
f"------The visualization result is saved in a html file: {output_file_html} which can interactively viewed in a web browser and a pdf file: {output_file_pdf}."
|
216
|
+
)
|
217
|
+
print(f"------The visualization data is saved in a csv file: {output_file_csv}.")
|