dataplot 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,92 @@
1
+ """
2
+ Contains a plotter class: QQPlot.
3
+
4
+ NOTE: this module is private. All functions and objects are available in the main
5
+ `dataplot` namespace - use that instead.
6
+
7
+ """
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ import numpy as np
12
+ from scipy import stats
13
+ from validating import dataclass
14
+
15
+ from .._typing import DistName
16
+ from ..setting import PlotSettable
17
+ from ..utils.math import get_quantile, linear_regression_1d
18
+ from .base import Plotter
19
+
20
+ if TYPE_CHECKING:
21
+ from ..container import AxesWrapper
22
+ from ..dataset import PlotDataSet
23
+
24
+ __all__ = ["QQPlot"]
25
+
26
+
27
+ @dataclass(validate_methods=True)
28
+ class QQPlot(Plotter):
29
+ """
30
+ A plotter class that creates a Q-Q plot.
31
+
32
+ """
33
+
34
+ dist_or_sample: "DistName | np.ndarray | PlotDataSet"
35
+ dots: int
36
+ edge_precision: float
37
+ fmt: str
38
+
39
+ def paint(self, ax: "AxesWrapper", **_) -> None:
40
+ ax.set_default(
41
+ title="Quantile-Quantile Plot",
42
+ xlabel="quantiles",
43
+ ylabel="quantiles",
44
+ )
45
+ ax.load(self.settings)
46
+ self.__plot(ax)
47
+
48
+ def __plot(self, ax: "AxesWrapper") -> None:
49
+ xlabel, p, q1 = self._generate_dist()
50
+ q2 = get_quantile(self.data, p)
51
+ ax.ax.plot(q1, q2, self.fmt, zorder=2.1, label=f"{self.label} & {xlabel}")
52
+ self._plot_fitted_line(ax, q1, q2)
53
+
54
+ def _generate_dist(self) -> tuple[str, np.ndarray, np.ndarray]:
55
+ if not 0 <= self.edge_precision < 0.5:
56
+ raise ValueError(
57
+ "'edge_precision' should be on the interval [0, 0.5), got "
58
+ f"{self.edge_precision} instead"
59
+ )
60
+ p = np.linspace(self.edge_precision, 1 - self.edge_precision, self.dots)
61
+ if isinstance(x := self.dist_or_sample, str):
62
+ xlabel = x + "-distribution"
63
+ q = self._get_ppf(x, p)
64
+ elif isinstance(x, PlotSettable):
65
+ xlabel = x.formatted_label()
66
+ q = get_quantile(x.data, p)
67
+ elif isinstance(x, (list, np.ndarray)):
68
+ xlabel = "sample"
69
+ q = get_quantile(x, p)
70
+ else:
71
+ raise TypeError(
72
+ f"'dist_or_sample' can not be instance of {x.__class__.__name__!r}"
73
+ )
74
+ return xlabel, p, q
75
+
76
+ @staticmethod
77
+ def _plot_fitted_line(ax: "AxesWrapper", x: np.ndarray, y: np.ndarray) -> None:
78
+ a, b = linear_regression_1d(y, x)
79
+ l, r = x.min(), x.max()
80
+ ax.ax.plot(
81
+ [l, r], [a + l * b, a + r * b], "--", label=f"y = {a:.3f} + {b:.3f}x"
82
+ )
83
+
84
+ @staticmethod
85
+ def _get_ppf(dist: str, p: np.ndarray) -> np.ndarray:
86
+ match dist:
87
+ case "normal":
88
+ return stats.norm.ppf(p)
89
+ case "expon":
90
+ return stats.expon.ppf(p)
91
+ case _:
92
+ raise ValueError(f"no such distribution: {dist!r}")
@@ -0,0 +1,59 @@
1
+ """
2
+ Contains a plotter class: ScatterChart.
3
+
4
+ NOTE: this module is private. All functions and objects are available in the main
5
+ `dataplot` namespace - use that instead.
6
+
7
+ """
8
+
9
+ from typing import TYPE_CHECKING, Optional
10
+
11
+ import numpy as np
12
+ from validating import dataclass
13
+
14
+ from ..setting import PlotSettable
15
+ from .base import Plotter
16
+
17
+ if TYPE_CHECKING:
18
+ from ..container import AxesWrapper
19
+ from ..dataset import PlotDataSet
20
+
21
+ __all__ = ["ScatterChart"]
22
+
23
+
24
+ @dataclass(validate_methods=True)
25
+ class ScatterChart(Plotter):
26
+ """
27
+ A plotter class that creates a scatter chart.
28
+
29
+ """
30
+
31
+ xticks: Optional["np.ndarray | PlotDataSet"]
32
+ fmt: str
33
+ sorted: bool
34
+
35
+ def paint(self, ax: "AxesWrapper", **_) -> None:
36
+ ax.set_default(title="Scatter Chart")
37
+ ax.load(self.settings)
38
+ self.__plot(ax)
39
+
40
+ def __plot(self, ax: "AxesWrapper") -> None:
41
+ if isinstance(self.xticks, PlotSettable):
42
+ xticks = self.xticks.data
43
+ else:
44
+ xticks = self.xticks
45
+ if xticks is None:
46
+ xticks = range(len(self.data))
47
+ elif (len_t := len(xticks)) != (len_d := len(self.data)):
48
+ raise ValueError(
49
+ "x-ticks and data must have the same length, but have "
50
+ f"lengths {len_t} and {len_d}"
51
+ )
52
+
53
+ if self.sorted:
54
+ paired = sorted(zip(xticks, self.data, strict=True), key=lambda pair: pair[0])
55
+ xticks, data = zip(*paired, strict=True)
56
+ else:
57
+ data = self.data
58
+
59
+ ax.ax.plot(xticks, data, self.fmt, linestyle="None", label=self.label)
dataplot/container.py ADDED
@@ -0,0 +1,203 @@
1
+ """
2
+ Contains container classes: FigWrapper and AxesWrapper.
3
+
4
+ NOTE: this module is private. All functions and objects are available in the main
5
+ `dataplot` namespace - use that instead.
6
+
7
+ """
8
+
9
+ import logging
10
+ from typing import TYPE_CHECKING, Any, Self, Unpack
11
+
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from matplotlib.figure import Figure
15
+ from matplotlib.pyplot import Axes
16
+ from validating import attr, dataclass
17
+
18
+ from ._typing import AxesSettingDict, FigureSettingDict, SettingKey
19
+ from .setting import PlotSettable
20
+
21
+ if TYPE_CHECKING:
22
+ from .artist import Artist
23
+
24
+ __all__ = ["FigWrapper", "AxesWrapper"]
25
+
26
+
27
+ @dataclass(validate_methods=True)
28
+ class AxesWrapper(PlotSettable):
29
+ """
30
+ Serves as a wrapper for creating and customizing axes in matplotlib.
31
+
32
+ Note that this should NEVER be instantiated directly, but always
33
+ through the invoking of `FigWrapper.axes`.
34
+
35
+ """
36
+
37
+ ax: Axes
38
+
39
+ def set_axes(self, **kwargs: Unpack[AxesSettingDict]) -> None:
40
+ """
41
+ Set the settings of axes.
42
+
43
+ Parameters
44
+ ----------
45
+ title : str, optional
46
+ Title of axes. Please note that there's another parameter with
47
+ the same name in `.set_figure()`.
48
+ xlabel : str, optional
49
+ Label for the x-axis.
50
+ ylabel : str, optional
51
+ Label for the y-axis.
52
+ alpha : float, optional
53
+ Controls the transparency of the plotted elements. It takes a float
54
+ value between 0 and 1, where 0 means completely transparent and 1
55
+ means completely opaque.
56
+ grid : bool, optional
57
+ Determines whether to show the grids or not.
58
+ grid_alpha : float, optional
59
+ Controls the transparency of the grid.
60
+ fontdict : FontDict, optional
61
+ A dictionary controlling the appearance of the title text.
62
+ legend_loc : LegendLoc, optional
63
+ Location of the legend.
64
+
65
+ """
66
+ self._set(inplace=True, **kwargs)
67
+
68
+ def exit(self) -> None:
69
+ """
70
+ Set various properties for the axes. This should be called only
71
+ by `FigWrapper.__exit__()`.
72
+
73
+ """
74
+ self.ax.set_xlabel(self.settings.xlabel)
75
+ self.ax.set_ylabel(self.settings.ylabel)
76
+ if len(self.ax.get_legend_handles_labels()[0]):
77
+ self.ax.legend(loc=self.settings.legend_loc)
78
+ if self.get_setting("grid", True):
79
+ alpha = self.get_setting("alpha", 1.0)
80
+ self.ax.grid(alpha=self.get_setting("grid_alpha", alpha / 2))
81
+ else:
82
+ self.ax.grid(False)
83
+ self.ax.set_title(self.settings.title, **self.get_setting("fontdict", {}))
84
+
85
+
86
+ @dataclass(validate_methods=True)
87
+ class FigWrapper(PlotSettable):
88
+ """
89
+ A wrapper of figure.
90
+
91
+ Note that this should NEVER be instantiated directly, but always through the
92
+ module-level function `dataplot.figure()`.
93
+
94
+ """
95
+
96
+ nrows: int = 1
97
+ ncols: int = 1
98
+ active: bool = attr(repr=False, default=True)
99
+ entered: bool = attr(init=False, repr=False, default=False)
100
+ fig: Figure = attr(init=False, repr=False)
101
+ axes: list[AxesWrapper] = attr(init=False, repr=False)
102
+ artists: "list[Artist]" = attr(default_factory=list, init=False, repr=False)
103
+
104
+ def __enter__(self) -> Self:
105
+ """
106
+ Create subplots and set the style.
107
+
108
+ Returns
109
+ -------
110
+ Self
111
+ An instance of self.
112
+
113
+ """
114
+ if not self.active:
115
+ return self
116
+ if self.entered:
117
+ raise DoubleEnteredError(
118
+ f"can't enter an instance of {self.__class__.__name__!r} for twice; "
119
+ "please do all the operations in one single context manager"
120
+ )
121
+
122
+ self.set_default(
123
+ style="seaborn-v0_8-darkgrid",
124
+ figsize=(10 * self.ncols, 5 * self.nrows),
125
+ subplots_adjust={"hspace": 0.5},
126
+ fontdict={"fontsize": "x-large"},
127
+ )
128
+ plt.style.use(self.settings.style)
129
+ self.fig, axes = plt.subplots(self.nrows, self.ncols)
130
+ self.axes: list[AxesWrapper] = [AxesWrapper(x) for x in np.reshape(axes, -1)]
131
+ self.entered = True
132
+ return self
133
+
134
+ def __exit__(self, *args) -> None:
135
+ """
136
+ Set various properties for the figure and paint it.
137
+
138
+ """
139
+ if not self.active:
140
+ return
141
+
142
+ if len(self.axes) > 1:
143
+ self.fig.suptitle(self.settings.title, **self.settings.fontdict)
144
+ else:
145
+ self.axes[0].ax.set_title(self.settings.title, **self.settings.fontdict)
146
+
147
+ self.fig.set_size_inches(*self.settings.figsize)
148
+ self.fig.subplots_adjust(**self.settings.subplots_adjust)
149
+ self.fig.set_dpi(self.get_setting("dpi", 100))
150
+
151
+ for ax in self.axes:
152
+ ax.exit()
153
+ if not ax.ax.has_data():
154
+ self.fig.delaxes(ax.ax)
155
+
156
+ plt.show()
157
+ plt.close(self.fig)
158
+ plt.style.use("default")
159
+
160
+ self.entered = False
161
+
162
+ def __repr__(self) -> str:
163
+ with self as fig:
164
+ for artist, ax in zip(self.artists, fig.axes[: len(self.artists)]):
165
+ artist.paint(ax)
166
+ return f"<{self.__class__.__name__}(nrows={self.nrows}, ncols={self.ncols})>"
167
+
168
+ def set_figure(self, **kwargs: Unpack[FigureSettingDict]) -> None:
169
+ """
170
+ Set the settings of figure.
171
+
172
+ Parameters
173
+ ----------
174
+ title : str, optional
175
+ Title of figure. Please note that there's another parameter with
176
+ the same name in `.set_axis()`.
177
+ dpi : float, optional
178
+ Sets the resolution of figure in dots-per-inch.
179
+ style : StyleName, optional
180
+ A style specification.
181
+ figsize : tuple[int, int], optional
182
+ Figure size, this takes a tuple of two integers that specifies the
183
+ width and height of the figure in inches.
184
+ fontdict : FontDict, optional
185
+ A dictionary controlling the appearance of the title text.
186
+ subplots_adjust : SubplotDict, optional
187
+ Adjusts the subplot layout parameters including: left, right, bottom,
188
+ top, wspace, and hspace. See `SubplotDict` for more details.
189
+
190
+ """
191
+ self._set(inplace=True, **kwargs)
192
+
193
+ def setting_check(self, key: SettingKey, value: Any) -> None:
194
+ if self.entered and key == "style":
195
+ logging.warning(
196
+ "setting the '%s' of a figure has no effect unless it's done "
197
+ "before invoking context manager",
198
+ key,
199
+ )
200
+
201
+
202
+ class DoubleEnteredError(Exception):
203
+ """Raised when entering a Figwrapper for twice."""
dataplot/core.py ADDED
@@ -0,0 +1,202 @@
1
+ """
2
+ Contains the core of dataplot: figure(), data(), show(), etc.
3
+
4
+ NOTE: this module is private. All functions and objects are available in the main
5
+ `dataplot` namespace - use that instead.
6
+
7
+ """
8
+
9
+ import dis
10
+ import re
11
+ import sys
12
+ from math import ceil, sqrt
13
+ from typing import TYPE_CHECKING, Any, Optional, Unpack
14
+
15
+ import numpy as np
16
+
17
+ from ._typing import FigureSettingDict
18
+ from .container import FigWrapper
19
+ from .dataset import PlotDataSet, PlotDataSets
20
+
21
+ if TYPE_CHECKING:
22
+ from .artist import Artist
23
+
24
+
25
+ __all__ = ["data", "figure"]
26
+
27
+
28
+ def _infer_var_names(*values: Any) -> list[Optional[str]]:
29
+ try:
30
+ search_frame = sys._getframe(1)
31
+ except ValueError:
32
+ return [None] * len(values)
33
+
34
+ labels: list[Optional[str]] = []
35
+ try:
36
+ for value in values:
37
+ name = None
38
+ current = search_frame
39
+ while current is not None and name is None:
40
+ local_items = list(current.f_locals.items())
41
+ global_items = list(current.f_globals.items())
42
+ name = next((k for k, v in local_items if v is value), None)
43
+ if name is None:
44
+ name = next((k for k, v in global_items if v is value), None)
45
+ current = current.f_back
46
+ labels.append(name)
47
+ finally:
48
+ del search_frame
49
+ return labels
50
+
51
+
52
+ def _infer_assigned_name() -> Optional[str]:
53
+ """Try inferring assignment target name from call-site."""
54
+ try:
55
+ frame = sys._getframe(2)
56
+ except ValueError:
57
+ return None
58
+
59
+ # Bytecode inspection works in REPL/notebook contexts where source code file
60
+ # is unavailable.
61
+ try:
62
+ instructions = list(dis.get_instructions(frame.f_code))
63
+ store_index = next(
64
+ (
65
+ i
66
+ for i, ins in enumerate(instructions)
67
+ if ins.offset > frame.f_lasti
68
+ and ins.opname in {"STORE_NAME", "STORE_FAST", "STORE_GLOBAL"}
69
+ ),
70
+ None,
71
+ )
72
+ if store_index is not None:
73
+ # `a = b = data(...)` compiles to STORE_* b then STORE_* a;
74
+ # returning the last one better matches user expectation.
75
+ last_store = store_index
76
+ while last_store + 1 < len(instructions) and instructions[
77
+ last_store + 1
78
+ ].opname in {"STORE_NAME", "STORE_FAST", "STORE_GLOBAL"}:
79
+ last_store += 1
80
+ return str(instructions[last_store].argval)
81
+ except Exception:
82
+ pass
83
+
84
+ try:
85
+ context = frame.f_code.co_filename
86
+ lineno = frame.f_lineno
87
+ with open(context, "r", encoding="utf-8") as f:
88
+ lines = f.readlines()
89
+ line = lines[lineno - 1].strip()
90
+ except (OSError, IndexError):
91
+ return None
92
+ finally:
93
+ del frame
94
+
95
+ if "=" not in line:
96
+ return None
97
+ lhs = line.split("=", 1)[0].strip()
98
+ if not lhs:
99
+ return None
100
+ m = re.match(r"([A-Za-z_][A-Za-z0-9_]*)", lhs)
101
+ return m.group(1) if m else None
102
+
103
+
104
+ def data(*x: Any, label: Optional[str | list[str]] = None) -> PlotDataSet:
105
+ """
106
+ Initializes a dataset interface which provides methods for mathematical
107
+ operations and plotting.
108
+
109
+ Parameters
110
+ ----------
111
+ *x : np.ndarray | Any
112
+ Input values, this takes one or multiple arrays, with each array
113
+ representing a dataset.
114
+ label : str | list[str], optional
115
+ Label(s) of the data, this takes either a single string or a list of strings.
116
+ If a list, should be the same length as the number of input arrays, with
117
+ each element corresponding to a specific array in `x`. If set to None,
118
+ use "x{i}" (i = 1, 2. 3, ...) as the label(s). By default None.
119
+
120
+ Returns
121
+ -------
122
+ PlotDataSet
123
+ Provides methods for mathematical operations and plotting.
124
+
125
+ """
126
+ if not x:
127
+ raise ValueError("at least one dataset should be provided")
128
+
129
+ if len(x) > 1:
130
+ if label is None:
131
+ label = [
132
+ lb if lb is not None else f"x{i}"
133
+ for i, lb in enumerate(_infer_var_names(*x), start=1)
134
+ ]
135
+ elif isinstance(label, str):
136
+ raise ValueError(
137
+ "for multiple datasets, please provide labels as a list of strings"
138
+ )
139
+ elif len(label) != len(x):
140
+ raise ValueError(
141
+ f"label should have the same length as x ({len(x)}), got {len(label)}"
142
+ )
143
+ datas = [PlotDataSet(np.array(d), lb) for d, lb in zip(x, label)]
144
+ return PlotDataSets(*datas)
145
+
146
+ if isinstance(label, list):
147
+ raise ValueError(
148
+ "it seems not necessary to provide a list of labels, since "
149
+ "the data has only one dimension"
150
+ )
151
+ if label is None:
152
+ label = _infer_assigned_name() or _infer_var_names(x[0])[0] or "x1"
153
+ return PlotDataSet(np.array(x[0]), label=label)
154
+
155
+
156
+ def figure(
157
+ artist: "Artist | list[Artist]",
158
+ nrows: int | None = None,
159
+ ncols: int | None = None,
160
+ **kwargs: Unpack[FigureSettingDict],
161
+ ) -> FigWrapper:
162
+ """
163
+ Provides a context manager interface (`__enter__` and `__exit__` methods) for
164
+ creating a figure with subplots and setting various properties for the figure.
165
+
166
+ Parameters
167
+ ----------
168
+ artist : Artist | list[Artist]
169
+ Artist or list of artists.
170
+ nrows : int, optional
171
+ Determines how many subplots can be arranged vertically in the figure,
172
+ If None, will be automatically set according to ``len(artist)``. By default
173
+ None.
174
+ ncols : int, optional
175
+ Determines how many subplots can be arranged horizontally in the figure.
176
+ If None, will be automatically set according to ``len(artist)``. By default
177
+ None.
178
+ **kwargs : **FigureSettingDict
179
+ Specifies the figure settings, see `FigWrapper.set_figure()` for more details.
180
+
181
+ Returns
182
+ -------
183
+ FigWrapper
184
+ A wrapper of figure.
185
+
186
+ """
187
+ if not isinstance(artist, list):
188
+ artist = [artist]
189
+ len_a = max(len(artist), 1)
190
+ if nrows is None and ncols is None:
191
+ ncols = int(sqrt(len_a))
192
+ nrows = ceil(len_a / ncols)
193
+ elif ncols is None:
194
+ nrows = min(nrows, len_a)
195
+ ncols = ceil(len_a / ncols)
196
+ else:
197
+ ncols = min(ncols, len_a)
198
+ nrows = ceil(len_a / ncols)
199
+ figw = FigWrapper(nrows=nrows, ncols=ncols)
200
+ figw.set_figure(**kwargs)
201
+ figw.artists = artist
202
+ return figw