MatplotLibAPI 4.0.0__tar.gz → 4.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/PKG-INFO +1 -1
  2. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/pyproject.toml +1 -1
  3. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/accessor.py +14 -14
  4. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/area.py +78 -20
  5. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/bar.py +18 -10
  6. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/box_violin.py +13 -7
  7. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/bubble.py +2 -1
  8. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/heatmap.py +56 -33
  9. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/histogram.py +13 -11
  10. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/network/core.py +2 -1
  11. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/network/plot.py +1 -1
  12. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/pie.py +11 -10
  13. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/pivot.py +11 -3
  14. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/table.py +2 -1
  15. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/timeserie.py +2 -1
  16. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/utils.py +23 -0
  17. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/waffle.py +2 -1
  18. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/word_cloud.py +2 -1
  19. matplotlibapi-4.0.1/tests/test_area.py +77 -0
  20. matplotlibapi-4.0.1/tests/test_bar.py +34 -0
  21. matplotlibapi-4.0.1/tests/test_heatmap.py +63 -0
  22. matplotlibapi-4.0.1/tests/test_histogram.py +30 -0
  23. matplotlibapi-4.0.1/tests/test_pie.py +31 -0
  24. matplotlibapi-4.0.1/tests/test_pivot.py +31 -0
  25. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_smoke.py +22 -0
  26. matplotlibapi-4.0.0/tests/test_area.py +0 -17
  27. matplotlibapi-4.0.0/tests/test_bar.py +0 -17
  28. matplotlibapi-4.0.0/tests/test_heatmap.py +0 -27
  29. matplotlibapi-4.0.0/tests/test_histogram.py +0 -15
  30. matplotlibapi-4.0.0/tests/test_pie.py +0 -15
  31. matplotlibapi-4.0.0/tests/test_pivot.py +0 -15
  32. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/.github/dependabot.yml +0 -0
  33. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/.github/workflows/ci.yml +0 -0
  34. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/.github/workflows/python-publish.yml +0 -0
  35. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/.gitignore +0 -0
  36. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/.pre-commit-config.yaml +0 -0
  37. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/AGENTS.md +0 -0
  38. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/LICENSE +0 -0
  39. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/README.md +0 -0
  40. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/SECURITY.md +0 -0
  41. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/SUGGESTIONS.md +0 -0
  42. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/examples/__init__.py +0 -0
  43. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/examples/network.png +0 -0
  44. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/examples/network.py +0 -0
  45. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/examples/sample_data.py +0 -0
  46. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/scripts/pre_commit.sh +0 -0
  47. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/__init__.py +0 -0
  48. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/base_plot.py +0 -0
  49. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/composite.py +0 -0
  50. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/mcp/__init__.py +0 -0
  51. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/mcp/metadata.py +0 -0
  52. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/mcp/renderers.py +0 -0
  53. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/mcp_server.py +0 -0
  54. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/network/__init__.py +0 -0
  55. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/network/constants.py +0 -0
  56. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/network/scaling.py +0 -0
  57. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/sankey.py +0 -0
  58. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/style_template.py +0 -0
  59. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/sunburst.py +0 -0
  60. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/treemap.py +0 -0
  61. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/src/MatplotLibAPI/typing.py +0 -0
  62. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/__init__.py +0 -0
  63. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/conftest.py +0 -0
  64. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_box_violin.py +0 -0
  65. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_bubble.py +0 -0
  66. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_composite.py +0 -0
  67. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_dependencies.py +0 -0
  68. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_mcp_server.py +0 -0
  69. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_network.py +0 -0
  70. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_sankey.py +0 -0
  71. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_style_template.py +0 -0
  72. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_sunburst.py +0 -0
  73. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_table.py +0 -0
  74. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_timeseries.py +0 -0
  75. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_treemap.py +0 -0
  76. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_waffle.py +0 -0
  77. {matplotlibapi-4.0.0 → matplotlibapi-4.0.1}/tests/test_wordcloud.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: MatplotLibAPI
3
- Version: 4.0.0
3
+ Version: 4.0.1
4
4
  License-File: LICENSE
5
5
  Requires-Python: >=3.8
6
6
  Requires-Dist: kaleido
@@ -3,7 +3,7 @@ requires = ["hatchling"]
3
3
  build-backend = "hatchling.build"
4
4
  [project]
5
5
  name = "MatplotLibAPI"
6
- version = "4.0.0"
6
+ version = "4.0.1"
7
7
  readme = "README.md"
8
8
  requires-python = ">=3.8"
9
9
  dependencies = [
@@ -428,16 +428,8 @@ class DataFrameAccessor:
428
428
  max_values=max_values,
429
429
  center_to_mean=center_to_mean,
430
430
  ).fplot(
431
- label=label,
432
- x=x,
433
- y=y,
434
- z=z,
435
431
  title=title,
436
432
  style=style or bubble_style_template,
437
- max_values=max_values,
438
- center_to_mean=center_to_mean,
439
- sort_by=sort_by,
440
- ascending=ascending,
441
433
  hline=hline,
442
434
  vline=vline,
443
435
  figsize=figsize,
@@ -897,6 +889,7 @@ class DataFrameAccessor:
897
889
  title: Optional[str] = None,
898
890
  style: StyleTemplate = AREA_STYLE_TEMPLATE,
899
891
  ax: Optional[Axes] = None,
892
+ **kwargs: Any,
900
893
  ) -> Axes:
901
894
  """Plot an area chart on existing axes.
902
895
 
@@ -916,6 +909,8 @@ class DataFrameAccessor:
916
909
  Styling template. The default is ``AREA_STYLE_TEMPLATE``.
917
910
  ax : Axes, optional
918
911
  Matplotlib axes to plot on. If None, uses the current axes.
912
+ **kwargs : Any
913
+ Additional keyword arguments forwarded to the underlying area plot call.
919
914
 
920
915
  Returns
921
916
  -------
@@ -923,14 +918,15 @@ class DataFrameAccessor:
923
918
  The Matplotlib axes object with the area chart.
924
919
  """
925
920
  return aplot_area(
926
- pd_df=self._obj,
927
- x=x,
928
- y=y,
921
+ self._obj,
922
+ x,
923
+ y,
929
924
  label=label,
930
925
  stacked=stacked,
931
926
  title=title,
932
927
  style=style,
933
928
  ax=ax,
929
+ **kwargs,
934
930
  )
935
931
 
936
932
  def fplot_area(
@@ -942,6 +938,7 @@ class DataFrameAccessor:
942
938
  title: Optional[str] = None,
943
939
  style: StyleTemplate = AREA_STYLE_TEMPLATE,
944
940
  figsize: Tuple[float, float] = FIG_SIZE,
941
+ **kwargs: Any,
945
942
  ) -> Figure:
946
943
  """Plot an area chart on a new figure.
947
944
 
@@ -961,6 +958,8 @@ class DataFrameAccessor:
961
958
  Styling template. The default is ``AREA_STYLE_TEMPLATE``.
962
959
  figsize : tuple[float, float], optional
963
960
  Figure size. The default is FIG_SIZE.
961
+ **kwargs : Any
962
+ Additional keyword arguments forwarded to the underlying area plot call.
964
963
 
965
964
  Returns
966
965
  -------
@@ -968,14 +967,15 @@ class DataFrameAccessor:
968
967
  The new Matplotlib figure with the area chart.
969
968
  """
970
969
  return fplot_area(
971
- pd_df=self._obj,
972
- x=x,
973
- y=y,
970
+ self._obj,
971
+ x,
972
+ y,
974
973
  label=label,
975
974
  stacked=stacked,
976
975
  title=title,
977
976
  style=style,
978
977
  figsize=figsize,
978
+ **kwargs,
979
979
  )
980
980
 
981
981
  def aplot_pie_donut(
@@ -1,9 +1,8 @@
1
- """Area chart helpers."""
1
+ """Area chart helpers for Matplotlib-based area visualizations."""
2
2
 
3
3
  from typing import Any, Optional, Tuple
4
4
 
5
5
  import pandas as pd
6
- import matplotlib.pyplot as plt
7
6
  from matplotlib.axes import Axes
8
7
  from matplotlib.figure import Figure
9
8
 
@@ -15,7 +14,7 @@ from .style_template import (
15
14
  string_formatter,
16
15
  validate_dataframe,
17
16
  )
18
- from .utils import _get_axis
17
+ from .utils import _get_axis, _merge_kwargs
19
18
 
20
19
  __all__ = ["AREA_STYLE_TEMPLATE", "aplot_area", "fplot_area"]
21
20
 
@@ -39,6 +38,21 @@ class AreaChart(BasePlot):
39
38
  label: Optional[str] = None,
40
39
  stacked: bool = True,
41
40
  ):
41
+ """Initialize an area chart plotter.
42
+
43
+ Parameters
44
+ ----------
45
+ pd_df : pd.DataFrame
46
+ DataFrame containing the data to visualize.
47
+ x : str
48
+ Column name used for the x-axis.
49
+ y : str
50
+ Column name used for the y-axis values.
51
+ label : str, optional
52
+ Column used to split the area into groups. The default is None.
53
+ stacked : bool, optional
54
+ Whether grouped areas are stacked. The default is True.
55
+ """
42
56
  super().__init__(pd_df=pd_df)
43
57
  self.x = x
44
58
  self.y = y
@@ -50,6 +64,52 @@ class AreaChart(BasePlot):
50
64
  cols.append(self.label)
51
65
  validate_dataframe(self._obj, cols=cols)
52
66
 
67
+ def _plot_grouped_area(
68
+ self,
69
+ plot_ax: Axes,
70
+ **kwargs: Any,
71
+ ) -> None:
72
+ """Plot grouped area data using a pivoted dataframe."""
73
+ pivot_df = self._obj.pivot_table(
74
+ index=self.x,
75
+ columns=self.label,
76
+ values=self.y,
77
+ aggfunc="sum",
78
+ ).sort_index()
79
+
80
+ plot_kwargs: dict[str, Any] = {
81
+ "kind": "area",
82
+ "stacked": self.stacked,
83
+ "alpha": 0.7,
84
+ "ax": plot_ax,
85
+ }
86
+ pivot_df.plot(**_merge_kwargs(plot_kwargs, kwargs))
87
+
88
+ legend = plot_ax.get_legend()
89
+ if legend is not None:
90
+ legend.set_title(string_formatter(self.label or ""))
91
+
92
+ def _plot_single_area(
93
+ self,
94
+ plot_ax: Axes,
95
+ style: StyleTemplate,
96
+ **kwargs: Any,
97
+ ) -> None:
98
+ """Plot a single-series area chart."""
99
+ sorted_df = self._obj.sort_values(by=self.x)
100
+ fill_between_kwargs: dict[str, Any] = {
101
+ "color": style.font_color,
102
+ "alpha": 0.4,
103
+ }
104
+ merged_fill_between_kwargs = _merge_kwargs(fill_between_kwargs, kwargs)
105
+
106
+ plot_ax.fill_between(
107
+ sorted_df[self.x],
108
+ sorted_df[self.y],
109
+ **merged_fill_between_kwargs,
110
+ )
111
+ plot_ax.plot(sorted_df[self.x], sorted_df[self.y], color=style.font_color)
112
+
53
113
  def aplot(
54
114
  self,
55
115
  title: Optional[str] = None,
@@ -68,7 +128,7 @@ class AreaChart(BasePlot):
68
128
  ax : Axes, optional
69
129
  Matplotlib axes to plot on. If None, use the current axes.
70
130
  **kwargs : Any
71
- Additional keyword arguments reserved for compatibility.
131
+ Additional keyword arguments forwarded to the area plotting call.
72
132
 
73
133
  Returns
74
134
  -------
@@ -76,21 +136,16 @@ class AreaChart(BasePlot):
76
136
  The Matplotlib axes containing the area chart.
77
137
  """
78
138
  plot_ax = _get_axis(ax)
139
+ plot_ax.set_facecolor(style.background_color)
79
140
 
80
141
  if self.label:
81
- pivot_df = self._obj.pivot_table(
82
- index=self.x, columns=self.label, values=self.y, aggfunc="sum"
83
- ).sort_index()
84
- pivot_df.plot(kind="area", stacked=self.stacked, alpha=0.7, ax=plot_ax)
142
+ self._plot_grouped_area(plot_ax=plot_ax, **kwargs)
85
143
  else:
86
- sorted_df = self._obj.sort_values(by=self.x)
87
- plot_ax.fill_between(
88
- sorted_df[self.x], sorted_df[self.y], color=style.font_color, alpha=0.4
89
- )
90
- plot_ax.plot(sorted_df[self.x], sorted_df[self.y], color=style.font_color)
144
+ self._plot_single_area(plot_ax=plot_ax, style=style, **kwargs)
91
145
 
92
146
  plot_ax.set_xlabel(string_formatter(self.x))
93
147
  plot_ax.set_ylabel(string_formatter(self.y))
148
+ plot_ax.tick_params(axis="x", labelrotation=45)
94
149
  if title:
95
150
  plot_ax.set_title(title)
96
151
  return plot_ax
@@ -100,6 +155,7 @@ class AreaChart(BasePlot):
100
155
  title: Optional[str] = None,
101
156
  style: StyleTemplate = AREA_STYLE_TEMPLATE,
102
157
  figsize: Tuple[float, float] = (10, 6),
158
+ **kwargs: Any,
103
159
  ) -> Figure:
104
160
  """Plot an area chart on a new figure.
105
161
 
@@ -112,18 +168,18 @@ class AreaChart(BasePlot):
112
168
  figsize : tuple[float, float], optional
113
169
  Figure size. The default is (10, 6).
114
170
 
171
+ **kwargs : Any
172
+ Additional keyword arguments forwarded to ``aplot``.
173
+
115
174
  Returns
116
175
  -------
117
176
  Figure
118
177
  The Matplotlib figure containing the area chart.
119
178
  """
120
- fig = Figure(
121
- figsize=figsize,
122
- facecolor=style.background_color,
123
- edgecolor=style.background_color,
124
- )
125
- ax = Axes(fig=fig, facecolor=style.background_color)
126
- self.aplot(title=title, style=style, ax=ax)
179
+ fig = Figure(figsize=figsize)
180
+ fig.set_facecolor(style.background_color)
181
+ ax = fig.add_subplot(111)
182
+ self.aplot(title=title, style=style, ax=ax, **kwargs)
127
183
  return fig
128
184
 
129
185
 
@@ -162,6 +218,7 @@ def fplot_area(
162
218
  title: Optional[str] = None,
163
219
  style: StyleTemplate = AREA_STYLE_TEMPLATE,
164
220
  figsize: Tuple[float, float] = (10, 6),
221
+ **kwargs: Any,
165
222
  ) -> Figure:
166
223
  """Plot area charts on a new figure."""
167
224
  return AreaChart(
@@ -174,4 +231,5 @@ def fplot_area(
174
231
  title=title,
175
232
  style=style,
176
233
  figsize=figsize,
234
+ **kwargs,
177
235
  )
@@ -16,7 +16,7 @@ from .style_template import (
16
16
  string_formatter,
17
17
  validate_dataframe,
18
18
  )
19
- from .utils import _get_axis
19
+ from .utils import _get_axis, _merge_kwargs
20
20
 
21
21
  __all__ = ["DISTRIBUTION_STYLE_TEMPLATE", "aplot_bar", "fplot_bar"]
22
22
 
@@ -84,15 +84,22 @@ class BarChart(BasePlot):
84
84
  values=self.value,
85
85
  aggfunc="sum",
86
86
  )
87
- pivot_df.plot(kind="bar", stacked=self.stacked, ax=plot_ax, alpha=0.85)
87
+ plot_kwargs: dict[str, Any] = {
88
+ "kind": "bar",
89
+ "stacked": self.stacked,
90
+ "ax": plot_ax,
91
+ "alpha": 0.85,
92
+ }
93
+ pivot_df.plot(**_merge_kwargs(plot_kwargs, kwargs))
88
94
  else:
89
- sns.barplot(
90
- data=self._obj,
91
- x=self.category,
92
- y=self.value,
93
- palette=style.palette,
94
- ax=plot_ax,
95
- )
95
+ barplot_kwargs: dict[str, Any] = {
96
+ "data": self._obj,
97
+ "x": self.category,
98
+ "y": self.value,
99
+ "palette": style.palette,
100
+ "ax": plot_ax,
101
+ }
102
+ sns.barplot(**_merge_kwargs(barplot_kwargs, kwargs))
96
103
 
97
104
  plot_ax.set_facecolor(style.background_color)
98
105
  plot_ax.set_xlabel(string_formatter(self.category))
@@ -129,7 +136,8 @@ class BarChart(BasePlot):
129
136
  facecolor=style.background_color,
130
137
  edgecolor=style.background_color,
131
138
  )
132
- ax = Axes(fig=fig, facecolor=style.background_color)
139
+ ax = fig.add_subplot(111)
140
+ ax.set_facecolor(style.background_color)
133
141
  fig.set_facecolor(style.background_color)
134
142
  self.aplot(title=title, style=style, ax=ax)
135
143
  return fig
@@ -16,7 +16,7 @@ from .style_template import (
16
16
  string_formatter,
17
17
  validate_dataframe,
18
18
  )
19
- from .utils import _get_axis
19
+ from .utils import _get_axis, _merge_kwargs
20
20
 
21
21
  __all__ = ["DISTRIBUTION_STYLE_TEMPLATE", "aplot_box_violin", "fplot_box_violin"]
22
22
 
@@ -82,10 +82,18 @@ class BoxViolinPlot(BasePlot):
82
82
  "palette": style.palette,
83
83
  }
84
84
 
85
+ plot_kwargs: dict[str, Any] = {
86
+ **common_kwargs,
87
+ "hue": self.by,
88
+ "legend": False,
89
+ "ax": plot_ax,
90
+ }
91
+ merged_plot_kwargs = _merge_kwargs(plot_kwargs, kwargs)
92
+
85
93
  if self.violin:
86
- sns.violinplot(**common_kwargs, hue=self.by, legend=False, ax=plot_ax)
94
+ sns.violinplot(**merged_plot_kwargs)
87
95
  else:
88
- sns.boxplot(**common_kwargs, hue=self.by, legend=False, ax=plot_ax)
96
+ sns.boxplot(**merged_plot_kwargs)
89
97
 
90
98
  plot_ax.set_facecolor(style.background_color)
91
99
  plot_ax.set_ylabel(string_formatter(self.column))
@@ -122,11 +130,9 @@ class BoxViolinPlot(BasePlot):
122
130
  facecolor=style.background_color,
123
131
  edgecolor=style.background_color,
124
132
  )
125
- ax = Axes(fig=fig, facecolor=style.background_color)
133
+ ax = fig.add_subplot(111)
134
+ ax.set_facecolor(style.background_color)
126
135
  self.aplot(
127
- column=self.column,
128
- by=self.by,
129
- violin=self.violin,
130
136
  title=title,
131
137
  style=style,
132
138
  ax=ax,
@@ -492,7 +492,8 @@ class Bubble(BasePlot):
492
492
  facecolor=style.background_color,
493
493
  edgecolor=style.background_color,
494
494
  )
495
- ax = Axes(fig=fig, facecolor=style.background_color)
495
+ ax = fig.add_subplot(111)
496
+ ax.set_facecolor(style.background_color)
496
497
 
497
498
  self.aplot(
498
499
  title=title,
@@ -1,24 +1,22 @@
1
1
  """Heatmap and correlation matrix helpers."""
2
2
 
3
- from typing import Any, Optional, Sequence, Tuple
3
+ from typing import Any, Optional, Sequence, Tuple, cast
4
4
 
5
5
  import pandas as pd
6
- from pandas.api.extensions import register_dataframe_accessor
7
- import matplotlib.pyplot as plt
8
6
  import seaborn as sns
9
7
  from matplotlib.axes import Axes
10
8
  from matplotlib.figure import Figure
9
+ from pandas.api.extensions import register_dataframe_accessor
11
10
 
12
11
  from .base_plot import BasePlot
13
-
14
12
  from .style_template import (
15
13
  HEATMAP_STYLE_TEMPLATE,
16
14
  StyleTemplate,
17
15
  string_formatter,
18
16
  validate_dataframe,
19
17
  )
20
- from .utils import _get_axis
21
18
  from .typing import CorrelationMethod
19
+ from .utils import _get_axis, _merge_kwargs
22
20
 
23
21
  __all__ = [
24
22
  "HEATMAP_STYLE_TEMPLATE",
@@ -57,8 +55,14 @@ class Heatmap(BasePlot):
57
55
  ax: Optional[Axes] = None,
58
56
  **kwargs: Any,
59
57
  ) -> Axes:
58
+ """Plot a heatmap on an existing Matplotlib axes."""
60
59
  plot_ax = _get_axis(ax)
61
- sns.heatmap(self._obj, cmap=style.palette, ax=plot_ax)
60
+ heatmap_kwargs: dict[str, Any] = {
61
+ "data": self._obj,
62
+ "cmap": style.palette,
63
+ "ax": plot_ax,
64
+ }
65
+ sns.heatmap(**_merge_kwargs(heatmap_kwargs, kwargs))
62
66
 
63
67
  plot_ax.set_xlabel(string_formatter(self.x))
64
68
  plot_ax.set_ylabel(string_formatter(self.y))
@@ -72,12 +76,14 @@ class Heatmap(BasePlot):
72
76
  style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
73
77
  figsize: Tuple[float, float] = (10, 6),
74
78
  ) -> Figure:
79
+ """Plot a heatmap on a new Matplotlib figure."""
75
80
  fig = Figure(
76
81
  figsize=figsize,
77
82
  facecolor=style.background_color,
78
83
  edgecolor=style.background_color,
79
84
  )
80
- ax = Axes(fig=fig, facecolor=style.background_color)
85
+ ax = fig.add_subplot(111)
86
+ ax.set_facecolor(style.background_color)
81
87
  self.aplot(title=title, style=style, ax=ax)
82
88
  return fig
83
89
 
@@ -88,14 +94,16 @@ class Heatmap(BasePlot):
88
94
  ax: Optional[Axes] = None,
89
95
  **kwargs: Any,
90
96
  ) -> Axes:
97
+ """Plot a correlation matrix heatmap on existing axes."""
91
98
  plot_ax = _get_axis(ax)
92
- sns.heatmap(
93
- self.correlation_matrix,
94
- cmap=style.palette,
95
- annot=True,
96
- fmt=".2f",
97
- ax=plot_ax,
98
- )
99
+ heatmap_kwargs: dict[str, Any] = {
100
+ "data": self.correlation_matrix,
101
+ "cmap": style.palette,
102
+ "annot": True,
103
+ "fmt": ".2f",
104
+ "ax": plot_ax,
105
+ }
106
+ sns.heatmap(**_merge_kwargs(heatmap_kwargs, kwargs))
99
107
  if title:
100
108
  plot_ax.set_title(title)
101
109
  return plot_ax
@@ -106,13 +114,15 @@ class Heatmap(BasePlot):
106
114
  style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
107
115
  figsize: Tuple[float, float] = (10, 6),
108
116
  ) -> Figure:
117
+ """Plot a correlation matrix heatmap on a new figure."""
109
118
  fig = Figure(
110
119
  figsize=figsize,
111
120
  facecolor=style.background_color,
112
121
  edgecolor=style.background_color,
113
122
  )
114
- ax = Axes(fig=fig, facecolor=style.background_color)
115
- self.aplot(
123
+ ax = fig.add_subplot(111)
124
+ ax.set_facecolor(style.background_color)
125
+ self.aplot_correlation_matrix(
116
126
  title=title,
117
127
  style=style,
118
128
  ax=ax,
@@ -126,7 +136,7 @@ def _prepare_data(
126
136
  y: str,
127
137
  value: str,
128
138
  ) -> pd.DataFrame:
129
- """Prepare data for treemap plotting."""
139
+ """Prepare data for heatmap plotting."""
130
140
  validate_dataframe(pd_df, cols=[x, y, value])
131
141
  plot_df = pd_df[[x, y, value]].pivot_table(
132
142
  index=y, columns=x, values=value, aggfunc="mean"
@@ -134,6 +144,19 @@ def _prepare_data(
134
144
  return plot_df
135
145
 
136
146
 
147
+ def _compute_correlation_matrix(
148
+ pd_df: pd.DataFrame,
149
+ columns: Optional[Sequence[str]],
150
+ method: CorrelationMethod,
151
+ ) -> pd.DataFrame:
152
+ """Compute a correlation matrix from numeric dataframe columns."""
153
+ source_df = pd_df[list(columns)] if columns else pd_df
154
+ numeric_df = source_df.select_dtypes(include="number")
155
+ if numeric_df.empty:
156
+ raise ValueError("No numeric columns available to compute correlation matrix.")
157
+ return numeric_df.corr(method=cast(Any, method))
158
+
159
+
137
160
  def aplot_heatmap(
138
161
  pd_df: pd.DataFrame,
139
162
  x: str,
@@ -168,19 +191,19 @@ def aplot_correlation_matrix(
168
191
  **kwargs: Any,
169
192
  ) -> Axes:
170
193
  """Plot a correlation matrix heatmap for numeric columns."""
171
- return Heatmap(
172
- pd_df=pd_df,
173
- x="", # Placeholder since correlation matrix is square
174
- y="", # Placeholder since correlation matrix is square
175
- value="", # Placeholder since correlation matrix is computed internally
176
- ).aplot_correlation_matrix(
177
- method=method,
178
- title=title,
179
- style=style,
180
- ax=ax,
181
- columns=columns,
182
- **kwargs,
183
- )
194
+ corr_df = _compute_correlation_matrix(pd_df=pd_df, columns=columns, method=method)
195
+ plot_ax = _get_axis(ax)
196
+ heatmap_kwargs: dict[str, Any] = {
197
+ "data": corr_df,
198
+ "cmap": style.palette,
199
+ "annot": True,
200
+ "fmt": ".2f",
201
+ "ax": plot_ax,
202
+ }
203
+ sns.heatmap(**_merge_kwargs(heatmap_kwargs, kwargs))
204
+ if title:
205
+ plot_ax.set_title(title)
206
+ return plot_ax
184
207
 
185
208
 
186
209
  def fplot_heatmap(
@@ -217,7 +240,7 @@ def fplot_correlation_matrix(
217
240
  """Plot a correlation matrix heatmap on a new figure."""
218
241
  return Heatmap(
219
242
  pd_df=pd_df,
220
- x=x, # Placeholder since correlation matrix is square
221
- y=y, # Placeholder since correlation matrix is square
222
- value=value, # Placeholder since correlation matrix is computed internally
243
+ x=x,
244
+ y=y,
245
+ value=value,
223
246
  ).fplot_correlation_matrix(title=title, style=style, figsize=figsize)
@@ -19,7 +19,7 @@ from .style_template import (
19
19
  string_formatter,
20
20
  validate_dataframe,
21
21
  )
22
- from .utils import _get_axis
22
+ from .utils import _get_axis, _merge_kwargs
23
23
 
24
24
  __all__ = ["DISTRIBUTION_STYLE_TEMPLATE", "aplot_histogram", "fplot_histogram"]
25
25
 
@@ -50,15 +50,16 @@ class Histogram(BasePlot):
50
50
 
51
51
  validate_dataframe(self._obj, cols=[self.column])
52
52
  plot_ax = _get_axis(ax)
53
- sns.histplot(
54
- data=self._obj,
55
- x=self.column,
56
- bins=self.bins,
57
- kde=self.kde,
58
- color=style.font_color,
59
- edgecolor=style.background_color,
60
- ax=plot_ax,
61
- )
53
+ histplot_kwargs: dict[str, Any] = {
54
+ "data": self._obj,
55
+ "x": self.column,
56
+ "bins": self.bins,
57
+ "kde": self.kde,
58
+ "color": style.font_color,
59
+ "edgecolor": style.background_color,
60
+ "ax": plot_ax,
61
+ }
62
+ sns.histplot(**_merge_kwargs(histplot_kwargs, kwargs))
62
63
  plot_ax.set_facecolor(style.background_color)
63
64
  plot_ax.set_xlabel(string_formatter(self.column))
64
65
  plot_ax.set_ylabel("Frequency")
@@ -77,7 +78,8 @@ class Histogram(BasePlot):
77
78
  facecolor=style.background_color,
78
79
  edgecolor=style.background_color,
79
80
  )
80
- ax = Axes(fig=fig, facecolor=style.background_color)
81
+ ax = fig.add_subplot(111)
82
+ ax.set_facecolor(style.background_color)
81
83
  self.aplot(
82
84
  title=title,
83
85
  style=style,
@@ -821,7 +821,8 @@ class NetworkGraph(BasePlot):
821
821
  facecolor=style.background_color,
822
822
  edgecolor=style.background_color,
823
823
  )
824
- ax = Axes(fig=fig, facecolor=style.background_color)
824
+ ax = fig.add_subplot(111)
825
+ ax.set_facecolor(style.background_color)
825
826
  self.aplot(
826
827
  title=title,
827
828
  style=style,
@@ -480,7 +480,7 @@ def fplot_network_node(
480
480
  """
481
481
  fig = cast(Figure, plt.figure(figsize=figsize))
482
482
  fig.set_facecolor(style.background_color)
483
- ax = fig.add_subplot()
483
+ ax = fig.add_subplot(111)
484
484
  ax = aplot_network_node(
485
485
  pd_df,
486
486
  node=node,
@@ -11,7 +11,7 @@ from matplotlib.figure import Figure
11
11
  from .base_plot import BasePlot
12
12
 
13
13
  from .style_template import PIE_STYLE_TEMPLATE, StyleTemplate, validate_dataframe
14
- from .utils import _get_axis
14
+ from .utils import _get_axis, _merge_kwargs
15
15
 
16
16
  __all__ = ["PIE_STYLE_TEMPLATE", "aplot_pie", "fplot_pie"]
17
17
 
@@ -67,14 +67,14 @@ class PieChart(BasePlot):
67
67
  wedgeprops: Optional[Dict[str, Any]] = None
68
68
  if donut:
69
69
  wedgeprops = {"width": 0.3}
70
- plot_ax.pie(
71
- sizes,
72
- labels=labels,
73
- autopct="%1.1f%%",
74
- colors=sns.color_palette(style.palette),
75
- wedgeprops=wedgeprops,
76
- textprops={"color": style.font_color, "fontsize": style.font_size},
77
- )
70
+ pie_kwargs: Dict[str, Any] = {
71
+ "labels": labels,
72
+ "autopct": "%1.1f%%",
73
+ "colors": sns.color_palette(style.palette),
74
+ "wedgeprops": wedgeprops,
75
+ "textprops": {"color": style.font_color, "fontsize": style.font_size},
76
+ }
77
+ plot_ax.pie(sizes, **_merge_kwargs(pie_kwargs, kwargs))
78
78
  plot_ax.axis("equal")
79
79
  if title:
80
80
  plot_ax.set_title(title)
@@ -110,7 +110,8 @@ class PieChart(BasePlot):
110
110
  facecolor=style.background_color,
111
111
  edgecolor=style.background_color,
112
112
  )
113
- ax = Axes(fig=fig, facecolor=style.background_color)
113
+ ax = fig.add_subplot(111)
114
+ ax.set_facecolor(style.background_color)
114
115
  fig.set_facecolor(style.background_color)
115
116
  self.aplot(donut=donut, title=title, style=style, ax=ax)
116
117
  return fig