gsMap 1.71.2__py3-none-any.whl → 1.72.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,55 +1,55 @@
1
+ import os
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
- import os
4
5
 
5
6
 
6
7
  # Fun for reading gwas data
7
8
  def _read_sumstats(fh, alleles=False, dropna=False):
8
- '''
9
+ """
9
10
  Parse gwas summary statistics.
10
- '''
11
- print('Reading summary statistics from {S} ...'.format(S=fh))
11
+ """
12
+ print(f"Reading summary statistics from {fh} ...")
12
13
  sumstats = ps_sumstats(fh, alleles=alleles, dropna=dropna)
13
- print('Read summary statistics for {N} SNPs.'.format(N=len(sumstats)))
14
+ print(f"Read summary statistics for {len(sumstats)} SNPs.")
14
15
 
15
16
  m = len(sumstats)
16
- sumstats = sumstats.drop_duplicates(subset='SNP')
17
+ sumstats = sumstats.drop_duplicates(subset="SNP")
17
18
  if m > len(sumstats):
18
- print('Dropped {M} SNPs with duplicated rs numbers.'.format(M=m - len(sumstats)))
19
+ print(f"Dropped {m - len(sumstats)} SNPs with duplicated rs numbers.")
19
20
 
20
21
  return sumstats
21
22
 
22
23
 
23
24
  def ps_sumstats(fh, alleles=False, dropna=True):
24
- '''
25
+ """
25
26
  Parses .sumstats files. See docs/file_formats_sumstats.txt.
26
- '''
27
-
28
- dtype_dict = {'SNP': str, 'Z': float, 'N': float, 'A1': str, 'A2': str}
27
+ """
28
+ dtype_dict = {"SNP": str, "Z": float, "N": float, "A1": str, "A2": str}
29
29
  compression = get_compression(fh)
30
- usecols = ['SNP', 'Z', 'N']
30
+ usecols = ["SNP", "Z", "N"]
31
31
  if alleles:
32
- usecols += ['A1', 'A2']
32
+ usecols += ["A1", "A2"]
33
33
 
34
34
  try:
35
35
  x = read_csv(fh, usecols=usecols, dtype=dtype_dict, compression=compression)
36
36
  except (AttributeError, ValueError) as e:
37
- raise ValueError('Improperly formatted sumstats file: ' + str(e.args))
37
+ raise ValueError("Improperly formatted sumstats file: " + str(e.args)) from e
38
38
 
39
39
  if dropna:
40
- x = x.dropna(how='any')
40
+ x = x.dropna(how="any")
41
41
 
42
42
  return x
43
43
 
44
44
 
45
45
  def get_compression(fh):
46
- '''
46
+ """
47
47
  Determin the format of compression used with read_csv?
48
- '''
49
- if fh.endswith('gz'):
50
- compression = 'gzip'
51
- elif fh.endswith('bz2'):
52
- compression = 'bz2'
48
+ """
49
+ if fh.endswith("gz"):
50
+ compression = "gzip"
51
+ elif fh.endswith("bz2"):
52
+ compression = "bz2"
53
53
  else:
54
54
  compression = None
55
55
  # -
@@ -57,60 +57,60 @@ def get_compression(fh):
57
57
 
58
58
 
59
59
  def read_csv(fh, **kwargs):
60
- '''
60
+ """
61
61
  Read the csv data
62
- '''
63
- return pd.read_csv(fh, sep='\s+', na_values='.', **kwargs)
62
+ """
63
+ return pd.read_csv(fh, sep=r"\s+", na_values=".", **kwargs)
64
64
 
65
65
 
66
- # Fun for reading loading LD scores
66
+ # Fun for reading loading LD scores
67
67
  def which_compression(fh):
68
- '''
68
+ """
69
69
  Given a file prefix, figure out what sort of compression to use.
70
- '''
71
- if os.access(fh + '.bz2', 4):
72
- suffix = '.bz2'
73
- compression = 'bz2'
74
- elif os.access(fh + '.gz', 4):
75
- suffix = '.gz'
76
- compression = 'gzip'
77
- elif os.access(fh + '.parquet', 4):
78
- suffix = '.parquet'
79
- compression = 'parquet'
80
- elif os.access(fh + '.feather', 4):
81
- suffix = '.feather'
82
- compression = 'feather'
70
+ """
71
+ if os.access(fh + ".bz2", 4):
72
+ suffix = ".bz2"
73
+ compression = "bz2"
74
+ elif os.access(fh + ".gz", 4):
75
+ suffix = ".gz"
76
+ compression = "gzip"
77
+ elif os.access(fh + ".parquet", 4):
78
+ suffix = ".parquet"
79
+ compression = "parquet"
80
+ elif os.access(fh + ".feather", 4):
81
+ suffix = ".feather"
82
+ compression = "feather"
83
83
  elif os.access(fh, 4):
84
- suffix = ''
84
+ suffix = ""
85
85
  compression = None
86
86
  else:
87
- raise IOError('Could not open {F}[./gz/bz2/parquet/feather]'.format(F=fh))
87
+ raise OSError(f"Could not open {fh}[./gz/bz2/parquet/feather]")
88
88
  # -
89
89
  return suffix, compression
90
90
 
91
91
 
92
92
  def _read_ref_ld(ld_file):
93
- suffix = '.l2.ldscore'
93
+ suffix = ".l2.ldscore"
94
94
  file = ld_file
95
- first_fh = f'{file}1{suffix}'
95
+ first_fh = f"{file}1{suffix}"
96
96
  s, compression = which_compression(first_fh)
97
97
  #
98
98
  ldscore_array = []
99
- print(f'Reading ld score annotations from {file}[1-22]{suffix}.{compression}')
99
+ print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
100
100
 
101
101
  for chr in range(1, 23):
102
- file_chr = f'{file}{chr}{suffix}{s}'
102
+ file_chr = f"{file}{chr}{suffix}{s}"
103
103
  #
104
- if compression == 'parquet':
104
+ if compression == "parquet":
105
105
  x = pd.read_parquet(file_chr)
106
- elif compression == 'feather':
106
+ elif compression == "feather":
107
107
  x = pd.read_feather(file_chr)
108
108
  else:
109
- x = pd.read_csv(file_chr, compression=compression, sep='\t')
109
+ x = pd.read_csv(file_chr, compression=compression, sep="\t")
110
110
 
111
- x = x.sort_values(by=['CHR', 'BP']) # SEs will be wrong unless sorted
111
+ x = x.sort_values(by=["CHR", "BP"]) # SEs will be wrong unless sorted
112
112
 
113
- columns_to_drop = ['MAF', 'CM', 'Gene', 'TSS', 'CHR', 'BP']
113
+ columns_to_drop = ["MAF", "CM", "Gene", "TSS", "CHR", "BP"]
114
114
  columns_to_drop = [col for col in columns_to_drop if col in x.columns]
115
115
  x = x.drop(columns_to_drop, axis=1)
116
116
 
@@ -121,56 +121,63 @@ def _read_ref_ld(ld_file):
121
121
 
122
122
 
123
123
  def _read_ref_ld_v2(ld_file):
124
- suffix = '.l2.ldscore'
124
+ suffix = ".l2.ldscore"
125
125
  file = ld_file
126
- first_fh = f'{file}1{suffix}'
126
+ first_fh = f"{file}1{suffix}"
127
127
  s, compression = which_compression(first_fh)
128
- print(f'Reading ld score annotations from {file}[1-22]{suffix}.{compression}')
128
+ print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
129
129
  ref_ld = pd.concat(
130
- [pd.read_feather(f'{file}{chr}{suffix}{s}') for chr in range(1, 23)], axis=0
130
+ [pd.read_feather(f"{file}{chr}{suffix}{s}") for chr in range(1, 23)], axis=0
131
131
  )
132
132
  # set first column as index
133
- ref_ld.rename(columns={'index': 'SNP'}, inplace=True)
134
- ref_ld.set_index('SNP', inplace=True)
133
+ ref_ld.rename(columns={"index": "SNP"}, inplace=True)
134
+ ref_ld.set_index("SNP", inplace=True)
135
135
  return ref_ld
136
136
 
137
+
137
138
  def _read_M_v2(ld_file, n_annot, not_M_5_50):
138
- suffix = '.l2.M'
139
+ suffix = ".l2.M"
139
140
  if not not_M_5_50:
140
- suffix += '_5_50'
141
- M_annot= np.array(
141
+ suffix += "_5_50"
142
+ M_annot = np.array(
142
143
  [
143
- np.loadtxt(f'{ld_file}{chr}{suffix}', )
144
- for chr in range(1, 23)]
145
-
144
+ np.loadtxt(
145
+ f"{ld_file}{chr}{suffix}",
146
+ )
147
+ for chr in range(1, 23)
148
+ ]
146
149
  )
147
150
  assert M_annot.shape == (22, n_annot)
148
151
  return M_annot.sum(axis=0).reshape((1, n_annot))
149
- # Fun for reading M annotations
152
+
153
+
154
+ # Fun for reading M annotations
150
155
  def _read_M(ld_file, n_annot, not_M_5_50):
151
- '''
156
+ """
152
157
  Read M (--M, --M-file, etc).
153
- '''
158
+ """
154
159
  M_annot = M(ld_file, common=(not not_M_5_50))
155
160
 
156
161
  try:
157
162
  M_annot = np.array(M_annot).reshape((1, n_annot))
158
163
  except ValueError as e:
159
- raise ValueError('# terms in --M must match # of LD Scores in --ref-ld.\n' + str(e.args))
164
+ raise ValueError(
165
+ "# terms in --M must match # of LD Scores in --ref-ld.\n" + str(e.args)
166
+ ) from e
160
167
  return M_annot
161
168
 
162
169
 
163
170
  def M(fh, common=False):
164
- '''
171
+ """
165
172
  Parses .l{N}.M files, split across num chromosomes.
166
- '''
167
- suffix = '.l2.M'
173
+ """
174
+ suffix = ".l2.M"
168
175
  if common:
169
- suffix += '_5_50'
176
+ suffix += "_5_50"
170
177
  # -
171
178
  M_array = []
172
179
  for i in range(1, 23):
173
- M_current = pd.read_csv(f'{fh}{i}' + suffix, header=None)
180
+ M_current = pd.read_csv(f"{fh}{i}" + suffix, header=None)
174
181
  M_array.append(M_current)
175
182
 
176
183
  M_array = pd.concat(M_array, axis=1).sum(axis=1)
@@ -179,100 +186,76 @@ def M(fh, common=False):
179
186
 
180
187
 
181
188
  def _check_variance(M_annot, ref_ld):
182
- '''
189
+ """
183
190
  Remove zero-variance LD Scores.
184
- '''
191
+ """
185
192
  ii = ref_ld.iloc[:, 1:].var() == 0 # NB there is a SNP column here
186
193
  if ii.all():
187
- raise ValueError('All LD Scores have zero variance.')
194
+ raise ValueError("All LD Scores have zero variance.")
188
195
  else:
189
- print('Removing partitioned LD Scores with zero variance.')
196
+ print("Removing partitioned LD Scores with zero variance.")
190
197
  ii_snp = np.array([True] + list(~ii))
191
198
  ii_m = np.array(~ii)
192
199
  ref_ld = ref_ld.iloc[:, ii_snp]
193
200
  M_annot = M_annot[:, ii_m]
194
201
  # -
195
202
  return M_annot, ref_ld, ii
203
+
204
+
196
205
  def _check_variance_v2(M_annot, ref_ld):
197
206
  ii = ref_ld.var() == 0
198
207
  if ii.all():
199
- raise ValueError('All LD Scores have zero variance.')
208
+ raise ValueError("All LD Scores have zero variance.")
200
209
  elif not ii.any():
201
- print('No partitioned LD Scores have zero variance.')
210
+ print("No partitioned LD Scores have zero variance.")
202
211
  else:
203
- ii_snp= ii_m = np.array(~ii)
204
- print(f'Removing {sum(ii)} partitioned LD Scores with zero variance.')
212
+ ii_snp = ii_m = np.array(~ii)
213
+ print(f"Removing {sum(ii)} partitioned LD Scores with zero variance.")
205
214
  ref_ld = ref_ld.iloc[:, ii_snp]
206
215
  M_annot = M_annot[:, ii_m]
207
216
  return M_annot, ref_ld
208
217
 
209
218
 
210
- # Fun for reading regression weights
211
- def which_compression(fh):
212
- '''
213
- Given a file prefix, figure out what sort of compression to use.
214
- '''
215
- if os.access(fh + '.bz2', 4):
216
- suffix = '.bz2'
217
- compression = 'bz2'
218
- elif os.access(fh + '.gz', 4):
219
- suffix = '.gz'
220
- compression = 'gzip'
221
- elif os.access(fh + '.parquet', 4):
222
- suffix = '.parquet'
223
- compression = 'parquet'
224
- elif os.access(fh + '.feather', 4):
225
- suffix = '.feather'
226
- compression = 'feather'
227
- elif os.access(fh, 4):
228
- suffix = ''
229
- compression = None
230
- else:
231
- raise IOError('Could not open {F}[./gz/bz2/parquet/feather]'.format(F=fh))
232
- # -
233
- return suffix, compression
234
-
235
-
236
219
  def _read_w_ld(w_file):
237
- suffix = '.l2.ldscore'
220
+ suffix = ".l2.ldscore"
238
221
  file = w_file
239
- first_fh = f'{file}1{suffix}'
222
+ first_fh = f"{file}1{suffix}"
240
223
  s, compression = which_compression(first_fh)
241
224
  #
242
225
  w_array = []
243
- print(f'Reading ld score annotations from {file}[1-22]{suffix}.{compression}')
226
+ print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
244
227
 
245
228
  for chr in range(1, 23):
246
- file_chr = f'{file}{chr}{suffix}{s}'
229
+ file_chr = f"{file}{chr}{suffix}{s}"
247
230
  #
248
- if compression == 'parquet':
231
+ if compression == "parquet":
249
232
  x = pd.read_parquet(file_chr)
250
- elif compression == 'feather':
233
+ elif compression == "feather":
251
234
  x = pd.read_feather(file_chr)
252
235
  else:
253
- x = pd.read_csv(file_chr, compression=compression, sep='\t')
236
+ x = pd.read_csv(file_chr, compression=compression, sep="\t")
254
237
 
255
- x = x.sort_values(by=['CHR', 'BP'])
238
+ x = x.sort_values(by=["CHR", "BP"])
256
239
 
257
- columns_to_drop = ['MAF', 'CM', 'Gene', 'TSS', 'CHR', 'BP']
240
+ columns_to_drop = ["MAF", "CM", "Gene", "TSS", "CHR", "BP"]
258
241
  columns_to_drop = [col for col in columns_to_drop if col in x.columns]
259
242
  x = x.drop(columns_to_drop, axis=1)
260
243
 
261
244
  w_array.append(x)
262
245
  #
263
246
  w_ld = pd.concat(w_array, axis=0)
264
- w_ld.columns = ['SNP', 'LD_weights']
247
+ w_ld.columns = ["SNP", "LD_weights"]
265
248
 
266
249
  return w_ld
267
250
 
268
251
 
269
252
  # Fun for merging
270
253
  def _merge_and_log(ld, sumstats, noun):
271
- '''
254
+ """
272
255
  Wrap smart merge with log messages about # of SNPs.
273
- '''
256
+ """
274
257
  sumstats = smart_merge(ld, sumstats)
275
- msg = 'After merging with {F}, {N} SNPs remain.'
258
+ msg = "After merging with {F}, {N} SNPs remain."
276
259
  if len(sumstats) == 0:
277
260
  raise ValueError(msg.format(N=len(sumstats), F=noun))
278
261
  else:
@@ -282,13 +265,13 @@ def _merge_and_log(ld, sumstats, noun):
282
265
 
283
266
 
284
267
  def smart_merge(x, y):
285
- '''
268
+ """
286
269
  Check if SNP columns are equal. If so, save time by using concat instead of merge.
287
- '''
270
+ """
288
271
  if len(x) == len(y) and (x.index == y.index).all() and (x.SNP == y.SNP).all():
289
272
  x = x.reset_index(drop=True)
290
- y = y.reset_index(drop=True).drop('SNP', 1)
273
+ y = y.reset_index(drop=True).drop("SNP", 1)
291
274
  out = pd.concat([x, y], axis=1)
292
275
  else:
293
- out = pd.merge(x, y, how='inner', on='SNP')
276
+ out = pd.merge(x, y, how="inner", on="SNP")
294
277
  return out
gsMap/visualize.py CHANGED
@@ -11,35 +11,41 @@ from gsMap.config import VisualizeConfig
11
11
 
12
12
 
13
13
  def load_ldsc(ldsc_input_file):
14
- ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
15
- ldsc.spot = ldsc.spot.astype(str).replace('\.0', '', regex=True)
16
- ldsc.index = ldsc.spot
17
- ldsc['logp'] = -np.log10(ldsc.p)
14
+ ldsc = pd.read_csv(
15
+ ldsc_input_file,
16
+ compression="gzip",
17
+ dtype={"spot": str, "p": float},
18
+ index_col="spot",
19
+ usecols=["spot", "p"],
20
+ )
21
+ ldsc["logp"] = -np.log10(ldsc.p)
18
22
  return ldsc
19
23
 
20
24
 
21
25
  # %%
22
26
  def load_st_coord(adata, feature_series: pd.Series, annotation):
23
27
  spot_name = adata.obs_names.to_list()
24
- assert 'spatial' in adata.obsm.keys(), 'spatial coordinates are not found in adata.obsm'
28
+ assert "spatial" in adata.obsm.keys(), "spatial coordinates are not found in adata.obsm"
25
29
 
26
30
  # to DataFrame
27
- space_coord = adata.obsm['spatial']
31
+ space_coord = adata.obsm["spatial"]
28
32
  if isinstance(space_coord, np.ndarray):
29
- space_coord = pd.DataFrame(space_coord, columns=['sx', 'sy'], index=spot_name)
33
+ space_coord = pd.DataFrame(space_coord, columns=["sx", "sy"], index=spot_name)
30
34
  else:
31
- space_coord = pd.DataFrame(space_coord.values, columns=['sx', 'sy'], index=spot_name)
35
+ space_coord = pd.DataFrame(space_coord.values, columns=["sx", "sy"], index=spot_name)
32
36
 
33
37
  space_coord = space_coord[space_coord.index.isin(feature_series.index)]
34
38
  space_coord_concat = pd.concat([space_coord.loc[feature_series.index], feature_series], axis=1)
35
39
  space_coord_concat.head()
36
40
  if annotation is not None:
37
- annotation = pd.Series(adata.obs[annotation].values, index=adata.obs_names, name='annotation')
41
+ annotation = pd.Series(
42
+ adata.obs[annotation].values, index=adata.obs_names, name="annotation"
43
+ )
38
44
  space_coord_concat = pd.concat([space_coord_concat, annotation], axis=1)
39
45
  return space_coord_concat
40
46
 
41
47
 
42
- def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH = 1000):
48
+ def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH=1000):
43
49
  tree = KDTree(coordinates)
44
50
  distances, _ = tree.query(coordinates, k=2)
45
51
  avg_min_distance = np.mean(distances[:, 1])
@@ -55,34 +61,42 @@ def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH = 1000):
55
61
  return (pixel_width, pixel_height), point_size
56
62
 
57
63
 
58
- def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'light'] = 'light',
59
- point_size: int = None, width=800, height=600, annotation=None, color_by='logp'):
64
+ def draw_scatter(
65
+ space_coord_concat,
66
+ title=None,
67
+ fig_style: Literal["dark", "light"] = "light",
68
+ point_size: int = None,
69
+ width=800,
70
+ height=600,
71
+ annotation=None,
72
+ color_by="logp",
73
+ ):
60
74
  # Set theme based on fig_style
61
- if fig_style == 'dark':
75
+ if fig_style == "dark":
62
76
  px.defaults.template = "plotly_dark"
63
77
  else:
64
78
  px.defaults.template = "plotly_white"
65
79
 
66
80
  custom_color_scale = [
67
- (1, '#d73027'), # Red
68
- (7 / 8, '#f46d43'), # Red-Orange
69
- (6 / 8, '#fdae61'), # Orange
70
- (5 / 8, '#fee090'), # Light Orange
71
- (4 / 8, '#e0f3f8'), # Light Blue
72
- (3 / 8, '#abd9e9'), # Sky Blue
73
- (2 / 8, '#74add1'), # Medium Blue
74
- (1 / 8, '#4575b4'), # Dark Blue
75
- (0, '#313695') # Deep Blue
81
+ (1, "#d73027"), # Red
82
+ (7 / 8, "#f46d43"), # Red-Orange
83
+ (6 / 8, "#fdae61"), # Orange
84
+ (5 / 8, "#fee090"), # Light Orange
85
+ (4 / 8, "#e0f3f8"), # Light Blue
86
+ (3 / 8, "#abd9e9"), # Sky Blue
87
+ (2 / 8, "#74add1"), # Medium Blue
88
+ (1 / 8, "#4575b4"), # Dark Blue
89
+ (0, "#313695"), # Deep Blue
76
90
  ]
77
91
  custom_color_scale.reverse()
78
92
 
79
93
  # Create the scatter plot
80
94
  fig = px.scatter(
81
95
  space_coord_concat,
82
- x='sx',
83
- y='sy',
96
+ x="sx",
97
+ y="sy",
84
98
  color=color_by,
85
- symbol='annotation' if annotation is not None else None,
99
+ symbol="annotation" if annotation is not None else None,
86
100
  title=title,
87
101
  color_continuous_scale=custom_color_scale,
88
102
  range_color=[0, max(space_coord_concat[color_by])],
@@ -90,7 +104,7 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
90
104
 
91
105
  # Update marker size if specified
92
106
  if point_size is not None:
93
- fig.update_traces(marker=dict(size=point_size, symbol='circle'))
107
+ fig.update_traces(marker=dict(size=point_size, symbol="circle"))
94
108
 
95
109
  # Update layout for figure size
96
110
  fig.update_layout(
@@ -108,45 +122,45 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
108
122
  x=1.0,
109
123
  font=dict(
110
124
  size=10,
111
- )
125
+ ),
112
126
  )
113
127
  )
114
128
 
115
129
  # Update colorbar to be at the bottom and horizontal
116
130
  fig.update_layout(
117
131
  coloraxis_colorbar=dict(
118
- orientation='h', # Make the colorbar horizontal
132
+ orientation="h", # Make the colorbar horizontal
119
133
  x=0.5, # Center the colorbar horizontally
120
134
  y=-0.0, # Position below the plot
121
- xanchor='center', # Anchor the colorbar at the center
122
- yanchor='top', # Anchor the colorbar at the top to keep it just below the plot
135
+ xanchor="center", # Anchor the colorbar at the center
136
+ yanchor="top", # Anchor the colorbar at the top to keep it just below the plot
123
137
  len=0.75, # Length of the colorbar relative to the plot width
124
138
  title=dict(
125
- text='-log10(p)' if color_by == 'logp' else color_by, # Colorbar title
126
- side='top' # Place the title at the top of the colorbar
127
- )
139
+ text="-log10(p)" if color_by == "logp" else color_by, # Colorbar title
140
+ side="top", # Place the title at the top of the colorbar
141
+ ),
128
142
  )
129
143
  )
130
144
  # Remove gridlines, axis labels, and ticks
131
145
  fig.update_xaxes(
132
- showgrid=False, # Hide x-axis gridlines
133
- zeroline=False, # Hide x-axis zero line
146
+ showgrid=False, # Hide x-axis gridlines
147
+ zeroline=False, # Hide x-axis zero line
134
148
  showticklabels=False, # Hide x-axis tick labels
135
- title=None, # Remove x-axis title
136
- scaleanchor='y', # Link the x-axis scale to the y-axis scale
149
+ title=None, # Remove x-axis title
150
+ scaleanchor="y", # Link the x-axis scale to the y-axis scale
137
151
  )
138
152
 
139
153
  fig.update_yaxes(
140
- showgrid=False, # Hide y-axis gridlines
141
- zeroline=False, # Hide y-axis zero line
154
+ showgrid=False, # Hide y-axis gridlines
155
+ zeroline=False, # Hide y-axis zero line
142
156
  showticklabels=False, # Hide y-axis tick labels
143
- title=None # Remove y-axis title
157
+ title=None, # Remove y-axis title
144
158
  )
145
159
 
146
160
  # Adjust margins to ensure no clipping and equal axis ratio
147
161
  fig.update_layout(
148
162
  margin=dict(l=0, r=0, t=20, b=10), # Adjust margins to prevent clipping
149
- height=width # Ensure the figure height matches the width for equal axis ratio
163
+ height=width, # Ensure the figure height matches the width for equal axis ratio
150
164
  )
151
165
 
152
166
  # Adjust the title location and font size
@@ -154,46 +168,50 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
154
168
  title=dict(
155
169
  y=0.98,
156
170
  x=0.5, # Center the title horizontally
157
- xanchor='center', # Anchor the title at the center
158
- yanchor='top', # Anchor the title at the top
171
+ xanchor="center", # Anchor the title at the center
172
+ yanchor="top", # Anchor the title at the top
159
173
  font=dict(
160
174
  size=20 # Increase the title font size
161
- )
162
- ))
175
+ ),
176
+ )
177
+ )
163
178
 
164
179
  return fig
165
180
 
166
181
 
167
-
168
182
  def run_Visualize(config: VisualizeConfig):
169
- print(f'------Loading LDSC results of {config.ldsc_save_dir}...')
170
- ldsc = load_ldsc(ldsc_input_file=Path(config.ldsc_save_dir) / f'{config.sample_name}_{config.trait_name}.csv.gz')
183
+ print(f"------Loading LDSC results of {config.ldsc_save_dir}...")
184
+ ldsc = load_ldsc(
185
+ ldsc_input_file=Path(config.ldsc_save_dir)
186
+ / f"{config.sample_name}_{config.trait_name}.csv.gz"
187
+ )
171
188
 
172
- print(f'------Loading ST data of {config.sample_name}...')
173
- adata = sc.read_h5ad(f'{config.hdf5_with_latent_path}')
189
+ print(f"------Loading ST data of {config.sample_name}...")
190
+ adata = sc.read_h5ad(f"{config.hdf5_with_latent_path}")
174
191
 
175
192
  space_coord_concat = load_st_coord(adata, ldsc, annotation=config.annotation)
176
- fig = draw_scatter(space_coord_concat,
177
- title=config.fig_title,
178
- fig_style=config.fig_style,
179
- point_size=config.point_size,
180
- width=config.fig_width,
181
- height=config.fig_height,
182
- annotation=config.annotation
183
- )
184
-
193
+ fig = draw_scatter(
194
+ space_coord_concat,
195
+ title=config.fig_title,
196
+ fig_style=config.fig_style,
197
+ point_size=config.point_size,
198
+ width=config.fig_width,
199
+ height=config.fig_height,
200
+ annotation=config.annotation,
201
+ )
185
202
 
186
203
  # Visualization
187
204
  output_dir = Path(config.output_dir)
188
205
  output_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
189
- output_file_html = output_dir / f'{config.sample_name}_{config.trait_name}.html'
190
- output_file_pdf = output_dir / f'{config.sample_name}_{config.trait_name}.pdf'
191
- output_file_csv = output_dir / f'{config.sample_name}_{config.trait_name}.csv'
206
+ output_file_html = output_dir / f"{config.sample_name}_{config.trait_name}.html"
207
+ output_file_pdf = output_dir / f"{config.sample_name}_{config.trait_name}.pdf"
208
+ output_file_csv = output_dir / f"{config.sample_name}_{config.trait_name}.csv"
192
209
 
193
210
  fig.write_html(str(output_file_html))
194
211
  fig.write_image(str(output_file_pdf))
195
212
  space_coord_concat.to_csv(str(output_file_csv))
196
213
 
197
214
  print(
198
- f'------The visualization result is saved in a html file: {output_file_html} which can interactively viewed in a web browser and a pdf file: {output_file_pdf}.')
199
- print(f'------The visualization data is saved in a csv file: {output_file_csv}.')
215
+ f"------The visualization result is saved in a html file: {output_file_html} which can interactively viewed in a web browser and a pdf file: {output_file_pdf}."
216
+ )
217
+ print(f"------The visualization data is saved in a csv file: {output_file_csv}.")