gsMap 1.71.2__py3-none-any.whl → 1.72.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN/adjacency_matrix.py +25 -27
- gsMap/GNN/model.py +9 -7
- gsMap/GNN/train.py +8 -11
- gsMap/__init__.py +3 -3
- gsMap/__main__.py +3 -2
- gsMap/cauchy_combination_test.py +75 -72
- gsMap/config.py +822 -316
- gsMap/create_slice_mean.py +154 -0
- gsMap/diagnosis.py +179 -101
- gsMap/find_latent_representation.py +28 -26
- gsMap/format_sumstats.py +233 -201
- gsMap/generate_ldscore.py +353 -209
- gsMap/latent_to_gene.py +92 -60
- gsMap/main.py +23 -14
- gsMap/report.py +39 -25
- gsMap/run_all_mode.py +86 -46
- gsMap/setup.py +1 -1
- gsMap/spatial_ldsc_multiple_sumstats.py +154 -80
- gsMap/utils/generate_r2_matrix.py +173 -140
- gsMap/utils/jackknife.py +84 -80
- gsMap/utils/manhattan_plot.py +180 -207
- gsMap/utils/regression_read.py +105 -122
- gsMap/visualize.py +82 -64
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/METADATA +21 -6
- gsmap-1.72.3.dist-info/RECORD +31 -0
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/WHEEL +1 -1
- gsMap/utils/make_annotations.py +0 -518
- gsmap-1.71.2.dist-info/RECORD +0 -31
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/LICENSE +0 -0
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/entry_points.txt +0 -0
gsMap/utils/regression_read.py
CHANGED
@@ -1,55 +1,55 @@
|
|
1
|
+
import os
|
2
|
+
|
1
3
|
import numpy as np
|
2
4
|
import pandas as pd
|
3
|
-
import os
|
4
5
|
|
5
6
|
|
6
7
|
# Fun for reading gwas data
|
7
8
|
def _read_sumstats(fh, alleles=False, dropna=False):
|
8
|
-
|
9
|
+
"""
|
9
10
|
Parse gwas summary statistics.
|
10
|
-
|
11
|
-
print(
|
11
|
+
"""
|
12
|
+
print(f"Reading summary statistics from {fh} ...")
|
12
13
|
sumstats = ps_sumstats(fh, alleles=alleles, dropna=dropna)
|
13
|
-
print(
|
14
|
+
print(f"Read summary statistics for {len(sumstats)} SNPs.")
|
14
15
|
|
15
16
|
m = len(sumstats)
|
16
|
-
sumstats = sumstats.drop_duplicates(subset=
|
17
|
+
sumstats = sumstats.drop_duplicates(subset="SNP")
|
17
18
|
if m > len(sumstats):
|
18
|
-
print(
|
19
|
+
print(f"Dropped {m - len(sumstats)} SNPs with duplicated rs numbers.")
|
19
20
|
|
20
21
|
return sumstats
|
21
22
|
|
22
23
|
|
23
24
|
def ps_sumstats(fh, alleles=False, dropna=True):
|
24
|
-
|
25
|
+
"""
|
25
26
|
Parses .sumstats files. See docs/file_formats_sumstats.txt.
|
26
|
-
|
27
|
-
|
28
|
-
dtype_dict = {'SNP': str, 'Z': float, 'N': float, 'A1': str, 'A2': str}
|
27
|
+
"""
|
28
|
+
dtype_dict = {"SNP": str, "Z": float, "N": float, "A1": str, "A2": str}
|
29
29
|
compression = get_compression(fh)
|
30
|
-
usecols = [
|
30
|
+
usecols = ["SNP", "Z", "N"]
|
31
31
|
if alleles:
|
32
|
-
usecols += [
|
32
|
+
usecols += ["A1", "A2"]
|
33
33
|
|
34
34
|
try:
|
35
35
|
x = read_csv(fh, usecols=usecols, dtype=dtype_dict, compression=compression)
|
36
36
|
except (AttributeError, ValueError) as e:
|
37
|
-
raise ValueError(
|
37
|
+
raise ValueError("Improperly formatted sumstats file: " + str(e.args)) from e
|
38
38
|
|
39
39
|
if dropna:
|
40
|
-
x = x.dropna(how=
|
40
|
+
x = x.dropna(how="any")
|
41
41
|
|
42
42
|
return x
|
43
43
|
|
44
44
|
|
45
45
|
def get_compression(fh):
|
46
|
-
|
46
|
+
"""
|
47
47
|
Determin the format of compression used with read_csv?
|
48
|
-
|
49
|
-
if fh.endswith(
|
50
|
-
compression =
|
51
|
-
elif fh.endswith(
|
52
|
-
compression =
|
48
|
+
"""
|
49
|
+
if fh.endswith("gz"):
|
50
|
+
compression = "gzip"
|
51
|
+
elif fh.endswith("bz2"):
|
52
|
+
compression = "bz2"
|
53
53
|
else:
|
54
54
|
compression = None
|
55
55
|
# -
|
@@ -57,60 +57,60 @@ def get_compression(fh):
|
|
57
57
|
|
58
58
|
|
59
59
|
def read_csv(fh, **kwargs):
|
60
|
-
|
60
|
+
"""
|
61
61
|
Read the csv data
|
62
|
-
|
63
|
-
return pd.read_csv(fh, sep=
|
62
|
+
"""
|
63
|
+
return pd.read_csv(fh, sep=r"\s+", na_values=".", **kwargs)
|
64
64
|
|
65
65
|
|
66
|
-
# Fun for reading loading LD scores
|
66
|
+
# Fun for reading loading LD scores
|
67
67
|
def which_compression(fh):
|
68
|
-
|
68
|
+
"""
|
69
69
|
Given a file prefix, figure out what sort of compression to use.
|
70
|
-
|
71
|
-
if os.access(fh +
|
72
|
-
suffix =
|
73
|
-
compression =
|
74
|
-
elif os.access(fh +
|
75
|
-
suffix =
|
76
|
-
compression =
|
77
|
-
elif os.access(fh +
|
78
|
-
suffix =
|
79
|
-
compression =
|
80
|
-
elif os.access(fh +
|
81
|
-
suffix =
|
82
|
-
compression =
|
70
|
+
"""
|
71
|
+
if os.access(fh + ".bz2", 4):
|
72
|
+
suffix = ".bz2"
|
73
|
+
compression = "bz2"
|
74
|
+
elif os.access(fh + ".gz", 4):
|
75
|
+
suffix = ".gz"
|
76
|
+
compression = "gzip"
|
77
|
+
elif os.access(fh + ".parquet", 4):
|
78
|
+
suffix = ".parquet"
|
79
|
+
compression = "parquet"
|
80
|
+
elif os.access(fh + ".feather", 4):
|
81
|
+
suffix = ".feather"
|
82
|
+
compression = "feather"
|
83
83
|
elif os.access(fh, 4):
|
84
|
-
suffix =
|
84
|
+
suffix = ""
|
85
85
|
compression = None
|
86
86
|
else:
|
87
|
-
raise
|
87
|
+
raise OSError(f"Could not open {fh}[./gz/bz2/parquet/feather]")
|
88
88
|
# -
|
89
89
|
return suffix, compression
|
90
90
|
|
91
91
|
|
92
92
|
def _read_ref_ld(ld_file):
|
93
|
-
suffix =
|
93
|
+
suffix = ".l2.ldscore"
|
94
94
|
file = ld_file
|
95
|
-
first_fh = f
|
95
|
+
first_fh = f"{file}1{suffix}"
|
96
96
|
s, compression = which_compression(first_fh)
|
97
97
|
#
|
98
98
|
ldscore_array = []
|
99
|
-
print(f
|
99
|
+
print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
|
100
100
|
|
101
101
|
for chr in range(1, 23):
|
102
|
-
file_chr = f
|
102
|
+
file_chr = f"{file}{chr}{suffix}{s}"
|
103
103
|
#
|
104
|
-
if compression ==
|
104
|
+
if compression == "parquet":
|
105
105
|
x = pd.read_parquet(file_chr)
|
106
|
-
elif compression ==
|
106
|
+
elif compression == "feather":
|
107
107
|
x = pd.read_feather(file_chr)
|
108
108
|
else:
|
109
|
-
x = pd.read_csv(file_chr, compression=compression, sep=
|
109
|
+
x = pd.read_csv(file_chr, compression=compression, sep="\t")
|
110
110
|
|
111
|
-
x = x.sort_values(by=[
|
111
|
+
x = x.sort_values(by=["CHR", "BP"]) # SEs will be wrong unless sorted
|
112
112
|
|
113
|
-
columns_to_drop = [
|
113
|
+
columns_to_drop = ["MAF", "CM", "Gene", "TSS", "CHR", "BP"]
|
114
114
|
columns_to_drop = [col for col in columns_to_drop if col in x.columns]
|
115
115
|
x = x.drop(columns_to_drop, axis=1)
|
116
116
|
|
@@ -121,56 +121,63 @@ def _read_ref_ld(ld_file):
|
|
121
121
|
|
122
122
|
|
123
123
|
def _read_ref_ld_v2(ld_file):
|
124
|
-
suffix =
|
124
|
+
suffix = ".l2.ldscore"
|
125
125
|
file = ld_file
|
126
|
-
first_fh = f
|
126
|
+
first_fh = f"{file}1{suffix}"
|
127
127
|
s, compression = which_compression(first_fh)
|
128
|
-
print(f
|
128
|
+
print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
|
129
129
|
ref_ld = pd.concat(
|
130
|
-
[pd.read_feather(f
|
130
|
+
[pd.read_feather(f"{file}{chr}{suffix}{s}") for chr in range(1, 23)], axis=0
|
131
131
|
)
|
132
132
|
# set first column as index
|
133
|
-
ref_ld.rename(columns={
|
134
|
-
ref_ld.set_index(
|
133
|
+
ref_ld.rename(columns={"index": "SNP"}, inplace=True)
|
134
|
+
ref_ld.set_index("SNP", inplace=True)
|
135
135
|
return ref_ld
|
136
136
|
|
137
|
+
|
137
138
|
def _read_M_v2(ld_file, n_annot, not_M_5_50):
|
138
|
-
suffix =
|
139
|
+
suffix = ".l2.M"
|
139
140
|
if not not_M_5_50:
|
140
|
-
suffix +=
|
141
|
-
M_annot= np.array(
|
141
|
+
suffix += "_5_50"
|
142
|
+
M_annot = np.array(
|
142
143
|
[
|
143
|
-
np.loadtxt(
|
144
|
-
|
145
|
-
|
144
|
+
np.loadtxt(
|
145
|
+
f"{ld_file}{chr}{suffix}",
|
146
|
+
)
|
147
|
+
for chr in range(1, 23)
|
148
|
+
]
|
146
149
|
)
|
147
150
|
assert M_annot.shape == (22, n_annot)
|
148
151
|
return M_annot.sum(axis=0).reshape((1, n_annot))
|
149
|
-
|
152
|
+
|
153
|
+
|
154
|
+
# Fun for reading M annotations
|
150
155
|
def _read_M(ld_file, n_annot, not_M_5_50):
|
151
|
-
|
156
|
+
"""
|
152
157
|
Read M (--M, --M-file, etc).
|
153
|
-
|
158
|
+
"""
|
154
159
|
M_annot = M(ld_file, common=(not not_M_5_50))
|
155
160
|
|
156
161
|
try:
|
157
162
|
M_annot = np.array(M_annot).reshape((1, n_annot))
|
158
163
|
except ValueError as e:
|
159
|
-
raise ValueError(
|
164
|
+
raise ValueError(
|
165
|
+
"# terms in --M must match # of LD Scores in --ref-ld.\n" + str(e.args)
|
166
|
+
) from e
|
160
167
|
return M_annot
|
161
168
|
|
162
169
|
|
163
170
|
def M(fh, common=False):
|
164
|
-
|
171
|
+
"""
|
165
172
|
Parses .l{N}.M files, split across num chromosomes.
|
166
|
-
|
167
|
-
suffix =
|
173
|
+
"""
|
174
|
+
suffix = ".l2.M"
|
168
175
|
if common:
|
169
|
-
suffix +=
|
176
|
+
suffix += "_5_50"
|
170
177
|
# -
|
171
178
|
M_array = []
|
172
179
|
for i in range(1, 23):
|
173
|
-
M_current = pd.read_csv(f
|
180
|
+
M_current = pd.read_csv(f"{fh}{i}" + suffix, header=None)
|
174
181
|
M_array.append(M_current)
|
175
182
|
|
176
183
|
M_array = pd.concat(M_array, axis=1).sum(axis=1)
|
@@ -179,100 +186,76 @@ def M(fh, common=False):
|
|
179
186
|
|
180
187
|
|
181
188
|
def _check_variance(M_annot, ref_ld):
|
182
|
-
|
189
|
+
"""
|
183
190
|
Remove zero-variance LD Scores.
|
184
|
-
|
191
|
+
"""
|
185
192
|
ii = ref_ld.iloc[:, 1:].var() == 0 # NB there is a SNP column here
|
186
193
|
if ii.all():
|
187
|
-
raise ValueError(
|
194
|
+
raise ValueError("All LD Scores have zero variance.")
|
188
195
|
else:
|
189
|
-
print(
|
196
|
+
print("Removing partitioned LD Scores with zero variance.")
|
190
197
|
ii_snp = np.array([True] + list(~ii))
|
191
198
|
ii_m = np.array(~ii)
|
192
199
|
ref_ld = ref_ld.iloc[:, ii_snp]
|
193
200
|
M_annot = M_annot[:, ii_m]
|
194
201
|
# -
|
195
202
|
return M_annot, ref_ld, ii
|
203
|
+
|
204
|
+
|
196
205
|
def _check_variance_v2(M_annot, ref_ld):
|
197
206
|
ii = ref_ld.var() == 0
|
198
207
|
if ii.all():
|
199
|
-
raise ValueError(
|
208
|
+
raise ValueError("All LD Scores have zero variance.")
|
200
209
|
elif not ii.any():
|
201
|
-
print(
|
210
|
+
print("No partitioned LD Scores have zero variance.")
|
202
211
|
else:
|
203
|
-
ii_snp= ii_m = np.array(~ii)
|
204
|
-
print(f
|
212
|
+
ii_snp = ii_m = np.array(~ii)
|
213
|
+
print(f"Removing {sum(ii)} partitioned LD Scores with zero variance.")
|
205
214
|
ref_ld = ref_ld.iloc[:, ii_snp]
|
206
215
|
M_annot = M_annot[:, ii_m]
|
207
216
|
return M_annot, ref_ld
|
208
217
|
|
209
218
|
|
210
|
-
# Fun for reading regression weights
|
211
|
-
def which_compression(fh):
|
212
|
-
'''
|
213
|
-
Given a file prefix, figure out what sort of compression to use.
|
214
|
-
'''
|
215
|
-
if os.access(fh + '.bz2', 4):
|
216
|
-
suffix = '.bz2'
|
217
|
-
compression = 'bz2'
|
218
|
-
elif os.access(fh + '.gz', 4):
|
219
|
-
suffix = '.gz'
|
220
|
-
compression = 'gzip'
|
221
|
-
elif os.access(fh + '.parquet', 4):
|
222
|
-
suffix = '.parquet'
|
223
|
-
compression = 'parquet'
|
224
|
-
elif os.access(fh + '.feather', 4):
|
225
|
-
suffix = '.feather'
|
226
|
-
compression = 'feather'
|
227
|
-
elif os.access(fh, 4):
|
228
|
-
suffix = ''
|
229
|
-
compression = None
|
230
|
-
else:
|
231
|
-
raise IOError('Could not open {F}[./gz/bz2/parquet/feather]'.format(F=fh))
|
232
|
-
# -
|
233
|
-
return suffix, compression
|
234
|
-
|
235
|
-
|
236
219
|
def _read_w_ld(w_file):
|
237
|
-
suffix =
|
220
|
+
suffix = ".l2.ldscore"
|
238
221
|
file = w_file
|
239
|
-
first_fh = f
|
222
|
+
first_fh = f"{file}1{suffix}"
|
240
223
|
s, compression = which_compression(first_fh)
|
241
224
|
#
|
242
225
|
w_array = []
|
243
|
-
print(f
|
226
|
+
print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
|
244
227
|
|
245
228
|
for chr in range(1, 23):
|
246
|
-
file_chr = f
|
229
|
+
file_chr = f"{file}{chr}{suffix}{s}"
|
247
230
|
#
|
248
|
-
if compression ==
|
231
|
+
if compression == "parquet":
|
249
232
|
x = pd.read_parquet(file_chr)
|
250
|
-
elif compression ==
|
233
|
+
elif compression == "feather":
|
251
234
|
x = pd.read_feather(file_chr)
|
252
235
|
else:
|
253
|
-
x = pd.read_csv(file_chr, compression=compression, sep=
|
236
|
+
x = pd.read_csv(file_chr, compression=compression, sep="\t")
|
254
237
|
|
255
|
-
x = x.sort_values(by=[
|
238
|
+
x = x.sort_values(by=["CHR", "BP"])
|
256
239
|
|
257
|
-
columns_to_drop = [
|
240
|
+
columns_to_drop = ["MAF", "CM", "Gene", "TSS", "CHR", "BP"]
|
258
241
|
columns_to_drop = [col for col in columns_to_drop if col in x.columns]
|
259
242
|
x = x.drop(columns_to_drop, axis=1)
|
260
243
|
|
261
244
|
w_array.append(x)
|
262
245
|
#
|
263
246
|
w_ld = pd.concat(w_array, axis=0)
|
264
|
-
w_ld.columns = [
|
247
|
+
w_ld.columns = ["SNP", "LD_weights"]
|
265
248
|
|
266
249
|
return w_ld
|
267
250
|
|
268
251
|
|
269
252
|
# Fun for merging
|
270
253
|
def _merge_and_log(ld, sumstats, noun):
|
271
|
-
|
254
|
+
"""
|
272
255
|
Wrap smart merge with log messages about # of SNPs.
|
273
|
-
|
256
|
+
"""
|
274
257
|
sumstats = smart_merge(ld, sumstats)
|
275
|
-
msg =
|
258
|
+
msg = "After merging with {F}, {N} SNPs remain."
|
276
259
|
if len(sumstats) == 0:
|
277
260
|
raise ValueError(msg.format(N=len(sumstats), F=noun))
|
278
261
|
else:
|
@@ -282,13 +265,13 @@ def _merge_and_log(ld, sumstats, noun):
|
|
282
265
|
|
283
266
|
|
284
267
|
def smart_merge(x, y):
|
285
|
-
|
268
|
+
"""
|
286
269
|
Check if SNP columns are equal. If so, save time by using concat instead of merge.
|
287
|
-
|
270
|
+
"""
|
288
271
|
if len(x) == len(y) and (x.index == y.index).all() and (x.SNP == y.SNP).all():
|
289
272
|
x = x.reset_index(drop=True)
|
290
|
-
y = y.reset_index(drop=True).drop(
|
273
|
+
y = y.reset_index(drop=True).drop("SNP", 1)
|
291
274
|
out = pd.concat([x, y], axis=1)
|
292
275
|
else:
|
293
|
-
out = pd.merge(x, y, how=
|
276
|
+
out = pd.merge(x, y, how="inner", on="SNP")
|
294
277
|
return out
|
gsMap/visualize.py
CHANGED
@@ -11,35 +11,41 @@ from gsMap.config import VisualizeConfig
|
|
11
11
|
|
12
12
|
|
13
13
|
def load_ldsc(ldsc_input_file):
|
14
|
-
ldsc = pd.read_csv(
|
15
|
-
|
16
|
-
|
17
|
-
|
14
|
+
ldsc = pd.read_csv(
|
15
|
+
ldsc_input_file,
|
16
|
+
compression="gzip",
|
17
|
+
dtype={"spot": str, "p": float},
|
18
|
+
index_col="spot",
|
19
|
+
usecols=["spot", "p"],
|
20
|
+
)
|
21
|
+
ldsc["logp"] = -np.log10(ldsc.p)
|
18
22
|
return ldsc
|
19
23
|
|
20
24
|
|
21
25
|
# %%
|
22
26
|
def load_st_coord(adata, feature_series: pd.Series, annotation):
|
23
27
|
spot_name = adata.obs_names.to_list()
|
24
|
-
assert
|
28
|
+
assert "spatial" in adata.obsm.keys(), "spatial coordinates are not found in adata.obsm"
|
25
29
|
|
26
30
|
# to DataFrame
|
27
|
-
space_coord = adata.obsm[
|
31
|
+
space_coord = adata.obsm["spatial"]
|
28
32
|
if isinstance(space_coord, np.ndarray):
|
29
|
-
space_coord = pd.DataFrame(space_coord, columns=[
|
33
|
+
space_coord = pd.DataFrame(space_coord, columns=["sx", "sy"], index=spot_name)
|
30
34
|
else:
|
31
|
-
space_coord = pd.DataFrame(space_coord.values, columns=[
|
35
|
+
space_coord = pd.DataFrame(space_coord.values, columns=["sx", "sy"], index=spot_name)
|
32
36
|
|
33
37
|
space_coord = space_coord[space_coord.index.isin(feature_series.index)]
|
34
38
|
space_coord_concat = pd.concat([space_coord.loc[feature_series.index], feature_series], axis=1)
|
35
39
|
space_coord_concat.head()
|
36
40
|
if annotation is not None:
|
37
|
-
annotation = pd.Series(
|
41
|
+
annotation = pd.Series(
|
42
|
+
adata.obs[annotation].values, index=adata.obs_names, name="annotation"
|
43
|
+
)
|
38
44
|
space_coord_concat = pd.concat([space_coord_concat, annotation], axis=1)
|
39
45
|
return space_coord_concat
|
40
46
|
|
41
47
|
|
42
|
-
def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH
|
48
|
+
def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH=1000):
|
43
49
|
tree = KDTree(coordinates)
|
44
50
|
distances, _ = tree.query(coordinates, k=2)
|
45
51
|
avg_min_distance = np.mean(distances[:, 1])
|
@@ -55,34 +61,42 @@ def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH = 1000):
|
|
55
61
|
return (pixel_width, pixel_height), point_size
|
56
62
|
|
57
63
|
|
58
|
-
def draw_scatter(
|
59
|
-
|
64
|
+
def draw_scatter(
|
65
|
+
space_coord_concat,
|
66
|
+
title=None,
|
67
|
+
fig_style: Literal["dark", "light"] = "light",
|
68
|
+
point_size: int = None,
|
69
|
+
width=800,
|
70
|
+
height=600,
|
71
|
+
annotation=None,
|
72
|
+
color_by="logp",
|
73
|
+
):
|
60
74
|
# Set theme based on fig_style
|
61
|
-
if fig_style ==
|
75
|
+
if fig_style == "dark":
|
62
76
|
px.defaults.template = "plotly_dark"
|
63
77
|
else:
|
64
78
|
px.defaults.template = "plotly_white"
|
65
79
|
|
66
80
|
custom_color_scale = [
|
67
|
-
(1,
|
68
|
-
(7 / 8,
|
69
|
-
(6 / 8,
|
70
|
-
(5 / 8,
|
71
|
-
(4 / 8,
|
72
|
-
(3 / 8,
|
73
|
-
(2 / 8,
|
74
|
-
(1 / 8,
|
75
|
-
(0,
|
81
|
+
(1, "#d73027"), # Red
|
82
|
+
(7 / 8, "#f46d43"), # Red-Orange
|
83
|
+
(6 / 8, "#fdae61"), # Orange
|
84
|
+
(5 / 8, "#fee090"), # Light Orange
|
85
|
+
(4 / 8, "#e0f3f8"), # Light Blue
|
86
|
+
(3 / 8, "#abd9e9"), # Sky Blue
|
87
|
+
(2 / 8, "#74add1"), # Medium Blue
|
88
|
+
(1 / 8, "#4575b4"), # Dark Blue
|
89
|
+
(0, "#313695"), # Deep Blue
|
76
90
|
]
|
77
91
|
custom_color_scale.reverse()
|
78
92
|
|
79
93
|
# Create the scatter plot
|
80
94
|
fig = px.scatter(
|
81
95
|
space_coord_concat,
|
82
|
-
x=
|
83
|
-
y=
|
96
|
+
x="sx",
|
97
|
+
y="sy",
|
84
98
|
color=color_by,
|
85
|
-
symbol=
|
99
|
+
symbol="annotation" if annotation is not None else None,
|
86
100
|
title=title,
|
87
101
|
color_continuous_scale=custom_color_scale,
|
88
102
|
range_color=[0, max(space_coord_concat[color_by])],
|
@@ -90,7 +104,7 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
|
|
90
104
|
|
91
105
|
# Update marker size if specified
|
92
106
|
if point_size is not None:
|
93
|
-
fig.update_traces(marker=dict(size=point_size, symbol=
|
107
|
+
fig.update_traces(marker=dict(size=point_size, symbol="circle"))
|
94
108
|
|
95
109
|
# Update layout for figure size
|
96
110
|
fig.update_layout(
|
@@ -108,45 +122,45 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
|
|
108
122
|
x=1.0,
|
109
123
|
font=dict(
|
110
124
|
size=10,
|
111
|
-
)
|
125
|
+
),
|
112
126
|
)
|
113
127
|
)
|
114
128
|
|
115
129
|
# Update colorbar to be at the bottom and horizontal
|
116
130
|
fig.update_layout(
|
117
131
|
coloraxis_colorbar=dict(
|
118
|
-
orientation=
|
132
|
+
orientation="h", # Make the colorbar horizontal
|
119
133
|
x=0.5, # Center the colorbar horizontally
|
120
134
|
y=-0.0, # Position below the plot
|
121
|
-
xanchor=
|
122
|
-
yanchor=
|
135
|
+
xanchor="center", # Anchor the colorbar at the center
|
136
|
+
yanchor="top", # Anchor the colorbar at the top to keep it just below the plot
|
123
137
|
len=0.75, # Length of the colorbar relative to the plot width
|
124
138
|
title=dict(
|
125
|
-
text=
|
126
|
-
side=
|
127
|
-
)
|
139
|
+
text="-log10(p)" if color_by == "logp" else color_by, # Colorbar title
|
140
|
+
side="top", # Place the title at the top of the colorbar
|
141
|
+
),
|
128
142
|
)
|
129
143
|
)
|
130
144
|
# Remove gridlines, axis labels, and ticks
|
131
145
|
fig.update_xaxes(
|
132
|
-
showgrid=False,
|
133
|
-
zeroline=False,
|
146
|
+
showgrid=False, # Hide x-axis gridlines
|
147
|
+
zeroline=False, # Hide x-axis zero line
|
134
148
|
showticklabels=False, # Hide x-axis tick labels
|
135
|
-
title=None,
|
136
|
-
scaleanchor=
|
149
|
+
title=None, # Remove x-axis title
|
150
|
+
scaleanchor="y", # Link the x-axis scale to the y-axis scale
|
137
151
|
)
|
138
152
|
|
139
153
|
fig.update_yaxes(
|
140
|
-
showgrid=False,
|
141
|
-
zeroline=False,
|
154
|
+
showgrid=False, # Hide y-axis gridlines
|
155
|
+
zeroline=False, # Hide y-axis zero line
|
142
156
|
showticklabels=False, # Hide y-axis tick labels
|
143
|
-
title=None
|
157
|
+
title=None, # Remove y-axis title
|
144
158
|
)
|
145
159
|
|
146
160
|
# Adjust margins to ensure no clipping and equal axis ratio
|
147
161
|
fig.update_layout(
|
148
162
|
margin=dict(l=0, r=0, t=20, b=10), # Adjust margins to prevent clipping
|
149
|
-
height=width # Ensure the figure height matches the width for equal axis ratio
|
163
|
+
height=width, # Ensure the figure height matches the width for equal axis ratio
|
150
164
|
)
|
151
165
|
|
152
166
|
# Adjust the title location and font size
|
@@ -154,46 +168,50 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
|
|
154
168
|
title=dict(
|
155
169
|
y=0.98,
|
156
170
|
x=0.5, # Center the title horizontally
|
157
|
-
xanchor=
|
158
|
-
yanchor=
|
171
|
+
xanchor="center", # Anchor the title at the center
|
172
|
+
yanchor="top", # Anchor the title at the top
|
159
173
|
font=dict(
|
160
174
|
size=20 # Increase the title font size
|
161
|
-
)
|
162
|
-
)
|
175
|
+
),
|
176
|
+
)
|
177
|
+
)
|
163
178
|
|
164
179
|
return fig
|
165
180
|
|
166
181
|
|
167
|
-
|
168
182
|
def run_Visualize(config: VisualizeConfig):
|
169
|
-
print(f
|
170
|
-
ldsc = load_ldsc(
|
183
|
+
print(f"------Loading LDSC results of {config.ldsc_save_dir}...")
|
184
|
+
ldsc = load_ldsc(
|
185
|
+
ldsc_input_file=Path(config.ldsc_save_dir)
|
186
|
+
/ f"{config.sample_name}_{config.trait_name}.csv.gz"
|
187
|
+
)
|
171
188
|
|
172
|
-
print(f
|
173
|
-
adata = sc.read_h5ad(f
|
189
|
+
print(f"------Loading ST data of {config.sample_name}...")
|
190
|
+
adata = sc.read_h5ad(f"{config.hdf5_with_latent_path}")
|
174
191
|
|
175
192
|
space_coord_concat = load_st_coord(adata, ldsc, annotation=config.annotation)
|
176
|
-
fig = draw_scatter(
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
193
|
+
fig = draw_scatter(
|
194
|
+
space_coord_concat,
|
195
|
+
title=config.fig_title,
|
196
|
+
fig_style=config.fig_style,
|
197
|
+
point_size=config.point_size,
|
198
|
+
width=config.fig_width,
|
199
|
+
height=config.fig_height,
|
200
|
+
annotation=config.annotation,
|
201
|
+
)
|
185
202
|
|
186
203
|
# Visualization
|
187
204
|
output_dir = Path(config.output_dir)
|
188
205
|
output_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
|
189
|
-
output_file_html = output_dir / f
|
190
|
-
output_file_pdf = output_dir / f
|
191
|
-
output_file_csv = output_dir / f
|
206
|
+
output_file_html = output_dir / f"{config.sample_name}_{config.trait_name}.html"
|
207
|
+
output_file_pdf = output_dir / f"{config.sample_name}_{config.trait_name}.pdf"
|
208
|
+
output_file_csv = output_dir / f"{config.sample_name}_{config.trait_name}.csv"
|
192
209
|
|
193
210
|
fig.write_html(str(output_file_html))
|
194
211
|
fig.write_image(str(output_file_pdf))
|
195
212
|
space_coord_concat.to_csv(str(output_file_csv))
|
196
213
|
|
197
214
|
print(
|
198
|
-
f
|
199
|
-
|
215
|
+
f"------The visualization result is saved in a html file: {output_file_html} which can interactively viewed in a web browser and a pdf file: {output_file_pdf}."
|
216
|
+
)
|
217
|
+
print(f"------The visualization data is saved in a csv file: {output_file_csv}.")
|