pylocuszoom 0.3.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,15 +20,24 @@ class PlotlyBackend:
20
20
  - Nearest gene
21
21
  """
22
22
 
23
+ # Class constants for style mappings
24
+ _MARKER_SYMBOLS = {
25
+ "o": "circle",
26
+ "D": "diamond",
27
+ "s": "square",
28
+ "^": "triangle-up",
29
+ "v": "triangle-down",
30
+ }
31
+ _DASH_MAP = {
32
+ "-": "solid",
33
+ "--": "dash",
34
+ ":": "dot",
35
+ "-.": "dashdot",
36
+ }
37
+
23
38
  def __init__(self) -> None:
24
39
  """Initialize the plotly backend."""
25
- self._marker_symbols = {
26
- "o": "circle",
27
- "D": "diamond",
28
- "s": "square",
29
- "^": "triangle-up",
30
- "v": "triangle-down",
31
- }
40
+ pass
32
41
 
33
42
  def create_figure(
34
43
  self,
@@ -71,6 +80,20 @@ class PlotlyBackend:
71
80
  template="plotly_white",
72
81
  )
73
82
 
83
+ # Style all panels for clean LocusZoom appearance
84
+ axis_style = dict(
85
+ showgrid=False,
86
+ showline=True,
87
+ linecolor="black",
88
+ ticks="outside",
89
+ minor_ticks="",
90
+ zeroline=False,
91
+ )
92
+ for row in range(1, n_panels + 1):
93
+ xaxis = self._axis_name("xaxis", row)
94
+ yaxis = self._axis_name("yaxis", row)
95
+ fig.update_layout(**{xaxis: axis_style, yaxis: axis_style})
96
+
74
97
  # Return (fig, row) tuples for each panel
75
98
  # This matches the expected ax parameter format for all methods
76
99
  panel_refs = [(fig, row) for row in range(1, n_panels + 1)]
@@ -97,7 +120,7 @@ class PlotlyBackend:
97
120
  fig, row = ax
98
121
 
99
122
  # Convert matplotlib marker to plotly symbol
100
- symbol = self._marker_symbols.get(marker, "circle")
123
+ symbol = self._MARKER_SYMBOLS.get(marker, "circle")
101
124
 
102
125
  # Convert size (matplotlib uses area, plotly uses diameter)
103
126
  if isinstance(sizes, (int, float)):
@@ -112,9 +135,9 @@ class PlotlyBackend:
112
135
  hovertemplate = "<b>%{customdata[0]}</b><br>"
113
136
  for i, col in enumerate(hover_cols[1:], 1):
114
137
  col_lower = col.lower()
115
- if col_lower == "p-value" or col_lower == "pval" or col_lower == "p_value":
138
+ if col_lower in ("p-value", "pval", "p_value"):
116
139
  hovertemplate += f"{col}: %{{customdata[{i}]:.2e}}<br>"
117
- elif "r2" in col_lower or "r²" in col_lower or "ld" in col_lower:
140
+ elif any(x in col_lower for x in ("r2", "r²", "ld")):
118
141
  hovertemplate += f"{col}: %{{customdata[{i}]:.3f}}<br>"
119
142
  elif "pos" in col_lower:
120
143
  hovertemplate += f"{col}: %{{customdata[{i}]:,.0f}}<br>"
@@ -164,15 +187,7 @@ class PlotlyBackend:
164
187
  ) -> Any:
165
188
  """Create a line plot on the given panel."""
166
189
  fig, row = ax
167
-
168
- # Convert linestyle
169
- dash_map = {
170
- "-": "solid",
171
- "--": "dash",
172
- ":": "dot",
173
- "-.": "dashdot",
174
- }
175
- dash = dash_map.get(linestyle, "solid")
190
+ dash = self._DASH_MAP.get(linestyle, "solid")
176
191
 
177
192
  trace = go.Scatter(
178
193
  x=x,
@@ -230,9 +245,7 @@ class PlotlyBackend:
230
245
  ) -> Any:
231
246
  """Add a horizontal line across the panel."""
232
247
  fig, row = ax
233
-
234
- dash_map = {"-": "solid", "--": "dash", ":": "dot", "-.": "dashdot"}
235
- dash = dash_map.get(linestyle, "dash")
248
+ dash = self._DASH_MAP.get(linestyle, "dash")
236
249
 
237
250
  fig.add_hline(
238
251
  y=y,
@@ -335,25 +348,25 @@ class PlotlyBackend:
335
348
  def set_xlim(self, ax: Tuple[go.Figure, int], left: float, right: float) -> None:
336
349
  """Set x-axis limits."""
337
350
  fig, row = ax
338
- xaxis = f"xaxis{row}" if row > 1 else "xaxis"
339
- fig.update_layout(**{xaxis: dict(range=[left, right])})
351
+ fig.update_layout(**{self._axis_name("xaxis", row): dict(range=[left, right])})
340
352
 
341
353
  def set_ylim(self, ax: Tuple[go.Figure, int], bottom: float, top: float) -> None:
342
354
  """Set y-axis limits."""
343
355
  fig, row = ax
344
- yaxis = f"yaxis{row}" if row > 1 else "yaxis"
345
- fig.update_layout(**{yaxis: dict(range=[bottom, top])})
356
+ fig.update_layout(**{self._axis_name("yaxis", row): dict(range=[bottom, top])})
346
357
 
347
358
  def set_xlabel(
348
359
  self, ax: Tuple[go.Figure, int], label: str, fontsize: int = 12
349
360
  ) -> None:
350
361
  """Set x-axis label."""
351
362
  fig, row = ax
352
- xaxis = f"xaxis{row}" if row > 1 else "xaxis"
353
- # Convert LaTeX-style labels to Unicode for Plotly
354
363
  label = self._convert_label(label)
355
364
  fig.update_layout(
356
- **{xaxis: dict(title=dict(text=label, font=dict(size=fontsize)))}
365
+ **{
366
+ self._axis_name("xaxis", row): dict(
367
+ title=dict(text=label, font=dict(size=fontsize))
368
+ )
369
+ }
357
370
  )
358
371
 
359
372
  def set_ylabel(
@@ -361,16 +374,35 @@ class PlotlyBackend:
361
374
  ) -> None:
362
375
  """Set y-axis label."""
363
376
  fig, row = ax
364
- yaxis = f"yaxis{row}" if row > 1 else "yaxis"
365
- # Convert LaTeX-style labels to Unicode for Plotly
366
377
  label = self._convert_label(label)
367
378
  fig.update_layout(
368
- **{yaxis: dict(title=dict(text=label, font=dict(size=fontsize)))}
379
+ **{
380
+ self._axis_name("yaxis", row): dict(
381
+ title=dict(text=label, font=dict(size=fontsize))
382
+ )
383
+ }
369
384
  )
370
385
 
386
+ def _axis_name(self, axis: str, row: int) -> str:
387
+ """Get Plotly axis name for a given row.
388
+
389
+ Plotly names axes as 'xaxis', 'yaxis' for row 1, and
390
+ 'xaxis2', 'yaxis2', etc. for subsequent rows.
391
+ """
392
+ return f"{axis}{row}" if row > 1 else axis
393
+
394
+ def _get_legend_position(self, loc: str) -> dict:
395
+ """Map matplotlib-style legend location to Plotly position dict."""
396
+ loc_map = {
397
+ "upper left": dict(x=0.01, y=0.99, xanchor="left", yanchor="top"),
398
+ "upper right": dict(x=0.99, y=0.99, xanchor="right", yanchor="top"),
399
+ "lower left": dict(x=0.01, y=0.01, xanchor="left", yanchor="bottom"),
400
+ "lower right": dict(x=0.99, y=0.01, xanchor="right", yanchor="bottom"),
401
+ }
402
+ return loc_map.get(loc, loc_map["upper left"])
403
+
371
404
  def _convert_label(self, label: str) -> str:
372
405
  """Convert LaTeX-style labels to Unicode for Plotly display."""
373
- # Common conversions for genomics plots
374
406
  conversions = [
375
407
  (r"$-\log_{10}$ P", "-log₁₀(P)"),
376
408
  (r"$-\log_{10}$", "-log₁₀"),
@@ -428,9 +460,7 @@ class PlotlyBackend:
428
460
  ) -> Any:
429
461
  """Create a line plot on secondary y-axis."""
430
462
  fig, row = ax
431
-
432
- dash_map = {"-": "solid", "--": "dash", ":": "dot", "-.": "dashdot"}
433
- dash = dash_map.get(linestyle, "solid")
463
+ dash = self._DASH_MAP.get(linestyle, "solid")
434
464
 
435
465
  trace = go.Scatter(
436
466
  x=x,
@@ -487,7 +517,9 @@ class PlotlyBackend:
487
517
  ) -> None:
488
518
  """Set secondary y-axis limits."""
489
519
  fig, row = ax
490
- yaxis_key = "yaxis" + yaxis_name[1:] if yaxis_name.startswith("y") else yaxis_name
520
+ yaxis_key = (
521
+ "yaxis" + yaxis_name[1:] if yaxis_name.startswith("y") else yaxis_name
522
+ )
491
523
  fig.update_layout(**{yaxis_key: dict(range=[bottom, top])})
492
524
 
493
525
  def set_secondary_ylabel(
@@ -501,7 +533,9 @@ class PlotlyBackend:
501
533
  """Set secondary y-axis label."""
502
534
  fig, row = ax
503
535
  label = self._convert_label(label)
504
- yaxis_key = "yaxis" + yaxis_name[1:] if yaxis_name.startswith("y") else yaxis_name
536
+ yaxis_key = (
537
+ "yaxis" + yaxis_name[1:] if yaxis_name.startswith("y") else yaxis_name
538
+ )
505
539
  fig.update_layout(
506
540
  **{
507
541
  yaxis_key: dict(
@@ -511,48 +545,87 @@ class PlotlyBackend:
511
545
  }
512
546
  )
513
547
 
548
+ def _get_panel_y_top(self, fig: go.Figure, row: int) -> float:
549
+ """Get the top y-coordinate (in paper coords) for a subplot row.
550
+
551
+ Plotly subplots have y-axis domains that define their vertical position.
552
+ This returns the top of the domain for positioning legends.
553
+ """
554
+ yaxis = getattr(fig.layout, self._axis_name("yaxis", row), None)
555
+ if yaxis and yaxis.domain:
556
+ return yaxis.domain[1]
557
+ return 0.99
558
+
559
+ def _add_legend_item(
560
+ self,
561
+ fig: go.Figure,
562
+ row: int,
563
+ name: str,
564
+ color: str,
565
+ symbol: str,
566
+ size: int,
567
+ legend_group: str,
568
+ ) -> None:
569
+ """Add an invisible scatter trace for a legend entry."""
570
+ fig.add_trace(
571
+ go.Scatter(
572
+ x=[None],
573
+ y=[None],
574
+ mode="markers",
575
+ marker=dict(
576
+ symbol=symbol,
577
+ size=size,
578
+ color=color,
579
+ line=dict(color="black", width=0.5),
580
+ ),
581
+ name=name,
582
+ showlegend=True,
583
+ legend=legend_group,
584
+ ),
585
+ row=row,
586
+ col=1,
587
+ )
588
+
589
+ def _configure_legend(
590
+ self, fig: go.Figure, row: int, legend_key: str, title: str
591
+ ) -> None:
592
+ """Configure legend position and styling."""
593
+ y_pos = self._get_panel_y_top(fig, row)
594
+ fig.update_layout(
595
+ **{
596
+ legend_key: dict(
597
+ title=dict(text=title),
598
+ x=0.99,
599
+ y=y_pos,
600
+ xanchor="right",
601
+ yanchor="top",
602
+ bgcolor="rgba(255,255,255,0.9)",
603
+ bordercolor="black",
604
+ borderwidth=1,
605
+ )
606
+ }
607
+ )
608
+
514
609
  def add_ld_legend(
515
610
  self,
516
611
  ax: Tuple[go.Figure, int],
517
612
  ld_bins: List[Tuple[float, str, str]],
518
613
  lead_snp_color: str,
519
614
  ) -> None:
520
- """Add LD color legend using invisible scatter traces."""
615
+ """Add LD color legend using invisible scatter traces.
616
+
617
+ Uses Plotly's separate legend feature (legend="legend") so LD legend
618
+ can be positioned independently from eQTL and fine-mapping legends.
619
+ """
521
620
  fig, row = ax
522
621
 
523
- # Add LD bin markers (no lead SNP - it's shown in the actual plot)
622
+ self._add_legend_item(
623
+ fig, row, "Lead SNP", lead_snp_color, "diamond", 12, "legend"
624
+ )
524
625
  for _, label, color in ld_bins:
525
- fig.add_trace(
526
- go.Scatter(
527
- x=[None],
528
- y=[None],
529
- mode="markers",
530
- marker=dict(
531
- symbol="square",
532
- size=10,
533
- color=color,
534
- line=dict(color="black", width=0.5),
535
- ),
536
- name=label,
537
- showlegend=True,
538
- ),
539
- row=row,
540
- col=1,
541
- )
626
+ self._add_legend_item(fig, row, label, color, "square", 10, "legend")
542
627
 
543
- # Position legend
544
- fig.update_layout(
545
- legend=dict(
546
- x=0.99,
547
- y=0.99,
548
- xanchor="right",
549
- yanchor="top",
550
- title=dict(text="r²"),
551
- bgcolor="rgba(255,255,255,0.9)",
552
- bordercolor="black",
553
- borderwidth=1,
554
- )
555
- )
628
+ self._configure_legend(fig, row, "legend", "r²")
556
629
 
557
630
  def add_legend(
558
631
  self,
@@ -568,16 +641,7 @@ class PlotlyBackend:
568
641
  This method updates legend positioning.
569
642
  """
570
643
  fig, _ = ax
571
-
572
- # Map matplotlib locations to plotly
573
- loc_map = {
574
- "upper left": dict(x=0.01, y=0.99, xanchor="left", yanchor="top"),
575
- "upper right": dict(x=0.99, y=0.99, xanchor="right", yanchor="top"),
576
- "lower left": dict(x=0.01, y=0.01, xanchor="left", yanchor="bottom"),
577
- "lower right": dict(x=0.99, y=0.01, xanchor="right", yanchor="bottom"),
578
- }
579
-
580
- legend_pos = loc_map.get(loc, loc_map["upper left"])
644
+ legend_pos = self._get_legend_position(loc)
581
645
  fig.update_layout(
582
646
  legend=dict(
583
647
  **legend_pos,
@@ -597,6 +661,20 @@ class PlotlyBackend:
597
661
  # No action needed - method exists for API compatibility
598
662
  pass
599
663
 
664
+ def hide_yaxis(self, ax: Tuple[go.Figure, int]) -> None:
665
+ """Hide y-axis ticks, labels, line, and grid for gene track panels."""
666
+ fig, row = ax
667
+ fig.update_layout(
668
+ **{
669
+ self._axis_name("yaxis", row): dict(
670
+ showticklabels=False,
671
+ showline=False,
672
+ showgrid=False,
673
+ ticks="",
674
+ )
675
+ }
676
+ )
677
+
600
678
  def format_xaxis_mb(self, ax: Tuple[go.Figure, int]) -> None:
601
679
  """Format x-axis to show megabase values.
602
680
 
@@ -634,6 +712,168 @@ class PlotlyBackend:
634
712
  """Close the figure (no-op for plotly)."""
635
713
  pass
636
714
 
715
+ def add_eqtl_legend(
716
+ self,
717
+ ax: Tuple[go.Figure, int],
718
+ eqtl_positive_bins: List[Tuple[float, float, str, str]],
719
+ eqtl_negative_bins: List[Tuple[float, float, str, str]],
720
+ ) -> None:
721
+ """Add eQTL effect size legend using invisible scatter traces.
722
+
723
+ Uses Plotly's separate legend feature (legend="legend2") so eQTL legend
724
+ is positioned independently below the LD legend.
725
+ """
726
+ fig, row = ax
727
+
728
+ for _, _, label, color in eqtl_positive_bins:
729
+ self._add_legend_item(fig, row, label, color, "triangle-up", 10, "legend2")
730
+ for _, _, label, color in eqtl_negative_bins:
731
+ self._add_legend_item(
732
+ fig, row, label, color, "triangle-down", 10, "legend2"
733
+ )
734
+
735
+ self._configure_legend(fig, row, "legend2", "eQTL effect")
736
+
737
+ def add_finemapping_legend(
738
+ self,
739
+ ax: Tuple[go.Figure, int],
740
+ credible_sets: List[int],
741
+ get_color_func: Any,
742
+ ) -> None:
743
+ """Add fine-mapping credible set legend using invisible scatter traces.
744
+
745
+ Uses Plotly's separate legend feature (legend="legend2") so fine-mapping
746
+ legend is positioned independently below the LD legend.
747
+ """
748
+ if not credible_sets:
749
+ return
750
+
751
+ fig, row = ax
752
+
753
+ for cs_id in credible_sets:
754
+ self._add_legend_item(
755
+ fig, row, f"CS{cs_id}", get_color_func(cs_id), "circle", 10, "legend2"
756
+ )
757
+
758
+ self._configure_legend(fig, row, "legend2", "Credible sets")
759
+
760
+ def add_simple_legend(
761
+ self,
762
+ ax: Tuple[go.Figure, int],
763
+ label: str,
764
+ loc: str = "upper right",
765
+ ) -> None:
766
+ """Add simple legend positioning.
767
+
768
+ Plotly handles legends automatically from trace names.
769
+ This just positions the legend.
770
+ """
771
+ fig, _ = ax
772
+ legend_pos = self._get_legend_position(loc)
773
+ fig.update_layout(
774
+ legend=dict(
775
+ **legend_pos,
776
+ bgcolor="rgba(255,255,255,0.9)",
777
+ bordercolor="black",
778
+ borderwidth=1,
779
+ )
780
+ )
781
+
782
+ def axvline(
783
+ self,
784
+ ax: Tuple[go.Figure, int],
785
+ x: float,
786
+ color: str = "grey",
787
+ linestyle: str = "--",
788
+ linewidth: float = 1.0,
789
+ alpha: float = 1.0,
790
+ zorder: int = 1,
791
+ ) -> Any:
792
+ """Add a vertical line across the panel."""
793
+ fig, row = ax
794
+ dash = self._DASH_MAP.get(linestyle, "dash")
795
+
796
+ fig.add_vline(
797
+ x=x,
798
+ line_dash=dash,
799
+ line_color=color,
800
+ line_width=linewidth,
801
+ opacity=alpha,
802
+ row=row,
803
+ col=1,
804
+ )
805
+
806
+ def hbar(
807
+ self,
808
+ ax: Tuple[go.Figure, int],
809
+ y: pd.Series,
810
+ width: pd.Series,
811
+ height: float = 0.8,
812
+ left: Union[float, pd.Series] = 0,
813
+ color: Union[str, List[str]] = "blue",
814
+ edgecolor: str = "black",
815
+ linewidth: float = 0.5,
816
+ zorder: int = 2,
817
+ ) -> Any:
818
+ """Create horizontal bar chart."""
819
+ fig, row = ax
820
+
821
+ # Convert left to array if scalar
822
+ if isinstance(left, (int, float)):
823
+ left_arr = [left] * len(y)
824
+ else:
825
+ left_arr = list(left) if hasattr(left, "tolist") else left
826
+
827
+ trace = go.Bar(
828
+ y=y,
829
+ x=width,
830
+ orientation="h",
831
+ base=left_arr,
832
+ marker=dict(
833
+ color=color,
834
+ line=dict(color=edgecolor, width=linewidth),
835
+ ),
836
+ showlegend=False,
837
+ )
838
+
839
+ fig.add_trace(trace, row=row, col=1)
840
+ return trace
841
+
842
+ def errorbar_h(
843
+ self,
844
+ ax: Tuple[go.Figure, int],
845
+ x: pd.Series,
846
+ y: pd.Series,
847
+ xerr_lower: pd.Series,
848
+ xerr_upper: pd.Series,
849
+ color: str = "black",
850
+ linewidth: float = 1.5,
851
+ capsize: float = 3,
852
+ zorder: int = 3,
853
+ ) -> Any:
854
+ """Add horizontal error bars."""
855
+ fig, row = ax
856
+
857
+ trace = go.Scatter(
858
+ x=x,
859
+ y=y,
860
+ mode="markers",
861
+ marker=dict(size=0),
862
+ error_x=dict(
863
+ type="data",
864
+ symmetric=False,
865
+ array=xerr_upper,
866
+ arrayminus=xerr_lower,
867
+ color=color,
868
+ thickness=linewidth,
869
+ width=capsize,
870
+ ),
871
+ showlegend=False,
872
+ )
873
+
874
+ fig.add_trace(trace, row=row, col=1)
875
+ return trace
876
+
637
877
  def finalize_layout(
638
878
  self,
639
879
  fig: go.Figure,
@@ -664,7 +904,7 @@ class PlotlyBackend:
664
904
  import numpy as np
665
905
 
666
906
  for row in fig._mb_format_rows:
667
- xaxis_name = f"xaxis{row}" if row > 1 else "xaxis"
907
+ xaxis_name = self._axis_name("xaxis", row)
668
908
  xaxis = getattr(fig.layout, xaxis_name, None)
669
909
 
670
910
  # Get x-range from the axis or compute from data
@@ -672,18 +912,16 @@ class PlotlyBackend:
672
912
  if xaxis and xaxis.range:
673
913
  x_range = xaxis.range
674
914
  else:
675
- # Compute from trace data
915
+ # Compute from trace data (filter out None values from legend traces)
676
916
  x_vals = []
677
917
  for trace in fig.data:
678
918
  if hasattr(trace, "x") and trace.x is not None:
679
- x_vals.extend(list(trace.x))
919
+ x_vals.extend([v for v in trace.x if v is not None])
680
920
  if x_vals:
681
921
  x_range = [min(x_vals), max(x_vals)]
682
922
 
683
923
  if x_range:
684
- # Create nice tick values in Mb
685
- x_min_mb = x_range[0] / 1e6
686
- x_max_mb = x_range[1] / 1e6
924
+ x_min_mb, x_max_mb = x_range[0] / 1e6, x_range[1] / 1e6
687
925
  span_mb = x_max_mb - x_min_mb
688
926
 
689
927
  # Choose tick spacing based on range
@@ -700,7 +938,9 @@ class PlotlyBackend:
700
938
 
701
939
  # Generate ticks
702
940
  first_tick = np.ceil(x_min_mb / tick_step) * tick_step
703
- tickvals_mb = np.arange(first_tick, x_max_mb + tick_step / 2, tick_step)
941
+ tickvals_mb = np.arange(
942
+ first_tick, x_max_mb + tick_step / 2, tick_step
943
+ )
704
944
  tickvals_bp = [v * 1e6 for v in tickvals_mb]
705
945
  ticktext = [f"{v:.2f}" for v in tickvals_mb]
706
946
 
pylocuszoom/colors.py CHANGED
@@ -236,4 +236,47 @@ def get_credible_set_color_palette(n_sets: int = 10) -> dict[int, str]:
236
236
  >>> palette[1]
237
237
  '#FF7F00'
238
238
  """
239
- return {i + 1: CREDIBLE_SET_COLORS[i % len(CREDIBLE_SET_COLORS)] for i in range(n_sets)}
239
+ return {
240
+ i + 1: CREDIBLE_SET_COLORS[i % len(CREDIBLE_SET_COLORS)] for i in range(n_sets)
241
+ }
242
+
243
+
244
+ # PheWAS category colors - distinct colors for phenotype categories
245
+ PHEWAS_CATEGORY_COLORS: List[str] = [
246
+ "#E41A1C", # red
247
+ "#377EB8", # blue
248
+ "#4DAF4A", # green
249
+ "#984EA3", # purple
250
+ "#FF7F00", # orange
251
+ "#FFFF33", # yellow
252
+ "#A65628", # brown
253
+ "#F781BF", # pink
254
+ "#999999", # grey
255
+ "#66C2A5", # teal
256
+ "#FC8D62", # salmon
257
+ "#8DA0CB", # periwinkle
258
+ ]
259
+
260
+
261
+ def get_phewas_category_color(category_idx: int) -> str:
262
+ """Get color for a PheWAS category by index.
263
+
264
+ Args:
265
+ category_idx: Zero-indexed category number.
266
+
267
+ Returns:
268
+ Hex color code string.
269
+ """
270
+ return PHEWAS_CATEGORY_COLORS[category_idx % len(PHEWAS_CATEGORY_COLORS)]
271
+
272
+
273
+ def get_phewas_category_palette(categories: List[str]) -> dict[str, str]:
274
+ """Get color palette mapping category names to colors.
275
+
276
+ Args:
277
+ categories: List of unique category names.
278
+
279
+ Returns:
280
+ Dictionary mapping category names to hex colors.
281
+ """
282
+ return {cat: get_phewas_category_color(i) for i, cat in enumerate(categories)}
pylocuszoom/forest.py ADDED
@@ -0,0 +1,37 @@
1
+ """Forest plot data validation and preparation.
2
+
3
+ Validates and prepares meta-analysis/forest plot data for visualization.
4
+ """
5
+
6
+ import pandas as pd
7
+
8
+ from .utils import ValidationError
9
+
10
+
11
+ def validate_forest_df(
12
+ df: pd.DataFrame,
13
+ study_col: str = "study",
14
+ effect_col: str = "effect",
15
+ ci_lower_col: str = "ci_lower",
16
+ ci_upper_col: str = "ci_upper",
17
+ ) -> None:
18
+ """Validate forest plot DataFrame has required columns.
19
+
20
+ Args:
21
+ df: Forest plot data DataFrame.
22
+ study_col: Column name for study/phenotype names.
23
+ effect_col: Column name for effect sizes (beta, OR, HR).
24
+ ci_lower_col: Column name for lower confidence interval.
25
+ ci_upper_col: Column name for upper confidence interval.
26
+
27
+ Raises:
28
+ ValidationError: If required columns are missing.
29
+ """
30
+ required = [study_col, effect_col, ci_lower_col, ci_upper_col]
31
+ missing = [col for col in required if col not in df.columns]
32
+
33
+ if missing:
34
+ raise ValidationError(
35
+ f"Forest plot DataFrame missing required columns: {missing}. "
36
+ f"Required: {required}. Found: {list(df.columns)}"
37
+ )
pylocuszoom/gene_track.py CHANGED
@@ -362,6 +362,7 @@ def plot_gene_track_generic(
362
362
 
363
363
  backend.set_xlim(ax, start, end)
364
364
  backend.set_ylabel(ax, "", fontsize=10)
365
+ backend.hide_yaxis(ax)
365
366
 
366
367
  if region_genes.empty:
367
368
  backend.set_ylim(ax, 0, 1)