pylocuszoom 0.8.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,21 @@ from plotly.subplots import make_subplots
11
11
 
12
12
  from . import convert_latex_to_unicode, register_backend
13
13
 
14
+ # Style mappings (matplotlib -> Plotly)
15
+ _MARKER_SYMBOLS = {
16
+ "o": "circle",
17
+ "D": "diamond",
18
+ "s": "square",
19
+ "^": "triangle-up",
20
+ "v": "triangle-down",
21
+ }
22
+ _DASH_MAP = {
23
+ "-": "solid",
24
+ "--": "dash",
25
+ ":": "dot",
26
+ "-.": "dashdot",
27
+ }
28
+
14
29
 
15
30
  @register_backend("plotly")
16
31
  class PlotlyBackend:
@@ -23,21 +38,6 @@ class PlotlyBackend:
23
38
  - Nearest gene
24
39
  """
25
40
 
26
- # Class constants for style mappings
27
- _MARKER_SYMBOLS = {
28
- "o": "circle",
29
- "D": "diamond",
30
- "s": "square",
31
- "^": "triangle-up",
32
- "v": "triangle-down",
33
- }
34
- _DASH_MAP = {
35
- "-": "solid",
36
- "--": "dash",
37
- ":": "dot",
38
- "-.": "dashdot",
39
- }
40
-
41
41
  @property
42
42
  def supports_snp_labels(self) -> bool:
43
43
  """Plotly uses hover tooltips instead of labels."""
@@ -113,6 +113,101 @@ class PlotlyBackend:
113
113
  panel_refs = [(fig, row) for row in range(1, n_panels + 1)]
114
114
  return fig, panel_refs
115
115
 
116
+ def create_figure_grid(
117
+ self,
118
+ n_rows: int,
119
+ n_cols: int,
120
+ width_ratios: Optional[List[float]] = None,
121
+ height_ratios: Optional[List[float]] = None,
122
+ figsize: Tuple[float, float] = (12.0, 8.0),
123
+ ) -> Tuple[go.Figure, List[Any]]:
124
+ """Create a figure with a grid of subplots.
125
+
126
+ Args:
127
+ n_rows: Number of rows.
128
+ n_cols: Number of columns.
129
+ width_ratios: Relative widths for columns.
130
+ height_ratios: Relative heights for rows.
131
+ figsize: Figure size as (width, height).
132
+
133
+ Returns:
134
+ Tuple of (figure, flattened list of (fig, row, col) tuples).
135
+ """
136
+ width_px = int(figsize[0] * 100)
137
+ height_px = int(figsize[1] * 100)
138
+
139
+ # Normalize ratios
140
+ if width_ratios is not None:
141
+ total = sum(width_ratios)
142
+ column_widths = [w / total for w in width_ratios]
143
+ else:
144
+ column_widths = None
145
+
146
+ if height_ratios is not None:
147
+ total = sum(height_ratios)
148
+ row_heights_norm = [h / total for h in height_ratios]
149
+ else:
150
+ row_heights_norm = None
151
+
152
+ fig = make_subplots(
153
+ rows=n_rows,
154
+ cols=n_cols,
155
+ column_widths=column_widths,
156
+ row_heights=row_heights_norm,
157
+ horizontal_spacing=0.08,
158
+ vertical_spacing=0.08,
159
+ )
160
+
161
+ fig.update_layout(
162
+ width=width_px,
163
+ height=height_px,
164
+ showlegend=True,
165
+ template="plotly_white",
166
+ )
167
+
168
+ # Store grid dimensions for axis naming
169
+ fig._n_cols = n_cols
170
+ fig._n_rows = n_rows
171
+
172
+ # Style all panels
173
+ axis_style = dict(
174
+ showgrid=False,
175
+ showline=True,
176
+ linecolor="black",
177
+ ticks="outside",
178
+ zeroline=False,
179
+ )
180
+ for row in range(1, n_rows + 1):
181
+ for col in range(1, n_cols + 1):
182
+ subplot_idx = (row - 1) * n_cols + col
183
+ xaxis = f"xaxis{subplot_idx}" if subplot_idx > 1 else "xaxis"
184
+ yaxis = f"yaxis{subplot_idx}" if subplot_idx > 1 else "yaxis"
185
+ fig.update_layout(**{xaxis: axis_style, yaxis: axis_style})
186
+
187
+ # Return flattened list of (fig, row, col) tuples
188
+ panel_refs = []
189
+ for row in range(1, n_rows + 1):
190
+ for col in range(1, n_cols + 1):
191
+ panel_refs.append((fig, row, col))
192
+ return fig, panel_refs
193
+
194
+ def _extract_row_col(self, ax: Any) -> Tuple[go.Figure, int, int, int]:
195
+ """Extract figure, row, col, and n_cols from ax tuple.
196
+
197
+ Handles both (fig, row) for create_figure and (fig, row, col) for create_figure_grid.
198
+
199
+ Returns:
200
+ Tuple of (figure, row, col, n_cols).
201
+ """
202
+ if len(ax) == 2:
203
+ fig, row = ax
204
+ col = 1
205
+ else:
206
+ fig, row, col = ax
207
+ # Get n_cols from figure if stored, default to 1 for single-column layouts
208
+ n_cols = getattr(fig, "_n_cols", 1)
209
+ return fig, row, col, n_cols
210
+
116
211
  def scatter(
117
212
  self,
118
213
  ax: Tuple[go.Figure, int],
@@ -129,12 +224,12 @@ class PlotlyBackend:
129
224
  ) -> Any:
130
225
  """Create a scatter plot on the given panel.
131
226
 
132
- For plotly, ax is a tuple of (figure, row_number).
227
+ For plotly, ax is a tuple of (figure, row_number) or (figure, row, col).
133
228
  """
134
- fig, row = ax
229
+ fig, row, col, _ = self._extract_row_col(ax)
135
230
 
136
231
  # Convert matplotlib marker to plotly symbol
137
- symbol = self._MARKER_SYMBOLS.get(marker, "circle")
232
+ symbol = _MARKER_SYMBOLS.get(marker, "circle")
138
233
 
139
234
  # Convert size (matplotlib uses area, plotly uses diameter)
140
235
  if isinstance(sizes, (int, float)):
@@ -147,16 +242,16 @@ class PlotlyBackend:
147
242
  customdata = hover_data.values
148
243
  hover_cols = hover_data.columns.tolist()
149
244
  hovertemplate = "<b>%{customdata[0]}</b><br>"
150
- for i, col in enumerate(hover_cols[1:], 1):
151
- col_lower = col.lower()
245
+ for i, col_name in enumerate(hover_cols[1:], 1):
246
+ col_lower = col_name.lower()
152
247
  if col_lower in ("p-value", "pval", "p_value"):
153
- hovertemplate += f"{col}: %{{customdata[{i}]:.2e}}<br>"
248
+ hovertemplate += f"{col_name}: %{{customdata[{i}]:.2e}}<br>"
154
249
  elif any(x in col_lower for x in ("r2", "r²", "ld")):
155
- hovertemplate += f"{col}: %{{customdata[{i}]:.3f}}<br>"
250
+ hovertemplate += f"{col_name}: %{{customdata[{i}]:.3f}}<br>"
156
251
  elif "pos" in col_lower:
157
- hovertemplate += f"{col}: %{{customdata[{i}]:,.0f}}<br>"
252
+ hovertemplate += f"{col_name}: %{{customdata[{i}]:,.0f}}<br>"
158
253
  else:
159
- hovertemplate += f"{col}: %{{customdata[{i}]}}<br>"
254
+ hovertemplate += f"{col_name}: %{{customdata[{i}]}}<br>"
160
255
  hovertemplate += "<extra></extra>"
161
256
  else:
162
257
  customdata = None
@@ -184,7 +279,7 @@ class PlotlyBackend:
184
279
  showlegend=label is not None,
185
280
  )
186
281
 
187
- fig.add_trace(trace, row=row, col=1)
282
+ fig.add_trace(trace, row=row, col=col)
188
283
  return trace
189
284
 
190
285
  def line(
@@ -200,8 +295,8 @@ class PlotlyBackend:
200
295
  label: Optional[str] = None,
201
296
  ) -> Any:
202
297
  """Create a line plot on the given panel."""
203
- fig, row = ax
204
- dash = self._DASH_MAP.get(linestyle, "solid")
298
+ fig, row, col, _ = self._extract_row_col(ax)
299
+ dash = _DASH_MAP.get(linestyle, "solid")
205
300
 
206
301
  trace = go.Scatter(
207
302
  x=x,
@@ -213,7 +308,7 @@ class PlotlyBackend:
213
308
  showlegend=label is not None,
214
309
  )
215
310
 
216
- fig.add_trace(trace, row=row, col=1)
311
+ fig.add_trace(trace, row=row, col=col)
217
312
  return trace
218
313
 
219
314
  def fill_between(
@@ -227,7 +322,7 @@ class PlotlyBackend:
227
322
  zorder: int = 0,
228
323
  ) -> Any:
229
324
  """Fill area between two y-values."""
230
- fig, row = ax
325
+ fig, row, col, _ = self._extract_row_col(ax)
231
326
 
232
327
  # Convert y1 to series if scalar
233
328
  if isinstance(y1, (int, float)):
@@ -244,7 +339,7 @@ class PlotlyBackend:
244
339
  hoverinfo="skip",
245
340
  )
246
341
 
247
- fig.add_trace(trace, row=row, col=1)
342
+ fig.add_trace(trace, row=row, col=col)
248
343
  return trace
249
344
 
250
345
  def axhline(
@@ -258,8 +353,8 @@ class PlotlyBackend:
258
353
  zorder: int = 1,
259
354
  ) -> Any:
260
355
  """Add a horizontal line across the panel."""
261
- fig, row = ax
262
- dash = self._DASH_MAP.get(linestyle, "dash")
356
+ fig, row, col, _ = self._extract_row_col(ax)
357
+ dash = _DASH_MAP.get(linestyle, "dash")
263
358
 
264
359
  fig.add_hline(
265
360
  y=y,
@@ -268,7 +363,7 @@ class PlotlyBackend:
268
363
  line_width=linewidth,
269
364
  opacity=alpha,
270
365
  row=row,
271
- col=1,
366
+ col=col,
272
367
  )
273
368
 
274
369
  def add_text(
@@ -284,7 +379,7 @@ class PlotlyBackend:
284
379
  color: str = "black",
285
380
  ) -> Any:
286
381
  """Add text annotation to panel."""
287
- fig, row = ax
382
+ fig, row, col, _ = self._extract_row_col(ax)
288
383
 
289
384
  # Map alignment
290
385
  xanchor_map = {"center": "center", "left": "left", "right": "right"}
@@ -300,7 +395,7 @@ class PlotlyBackend:
300
395
  textangle=-rotation,
301
396
  showarrow=False,
302
397
  row=row,
303
- col=1,
398
+ col=col,
304
399
  )
305
400
 
306
401
  def add_rectangle(
@@ -315,7 +410,7 @@ class PlotlyBackend:
315
410
  zorder: int = 2,
316
411
  ) -> Any:
317
412
  """Add a rectangle to the panel."""
318
- fig, row = ax
413
+ fig, row, col, _ = self._extract_row_col(ax)
319
414
 
320
415
  x0, y0 = xy
321
416
  x1, y1 = x0 + width, y0 + height
@@ -329,7 +424,7 @@ class PlotlyBackend:
329
424
  fillcolor=facecolor,
330
425
  line=dict(color=edgecolor, width=linewidth),
331
426
  row=row,
332
- col=1,
427
+ col=col,
333
428
  )
334
429
 
335
430
  def add_polygon(
@@ -342,7 +437,7 @@ class PlotlyBackend:
342
437
  zorder: int = 2,
343
438
  ) -> Any:
344
439
  """Add a polygon (e.g., triangle for strand arrows) to the panel."""
345
- fig, row = ax
440
+ fig, row, col, _ = self._extract_row_col(ax)
346
441
 
347
442
  # Build SVG path from points
348
443
  path = f"M {points[0][0]} {points[0][1]}"
@@ -356,28 +451,32 @@ class PlotlyBackend:
356
451
  fillcolor=facecolor,
357
452
  line=dict(color=edgecolor, width=linewidth),
358
453
  row=row,
359
- col=1,
454
+ col=col,
360
455
  )
361
456
 
362
457
  def set_xlim(self, ax: Tuple[go.Figure, int], left: float, right: float) -> None:
363
458
  """Set x-axis limits."""
364
- fig, row = ax
365
- fig.update_layout(**{self._axis_name("xaxis", row): dict(range=[left, right])})
459
+ fig, row, col, n_cols = self._extract_row_col(ax)
460
+ fig.update_layout(
461
+ **{self._axis_name("xaxis", row, col, n_cols): dict(range=[left, right])}
462
+ )
366
463
 
367
464
  def set_ylim(self, ax: Tuple[go.Figure, int], bottom: float, top: float) -> None:
368
465
  """Set y-axis limits."""
369
- fig, row = ax
370
- fig.update_layout(**{self._axis_name("yaxis", row): dict(range=[bottom, top])})
466
+ fig, row, col, n_cols = self._extract_row_col(ax)
467
+ fig.update_layout(
468
+ **{self._axis_name("yaxis", row, col, n_cols): dict(range=[bottom, top])}
469
+ )
371
470
 
372
471
  def set_xlabel(
373
472
  self, ax: Tuple[go.Figure, int], label: str, fontsize: int = 12
374
473
  ) -> None:
375
474
  """Set x-axis label."""
376
- fig, row = ax
475
+ fig, row, col, n_cols = self._extract_row_col(ax)
377
476
  label = self._convert_label(label)
378
477
  fig.update_layout(
379
478
  **{
380
- self._axis_name("xaxis", row): dict(
479
+ self._axis_name("xaxis", row, col, n_cols): dict(
381
480
  title=dict(text=label, font=dict(size=fontsize))
382
481
  )
383
482
  }
@@ -387,11 +486,11 @@ class PlotlyBackend:
387
486
  self, ax: Tuple[go.Figure, int], label: str, fontsize: int = 12
388
487
  ) -> None:
389
488
  """Set y-axis label."""
390
- fig, row = ax
489
+ fig, row, col, n_cols = self._extract_row_col(ax)
391
490
  label = self._convert_label(label)
392
491
  fig.update_layout(
393
492
  **{
394
- self._axis_name("yaxis", row): dict(
493
+ self._axis_name("yaxis", row, col, n_cols): dict(
395
494
  title=dict(text=label, font=dict(size=fontsize))
396
495
  )
397
496
  }
@@ -405,10 +504,10 @@ class PlotlyBackend:
405
504
  fontsize: int = 10,
406
505
  ) -> None:
407
506
  """Set y-axis tick positions and labels."""
408
- fig, row = ax
507
+ fig, row, col, n_cols = self._extract_row_col(ax)
409
508
  fig.update_layout(
410
509
  **{
411
- self._axis_name("yaxis", row): dict(
510
+ self._axis_name("yaxis", row, col, n_cols): dict(
412
511
  tickmode="array",
413
512
  tickvals=positions,
414
513
  ticktext=labels,
@@ -417,13 +516,48 @@ class PlotlyBackend:
417
516
  }
418
517
  )
419
518
 
420
- def _axis_name(self, axis: str, row: int) -> str:
421
- """Get Plotly axis name for a given row.
519
+ def set_xticks(
520
+ self,
521
+ ax: Tuple[go.Figure, int],
522
+ positions: List[float],
523
+ labels: List[str],
524
+ fontsize: int = 10,
525
+ rotation: int = 0,
526
+ ha: str = "center",
527
+ ) -> None:
528
+ """Set x-axis tick positions and labels."""
529
+ fig, row, col, n_cols = self._extract_row_col(ax)
530
+ fig.update_layout(
531
+ **{
532
+ self._axis_name("xaxis", row, col, n_cols): dict(
533
+ tickmode="array",
534
+ tickvals=positions,
535
+ ticktext=labels,
536
+ tickfont=dict(size=fontsize),
537
+ tickangle=-rotation if rotation else 0,
538
+ )
539
+ }
540
+ )
541
+
542
+ def _axis_name(self, axis: str, row: int, col: int = 1, n_cols: int = 1) -> str:
543
+ """Get Plotly axis name for a given row and column.
422
544
 
423
- Plotly names axes as 'xaxis', 'yaxis' for row 1, and
424
- 'xaxis2', 'yaxis2', etc. for subsequent rows.
545
+ Plotly names axes using a linear subplot index:
546
+ - subplot (1,1) uses 'xaxis', 'yaxis'
547
+ - subplot (1,2) uses 'xaxis2', 'yaxis2'
548
+ - subplot (2,1) uses 'xaxis3', 'yaxis3' (for 2-column grid)
549
+
550
+ Args:
551
+ axis: Base axis name ('xaxis' or 'yaxis').
552
+ row: Row number (1-indexed).
553
+ col: Column number (1-indexed).
554
+ n_cols: Total number of columns in the grid.
555
+
556
+ Returns:
557
+ Plotly axis name string.
425
558
  """
426
- return f"{axis}{row}" if row > 1 else axis
559
+ subplot_idx = (row - 1) * n_cols + col
560
+ return f"{axis}{subplot_idx}" if subplot_idx > 1 else axis
427
561
 
428
562
  def _get_legend_position(self, loc: str) -> dict:
429
563
  """Map matplotlib-style legend location to Plotly position dict."""
@@ -442,26 +576,81 @@ class PlotlyBackend:
442
576
  def set_title(
443
577
  self, ax: Tuple[go.Figure, int], title: str, fontsize: int = 14
444
578
  ) -> None:
445
- """Set figure title (only works for first panel)."""
446
- fig, row = ax
447
- if row == 1:
579
+ """Set subplot title using annotation.
580
+
581
+ For grid layouts, this adds an annotation above the subplot.
582
+ For single-column layouts, sets the global figure title for the first panel.
583
+ """
584
+ fig, row, col, n_cols = self._extract_row_col(ax)
585
+
586
+ if n_cols == 1 and row == 1:
587
+ # Single-column layout: use global figure title
448
588
  fig.update_layout(title=dict(text=title, font=dict(size=fontsize)))
589
+ else:
590
+ # Grid layout: add annotation above the subplot
591
+ # Use subplot's axis domain for positioning
592
+ # Plotly uses "x" or "x2", "x3", etc. (not "xaxis")
593
+ subplot_idx = (row - 1) * n_cols + col
594
+ xref = f"x{subplot_idx} domain" if subplot_idx > 1 else "x domain"
595
+ yref = f"y{subplot_idx} domain" if subplot_idx > 1 else "y domain"
596
+
597
+ fig.add_annotation(
598
+ text=f"<b>{title}</b>",
599
+ xref=xref,
600
+ yref=yref,
601
+ x=0.5,
602
+ y=1.05,
603
+ showarrow=False,
604
+ font=dict(size=fontsize),
605
+ xanchor="center",
606
+ yanchor="bottom",
607
+ )
608
+
609
+ def set_suptitle(self, fig: go.Figure, title: str, fontsize: int = 14) -> None:
610
+ """Set overall figure title (super title)."""
611
+ fig.update_layout(
612
+ title=dict(
613
+ text=title,
614
+ font=dict(size=fontsize),
615
+ x=0.5,
616
+ xanchor="center",
617
+ )
618
+ )
449
619
 
450
620
  def create_twin_axis(self, ax: Tuple[go.Figure, int]) -> Tuple[go.Figure, int, str]:
451
621
  """Create a secondary y-axis.
452
622
 
453
623
  Returns tuple of (figure, row, secondary_yaxis_name).
624
+
625
+ For Plotly subplots, we need unique axis names that don't conflict
626
+ with the subplot axes. We use a high number suffix to avoid conflicts.
454
627
  """
455
- fig, row = ax
456
- secondary_y = f"y{row}2" if row > 1 else "y2"
628
+ fig, row, col, n_cols = self._extract_row_col(ax)
629
+
630
+ # Calculate subplot index for proper axis naming
631
+ subplot_idx = (row - 1) * n_cols + col
632
+
633
+ # Use a unique suffix that won't conflict with subplot axis numbering
634
+ # yaxis10, yaxis11, etc. are unlikely to conflict with typical subplot counts
635
+ secondary_suffix = 10 + subplot_idx - 1
636
+ secondary_y = f"y{secondary_suffix}"
637
+ yaxis_name = f"yaxis{secondary_suffix}"
457
638
 
458
- # Configure secondary y-axis
459
- yaxis_name = f"yaxis{row}2" if row > 1 else "yaxis2"
639
+ # Get the primary y-axis name for this subplot
640
+ primary_y = f"y{subplot_idx}" if subplot_idx > 1 else "y"
641
+ # Get the x-axis name for this subplot
642
+ xaxis_ref = f"x{subplot_idx}" if subplot_idx > 1 else "x"
643
+
644
+ # Configure secondary y-axis to overlay the primary axis of this row
460
645
  fig.update_layout(
461
646
  **{
462
647
  yaxis_name: dict(
463
- overlaying=f"y{row}" if row > 1 else "y",
648
+ overlaying=primary_y,
464
649
  side="right",
650
+ anchor=xaxis_ref,
651
+ showgrid=False,
652
+ showline=False,
653
+ zeroline=False,
465
654
  )
466
655
  }
467
656
  )
@@ -481,8 +670,13 @@ class PlotlyBackend:
481
670
  yaxis_name: str = "y2",
482
671
  ) -> Any:
483
672
  """Create a line plot on secondary y-axis."""
484
- fig, row = ax
485
- dash = self._DASH_MAP.get(linestyle, "solid")
673
+ fig, row, col, n_cols = self._extract_row_col(ax)
674
+ dash = _DASH_MAP.get(linestyle, "solid")
675
+
676
+ # For secondary axes, we need to set both xaxis and yaxis explicitly
677
+ # and NOT use row/col which would override these references
678
+ subplot_idx = (row - 1) * n_cols + col
679
+ xaxis_ref = f"x{subplot_idx}" if subplot_idx > 1 else "x"
486
680
 
487
681
  trace = go.Scatter(
488
682
  x=x,
@@ -492,11 +686,13 @@ class PlotlyBackend:
492
686
  opacity=alpha,
493
687
  name=label or "",
494
688
  showlegend=label is not None,
689
+ xaxis=xaxis_ref,
495
690
  yaxis=yaxis_name,
496
691
  hoverinfo="skip",
497
692
  )
498
693
 
499
- fig.add_trace(trace, row=row, col=1)
694
+ # Add trace directly without row/col to preserve axis references
695
+ fig.add_trace(trace)
500
696
  return trace
501
697
 
502
698
  def fill_between_secondary(
@@ -510,11 +706,16 @@ class PlotlyBackend:
510
706
  yaxis_name: str = "y2",
511
707
  ) -> Any:
512
708
  """Fill area between two y-values on secondary y-axis."""
513
- fig, row = ax
709
+ fig, row, col, n_cols = self._extract_row_col(ax)
514
710
 
515
711
  if isinstance(y1, (int, float)):
516
712
  y1 = pd.Series([y1] * len(x))
517
713
 
714
+ # For secondary axes, we need to set both xaxis and yaxis explicitly
715
+ # and NOT use row/col which would override these references
716
+ subplot_idx = (row - 1) * n_cols + col
717
+ xaxis_ref = f"x{subplot_idx}" if subplot_idx > 1 else "x"
718
+
518
719
  trace = go.Scatter(
519
720
  x=pd.concat([x, x[::-1]]),
520
721
  y=pd.concat([y2, y1[::-1]]),
@@ -524,10 +725,12 @@ class PlotlyBackend:
524
725
  line=dict(width=0),
525
726
  showlegend=False,
526
727
  hoverinfo="skip",
728
+ xaxis=xaxis_ref,
527
729
  yaxis=yaxis_name,
528
730
  )
529
731
 
530
- fig.add_trace(trace, row=row, col=1)
732
+ # Add trace directly without row/col to preserve axis references
733
+ fig.add_trace(trace)
531
734
  return trace
532
735
 
533
736
  def set_secondary_ylim(
@@ -538,7 +741,7 @@ class PlotlyBackend:
538
741
  yaxis_name: str = "y2",
539
742
  ) -> None:
540
743
  """Set secondary y-axis limits."""
541
- fig, row = ax
744
+ fig, row, col, _ = self._extract_row_col(ax)
542
745
  yaxis_key = (
543
746
  "yaxis" + yaxis_name[1:] if yaxis_name.startswith("y") else yaxis_name
544
747
  )
@@ -553,7 +756,7 @@ class PlotlyBackend:
553
756
  yaxis_name: str = "y2",
554
757
  ) -> None:
555
758
  """Set secondary y-axis label."""
556
- fig, row = ax
759
+ fig, row, col, _ = self._extract_row_col(ax)
557
760
  label = self._convert_label(label)
558
761
  yaxis_key = (
559
762
  "yaxis" + yaxis_name[1:] if yaxis_name.startswith("y") else yaxis_name
@@ -650,7 +853,7 @@ class PlotlyBackend:
650
853
  y_frac: float = 0.95,
651
854
  ) -> None:
652
855
  """Add label text at fractional position in panel."""
653
- fig, row = ax
856
+ fig, row, col, _ = self._extract_row_col(ax)
654
857
  fig.add_annotation(
655
858
  text=f"<b>{label}</b>",
656
859
  xref="x domain",
@@ -660,7 +863,7 @@ class PlotlyBackend:
660
863
  showarrow=False,
661
864
  font=dict(size=12),
662
865
  row=row,
663
- col=1,
866
+ col=col,
664
867
  )
665
868
 
666
869
  def add_ld_legend(
@@ -674,7 +877,7 @@ class PlotlyBackend:
674
877
  Uses Plotly's separate legend feature (legend="legend") so LD legend
675
878
  can be positioned independently from eQTL and fine-mapping legends.
676
879
  """
677
- fig, row = ax
880
+ fig, row, col, _ = self._extract_row_col(ax)
678
881
 
679
882
  self._add_legend_item(
680
883
  fig, row, "Lead SNP", lead_snp_color, "diamond", 12, "legend"
@@ -720,10 +923,10 @@ class PlotlyBackend:
720
923
 
721
924
  def hide_yaxis(self, ax: Tuple[go.Figure, int]) -> None:
722
925
  """Hide y-axis ticks, labels, line, and grid for gene track panels."""
723
- fig, row = ax
926
+ fig, row, col, n_cols = self._extract_row_col(ax)
724
927
  fig.update_layout(
725
928
  **{
726
- self._axis_name("yaxis", row): dict(
929
+ self._axis_name("yaxis", row, col, n_cols): dict(
727
930
  showticklabels=False,
728
931
  showline=False,
729
932
  showgrid=False,
@@ -735,13 +938,14 @@ class PlotlyBackend:
735
938
  def format_xaxis_mb(self, ax: Tuple[go.Figure, int]) -> None:
736
939
  """Format x-axis to show megabase values.
737
940
 
738
- Stores the row for later tick formatting in finalize_layout.
941
+ Stores the subplot info for later tick formatting in finalize_layout.
739
942
  """
740
- fig, row = ax
943
+ fig, row, col, n_cols = self._extract_row_col(ax)
741
944
  # Store that this axis needs Mb formatting
742
945
  if not hasattr(fig, "_mb_format_rows"):
743
946
  fig._mb_format_rows = []
744
- fig._mb_format_rows.append(row)
947
+ # Store (row, col, n_cols) tuple for proper axis naming later
948
+ fig._mb_format_rows.append((row, col, n_cols))
745
949
 
746
950
  def save(
747
951
  self,
@@ -780,7 +984,7 @@ class PlotlyBackend:
780
984
  Uses Plotly's separate legend feature (legend="legend2") so eQTL legend
781
985
  is positioned independently below the LD legend.
782
986
  """
783
- fig, row = ax
987
+ fig, row, col, _ = self._extract_row_col(ax)
784
988
 
785
989
  for _, _, label, color in eqtl_positive_bins:
786
990
  self._add_legend_item(fig, row, label, color, "triangle-up", 10, "legend2")
@@ -805,7 +1009,7 @@ class PlotlyBackend:
805
1009
  if not credible_sets:
806
1010
  return
807
1011
 
808
- fig, row = ax
1012
+ fig, row, col, _ = self._extract_row_col(ax)
809
1013
 
810
1014
  for cs_id in credible_sets:
811
1015
  self._add_legend_item(
@@ -847,8 +1051,8 @@ class PlotlyBackend:
847
1051
  zorder: int = 1,
848
1052
  ) -> Any:
849
1053
  """Add a vertical line across the panel."""
850
- fig, row = ax
851
- dash = self._DASH_MAP.get(linestyle, "dash")
1054
+ fig, row, col, _ = self._extract_row_col(ax)
1055
+ dash = _DASH_MAP.get(linestyle, "dash")
852
1056
 
853
1057
  fig.add_vline(
854
1058
  x=x,
@@ -857,7 +1061,7 @@ class PlotlyBackend:
857
1061
  line_width=linewidth,
858
1062
  opacity=alpha,
859
1063
  row=row,
860
- col=1,
1064
+ col=col,
861
1065
  )
862
1066
 
863
1067
  def hbar(
@@ -873,7 +1077,7 @@ class PlotlyBackend:
873
1077
  zorder: int = 2,
874
1078
  ) -> Any:
875
1079
  """Create horizontal bar chart."""
876
- fig, row = ax
1080
+ fig, row, col, _ = self._extract_row_col(ax)
877
1081
 
878
1082
  # Convert left to array if scalar
879
1083
  if isinstance(left, (int, float)):
@@ -893,7 +1097,7 @@ class PlotlyBackend:
893
1097
  showlegend=False,
894
1098
  )
895
1099
 
896
- fig.add_trace(trace, row=row, col=1)
1100
+ fig.add_trace(trace, row=row, col=col)
897
1101
  return trace
898
1102
 
899
1103
  def errorbar_h(
@@ -909,7 +1113,7 @@ class PlotlyBackend:
909
1113
  zorder: int = 3,
910
1114
  ) -> Any:
911
1115
  """Add horizontal error bars."""
912
- fig, row = ax
1116
+ fig, row, col, _ = self._extract_row_col(ax)
913
1117
 
914
1118
  trace = go.Scatter(
915
1119
  x=x,
@@ -928,7 +1132,7 @@ class PlotlyBackend:
928
1132
  showlegend=False,
929
1133
  )
930
1134
 
931
- fig.add_trace(trace, row=row, col=1)
1135
+ fig.add_trace(trace, row=row, col=col)
932
1136
  return trace
933
1137
 
934
1138
  def finalize_layout(
@@ -960,8 +1164,13 @@ class PlotlyBackend:
960
1164
  if hasattr(fig, "_mb_format_rows"):
961
1165
  import numpy as np
962
1166
 
963
- for row in fig._mb_format_rows:
964
- xaxis_name = self._axis_name("xaxis", row)
1167
+ for item in fig._mb_format_rows:
1168
+ # Handle both old format (row) and new format (row, col, n_cols)
1169
+ if isinstance(item, tuple):
1170
+ row, col, n_cols = item
1171
+ else:
1172
+ row, col, n_cols = item, 1, 1
1173
+ xaxis_name = self._axis_name("xaxis", row, col, n_cols)
965
1174
  xaxis = getattr(fig.layout, xaxis_name, None)
966
1175
 
967
1176
  # Get x-range from the axis or compute from data