gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1409 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import logging
|
|
3
|
+
import warnings
|
|
4
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
import distinctipy
|
|
9
|
+
import matplotlib
|
|
10
|
+
import matplotlib.axes
|
|
11
|
+
import matplotlib.colors as mcolors
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
import plotly.express as px
|
|
16
|
+
import plotly.graph_objects as go
|
|
17
|
+
import scipy.stats as stats
|
|
18
|
+
from rich import print
|
|
19
|
+
from scipy.cluster.hierarchy import leaves_list, linkage
|
|
20
|
+
from scipy.spatial import KDTree
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
|
|
23
|
+
warnings.filterwarnings("ignore")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def remove_outliers_MAD(data, threshold=3.5):
|
|
27
|
+
"""
|
|
28
|
+
Remove outliers based on Median Absolute Deviation (MAD).
|
|
29
|
+
"""
|
|
30
|
+
if isinstance(data, pd.Series | pd.DataFrame):
|
|
31
|
+
data_values = data.values.flatten()
|
|
32
|
+
else:
|
|
33
|
+
data_values = np.asarray(data).flatten()
|
|
34
|
+
|
|
35
|
+
if len(data_values) == 0:
|
|
36
|
+
return data, np.ones(len(data), dtype=bool)
|
|
37
|
+
|
|
38
|
+
median = np.nanmedian(data_values)
|
|
39
|
+
mad = np.nanmedian(np.abs(data_values - median))
|
|
40
|
+
if mad == 0:
|
|
41
|
+
return data, np.ones(len(data), dtype=bool)
|
|
42
|
+
|
|
43
|
+
modified_z_score = 0.6745 * (data_values - median) / mad
|
|
44
|
+
mask = np.abs(modified_z_score) <= threshold
|
|
45
|
+
|
|
46
|
+
if isinstance(data, pd.Series):
|
|
47
|
+
return data[mask], mask
|
|
48
|
+
elif isinstance(data, np.ndarray):
|
|
49
|
+
if len(data.shape) == 1:
|
|
50
|
+
return data[mask], mask
|
|
51
|
+
else:
|
|
52
|
+
return data.flatten()[mask], mask
|
|
53
|
+
return data[mask], mask
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def load_ldsc(ldsc_input_file):
|
|
57
|
+
ldsc = pd.read_csv(
|
|
58
|
+
ldsc_input_file,
|
|
59
|
+
compression="gzip",
|
|
60
|
+
dtype={"spot": str, "p": float},
|
|
61
|
+
index_col="spot",
|
|
62
|
+
usecols=["spot", "p"],
|
|
63
|
+
)
|
|
64
|
+
ldsc["logp"] = -np.log10(ldsc.p)
|
|
65
|
+
return ldsc
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# %%
|
|
69
|
+
def load_st_coord(adata, feature_series: pd.Series, annotation):
|
|
70
|
+
spot_name = adata.obs_names.to_list()
|
|
71
|
+
assert "spatial" in adata.obsm.keys(), "spatial coordinates are not found in adata.obsm"
|
|
72
|
+
|
|
73
|
+
# to DataFrame
|
|
74
|
+
space_coord = adata.obsm["spatial"]
|
|
75
|
+
if isinstance(space_coord, np.ndarray):
|
|
76
|
+
space_coord = pd.DataFrame(space_coord, columns=["sx", "sy"], index=spot_name)
|
|
77
|
+
else:
|
|
78
|
+
space_coord = pd.DataFrame(space_coord.values, columns=["sx", "sy"], index=spot_name)
|
|
79
|
+
|
|
80
|
+
space_coord = space_coord[space_coord.index.isin(feature_series.index)]
|
|
81
|
+
space_coord_concat = pd.concat([space_coord.loc[feature_series.index], feature_series], axis=1)
|
|
82
|
+
space_coord_concat.head()
|
|
83
|
+
if annotation is not None:
|
|
84
|
+
annotation = pd.Series(
|
|
85
|
+
adata.obs[annotation].values, index=adata.obs_names, name="annotation"
|
|
86
|
+
)
|
|
87
|
+
space_coord_concat = pd.concat([space_coord_concat, annotation], axis=1)
|
|
88
|
+
return space_coord_concat
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def estimate_plotly_point_size(coordinates, DEFAULT_PIXEL_WIDTH=1000):
|
|
92
|
+
# Convert to numpy array if it's a DataFrame or other array-like object
|
|
93
|
+
if hasattr(coordinates, 'values'):
|
|
94
|
+
coordinates = coordinates.values
|
|
95
|
+
coordinates = np.asarray(coordinates)
|
|
96
|
+
|
|
97
|
+
tree = KDTree(coordinates)
|
|
98
|
+
distances, _ = tree.query(coordinates, k=2)
|
|
99
|
+
avg_min_distance = np.median(distances[:, 1])
|
|
100
|
+
# get the width and height of the plot
|
|
101
|
+
width = np.max(coordinates[:, 0]) - np.min(coordinates[:, 0])
|
|
102
|
+
height = np.max(coordinates[:, 1]) - np.min(coordinates[:, 1])
|
|
103
|
+
|
|
104
|
+
scale_factor = DEFAULT_PIXEL_WIDTH / max(width, height)
|
|
105
|
+
pixel_width = width * scale_factor
|
|
106
|
+
pixel_height = height * scale_factor
|
|
107
|
+
|
|
108
|
+
point_size = avg_min_distance * scale_factor
|
|
109
|
+
|
|
110
|
+
return (pixel_width, pixel_height), point_size
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def estimate_matplotlib_scatter_marker_size(ax: matplotlib.axes.Axes, coordinates: np.ndarray,
|
|
114
|
+
x_limits: tuple | None = None,
|
|
115
|
+
y_limits: tuple | None = None) -> float:
|
|
116
|
+
"""
|
|
117
|
+
Estimates the appropriate marker size to make adjacent markers touch.
|
|
118
|
+
|
|
119
|
+
This function calculates the size 's' for a square marker (in points^2)
|
|
120
|
+
such that its diameter in the plot corresponds to the average distance
|
|
121
|
+
to the nearest neighbor for each point in the dataset. It accounts for
|
|
122
|
+
the plot's aspect ratio and final rendered dimensions.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
ax (matplotlib.axes.Axes): The subplot object. The function will
|
|
126
|
+
temporarily set its limits and aspect ratio to ensure the
|
|
127
|
+
transformation from data units to display units is accurate.
|
|
128
|
+
coordinates (np.ndarray): A NumPy array of shape (n, 2)
|
|
129
|
+
containing the (x, y) coordinates of the points.
|
|
130
|
+
x_limits (Optional[tuple]): Optional (min, max) tuple to override automatic x-axis limits.
|
|
131
|
+
y_limits (Optional[tuple]): Optional (min, max) tuple to override automatic y-axis limits.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
float: The estimated marker size 's' (in points^2) for use with
|
|
135
|
+
ax.scatter().
|
|
136
|
+
"""
|
|
137
|
+
# 1. Set up the axes' properties to ensure accurate transformations.
|
|
138
|
+
# The aspect ratio and data limits must be set to correctly
|
|
139
|
+
# calculate the relationship between data units and display units (inches/points).
|
|
140
|
+
ax.set_aspect('equal')
|
|
141
|
+
|
|
142
|
+
# Use provided limits if available, otherwise calculate from data
|
|
143
|
+
if x_limits is not None:
|
|
144
|
+
x_data_min, x_data_max = x_limits
|
|
145
|
+
else:
|
|
146
|
+
x_data_min, x_data_max = np.min(coordinates[:, 0]), np.max(coordinates[:, 0])
|
|
147
|
+
|
|
148
|
+
if y_limits is not None:
|
|
149
|
+
y_data_min, y_data_max = np.min(coordinates[:, 1]), np.max(coordinates[:, 1])
|
|
150
|
+
else:
|
|
151
|
+
y_data_min, y_data_max = np.min(coordinates[:, 1]), np.max(coordinates[:, 1])
|
|
152
|
+
|
|
153
|
+
ax.set_xlim(x_data_min, x_data_max)
|
|
154
|
+
ax.set_ylim(y_data_min, y_data_max)
|
|
155
|
+
|
|
156
|
+
# Force a draw of the canvas to finalize the transformations.
|
|
157
|
+
ax.figure.canvas.draw()
|
|
158
|
+
|
|
159
|
+
# 2. Calculate the required marker radius in data units.
|
|
160
|
+
# We find the average distance to the nearest neighbor for all points.
|
|
161
|
+
# The desired radius is half of this distance.
|
|
162
|
+
tree = KDTree(coordinates)
|
|
163
|
+
distances, _ = tree.query(coordinates, k=2)
|
|
164
|
+
radius_data = np.mean(distances[:, 1]) / 2
|
|
165
|
+
|
|
166
|
+
# 3. Convert the data radius to display units (points).
|
|
167
|
+
# This requires transforming the axes' bounding box from data coordinates
|
|
168
|
+
# to display coordinates (pixels), then to physical units (inches).
|
|
169
|
+
|
|
170
|
+
# Get the bounding box in display (pixel) coordinates
|
|
171
|
+
x_display_min, _ = ax.transData.transform((x_data_min, y_data_min))
|
|
172
|
+
x_display_max, _ = ax.transData.transform((x_data_max, y_data_max))
|
|
173
|
+
|
|
174
|
+
# Convert the display coordinates to inches
|
|
175
|
+
x_inch_min, _ = ax.figure.dpi_scale_trans.inverted().transform((x_display_min, 0))
|
|
176
|
+
x_inch_max, _ = ax.figure.dpi_scale_trans.inverted().transform((x_display_max, 0))
|
|
177
|
+
|
|
178
|
+
width_inch = x_inch_max - x_inch_min
|
|
179
|
+
width_data = x_data_max - x_data_min
|
|
180
|
+
|
|
181
|
+
# Calculate the radius in inches. This scales the data radius by the
|
|
182
|
+
# ratio of the plot's physical width to its data width.
|
|
183
|
+
# This works because the aspect ratio is 'equal'.
|
|
184
|
+
radius_inch = (radius_data / width_data) * width_inch
|
|
185
|
+
|
|
186
|
+
# Convert inches to points (1 inch = 72 points).
|
|
187
|
+
radius_points = radius_inch * 72
|
|
188
|
+
|
|
189
|
+
# 4. Calculate the marker size 's'.
|
|
190
|
+
# For ax.scatter, 's' is the marker area in points^2.
|
|
191
|
+
# For a square marker, the area is (side)^2, where side = 2 * radius.
|
|
192
|
+
square_marker_size = (2 * radius_points) ** 2
|
|
193
|
+
|
|
194
|
+
return square_marker_size * 1.2
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def draw_scatter(
|
|
198
|
+
space_coord_concat,
|
|
199
|
+
title=None,
|
|
200
|
+
fig_style: Literal["dark", "light"] = "light",
|
|
201
|
+
point_size: int = None,
|
|
202
|
+
width=800,
|
|
203
|
+
height=600,
|
|
204
|
+
annotation=None,
|
|
205
|
+
color_by="logp",
|
|
206
|
+
color_continuous_scale=None,
|
|
207
|
+
plot_origin="upper",
|
|
208
|
+
):
|
|
209
|
+
# Set theme based on fig_style
|
|
210
|
+
if fig_style == "dark":
|
|
211
|
+
px.defaults.template = "plotly_dark"
|
|
212
|
+
else:
|
|
213
|
+
px.defaults.template = "plotly_white"
|
|
214
|
+
|
|
215
|
+
if color_continuous_scale is None:
|
|
216
|
+
custom_color_scale = [
|
|
217
|
+
(1, "#d73027"), # Red
|
|
218
|
+
(7 / 8, "#f46d43"), # Red-Orange
|
|
219
|
+
(6 / 8, "#fdae61"), # Orange
|
|
220
|
+
(5 / 8, "#fee090"), # Light Orange
|
|
221
|
+
(4 / 8, "#e0f3f8"), # Light Blue
|
|
222
|
+
(3 / 8, "#abd9e9"), # Sky Blue
|
|
223
|
+
(2 / 8, "#74add1"), # Medium Blue
|
|
224
|
+
(1 / 8, "#4575b4"), # Dark Blue
|
|
225
|
+
(0, "#313695"), # Deep Blue
|
|
226
|
+
]
|
|
227
|
+
custom_color_scale.reverse()
|
|
228
|
+
color_continuous_scale = custom_color_scale
|
|
229
|
+
|
|
230
|
+
# Create the scatter plot
|
|
231
|
+
fig = px.scatter(
|
|
232
|
+
space_coord_concat,
|
|
233
|
+
x="sx",
|
|
234
|
+
y="sy",
|
|
235
|
+
color=color_by,
|
|
236
|
+
symbol="annotation" if annotation is not None else None,
|
|
237
|
+
title=title,
|
|
238
|
+
color_continuous_scale=color_continuous_scale,
|
|
239
|
+
range_color=[0, max(space_coord_concat[color_by])],
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Update marker size if specified
|
|
243
|
+
if point_size is not None:
|
|
244
|
+
fig.update_traces(marker=dict(size=point_size, symbol="circle"))
|
|
245
|
+
|
|
246
|
+
# Update layout for figure size
|
|
247
|
+
fig.update_layout(
|
|
248
|
+
autosize=False,
|
|
249
|
+
width=width,
|
|
250
|
+
height=height,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Adjusting the legend
|
|
254
|
+
fig.update_layout(
|
|
255
|
+
legend=dict(
|
|
256
|
+
yanchor="top",
|
|
257
|
+
y=0.95,
|
|
258
|
+
xanchor="left",
|
|
259
|
+
x=1.0,
|
|
260
|
+
font=dict(
|
|
261
|
+
size=10,
|
|
262
|
+
),
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Update colorbar to be at the bottom and horizontal
|
|
267
|
+
fig.update_layout(
|
|
268
|
+
coloraxis_colorbar=dict(
|
|
269
|
+
orientation="h", # Make the colorbar horizontal
|
|
270
|
+
x=0.5, # Center the colorbar horizontally
|
|
271
|
+
y=-0.0, # Position below the plot
|
|
272
|
+
xanchor="center", # Anchor the colorbar at the center
|
|
273
|
+
yanchor="top", # Anchor the colorbar at the top to keep it just below the plot
|
|
274
|
+
len=0.75, # Length of the colorbar relative to the plot width
|
|
275
|
+
title=dict(
|
|
276
|
+
text="-log10(p)" if color_by == "logp" else color_by, # Colorbar title
|
|
277
|
+
side="top", # Place the title at the top of the colorbar
|
|
278
|
+
),
|
|
279
|
+
)
|
|
280
|
+
)
|
|
281
|
+
# Remove gridlines, axis labels, and ticks
|
|
282
|
+
fig.update_xaxes(
|
|
283
|
+
showgrid=False, # Hide x-axis gridlines
|
|
284
|
+
zeroline=False, # Hide x-axis zero line
|
|
285
|
+
showticklabels=False, # Hide x-axis tick labels
|
|
286
|
+
title=None, # Remove x-axis title
|
|
287
|
+
scaleanchor="y", # Link the x-axis scale to the y-axis scale
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
fig.update_yaxes(
|
|
291
|
+
showgrid=False, # Hide y-axis gridlines
|
|
292
|
+
zeroline=False, # Hide y-axis zero line
|
|
293
|
+
showticklabels=False, # Hide y-axis tick labels
|
|
294
|
+
title=None, # Remove y-axis title
|
|
295
|
+
autorange="reversed" if plot_origin == "upper" else True,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Adjust margins to ensure no clipping and equal axis ratio
|
|
299
|
+
fig.update_layout(
|
|
300
|
+
margin=dict(l=0, r=0, t=20, b=10), # Adjust margins to prevent clipping
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Adjust the title location and font size
|
|
304
|
+
fig.update_layout(
|
|
305
|
+
title=dict(
|
|
306
|
+
y=0.98,
|
|
307
|
+
x=0.5, # Center the title horizontally
|
|
308
|
+
xanchor="center", # Anchor the title at the center
|
|
309
|
+
yanchor="top", # Anchor the title at the top
|
|
310
|
+
font=dict(
|
|
311
|
+
size=20 # Increase the title font size
|
|
312
|
+
),
|
|
313
|
+
)
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return fig
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _create_color_map(category_list: list, hex=False, rng=42) -> dict[str, tuple]:
|
|
320
|
+
unique_categories = sorted(set(category_list), key=str)
|
|
321
|
+
|
|
322
|
+
# Check for 'NaN' or nan and handle separately
|
|
323
|
+
nan_values = [v for v in unique_categories if str(v).lower() in ['nan', 'none', 'null']]
|
|
324
|
+
other_categories = [v for v in unique_categories if v not in nan_values]
|
|
325
|
+
|
|
326
|
+
n_colors = len(other_categories)
|
|
327
|
+
|
|
328
|
+
# Generate N visually distinct colors for non-NaN categories
|
|
329
|
+
if n_colors > 0:
|
|
330
|
+
colors = distinctipy.get_colors(n_colors, rng=rng)
|
|
331
|
+
color_map = dict(zip(other_categories, colors, strict=False))
|
|
332
|
+
else:
|
|
333
|
+
color_map = {}
|
|
334
|
+
|
|
335
|
+
# Assign grey color to NaN values
|
|
336
|
+
grey_rgb = (0.827, 0.827, 0.827) # lightgrey
|
|
337
|
+
for v in nan_values:
|
|
338
|
+
color_map[v] = grey_rgb
|
|
339
|
+
|
|
340
|
+
if hex:
|
|
341
|
+
# Convert RGB tuples to hex format
|
|
342
|
+
color_map = {category: distinctipy.get_hex(color_map[category]) for category in color_map}
|
|
343
|
+
print("Generated color map in hex format")
|
|
344
|
+
return color_map
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class VisualizeRunner:
|
|
348
|
+
def __init__(self, config):
|
|
349
|
+
self.config = config
|
|
350
|
+
|
|
351
|
+
custom_colors_list = [
|
|
352
|
+
'#d73027', '#f46d43', '#fdae61', '#fee090', '#e0f3f8',
|
|
353
|
+
'#abd9e9', '#74add1', '#4575b4', '#313695'
|
|
354
|
+
]
|
|
355
|
+
|
|
356
|
+
def _generate_visualizations(self, obs_ldsc_merged: pd.DataFrame):
|
|
357
|
+
"""Generate all visualizations"""
|
|
358
|
+
|
|
359
|
+
# Create visualization directories
|
|
360
|
+
single_sample_folder = self.config.visualization_result_dir / 'single_sample_multi_trait_plot'
|
|
361
|
+
annotation_folder = self.config.visualization_result_dir / 'annotation_distribution'
|
|
362
|
+
annotation_folder.mkdir(exist_ok=True, parents=True)
|
|
363
|
+
|
|
364
|
+
sample_names_list = sorted(obs_ldsc_merged['sample_name'].unique())
|
|
365
|
+
|
|
366
|
+
for sample_name in tqdm(sample_names_list, desc='Generating visualizations'):
|
|
367
|
+
# Multi-trait plot
|
|
368
|
+
traits_png = single_sample_folder / 'static_png' / f'{sample_name}_gwas_traits_pvalues.jpg'
|
|
369
|
+
traits_pdf = single_sample_folder / 'static_pdf' / f'{sample_name}_gwas_traits_pvalues.pdf' # Added PDF output path
|
|
370
|
+
|
|
371
|
+
# Create parent directories for the output files
|
|
372
|
+
traits_png.parent.mkdir(exist_ok=True, parents=True)
|
|
373
|
+
traits_pdf.parent.mkdir(exist_ok=True, parents=True)
|
|
374
|
+
|
|
375
|
+
# Call the modified matplotlib-based plotting function.
|
|
376
|
+
# This function saves files directly and does not return a figure object.
|
|
377
|
+
self._create_single_sample_multi_trait_plots(
|
|
378
|
+
obs_ldsc_merged=obs_ldsc_merged,
|
|
379
|
+
trait_names=self.config.trait_name_list,
|
|
380
|
+
sample_name=sample_name,
|
|
381
|
+
output_png_path=traits_png,
|
|
382
|
+
output_pdf_path=traits_pdf,
|
|
383
|
+
max_cols=self.config.single_sample_multi_trait_max_cols,
|
|
384
|
+
subsample_n_points=self.config.subsample_n_points,
|
|
385
|
+
# Use new parameters from the updated VisualizationConfig
|
|
386
|
+
subplot_width_inches=self.config.single_sample_multi_trait_subplot_width_inches,
|
|
387
|
+
dpi=self.config.single_sample_multi_trait_dpi,
|
|
388
|
+
enable_pdf_output=self.config.enable_pdf_output
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Annotation distribution plots
|
|
392
|
+
sample_data = obs_ldsc_merged.query(f'sample_name == "{sample_name}"')
|
|
393
|
+
(pixel_width, pixel_height), point_size = estimate_plotly_point_size(sample_data[['sx', 'sy']].values)
|
|
394
|
+
|
|
395
|
+
for annotation in self.config.cauchy_annotations:
|
|
396
|
+
annotation_dir = annotation_folder / annotation
|
|
397
|
+
annotation_dir.mkdir(exist_ok=True)
|
|
398
|
+
|
|
399
|
+
annotation_color_map = _create_color_map(obs_ldsc_merged[annotation].unique(), hex=True)
|
|
400
|
+
fig = self._draw_scatter(sample_data, title=f'{annotation}_{sample_name}',
|
|
401
|
+
point_size=point_size, width=pixel_width, height=pixel_height,
|
|
402
|
+
hover_text_list=self.config.hover_text_list,
|
|
403
|
+
color_by=annotation, color_map=annotation_color_map)
|
|
404
|
+
|
|
405
|
+
annotation_png = annotation_dir / f'{sample_name}_{annotation}.png'
|
|
406
|
+
annotation_html = annotation_dir / f'{sample_name}_{annotation}.html'
|
|
407
|
+
fig.write_image(annotation_png)
|
|
408
|
+
fig.write_html(annotation_html)
|
|
409
|
+
|
|
410
|
+
# Generate multi-sample annotation plots
|
|
411
|
+
print("Generating multi-sample annotation plots...")
|
|
412
|
+
sample_count = len(sample_names_list)
|
|
413
|
+
n_rows, n_cols = self._calculate_optimal_grid_layout(
|
|
414
|
+
item_count=sample_count,
|
|
415
|
+
max_cols=self.config.single_sample_multi_trait_max_cols
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
for annotation in tqdm(self.config.cauchy_annotations,
|
|
419
|
+
desc='Generating multi-sample annotation plots'):
|
|
420
|
+
annotation_dir = annotation_folder / annotation
|
|
421
|
+
annotation_dir.mkdir(exist_ok=True)
|
|
422
|
+
|
|
423
|
+
self._create_multi_sample_annotation_plot(
|
|
424
|
+
obs_ldsc_merged=obs_ldsc_merged,
|
|
425
|
+
annotation=annotation,
|
|
426
|
+
sample_names_list=sample_names_list,
|
|
427
|
+
output_dir=annotation_dir,
|
|
428
|
+
n_rows=n_rows,
|
|
429
|
+
n_cols=n_cols
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
def _create_single_trait_multi_sample_plots(self, obs_ldsc_merged: pd.DataFrame):
|
|
433
|
+
"""Generate single trait multi-sample visualizations using matplotlib"""
|
|
434
|
+
|
|
435
|
+
trait_names = self.config.trait_name_list
|
|
436
|
+
|
|
437
|
+
# Create output directory
|
|
438
|
+
single_trait_folder = self.config.visualization_result_dir / 'single_trait_multi_sample_plot'
|
|
439
|
+
single_trait_folder.mkdir(exist_ok=True, parents=True)
|
|
440
|
+
|
|
441
|
+
# Prepare coordinate columns (assuming sx, sy are the spatial coordinates)
|
|
442
|
+
obs_ldsc_merged = obs_ldsc_merged.copy()
|
|
443
|
+
|
|
444
|
+
# Get sample count to determine grid dimensions
|
|
445
|
+
sample_count = obs_ldsc_merged['sample_name'].nunique()
|
|
446
|
+
n_rows, n_cols = self._calculate_optimal_grid_layout(
|
|
447
|
+
item_count=sample_count,
|
|
448
|
+
max_cols=self.config.single_trait_multi_sample_max_cols
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
print(f"Generating plots for {len(trait_names)} traits with {sample_count} samples in {n_rows}x{n_cols} grid")
|
|
452
|
+
|
|
453
|
+
# Generate visualization for each trait
|
|
454
|
+
for trait in tqdm(trait_names, desc="Generating single trait multi-sample plots"):
|
|
455
|
+
if trait not in obs_ldsc_merged.columns:
|
|
456
|
+
print(f"Warning: Trait {trait} not found in data. Skipping.")
|
|
457
|
+
continue
|
|
458
|
+
|
|
459
|
+
self._create_single_trait_multi_sample_matplotlib_plot(
|
|
460
|
+
obs_ldsc_merged=obs_ldsc_merged,
|
|
461
|
+
trait_abbreviation=trait,
|
|
462
|
+
output_png_path=single_trait_folder / f'{trait}_multi_sample_plot.jpg',
|
|
463
|
+
output_pdf_path=single_trait_folder / f'{trait}_multi_sample_plot.pdf',
|
|
464
|
+
n_rows=n_rows,
|
|
465
|
+
n_cols=n_cols,
|
|
466
|
+
subplot_width_inches=self.config.single_trait_multi_sample_subplot_width_inches,
|
|
467
|
+
scaling_factor=self.config.single_trait_multi_sample_scaling_factor,
|
|
468
|
+
dpi=self.config.single_trait_multi_sample_dpi,
|
|
469
|
+
enable_pdf_output=self.config.enable_pdf_output,
|
|
470
|
+
share_coords=self.config.share_coords
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
def _calculate_optimal_grid_layout(self, item_count: int, max_cols: int = 8) -> tuple[int, int]:
|
|
474
|
+
"""
|
|
475
|
+
Calculate optimal grid dimensions (rows, cols) for displaying items in a grid.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
item_count: Number of items to display
|
|
479
|
+
max_cols: Maximum number of columns allowed
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
tuple: (n_rows, n_cols) for optimal grid layout
|
|
483
|
+
"""
|
|
484
|
+
import math
|
|
485
|
+
|
|
486
|
+
if item_count <= 0:
|
|
487
|
+
return 1, 1
|
|
488
|
+
|
|
489
|
+
# For small counts, use simple layouts favoring horizontal arrangement
|
|
490
|
+
if item_count <= 3:
|
|
491
|
+
return 1, item_count
|
|
492
|
+
elif item_count <= 6:
|
|
493
|
+
return 2, math.ceil(item_count / 2)
|
|
494
|
+
elif item_count <= 12:
|
|
495
|
+
return 3, math.ceil(item_count / 3)
|
|
496
|
+
else:
|
|
497
|
+
# For larger counts, try to create a roughly square grid
|
|
498
|
+
# but respect the max_cols constraint
|
|
499
|
+
optimal_cols = min(math.ceil(math.sqrt(item_count)), max_cols)
|
|
500
|
+
optimal_rows = math.ceil(item_count / optimal_cols)
|
|
501
|
+
|
|
502
|
+
# If we hit the max_cols limit, recalculate rows
|
|
503
|
+
if optimal_cols >= max_cols:
|
|
504
|
+
n_cols = max_cols
|
|
505
|
+
n_rows = math.ceil(item_count / max_cols)
|
|
506
|
+
else:
|
|
507
|
+
n_rows = optimal_rows
|
|
508
|
+
n_cols = optimal_cols
|
|
509
|
+
|
|
510
|
+
print(f"Calculated grid layout: {n_rows} rows × {n_cols} cols for {item_count} items")
|
|
511
|
+
return n_rows, n_cols
|
|
512
|
+
|
|
513
|
+
def _create_single_trait_multi_sample_matplotlib_plot(self, obs_ldsc_merged: pd.DataFrame, trait_abbreviation: str,
|
|
514
|
+
sample_name_list: list[str] | None = None,
|
|
515
|
+
output_png_path: Path | None = None,
|
|
516
|
+
output_pdf_path: Path | None = None,
|
|
517
|
+
n_rows: int = 6, n_cols: int = 8,
|
|
518
|
+
subplot_width_inches: float = 4.0,
|
|
519
|
+
scaling_factor: float = 1.0, dpi: int = 300,
|
|
520
|
+
enable_pdf_output: bool = True,
|
|
521
|
+
show=False,
|
|
522
|
+
share_coords: bool = False
|
|
523
|
+
):
|
|
524
|
+
"""
|
|
525
|
+
Create and save a visualization for a specific trait showing all samples
|
|
526
|
+
"""
|
|
527
|
+
|
|
528
|
+
matplotlib.rcParams['figure.dpi'] = dpi
|
|
529
|
+
print(f"Creating visualization for {trait_abbreviation}")
|
|
530
|
+
|
|
531
|
+
# Check if trait exists in the dataframe
|
|
532
|
+
if trait_abbreviation not in obs_ldsc_merged.columns:
|
|
533
|
+
print(f"Warning: Trait {trait_abbreviation} not found in the data. Skipping.")
|
|
534
|
+
return
|
|
535
|
+
|
|
536
|
+
# Set font to Arial with fallbacks to avoid warnings
|
|
537
|
+
plt.rcParams['font.family'] = 'sans-serif'
|
|
538
|
+
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Bitstream Vera Sans', 'sans-serif']
|
|
539
|
+
|
|
540
|
+
custom_cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap', self.custom_colors_list)
|
|
541
|
+
custom_cmap = custom_cmap.reversed()
|
|
542
|
+
|
|
543
|
+
# Calculate figure size based on subplot dimensions
|
|
544
|
+
fig_width = n_cols * subplot_width_inches
|
|
545
|
+
fig_height = n_rows * subplot_width_inches
|
|
546
|
+
|
|
547
|
+
# Create figure with title
|
|
548
|
+
fig = plt.figure(figsize=(fig_width, fig_height))
|
|
549
|
+
|
|
550
|
+
# Add main title
|
|
551
|
+
fig.suptitle(trait_abbreviation, fontsize=24, fontweight='bold', y=0.98)
|
|
552
|
+
|
|
553
|
+
# Create grid of subplots
|
|
554
|
+
grid_specs = fig.add_gridspec(nrows=n_rows, ncols=n_cols, wspace=0.1, hspace=0.1)
|
|
555
|
+
|
|
556
|
+
_, pass_filter_mask = remove_outliers_MAD(obs_ldsc_merged[trait_abbreviation])
|
|
557
|
+
obs_ldsc_merged_filtered = obs_ldsc_merged[pass_filter_mask]
|
|
558
|
+
|
|
559
|
+
pd_min = 0
|
|
560
|
+
pd_max = obs_ldsc_merged_filtered[trait_abbreviation].quantile(0.999)
|
|
561
|
+
|
|
562
|
+
print(f"Color scale min: {pd_min}, max: {pd_max}")
|
|
563
|
+
# Get list of sample names - use provided list or fallback to sorted unique
|
|
564
|
+
if sample_name_list is None:
|
|
565
|
+
sample_name_list = sorted(obs_ldsc_merged_filtered['sample_name'].unique())
|
|
566
|
+
|
|
567
|
+
# get the x and y limit if share coordinates
|
|
568
|
+
if share_coords:
|
|
569
|
+
x_limits = (obs_ldsc_merged_filtered['sx'].min(), obs_ldsc_merged_filtered['sx'].max())
|
|
570
|
+
y_limits = (obs_ldsc_merged_filtered['sy'].min(), obs_ldsc_merged_filtered['sy'].max())
|
|
571
|
+
else:
|
|
572
|
+
x_limits = None
|
|
573
|
+
y_limits = None
|
|
574
|
+
|
|
575
|
+
# Create a scatter plot for each sample
|
|
576
|
+
for position_num, select_sample_name in enumerate(sample_name_list[:n_rows * n_cols], 1):
|
|
577
|
+
# Calculate row and column in the grid
|
|
578
|
+
row = (position_num - 1) // n_cols
|
|
579
|
+
col = (position_num - 1) % n_cols
|
|
580
|
+
|
|
581
|
+
# Create subplot
|
|
582
|
+
ax = fig.add_subplot(grid_specs[row, col])
|
|
583
|
+
|
|
584
|
+
# Get data for this sample
|
|
585
|
+
sample_data = obs_ldsc_merged_filtered[obs_ldsc_merged_filtered['sample_name'] == select_sample_name]
|
|
586
|
+
|
|
587
|
+
point_size = self.estimate_matplitlib_scatter_marker_size(ax, sample_data[['sx', 'sy']].values,
|
|
588
|
+
x_limits=x_limits, y_limits=y_limits)
|
|
589
|
+
point_size *= scaling_factor # Apply scaling factor
|
|
590
|
+
# Create scatter plot
|
|
591
|
+
scatter = ax.scatter(
|
|
592
|
+
sample_data['sx'],
|
|
593
|
+
sample_data['sy'],
|
|
594
|
+
c=sample_data[trait_abbreviation],
|
|
595
|
+
cmap=custom_cmap,
|
|
596
|
+
s=point_size,
|
|
597
|
+
vmin=pd_min,
|
|
598
|
+
vmax=pd_max,
|
|
599
|
+
marker='o',
|
|
600
|
+
edgecolors='none',
|
|
601
|
+
rasterized=True if output_pdf_path is not None and enable_pdf_output else False
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
if self.config.plot_origin == 'upper':
|
|
605
|
+
ax.invert_yaxis()
|
|
606
|
+
|
|
607
|
+
ax.axis('off')
|
|
608
|
+
# Add sample label as title
|
|
609
|
+
ax.set_title(select_sample_name, fontsize=12, pad=None)
|
|
610
|
+
|
|
611
|
+
# Add colorbar to the right side
|
|
612
|
+
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) # [left, bottom, width, height]
|
|
613
|
+
cbar = fig.colorbar(scatter, cax=cbar_ax)
|
|
614
|
+
cbar.set_label('$-\\log_{10}p$', fontsize=12, fontweight='bold')
|
|
615
|
+
|
|
616
|
+
if output_png_path is not None:
|
|
617
|
+
output_png_path.parent.mkdir(parents=True, exist_ok=True)
|
|
618
|
+
plt.savefig(output_png_path, dpi=dpi, bbox_inches='tight', )
|
|
619
|
+
|
|
620
|
+
if output_pdf_path is not None and enable_pdf_output:
|
|
621
|
+
output_pdf_path.parent.mkdir(parents=True, exist_ok=True)
|
|
622
|
+
plt.savefig(output_pdf_path, bbox_inches='tight', )
|
|
623
|
+
|
|
624
|
+
# Close the figure to free memory only if not returning it
|
|
625
|
+
if show:
|
|
626
|
+
plt.show()
|
|
627
|
+
|
|
628
|
+
gc.collect()
|
|
629
|
+
return fig
|
|
630
|
+
|
|
631
|
+
def _create_single_sample_multi_trait_plots(self,
|
|
632
|
+
obs_ldsc_merged: pd.DataFrame,
|
|
633
|
+
trait_names: list[str],
|
|
634
|
+
sample_name: str,
|
|
635
|
+
# New arguments for output paths, as matplotlib saves directly
|
|
636
|
+
output_png_path: Path | None,
|
|
637
|
+
output_pdf_path: Path | None,
|
|
638
|
+
# Arguments from original function signature, adapted for matplotlib
|
|
639
|
+
max_cols: int = 5,
|
|
640
|
+
subsample_n_points: int | None = None,
|
|
641
|
+
# subplot_width is now interpreted as inches for figsize
|
|
642
|
+
subplot_width_inches: float = 4.0,
|
|
643
|
+
dpi: int = 300,
|
|
644
|
+
enable_pdf_output: bool = True
|
|
645
|
+
):
|
|
646
|
+
|
|
647
|
+
print(f"Creating Matplotlib-based multi-trait visualization for sample: {sample_name}")
|
|
648
|
+
|
|
649
|
+
# 1. Filter data for the specific sample and subsample if requested
|
|
650
|
+
sample_plot_data = obs_ldsc_merged[obs_ldsc_merged['sample_name'] == sample_name].copy()
|
|
651
|
+
if subsample_n_points and len(sample_plot_data) > subsample_n_points:
|
|
652
|
+
print(f"Subsampling to {subsample_n_points} points for plotting.")
|
|
653
|
+
sample_plot_data = sample_plot_data.sample(n=subsample_n_points, random_state=42)
|
|
654
|
+
|
|
655
|
+
if sample_plot_data.empty:
|
|
656
|
+
print(f"Warning: No data found for sample '{sample_name}'. Skipping plot generation.")
|
|
657
|
+
return
|
|
658
|
+
|
|
659
|
+
# 2. Calculate optimal grid layout for subplots
|
|
660
|
+
n_traits = len(trait_names)
|
|
661
|
+
n_rows, n_cols = self._calculate_optimal_grid_layout(item_count=n_traits, max_cols=max_cols)
|
|
662
|
+
print(f"Plotting {n_traits} traits in a {n_rows}x{n_cols} grid.")
|
|
663
|
+
|
|
664
|
+
# 3. Determine figure size and create figure and axes
|
|
665
|
+
# Estimate subplot height based on data's aspect ratio to avoid distortion
|
|
666
|
+
x_range = sample_plot_data['sx'].max() - sample_plot_data['sx'].min()
|
|
667
|
+
y_range = sample_plot_data['sy'].max() - sample_plot_data['sy'].min()
|
|
668
|
+
aspect_ratio = y_range / x_range if x_range > 0 else 1.0
|
|
669
|
+
subplot_height_inches = subplot_width_inches * aspect_ratio
|
|
670
|
+
|
|
671
|
+
# Calculate total figure size, adding padding for titles and colorbars
|
|
672
|
+
fig_width = subplot_width_inches * n_cols
|
|
673
|
+
fig_height = (subplot_height_inches * n_rows) * 1.2 # Add 20% vertical space for titles/colorbars
|
|
674
|
+
|
|
675
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), squeeze=False)
|
|
676
|
+
fig.suptitle(f"Sample: {sample_name}", fontsize=16, fontweight='bold')
|
|
677
|
+
|
|
678
|
+
# 4. Define custom colormap and font
|
|
679
|
+
# Set font to Arial with fallbacks to avoid warnings
|
|
680
|
+
plt.rcParams['font.family'] = 'sans-serif'
|
|
681
|
+
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Bitstream Vera Sans', 'sans-serif']
|
|
682
|
+
|
|
683
|
+
custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom_cmap',
|
|
684
|
+
self.custom_colors_list).reversed()
|
|
685
|
+
|
|
686
|
+
# 5. Iterate through traits and create each subplot
|
|
687
|
+
axes_flat = axes.flatten()
|
|
688
|
+
for i, trait in enumerate(trait_names):
|
|
689
|
+
if i >= len(axes_flat):
|
|
690
|
+
break # Should not happen with correct grid calculation, but is a safe guard
|
|
691
|
+
|
|
692
|
+
ax = axes_flat[i]
|
|
693
|
+
|
|
694
|
+
# Estimate marker size to fill space without overlap
|
|
695
|
+
point_size = estimate_matplotlib_scatter_marker_size(ax, sample_plot_data[['sx', 'sy']].values)
|
|
696
|
+
|
|
697
|
+
# Determine color scale, capping at the 99.9th percentile to handle outliers
|
|
698
|
+
sample_trait_data = sample_plot_data[['sx', 'sy', trait]].dropna()
|
|
699
|
+
trait_values, mask = remove_outliers_MAD(sample_trait_data[trait])
|
|
700
|
+
sample_trait_data = sample_trait_data[mask] # filter out outliers
|
|
701
|
+
|
|
702
|
+
vmin = 0
|
|
703
|
+
vmax = trait_values.quantile(0.999)
|
|
704
|
+
if pd.isna(vmax) or vmax == 0:
|
|
705
|
+
vmax = trait_values.max() if trait_values.max() > 0 else 1.0
|
|
706
|
+
|
|
707
|
+
# Create the scatter plot
|
|
708
|
+
scatter = ax.scatter(
|
|
709
|
+
sample_trait_data['sx'],
|
|
710
|
+
sample_trait_data['sy'],
|
|
711
|
+
c=trait_values,
|
|
712
|
+
cmap=custom_cmap,
|
|
713
|
+
s=point_size,
|
|
714
|
+
vmin=vmin,
|
|
715
|
+
vmax=vmax,
|
|
716
|
+
marker='o',
|
|
717
|
+
edgecolors='none',
|
|
718
|
+
rasterized=True if output_pdf_path is not None and enable_pdf_output else False
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
ax.set_title(trait, fontsize=16, pad=10, fontweight='bold')
|
|
722
|
+
ax.set_aspect('equal', adjustable='box')
|
|
723
|
+
|
|
724
|
+
if self.config.plot_origin == 'upper':
|
|
725
|
+
ax.invert_yaxis()
|
|
726
|
+
|
|
727
|
+
ax.axis('off')
|
|
728
|
+
|
|
729
|
+
# Add a colorbar to each subplot
|
|
730
|
+
cbar = fig.colorbar(scatter, ax=ax, orientation='horizontal', pad=0.1, fraction=0.05)
|
|
731
|
+
cbar.set_label('$-\\log_{10}p$', fontsize=8)
|
|
732
|
+
cbar.ax.tick_params(labelsize=7)
|
|
733
|
+
|
|
734
|
+
# Hide any unused axes in the grid
|
|
735
|
+
for j in range(len(trait_names), len(axes_flat)):
|
|
736
|
+
axes_flat[j].axis('off')
|
|
737
|
+
#
|
|
738
|
+
# # 6. Adjust layout and save the figure
|
|
739
|
+
# fig.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust for suptitle and bottom elements
|
|
740
|
+
|
|
741
|
+
# Only proceed with saving if paths are provided
|
|
742
|
+
if output_png_path is not None:
|
|
743
|
+
output_png_path.parent.mkdir(parents=True, exist_ok=True)
|
|
744
|
+
plt.savefig(output_png_path, dpi=dpi, bbox_inches='tight', facecolor='white')
|
|
745
|
+
print(f"Saved multi-trait plot for '{sample_name}' to:\n - {output_png_path}")
|
|
746
|
+
|
|
747
|
+
if output_pdf_path is not None and enable_pdf_output:
|
|
748
|
+
output_pdf_path.parent.mkdir(parents=True, exist_ok=True)
|
|
749
|
+
plt.savefig(output_pdf_path, bbox_inches='tight', facecolor='white')
|
|
750
|
+
print(f"Saved multi-trait plot for '{sample_name}' to:\n - {output_pdf_path}")
|
|
751
|
+
|
|
752
|
+
# Clean up to free memory
|
|
753
|
+
plt.close(fig)
|
|
754
|
+
|
|
755
|
+
def _draw_scatter(self, space_coord_concat: pd.DataFrame, title: str | None = None,
|
|
756
|
+
fig_style: str = 'light', point_size: int | None = None,
|
|
757
|
+
hover_text_list: list[str] | None = None,
|
|
758
|
+
width: int = 800, height: int = 600, annotation: str | None = None,
|
|
759
|
+
color_by: str = 'logp', color_map: dict | None = None):
|
|
760
|
+
"""Create scatter plot (adapted from original draw_scatter function)"""
|
|
761
|
+
# Set theme based on fig_style
|
|
762
|
+
if fig_style == 'dark':
|
|
763
|
+
px.defaults.template = "plotly_dark"
|
|
764
|
+
else:
|
|
765
|
+
px.defaults.template = "plotly_white"
|
|
766
|
+
|
|
767
|
+
custom_color_scale = [
|
|
768
|
+
(1, '#d73027'), # Red
|
|
769
|
+
(7 / 8, '#f46d43'), # Red-Orange
|
|
770
|
+
(6 / 8, '#fdae61'), # Orange
|
|
771
|
+
(5 / 8, '#fee090'), # Light Orange
|
|
772
|
+
(4 / 8, '#e0f3f8'), # Light Blue
|
|
773
|
+
(3 / 8, '#abd9e9'), # Sky Blue
|
|
774
|
+
(2 / 8, '#74add1'), # Medium Blue
|
|
775
|
+
(1 / 8, '#4575b4'), # Dark Blue
|
|
776
|
+
(0, '#313695') # Deep Blue
|
|
777
|
+
]
|
|
778
|
+
custom_color_scale.reverse()
|
|
779
|
+
|
|
780
|
+
# if category data
|
|
781
|
+
if not pd.api.types.is_numeric_dtype(space_coord_concat[color_by]):
|
|
782
|
+
# Create the scatter plot
|
|
783
|
+
fig = px.scatter(
|
|
784
|
+
space_coord_concat,
|
|
785
|
+
x='sx',
|
|
786
|
+
y='sy',
|
|
787
|
+
color=color_by,
|
|
788
|
+
# symbol=annotation,
|
|
789
|
+
title=title,
|
|
790
|
+
color_discrete_map=color_map,
|
|
791
|
+
hover_name=color_by,
|
|
792
|
+
hover_data=hover_text_list,
|
|
793
|
+
# color_continuous_scale=custom_color_scale,
|
|
794
|
+
# range_color=[0, max(space_coord_concat[color_by])],
|
|
795
|
+
)
|
|
796
|
+
else:
|
|
797
|
+
fig = px.scatter(
|
|
798
|
+
space_coord_concat,
|
|
799
|
+
x='sx',
|
|
800
|
+
y='sy',
|
|
801
|
+
color=color_by,
|
|
802
|
+
symbol=annotation,
|
|
803
|
+
title=title,
|
|
804
|
+
hover_name=color_by,
|
|
805
|
+
hover_data=hover_text_list,
|
|
806
|
+
color_continuous_scale=custom_color_scale,
|
|
807
|
+
range_color=[0, space_coord_concat[color_by].max()],
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
# Update marker size if specified
|
|
811
|
+
if point_size is not None:
|
|
812
|
+
fig.update_traces(marker=dict(size=point_size, symbol='circle'))
|
|
813
|
+
|
|
814
|
+
# Update layout for figure size
|
|
815
|
+
fig.update_layout(
|
|
816
|
+
autosize=False,
|
|
817
|
+
width=width,
|
|
818
|
+
height=height,
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
# Adjusting the legend - Updated position and marker size
|
|
822
|
+
fig.update_layout(
|
|
823
|
+
legend=dict(
|
|
824
|
+
yanchor="middle", # Anchor point for y
|
|
825
|
+
y=0.5, # Center vertically
|
|
826
|
+
xanchor="left", # Anchor point for x
|
|
827
|
+
x=1.02, # Position just outside the plot
|
|
828
|
+
font=dict(
|
|
829
|
+
size=10,
|
|
830
|
+
),
|
|
831
|
+
itemsizing='constant', # Makes legend markers a constant size
|
|
832
|
+
itemwidth=30, # Adjust width of legend items
|
|
833
|
+
)
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
# Update colorbar to be at the bottom and horizontal
|
|
837
|
+
fig.update_layout(
|
|
838
|
+
coloraxis_colorbar=dict(
|
|
839
|
+
orientation='h',
|
|
840
|
+
x=0.5,
|
|
841
|
+
y=-0.0,
|
|
842
|
+
xanchor='center',
|
|
843
|
+
yanchor='top',
|
|
844
|
+
len=0.75,
|
|
845
|
+
title=dict(
|
|
846
|
+
text='-log10(p)' if color_by == 'logp' else color_by,
|
|
847
|
+
side='top'
|
|
848
|
+
)
|
|
849
|
+
)
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
# Remove gridlines, axis labels, and ticks
|
|
853
|
+
fig.update_xaxes(
|
|
854
|
+
showgrid=False,
|
|
855
|
+
zeroline=False,
|
|
856
|
+
showticklabels=False,
|
|
857
|
+
title=None,
|
|
858
|
+
scaleanchor='y',
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
fig.update_yaxes(
|
|
862
|
+
showgrid=False,
|
|
863
|
+
zeroline=False,
|
|
864
|
+
showticklabels=False,
|
|
865
|
+
title=None,
|
|
866
|
+
autorange='reversed' if self.config.plot_origin == 'upper' else True
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
# Adjust margins to ensure no clipping and equal axis ratio
|
|
870
|
+
fig.update_layout(
|
|
871
|
+
margin=dict(l=0, r=100, t=20, b=10), # Increased right margin to accommodate legend
|
|
872
|
+
height=width
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
# Adjust the title location and font size
|
|
876
|
+
fig.update_layout(
|
|
877
|
+
title=dict(
|
|
878
|
+
y=0.98,
|
|
879
|
+
x=0.5,
|
|
880
|
+
xanchor='center',
|
|
881
|
+
yanchor='top',
|
|
882
|
+
font=dict(
|
|
883
|
+
size=20
|
|
884
|
+
)
|
|
885
|
+
))
|
|
886
|
+
|
|
887
|
+
return fig
|
|
888
|
+
|
|
889
|
+
@classmethod
|
|
890
|
+
def estimate_matplitlib_scatter_marker_size(cls, ax: matplotlib.axes.Axes, coordinates: np.ndarray,
|
|
891
|
+
x_limits: tuple | None = None,
|
|
892
|
+
y_limits: tuple | None = None) -> float:
|
|
893
|
+
"""Alias for estimate_matplotlib_scatter_marker_size (with typo) for backward compatibility."""
|
|
894
|
+
return estimate_matplotlib_scatter_marker_size(ax, coordinates, x_limits, y_limits)
|
|
895
|
+
|
|
896
|
+
@classmethod
|
|
897
|
+
def estimate_matplotlib_scatter_marker_size(cls, ax: matplotlib.axes.Axes, coordinates: np.ndarray,
|
|
898
|
+
x_limits: tuple | None = None,
|
|
899
|
+
y_limits: tuple | None = None) -> float:
|
|
900
|
+
"""Alias for estimate_matplotlib_scatter_marker_size for backward compatibility."""
|
|
901
|
+
return estimate_matplotlib_scatter_marker_size(ax, coordinates, x_limits, y_limits)
|
|
902
|
+
|
|
903
|
+
def _create_multi_sample_annotation_plot(self, obs_ldsc_merged: pd.DataFrame, annotation: str,
|
|
904
|
+
sample_names_list: list, output_dir: Path,
|
|
905
|
+
n_rows: int, n_cols: int,
|
|
906
|
+
fig_width: float = 20, fig_height: float = 15,
|
|
907
|
+
scaling_factor: float = 1.0, dpi: int = 300):
|
|
908
|
+
"""Create multi-sample annotation plot using matplotlib with subplots for each sample"""
|
|
909
|
+
|
|
910
|
+
print(f"Creating multi-sample plot for annotation: {annotation}")
|
|
911
|
+
|
|
912
|
+
# Create figure
|
|
913
|
+
fig = plt.figure(figsize=(fig_width, fig_height))
|
|
914
|
+
fig.suptitle(f'{annotation} - All Samples', fontsize=24, fontweight='bold', y=0.98)
|
|
915
|
+
|
|
916
|
+
# Create grid of subplots
|
|
917
|
+
grid_specs = fig.add_gridspec(nrows=n_rows, ncols=n_cols, wspace=0.1, hspace=0.1)
|
|
918
|
+
|
|
919
|
+
# Get unique annotation values and create color map
|
|
920
|
+
unique_annotations = obs_ldsc_merged[annotation].unique()
|
|
921
|
+
if pd.api.types.is_numeric_dtype(obs_ldsc_merged[annotation]):
|
|
922
|
+
|
|
923
|
+
custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom_cmap', self.custom_colors_list)
|
|
924
|
+
cmap = custom_cmap.reversed()
|
|
925
|
+
norm = plt.Normalize(vmin=obs_ldsc_merged[annotation].min(),
|
|
926
|
+
vmax=obs_ldsc_merged[annotation].max())
|
|
927
|
+
else:
|
|
928
|
+
# For categorical annotations, use discrete colors
|
|
929
|
+
color_map = _create_color_map(unique_annotations, hex=False)
|
|
930
|
+
|
|
931
|
+
# Create scatter plot for each sample
|
|
932
|
+
for position_num, sample_name in enumerate(sample_names_list[:n_rows * n_cols], 1):
|
|
933
|
+
# Calculate row and column in the grid
|
|
934
|
+
row = (position_num - 1) // n_cols
|
|
935
|
+
col = (position_num - 1) % n_cols
|
|
936
|
+
|
|
937
|
+
# Create subplot
|
|
938
|
+
ax = fig.add_subplot(grid_specs[row, col])
|
|
939
|
+
|
|
940
|
+
# Get data for this sample
|
|
941
|
+
sample_data = obs_ldsc_merged[obs_ldsc_merged['sample_name'] == sample_name]
|
|
942
|
+
|
|
943
|
+
# Estimate point size based on data density
|
|
944
|
+
point_size = estimate_matplotlib_scatter_marker_size(ax, sample_data[['sx', 'sy']].values)
|
|
945
|
+
point_size *= scaling_factor # Apply scaling factor
|
|
946
|
+
|
|
947
|
+
# Create scatter plot
|
|
948
|
+
if pd.api.types.is_numeric_dtype(obs_ldsc_merged[annotation]):
|
|
949
|
+
ax.scatter(sample_data['sx'], sample_data['sy'],
|
|
950
|
+
c=sample_data[annotation], cmap=cmap, norm=norm,
|
|
951
|
+
s=point_size, alpha=1.0, edgecolors='none')
|
|
952
|
+
else:
|
|
953
|
+
# For categorical data, plot each category separately
|
|
954
|
+
for cat in unique_annotations:
|
|
955
|
+
cat_data = sample_data[sample_data[annotation] == cat]
|
|
956
|
+
if len(cat_data) > 0:
|
|
957
|
+
ax.scatter(cat_data['sx'], cat_data['sy'],
|
|
958
|
+
c=[color_map[cat]], s=point_size, alpha=1.0,
|
|
959
|
+
edgecolors='none', label=cat)
|
|
960
|
+
|
|
961
|
+
# Set subplot title
|
|
962
|
+
ax.set_title(sample_name, fontsize=10)
|
|
963
|
+
ax.set_aspect('equal')
|
|
964
|
+
if self.config.plot_origin == 'upper':
|
|
965
|
+
ax.invert_yaxis()
|
|
966
|
+
ax.axis('off')
|
|
967
|
+
|
|
968
|
+
# Add colorbar for numeric annotations or legend for categorical
|
|
969
|
+
if pd.api.types.is_numeric_dtype(obs_ldsc_merged[annotation]):
|
|
970
|
+
# Create a colorbar on the right side of the figure
|
|
971
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
972
|
+
sm.set_array([])
|
|
973
|
+
fig.subplots_adjust(right=0.85)
|
|
974
|
+
cbar_ax = fig.add_axes([0.88, 0.15, 0.03, 0.7])
|
|
975
|
+
cbar = fig.colorbar(sm, cax=cbar_ax, orientation='vertical')
|
|
976
|
+
cbar.set_label(annotation, fontsize=14)
|
|
977
|
+
else:
|
|
978
|
+
# Create a legend on the right side of the figure
|
|
979
|
+
handles = [plt.Line2D([0], [0], marker='o', color='w', label=label,
|
|
980
|
+
markerfacecolor=color, markersize=getattr(self.config, 'legend_marker_size', 10))
|
|
981
|
+
for label, color in color_map.items()]
|
|
982
|
+
fig.subplots_adjust(right=0.8)
|
|
983
|
+
fig.legend(handles=handles, title=annotation, loc='center left', bbox_to_anchor=(0.85, 0.5))
|
|
984
|
+
|
|
985
|
+
# Save the plot if output_dir is provided
|
|
986
|
+
if output_dir:
|
|
987
|
+
output_path = output_dir / f'multi_sample_{annotation}.png'
|
|
988
|
+
plt.savefig(output_path, dpi=dpi, bbox_inches='tight', facecolor='white')
|
|
989
|
+
print(f"Saved multi-sample annotation plot: {output_path}")
|
|
990
|
+
|
|
991
|
+
return fig
|
|
992
|
+
|
|
993
|
+
def _run_cauchy_analysis(self, obs_ldsc_merged: pd.DataFrame):
|
|
994
|
+
"""Run Cauchy combination analysis"""
|
|
995
|
+
|
|
996
|
+
trait_names = self.config.trait_name_list
|
|
997
|
+
|
|
998
|
+
for annotation_col in self.config.cauchy_annotations:
|
|
999
|
+
print(f"Running Cauchy analysis for {annotation_col}...")
|
|
1000
|
+
|
|
1001
|
+
cauchy_results = self._run_cauchy_combination_per_annotation(
|
|
1002
|
+
obs_ldsc_merged, annotation_col=annotation_col, trait_cols=trait_names)
|
|
1003
|
+
|
|
1004
|
+
# Save results
|
|
1005
|
+
output_file = self.config.visualization_result_dir / f'cauchy_results_{annotation_col}.csv'
|
|
1006
|
+
self._save_cauchy_results_to_csv(cauchy_results, output_file)
|
|
1007
|
+
|
|
1008
|
+
# Generate heatmaps
|
|
1009
|
+
self._generate_cauchy_heatmaps(cauchy_results, annotation_col)
|
|
1010
|
+
|
|
1011
|
+
def _run_cauchy_combination_per_annotation(self, df: pd.DataFrame, annotation_col: str,
|
|
1012
|
+
trait_cols: list[str], max_workers=None):
|
|
1013
|
+
"""
|
|
1014
|
+
Runs the Cauchy combination on each annotation category for each given trait in parallel.
|
|
1015
|
+
Also calculates odds ratios with confidence intervals for significant spots in each annotation.
|
|
1016
|
+
"""
|
|
1017
|
+
from functools import partial
|
|
1018
|
+
|
|
1019
|
+
import statsmodels.api as sm
|
|
1020
|
+
from scipy.stats import fisher_exact
|
|
1021
|
+
|
|
1022
|
+
results_dict = {}
|
|
1023
|
+
annotations = df[annotation_col].unique()
|
|
1024
|
+
|
|
1025
|
+
# Helper function to process a single trait for a given annotation
|
|
1026
|
+
def process_trait(trait, anno_data, all_data, annotation):
|
|
1027
|
+
# Calculate significance threshold (Bonferroni correction)
|
|
1028
|
+
sig_threshold = 0.05 / len(all_data)
|
|
1029
|
+
|
|
1030
|
+
# Get p-values for this annotation and trait
|
|
1031
|
+
log10p = anno_data[trait].values
|
|
1032
|
+
log10p, mask = remove_outliers_MAD(log10p, )
|
|
1033
|
+
p_values = 10 ** (-log10p) # convert from log10(p) to p
|
|
1034
|
+
|
|
1035
|
+
# Calculate Cauchy combination and median
|
|
1036
|
+
p_cauchy_val = self._acat_test(p_values)
|
|
1037
|
+
p_median_val = np.median(p_values)
|
|
1038
|
+
|
|
1039
|
+
# Calculate significance statistics
|
|
1040
|
+
sig_spots_in_anno = np.sum(p_values < sig_threshold)
|
|
1041
|
+
total_spots_in_anno = len(p_values)
|
|
1042
|
+
|
|
1043
|
+
# Get p-values for other annotations
|
|
1044
|
+
other_annotations_mask = all_data[annotation_col] != annotation
|
|
1045
|
+
other_p_values = 10 ** (-all_data.loc[other_annotations_mask, trait].values)
|
|
1046
|
+
sig_spots_elsewhere = np.sum(other_p_values < sig_threshold)
|
|
1047
|
+
total_spots_elsewhere = len(other_p_values)
|
|
1048
|
+
|
|
1049
|
+
# Odds ratio calculation using Fisher's exact test
|
|
1050
|
+
try:
|
|
1051
|
+
# Create contingency table
|
|
1052
|
+
contingency_table = np.array([
|
|
1053
|
+
[sig_spots_in_anno, total_spots_in_anno - sig_spots_in_anno],
|
|
1054
|
+
[sig_spots_elsewhere, total_spots_elsewhere - sig_spots_elsewhere]
|
|
1055
|
+
])
|
|
1056
|
+
|
|
1057
|
+
# Calculate odds ratio and p-value using Fisher's exact test
|
|
1058
|
+
odds_ratio, p_value = fisher_exact(contingency_table)
|
|
1059
|
+
|
|
1060
|
+
# if odds_ratio is infinite, set it to a large number
|
|
1061
|
+
if odds_ratio == np.inf:
|
|
1062
|
+
odds_ratio = 1e4 # Set to a large number to avoid overflow
|
|
1063
|
+
|
|
1064
|
+
# Calculate confidence intervals
|
|
1065
|
+
table = sm.stats.Table2x2(contingency_table)
|
|
1066
|
+
conf_int = table.oddsratio_confint()
|
|
1067
|
+
ci_low, ci_high = conf_int
|
|
1068
|
+
except Exception as e:
|
|
1069
|
+
# Handle calculation errors
|
|
1070
|
+
odds_ratio = 0
|
|
1071
|
+
p_value = 1
|
|
1072
|
+
ci_low, ci_high = 0, 0
|
|
1073
|
+
print(f"Fisher's exact test failed for {trait} in {annotation}: {e}")
|
|
1074
|
+
|
|
1075
|
+
return {
|
|
1076
|
+
'trait': trait,
|
|
1077
|
+
'p_cauchy': p_cauchy_val,
|
|
1078
|
+
'p_median': p_median_val,
|
|
1079
|
+
'odds_ratio': odds_ratio,
|
|
1080
|
+
'ci_low': ci_low,
|
|
1081
|
+
'ci_high': ci_high,
|
|
1082
|
+
'p_odds_ratio': p_value,
|
|
1083
|
+
'sig_spots': sig_spots_in_anno,
|
|
1084
|
+
'total_spots': total_spots_in_anno,
|
|
1085
|
+
'sig_ratio': sig_spots_in_anno / total_spots_in_anno if total_spots_in_anno > 0 else 0,
|
|
1086
|
+
'overall_sig_spots': sig_spots_in_anno + sig_spots_elsewhere,
|
|
1087
|
+
'overall_spots': total_spots_in_anno + total_spots_elsewhere
|
|
1088
|
+
}
|
|
1089
|
+
|
|
1090
|
+
# Process each annotation (sequential)
|
|
1091
|
+
for anno in tqdm(annotations, desc="Processing annotations"):
|
|
1092
|
+
df_anno = df[df[annotation_col] == anno]
|
|
1093
|
+
|
|
1094
|
+
# Create a partial function with fixed parameters
|
|
1095
|
+
process_trait_for_anno = partial(process_trait,
|
|
1096
|
+
anno_data=df_anno,
|
|
1097
|
+
all_data=df,
|
|
1098
|
+
annotation=anno)
|
|
1099
|
+
|
|
1100
|
+
# Process traits in parallel with progress bar
|
|
1101
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
1102
|
+
# Create list for results and submit all tasks
|
|
1103
|
+
futures = list(tqdm(
|
|
1104
|
+
executor.map(process_trait_for_anno, trait_cols),
|
|
1105
|
+
total=len(trait_cols),
|
|
1106
|
+
desc=f"Processing traits for {anno}",
|
|
1107
|
+
leave=False
|
|
1108
|
+
))
|
|
1109
|
+
trait_results = list(futures)
|
|
1110
|
+
|
|
1111
|
+
# Create a DataFrame for this annotation
|
|
1112
|
+
anno_results_df = pd.DataFrame(trait_results).sort_values(by='p_cauchy')
|
|
1113
|
+
results_dict[anno] = anno_results_df
|
|
1114
|
+
|
|
1115
|
+
return results_dict
|
|
1116
|
+
|
|
1117
|
+
def _acat_test(self, pvalues: np.ndarray, weights=None):
|
|
1118
|
+
logger = logging.getLogger('gsMap.post_analysis.cauchy')
|
|
1119
|
+
if np.any(np.isnan(pvalues)):
|
|
1120
|
+
raise ValueError("Cannot have NAs in the p-values.")
|
|
1121
|
+
if np.any((pvalues > 1) | (pvalues < 0)):
|
|
1122
|
+
raise ValueError("P-values must be between 0 and 1.")
|
|
1123
|
+
if np.any(pvalues == 0) and np.any(pvalues == 1):
|
|
1124
|
+
raise ValueError("Cannot have both 0 and 1 p-values.")
|
|
1125
|
+
if np.any(pvalues == 0):
|
|
1126
|
+
logger.info("Warn: p-values are exactly 0.")
|
|
1127
|
+
return 0
|
|
1128
|
+
if np.any(pvalues == 1):
|
|
1129
|
+
logger.info("Warn: p-values are exactly 1.")
|
|
1130
|
+
return 1
|
|
1131
|
+
|
|
1132
|
+
if weights is None:
|
|
1133
|
+
weights = np.full(len(pvalues), 1 / len(pvalues))
|
|
1134
|
+
else:
|
|
1135
|
+
if len(weights) != len(pvalues):
|
|
1136
|
+
raise Exception("Length of weights and p-values differs.")
|
|
1137
|
+
if any(weights < 0):
|
|
1138
|
+
raise Exception("All weights must be positive.")
|
|
1139
|
+
weights = np.array(weights) / np.sum(weights)
|
|
1140
|
+
|
|
1141
|
+
is_small = pvalues < 1e-16
|
|
1142
|
+
is_large = ~is_small
|
|
1143
|
+
|
|
1144
|
+
if not np.any(is_small):
|
|
1145
|
+
cct_stat = np.sum(weights * np.tan((0.5 - pvalues) * np.pi))
|
|
1146
|
+
else:
|
|
1147
|
+
cct_stat = np.sum((weights[is_small] / pvalues[is_small]) / np.pi) + \
|
|
1148
|
+
np.sum(weights[is_large] * np.tan((0.5 - pvalues[is_large]) * np.pi))
|
|
1149
|
+
|
|
1150
|
+
if cct_stat > 1e15:
|
|
1151
|
+
pval = (1 / cct_stat) / np.pi
|
|
1152
|
+
else:
|
|
1153
|
+
pval = 1 - stats.cauchy.cdf(cct_stat)
|
|
1154
|
+
|
|
1155
|
+
return pval
|
|
1156
|
+
|
|
1157
|
+
def _save_cauchy_results_to_csv(self, cauchy_results: dict, output_path: Path):
|
|
1158
|
+
"""Save Cauchy results to CSV"""
|
|
1159
|
+
all_results = []
|
|
1160
|
+
for annotation, df in cauchy_results.items():
|
|
1161
|
+
df_copy = df.copy()
|
|
1162
|
+
df_copy['annotation'] = annotation
|
|
1163
|
+
all_results.append(df_copy)
|
|
1164
|
+
|
|
1165
|
+
combined_results = pd.concat(all_results, ignore_index=True)
|
|
1166
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1167
|
+
combined_results.to_csv(output_path, index=False)
|
|
1168
|
+
|
|
1169
|
+
return combined_results
|
|
1170
|
+
|
|
1171
|
+
def _generate_cauchy_heatmaps(self, cauchy_results: dict, annotation_col: str):
|
|
1172
|
+
"""Generate multiple types of Cauchy combination heatmaps"""
|
|
1173
|
+
# Convert results to pivot table format for different metrics
|
|
1174
|
+
table_cauchy = self._results_dict_to_log10_table(cauchy_results, value_col='p_cauchy', log10_transform=True)
|
|
1175
|
+
table_median = self._results_dict_to_log10_table(cauchy_results, value_col='p_median', log10_transform=True)
|
|
1176
|
+
table_odds_ratio = self._results_dict_to_log10_table(cauchy_results, value_col='odds_ratio',
|
|
1177
|
+
log10_transform=False)
|
|
1178
|
+
|
|
1179
|
+
# Create heatmap directories
|
|
1180
|
+
cauchy_heatmap_base = self.config.visualization_result_dir / 'cauchy_heatmap'
|
|
1181
|
+
static_folder = cauchy_heatmap_base / 'static_png'
|
|
1182
|
+
interactive_folder = cauchy_heatmap_base / 'interactive_html'
|
|
1183
|
+
|
|
1184
|
+
for folder in [static_folder, interactive_folder]:
|
|
1185
|
+
folder.mkdir(exist_ok=True, parents=True)
|
|
1186
|
+
|
|
1187
|
+
# Calculate dimensions
|
|
1188
|
+
num_annotations, num_traits = table_cauchy.shape
|
|
1189
|
+
width = 50 * num_traits
|
|
1190
|
+
height = 50 * num_annotations
|
|
1191
|
+
|
|
1192
|
+
# 1. Cauchy combination heatmap (non-normalized)
|
|
1193
|
+
fig = self._plot_p_cauchy_heatmap(
|
|
1194
|
+
df=table_cauchy,
|
|
1195
|
+
title=f"Cauchy Combination Heatmap -- By {annotation_col}",
|
|
1196
|
+
cluster_rows=True, cluster_cols=True,
|
|
1197
|
+
width=width, height=height,
|
|
1198
|
+
text_format=".2f", font_size=10, margin_pad=150
|
|
1199
|
+
)
|
|
1200
|
+
fig.write_image(static_folder / f'cauchy_combination_by_{annotation_col}.png', scale=2)
|
|
1201
|
+
fig.write_html(interactive_folder / f'cauchy_combination_by_{annotation_col}.html')
|
|
1202
|
+
|
|
1203
|
+
# 2. Cauchy combination heatmap (normalized)
|
|
1204
|
+
fig = self._plot_p_cauchy_heatmap(
|
|
1205
|
+
df=table_cauchy,
|
|
1206
|
+
title=f"Cauchy Combination Heatmap -- By {annotation_col}",
|
|
1207
|
+
normalize_axis='column',
|
|
1208
|
+
cluster_rows=True, cluster_cols=True,
|
|
1209
|
+
width=width, height=height,
|
|
1210
|
+
text_format=".2f", font_size=10, margin_pad=150
|
|
1211
|
+
)
|
|
1212
|
+
fig.write_image(static_folder / f'cauchy_combination_by_{annotation_col}_normalized.png', scale=2)
|
|
1213
|
+
fig.write_html(interactive_folder / f'cauchy_combination_by_{annotation_col}_normalized.html')
|
|
1214
|
+
|
|
1215
|
+
# 3. Median p-value heatmap (non-normalized)
|
|
1216
|
+
fig = self._plot_p_cauchy_heatmap(
|
|
1217
|
+
df=table_median,
|
|
1218
|
+
title=f"Median log 10 pvalue Heatmap -- By {annotation_col}",
|
|
1219
|
+
cluster_rows=True, cluster_cols=True,
|
|
1220
|
+
width=width, height=height,
|
|
1221
|
+
text_format=".2f", font_size=10, margin_pad=150
|
|
1222
|
+
)
|
|
1223
|
+
fig.write_image(static_folder / f'median_pvalue_{annotation_col}.png', scale=2)
|
|
1224
|
+
fig.write_html(interactive_folder / f'median_pvalue_{annotation_col}.html')
|
|
1225
|
+
|
|
1226
|
+
# 4. Median p-value heatmap (normalized)
|
|
1227
|
+
fig = self._plot_p_cauchy_heatmap(
|
|
1228
|
+
df=table_median,
|
|
1229
|
+
title=f"Median log 10 pvalue Heatmap -- By {annotation_col}",
|
|
1230
|
+
normalize_axis='column',
|
|
1231
|
+
cluster_rows=True, cluster_cols=True,
|
|
1232
|
+
width=width, height=height,
|
|
1233
|
+
text_format=".2f", font_size=10, margin_pad=150
|
|
1234
|
+
)
|
|
1235
|
+
fig.write_image(static_folder / f'median_pvalue_{annotation_col}_normalized.png', scale=2)
|
|
1236
|
+
fig.write_html(interactive_folder / f'median_pvalue_{annotation_col}_normalized.html')
|
|
1237
|
+
|
|
1238
|
+
# 5. Odds ratio heatmap
|
|
1239
|
+
fig = self._plot_p_cauchy_heatmap(
|
|
1240
|
+
df=table_odds_ratio,
|
|
1241
|
+
title=f"Odds Ratio Heatmap -- By {annotation_col}",
|
|
1242
|
+
cluster_rows=True, cluster_cols=True,
|
|
1243
|
+
width=width, height=height,
|
|
1244
|
+
text_format=".2f", font_size=10, margin_pad=150
|
|
1245
|
+
)
|
|
1246
|
+
fig.write_image(static_folder / f'odds_ratio_{annotation_col}.png', scale=2)
|
|
1247
|
+
fig.write_html(interactive_folder / f'odds_ratio_{annotation_col}.html')
|
|
1248
|
+
|
|
1249
|
+
def _results_dict_to_log10_table(self, results_dict: dict, value_col: str = 'p_cauchy',
|
|
1250
|
+
log10_transform: bool = True, epsilon: float = 1e-300) -> pd.DataFrame:
|
|
1251
|
+
"""Convert results dict to pivot table"""
|
|
1252
|
+
all_data = []
|
|
1253
|
+
for anno, df in results_dict.items():
|
|
1254
|
+
temp = df.copy()
|
|
1255
|
+
temp['annotation'] = anno
|
|
1256
|
+
all_data.append(temp)
|
|
1257
|
+
|
|
1258
|
+
combined_df = pd.concat(all_data, ignore_index=True)
|
|
1259
|
+
|
|
1260
|
+
if log10_transform:
|
|
1261
|
+
combined_df.loc[combined_df[value_col] == 0, value_col] = epsilon
|
|
1262
|
+
combined_df['transformed'] = -np.log10(combined_df[value_col])
|
|
1263
|
+
else:
|
|
1264
|
+
combined_df['transformed'] = combined_df[value_col]
|
|
1265
|
+
|
|
1266
|
+
pivot_df = combined_df.pivot(index='annotation', columns='trait', values='transformed')
|
|
1267
|
+
return pivot_df
|
|
1268
|
+
|
|
1269
|
+
def _plot_p_cauchy_heatmap(self, df: pd.DataFrame, title: str = "Cauchy Combination Heatmap",
|
|
1270
|
+
normalize_axis: Literal["row", "column"] | None = None,
|
|
1271
|
+
cluster_rows: bool = False, cluster_cols: bool = False,
|
|
1272
|
+
color_continuous_scale: str | list = "RdBu_r",
|
|
1273
|
+
width: int | None = None, height: int | None = None,
|
|
1274
|
+
text_format: str = ".2f",
|
|
1275
|
+
show_text: bool = True, font_size: int = 10, margin_pad: int = 150) -> go.Figure:
|
|
1276
|
+
"""
|
|
1277
|
+
Create an enhanced heatmap visualization for trait-annotation relationships.
|
|
1278
|
+
"""
|
|
1279
|
+
data = df.copy()
|
|
1280
|
+
|
|
1281
|
+
# Input validation
|
|
1282
|
+
if not isinstance(data, pd.DataFrame):
|
|
1283
|
+
raise TypeError("Input must be a pandas DataFrame")
|
|
1284
|
+
if data.empty:
|
|
1285
|
+
raise ValueError("Input DataFrame is empty")
|
|
1286
|
+
if not np.issubdtype(data.values.dtype, np.number):
|
|
1287
|
+
raise ValueError("DataFrame must contain numeric values")
|
|
1288
|
+
|
|
1289
|
+
n_rows, n_cols = data.shape
|
|
1290
|
+
# Set dynamic width/height if not provided to ensure good aspect ratio
|
|
1291
|
+
# Previously we used 50 and 30, which led to vertical stretching.
|
|
1292
|
+
# Let's use more balanced units.
|
|
1293
|
+
if width is None:
|
|
1294
|
+
width = max(600, n_cols * 150 + margin_pad * 2)
|
|
1295
|
+
if height is None:
|
|
1296
|
+
height = max(500, n_rows * 60 + margin_pad * 2)
|
|
1297
|
+
|
|
1298
|
+
# Normalization with error handling
|
|
1299
|
+
if normalize_axis in ['row', 'column']:
|
|
1300
|
+
axis = 1 if normalize_axis == 'row' else 0
|
|
1301
|
+
try:
|
|
1302
|
+
# Store original data for text annotations
|
|
1303
|
+
original_data = data.copy()
|
|
1304
|
+
|
|
1305
|
+
# Calculate min and max along specified axis
|
|
1306
|
+
min_vals = data.min(axis=axis)
|
|
1307
|
+
max_vals = data.max(axis=axis)
|
|
1308
|
+
range_vals = max_vals - min_vals
|
|
1309
|
+
|
|
1310
|
+
# Replace zero range with 1 to avoid division by zero
|
|
1311
|
+
range_vals = range_vals.replace(0, 1)
|
|
1312
|
+
|
|
1313
|
+
# Normalize using broadcasting
|
|
1314
|
+
if normalize_axis == 'row':
|
|
1315
|
+
data = data.sub(min_vals, axis=0).div(range_vals, axis=0)
|
|
1316
|
+
else: # column
|
|
1317
|
+
data = data.sub(min_vals, axis=1).div(range_vals, axis=1)
|
|
1318
|
+
|
|
1319
|
+
data = data.fillna(0)
|
|
1320
|
+
except Exception as e:
|
|
1321
|
+
raise ValueError(f"Normalization failed: {str(e)}")
|
|
1322
|
+
else:
|
|
1323
|
+
# No normalization, use original data for both color and text
|
|
1324
|
+
original_data = data
|
|
1325
|
+
|
|
1326
|
+
# Clustering with error handling
|
|
1327
|
+
try:
|
|
1328
|
+
if cluster_rows:
|
|
1329
|
+
row_linkage = linkage(data.fillna(0).values, method='average', metric='euclidean')
|
|
1330
|
+
row_order = leaves_list(row_linkage)
|
|
1331
|
+
data = data.iloc[row_order, :]
|
|
1332
|
+
original_data = original_data.iloc[row_order, :] # Apply the same order to original data
|
|
1333
|
+
|
|
1334
|
+
if cluster_cols:
|
|
1335
|
+
col_linkage = linkage(data.fillna(0).values.T, method='average', metric='euclidean')
|
|
1336
|
+
col_order = leaves_list(col_linkage)
|
|
1337
|
+
data = data.iloc[:, col_order]
|
|
1338
|
+
original_data = original_data.iloc[:, col_order] # Apply the same order to original data
|
|
1339
|
+
except Exception as e:
|
|
1340
|
+
raise ValueError(f"Clustering failed: {str(e)}")
|
|
1341
|
+
|
|
1342
|
+
# Create heatmap with enhanced formatting
|
|
1343
|
+
if normalize_axis is None:
|
|
1344
|
+
# Use original settings for speed when no normalization is applied
|
|
1345
|
+
fig = px.imshow(
|
|
1346
|
+
data,
|
|
1347
|
+
color_continuous_scale=color_continuous_scale,
|
|
1348
|
+
aspect='auto',
|
|
1349
|
+
width=width,
|
|
1350
|
+
height=height,
|
|
1351
|
+
text_auto=text_format if show_text else False # Automatic text generation
|
|
1352
|
+
)
|
|
1353
|
+
else:
|
|
1354
|
+
# Use custom logic for normalization (manual text annotations)
|
|
1355
|
+
fig = px.imshow(
|
|
1356
|
+
data,
|
|
1357
|
+
color_continuous_scale=color_continuous_scale,
|
|
1358
|
+
aspect='auto',
|
|
1359
|
+
width=width,
|
|
1360
|
+
height=height,
|
|
1361
|
+
text_auto=False # Disable automatic text generation
|
|
1362
|
+
)
|
|
1363
|
+
|
|
1364
|
+
# Add manual text annotations using original data
|
|
1365
|
+
if show_text:
|
|
1366
|
+
for i, row in enumerate(original_data.values):
|
|
1367
|
+
for j, value in enumerate(row):
|
|
1368
|
+
fig.add_annotation(
|
|
1369
|
+
x=j,
|
|
1370
|
+
y=i,
|
|
1371
|
+
text=f"{value:{text_format}}",
|
|
1372
|
+
showarrow=False,
|
|
1373
|
+
font=dict(size=font_size, color='black')
|
|
1374
|
+
)
|
|
1375
|
+
|
|
1376
|
+
# Enhanced layout configuration
|
|
1377
|
+
fig.update_layout(
|
|
1378
|
+
title={
|
|
1379
|
+
'text': title,
|
|
1380
|
+
'y': 0.98,
|
|
1381
|
+
'x': 0.5,
|
|
1382
|
+
'xanchor': 'center',
|
|
1383
|
+
'yanchor': 'bottom',
|
|
1384
|
+
'font': {'size': font_size + 4}
|
|
1385
|
+
},
|
|
1386
|
+
xaxis={
|
|
1387
|
+
'title': "Trait",
|
|
1388
|
+
'tickangle': 45,
|
|
1389
|
+
'side': 'bottom',
|
|
1390
|
+
'tickfont': {'size': font_size},
|
|
1391
|
+
'title_font': {'size': font_size + 2}
|
|
1392
|
+
},
|
|
1393
|
+
yaxis={
|
|
1394
|
+
'title': "Annotation",
|
|
1395
|
+
'tickfont': {'size': font_size},
|
|
1396
|
+
'title_font': {'size': font_size + 2}
|
|
1397
|
+
},
|
|
1398
|
+
width=width,
|
|
1399
|
+
height=height,
|
|
1400
|
+
template='plotly_white',
|
|
1401
|
+
margin=dict(l=margin_pad, r=margin_pad, t=margin_pad, b=margin_pad),
|
|
1402
|
+
coloraxis_colorbar={
|
|
1403
|
+
'title': '-log10(p)' if not normalize_axis else 'Normalized Value',
|
|
1404
|
+
'title_side': 'right',
|
|
1405
|
+
'title_font': {'size': font_size + 2},
|
|
1406
|
+
'tickfont': {'size': font_size}
|
|
1407
|
+
}
|
|
1408
|
+
)
|
|
1409
|
+
return fig
|