MatplotLibAPI 4.0.0__py3-none-any.whl → 4.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
MatplotLibAPI/heatmap.py CHANGED
@@ -1,24 +1,22 @@
1
1
  """Heatmap and correlation matrix helpers."""
2
2
 
3
- from typing import Any, Optional, Sequence, Tuple
3
+ from typing import Any, Optional, Sequence, Tuple, cast, Literal
4
4
 
5
5
  import pandas as pd
6
- from pandas.api.extensions import register_dataframe_accessor
7
- import matplotlib.pyplot as plt
8
6
  import seaborn as sns
9
7
  from matplotlib.axes import Axes
10
8
  from matplotlib.figure import Figure
9
+ from pandas.api.extensions import register_dataframe_accessor
11
10
 
12
11
  from .base_plot import BasePlot
13
-
14
12
  from .style_template import (
15
13
  HEATMAP_STYLE_TEMPLATE,
16
14
  StyleTemplate,
17
15
  string_formatter,
18
16
  validate_dataframe,
19
17
  )
20
- from .utils import _get_axis
21
18
  from .typing import CorrelationMethod
19
+ from .utils import _get_axis, _merge_kwargs, create_fig
22
20
 
23
21
  __all__ = [
24
22
  "HEATMAP_STYLE_TEMPLATE",
@@ -45,20 +43,30 @@ class Heatmap(BasePlot):
45
43
  self.y = y
46
44
  self.value = value
47
45
 
48
- @property
49
- def correlation_matrix(self) -> pd.DataFrame:
46
+ def correlation_matrix(
47
+ self,
48
+ correlation_method: CorrelationMethod = "pearson",
49
+ ) -> pd.DataFrame:
50
50
  """Compute the correlation matrix for the underlying DataFrame."""
51
- return self._obj.corr()
51
+ return self._obj.corr(method=correlation_method)
52
52
 
53
53
  def aplot(
54
54
  self,
55
55
  title: Optional[str] = None,
56
- style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
56
+ style: Optional[StyleTemplate] = None,
57
57
  ax: Optional[Axes] = None,
58
58
  **kwargs: Any,
59
59
  ) -> Axes:
60
+ """Plot a heatmap on an existing Matplotlib axes."""
61
+ if not style:
62
+ style = HEATMAP_STYLE_TEMPLATE
60
63
  plot_ax = _get_axis(ax)
61
- sns.heatmap(self._obj, cmap=style.palette, ax=plot_ax)
64
+ heatmap_kwargs: dict[str, Any] = {
65
+ "data": self._obj,
66
+ "cmap": style.palette,
67
+ "ax": plot_ax,
68
+ }
69
+ sns.heatmap(**_merge_kwargs(heatmap_kwargs, kwargs))
62
70
 
63
71
  plot_ax.set_xlabel(string_formatter(self.x))
64
72
  plot_ax.set_ylabel(string_formatter(self.y))
@@ -66,36 +74,26 @@ class Heatmap(BasePlot):
66
74
  plot_ax.set_title(title)
67
75
  return plot_ax
68
76
 
69
- def fplot(
70
- self,
71
- title: Optional[str] = None,
72
- style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
73
- figsize: Tuple[float, float] = (10, 6),
74
- ) -> Figure:
75
- fig = Figure(
76
- figsize=figsize,
77
- facecolor=style.background_color,
78
- edgecolor=style.background_color,
79
- )
80
- ax = Axes(fig=fig, facecolor=style.background_color)
81
- self.aplot(title=title, style=style, ax=ax)
82
- return fig
83
-
84
77
  def aplot_correlation_matrix(
85
78
  self,
86
79
  title: Optional[str] = None,
87
- style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
80
+ style: Optional[StyleTemplate] = None,
81
+ correlation_method: CorrelationMethod = "pearson",
88
82
  ax: Optional[Axes] = None,
89
83
  **kwargs: Any,
90
84
  ) -> Axes:
85
+ """Plot a correlation matrix heatmap on existing axes."""
86
+ if not style:
87
+ style = HEATMAP_STYLE_TEMPLATE
91
88
  plot_ax = _get_axis(ax)
92
- sns.heatmap(
93
- self.correlation_matrix,
94
- cmap=style.palette,
95
- annot=True,
96
- fmt=".2f",
97
- ax=plot_ax,
98
- )
89
+ heatmap_kwargs: dict[str, Any] = {
90
+ "data": self.correlation_matrix(correlation_method),
91
+ "cmap": style.palette,
92
+ "annot": True,
93
+ "fmt": ".2f",
94
+ "ax": plot_ax,
95
+ }
96
+ sns.heatmap(**_merge_kwargs(heatmap_kwargs, kwargs))
99
97
  if title:
100
98
  plot_ax.set_title(title)
101
99
  return plot_ax
@@ -103,18 +101,18 @@ class Heatmap(BasePlot):
103
101
  def fplot_correlation_matrix(
104
102
  self,
105
103
  title: Optional[str] = None,
106
- style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
104
+ style: Optional[StyleTemplate] = None,
107
105
  figsize: Tuple[float, float] = (10, 6),
106
+ correlation_method: CorrelationMethod = "pearson",
108
107
  ) -> Figure:
109
- fig = Figure(
110
- figsize=figsize,
111
- facecolor=style.background_color,
112
- edgecolor=style.background_color,
113
- )
114
- ax = Axes(fig=fig, facecolor=style.background_color)
115
- self.aplot(
108
+ """Plot a correlation matrix heatmap on a new figure."""
109
+ if not style:
110
+ style = HEATMAP_STYLE_TEMPLATE
111
+ fig, ax = create_fig(figsize=figsize, style=style)
112
+ self.aplot_correlation_matrix(
116
113
  title=title,
117
114
  style=style,
115
+ correlation_method=correlation_method,
118
116
  ax=ax,
119
117
  )
120
118
  return fig
@@ -126,7 +124,7 @@ def _prepare_data(
126
124
  y: str,
127
125
  value: str,
128
126
  ) -> pd.DataFrame:
129
- """Prepare data for treemap plotting."""
127
+ """Prepare data for heatmap plotting."""
130
128
  validate_dataframe(pd_df, cols=[x, y, value])
131
129
  plot_df = pd_df[[x, y, value]].pivot_table(
132
130
  index=y, columns=x, values=value, aggfunc="mean"
@@ -134,13 +132,26 @@ def _prepare_data(
134
132
  return plot_df
135
133
 
136
134
 
135
+ def _compute_correlation_matrix(
136
+ pd_df: pd.DataFrame,
137
+ columns: Optional[Sequence[str]],
138
+ method: CorrelationMethod,
139
+ ) -> pd.DataFrame:
140
+ """Compute a correlation matrix from numeric dataframe columns."""
141
+ source_df = pd_df[list(columns)] if columns else pd_df
142
+ numeric_df = source_df.select_dtypes(include="number")
143
+ if numeric_df.empty:
144
+ raise ValueError("No numeric columns available to compute correlation matrix.")
145
+ return numeric_df.corr(method=cast(Any, method))
146
+
147
+
137
148
  def aplot_heatmap(
138
149
  pd_df: pd.DataFrame,
139
150
  x: str,
140
151
  y: str,
141
152
  value: str,
142
153
  title: Optional[str] = None,
143
- style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
154
+ style: Optional[StyleTemplate] = None,
144
155
  ax: Optional[Axes] = None,
145
156
  **kwargs: Any,
146
157
  ) -> Axes:
@@ -160,26 +171,50 @@ def aplot_heatmap(
160
171
 
161
172
  def aplot_correlation_matrix(
162
173
  pd_df: pd.DataFrame,
163
- columns: Optional[Sequence[str]] = None,
164
- method: CorrelationMethod = "pearson",
174
+ x: str,
175
+ y: str,
176
+ value: str,
177
+ correlation_method: CorrelationMethod = "pearson",
165
178
  title: Optional[str] = None,
166
- style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
179
+ style: Optional[StyleTemplate] = None,
167
180
  ax: Optional[Axes] = None,
168
181
  **kwargs: Any,
169
182
  ) -> Axes:
170
- """Plot a correlation matrix heatmap for numeric columns."""
183
+ """Plot a correlation matrix heatmap on existing Matplotlib axes.
184
+
185
+ Parameters
186
+ ----------
187
+ pd_df : pd.DataFrame
188
+ Source dataframe containing correlation inputs.
189
+ x : str
190
+ Column used for heatmap x-axis labels.
191
+ y : str
192
+ Column used for heatmap y-axis labels.
193
+ value : str
194
+ Column providing values before correlation aggregation.
195
+ correlation_method : CorrelationMethod, optional
196
+ Correlation method. The default is ``"pearson"``.
197
+ title : str, optional
198
+ Plot title. The default is ``None``.
199
+ style : StyleTemplate, optional
200
+ Style template for rendering. The default is ``None``.
201
+ ax : Axes, optional
202
+ Matplotlib axes to draw on. If ``None``, use current axes.
203
+ **kwargs : Any
204
+ Additional keyword arguments forwarded to seaborn.
205
+
206
+ Returns
207
+ -------
208
+ Axes
209
+ The Matplotlib axes containing the correlation heatmap.
210
+ """
171
211
  return Heatmap(
172
212
  pd_df=pd_df,
173
- x="", # Placeholder since correlation matrix is square
174
- y="", # Placeholder since correlation matrix is square
175
- value="", # Placeholder since correlation matrix is computed internally
213
+ x=x,
214
+ y=y,
215
+ value=value,
176
216
  ).aplot_correlation_matrix(
177
- method=method,
178
- title=title,
179
- style=style,
180
- ax=ax,
181
- columns=columns,
182
- **kwargs,
217
+ title=title, style=style, correlation_method=correlation_method, ax=ax
183
218
  )
184
219
 
185
220
 
@@ -189,7 +224,7 @@ def fplot_heatmap(
189
224
  y: str,
190
225
  value: str,
191
226
  title: Optional[str] = None,
192
- style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
227
+ style: Optional[StyleTemplate] = None,
193
228
  figsize: Tuple[float, float] = (10, 6),
194
229
  ) -> Figure:
195
230
  """Plot a matrix heatmap on a new figure."""
@@ -211,13 +246,16 @@ def fplot_correlation_matrix(
211
246
  y: str,
212
247
  value: str,
213
248
  title: Optional[str] = None,
214
- style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
249
+ style: Optional[StyleTemplate] = None,
215
250
  figsize: Tuple[float, float] = (10, 6),
251
+ correlation_method: CorrelationMethod = "pearson",
216
252
  ) -> Figure:
217
253
  """Plot a correlation matrix heatmap on a new figure."""
218
254
  return Heatmap(
219
255
  pd_df=pd_df,
220
- x=x, # Placeholder since correlation matrix is square
221
- y=y, # Placeholder since correlation matrix is square
222
- value=value, # Placeholder since correlation matrix is computed internally
223
- ).fplot_correlation_matrix(title=title, style=style, figsize=figsize)
256
+ x=x,
257
+ y=y,
258
+ value=value,
259
+ ).fplot_correlation_matrix(
260
+ title=title, style=style, figsize=figsize, correlation_method=correlation_method
261
+ )
@@ -19,7 +19,7 @@ from .style_template import (
19
19
  string_formatter,
20
20
  validate_dataframe,
21
21
  )
22
- from .utils import _get_axis
22
+ from .utils import _get_axis, _merge_kwargs
23
23
 
24
24
  __all__ = ["DISTRIBUTION_STYLE_TEMPLATE", "aplot_histogram", "fplot_histogram"]
25
25
 
@@ -50,15 +50,16 @@ class Histogram(BasePlot):
50
50
 
51
51
  validate_dataframe(self._obj, cols=[self.column])
52
52
  plot_ax = _get_axis(ax)
53
- sns.histplot(
54
- data=self._obj,
55
- x=self.column,
56
- bins=self.bins,
57
- kde=self.kde,
58
- color=style.font_color,
59
- edgecolor=style.background_color,
60
- ax=plot_ax,
61
- )
53
+ histplot_kwargs: dict[str, Any] = {
54
+ "data": self._obj,
55
+ "x": self.column,
56
+ "bins": self.bins,
57
+ "kde": self.kde,
58
+ "color": style.font_color,
59
+ "edgecolor": style.background_color,
60
+ "ax": plot_ax,
61
+ }
62
+ sns.histplot(**_merge_kwargs(histplot_kwargs, kwargs))
62
63
  plot_ax.set_facecolor(style.background_color)
63
64
  plot_ax.set_xlabel(string_formatter(self.column))
64
65
  plot_ax.set_ylabel("Frequency")
@@ -66,25 +67,6 @@ class Histogram(BasePlot):
66
67
  plot_ax.set_title(title)
67
68
  return plot_ax
68
69
 
69
- def fplot(
70
- self,
71
- title: Optional[str] = None,
72
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
73
- figsize: Tuple[float, float] = (10, 6),
74
- ) -> Figure:
75
- fig = Figure(
76
- figsize=figsize,
77
- facecolor=style.background_color,
78
- edgecolor=style.background_color,
79
- )
80
- ax = Axes(fig=fig, facecolor=style.background_color)
81
- self.aplot(
82
- title=title,
83
- style=style,
84
- ax=ax,
85
- )
86
- return fig
87
-
88
70
 
89
71
  def aplot_histogram(
90
72
  pd_df: pd.DataFrame,
@@ -7,6 +7,7 @@ organizing implementation details in submodules.
7
7
  from .constants import _DEFAULT, _WEIGHT_PERCENTILES
8
8
  from .core import NETWORK_STYLE_TEMPLATE, NetworkGraph
9
9
  from .plot import (
10
+ trim_low_degree_nodes,
10
11
  aplot_network,
11
12
  aplot_network_node,
12
13
  aplot_network_components,
@@ -17,6 +18,7 @@ from .plot import (
17
18
  from .scaling import _scale_weights, _softmax
18
19
 
19
20
  __all__ = [
21
+ "trim_low_degree_nodes",
20
22
  "aplot_network",
21
23
  "aplot_network_node",
22
24
  "aplot_network_components",
@@ -29,6 +29,14 @@ __all__ = [
29
29
  ]
30
30
 
31
31
 
32
+ def _compute_deciles(weights: Iterable[float]) -> Optional[np.ndarray]:
33
+ """Return deciles for ``weights`` or ``None`` when empty."""
34
+ weights_arr = np.asarray(list(weights), dtype=float)
35
+ if weights_arr.size == 0:
36
+ return None
37
+ return np.percentile(weights_arr, _WEIGHT_PERCENTILES)
38
+
39
+
32
40
  class NodeView(nx.classes.reportviews.NodeView):
33
41
  """Extended node view with convenience helpers."""
34
42
 
@@ -520,6 +528,9 @@ class NetworkGraph(BasePlot):
520
528
  max_font_size: int = _DEFAULT["MAX_FONT_SIZE"],
521
529
  min_font_size: int = _DEFAULT["MIN_FONT_SIZE"],
522
530
  edge_weight_col: str = "weight",
531
+ *,
532
+ node_deciles: Optional[np.ndarray],
533
+ edge_deciles: Optional[np.ndarray],
523
534
  ) -> Tuple[List[float], List[float], Dict[int, List[str]]]:
524
535
  """Calculate node, edge and font sizes based on weights.
525
536
 
@@ -536,7 +547,11 @@ class NetworkGraph(BasePlot):
536
547
  min_font_size : int, optional
537
548
  Lower bound for font size. The default is `_DEFAULT["MIN_FONT_SIZE"]`.
538
549
  edge_weight_col : str, optional
539
- Node attribute used for weighting. The default is "weight".
550
+ Edge attribute used for weighting. The default is "weight".
551
+ node_deciles : np.ndarray, optional
552
+ Node-weight deciles used to scale node and font sizes.
553
+ edge_deciles : np.ndarray, optional
554
+ Edge-weight deciles used to scale edge widths.
540
555
 
541
556
  Returns
542
557
  -------
@@ -547,11 +562,6 @@ class NetworkGraph(BasePlot):
547
562
  node_weights = [
548
563
  data.get(edge_weight_col, 1) for node, data in self.node_view(data=True)
549
564
  ]
550
- node_deciles = (
551
- np.percentile(np.array(node_weights), _WEIGHT_PERCENTILES)
552
- if node_weights
553
- else None
554
- )
555
565
  node_size = _scale_weights(
556
566
  weights=node_weights,
557
567
  scale_max=max_node_size,
@@ -563,7 +573,11 @@ class NetworkGraph(BasePlot):
563
573
  edge_weights = [
564
574
  data.get(edge_weight_col, 1) for _, _, data in self.edge_view(data=True)
565
575
  ]
566
- edges_width = _scale_weights(weights=edge_weights, scale_max=max_edge_width)
576
+ edges_width = _scale_weights(
577
+ weights=edge_weights,
578
+ scale_max=max_edge_width,
579
+ deciles=edge_deciles,
580
+ )
567
581
 
568
582
  # Scale the normalized node weights within the desired range of font sizes
569
583
  node_size_dict = dict(
@@ -683,7 +697,7 @@ class NetworkGraph(BasePlot):
683
697
  def aplot(
684
698
  self,
685
699
  title: Optional[str] = None,
686
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
700
+ style: Optional[StyleTemplate] = None,
687
701
  edge_weight_col: str = "weight",
688
702
  layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
689
703
  ax: Optional[Axes] = None,
@@ -709,6 +723,8 @@ class NetworkGraph(BasePlot):
709
723
  Axes
710
724
  Matplotlib axes with the plotted network.
711
725
  """
726
+ if not style:
727
+ style = NETWORK_STYLE_TEMPLATE
712
728
  sns.set_palette(style.palette)
713
729
  if ax is None:
714
730
  ax = cast(Axes, plt.gca())
@@ -716,6 +732,7 @@ class NetworkGraph(BasePlot):
716
732
  isolated_nodes = list(nx.isolates(self._nx_graph))
717
733
  graph_nx = self._nx_graph
718
734
  if isolated_nodes:
735
+
719
736
  graph_nx = graph_nx.copy()
720
737
  graph_nx.remove_nodes_from(isolated_nodes)
721
738
 
@@ -733,6 +750,14 @@ class NetworkGraph(BasePlot):
733
750
 
734
751
  mapped_min_font_size = style.font_mapping.get(0)
735
752
  mapped_max_font_size = style.font_mapping.get(4)
753
+ node_weights = [
754
+ data.get(edge_weight_col, 1) for _, data in graph.node_view(data=True)
755
+ ]
756
+ edge_weights = [
757
+ data.get(edge_weight_col, 1) for _, _, data in graph.edge_view(data=True)
758
+ ]
759
+ node_deciles = _compute_deciles(node_weights)
760
+ edge_deciles = _compute_deciles(edge_weights)
736
761
 
737
762
  node_sizes, edge_widths, font_sizes = graph.layout(
738
763
  min_node_size=_DEFAULT["MIN_NODE_SIZE"],
@@ -749,6 +774,8 @@ class NetworkGraph(BasePlot):
749
774
  else _DEFAULT["MAX_FONT_SIZE"]
750
775
  ),
751
776
  edge_weight_col=edge_weight_col,
777
+ node_deciles=node_deciles,
778
+ edge_deciles=edge_deciles,
752
779
  )
753
780
  pos = graph.compute_positions(seed=layout_seed)
754
781
  # nodes
@@ -791,50 +818,10 @@ class NetworkGraph(BasePlot):
791
818
 
792
819
  return ax
793
820
 
794
- def fplot(
795
- self,
796
- title: Optional[str] = None,
797
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
798
- layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
799
- figsize: Tuple[float, float] = FIG_SIZE,
800
- ) -> Figure:
801
- """Plot the graph using node and edge weights.
802
-
803
- Parameters
804
- ----------
805
- title : str, optional
806
- Plot title.
807
- style : StyleTemplate, optional
808
- Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
809
- edge_weight_col : str, optional
810
- Edge attribute used for weighting. The default is "weight".
811
- layout_seed : int, optional
812
- Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
813
-
814
- Returns
815
- -------
816
- Figure
817
- Matplotlib figure with the plotted network.
818
- """
819
- fig = Figure(
820
- figsize=figsize,
821
- facecolor=style.background_color,
822
- edgecolor=style.background_color,
823
- )
824
- ax = Axes(fig=fig, facecolor=style.background_color)
825
- self.aplot(
826
- title=title,
827
- style=style,
828
- edge_weight_col="",
829
- layout_seed=layout_seed,
830
- ax=ax,
831
- )
832
- return fig
833
-
834
821
  def aplot_connected_components(
835
822
  self,
836
823
  title: Optional[str] = None,
837
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
824
+ style: Optional[StyleTemplate] = None,
838
825
  edge_weight_col: str = "weight",
839
826
  layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
840
827
  axes: Optional[np.ndarray] = None,
@@ -863,6 +850,8 @@ class NetworkGraph(BasePlot):
863
850
  or created, the flattened array of axes is returned; otherwise, a
864
851
  single Axes is returned.
865
852
  """
853
+ if not style:
854
+ style = NETWORK_STYLE_TEMPLATE
866
855
  sns.set_palette(style.palette)
867
856
 
868
857
  graph = self
@@ -911,9 +900,10 @@ class NetworkGraph(BasePlot):
911
900
  def fplot_connected_components(
912
901
  self,
913
902
  title: Optional[str] = None,
914
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
903
+ style: Optional[StyleTemplate] = None,
915
904
  edge_weight_col: str = "weight",
916
905
  layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
906
+ figsize: Tuple[float, float] = FIG_SIZE,
917
907
  ) -> Figure:
918
908
  """Plot all connected components of the graph.
919
909
 
@@ -933,9 +923,11 @@ class NetworkGraph(BasePlot):
933
923
  Figure
934
924
  Matplotlib figure with the plotted network.
935
925
  """
936
- fig, ax = plt.subplots(figsize=FIG_SIZE)
937
- fig = cast(Figure, fig)
938
- fig.set_facecolor(style.background_color)
926
+ if not style:
927
+ style = NETWORK_STYLE_TEMPLATE
928
+
929
+ fig, ax = BasePlot.create_fig(figsize=figsize, style=style)
930
+
939
931
  self.aplot_connected_components(
940
932
  title=title,
941
933
  style=style,