dataplot 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataplot/dataset.py ADDED
@@ -0,0 +1,968 @@
1
+ """
2
+ Contains the dataset interface: PlotData.
3
+
4
+ NOTE: this module is private. All functions and objects are available in the main
5
+ `dataplot` namespace - use that instead.
6
+
7
+ """
8
+
9
+ from abc import ABCMeta
10
+ from functools import partial
11
+ from typing import (
12
+ TYPE_CHECKING,
13
+ Any,
14
+ Callable,
15
+ Literal,
16
+ Optional,
17
+ Self,
18
+ Unpack,
19
+ overload,
20
+ )
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+ from validating import attr, dataclass
25
+
26
+ from ._typing import DistName, ResampleRule, SettingDict
27
+ from .artist import (
28
+ Artist,
29
+ CorrMap,
30
+ Histogram,
31
+ KSPlot,
32
+ LineChart,
33
+ PPPlot,
34
+ QQPlot,
35
+ ScatterChart,
36
+ )
37
+ from .setting import PlotSettable, PlotSettings
38
+ from .utils.multi import (
39
+ REMAIN,
40
+ UNSUBSCRIPTABLE,
41
+ MultiObject,
42
+ multipartial,
43
+ single,
44
+ )
45
+
46
+ if TYPE_CHECKING:
47
+ from .artist import Plotter
48
+ from .container import AxesWrapper
49
+
50
+
51
+ __all__ = ["PlotDataSet"]
52
+
53
+
54
+ @dataclass(validate_methods=True)
55
+ class PlotDataSet(PlotSettable, metaclass=ABCMeta):
56
+ """
57
+ A dataset class providing methods for mathematical operations and plotting.
58
+
59
+ Note that this should NEVER be instantiated directly, but always through the
60
+ module-level function `dataplot.data()`.
61
+
62
+ Parameters
63
+ ----------
64
+ data : np.ndarray
65
+ Input data.
66
+ label : str, optional
67
+ Label of the data. If set to None, use "x" as the label. By default None.
68
+
69
+ Properties
70
+ ----------
71
+ fmt : str
72
+ A string recording the mathmatical operations done on the data.
73
+ original_data : np.ndarray
74
+ Original input data.
75
+ settings : PlotSettings
76
+ Settings for plot (whether a figure or an axes).
77
+ priority : int
78
+ Priority of the latest mathmatical operation, where:
79
+ 0 : Highest priority, refering to `repr()` and some of unary operations;
80
+ 10 : Refers to binary operations that are prior to / (e.g., **);
81
+ 19 : Particularly refers to /;
82
+ 20 : Particularly refers to *;
83
+ 29 : Particularly refers to binary -;
84
+ 30 : Particularly refers to +;
85
+ 40 : Particularly refers to unary -.
86
+ Note that / and binary - are distinguished from * or + because the
87
+ former ones disobey the associative law.
88
+
89
+ """
90
+
91
+ data: np.ndarray
92
+ label: Optional[str] = attr(default=None)
93
+ fmtb: str = attr(init=False, default="{0}")
94
+ original_data: np.ndarray = attr(init=False)
95
+ settings: PlotSettings = attr(init=False, default_factory=PlotSettings)
96
+ priority: int = attr(init=False, default=0)
97
+
98
+ @classmethod
99
+ def __subclasshook__(cls, __subclass: type) -> bool:
100
+ if __subclass is PlotDataSet or issubclass(__subclass, PlotDataSets):
101
+ return True
102
+ return False
103
+
104
+ def __post_init__(self) -> None:
105
+ self.label = "x" if self.label is None else self.label
106
+ self.original_data = self.data
107
+
108
+ def __create(
109
+ self, fmt: str, data: np.ndarray, priority: int = 0, label: Optional[str] = None
110
+ ) -> Self:
111
+ obj = self.customize(
112
+ self.__class__,
113
+ self.original_data,
114
+ self.label if label is None else label,
115
+ fmtb=fmt,
116
+ priority=priority,
117
+ )
118
+ obj.data = data
119
+ return obj
120
+
121
+ def __repr__(self) -> str:
122
+ return self.__class__.__name__ + "\n- " + self.data_info()
123
+
124
+ def data_info(self) -> str:
125
+ """
126
+ Information of dataset.
127
+
128
+ Returns
129
+ -------
130
+ str
131
+ A string indicating the data label and the plot settings.
132
+
133
+ """
134
+ not_none = self.settings.repr_not_none()
135
+ return f"{self.formatted_label()}{': ' if not_none else ''}{not_none}"
136
+
137
+ def __getitem__(self, __key: int) -> Self | Any:
138
+ return UNSUBSCRIPTABLE
139
+
140
+ def __neg__(self) -> Self:
141
+ new_fmt = f"(-{self.__remove_brackets(self.fmtb, priority=28)})"
142
+ new_data = -self.data
143
+ return self.__create(new_fmt, new_data, priority=40)
144
+
145
+ def __add__(self, __other: "float | int | PlotDataSet") -> Self:
146
+ return self.__binary_operation(__other, "+", np.add, priority=30)
147
+
148
+ def __radd__(self, __other: "float | int | PlotDataSet") -> Self:
149
+ return self.__binary_operation(__other, "+", np.add, reverse=True, priority=30)
150
+
151
+ def __sub__(self, __other: "float | int | PlotDataSet") -> Self:
152
+ return self.__binary_operation(__other, "-", np.subtract, priority=29)
153
+
154
+ def __rsub__(self, __other: "float | int | PlotDataSet") -> Self:
155
+ return self.__binary_operation(
156
+ __other, "-", np.subtract, reverse=True, priority=29
157
+ )
158
+
159
+ def __mul__(self, __other: "float | int | PlotDataSet") -> Self:
160
+ return self.__binary_operation(__other, "*", np.multiply, priority=20)
161
+
162
+ def __rmul__(self, __other: "float | int | PlotDataSet") -> Self:
163
+ return self.__binary_operation(
164
+ __other, "*", np.multiply, reverse=True, priority=20
165
+ )
166
+
167
+ def __truediv__(self, __other: "float | int | PlotDataSet") -> Self:
168
+ return self.__binary_operation(__other, "/", np.true_divide, priority=19)
169
+
170
+ def __rtruediv__(self, __other: "float | int | PlotDataSet") -> Self:
171
+ return self.__binary_operation(
172
+ __other, "/", np.true_divide, reverse=True, priority=19
173
+ )
174
+
175
+ def __pow__(self, __other: "float | int | PlotDataSet") -> Self:
176
+ return self.__binary_operation(__other, "**", np.power)
177
+
178
+ def __rpow__(self, __other: "float | int | PlotDataSet") -> Self:
179
+ return self.__binary_operation(__other, "**", np.power, reverse=True)
180
+
181
+ def __binary_operation(
182
+ self,
183
+ other: "float | int | PlotDataSet | Any",
184
+ sign: str,
185
+ func: Callable[[Any, Any], np.ndarray],
186
+ reverse: bool = False,
187
+ priority: int = 10,
188
+ ) -> Self:
189
+ if reverse:
190
+ this_fmt = self.__remove_brackets(self.fmtb, priority=priority)
191
+ new_fmt = f"({other}{sign}{this_fmt})"
192
+ new_data = func(other, self.data)
193
+ return self.__create(new_fmt, new_data, priority=priority)
194
+
195
+ this_fmt = self.__remove_brackets(self.fmtb, priority=priority + 1)
196
+ if isinstance(other, (float, int)):
197
+ new_fmt = f"({this_fmt}{sign}{other})"
198
+ new_data = func(self.data, other)
199
+ elif isinstance(other, PlotDataSet):
200
+ other_label = other.formatted_label(priority=priority)
201
+ new_fmt = f"({this_fmt}{sign}{other_label})"
202
+ new_data = func(self.data, other.data)
203
+ else:
204
+ raise ValueError(
205
+ f"{sign!r} not supported between instances of 'PlotDataSet' and "
206
+ f"{other.__class__.__name__!r}"
207
+ )
208
+ return self.__create(new_fmt, new_data, priority=priority)
209
+
210
+ def __remove_brackets(self, string: str, priority: int = 0):
211
+ if priority == 0 or self.priority <= priority:
212
+ if string.startswith("(") and string.endswith(")"):
213
+ return string[1:-1]
214
+ return string
215
+
216
+ @property
217
+ def format(self) -> str:
218
+ """
219
+ Return the label format.
220
+
221
+ Returns
222
+ -------
223
+ str
224
+ Label format.
225
+
226
+ """
227
+ return self.__remove_brackets(self.fmtb)
228
+
229
+ def formatted_label(self, priority: int = 0) -> str:
230
+ """
231
+ Return the formatted label, but remove the pair of brackets at both ends
232
+ of the string if neccessary.
233
+
234
+ Parameters
235
+ ----------
236
+ priority : int, optional
237
+ Indicates whether to remove the brackets, by default 0.
238
+
239
+ Returns
240
+ -------
241
+ str
242
+ Formatted label.
243
+
244
+ """
245
+ if priority == self.priority and priority in (19, 29):
246
+ priority -= 1
247
+ return self.__remove_brackets(self.fmtb.format(self.label), priority=priority)
248
+
249
+ def join(self, *others: "PlotDataSet") -> Self:
250
+ """
251
+ Merge two or more `PlotDataSet` instances.
252
+
253
+ Parameters
254
+ ----------
255
+ *others : PlotDataSet
256
+ The instances to be merged.
257
+
258
+ Returns
259
+ -------
260
+ Self
261
+ A new instance of self.__class__.
262
+
263
+ """
264
+ return PlotDataSets(self, *others)
265
+
266
+ def resample(self, n: int, rule: ResampleRule = "head") -> Self:
267
+ """
268
+ Resample from the data.
269
+
270
+ Parameters
271
+ ----------
272
+ n : int
273
+ Length of new sample.
274
+ rule : ResampleRule, optional
275
+ Resample rule, by default "head".
276
+
277
+ Returns
278
+ -------
279
+ Self
280
+ A new instance of self.__class__.
281
+
282
+ Raises
283
+ ------
284
+ ValueError
285
+ Raised when receiving illegal rule.
286
+
287
+ """
288
+ new_fmt = f"resample({self.format}, {n})"
289
+ match rule:
290
+ case "random":
291
+ idx = np.random.randint(0, len(self.data), n)
292
+ new_data = self.data[idx]
293
+ case "head":
294
+ new_data = self.data[:n]
295
+ case "tail":
296
+ new_data = self.data[-n:]
297
+ case _:
298
+ raise ValueError(f"rule not supported: {rule!r}")
299
+ return self.__create(new_fmt, new_data)
300
+
301
+ def log(self) -> Self:
302
+ """
303
+ Perform a log operation on the data.
304
+
305
+ Returns
306
+ -------
307
+ Self
308
+ A new instance of self.__class__.
309
+
310
+ """
311
+ new_fmt = f"log({self.format})"
312
+ new_data = np.log(np.where(self.data > 0, self.data, np.nan))
313
+ return self.__create(new_fmt, new_data)
314
+
315
+ def log10(self) -> Self:
316
+ """
317
+ Perform a log operation on the data (with base 10).
318
+
319
+ Returns
320
+ -------
321
+ Self
322
+ A new instance of self.__class__.
323
+
324
+ """
325
+ new_fmt = f"log10({self.format})"
326
+ new_data = np.log10(np.where(self.data > 0, self.data, np.nan))
327
+ return self.__create(new_fmt, new_data)
328
+
329
+ def signedlog(self) -> Self:
330
+ """
331
+ Perform a log operation on the data, but keep the sign.
332
+
333
+ signedlog(x) =
334
+
335
+ * log(x), for x > 0;
336
+ * 0, for x = 0;
337
+ * -log(-x), for x < 0.
338
+
339
+ Returns
340
+ -------
341
+ Self
342
+ A new instance of self.__class__.
343
+
344
+ """
345
+ new_fmt = f"signedlog({self.format})"
346
+ new_data = np.log(np.where(self.data > 0, self.data, np.nan))
347
+ new_data[self.data < 0] = np.log(-self.data[self.data < 0])
348
+ new_data[self.data == 0] = 0
349
+ return self.__create(new_fmt, new_data)
350
+
351
+ def signedpow(self, n: float) -> Self:
352
+ """
353
+ Perform a power operation on the data, but keep the sign.
354
+
355
+ signedpow(x, n) =
356
+
357
+ * x**n, for x > 0;
358
+ * 0, for x = 0;
359
+ * -x**(-n) for x < 0.
360
+
361
+ Returns
362
+ -------
363
+ Self
364
+ A new instance of self.__class__.
365
+
366
+ """
367
+ new_fmt = f"signedpow({self.format})"
368
+ new_data = np.where(self.data > 0, self.data, np.nan) ** n
369
+ new_data[self.data < 0] = -((-self.data[self.data < 0]) ** n)
370
+ new_data[self.data == 0] = 0
371
+ return self.__create(new_fmt, new_data)
372
+
373
+ def rolling(self, n: int) -> Self:
374
+ """
375
+ Perform a rolling-mean operation on the data.
376
+
377
+ Parameters
378
+ ----------
379
+ n : int
380
+ Specifies the window size for calculating the rolling average of
381
+ the data points.
382
+
383
+ Returns
384
+ -------
385
+ Self
386
+ A new instance of self.__class__.
387
+
388
+ """
389
+ new_fmt = f"rolling({self.format}, {n})"
390
+ new_data = pd.Series(self.data).rolling(n).mean().values
391
+ return self.__create(new_fmt, new_data)
392
+
393
+ def exp(self) -> Self:
394
+ """
395
+ Perform an exp operation on the data.
396
+
397
+ Returns
398
+ -------
399
+ Self
400
+ A new instance of self.__class__.
401
+
402
+ """
403
+ new_fmt = f"exp({self.format})"
404
+ new_data = np.exp(self.data)
405
+ return self.__create(new_fmt, new_data)
406
+
407
+ def abs(self) -> Self:
408
+ """
409
+ Perform an abs operation on the data.
410
+
411
+ Returns
412
+ -------
413
+ Self
414
+ A new instance of self.__class__.
415
+
416
+ """
417
+ new_fmt = f"abs({self.format})"
418
+ new_data = np.abs(self.data)
419
+ return self.__create(new_fmt, new_data)
420
+
421
+ def demean(self) -> Self:
422
+ """
423
+ Perform a demean operation on the data by subtracting its mean.
424
+
425
+ Returns
426
+ -------
427
+ Self
428
+ A new instance of self.__class__.
429
+
430
+ """
431
+ new_fmt = f"({self.format}-mean({self.format}))"
432
+ new_data = self.data - np.nanmean(self.data)
433
+ return self.__create(new_fmt, new_data)
434
+
435
+ def zscore(self) -> Self:
436
+ """
437
+ Perform a zscore operation on the data by subtracting its mean and then
438
+ dividing by its standard deviation.
439
+
440
+ Returns
441
+ -------
442
+ Self
443
+ A new instance of self.__class__.
444
+
445
+ """
446
+ new_fmt = f"zscore({self.format})"
447
+ new_data = (self.data - np.nanmean(self.data)) / np.nanstd(self.data)
448
+ return self.__create(new_fmt, new_data)
449
+
450
+ def cumsum(self) -> Self:
451
+ """
452
+ Perform a cumsum operation on the data by calculating its cummulative
453
+ sums.
454
+
455
+ Returns
456
+ -------
457
+ Self
458
+ A new instance of self.__class__.
459
+
460
+ """
461
+ new_fmt = f"csum({self.format})"
462
+ new_data = np.cumsum(self.data)
463
+ return self.__create(new_fmt, new_data)
464
+
465
+ def copy(self) -> Self:
466
+ return self.__create(self.fmtb, self.data, priority=self.priority)
467
+
468
+ def reset(self) -> Self:
469
+ """
470
+ Copy and reset the plot settings.
471
+
472
+ Returns
473
+ -------
474
+ Self
475
+ A new instance of self.__class__.
476
+
477
+ """
478
+ obj = self.copy()
479
+ obj.settings.reset()
480
+ return obj
481
+
482
+ def undo_all(self) -> None:
483
+ """
484
+ Undo all the operations performed on the data and clean the records.
485
+
486
+ """
487
+ self.fmtb = "{0}"
488
+ self.data = self.original_data
489
+
490
+ def set_label(
491
+ self, label: Optional[str] = None, reset_format: bool = True, /, **kwargs: str
492
+ ) -> Self:
493
+ """
494
+ Set the labels.
495
+
496
+ Parameters
497
+ ----------
498
+ label : str, optional
499
+ The new label (if specified), by default None.
500
+ reset_format : bool, optional
501
+ Determines whether to reset the format of the label (which shows
502
+ the operations done on the data), by default True.
503
+ **kwargs : str
504
+ Works as a mapper to find the new label. If `self.label` is in
505
+ `kwargs`, the label will be set to `kwargs[self.label]`.
506
+
507
+ Returns
508
+ -------
509
+ Self
510
+ A new instance of self.__class__.
511
+
512
+ """
513
+ if isinstance(label, str):
514
+ new_label = label
515
+ elif self.label in kwargs:
516
+ new_label = kwargs[self.label]
517
+ else:
518
+ new_label = self.label
519
+ return self.__create(
520
+ "{0}" if reset_format else self.fmtb,
521
+ self.data,
522
+ priority=self.priority,
523
+ label=new_label,
524
+ )
525
+
526
+ @overload
527
+ def set_plot(
528
+ self, *, inplace: Literal[False] = False, **kwargs: Unpack[SettingDict]
529
+ ) -> Self: ...
530
+ @overload
531
+ def set_plot(
532
+ self, *, inplace: Literal[True] = True, **kwargs: Unpack[SettingDict]
533
+ ) -> None: ...
534
+ def set_plot(
535
+ self, *, inplace: bool = False, **kwargs: Unpack[SettingDict]
536
+ ) -> Self | None:
537
+ """
538
+ Set the settings of a plot (whether a figure or an axes).
539
+
540
+ Parameters
541
+ ----------
542
+ inplace : bool, optional
543
+ Determines whether the changes of settings will happen in self or
544
+ in a new copy of self, by default False.
545
+ title : str, optional
546
+ Title of plot.
547
+ xlabel : str, optional
548
+ Label for the x-axis.
549
+ ylabel : str, optional
550
+ Label for the y-axis.
551
+ alpha : float, optional
552
+ Controls the transparency of the plotted elements. It takes a float
553
+ value between 0 and 1, where 0 means completely transparent and 1
554
+ means completely opaque.
555
+ dpi : float, optional
556
+ Sets the resolution of figure in dots-per-inch.
557
+ grid : bool, optional
558
+ Determines whether to show the grids or not.
559
+ grid_alpha : float, optional
560
+ Controls the transparency of the grid.
561
+ style : StyleName, optional
562
+ A style specification.
563
+ figsize : tuple[int, int], optional
564
+ Figure size, this takes a tuple of two integers that specifies the
565
+ width and height of the figure in inches.
566
+ fontdict : FontDict, optional
567
+ A dictionary controlling the appearance of the title text.
568
+ legend_loc : LegendLoc, optional
569
+ Location of the legend.
570
+ format_label : bool, optional
571
+ Determines whether to format the label (to show the operations done
572
+ on the data).
573
+
574
+ Returns
575
+ -------
576
+ Self | None
577
+ A new instance of self.__class__, or None.
578
+
579
+ """
580
+ return self._set(inplace=inplace, **kwargs)
581
+
582
+ def batched(self, n: int = 1) -> Self:
583
+ """
584
+ If this instance is joined by multiple `PlotDataSet` objects, batch the
585
+ objects into tuples of length n, otherwise return self.
586
+
587
+ Use this together with `.plot()`, `.hist()`, etc.
588
+
589
+ Parameters
590
+ ----------
591
+ n : int, optional
592
+ Specifies the batch size, by default 1.
593
+
594
+ Returns
595
+ -------
596
+ Self
597
+ A new instance of self.__class__.
598
+
599
+ """
600
+ if n <= 0:
601
+ raise ValueError(f"batch size should be greater than 0, got {n} instead")
602
+ return MultiObject([self])
603
+
604
+ def hist(
605
+ self,
606
+ bins: int | list[float] = 100,
607
+ fit: bool = True,
608
+ density: bool = True,
609
+ log: bool = False,
610
+ same_bin: bool = True,
611
+ stats: bool = True,
612
+ ax: Optional["AxesWrapper"] = None,
613
+ **kwargs: Unpack[SettingDict],
614
+ ) -> Artist:
615
+ """
616
+ Create a histogram of the data.
617
+
618
+ Parameters
619
+ ----------
620
+ bins : int | list[float], optional
621
+ Specifies the bins to divide the data into. If int, should be the number
622
+ of bins. By default 100.
623
+ fit : bool, optional
624
+ Determines whether to fit a curve to the histogram, only available when
625
+ `density=True`, by default True.
626
+ density : bool, optional
627
+ Determines whether to draw a probability density. If True, the histogram
628
+ will be normalized such that the area under it equals to 1. By default
629
+ True.
630
+ log : bool, optional
631
+ Determines whether to set the histogram axis to a log scale, by default
632
+ False.
633
+ same_bin : bool, optional
634
+ Determines whether the bins should be the same for all sets of data, by
635
+ default True.
636
+ stats : bool, optional
637
+ Determines whether to show the statistics, including the calculated mean,
638
+ standard deviation, skewness, and kurtosis of the input, by default True.
639
+ on : Optional[AxesWrapper], optional
640
+ Specifies the axes-wrapper on which the plot should be painted. If
641
+ not specified, the histogram will be plotted on a new axes in a new
642
+ figure. By default None.
643
+ **kwargs : **SettingDict
644
+ Specifies the plot settings, see `.set_plot()` for more details.
645
+
646
+ Returns
647
+ -------
648
+ Artist
649
+ An instance of Artist.
650
+
651
+ """
652
+ return self._get_artist(Histogram, locals())
653
+
654
+ def plot(
655
+ self,
656
+ xticks: Optional["np.ndarray | PlotDataSet"] = None,
657
+ fmt: str = "",
658
+ scatter: bool = False,
659
+ sorted: bool = False,
660
+ ax: Optional["AxesWrapper"] = None,
661
+ **kwargs: Unpack[SettingDict],
662
+ ) -> Artist:
663
+ """
664
+ Create a line chart for the data. If there are more than one datasets, all of
665
+ them should have the same length.
666
+
667
+ Parameters
668
+ ----------
669
+ xticks : np.ndarray | PlotDataSet, optional
670
+ Specifies the x-ticks for the line chart. If not provided, the x-ticks will
671
+ be set to `range(len(data))`. By default None.
672
+ fmt : str, optional
673
+ A format string, e.g. 'ro' for red circles, by default ''.
674
+ scatter : bool, optional
675
+ Determines whether to include scatter points in the line chart, by default
676
+ False.
677
+ sorted : bool, optional
678
+ Determines whether to sort by x-ticks before drawing the chart, by
679
+ default False.
680
+ ax : AxesWrapper, optional
681
+ Specifies the axes-wrapper on which the plot should be painted If
682
+ not specified, the histogram will be plotted on a new axes in a new
683
+ figure. By default None.
684
+ **kwargs : **SettingDict
685
+ Specifies the plot settings, see `.set_plot()` for more details.
686
+
687
+ Returns
688
+ -------
689
+ Artist
690
+ An instance of Artist.
691
+
692
+ """
693
+ if isinstance(xticks, PlotSettable) and "xlabel" not in kwargs:
694
+ if kwargs.get("format_label", True):
695
+ kwargs["xlabel"] = xticks.formatted_label()
696
+ else:
697
+ kwargs["xlabel"] = xticks.label
698
+ return self._get_artist(LineChart, locals())
699
+
700
+ def scatter(
701
+ self,
702
+ xticks: Optional["np.ndarray | PlotDataSet"] = None,
703
+ fmt: str = "",
704
+ sorted: bool = False,
705
+ ax: Optional["AxesWrapper"] = None,
706
+ **kwargs: Unpack[SettingDict],
707
+ ) -> Artist:
708
+ """
709
+ Create a scatter chart for the data. If there are more than one datasets,
710
+ all of them should have the same length.
711
+
712
+ Parameters
713
+ ----------
714
+ xticks : np.ndarray | PlotDataSet, optional
715
+ Specifies the x-ticks for the chart. If not provided, the x-ticks will
716
+ be set to `range(len(data))`. By default None.
717
+ fmt : str, optional
718
+ A format string, e.g. 'ro' for red circles, by default ''.
719
+ sorted : bool, optional
720
+ Determines whether to sort by x-ticks before drawing the chart, by
721
+ default False.
722
+ ax : AxesWrapper, optional
723
+ Specifies the axes-wrapper on which the plot should be painted If
724
+ not specified, the histogram will be plotted on a new axes in a new
725
+ figure. By default None.
726
+ **kwargs : **SettingDict
727
+ Specifies the plot settings, see `.set_plot()` for more details.
728
+
729
+ Returns
730
+ -------
731
+ Artist
732
+ An instance of Artist.
733
+
734
+ """
735
+ if isinstance(xticks, PlotSettable) and "xlabel" not in kwargs:
736
+ if kwargs.get("format_label", True):
737
+ kwargs["xlabel"] = xticks.formatted_label()
738
+ else:
739
+ kwargs["xlabel"] = xticks.label
740
+ return self._get_artist(ScatterChart, locals())
741
+
742
+ def qqplot(
743
+ self,
744
+ dist_or_sample: "DistName | np.ndarray | PlotDataSet" = "normal",
745
+ dots: int = 30,
746
+ edge_precision: float = 1e-2,
747
+ fmt: str = "o",
748
+ ax: Optional["AxesWrapper"] = None,
749
+ **kwargs: Unpack[SettingDict],
750
+ ) -> Artist:
751
+ """
752
+ Create a quantile-quantile plot.
753
+
754
+ Parameters
755
+ ----------
756
+ dist_or_sample : DistName | np.ndarray | PlotDataSet, optional
757
+ Specifies the distribution to compare with. If str, specifies a
758
+ theoretical distribution; if np.ndarray or PlotDataSet, specifies
759
+ another real sample. By default 'normal'.
760
+ dots : int, optional
761
+ Number of dots, by default 30.
762
+ edge_precision : float, optional
763
+ Specifies the lowest quantile (`=edge_precision`) and the highest
764
+ quantile (`=1-edge_precision`), by default 1e-2.
765
+ fmt : str, optional
766
+ A format string, e.g. 'ro' for red circles, by default 'o'.
767
+ ax : AxesWrapper, optional
768
+ Specifies the axes-wrapper on which the plot should be painted. If
769
+ not specified, the histogram will be plotted on a new axes in a new
770
+ figure. By default None.
771
+ **kwargs : **SettingDict
772
+ Specifies the plot settings, see `.set_plot()` for more details.
773
+
774
+ Returns
775
+ -------
776
+ Artist
777
+ An instance of Artist.
778
+
779
+ """
780
+ return self._get_artist(QQPlot, locals())
781
+
782
+ def ppplot(
783
+ self,
784
+ dist_or_sample: "DistName | np.ndarray | PlotDataSet" = "normal",
785
+ dots: int = 30,
786
+ edge_precision: float = 1e-6,
787
+ fmt: str = "o",
788
+ ax: Optional["AxesWrapper"] = None,
789
+ **kwargs: Unpack[SettingDict],
790
+ ) -> Artist:
791
+ """
792
+ Create a probability-probability plot.
793
+
794
+ Parameters
795
+ ----------
796
+ dist_or_sample : DistName | np.ndarray | PlotDataSet, optional
797
+ Specifies the distribution to compare with. If str, specifies a
798
+ theoretical distribution; if np.ndarray or PlotDataSet, specifies
799
+ another real sample. By default 'normal'.
800
+ dots : int, optional
801
+ Number of dots, by default 30.
802
+ edge_precision : float, optional
803
+ Specifies the lowest quantile (`=edge_precision`) and the highest
804
+ quantile (`=1-edge_precision`), by default 1e-6.
805
+ fmt : str, optional
806
+ A format string, e.g. 'ro' for red circles, by default 'o'.
807
+ ax : AxesWrapper, optional
808
+ Specifies the axes-wrapper on which the plot should be painted. If
809
+ not specified, the histogram will be plotted on a new axes in a new
810
+ figure. By default None.
811
+ **kwargs : **SettingDict
812
+ Specifies the plot settings, see `.set_plot()` for more details.
813
+
814
+ Returns
815
+ -------
816
+ Artist
817
+ An instance of Artist.
818
+
819
+ """
820
+ return self._get_artist(PPPlot, locals())
821
+
822
+ def ksplot(
823
+ self,
824
+ dist_or_sample: "DistName | np.ndarray | PlotDataSet" = "normal",
825
+ dots: int = 1000,
826
+ edge_precision: float = 1e-6,
827
+ fmt: str = "",
828
+ ax: Optional["AxesWrapper"] = None,
829
+ **kwargs: Unpack[SettingDict],
830
+ ) -> Artist:
831
+ """
832
+ Create a kolmogorov-smirnov plot.
833
+
834
+ Parameters
835
+ ----------
836
+ dist_or_sample : DistName | np.ndarray | PlotDataSet, optional
837
+ Specifies the distribution to compare with. If str, specifies a
838
+ theoretical distribution; if np.ndarray or PlotDataSet, specifies
839
+ another real sample. By default 'normal'.
840
+ dots : int, optional
841
+ Number of dots, by default 1000.
842
+ edge_precision : float, optional
843
+ Specifies the lowest quantile (`=edge_precision`) and the highest
844
+ quantile (`=1-edge_precision`), by default 1e-6.
845
+ fmt : str, optional
846
+ A format string, e.g. 'ro' for red circles, by default ''.
847
+ ax : AxesWrapper, optional
848
+ Specifies the axes-wrapper on which the plot should be painted. If
849
+ not specified, the histogram will be plotted on a new axes in a new
850
+ figure. By default None.
851
+ **kwargs : **SettingDict
852
+ Specifies the plot settings, see `.set_plot()` for more details.
853
+
854
+ Returns
855
+ -------
856
+ Artist
857
+ An instance of Artist.
858
+
859
+ """
860
+ return self._get_artist(KSPlot, locals())
861
+
862
+ def corrmap(
863
+ self,
864
+ annot: bool = True,
865
+ ax: Optional["AxesWrapper"] = None,
866
+ **kwargs: Unpack[SettingDict],
867
+ ) -> Artist:
868
+ """
869
+ Create a correlation heatmap.
870
+
871
+ Parameters
872
+ ----------
873
+ annot : bool, optional
874
+ Specifies whether to write the data value in each cell, by default
875
+ True.
876
+ ax : AxesWrapper, optional
877
+ Specifies the axes-wrapper on which the plot should be painted. If
878
+ not specified, the histogram will be plotted on a new axes in a new
879
+ figure. By default None.
880
+ **kwargs : **SettingDict
881
+ Specifies the plot settings, see `.set_plot()` for more details.
882
+
883
+ Returns
884
+ -------
885
+ Artist
886
+ An instance of Artist.
887
+
888
+ """
889
+ return self._get_artist(CorrMap, locals())
890
+
891
+ def _get_artist(self, cls: type["Plotter"], local: dict[str, Any]) -> Artist:
892
+ params: dict[str, Any] = {}
893
+ for key in cls.__init__.__code__.co_varnames[1:]:
894
+ params[key] = local[key]
895
+ if "format_label" in local["kwargs"] and not local["kwargs"]["format_label"]:
896
+ label = self.label
897
+ else:
898
+ label = self.formatted_label()
899
+ plotter = self.customize(cls, data=self.data, label=label, **params)
900
+ artist = single(self.customize)(Artist, plotter=plotter, ax=local["ax"])
901
+ if local["kwargs"]:
902
+ artist.plotter.load(local["kwargs"])
903
+ artist.load(local["kwargs"])
904
+ return artist
905
+
906
+
907
+ class PlotDataSets(MultiObject[PlotDataSet]):
908
+ """A duck subclass of `PlotDataSet`."""
909
+
910
+ def __init__(self, *args: Any) -> None:
911
+ if not args:
912
+ raise ValueError("no args")
913
+ objs: list[PlotDataSet] = []
914
+ for a in args:
915
+ if isinstance(a, self.__class__):
916
+ objs.extend(a.__multiobjects__)
917
+ elif isinstance(a, PlotDataSet):
918
+ objs.append(a)
919
+ else:
920
+ raise TypeError(f"invalid type: {a.__class__.__name__!r}")
921
+ super().__init__(objs, attr_reducer=self.__dataset_attr_reducer)
922
+
923
+ def __repr__(self) -> str:
924
+ data_info = "\n- ".join([x.data_info() for x in self.__multiobjects__])
925
+ return f"{PlotDataSet.__name__}\n- {data_info}"
926
+
927
+ def batched(self, n: int = 1) -> MultiObject:
928
+ """Overrides `PlotDataSet.batched()`."""
929
+ PlotDataSet.batched(self, n)
930
+ m = MultiObject()
931
+ for i in range(0, len(self.__multiobjects__), n):
932
+ m.__multiobjects__.append(PlotDataSets(*self.__multiobjects__[i : i + n]))
933
+ return m
934
+
935
+ def __dataset_attr_reducer(self, n: str) -> Callable:
936
+ match n:
937
+ case (
938
+ "hist"
939
+ | "plot"
940
+ | "scatter"
941
+ | "ppplot"
942
+ | "qqplot"
943
+ | "ksplot"
944
+ | "corrmap"
945
+ | "join"
946
+ | "_get_artist"
947
+ ):
948
+ return lambda _: partial(getattr(PlotDataSet, n), self)
949
+ case "customize":
950
+ return multipartial(
951
+ call_reducer=multipartial(
952
+ attr_reducer=lambda x: multipartial(call_reflex=x == "paint")
953
+ )
954
+ )
955
+ case _ if n.startswith("_"):
956
+ raise AttributeError(
957
+ f"cannot reach attribute '{n}' after dataset is joined"
958
+ )
959
+ case _:
960
+ return multipartial(call_reducer=self.__join_if_dataset)
961
+
962
+ @classmethod
963
+ def __join_if_dataset(cls, x: list) -> Any:
964
+ if x and isinstance(x[0], PlotDataSet):
965
+ return cls(*x)
966
+ if all(i is None for i in x):
967
+ return None
968
+ return REMAIN