gsMap 1.71.2__py3-none-any.whl → 1.73.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,55 +1,55 @@
1
+ import os
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
- import os
4
5
 
5
6
 
6
7
  # Fun for reading gwas data
7
8
  def _read_sumstats(fh, alleles=False, dropna=False):
8
- '''
9
+ """
9
10
  Parse gwas summary statistics.
10
- '''
11
- print('Reading summary statistics from {S} ...'.format(S=fh))
11
+ """
12
+ print(f"Reading summary statistics from {fh} ...")
12
13
  sumstats = ps_sumstats(fh, alleles=alleles, dropna=dropna)
13
- print('Read summary statistics for {N} SNPs.'.format(N=len(sumstats)))
14
+ print(f"Read summary statistics for {len(sumstats)} SNPs.")
14
15
 
15
16
  m = len(sumstats)
16
- sumstats = sumstats.drop_duplicates(subset='SNP')
17
+ sumstats = sumstats.drop_duplicates(subset="SNP")
17
18
  if m > len(sumstats):
18
- print('Dropped {M} SNPs with duplicated rs numbers.'.format(M=m - len(sumstats)))
19
+ print(f"Dropped {m - len(sumstats)} SNPs with duplicated rs numbers.")
19
20
 
20
21
  return sumstats
21
22
 
22
23
 
23
24
  def ps_sumstats(fh, alleles=False, dropna=True):
24
- '''
25
+ """
25
26
  Parses .sumstats files. See docs/file_formats_sumstats.txt.
26
- '''
27
-
28
- dtype_dict = {'SNP': str, 'Z': float, 'N': float, 'A1': str, 'A2': str}
27
+ """
28
+ dtype_dict = {"SNP": str, "Z": float, "N": float, "A1": str, "A2": str}
29
29
  compression = get_compression(fh)
30
- usecols = ['SNP', 'Z', 'N']
30
+ usecols = ["SNP", "Z", "N"]
31
31
  if alleles:
32
- usecols += ['A1', 'A2']
32
+ usecols += ["A1", "A2"]
33
33
 
34
34
  try:
35
35
  x = read_csv(fh, usecols=usecols, dtype=dtype_dict, compression=compression)
36
36
  except (AttributeError, ValueError) as e:
37
- raise ValueError('Improperly formatted sumstats file: ' + str(e.args))
37
+ raise ValueError("Improperly formatted sumstats file: " + str(e.args)) from e
38
38
 
39
39
  if dropna:
40
- x = x.dropna(how='any')
40
+ x = x.dropna(how="any")
41
41
 
42
42
  return x
43
43
 
44
44
 
45
45
  def get_compression(fh):
46
- '''
46
+ """
47
47
  Determin the format of compression used with read_csv?
48
- '''
49
- if fh.endswith('gz'):
50
- compression = 'gzip'
51
- elif fh.endswith('bz2'):
52
- compression = 'bz2'
48
+ """
49
+ if fh.endswith("gz"):
50
+ compression = "gzip"
51
+ elif fh.endswith("bz2"):
52
+ compression = "bz2"
53
53
  else:
54
54
  compression = None
55
55
  # -
@@ -57,120 +57,96 @@ def get_compression(fh):
57
57
 
58
58
 
59
59
  def read_csv(fh, **kwargs):
60
- '''
60
+ """
61
61
  Read the csv data
62
- '''
63
- return pd.read_csv(fh, sep='\s+', na_values='.', **kwargs)
62
+ """
63
+ return pd.read_csv(fh, sep=r"\s+", na_values=".", **kwargs)
64
64
 
65
65
 
66
- # Fun for reading loading LD scores
66
+ # Fun for reading loading LD scores
67
67
  def which_compression(fh):
68
- '''
68
+ """
69
69
  Given a file prefix, figure out what sort of compression to use.
70
- '''
71
- if os.access(fh + '.bz2', 4):
72
- suffix = '.bz2'
73
- compression = 'bz2'
74
- elif os.access(fh + '.gz', 4):
75
- suffix = '.gz'
76
- compression = 'gzip'
77
- elif os.access(fh + '.parquet', 4):
78
- suffix = '.parquet'
79
- compression = 'parquet'
80
- elif os.access(fh + '.feather', 4):
81
- suffix = '.feather'
82
- compression = 'feather'
70
+ """
71
+ if os.access(fh + ".bz2", 4):
72
+ suffix = ".bz2"
73
+ compression = "bz2"
74
+ elif os.access(fh + ".gz", 4):
75
+ suffix = ".gz"
76
+ compression = "gzip"
77
+ elif os.access(fh + ".parquet", 4):
78
+ suffix = ".parquet"
79
+ compression = "parquet"
80
+ elif os.access(fh + ".feather", 4):
81
+ suffix = ".feather"
82
+ compression = "feather"
83
83
  elif os.access(fh, 4):
84
- suffix = ''
84
+ suffix = ""
85
85
  compression = None
86
86
  else:
87
- raise IOError('Could not open {F}[./gz/bz2/parquet/feather]'.format(F=fh))
87
+ raise OSError(f"Could not open {fh}[./gz/bz2/parquet/feather]")
88
88
  # -
89
89
  return suffix, compression
90
90
 
91
91
 
92
- def _read_ref_ld(ld_file):
93
- suffix = '.l2.ldscore'
94
- file = ld_file
95
- first_fh = f'{file}1{suffix}'
96
- s, compression = which_compression(first_fh)
97
- #
98
- ldscore_array = []
99
- print(f'Reading ld score annotations from {file}[1-22]{suffix}.{compression}')
100
-
101
- for chr in range(1, 23):
102
- file_chr = f'{file}{chr}{suffix}{s}'
103
- #
104
- if compression == 'parquet':
105
- x = pd.read_parquet(file_chr)
106
- elif compression == 'feather':
107
- x = pd.read_feather(file_chr)
108
- else:
109
- x = pd.read_csv(file_chr, compression=compression, sep='\t')
110
-
111
- x = x.sort_values(by=['CHR', 'BP']) # SEs will be wrong unless sorted
112
-
113
- columns_to_drop = ['MAF', 'CM', 'Gene', 'TSS', 'CHR', 'BP']
114
- columns_to_drop = [col for col in columns_to_drop if col in x.columns]
115
- x = x.drop(columns_to_drop, axis=1)
116
-
117
- ldscore_array.append(x)
118
- #
119
- ref_ld = pd.concat(ldscore_array, axis=0)
120
- return ref_ld
121
-
122
-
123
92
  def _read_ref_ld_v2(ld_file):
124
- suffix = '.l2.ldscore'
93
+ suffix = ".l2.ldscore"
125
94
  file = ld_file
126
- first_fh = f'{file}1{suffix}'
95
+ first_fh = f"{file}1{suffix}"
127
96
  s, compression = which_compression(first_fh)
128
- print(f'Reading ld score annotations from {file}[1-22]{suffix}.{compression}')
97
+ print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
129
98
  ref_ld = pd.concat(
130
- [pd.read_feather(f'{file}{chr}{suffix}{s}') for chr in range(1, 23)], axis=0
99
+ [pd.read_feather(f"{file}{chr}{suffix}{s}") for chr in range(1, 23)], axis=0
131
100
  )
132
101
  # set first column as index
133
- ref_ld.rename(columns={'index': 'SNP'}, inplace=True)
134
- ref_ld.set_index('SNP', inplace=True)
102
+ ref_ld.rename(columns={"index": "SNP"}, inplace=True)
103
+ ref_ld.set_index("SNP", inplace=True)
135
104
  return ref_ld
136
105
 
106
+
137
107
  def _read_M_v2(ld_file, n_annot, not_M_5_50):
138
- suffix = '.l2.M'
108
+ suffix = ".l2.M"
139
109
  if not not_M_5_50:
140
- suffix += '_5_50'
141
- M_annot= np.array(
110
+ suffix += "_5_50"
111
+ M_annot = np.array(
142
112
  [
143
- np.loadtxt(f'{ld_file}{chr}{suffix}', )
144
- for chr in range(1, 23)]
145
-
113
+ np.loadtxt(
114
+ f"{ld_file}{chr}{suffix}",
115
+ )
116
+ for chr in range(1, 23)
117
+ ]
146
118
  )
147
119
  assert M_annot.shape == (22, n_annot)
148
120
  return M_annot.sum(axis=0).reshape((1, n_annot))
149
- # Fun for reading M annotations
121
+
122
+
123
+ # Fun for reading M annotations
150
124
  def _read_M(ld_file, n_annot, not_M_5_50):
151
- '''
125
+ """
152
126
  Read M (--M, --M-file, etc).
153
- '''
127
+ """
154
128
  M_annot = M(ld_file, common=(not not_M_5_50))
155
129
 
156
130
  try:
157
131
  M_annot = np.array(M_annot).reshape((1, n_annot))
158
132
  except ValueError as e:
159
- raise ValueError('# terms in --M must match # of LD Scores in --ref-ld.\n' + str(e.args))
133
+ raise ValueError(
134
+ "# terms in --M must match # of LD Scores in --ref-ld.\n" + str(e.args)
135
+ ) from e
160
136
  return M_annot
161
137
 
162
138
 
163
139
  def M(fh, common=False):
164
- '''
140
+ """
165
141
  Parses .l{N}.M files, split across num chromosomes.
166
- '''
167
- suffix = '.l2.M'
142
+ """
143
+ suffix = ".l2.M"
168
144
  if common:
169
- suffix += '_5_50'
145
+ suffix += "_5_50"
170
146
  # -
171
147
  M_array = []
172
148
  for i in range(1, 23):
173
- M_current = pd.read_csv(f'{fh}{i}' + suffix, header=None)
149
+ M_current = pd.read_csv(f"{fh}{i}" + suffix, header=None)
174
150
  M_array.append(M_current)
175
151
 
176
152
  M_array = pd.concat(M_array, axis=1).sum(axis=1)
@@ -178,117 +154,48 @@ def M(fh, common=False):
178
154
  return np.array(M_array).reshape((1, len(M_array)))
179
155
 
180
156
 
181
- def _check_variance(M_annot, ref_ld):
182
- '''
183
- Remove zero-variance LD Scores.
184
- '''
185
- ii = ref_ld.iloc[:, 1:].var() == 0 # NB there is a SNP column here
186
- if ii.all():
187
- raise ValueError('All LD Scores have zero variance.')
188
- else:
189
- print('Removing partitioned LD Scores with zero variance.')
190
- ii_snp = np.array([True] + list(~ii))
191
- ii_m = np.array(~ii)
192
- ref_ld = ref_ld.iloc[:, ii_snp]
193
- M_annot = M_annot[:, ii_m]
194
- # -
195
- return M_annot, ref_ld, ii
196
157
  def _check_variance_v2(M_annot, ref_ld):
197
158
  ii = ref_ld.var() == 0
198
159
  if ii.all():
199
- raise ValueError('All LD Scores have zero variance.')
160
+ raise ValueError("All LD Scores have zero variance.")
200
161
  elif not ii.any():
201
- print('No partitioned LD Scores have zero variance.')
162
+ print("No partitioned LD Scores have zero variance.")
202
163
  else:
203
- ii_snp= ii_m = np.array(~ii)
204
- print(f'Removing {sum(ii)} partitioned LD Scores with zero variance.')
164
+ ii_snp = ii_m = np.array(~ii)
165
+ print(f"Removing {sum(ii)} partitioned LD Scores with zero variance.")
205
166
  ref_ld = ref_ld.iloc[:, ii_snp]
206
167
  M_annot = M_annot[:, ii_m]
207
168
  return M_annot, ref_ld
208
169
 
209
170
 
210
- # Fun for reading regression weights
211
- def which_compression(fh):
212
- '''
213
- Given a file prefix, figure out what sort of compression to use.
214
- '''
215
- if os.access(fh + '.bz2', 4):
216
- suffix = '.bz2'
217
- compression = 'bz2'
218
- elif os.access(fh + '.gz', 4):
219
- suffix = '.gz'
220
- compression = 'gzip'
221
- elif os.access(fh + '.parquet', 4):
222
- suffix = '.parquet'
223
- compression = 'parquet'
224
- elif os.access(fh + '.feather', 4):
225
- suffix = '.feather'
226
- compression = 'feather'
227
- elif os.access(fh, 4):
228
- suffix = ''
229
- compression = None
230
- else:
231
- raise IOError('Could not open {F}[./gz/bz2/parquet/feather]'.format(F=fh))
232
- # -
233
- return suffix, compression
234
-
235
-
236
171
  def _read_w_ld(w_file):
237
- suffix = '.l2.ldscore'
172
+ suffix = ".l2.ldscore"
238
173
  file = w_file
239
- first_fh = f'{file}1{suffix}'
174
+ first_fh = f"{file}1{suffix}"
240
175
  s, compression = which_compression(first_fh)
241
176
  #
242
177
  w_array = []
243
- print(f'Reading ld score annotations from {file}[1-22]{suffix}.{compression}')
178
+ print(f"Reading ld score annotations from {file}[1-22]{suffix}.{compression}")
244
179
 
245
180
  for chr in range(1, 23):
246
- file_chr = f'{file}{chr}{suffix}{s}'
181
+ file_chr = f"{file}{chr}{suffix}{s}"
247
182
  #
248
- if compression == 'parquet':
183
+ if compression == "parquet":
249
184
  x = pd.read_parquet(file_chr)
250
- elif compression == 'feather':
185
+ elif compression == "feather":
251
186
  x = pd.read_feather(file_chr)
252
187
  else:
253
- x = pd.read_csv(file_chr, compression=compression, sep='\t')
188
+ x = pd.read_csv(file_chr, compression=compression, sep="\t")
254
189
 
255
- x = x.sort_values(by=['CHR', 'BP'])
190
+ x = x.sort_values(by=["CHR", "BP"])
256
191
 
257
- columns_to_drop = ['MAF', 'CM', 'Gene', 'TSS', 'CHR', 'BP']
192
+ columns_to_drop = ["MAF", "CM", "Gene", "TSS", "CHR", "BP"]
258
193
  columns_to_drop = [col for col in columns_to_drop if col in x.columns]
259
194
  x = x.drop(columns_to_drop, axis=1)
260
195
 
261
196
  w_array.append(x)
262
197
  #
263
198
  w_ld = pd.concat(w_array, axis=0)
264
- w_ld.columns = ['SNP', 'LD_weights']
199
+ w_ld.columns = ["SNP", "LD_weights"]
265
200
 
266
201
  return w_ld
267
-
268
-
269
- # Fun for merging
270
- def _merge_and_log(ld, sumstats, noun):
271
- '''
272
- Wrap smart merge with log messages about # of SNPs.
273
- '''
274
- sumstats = smart_merge(ld, sumstats)
275
- msg = 'After merging with {F}, {N} SNPs remain.'
276
- if len(sumstats) == 0:
277
- raise ValueError(msg.format(N=len(sumstats), F=noun))
278
- else:
279
- print(msg.format(N=len(sumstats), F=noun))
280
- # -
281
- return sumstats
282
-
283
-
284
- def smart_merge(x, y):
285
- '''
286
- Check if SNP columns are equal. If so, save time by using concat instead of merge.
287
- '''
288
- if len(x) == len(y) and (x.index == y.index).all() and (x.SNP == y.SNP).all():
289
- x = x.reset_index(drop=True)
290
- y = y.reset_index(drop=True).drop('SNP', 1)
291
- out = pd.concat([x, y], axis=1)
292
- else:
293
- out = pd.merge(x, y, how='inner', on='SNP')
294
- return out
gsMap/visualize.py CHANGED
@@ -11,35 +11,41 @@ from gsMap.config import VisualizeConfig
11
11
 
12
12
 
13
13
  def load_ldsc(ldsc_input_file):
14
- ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
15
- ldsc.spot = ldsc.spot.astype(str).replace('\.0', '', regex=True)
16
- ldsc.index = ldsc.spot
17
- ldsc['logp'] = -np.log10(ldsc.p)
14
+ ldsc = pd.read_csv(
15
+ ldsc_input_file,
16
+ compression="gzip",
17
+ dtype={"spot": str, "p": float},
18
+ index_col="spot",
19
+ usecols=["spot", "p"],
20
+ )
21
+ ldsc["logp"] = -np.log10(ldsc.p)
18
22
  return ldsc
19
23
 
20
24
 
21
25
  # %%
22
26
  def load_st_coord(adata, feature_series: pd.Series, annotation):
23
27
  spot_name = adata.obs_names.to_list()
24
- assert 'spatial' in adata.obsm.keys(), 'spatial coordinates are not found in adata.obsm'
28
+ assert "spatial" in adata.obsm.keys(), "spatial coordinates are not found in adata.obsm"
25
29
 
26
30
  # to DataFrame
27
- space_coord = adata.obsm['spatial']
31
+ space_coord = adata.obsm["spatial"]
28
32
  if isinstance(space_coord, np.ndarray):
29
- space_coord = pd.DataFrame(space_coord, columns=['sx', 'sy'], index=spot_name)
33
+ space_coord = pd.DataFrame(space_coord, columns=["sx", "sy"], index=spot_name)
30
34
  else:
31
- space_coord = pd.DataFrame(space_coord.values, columns=['sx', 'sy'], index=spot_name)
35
+ space_coord = pd.DataFrame(space_coord.values, columns=["sx", "sy"], index=spot_name)
32
36
 
33
37
  space_coord = space_coord[space_coord.index.isin(feature_series.index)]
34
38
  space_coord_concat = pd.concat([space_coord.loc[feature_series.index], feature_series], axis=1)
35
39
  space_coord_concat.head()
36
40
  if annotation is not None:
37
- annotation = pd.Series(adata.obs[annotation].values, index=adata.obs_names, name='annotation')
41
+ annotation = pd.Series(
42
+ adata.obs[annotation].values, index=adata.obs_names, name="annotation"
43
+ )
38
44
  space_coord_concat = pd.concat([space_coord_concat, annotation], axis=1)
39
45
  return space_coord_concat
40
46
 
41
47
 
42
- def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH = 1000):
48
+ def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH=1000):
43
49
  tree = KDTree(coordinates)
44
50
  distances, _ = tree.query(coordinates, k=2)
45
51
  avg_min_distance = np.mean(distances[:, 1])
@@ -55,34 +61,42 @@ def estimate_point_size_for_plot(coordinates, DEFAULT_PIXEL_WIDTH = 1000):
55
61
  return (pixel_width, pixel_height), point_size
56
62
 
57
63
 
58
- def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'light'] = 'light',
59
- point_size: int = None, width=800, height=600, annotation=None, color_by='logp'):
64
+ def draw_scatter(
65
+ space_coord_concat,
66
+ title=None,
67
+ fig_style: Literal["dark", "light"] = "light",
68
+ point_size: int = None,
69
+ width=800,
70
+ height=600,
71
+ annotation=None,
72
+ color_by="logp",
73
+ ):
60
74
  # Set theme based on fig_style
61
- if fig_style == 'dark':
75
+ if fig_style == "dark":
62
76
  px.defaults.template = "plotly_dark"
63
77
  else:
64
78
  px.defaults.template = "plotly_white"
65
79
 
66
80
  custom_color_scale = [
67
- (1, '#d73027'), # Red
68
- (7 / 8, '#f46d43'), # Red-Orange
69
- (6 / 8, '#fdae61'), # Orange
70
- (5 / 8, '#fee090'), # Light Orange
71
- (4 / 8, '#e0f3f8'), # Light Blue
72
- (3 / 8, '#abd9e9'), # Sky Blue
73
- (2 / 8, '#74add1'), # Medium Blue
74
- (1 / 8, '#4575b4'), # Dark Blue
75
- (0, '#313695') # Deep Blue
81
+ (1, "#d73027"), # Red
82
+ (7 / 8, "#f46d43"), # Red-Orange
83
+ (6 / 8, "#fdae61"), # Orange
84
+ (5 / 8, "#fee090"), # Light Orange
85
+ (4 / 8, "#e0f3f8"), # Light Blue
86
+ (3 / 8, "#abd9e9"), # Sky Blue
87
+ (2 / 8, "#74add1"), # Medium Blue
88
+ (1 / 8, "#4575b4"), # Dark Blue
89
+ (0, "#313695"), # Deep Blue
76
90
  ]
77
91
  custom_color_scale.reverse()
78
92
 
79
93
  # Create the scatter plot
80
94
  fig = px.scatter(
81
95
  space_coord_concat,
82
- x='sx',
83
- y='sy',
96
+ x="sx",
97
+ y="sy",
84
98
  color=color_by,
85
- symbol='annotation' if annotation is not None else None,
99
+ symbol="annotation" if annotation is not None else None,
86
100
  title=title,
87
101
  color_continuous_scale=custom_color_scale,
88
102
  range_color=[0, max(space_coord_concat[color_by])],
@@ -90,7 +104,7 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
90
104
 
91
105
  # Update marker size if specified
92
106
  if point_size is not None:
93
- fig.update_traces(marker=dict(size=point_size, symbol='circle'))
107
+ fig.update_traces(marker=dict(size=point_size, symbol="circle"))
94
108
 
95
109
  # Update layout for figure size
96
110
  fig.update_layout(
@@ -108,45 +122,45 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
108
122
  x=1.0,
109
123
  font=dict(
110
124
  size=10,
111
- )
125
+ ),
112
126
  )
113
127
  )
114
128
 
115
129
  # Update colorbar to be at the bottom and horizontal
116
130
  fig.update_layout(
117
131
  coloraxis_colorbar=dict(
118
- orientation='h', # Make the colorbar horizontal
132
+ orientation="h", # Make the colorbar horizontal
119
133
  x=0.5, # Center the colorbar horizontally
120
134
  y=-0.0, # Position below the plot
121
- xanchor='center', # Anchor the colorbar at the center
122
- yanchor='top', # Anchor the colorbar at the top to keep it just below the plot
135
+ xanchor="center", # Anchor the colorbar at the center
136
+ yanchor="top", # Anchor the colorbar at the top to keep it just below the plot
123
137
  len=0.75, # Length of the colorbar relative to the plot width
124
138
  title=dict(
125
- text='-log10(p)' if color_by == 'logp' else color_by, # Colorbar title
126
- side='top' # Place the title at the top of the colorbar
127
- )
139
+ text="-log10(p)" if color_by == "logp" else color_by, # Colorbar title
140
+ side="top", # Place the title at the top of the colorbar
141
+ ),
128
142
  )
129
143
  )
130
144
  # Remove gridlines, axis labels, and ticks
131
145
  fig.update_xaxes(
132
- showgrid=False, # Hide x-axis gridlines
133
- zeroline=False, # Hide x-axis zero line
146
+ showgrid=False, # Hide x-axis gridlines
147
+ zeroline=False, # Hide x-axis zero line
134
148
  showticklabels=False, # Hide x-axis tick labels
135
- title=None, # Remove x-axis title
136
- scaleanchor='y', # Link the x-axis scale to the y-axis scale
149
+ title=None, # Remove x-axis title
150
+ scaleanchor="y", # Link the x-axis scale to the y-axis scale
137
151
  )
138
152
 
139
153
  fig.update_yaxes(
140
- showgrid=False, # Hide y-axis gridlines
141
- zeroline=False, # Hide y-axis zero line
154
+ showgrid=False, # Hide y-axis gridlines
155
+ zeroline=False, # Hide y-axis zero line
142
156
  showticklabels=False, # Hide y-axis tick labels
143
- title=None # Remove y-axis title
157
+ title=None, # Remove y-axis title
144
158
  )
145
159
 
146
160
  # Adjust margins to ensure no clipping and equal axis ratio
147
161
  fig.update_layout(
148
162
  margin=dict(l=0, r=0, t=20, b=10), # Adjust margins to prevent clipping
149
- height=width # Ensure the figure height matches the width for equal axis ratio
163
+ height=width, # Ensure the figure height matches the width for equal axis ratio
150
164
  )
151
165
 
152
166
  # Adjust the title location and font size
@@ -154,46 +168,50 @@ def draw_scatter(space_coord_concat, title=None, fig_style: Literal['dark', 'lig
154
168
  title=dict(
155
169
  y=0.98,
156
170
  x=0.5, # Center the title horizontally
157
- xanchor='center', # Anchor the title at the center
158
- yanchor='top', # Anchor the title at the top
171
+ xanchor="center", # Anchor the title at the center
172
+ yanchor="top", # Anchor the title at the top
159
173
  font=dict(
160
174
  size=20 # Increase the title font size
161
- )
162
- ))
175
+ ),
176
+ )
177
+ )
163
178
 
164
179
  return fig
165
180
 
166
181
 
167
-
168
182
  def run_Visualize(config: VisualizeConfig):
169
- print(f'------Loading LDSC results of {config.ldsc_save_dir}...')
170
- ldsc = load_ldsc(ldsc_input_file=Path(config.ldsc_save_dir) / f'{config.sample_name}_{config.trait_name}.csv.gz')
183
+ print(f"------Loading LDSC results of {config.ldsc_save_dir}...")
184
+ ldsc = load_ldsc(
185
+ ldsc_input_file=Path(config.ldsc_save_dir)
186
+ / f"{config.sample_name}_{config.trait_name}.csv.gz"
187
+ )
171
188
 
172
- print(f'------Loading ST data of {config.sample_name}...')
173
- adata = sc.read_h5ad(f'{config.hdf5_with_latent_path}')
189
+ print(f"------Loading ST data of {config.sample_name}...")
190
+ adata = sc.read_h5ad(f"{config.hdf5_with_latent_path}")
174
191
 
175
192
  space_coord_concat = load_st_coord(adata, ldsc, annotation=config.annotation)
176
- fig = draw_scatter(space_coord_concat,
177
- title=config.fig_title,
178
- fig_style=config.fig_style,
179
- point_size=config.point_size,
180
- width=config.fig_width,
181
- height=config.fig_height,
182
- annotation=config.annotation
183
- )
184
-
193
+ fig = draw_scatter(
194
+ space_coord_concat,
195
+ title=config.fig_title,
196
+ fig_style=config.fig_style,
197
+ point_size=config.point_size,
198
+ width=config.fig_width,
199
+ height=config.fig_height,
200
+ annotation=config.annotation,
201
+ )
185
202
 
186
203
  # Visualization
187
204
  output_dir = Path(config.output_dir)
188
205
  output_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
189
- output_file_html = output_dir / f'{config.sample_name}_{config.trait_name}.html'
190
- output_file_pdf = output_dir / f'{config.sample_name}_{config.trait_name}.pdf'
191
- output_file_csv = output_dir / f'{config.sample_name}_{config.trait_name}.csv'
206
+ output_file_html = output_dir / f"{config.sample_name}_{config.trait_name}.html"
207
+ output_file_pdf = output_dir / f"{config.sample_name}_{config.trait_name}.pdf"
208
+ output_file_csv = output_dir / f"{config.sample_name}_{config.trait_name}.csv"
192
209
 
193
210
  fig.write_html(str(output_file_html))
194
211
  fig.write_image(str(output_file_pdf))
195
212
  space_coord_concat.to_csv(str(output_file_csv))
196
213
 
197
214
  print(
198
- f'------The visualization result is saved in a html file: {output_file_html} which can interactively viewed in a web browser and a pdf file: {output_file_pdf}.')
199
- print(f'------The visualization data is saved in a csv file: {output_file_csv}.')
215
+ f"------The visualization result is saved in a html file: {output_file_html} which can interactively viewed in a web browser and a pdf file: {output_file_pdf}."
216
+ )
217
+ print(f"------The visualization data is saved in a csv file: {output_file_csv}.")