plotcraft 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plotcraft/__init__.py ADDED
File without changes
plotcraft/draw.py ADDED
@@ -0,0 +1,739 @@
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.transforms as transforms
3
+ from typing import Union, List, Optional
4
+ import numpy as np
5
+ from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
6
+ from .utils import floor_significant_digits
7
+ from matplotlib.colors import Normalize
8
+ from matplotlib.cm import ScalarMappable
9
+ import matplotlib.patches as patches
10
+ import pandas as pd
11
+ from matplotlib.figure import Figure
12
+ from matplotlib.axes import Axes
13
+
14
+
15
+ def train_test_lift(
16
+ train:Union[List[List], np.ndarray],
17
+ test:Union[List[List], np.ndarray],
18
+ paired:bool=True,
19
+ colors:Optional[List[str]]=None,
20
+ labels:Optional[List[str]]=None,
21
+ yticks_interval:Optional[int|float]=None,
22
+ axis_range:Optional[List[Optional[int|float]]]=None,
23
+ offset:Optional[int|float]=None
24
+ ) -> tuple[Figure,Axes]:
25
+ """
26
+ Plot lifted histogram comparison between training and test distributions.
27
+
28
+ Visualize two groups of data (training vs test) as bar charts, with the test
29
+ bars lifted vertically for clear separation. Dual Y-axis ticks are drawn on
30
+ the left and right to match each distribution’s baseline. Suitable for length
31
+ distribution, value count, or density comparison in data analysis pipelines.
32
+
33
+ Parameters
34
+ ----------
35
+ train : list of lists or np.ndarray
36
+ Training data, either as paired [[x1, y1], ...] or separated [x_vals, y_vals].
37
+ test : list of lists or np.ndarray
38
+ Test data, in the same format as training data.
39
+ paired : bool, default=True
40
+ If True, input arrays are treated as paired points: [[x1, y1], [x2, y2], ...].
41
+ If False, inputs are separated coordinates: ([x1, x2, ...], [y1, y2, ...]).
42
+ colors : list of str, optional
43
+ Two-element color list for training and test bars.
44
+ Defaults to muted dark pink and deep blue.
45
+ labels : list of str, optional
46
+ Legend labels for training and test sets. Default: ["Train", "Test"].
47
+ yticks_interval : int or float, optional
48
+ Step interval for Y-axis ticks. If None, computed automatically from data range.
49
+ axis_range : list of int/float/None, optional
50
+ Axis limits in the form [X_min, X_max, Y_min, Y_max].
51
+ Use None to auto-compute a given limit.
52
+ offset : int or float, optional
53
+ Vertical offset to lift test bars. If None, set to half the tick interval.
54
+
55
+ Returns
56
+ -------
57
+ Figure
58
+ The figure object containing the plot.
59
+ Axes
60
+ Matplotlib Axes object containing the finished plot for further styling.
61
+
62
+ Examples
63
+ --------
64
+ >>> from plotcraft.draw import train_test_lift
65
+ >>> import numpy as np
66
+ >>> import matplotlib.pyplot as plt
67
+ >>> train_data = np.arange(21, 100,dtype=int)
68
+ >>> sigma, mu = 15, 60
69
+ >>> y = np.exp(-(train_data - mu) ** 2 / (2 * sigma ** 2))
70
+ >>> train_count = (y * 50 + 10).astype(int)
71
+ >>> test_data = train_data.copy()
72
+ >>> test_count = train_count.copy()
73
+ >>> fig, ax = train_test_lift([train_data,train_count],[test_data,test_count],paired=False)
74
+ >>> ax.set_xlabel('Length', fontsize=11)
75
+ >>> ax.set_ylabel('Frequency', fontsize=11, labelpad=35)
76
+ >>> plt.show()
77
+ """
78
+ train = np.array(train)
79
+ test = np.array(test)
80
+ if paired:
81
+ train_x = train[:,0]
82
+ train_y = train[:,1]
83
+ test_x = test[:,0]
84
+ test_y = test[:,1]
85
+ else:
86
+ train_x, train_y = train
87
+ test_x, test_y = test
88
+
89
+ if axis_range is None:
90
+ X_min = min(min(train_x), min(test_x))
91
+ X_max = max(max(train_x), max(test_x))
92
+ Y_min = 0
93
+ Y_max = max(max(train_y), max(test_y))
94
+ else:
95
+ X_min, X_max, Y_min, Y_max = axis_range
96
+ if X_min is None:
97
+ X_min = min(min(train_x), min(test_x))
98
+ if X_max is None:
99
+ X_max = max(max(train_x), max(test_x))
100
+ if Y_min is None:
101
+ Y_min = 0
102
+ if Y_max is None:
103
+ Y_max = max(max(train_y), max(test_y))
104
+
105
+ if labels is None:
106
+ labels = ["Train", "Test"]
107
+
108
+ if colors is None:
109
+ colors = ['#E0726D', '#5187B0']
110
+
111
+ if yticks_interval is None:
112
+ yticks_interval = floor_significant_digits((Y_max - Y_min)/4, 2)
113
+ tick_vals = np.arange(Y_min,Y_max,yticks_interval)
114
+
115
+ if offset is None:
116
+ offset = yticks_interval / 2
117
+
118
+ fig, ax = plt.subplots()
119
+
120
+ ax.bar(train_x, train_y, alpha=0.5,
121
+ color=colors[0], edgecolor='white', linewidth=0.5, label=labels[0])
122
+
123
+ ax.bar(test_x, test_y, bottom=offset, alpha=0.5,
124
+ color=colors[1], edgecolor='white', linewidth=0.5, label=labels[1])
125
+
126
+ ax.set_xlim(X_min-1, X_max+1)
127
+ ax.set_ylim(Y_min, Y_max + offset)
128
+ ax.set_yticks([])
129
+
130
+ ax.axhline(y=offset, color='#888888', linestyle='--', linewidth=1.5, dashes=(5, 2), alpha=0.8)
131
+
132
+ blend = transforms.blended_transform_factory(ax.transAxes, ax.transData)
133
+
134
+ for i, v in enumerate(tick_vals):
135
+ ax.text(-0.03, v, f'{v:.2f}', transform=blend,
136
+ fontsize=8, color=colors[0], va='center', ha='right')
137
+ ax.plot([-0.02, 0], [v, v], color=colors[0], linewidth=0.8,
138
+ clip_on=False, transform=blend)
139
+
140
+ if i:
141
+ ax.text(0.03, v + offset, f'{v:.2f}', transform=blend,
142
+ fontsize=8, color=colors[1], va='center', ha='left')
143
+ ax.plot([0, 0.02], [v + offset, v + offset], color=colors[1],
144
+ linewidth=0.8, clip_on=False, transform=blend)
145
+
146
+ ax.spines['top'].set_visible(False)
147
+ ax.spines['right'].set_visible(False)
148
+ ax.spines['left'].set_linewidth(1.5)
149
+ ax.spines['bottom'].set_linewidth(1.5)
150
+
151
+ ax.legend(frameon=True, fontsize=9, loc='upper right')
152
+ plt.subplots_adjust(left=0.18)
153
+ return fig, ax
154
+
155
+
156
+ def triangular_heatmap(
157
+ data: pd.DataFrame | np.ndarray,
158
+ annot: bool = True,
159
+ annot_kws: Optional[dict] = None,
160
+ linewidths: float | int = 1.5,
161
+ linecolor: str = 'white',
162
+ ticks_size: int | float = 9,
163
+ vmin: float | int = -1,
164
+ vmax: float | int = 1,
165
+ cmap: str | plt.Colormap = None,
166
+ norm: Normalize = None
167
+ ) -> tuple[Figure,Axes]:
168
+ """
169
+ Draw a heatmap of a triangle.
170
+
171
+ This function creates a triangular heatmap using diamond-shaped cells to visualize
172
+ the lower triangular part of a square correlation matrix. It supports custom color
173
+ mapping, value annotations, and styling of cell borders and labels.
174
+
175
+ Parameters
176
+ ----------
177
+ data : pd.DataFrame or np.ndarray
178
+ Square matrix (n×n) containing correlation values. Only the lower triangular
179
+ part of the matrix will be visualized. If a DataFrame is provided, column names
180
+ will be used as variable labels; if a numpy array is provided, labels will be
181
+ automatically generated as Var1, Var2, ..., Varn.
182
+
183
+ annot : bool, default=True
184
+ Whether to display numerical values inside each diamond cell.
185
+
186
+ annot_kws : dict or None, default=None
187
+ Keyword arguments for customizing the annotation text. Supported keys:
188
+ - 'size': Font size of the annotation (default: 20)
189
+ - 'color': Fixed text color; if not specified, text color will be white for
190
+ values with absolute value > 0.60, otherwise dark gray (#222222)
191
+ - 'fontweight': Font weight (default: 'bold')
192
+ - 'fontfamily': Font family (default: None, inherits global settings)
193
+
194
+ linewidths : float or int, default=1.5
195
+ Width of the border lines between diamond cells.
196
+
197
+ linecolor : str, default='white'
198
+ Color of the border lines between diamond cells.
199
+
200
+ ticks_size : float or int, default=9
201
+ Font size of the variable name labels on the triangular axes.
202
+
203
+ vmin : float or int, default=-1
204
+ Minimum value for color normalization. Values less than or equal to vmin
205
+ will be mapped to the bottom color of the colormap.
206
+
207
+ vmax : float or int, default=1
208
+ Maximum value for color normalization. Values greater than or equal to vmax
209
+ will be mapped to the top color of the colormap.
210
+
211
+ cmap : str or matplotlib.colors.Colormap, default=None
212
+ Colormap used for mapping correlation values to colors. If None, 'RdBu_r'
213
+ (red-blue reversed) will be used.
214
+
215
+ norm : matplotlib.colors.Normalize, default=None
216
+ Normalization object to scale data values to the [0, 1] range for colormap
217
+ mapping. If None, a basic Normalize instance with vmin and vmax will be used.
218
+ Other options include CenteredNorm or TwoSlopeNorm for asymmetric scaling.
219
+
220
+ Returns
221
+ -------
222
+ Figure
223
+ The figure object containing the plot.
224
+ Axes
225
+ Matplotlib Axes object containing the finished plot for further styling.
226
+
227
+ Examples
228
+ --------
229
+ >>> import numpy as np
230
+ >>> import pandas as pd
231
+ >>> from scipy import stats
232
+ >>> from plotcraft.draw import triangular_heatmap
233
+ >>> n_samples, n_vars = 200, 20
234
+ >>> data = np.random.randn(n_samples, n_vars)
235
+ >>> cols = [f"Var{i+1}" for i in range(n_vars)]
236
+ >>> df = pd.DataFrame(data, columns=cols)
237
+ >>> n = n_vars
238
+ >>> corr = np.ones((n, n))
239
+ >>> for i in range(n):
240
+ ... for j in range(i + 1, n):
241
+ ... r, _ = stats.spearmanr(df.iloc[:, i], df.iloc[:, j])
242
+ ... corr[i, j] = r
243
+ ... corr[j, i] = r
244
+ >>> corr_df = pd.DataFrame(corr, index=cols, columns=cols)
245
+ >>> fig, ax = triangular_heatmap(
246
+ ... corr_df,
247
+ ... annot=True,
248
+ ... annot_kws={'size': 7.2},
249
+ ... linewidths=0.5,
250
+ ... linecolor='white',
251
+ ... ticks_size=8,
252
+ ... vmax=1,
253
+ ... vmin=-1,
254
+ ... )
255
+ >>> plt.show()
256
+ """
257
+
258
+ assert vmax > vmin
259
+ if isinstance(data, pd.DataFrame):
260
+ columns = list(data.columns)
261
+ corr = data.values
262
+ else:
263
+ corr = np.asarray(data)
264
+ columns = [f"Var{i+1}" for i in range(corr.shape[0])]
265
+
266
+ n = corr.shape[0]
267
+ assert corr.shape == (n, n), "data 必须是方阵"
268
+
269
+ _annot_kws = {'size': 20, 'fontweight': 'bold', 'fontfamily': None, 'color': None}
270
+ if annot_kws:
271
+ _annot_kws.update(annot_kws)
272
+
273
+ def to_canvas(row, col):
274
+ cx = 2 * (n - 1) - (row + col)
275
+ cy = row - col
276
+ return cx, cy
277
+
278
+ half = 1.0
279
+
280
+ fig, ax = plt.subplots(figsize=(11, 9))
281
+ ax.set_aspect('equal')
282
+ ax.axis('off')
283
+ fig.patch.set_facecolor('white')
284
+
285
+ if cmap is None:
286
+ cmap = 'RdBu_r'
287
+ if isinstance(cmap, str):
288
+ cmap = plt.get_cmap(cmap)
289
+ if norm is None:
290
+ norm_c = Normalize(vmin=vmin, vmax=vmax)
291
+ else:
292
+ norm_c = norm
293
+
294
+ for row in range(n):
295
+ for col in range(row + 1):
296
+ val = corr[row, col]
297
+ color = cmap(norm_c(val))
298
+ cx, cy = to_canvas(row, col)
299
+
300
+ diamond = patches.Polygon(
301
+ [(cx, cy+half), (cx+half, cy), (cx, cy-half), (cx-half, cy)],
302
+ closed=True,
303
+ facecolor=color,
304
+ edgecolor=linecolor,
305
+ linewidth=linewidths,
306
+ zorder=2,
307
+ )
308
+ ax.add_patch(diamond)
309
+
310
+ if annot:
311
+ if _annot_kws['color'] is not None:
312
+ txt_color = _annot_kws['color']
313
+ else:
314
+ txt_color = 'white' if abs(val) > 0.60 else '#222222'
315
+
316
+ txt_kws = dict(
317
+ ha='center', va='center', zorder=3,
318
+ fontsize=_annot_kws['size'],
319
+ color=txt_color,
320
+ fontweight=_annot_kws['fontweight'],
321
+ )
322
+ if _annot_kws['fontfamily']:
323
+ txt_kws['fontfamily'] = _annot_kws['fontfamily']
324
+
325
+ ax.text(cx, cy, f'{val:.2f}', **txt_kws)
326
+
327
+ t = n * 0.005 + 0.6
328
+ offset = 0.18
329
+ sq2 = np.sqrt(2)
330
+
331
+ for i in range(n):
332
+ cx, cy = to_canvas(i, 0)
333
+ lx = cx + half * t + offset / sq2
334
+ ly = cy + half * (1 - t) + offset / sq2
335
+ ax.text(lx, ly, columns[i],
336
+ ha='left', va='bottom',
337
+ fontsize=ticks_size, rotation=45,
338
+ rotation_mode='anchor', zorder=4)
339
+
340
+ cx2, cy2 = to_canvas(n - 1, i)
341
+ lx2 = cx2 - half * t - offset / sq2
342
+ ly2 = cy2 + half * (1 - t) + offset / sq2
343
+ ax.text(lx2, ly2, columns[i],
344
+ ha='right', va='bottom',
345
+ fontsize=ticks_size, rotation=-45,
346
+ rotation_mode='anchor', zorder=4)
347
+
348
+ sm = ScalarMappable(cmap=cmap, norm=norm_c)
349
+ sm.set_array([])
350
+ cbar = fig.colorbar(sm, ax=ax, fraction=0.022, pad=0.01, shrink=0.65, aspect=22)
351
+ cbar.set_ticks(np.linspace(vmin,vmax,9))
352
+ cbar.ax.tick_params(labelsize=8.5)
353
+ cbar.outline.set_linewidth(0.5)
354
+
355
+ ax.set_xlim(-half - 3.0, 2*(n-1) + half + 3.0)
356
+ ax.set_ylim(-half - 0.5, (n-1) + half + 2.5)
357
+
358
+ plt.tight_layout()
359
+ return fig, ax
360
+
361
+ def enlarged_roc_curve(
362
+ *true_score_pairs,
363
+ colors:Optional[List[str]]=None,
364
+ labels:Optional[List[str]]=None,
365
+ paired:bool=False,
366
+ calculate:bool=True,
367
+ plot_kwargs:dict=None,
368
+ enlarged:bool=False,
369
+ to_enlarge_frame_location:List[int|float]=None,
370
+ enlarged_frame_location:List[int|float]=None,
371
+ enlarged_frame_xticks:List[int|float]=None,
372
+ enlarged_frame_yticks:List[int|float]=None,
373
+ enlarged_frame_transparent:bool=True,
374
+ legend_kwargs:dict=None
375
+ ) -> tuple[Figure,Axes]:
376
+ """
377
+ Plot ROC curves with optional local zoom-in functionality.
378
+
379
+ Convenience function to draw ROC curves for one or multiple models,
380
+ compute AUC scores, and add an inset axes to magnify a region of interest
381
+ in the ROC space (typically low FPR, high TPR).
382
+
383
+ Parameters
384
+ ----------
385
+ *true_score_pairs : sequence of array-like
386
+ Each argument is a pair (y_true, y_score). Multiple pairs can be
387
+ passed to compare ROC curves across models.
388
+
389
+ colors : list of str, default=None
390
+ List of colors for each ROC curve. Length must match the number
391
+ of model pairs provided.
392
+
393
+ labels : list of str, default=None
394
+ List of labels for each ROC curve. Length must match the number
395
+ of model pairs provided.
396
+
397
+ paired : bool, default=False
398
+ If True, each input pair is expected to be an N x 2 array
399
+ where each row is [y_true, score].
400
+ If False, each input pair is interpreted as two 1D arrays:
401
+ [y_true_array, score_array].
402
+
403
+ calculate : bool, default=True
404
+ Whether to compute and display AUC in the legend label.
405
+
406
+ plot_kwargs : dict, default=None
407
+ Keyword arguments passed to ax.plot() for ROC curves,
408
+ e.g., linewidth, linestyle, alpha.
409
+
410
+ enlarged : bool, default=False
411
+ Whether to create an inset axes with a zoomed view of a subregion.
412
+
413
+ to_enlarge_frame_location : list of float, length 4
414
+ Region in main axes to magnify, specified as [x1, y1, x2, y2]
415
+ in [0,1] coordinates, where (x1,y1) is lower-left and (x2,y2) upper-right.
416
+
417
+ enlarged_frame_location : list of float, length 4
418
+ Position of the inset axes within the main axes, in relative coordinates:
419
+ [x1, y1, x2, y2] lower-left to upper-right.
420
+
421
+ enlarged_frame_xticks : array-like, default=None
422
+ Custom tick positions for the x-axis of the inset plot.
423
+
424
+ enlarged_frame_yticks : array-like, default=None
425
+ Custom tick positions for the y-axis of the inset plot.
426
+
427
+ enlarged_frame_transparent : bool, default=True
428
+ Whether to make the background of the inset plot transparent.
429
+
430
+ legend_kwargs : dict, default=None
431
+ Keyword arguments passed to ax.legend(), e.g., fontsize, loc.
432
+
433
+ Returns
434
+ -------
435
+ Figure
436
+ The figure object containing the plot.
437
+ Axes
438
+ Matplotlib Axes object containing the finished plot for further styling.
439
+
440
+ Examples
441
+ --------
442
+ >>> import numpy as np
443
+ >>> import matplotlib.pyplot as plt
444
+ >>> from plotcraft.draw import enlarged_roc_curve
445
+ >>> arr = np.load('examples/data/true_score.npy')
446
+ >>> data_list = [[arr[i], arr[i+1]] for i in range(0, arr.shape[0], 2)]
447
+ >>> fig, ax = enlarged_roc_curve(
448
+ ... *data_list,
449
+ ... labels=[f'model{i}' for i in range(len(data_list))],
450
+ ... enlarged=True,
451
+ ... to_enlarge_frame_location=[0.01, 0.80, 0.15, 0.98],
452
+ ... enlarged_frame_location=[0.3, 0.5, 0.4, 0.4],
453
+ ... enlarged_frame_xticks=[0.045, 0.08, 0.115],
454
+ ... enlarged_frame_yticks=[0.9, 0.93, 0.96]
455
+ ... )
456
+ >>> plt.show()
457
+ """
458
+ fig, ax = plt.subplots(figsize=(8,8))
459
+
460
+ ax.plot([0, 1], [0, 1], color="lightgray", linestyle="--")
461
+
462
+ fpr_list, tpr_list = [], []
463
+ for i, true_score_pair in enumerate(true_score_pairs):
464
+ true_score_pair = np.array(true_score_pair)
465
+ if paired:
466
+ y_true, score = true_score_pair[:, 0], true_score_pair[:, 1]
467
+ else:
468
+ y_true, score = true_score_pair
469
+ fpr, tpr, _ = roc_curve(y_true, score)
470
+ if calculate:
471
+ roc_auc = auc(fpr, tpr)
472
+ add_str = f"(AUC = {roc_auc:.3f})"
473
+ else:
474
+ add_str = ""
475
+ fpr_list.append(fpr)
476
+ tpr_list.append(tpr)
477
+ parameters = {}
478
+ if colors is not None:
479
+ parameters['color'] = colors[i]
480
+ if labels is not None:
481
+ parameters['label'] = labels[i] + add_str
482
+ if plot_kwargs is not None:
483
+ parameters.update(plot_kwargs)
484
+ else:
485
+ parameters['linewidth'] = 2
486
+
487
+ ax.plot(fpr, tpr, **parameters)
488
+
489
+ ax.spines[["top", "left"]].set_visible(False)
490
+ ax.spines["right"].set_visible(True)
491
+ ax.yaxis.tick_right()
492
+ ax.yaxis.set_label_position("right")
493
+
494
+ # 主图标签与标题
495
+ ax.set_xlabel("False positive rate", fontsize=22, labelpad=10)
496
+ ax.set_ylabel("True positive rate", fontsize=22, labelpad=20)
497
+ ax.set_title("ROC curve", fontsize=22, pad=20)
498
+ ax.set_xlim(0, 1)
499
+ ax.set_ylim(0, 1)
500
+ if labels is not None:
501
+ if legend_kwargs is None:
502
+ legend_kwargs = {'fontsize':12}
503
+ ax.legend(loc="lower right",**legend_kwargs)
504
+ ax.grid(False)
505
+
506
+ if enlarged:
507
+ assert to_enlarge_frame_location is not None
508
+ assert enlarged_frame_location is not None
509
+ x1, y1, x2, y2 = to_enlarge_frame_location
510
+ assert 0 <= x1 < x2 <=1
511
+ assert 0 <= y1 < y2 <=1
512
+ axins = ax.inset_axes(enlarged_frame_location,
513
+ xlim=(x1, x2), ylim=(y1, y2))
514
+
515
+ if enlarged_frame_transparent:
516
+ axins.patch.set_alpha(0.0)
517
+
518
+ for i, (fpr, tpr) in enumerate(zip(fpr_list, tpr_list)):
519
+ parameters = {}
520
+ if colors is not None:
521
+ parameters['color'] = colors[i]
522
+ if plot_kwargs is not None:
523
+ parameters.update(plot_kwargs)
524
+ else:
525
+ parameters['linewidth'] = 2
526
+ axins.plot(fpr, tpr, **parameters)
527
+
528
+ axins.yaxis.tick_right()
529
+ if enlarged_frame_xticks is not None:
530
+ axins.set_xticks(enlarged_frame_xticks)
531
+ if enlarged_frame_yticks is not None:
532
+ axins.set_yticks(enlarged_frame_yticks)
533
+ axins.grid(False)
534
+
535
+ ax.indicate_inset_zoom(axins, edgecolor="black", linewidth=1.5)
536
+
537
+ plt.tight_layout()
538
+ return fig, ax
539
+
540
+ def enlarged_pr_curve(*true_score_pairs: List[List] | np.ndarray,
541
+ colors:Optional[List[str]]=None,
542
+ labels:Optional[List[str]]=None,
543
+ paired:bool=False,
544
+ calculate:bool=True,
545
+ plot_kwargs:dict=None,
546
+ enlarged:bool=False,
547
+ to_enlarge_frame_location:List[int|float]=None,
548
+ enlarged_frame_location:List[int|float]=None,
549
+ enlarged_frame_xticks:List[int|float]=None,
550
+ enlarged_frame_yticks:List[int|float]=None,
551
+ enlarged_frame_transparent:bool=True,
552
+ legend_kwargs:dict=None) -> tuple[Axes, Figure]:
553
+ """
554
+ Plot PR curves with optional local zoom-in functionality.
555
+
556
+ Convenience function to draw PR curves for one or multiple models,
557
+ compute AUC scores, and add an inset axes to magnify a region of interest
558
+ in the PR space (typically high Recall, high Precision).
559
+
560
+ Parameters
561
+ ----------
562
+ *true_score_pairs : sequence of array-like
563
+ Each argument is a pair (y_true, y_score). Multiple pairs can be
564
+ passed to compare PR curves across models.
565
+
566
+ colors : list of str, default=None
567
+ List of colors for each PR curve. Length must match the number
568
+ of model pairs provided.
569
+
570
+ labels : list of str, default=None
571
+ List of labels for each PR curve. Length must match the number
572
+ of model pairs provided.
573
+
574
+ paired : bool, default=False
575
+ If True, each input pair is expected to be an N x 2 array
576
+ where each row is [y_true, score].
577
+ If False, each input pair is interpreted as two 1D arrays:
578
+ [y_true_array, score_array].
579
+
580
+ calculate : bool, default=True
581
+ Whether to compute and display AUC in the legend label.
582
+
583
+ plot_kwargs : dict, default=None
584
+ Keyword arguments passed to ax.plot() for PR curves,
585
+ e.g., linewidth, linestyle, alpha.
586
+
587
+ enlarged : bool, default=False
588
+ Whether to create an inset axes with a zoomed view of a subregion.
589
+
590
+ to_enlarge_frame_location : list of float, length 4
591
+ Region in main axes to magnify, specified as [x1, y1, x2, y2]
592
+ in [0,1] coordinates, where (x1,y1) is lower-left and (x2,y2) upper-right.
593
+
594
+ enlarged_frame_location : list of float, length 4
595
+ Position of the inset axes within the main axes, in relative coordinates:
596
+ [x1, y1, x2, y2] lower-left to upper-right.
597
+
598
+ enlarged_frame_xticks : array-like, default=None
599
+ Custom tick positions for the x-axis of the inset plot.
600
+
601
+ enlarged_frame_yticks : array-like, default=None
602
+ Custom tick positions for the y-axis of the inset plot.
603
+
604
+ enlarged_frame_transparent : bool, default=True
605
+ Whether to make the background of the inset plot transparent.
606
+
607
+ legend_kwargs : dict, default=None
608
+ Keyword arguments passed to ax.legend(), e.g., fontsize, loc.
609
+
610
+ Returns
611
+ -------
612
+ Figure
613
+ The figure object containing the plot.
614
+ Axes
615
+ Matplotlib Axes object containing the finished plot for further styling.
616
+
617
+ Examples
618
+ --------
619
+ >>> import numpy as np
620
+ >>> import matplotlib.pyplot as plt
621
+ >>> arr = np.load('./data/true_score.npy')
622
+ >>> data_list = [[arr[i], arr[i+1]] for i in range(0, arr.shape[0], 2)]
623
+ >>> fig, ax = enlarged_pr_curve(*data_list,
624
+ ... labels=[f'model{i}' for i in range(len(datas))],
625
+ ... enlarged=True,
626
+ ... to_enlarge_frame_location=[0.82,0.75,0.97,0.93],
627
+ ... enlarged_frame_location=[0.3, 0.5, 0.4, 0.4],
628
+ ... enlarged_frame_xticks=[0.858,0.895,0.93],
629
+ ... enlarged_frame_yticks=[0.795, 0.84, 0.885]
630
+ ... )
631
+ >>> plt.show()
632
+ """
633
+ fig, ax = plt.subplots(figsize=(8, 8))
634
+
635
+ precision_list, recall_list = [], []
636
+ for i, true_score_pair in enumerate(true_score_pairs):
637
+ true_score_pair = np.array(true_score_pair)
638
+ if paired:
639
+ y_true, score = true_score_pair[:, 0], true_score_pair[:, 1]
640
+ else:
641
+ y_true, score = true_score_pair
642
+ precision, recall, _ = precision_recall_curve(y_true, score)
643
+ if calculate:
644
+ AP = average_precision_score(y_true, score)
645
+ add_str = f"(AUC = {AP:.3f})"
646
+ else:
647
+ add_str = ""
648
+ precision_list.append(precision)
649
+ recall_list.append(recall)
650
+ parameters = {}
651
+ if colors is not None:
652
+ parameters['color'] = colors[i]
653
+ if labels is not None:
654
+ parameters['label'] = labels[i] + add_str
655
+ if plot_kwargs is not None:
656
+ parameters.update(plot_kwargs)
657
+ else:
658
+ parameters['linewidth'] = 2
659
+
660
+ ax.plot(recall, precision, **parameters)
661
+
662
+ ax.spines[["top", "right"]].set_visible(False)
663
+ # ax.yaxis.tick_right()
664
+ # ax.yaxis.set_label_position("right")
665
+
666
+ # 主图标签与标题
667
+ ax.set_xlabel("Recall", fontsize=22, labelpad=10)
668
+ ax.set_ylabel("Precision", fontsize=22, labelpad=20)
669
+ ax.set_title("PR curve", fontsize=22, pad=20)
670
+ ax.set_xlim(0, 1)
671
+ ax.set_ylim(0, 1)
672
+ if labels is not None:
673
+ if legend_kwargs is None:
674
+ legend_kwargs = {'fontsize': 12}
675
+ ax.legend(loc="lower left", **legend_kwargs)
676
+ ax.grid(False)
677
+
678
+ if enlarged:
679
+ assert to_enlarge_frame_location is not None
680
+ assert enlarged_frame_location is not None
681
+ x1, y1, x2, y2 = to_enlarge_frame_location
682
+ assert 0 <= x1 < x2 <= 1
683
+ assert 0 <= y1 < y2 <= 1
684
+ axins = ax.inset_axes(enlarged_frame_location,
685
+ xlim=(x1, x2), ylim=(y1, y2))
686
+ if enlarged_frame_transparent:
687
+ axins.patch.set_alpha(0.0)
688
+
689
+ for i, (recall, precision) in enumerate(zip(recall_list, precision_list)):
690
+ parameters = {}
691
+ if colors is not None:
692
+ parameters['color'] = colors[i]
693
+ if plot_kwargs is not None:
694
+ parameters.update(plot_kwargs)
695
+ else:
696
+ parameters['linewidth'] = 2
697
+ axins.plot(recall, precision, **parameters)
698
+
699
+ # axins.yaxis.tick_right()
700
+ if enlarged_frame_xticks is not None:
701
+ axins.set_xticks(enlarged_frame_xticks)
702
+ if enlarged_frame_yticks is not None:
703
+ axins.set_yticks(enlarged_frame_yticks)
704
+ axins.grid(False)
705
+
706
+ ax.indicate_inset_zoom(axins, edgecolor="black", linewidth=1.5)
707
+
708
+ plt.tight_layout()
709
+ return fig, ax
710
+
711
+
712
+ if __name__ == '__main__':
713
+ import numpy as np
714
+ import pandas as pd
715
+ from scipy import stats
716
+
717
+ n_samples, n_vars = 200, 20
718
+ data = np.random.randn(n_samples, n_vars)
719
+ cols = [f"Var{i + 1}" for i in range(n_vars)]
720
+ df = pd.DataFrame(data, columns=cols)
721
+ n = n_vars
722
+ corr = np.ones((n, n))
723
+ for i in range(n):
724
+ for j in range(i + 1, n):
725
+ r, _ = stats.spearmanr(df.iloc[:, i], df.iloc[:, j])
726
+ corr[i, j] = r
727
+ corr[j, i] = r
728
+ corr_df = pd.DataFrame(corr, index=cols, columns=cols)
729
+ ax = triangular_heatmap(
730
+ corr_df,
731
+ annot=True,
732
+ annot_kws={'size': 7.2},
733
+ linewidths=0.5,
734
+ linecolor='white',
735
+ ticks_size=8,
736
+ vmax=1,
737
+ vmin=-1,
738
+ )
739
+ plt.show()
plotcraft/tools.py ADDED
File without changes
plotcraft/utils.py ADDED
@@ -0,0 +1,58 @@
1
+ import math
2
+
3
+ def floor_significant_digits(x:int | float, digits:int) -> int | float:
4
+ """Round a number DOWN to the specified number of significant digits.
5
+
6
+ This function always rounds toward negative infinity
7
+ to retain a fixed number of significant digits,
8
+ without rounding up. This is especially useful for
9
+ truncating values strictly downward for numerical precision.
10
+
11
+ Parameters
12
+ ----------
13
+ x : int or float
14
+ Input number to be rounded down to significant digits.
15
+
16
+ digits : int
17
+ Number of significant digits to retain.
18
+ Must be a positive integer.
19
+
20
+ Returns
21
+ -------
22
+ int or float
23
+ The input value rounded down to the specified
24
+ number of significant digits.
25
+
26
+ Raises
27
+ ------
28
+ ValueError
29
+ If ``digits`` is not a positive integer.
30
+
31
+ Examples
32
+ --------
33
+ >>> floor_significant_digits(123456, 2)
34
+ 120000
35
+ >>> floor_significant_digits(-123456, 2)
36
+ -120000
37
+ >>> floor_significant_digits(1.23456, 2)
38
+ 1.2
39
+ >>> floor_significant_digits(-1.23456, 2)
40
+ -1.2
41
+ """
42
+ if digits <= 0 or type(digits) != int:
43
+ raise ValueError("floor significant digits should be positive int")
44
+ if x == 0:
45
+ return 0
46
+ elif x > 0:
47
+ exp = math.floor(math.log10(x))
48
+ decimals = exp - digits + 1
49
+ if decimals < 0:
50
+ decimals = -decimals
51
+ scale = 10 ** decimals
52
+ return math.floor(x * scale) / scale
53
+ else:
54
+ scale = 10 ** decimals
55
+ return math.floor(x / scale) * scale
56
+ else:
57
+ x = abs(x)
58
+ return -floor_significant_digits(x,digits)
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: plotcraft
3
+ Version: 0.1.0
4
+ Summary: supply more plot function for python, it will renew forever
5
+ Author: descartescy
6
+ Author-email: caowangyangcao@163.com
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.12
11
+ Requires-Dist: matplotlib==3.10.8
12
+ Requires-Dist: scikit-learn==1.8.0
13
+ Requires-Dist: pandas==3.0.1
14
+ Dynamic: author
15
+ Dynamic: author-email
16
+ Dynamic: classifier
17
+ Dynamic: requires-dist
18
+ Dynamic: requires-python
19
+ Dynamic: summary
@@ -0,0 +1,8 @@
1
+ plotcraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ plotcraft/draw.py,sha256=bqYi_SInFXiT1y3U1xGeqCO8auDTBJFry6X9mw9fuWQ,27283
3
+ plotcraft/tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ plotcraft/utils.py,sha256=6oxKrjr9mx4hwOGZQTYtKjkcva36I_fD8VPmAk7ebgk,1701
5
+ plotcraft-0.1.0.dist-info/METADATA,sha256=O3sDOpr_KiasI1-zzdJZAiVz-F7JaW2bKorAHPvt1TU,589
6
+ plotcraft-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
7
+ plotcraft-0.1.0.dist-info/top_level.txt,sha256=RyOmPcxfwKJ1F2ublKd26MJoiQkQkJqctfjJE1ofH4Q,10
8
+ plotcraft-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ plotcraft