pylocuszoom 0.8.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pylocuszoom/__init__.py +27 -7
- pylocuszoom/_plotter_utils.py +66 -0
- pylocuszoom/backends/base.py +56 -0
- pylocuszoom/backends/bokeh_backend.py +141 -29
- pylocuszoom/backends/matplotlib_backend.py +60 -0
- pylocuszoom/backends/plotly_backend.py +297 -88
- pylocuszoom/config.py +365 -0
- pylocuszoom/ensembl.py +6 -11
- pylocuszoom/eqtl.py +3 -7
- pylocuszoom/exceptions.py +33 -0
- pylocuszoom/finemapping.py +2 -7
- pylocuszoom/forest.py +1 -0
- pylocuszoom/gene_track.py +10 -31
- pylocuszoom/labels.py +6 -2
- pylocuszoom/manhattan.py +246 -0
- pylocuszoom/manhattan_plotter.py +760 -0
- pylocuszoom/plotter.py +401 -327
- pylocuszoom/qq.py +123 -0
- pylocuszoom/recombination.py +7 -7
- pylocuszoom/schemas.py +1 -6
- pylocuszoom/stats_plotter.py +319 -0
- pylocuszoom/utils.py +2 -4
- pylocuszoom/validation.py +51 -0
- {pylocuszoom-0.8.0.dist-info → pylocuszoom-1.1.0.dist-info}/METADATA +159 -25
- pylocuszoom-1.1.0.dist-info/RECORD +36 -0
- pylocuszoom-0.8.0.dist-info/RECORD +0 -29
- {pylocuszoom-0.8.0.dist-info → pylocuszoom-1.1.0.dist-info}/WHEEL +0 -0
- {pylocuszoom-0.8.0.dist-info → pylocuszoom-1.1.0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -11,6 +11,21 @@ from plotly.subplots import make_subplots
|
|
|
11
11
|
|
|
12
12
|
from . import convert_latex_to_unicode, register_backend
|
|
13
13
|
|
|
14
|
+
# Style mappings (matplotlib -> Plotly)
|
|
15
|
+
_MARKER_SYMBOLS = {
|
|
16
|
+
"o": "circle",
|
|
17
|
+
"D": "diamond",
|
|
18
|
+
"s": "square",
|
|
19
|
+
"^": "triangle-up",
|
|
20
|
+
"v": "triangle-down",
|
|
21
|
+
}
|
|
22
|
+
_DASH_MAP = {
|
|
23
|
+
"-": "solid",
|
|
24
|
+
"--": "dash",
|
|
25
|
+
":": "dot",
|
|
26
|
+
"-.": "dashdot",
|
|
27
|
+
}
|
|
28
|
+
|
|
14
29
|
|
|
15
30
|
@register_backend("plotly")
|
|
16
31
|
class PlotlyBackend:
|
|
@@ -23,21 +38,6 @@ class PlotlyBackend:
|
|
|
23
38
|
- Nearest gene
|
|
24
39
|
"""
|
|
25
40
|
|
|
26
|
-
# Class constants for style mappings
|
|
27
|
-
_MARKER_SYMBOLS = {
|
|
28
|
-
"o": "circle",
|
|
29
|
-
"D": "diamond",
|
|
30
|
-
"s": "square",
|
|
31
|
-
"^": "triangle-up",
|
|
32
|
-
"v": "triangle-down",
|
|
33
|
-
}
|
|
34
|
-
_DASH_MAP = {
|
|
35
|
-
"-": "solid",
|
|
36
|
-
"--": "dash",
|
|
37
|
-
":": "dot",
|
|
38
|
-
"-.": "dashdot",
|
|
39
|
-
}
|
|
40
|
-
|
|
41
41
|
@property
|
|
42
42
|
def supports_snp_labels(self) -> bool:
|
|
43
43
|
"""Plotly uses hover tooltips instead of labels."""
|
|
@@ -113,6 +113,101 @@ class PlotlyBackend:
|
|
|
113
113
|
panel_refs = [(fig, row) for row in range(1, n_panels + 1)]
|
|
114
114
|
return fig, panel_refs
|
|
115
115
|
|
|
116
|
+
def create_figure_grid(
|
|
117
|
+
self,
|
|
118
|
+
n_rows: int,
|
|
119
|
+
n_cols: int,
|
|
120
|
+
width_ratios: Optional[List[float]] = None,
|
|
121
|
+
height_ratios: Optional[List[float]] = None,
|
|
122
|
+
figsize: Tuple[float, float] = (12.0, 8.0),
|
|
123
|
+
) -> Tuple[go.Figure, List[Any]]:
|
|
124
|
+
"""Create a figure with a grid of subplots.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
n_rows: Number of rows.
|
|
128
|
+
n_cols: Number of columns.
|
|
129
|
+
width_ratios: Relative widths for columns.
|
|
130
|
+
height_ratios: Relative heights for rows.
|
|
131
|
+
figsize: Figure size as (width, height).
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Tuple of (figure, flattened list of (fig, row, col) tuples).
|
|
135
|
+
"""
|
|
136
|
+
width_px = int(figsize[0] * 100)
|
|
137
|
+
height_px = int(figsize[1] * 100)
|
|
138
|
+
|
|
139
|
+
# Normalize ratios
|
|
140
|
+
if width_ratios is not None:
|
|
141
|
+
total = sum(width_ratios)
|
|
142
|
+
column_widths = [w / total for w in width_ratios]
|
|
143
|
+
else:
|
|
144
|
+
column_widths = None
|
|
145
|
+
|
|
146
|
+
if height_ratios is not None:
|
|
147
|
+
total = sum(height_ratios)
|
|
148
|
+
row_heights_norm = [h / total for h in height_ratios]
|
|
149
|
+
else:
|
|
150
|
+
row_heights_norm = None
|
|
151
|
+
|
|
152
|
+
fig = make_subplots(
|
|
153
|
+
rows=n_rows,
|
|
154
|
+
cols=n_cols,
|
|
155
|
+
column_widths=column_widths,
|
|
156
|
+
row_heights=row_heights_norm,
|
|
157
|
+
horizontal_spacing=0.08,
|
|
158
|
+
vertical_spacing=0.08,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
fig.update_layout(
|
|
162
|
+
width=width_px,
|
|
163
|
+
height=height_px,
|
|
164
|
+
showlegend=True,
|
|
165
|
+
template="plotly_white",
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Store grid dimensions for axis naming
|
|
169
|
+
fig._n_cols = n_cols
|
|
170
|
+
fig._n_rows = n_rows
|
|
171
|
+
|
|
172
|
+
# Style all panels
|
|
173
|
+
axis_style = dict(
|
|
174
|
+
showgrid=False,
|
|
175
|
+
showline=True,
|
|
176
|
+
linecolor="black",
|
|
177
|
+
ticks="outside",
|
|
178
|
+
zeroline=False,
|
|
179
|
+
)
|
|
180
|
+
for row in range(1, n_rows + 1):
|
|
181
|
+
for col in range(1, n_cols + 1):
|
|
182
|
+
subplot_idx = (row - 1) * n_cols + col
|
|
183
|
+
xaxis = f"xaxis{subplot_idx}" if subplot_idx > 1 else "xaxis"
|
|
184
|
+
yaxis = f"yaxis{subplot_idx}" if subplot_idx > 1 else "yaxis"
|
|
185
|
+
fig.update_layout(**{xaxis: axis_style, yaxis: axis_style})
|
|
186
|
+
|
|
187
|
+
# Return flattened list of (fig, row, col) tuples
|
|
188
|
+
panel_refs = []
|
|
189
|
+
for row in range(1, n_rows + 1):
|
|
190
|
+
for col in range(1, n_cols + 1):
|
|
191
|
+
panel_refs.append((fig, row, col))
|
|
192
|
+
return fig, panel_refs
|
|
193
|
+
|
|
194
|
+
def _extract_row_col(self, ax: Any) -> Tuple[go.Figure, int, int, int]:
|
|
195
|
+
"""Extract figure, row, col, and n_cols from ax tuple.
|
|
196
|
+
|
|
197
|
+
Handles both (fig, row) for create_figure and (fig, row, col) for create_figure_grid.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Tuple of (figure, row, col, n_cols).
|
|
201
|
+
"""
|
|
202
|
+
if len(ax) == 2:
|
|
203
|
+
fig, row = ax
|
|
204
|
+
col = 1
|
|
205
|
+
else:
|
|
206
|
+
fig, row, col = ax
|
|
207
|
+
# Get n_cols from figure if stored, default to 1 for single-column layouts
|
|
208
|
+
n_cols = getattr(fig, "_n_cols", 1)
|
|
209
|
+
return fig, row, col, n_cols
|
|
210
|
+
|
|
116
211
|
def scatter(
|
|
117
212
|
self,
|
|
118
213
|
ax: Tuple[go.Figure, int],
|
|
@@ -129,12 +224,12 @@ class PlotlyBackend:
|
|
|
129
224
|
) -> Any:
|
|
130
225
|
"""Create a scatter plot on the given panel.
|
|
131
226
|
|
|
132
|
-
For plotly, ax is a tuple of (figure, row_number).
|
|
227
|
+
For plotly, ax is a tuple of (figure, row_number) or (figure, row, col).
|
|
133
228
|
"""
|
|
134
|
-
fig, row = ax
|
|
229
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
135
230
|
|
|
136
231
|
# Convert matplotlib marker to plotly symbol
|
|
137
|
-
symbol =
|
|
232
|
+
symbol = _MARKER_SYMBOLS.get(marker, "circle")
|
|
138
233
|
|
|
139
234
|
# Convert size (matplotlib uses area, plotly uses diameter)
|
|
140
235
|
if isinstance(sizes, (int, float)):
|
|
@@ -147,16 +242,16 @@ class PlotlyBackend:
|
|
|
147
242
|
customdata = hover_data.values
|
|
148
243
|
hover_cols = hover_data.columns.tolist()
|
|
149
244
|
hovertemplate = "<b>%{customdata[0]}</b><br>"
|
|
150
|
-
for i,
|
|
151
|
-
col_lower =
|
|
245
|
+
for i, col_name in enumerate(hover_cols[1:], 1):
|
|
246
|
+
col_lower = col_name.lower()
|
|
152
247
|
if col_lower in ("p-value", "pval", "p_value"):
|
|
153
|
-
hovertemplate += f"{
|
|
248
|
+
hovertemplate += f"{col_name}: %{{customdata[{i}]:.2e}}<br>"
|
|
154
249
|
elif any(x in col_lower for x in ("r2", "r²", "ld")):
|
|
155
|
-
hovertemplate += f"{
|
|
250
|
+
hovertemplate += f"{col_name}: %{{customdata[{i}]:.3f}}<br>"
|
|
156
251
|
elif "pos" in col_lower:
|
|
157
|
-
hovertemplate += f"{
|
|
252
|
+
hovertemplate += f"{col_name}: %{{customdata[{i}]:,.0f}}<br>"
|
|
158
253
|
else:
|
|
159
|
-
hovertemplate += f"{
|
|
254
|
+
hovertemplate += f"{col_name}: %{{customdata[{i}]}}<br>"
|
|
160
255
|
hovertemplate += "<extra></extra>"
|
|
161
256
|
else:
|
|
162
257
|
customdata = None
|
|
@@ -184,7 +279,7 @@ class PlotlyBackend:
|
|
|
184
279
|
showlegend=label is not None,
|
|
185
280
|
)
|
|
186
281
|
|
|
187
|
-
fig.add_trace(trace, row=row, col=
|
|
282
|
+
fig.add_trace(trace, row=row, col=col)
|
|
188
283
|
return trace
|
|
189
284
|
|
|
190
285
|
def line(
|
|
@@ -200,8 +295,8 @@ class PlotlyBackend:
|
|
|
200
295
|
label: Optional[str] = None,
|
|
201
296
|
) -> Any:
|
|
202
297
|
"""Create a line plot on the given panel."""
|
|
203
|
-
fig, row = ax
|
|
204
|
-
dash =
|
|
298
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
299
|
+
dash = _DASH_MAP.get(linestyle, "solid")
|
|
205
300
|
|
|
206
301
|
trace = go.Scatter(
|
|
207
302
|
x=x,
|
|
@@ -213,7 +308,7 @@ class PlotlyBackend:
|
|
|
213
308
|
showlegend=label is not None,
|
|
214
309
|
)
|
|
215
310
|
|
|
216
|
-
fig.add_trace(trace, row=row, col=
|
|
311
|
+
fig.add_trace(trace, row=row, col=col)
|
|
217
312
|
return trace
|
|
218
313
|
|
|
219
314
|
def fill_between(
|
|
@@ -227,7 +322,7 @@ class PlotlyBackend:
|
|
|
227
322
|
zorder: int = 0,
|
|
228
323
|
) -> Any:
|
|
229
324
|
"""Fill area between two y-values."""
|
|
230
|
-
fig, row = ax
|
|
325
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
231
326
|
|
|
232
327
|
# Convert y1 to series if scalar
|
|
233
328
|
if isinstance(y1, (int, float)):
|
|
@@ -244,7 +339,7 @@ class PlotlyBackend:
|
|
|
244
339
|
hoverinfo="skip",
|
|
245
340
|
)
|
|
246
341
|
|
|
247
|
-
fig.add_trace(trace, row=row, col=
|
|
342
|
+
fig.add_trace(trace, row=row, col=col)
|
|
248
343
|
return trace
|
|
249
344
|
|
|
250
345
|
def axhline(
|
|
@@ -258,8 +353,8 @@ class PlotlyBackend:
|
|
|
258
353
|
zorder: int = 1,
|
|
259
354
|
) -> Any:
|
|
260
355
|
"""Add a horizontal line across the panel."""
|
|
261
|
-
fig, row = ax
|
|
262
|
-
dash =
|
|
356
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
357
|
+
dash = _DASH_MAP.get(linestyle, "dash")
|
|
263
358
|
|
|
264
359
|
fig.add_hline(
|
|
265
360
|
y=y,
|
|
@@ -268,7 +363,7 @@ class PlotlyBackend:
|
|
|
268
363
|
line_width=linewidth,
|
|
269
364
|
opacity=alpha,
|
|
270
365
|
row=row,
|
|
271
|
-
col=
|
|
366
|
+
col=col,
|
|
272
367
|
)
|
|
273
368
|
|
|
274
369
|
def add_text(
|
|
@@ -284,7 +379,7 @@ class PlotlyBackend:
|
|
|
284
379
|
color: str = "black",
|
|
285
380
|
) -> Any:
|
|
286
381
|
"""Add text annotation to panel."""
|
|
287
|
-
fig, row = ax
|
|
382
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
288
383
|
|
|
289
384
|
# Map alignment
|
|
290
385
|
xanchor_map = {"center": "center", "left": "left", "right": "right"}
|
|
@@ -300,7 +395,7 @@ class PlotlyBackend:
|
|
|
300
395
|
textangle=-rotation,
|
|
301
396
|
showarrow=False,
|
|
302
397
|
row=row,
|
|
303
|
-
col=
|
|
398
|
+
col=col,
|
|
304
399
|
)
|
|
305
400
|
|
|
306
401
|
def add_rectangle(
|
|
@@ -315,7 +410,7 @@ class PlotlyBackend:
|
|
|
315
410
|
zorder: int = 2,
|
|
316
411
|
) -> Any:
|
|
317
412
|
"""Add a rectangle to the panel."""
|
|
318
|
-
fig, row = ax
|
|
413
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
319
414
|
|
|
320
415
|
x0, y0 = xy
|
|
321
416
|
x1, y1 = x0 + width, y0 + height
|
|
@@ -329,7 +424,7 @@ class PlotlyBackend:
|
|
|
329
424
|
fillcolor=facecolor,
|
|
330
425
|
line=dict(color=edgecolor, width=linewidth),
|
|
331
426
|
row=row,
|
|
332
|
-
col=
|
|
427
|
+
col=col,
|
|
333
428
|
)
|
|
334
429
|
|
|
335
430
|
def add_polygon(
|
|
@@ -342,7 +437,7 @@ class PlotlyBackend:
|
|
|
342
437
|
zorder: int = 2,
|
|
343
438
|
) -> Any:
|
|
344
439
|
"""Add a polygon (e.g., triangle for strand arrows) to the panel."""
|
|
345
|
-
fig, row = ax
|
|
440
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
346
441
|
|
|
347
442
|
# Build SVG path from points
|
|
348
443
|
path = f"M {points[0][0]} {points[0][1]}"
|
|
@@ -356,28 +451,32 @@ class PlotlyBackend:
|
|
|
356
451
|
fillcolor=facecolor,
|
|
357
452
|
line=dict(color=edgecolor, width=linewidth),
|
|
358
453
|
row=row,
|
|
359
|
-
col=
|
|
454
|
+
col=col,
|
|
360
455
|
)
|
|
361
456
|
|
|
362
457
|
def set_xlim(self, ax: Tuple[go.Figure, int], left: float, right: float) -> None:
|
|
363
458
|
"""Set x-axis limits."""
|
|
364
|
-
fig, row = ax
|
|
365
|
-
fig.update_layout(
|
|
459
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
460
|
+
fig.update_layout(
|
|
461
|
+
**{self._axis_name("xaxis", row, col, n_cols): dict(range=[left, right])}
|
|
462
|
+
)
|
|
366
463
|
|
|
367
464
|
def set_ylim(self, ax: Tuple[go.Figure, int], bottom: float, top: float) -> None:
|
|
368
465
|
"""Set y-axis limits."""
|
|
369
|
-
fig, row = ax
|
|
370
|
-
fig.update_layout(
|
|
466
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
467
|
+
fig.update_layout(
|
|
468
|
+
**{self._axis_name("yaxis", row, col, n_cols): dict(range=[bottom, top])}
|
|
469
|
+
)
|
|
371
470
|
|
|
372
471
|
def set_xlabel(
|
|
373
472
|
self, ax: Tuple[go.Figure, int], label: str, fontsize: int = 12
|
|
374
473
|
) -> None:
|
|
375
474
|
"""Set x-axis label."""
|
|
376
|
-
fig, row = ax
|
|
475
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
377
476
|
label = self._convert_label(label)
|
|
378
477
|
fig.update_layout(
|
|
379
478
|
**{
|
|
380
|
-
self._axis_name("xaxis", row): dict(
|
|
479
|
+
self._axis_name("xaxis", row, col, n_cols): dict(
|
|
381
480
|
title=dict(text=label, font=dict(size=fontsize))
|
|
382
481
|
)
|
|
383
482
|
}
|
|
@@ -387,11 +486,11 @@ class PlotlyBackend:
|
|
|
387
486
|
self, ax: Tuple[go.Figure, int], label: str, fontsize: int = 12
|
|
388
487
|
) -> None:
|
|
389
488
|
"""Set y-axis label."""
|
|
390
|
-
fig, row = ax
|
|
489
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
391
490
|
label = self._convert_label(label)
|
|
392
491
|
fig.update_layout(
|
|
393
492
|
**{
|
|
394
|
-
self._axis_name("yaxis", row): dict(
|
|
493
|
+
self._axis_name("yaxis", row, col, n_cols): dict(
|
|
395
494
|
title=dict(text=label, font=dict(size=fontsize))
|
|
396
495
|
)
|
|
397
496
|
}
|
|
@@ -405,10 +504,10 @@ class PlotlyBackend:
|
|
|
405
504
|
fontsize: int = 10,
|
|
406
505
|
) -> None:
|
|
407
506
|
"""Set y-axis tick positions and labels."""
|
|
408
|
-
fig, row = ax
|
|
507
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
409
508
|
fig.update_layout(
|
|
410
509
|
**{
|
|
411
|
-
self._axis_name("yaxis", row): dict(
|
|
510
|
+
self._axis_name("yaxis", row, col, n_cols): dict(
|
|
412
511
|
tickmode="array",
|
|
413
512
|
tickvals=positions,
|
|
414
513
|
ticktext=labels,
|
|
@@ -417,13 +516,48 @@ class PlotlyBackend:
|
|
|
417
516
|
}
|
|
418
517
|
)
|
|
419
518
|
|
|
420
|
-
def
|
|
421
|
-
|
|
519
|
+
def set_xticks(
|
|
520
|
+
self,
|
|
521
|
+
ax: Tuple[go.Figure, int],
|
|
522
|
+
positions: List[float],
|
|
523
|
+
labels: List[str],
|
|
524
|
+
fontsize: int = 10,
|
|
525
|
+
rotation: int = 0,
|
|
526
|
+
ha: str = "center",
|
|
527
|
+
) -> None:
|
|
528
|
+
"""Set x-axis tick positions and labels."""
|
|
529
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
530
|
+
fig.update_layout(
|
|
531
|
+
**{
|
|
532
|
+
self._axis_name("xaxis", row, col, n_cols): dict(
|
|
533
|
+
tickmode="array",
|
|
534
|
+
tickvals=positions,
|
|
535
|
+
ticktext=labels,
|
|
536
|
+
tickfont=dict(size=fontsize),
|
|
537
|
+
tickangle=-rotation if rotation else 0,
|
|
538
|
+
)
|
|
539
|
+
}
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
def _axis_name(self, axis: str, row: int, col: int = 1, n_cols: int = 1) -> str:
|
|
543
|
+
"""Get Plotly axis name for a given row and column.
|
|
422
544
|
|
|
423
|
-
Plotly names axes
|
|
424
|
-
'
|
|
545
|
+
Plotly names axes using a linear subplot index:
|
|
546
|
+
- subplot (1,1) uses 'xaxis', 'yaxis'
|
|
547
|
+
- subplot (1,2) uses 'xaxis2', 'yaxis2'
|
|
548
|
+
- subplot (2,1) uses 'xaxis3', 'yaxis3' (for 2-column grid)
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
axis: Base axis name ('xaxis' or 'yaxis').
|
|
552
|
+
row: Row number (1-indexed).
|
|
553
|
+
col: Column number (1-indexed).
|
|
554
|
+
n_cols: Total number of columns in the grid.
|
|
555
|
+
|
|
556
|
+
Returns:
|
|
557
|
+
Plotly axis name string.
|
|
425
558
|
"""
|
|
426
|
-
|
|
559
|
+
subplot_idx = (row - 1) * n_cols + col
|
|
560
|
+
return f"{axis}{subplot_idx}" if subplot_idx > 1 else axis
|
|
427
561
|
|
|
428
562
|
def _get_legend_position(self, loc: str) -> dict:
|
|
429
563
|
"""Map matplotlib-style legend location to Plotly position dict."""
|
|
@@ -442,26 +576,81 @@ class PlotlyBackend:
|
|
|
442
576
|
def set_title(
|
|
443
577
|
self, ax: Tuple[go.Figure, int], title: str, fontsize: int = 14
|
|
444
578
|
) -> None:
|
|
445
|
-
"""Set
|
|
446
|
-
|
|
447
|
-
|
|
579
|
+
"""Set subplot title using annotation.
|
|
580
|
+
|
|
581
|
+
For grid layouts, this adds an annotation above the subplot.
|
|
582
|
+
For single-column layouts, sets the global figure title for the first panel.
|
|
583
|
+
"""
|
|
584
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
585
|
+
|
|
586
|
+
if n_cols == 1 and row == 1:
|
|
587
|
+
# Single-column layout: use global figure title
|
|
448
588
|
fig.update_layout(title=dict(text=title, font=dict(size=fontsize)))
|
|
589
|
+
else:
|
|
590
|
+
# Grid layout: add annotation above the subplot
|
|
591
|
+
# Use subplot's axis domain for positioning
|
|
592
|
+
# Plotly uses "x" or "x2", "x3", etc. (not "xaxis")
|
|
593
|
+
subplot_idx = (row - 1) * n_cols + col
|
|
594
|
+
xref = f"x{subplot_idx} domain" if subplot_idx > 1 else "x domain"
|
|
595
|
+
yref = f"y{subplot_idx} domain" if subplot_idx > 1 else "y domain"
|
|
596
|
+
|
|
597
|
+
fig.add_annotation(
|
|
598
|
+
text=f"<b>{title}</b>",
|
|
599
|
+
xref=xref,
|
|
600
|
+
yref=yref,
|
|
601
|
+
x=0.5,
|
|
602
|
+
y=1.05,
|
|
603
|
+
showarrow=False,
|
|
604
|
+
font=dict(size=fontsize),
|
|
605
|
+
xanchor="center",
|
|
606
|
+
yanchor="bottom",
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
def set_suptitle(self, fig: go.Figure, title: str, fontsize: int = 14) -> None:
|
|
610
|
+
"""Set overall figure title (super title)."""
|
|
611
|
+
fig.update_layout(
|
|
612
|
+
title=dict(
|
|
613
|
+
text=title,
|
|
614
|
+
font=dict(size=fontsize),
|
|
615
|
+
x=0.5,
|
|
616
|
+
xanchor="center",
|
|
617
|
+
)
|
|
618
|
+
)
|
|
449
619
|
|
|
450
620
|
def create_twin_axis(self, ax: Tuple[go.Figure, int]) -> Tuple[go.Figure, int, str]:
|
|
451
621
|
"""Create a secondary y-axis.
|
|
452
622
|
|
|
453
623
|
Returns tuple of (figure, row, secondary_yaxis_name).
|
|
624
|
+
|
|
625
|
+
For Plotly subplots, we need unique axis names that don't conflict
|
|
626
|
+
with the subplot axes. We use a high number suffix to avoid conflicts.
|
|
454
627
|
"""
|
|
455
|
-
fig, row = ax
|
|
456
|
-
|
|
628
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
629
|
+
|
|
630
|
+
# Calculate subplot index for proper axis naming
|
|
631
|
+
subplot_idx = (row - 1) * n_cols + col
|
|
632
|
+
|
|
633
|
+
# Use a unique suffix that won't conflict with subplot axis numbering
|
|
634
|
+
# yaxis10, yaxis11, etc. are unlikely to conflict with typical subplot counts
|
|
635
|
+
secondary_suffix = 10 + subplot_idx - 1
|
|
636
|
+
secondary_y = f"y{secondary_suffix}"
|
|
637
|
+
yaxis_name = f"yaxis{secondary_suffix}"
|
|
457
638
|
|
|
458
|
-
#
|
|
459
|
-
|
|
639
|
+
# Get the primary y-axis name for this subplot
|
|
640
|
+
primary_y = f"y{subplot_idx}" if subplot_idx > 1 else "y"
|
|
641
|
+
# Get the x-axis name for this subplot
|
|
642
|
+
xaxis_ref = f"x{subplot_idx}" if subplot_idx > 1 else "x"
|
|
643
|
+
|
|
644
|
+
# Configure secondary y-axis to overlay the primary axis of this row
|
|
460
645
|
fig.update_layout(
|
|
461
646
|
**{
|
|
462
647
|
yaxis_name: dict(
|
|
463
|
-
overlaying=
|
|
648
|
+
overlaying=primary_y,
|
|
464
649
|
side="right",
|
|
650
|
+
anchor=xaxis_ref,
|
|
651
|
+
showgrid=False,
|
|
652
|
+
showline=False,
|
|
653
|
+
zeroline=False,
|
|
465
654
|
)
|
|
466
655
|
}
|
|
467
656
|
)
|
|
@@ -481,8 +670,13 @@ class PlotlyBackend:
|
|
|
481
670
|
yaxis_name: str = "y2",
|
|
482
671
|
) -> Any:
|
|
483
672
|
"""Create a line plot on secondary y-axis."""
|
|
484
|
-
fig, row = ax
|
|
485
|
-
dash =
|
|
673
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
674
|
+
dash = _DASH_MAP.get(linestyle, "solid")
|
|
675
|
+
|
|
676
|
+
# For secondary axes, we need to set both xaxis and yaxis explicitly
|
|
677
|
+
# and NOT use row/col which would override these references
|
|
678
|
+
subplot_idx = (row - 1) * n_cols + col
|
|
679
|
+
xaxis_ref = f"x{subplot_idx}" if subplot_idx > 1 else "x"
|
|
486
680
|
|
|
487
681
|
trace = go.Scatter(
|
|
488
682
|
x=x,
|
|
@@ -492,11 +686,13 @@ class PlotlyBackend:
|
|
|
492
686
|
opacity=alpha,
|
|
493
687
|
name=label or "",
|
|
494
688
|
showlegend=label is not None,
|
|
689
|
+
xaxis=xaxis_ref,
|
|
495
690
|
yaxis=yaxis_name,
|
|
496
691
|
hoverinfo="skip",
|
|
497
692
|
)
|
|
498
693
|
|
|
499
|
-
|
|
694
|
+
# Add trace directly without row/col to preserve axis references
|
|
695
|
+
fig.add_trace(trace)
|
|
500
696
|
return trace
|
|
501
697
|
|
|
502
698
|
def fill_between_secondary(
|
|
@@ -510,11 +706,16 @@ class PlotlyBackend:
|
|
|
510
706
|
yaxis_name: str = "y2",
|
|
511
707
|
) -> Any:
|
|
512
708
|
"""Fill area between two y-values on secondary y-axis."""
|
|
513
|
-
fig, row = ax
|
|
709
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
514
710
|
|
|
515
711
|
if isinstance(y1, (int, float)):
|
|
516
712
|
y1 = pd.Series([y1] * len(x))
|
|
517
713
|
|
|
714
|
+
# For secondary axes, we need to set both xaxis and yaxis explicitly
|
|
715
|
+
# and NOT use row/col which would override these references
|
|
716
|
+
subplot_idx = (row - 1) * n_cols + col
|
|
717
|
+
xaxis_ref = f"x{subplot_idx}" if subplot_idx > 1 else "x"
|
|
718
|
+
|
|
518
719
|
trace = go.Scatter(
|
|
519
720
|
x=pd.concat([x, x[::-1]]),
|
|
520
721
|
y=pd.concat([y2, y1[::-1]]),
|
|
@@ -524,10 +725,12 @@ class PlotlyBackend:
|
|
|
524
725
|
line=dict(width=0),
|
|
525
726
|
showlegend=False,
|
|
526
727
|
hoverinfo="skip",
|
|
728
|
+
xaxis=xaxis_ref,
|
|
527
729
|
yaxis=yaxis_name,
|
|
528
730
|
)
|
|
529
731
|
|
|
530
|
-
|
|
732
|
+
# Add trace directly without row/col to preserve axis references
|
|
733
|
+
fig.add_trace(trace)
|
|
531
734
|
return trace
|
|
532
735
|
|
|
533
736
|
def set_secondary_ylim(
|
|
@@ -538,7 +741,7 @@ class PlotlyBackend:
|
|
|
538
741
|
yaxis_name: str = "y2",
|
|
539
742
|
) -> None:
|
|
540
743
|
"""Set secondary y-axis limits."""
|
|
541
|
-
fig, row = ax
|
|
744
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
542
745
|
yaxis_key = (
|
|
543
746
|
"yaxis" + yaxis_name[1:] if yaxis_name.startswith("y") else yaxis_name
|
|
544
747
|
)
|
|
@@ -553,7 +756,7 @@ class PlotlyBackend:
|
|
|
553
756
|
yaxis_name: str = "y2",
|
|
554
757
|
) -> None:
|
|
555
758
|
"""Set secondary y-axis label."""
|
|
556
|
-
fig, row = ax
|
|
759
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
557
760
|
label = self._convert_label(label)
|
|
558
761
|
yaxis_key = (
|
|
559
762
|
"yaxis" + yaxis_name[1:] if yaxis_name.startswith("y") else yaxis_name
|
|
@@ -650,7 +853,7 @@ class PlotlyBackend:
|
|
|
650
853
|
y_frac: float = 0.95,
|
|
651
854
|
) -> None:
|
|
652
855
|
"""Add label text at fractional position in panel."""
|
|
653
|
-
fig, row = ax
|
|
856
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
654
857
|
fig.add_annotation(
|
|
655
858
|
text=f"<b>{label}</b>",
|
|
656
859
|
xref="x domain",
|
|
@@ -660,7 +863,7 @@ class PlotlyBackend:
|
|
|
660
863
|
showarrow=False,
|
|
661
864
|
font=dict(size=12),
|
|
662
865
|
row=row,
|
|
663
|
-
col=
|
|
866
|
+
col=col,
|
|
664
867
|
)
|
|
665
868
|
|
|
666
869
|
def add_ld_legend(
|
|
@@ -674,7 +877,7 @@ class PlotlyBackend:
|
|
|
674
877
|
Uses Plotly's separate legend feature (legend="legend") so LD legend
|
|
675
878
|
can be positioned independently from eQTL and fine-mapping legends.
|
|
676
879
|
"""
|
|
677
|
-
fig, row = ax
|
|
880
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
678
881
|
|
|
679
882
|
self._add_legend_item(
|
|
680
883
|
fig, row, "Lead SNP", lead_snp_color, "diamond", 12, "legend"
|
|
@@ -720,10 +923,10 @@ class PlotlyBackend:
|
|
|
720
923
|
|
|
721
924
|
def hide_yaxis(self, ax: Tuple[go.Figure, int]) -> None:
|
|
722
925
|
"""Hide y-axis ticks, labels, line, and grid for gene track panels."""
|
|
723
|
-
fig, row = ax
|
|
926
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
724
927
|
fig.update_layout(
|
|
725
928
|
**{
|
|
726
|
-
self._axis_name("yaxis", row): dict(
|
|
929
|
+
self._axis_name("yaxis", row, col, n_cols): dict(
|
|
727
930
|
showticklabels=False,
|
|
728
931
|
showline=False,
|
|
729
932
|
showgrid=False,
|
|
@@ -735,13 +938,14 @@ class PlotlyBackend:
|
|
|
735
938
|
def format_xaxis_mb(self, ax: Tuple[go.Figure, int]) -> None:
|
|
736
939
|
"""Format x-axis to show megabase values.
|
|
737
940
|
|
|
738
|
-
Stores the
|
|
941
|
+
Stores the subplot info for later tick formatting in finalize_layout.
|
|
739
942
|
"""
|
|
740
|
-
fig, row = ax
|
|
943
|
+
fig, row, col, n_cols = self._extract_row_col(ax)
|
|
741
944
|
# Store that this axis needs Mb formatting
|
|
742
945
|
if not hasattr(fig, "_mb_format_rows"):
|
|
743
946
|
fig._mb_format_rows = []
|
|
744
|
-
|
|
947
|
+
# Store (row, col, n_cols) tuple for proper axis naming later
|
|
948
|
+
fig._mb_format_rows.append((row, col, n_cols))
|
|
745
949
|
|
|
746
950
|
def save(
|
|
747
951
|
self,
|
|
@@ -780,7 +984,7 @@ class PlotlyBackend:
|
|
|
780
984
|
Uses Plotly's separate legend feature (legend="legend2") so eQTL legend
|
|
781
985
|
is positioned independently below the LD legend.
|
|
782
986
|
"""
|
|
783
|
-
fig, row = ax
|
|
987
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
784
988
|
|
|
785
989
|
for _, _, label, color in eqtl_positive_bins:
|
|
786
990
|
self._add_legend_item(fig, row, label, color, "triangle-up", 10, "legend2")
|
|
@@ -805,7 +1009,7 @@ class PlotlyBackend:
|
|
|
805
1009
|
if not credible_sets:
|
|
806
1010
|
return
|
|
807
1011
|
|
|
808
|
-
fig, row = ax
|
|
1012
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
809
1013
|
|
|
810
1014
|
for cs_id in credible_sets:
|
|
811
1015
|
self._add_legend_item(
|
|
@@ -847,8 +1051,8 @@ class PlotlyBackend:
|
|
|
847
1051
|
zorder: int = 1,
|
|
848
1052
|
) -> Any:
|
|
849
1053
|
"""Add a vertical line across the panel."""
|
|
850
|
-
fig, row = ax
|
|
851
|
-
dash =
|
|
1054
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
1055
|
+
dash = _DASH_MAP.get(linestyle, "dash")
|
|
852
1056
|
|
|
853
1057
|
fig.add_vline(
|
|
854
1058
|
x=x,
|
|
@@ -857,7 +1061,7 @@ class PlotlyBackend:
|
|
|
857
1061
|
line_width=linewidth,
|
|
858
1062
|
opacity=alpha,
|
|
859
1063
|
row=row,
|
|
860
|
-
col=
|
|
1064
|
+
col=col,
|
|
861
1065
|
)
|
|
862
1066
|
|
|
863
1067
|
def hbar(
|
|
@@ -873,7 +1077,7 @@ class PlotlyBackend:
|
|
|
873
1077
|
zorder: int = 2,
|
|
874
1078
|
) -> Any:
|
|
875
1079
|
"""Create horizontal bar chart."""
|
|
876
|
-
fig, row = ax
|
|
1080
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
877
1081
|
|
|
878
1082
|
# Convert left to array if scalar
|
|
879
1083
|
if isinstance(left, (int, float)):
|
|
@@ -893,7 +1097,7 @@ class PlotlyBackend:
|
|
|
893
1097
|
showlegend=False,
|
|
894
1098
|
)
|
|
895
1099
|
|
|
896
|
-
fig.add_trace(trace, row=row, col=
|
|
1100
|
+
fig.add_trace(trace, row=row, col=col)
|
|
897
1101
|
return trace
|
|
898
1102
|
|
|
899
1103
|
def errorbar_h(
|
|
@@ -909,7 +1113,7 @@ class PlotlyBackend:
|
|
|
909
1113
|
zorder: int = 3,
|
|
910
1114
|
) -> Any:
|
|
911
1115
|
"""Add horizontal error bars."""
|
|
912
|
-
fig, row = ax
|
|
1116
|
+
fig, row, col, _ = self._extract_row_col(ax)
|
|
913
1117
|
|
|
914
1118
|
trace = go.Scatter(
|
|
915
1119
|
x=x,
|
|
@@ -928,7 +1132,7 @@ class PlotlyBackend:
|
|
|
928
1132
|
showlegend=False,
|
|
929
1133
|
)
|
|
930
1134
|
|
|
931
|
-
fig.add_trace(trace, row=row, col=
|
|
1135
|
+
fig.add_trace(trace, row=row, col=col)
|
|
932
1136
|
return trace
|
|
933
1137
|
|
|
934
1138
|
def finalize_layout(
|
|
@@ -960,8 +1164,13 @@ class PlotlyBackend:
|
|
|
960
1164
|
if hasattr(fig, "_mb_format_rows"):
|
|
961
1165
|
import numpy as np
|
|
962
1166
|
|
|
963
|
-
for
|
|
964
|
-
|
|
1167
|
+
for item in fig._mb_format_rows:
|
|
1168
|
+
# Handle both old format (row) and new format (row, col, n_cols)
|
|
1169
|
+
if isinstance(item, tuple):
|
|
1170
|
+
row, col, n_cols = item
|
|
1171
|
+
else:
|
|
1172
|
+
row, col, n_cols = item, 1, 1
|
|
1173
|
+
xaxis_name = self._axis_name("xaxis", row, col, n_cols)
|
|
965
1174
|
xaxis = getattr(fig.layout, xaxis_name, None)
|
|
966
1175
|
|
|
967
1176
|
# Get x-range from the axis or compute from data
|