hossam 0.4.13__py3-none-any.whl → 0.4.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hossam/__init__.py +23 -16
- hossam/hs_cluster.py +777 -40
- hossam/hs_plot.py +322 -218
- hossam/hs_study.py +76 -0
- {hossam-0.4.13.dist-info → hossam-0.4.14.dist-info}/METADATA +1 -1
- {hossam-0.4.13.dist-info → hossam-0.4.14.dist-info}/RECORD +9 -8
- {hossam-0.4.13.dist-info → hossam-0.4.14.dist-info}/WHEEL +0 -0
- {hossam-0.4.13.dist-info → hossam-0.4.14.dist-info}/licenses/LICENSE +0 -0
- {hossam-0.4.13.dist-info → hossam-0.4.14.dist-info}/top_level.txt +0 -0
hossam/hs_plot.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Callable
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import seaborn as sb
|
|
9
9
|
import matplotlib.pyplot as plt
|
|
10
|
-
from matplotlib.pyplot import Axes
|
|
10
|
+
from matplotlib.pyplot import Axes # type: ignore
|
|
11
11
|
from pandas import Series, DataFrame
|
|
12
12
|
from math import sqrt
|
|
13
13
|
from pandas import DataFrame
|
|
@@ -31,7 +31,7 @@ from sklearn.metrics import (
|
|
|
31
31
|
auc,
|
|
32
32
|
confusion_matrix,
|
|
33
33
|
silhouette_score,
|
|
34
|
-
silhouette_samples
|
|
34
|
+
silhouette_samples,
|
|
35
35
|
)
|
|
36
36
|
|
|
37
37
|
# ===================================================================
|
|
@@ -45,13 +45,24 @@ config = SimpleNamespace(
|
|
|
45
45
|
line_width=1,
|
|
46
46
|
grid_alpha=0.3,
|
|
47
47
|
grid_width=0.5,
|
|
48
|
-
fill_alpha=0.3
|
|
48
|
+
fill_alpha=0.3,
|
|
49
49
|
)
|
|
50
50
|
|
|
51
|
+
|
|
51
52
|
# ===================================================================
|
|
52
53
|
# 기본 크기가 설정된 Figure와 Axes를 생성한다
|
|
53
54
|
# ===================================================================
|
|
54
|
-
def get_default_ax(
|
|
55
|
+
def get_default_ax(
|
|
56
|
+
width: int = config.width,
|
|
57
|
+
height: int = config.height,
|
|
58
|
+
rows: int = 1,
|
|
59
|
+
cols: int = 1,
|
|
60
|
+
dpi: int = config.dpi,
|
|
61
|
+
flatten: bool = False,
|
|
62
|
+
ws: int | None = None,
|
|
63
|
+
hs: int | None = None,
|
|
64
|
+
title: str | None = None,
|
|
65
|
+
):
|
|
55
66
|
"""기본 크기의 Figure와 Axes를 생성한다.
|
|
56
67
|
|
|
57
68
|
Args:
|
|
@@ -78,7 +89,7 @@ def get_default_ax(width: int = config.width, height: int = config.height, rows:
|
|
|
78
89
|
fig.subplots_adjust(wspace=ws, hspace=hs)
|
|
79
90
|
|
|
80
91
|
if title and is_array:
|
|
81
|
-
fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight=
|
|
92
|
+
fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight="bold")
|
|
82
93
|
|
|
83
94
|
if flatten == True:
|
|
84
95
|
# 단일 Axes인 경우 리스트로 변환
|
|
@@ -97,7 +108,7 @@ def get_default_ax(width: int = config.width, height: int = config.height, rows:
|
|
|
97
108
|
for spine in a.spines.values():
|
|
98
109
|
spine.set_linewidth(config.frame_width)
|
|
99
110
|
else:
|
|
100
|
-
for spine in ax.spines.values():
|
|
111
|
+
for spine in ax.spines.values(): # type: ignore
|
|
101
112
|
spine.set_linewidth(config.frame_width)
|
|
102
113
|
|
|
103
114
|
return fig, ax
|
|
@@ -106,7 +117,17 @@ def get_default_ax(width: int = config.width, height: int = config.height, rows:
|
|
|
106
117
|
# ===================================================================
|
|
107
118
|
# 기본 크기가 설정된 Figure와 Axes를 생성한다
|
|
108
119
|
# ===================================================================
|
|
109
|
-
def create_figure(
|
|
120
|
+
def create_figure(
|
|
121
|
+
width: int = config.width,
|
|
122
|
+
height: int = config.height,
|
|
123
|
+
rows: int = 1,
|
|
124
|
+
cols: int = 1,
|
|
125
|
+
dpi: int = config.dpi,
|
|
126
|
+
flatten: bool = False,
|
|
127
|
+
ws: int | None = None,
|
|
128
|
+
hs: int | None = None,
|
|
129
|
+
title: str | None = None,
|
|
130
|
+
):
|
|
110
131
|
"""기본 크기의 Figure와 Axes를 생성한다. get_default_ax의 래퍼 함수.
|
|
111
132
|
|
|
112
133
|
Args:
|
|
@@ -130,7 +151,14 @@ def create_figure(width: int = config.width, height: int = config.height, rows:
|
|
|
130
151
|
# ===================================================================
|
|
131
152
|
# 그래프의 그리드, 레이아웃을 정리하고 필요 시 저장 또는 표시한다
|
|
132
153
|
# ===================================================================
|
|
133
|
-
def finalize_plot(
|
|
154
|
+
def finalize_plot(
|
|
155
|
+
ax: Axes | np.ndarray | list,
|
|
156
|
+
callback: Callable | None = None,
|
|
157
|
+
outparams: bool = False,
|
|
158
|
+
save_path: str | None = None,
|
|
159
|
+
grid: bool = True,
|
|
160
|
+
title: str | None = None,
|
|
161
|
+
) -> None:
|
|
134
162
|
"""공통 후처리를 수행한다: 콜백 실행, 레이아웃 정리, 필요 시 표시/종료.
|
|
135
163
|
|
|
136
164
|
Args:
|
|
@@ -149,7 +177,7 @@ def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None
|
|
|
149
177
|
# callback 실행
|
|
150
178
|
if callback:
|
|
151
179
|
if is_array:
|
|
152
|
-
for a in
|
|
180
|
+
for a in ax.flat if isinstance(ax, np.ndarray) else ax:
|
|
153
181
|
callback(a)
|
|
154
182
|
else:
|
|
155
183
|
callback(ax)
|
|
@@ -157,7 +185,7 @@ def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None
|
|
|
157
185
|
# grid 설정
|
|
158
186
|
if grid:
|
|
159
187
|
if is_array:
|
|
160
|
-
for a in
|
|
188
|
+
for a in ax.flat if isinstance(ax, np.ndarray) else ax:
|
|
161
189
|
a.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
|
|
162
190
|
else:
|
|
163
191
|
ax.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
|
|
@@ -165,10 +193,10 @@ def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None
|
|
|
165
193
|
plt.tight_layout()
|
|
166
194
|
|
|
167
195
|
if title and not is_array:
|
|
168
|
-
ax.set_title(title, fontsize=config.font_size * 1.3, pad=7, fontweight=
|
|
196
|
+
ax.set_title(title, fontsize=config.font_size * 1.3, pad=7, fontweight="bold")
|
|
169
197
|
|
|
170
198
|
if save_path is not None:
|
|
171
|
-
plt.savefig(save_path, dpi=config.dpi * 2, bbox_inches=
|
|
199
|
+
plt.savefig(save_path, dpi=config.dpi * 2, bbox_inches="tight")
|
|
172
200
|
|
|
173
201
|
if outparams:
|
|
174
202
|
plt.show()
|
|
@@ -178,7 +206,14 @@ def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None
|
|
|
178
206
|
# ===================================================================
|
|
179
207
|
# 그래프의 그리드, 레이아웃을 정리하고 필요 시 저장 또는 표시한다
|
|
180
208
|
# ===================================================================
|
|
181
|
-
def show_figure(
|
|
209
|
+
def show_figure(
|
|
210
|
+
ax: Axes | np.ndarray,
|
|
211
|
+
callback: Callable | None = None,
|
|
212
|
+
outparams: bool = False,
|
|
213
|
+
save_path: str | None = None,
|
|
214
|
+
grid: bool = True,
|
|
215
|
+
title: str | None = None,
|
|
216
|
+
) -> None:
|
|
182
217
|
"""공통 후처리를 수행한다: 콜백 실행, 레이아웃 정리, 필요 시 표시/종료.
|
|
183
218
|
finalize_plot의 래퍼 함수.
|
|
184
219
|
|
|
@@ -241,7 +276,7 @@ def lineplot(
|
|
|
241
276
|
outparams = False
|
|
242
277
|
|
|
243
278
|
if ax is None:
|
|
244
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
279
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
245
280
|
outparams = True
|
|
246
281
|
|
|
247
282
|
# hue가 있을 때만 palette 사용, 없으면 color 사용
|
|
@@ -262,7 +297,7 @@ def lineplot(
|
|
|
262
297
|
lineplot_kwargs.update(params)
|
|
263
298
|
|
|
264
299
|
sb.lineplot(**lineplot_kwargs, linewidth=linewidth)
|
|
265
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
300
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
266
301
|
|
|
267
302
|
|
|
268
303
|
# ===================================================================
|
|
@@ -316,7 +351,7 @@ def boxplot(
|
|
|
316
351
|
outparams = False
|
|
317
352
|
|
|
318
353
|
if ax is None:
|
|
319
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
354
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
320
355
|
outparams = True
|
|
321
356
|
|
|
322
357
|
if xname is not None and yname is not None:
|
|
@@ -344,13 +379,17 @@ def boxplot(
|
|
|
344
379
|
if stat_pairs is None:
|
|
345
380
|
stat_pairs = [df[xname].dropna().unique().tolist()]
|
|
346
381
|
|
|
347
|
-
annotator = Annotator(
|
|
348
|
-
|
|
382
|
+
annotator = Annotator(
|
|
383
|
+
ax, data=df, x=xname, y=yname, pairs=stat_pairs, orient=orient
|
|
384
|
+
)
|
|
385
|
+
annotator.configure(
|
|
386
|
+
test=stat_test, text_format=stat_text_format, loc=stat_loc
|
|
387
|
+
)
|
|
349
388
|
annotator.apply_and_annotate()
|
|
350
389
|
else:
|
|
351
|
-
sb.boxplot(data=df, orient=orient, ax=ax, linewidth=linewidth, **params)
|
|
390
|
+
sb.boxplot(data=df, orient=orient, ax=ax, linewidth=linewidth, **params) # type: ignore
|
|
352
391
|
|
|
353
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
392
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
354
393
|
|
|
355
394
|
|
|
356
395
|
# ===================================================================
|
|
@@ -372,7 +411,7 @@ def pvalue1_anotation(
|
|
|
372
411
|
save_path: str | None = None,
|
|
373
412
|
callback: Callable | None = None,
|
|
374
413
|
ax: Axes | None = None,
|
|
375
|
-
**params
|
|
414
|
+
**params,
|
|
376
415
|
) -> None:
|
|
377
416
|
"""
|
|
378
417
|
boxplot의 wrapper 함수로, 상자그림에 p-value 주석을 추가한다.
|
|
@@ -395,7 +434,7 @@ def pvalue1_anotation(
|
|
|
395
434
|
save_path=save_path,
|
|
396
435
|
callback=callback,
|
|
397
436
|
ax=ax,
|
|
398
|
-
**params
|
|
437
|
+
**params,
|
|
399
438
|
)
|
|
400
439
|
|
|
401
440
|
|
|
@@ -452,7 +491,9 @@ def kdeplot(
|
|
|
452
491
|
# 사분위수 분할 전용 처리 (1D KDE만 지원)
|
|
453
492
|
if quartile_split:
|
|
454
493
|
if yname is not None:
|
|
455
|
-
raise ValueError(
|
|
494
|
+
raise ValueError(
|
|
495
|
+
"quartile_split은 1차원 KDE(xname)에서만 사용할 수 있습니다."
|
|
496
|
+
)
|
|
456
497
|
|
|
457
498
|
series = df[xname].dropna()
|
|
458
499
|
if series.empty:
|
|
@@ -499,7 +540,7 @@ def kdeplot(
|
|
|
499
540
|
return
|
|
500
541
|
|
|
501
542
|
if ax is None:
|
|
502
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
543
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
503
544
|
outparams = True
|
|
504
545
|
|
|
505
546
|
# 기본 kwargs 설정
|
|
@@ -529,7 +570,7 @@ def kdeplot(
|
|
|
529
570
|
|
|
530
571
|
sb.kdeplot(**kdeplot_kwargs)
|
|
531
572
|
|
|
532
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
573
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
533
574
|
|
|
534
575
|
|
|
535
576
|
# ===================================================================
|
|
@@ -576,7 +617,7 @@ def histplot(
|
|
|
576
617
|
outparams = False
|
|
577
618
|
|
|
578
619
|
if ax is None:
|
|
579
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
620
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
580
621
|
outparams = True
|
|
581
622
|
|
|
582
623
|
if bins:
|
|
@@ -604,7 +645,7 @@ def histplot(
|
|
|
604
645
|
"hue": hue,
|
|
605
646
|
"kde": kde,
|
|
606
647
|
"linewidth": linewidth,
|
|
607
|
-
"ax": ax
|
|
648
|
+
"ax": ax,
|
|
608
649
|
}
|
|
609
650
|
|
|
610
651
|
if hue is not None and palette is not None:
|
|
@@ -615,7 +656,7 @@ def histplot(
|
|
|
615
656
|
histplot_kwargs.update(params)
|
|
616
657
|
sb.histplot(**histplot_kwargs)
|
|
617
658
|
|
|
618
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
659
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
619
660
|
|
|
620
661
|
|
|
621
662
|
# ===================================================================
|
|
@@ -658,7 +699,7 @@ def stackplot(
|
|
|
658
699
|
outparams = False
|
|
659
700
|
|
|
660
701
|
if ax is None:
|
|
661
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
702
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
662
703
|
outparams = True
|
|
663
704
|
|
|
664
705
|
df2 = df[[xname, hue]].copy()
|
|
@@ -685,11 +726,11 @@ def stackplot(
|
|
|
685
726
|
sb.histplot(**stackplot_kwargs)
|
|
686
727
|
|
|
687
728
|
# 그래프의 x축 항목 수 만큼 반복
|
|
688
|
-
for p in ax.patches:
|
|
729
|
+
for p in ax.patches: # type: ignore
|
|
689
730
|
# 각 막대의 위치, 넓이, 높이
|
|
690
|
-
left, bottom, width, height = p.get_bbox().bounds
|
|
731
|
+
left, bottom, width, height = p.get_bbox().bounds # type: ignore
|
|
691
732
|
# 막대의 중앙에 글자 표시하기
|
|
692
|
-
ax.annotate(
|
|
733
|
+
ax.annotate( # type: ignore
|
|
693
734
|
"%0.1f%%" % (height * 100),
|
|
694
735
|
xy=(left + width / 2, bottom + height / 2),
|
|
695
736
|
ha="center",
|
|
@@ -698,10 +739,10 @@ def stackplot(
|
|
|
698
739
|
|
|
699
740
|
if str(df[xname].dtype) in ["int", "int32", "int64", "float", "float32", "float64"]:
|
|
700
741
|
xticks = list(df[xname].unique())
|
|
701
|
-
ax.set_xticks(xticks)
|
|
742
|
+
ax.set_xticks(xticks) # type: ignore
|
|
702
743
|
ax.set_xticklabels(xticks) # type: ignore
|
|
703
744
|
|
|
704
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
745
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
705
746
|
|
|
706
747
|
|
|
707
748
|
# ===================================================================
|
|
@@ -750,10 +791,9 @@ def scatterplot(
|
|
|
750
791
|
outparams = False
|
|
751
792
|
|
|
752
793
|
if ax is None:
|
|
753
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
794
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
754
795
|
outparams = True
|
|
755
796
|
|
|
756
|
-
|
|
757
797
|
if outline and hue is not None:
|
|
758
798
|
# 군집별 값의 종류별로 반복 수행
|
|
759
799
|
for c in df[hue].unique():
|
|
@@ -770,8 +810,11 @@ def scatterplot(
|
|
|
770
810
|
# 마지막 좌표 이후에 첫 번째 좌표를 연결
|
|
771
811
|
points = np.append(hull.vertices, hull.vertices[0])
|
|
772
812
|
|
|
773
|
-
ax.plot(
|
|
774
|
-
df_c.iloc[points, 0],
|
|
813
|
+
ax.plot( # type: ignore
|
|
814
|
+
df_c.iloc[points, 0],
|
|
815
|
+
df_c.iloc[points, 1],
|
|
816
|
+
linewidth=linewidth,
|
|
817
|
+
linestyle=":",
|
|
775
818
|
)
|
|
776
819
|
ax.fill(df_c.iloc[points, 0], df_c.iloc[points, 1], alpha=0.1) # type: ignore
|
|
777
820
|
except:
|
|
@@ -798,27 +841,26 @@ def scatterplot(
|
|
|
798
841
|
sb.scatterplot(data=df, **scatterplot_kwargs)
|
|
799
842
|
else:
|
|
800
843
|
# 핵심벡터
|
|
801
|
-
scatterplot_kwargs[
|
|
844
|
+
scatterplot_kwargs["edgecolor"] = "#ffffff"
|
|
802
845
|
sb.scatterplot(data=df[df[vector] == "core"], **scatterplot_kwargs)
|
|
803
846
|
|
|
804
847
|
# 외곽백터
|
|
805
|
-
scatterplot_kwargs[
|
|
806
|
-
scatterplot_kwargs[
|
|
807
|
-
scatterplot_kwargs[
|
|
808
|
-
scatterplot_kwargs[
|
|
848
|
+
scatterplot_kwargs["edgecolor"] = "#000000"
|
|
849
|
+
scatterplot_kwargs["s"] = 25
|
|
850
|
+
scatterplot_kwargs["marker"] = "^"
|
|
851
|
+
scatterplot_kwargs["linewidth"] = 0.8
|
|
809
852
|
sb.scatterplot(data=df[df[vector] == "border"], **scatterplot_kwargs)
|
|
810
853
|
|
|
811
854
|
# 노이즈벡터
|
|
812
|
-
scatterplot_kwargs[
|
|
813
|
-
scatterplot_kwargs[
|
|
814
|
-
scatterplot_kwargs[
|
|
815
|
-
scatterplot_kwargs[
|
|
816
|
-
scatterplot_kwargs[
|
|
817
|
-
scatterplot_kwargs[
|
|
855
|
+
scatterplot_kwargs["edgecolor"] = None
|
|
856
|
+
scatterplot_kwargs["s"] = 25
|
|
857
|
+
scatterplot_kwargs["marker"] = "x"
|
|
858
|
+
scatterplot_kwargs["linewidth"] = 2
|
|
859
|
+
scatterplot_kwargs["color"] = "#ff0000"
|
|
860
|
+
scatterplot_kwargs["hue"] = None
|
|
818
861
|
sb.scatterplot(data=df[df[vector] == "noise"], **scatterplot_kwargs)
|
|
819
862
|
|
|
820
|
-
|
|
821
|
-
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
863
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
822
864
|
|
|
823
865
|
|
|
824
866
|
# ===================================================================
|
|
@@ -861,7 +903,7 @@ def regplot(
|
|
|
861
903
|
outparams = False
|
|
862
904
|
|
|
863
905
|
if ax is None:
|
|
864
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
906
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
865
907
|
outparams = True
|
|
866
908
|
|
|
867
909
|
# regplot은 hue를 지원하지 않으므로 palette를 color로 변환
|
|
@@ -877,13 +919,9 @@ def regplot(
|
|
|
877
919
|
"s": 20,
|
|
878
920
|
"linewidths": 0.5,
|
|
879
921
|
"edgecolor": "w",
|
|
880
|
-
"color": scatter_color
|
|
881
|
-
},
|
|
882
|
-
"line_kws": {
|
|
883
|
-
"color": "red",
|
|
884
|
-
"linestyle": "--",
|
|
885
|
-
"linewidth": linewidth
|
|
922
|
+
"color": scatter_color,
|
|
886
923
|
},
|
|
924
|
+
"line_kws": {"color": "red", "linestyle": "--", "linewidth": linewidth},
|
|
887
925
|
"ax": ax,
|
|
888
926
|
}
|
|
889
927
|
|
|
@@ -891,7 +929,7 @@ def regplot(
|
|
|
891
929
|
|
|
892
930
|
sb.regplot(**regplot_kwargs)
|
|
893
931
|
|
|
894
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
932
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
895
933
|
|
|
896
934
|
|
|
897
935
|
# ===================================================================
|
|
@@ -951,19 +989,19 @@ def lmplot(
|
|
|
951
989
|
# 회귀선에 linewidth 적용
|
|
952
990
|
for ax in g.axes.flat:
|
|
953
991
|
for line in ax.get_lines():
|
|
954
|
-
if line.get_marker() ==
|
|
992
|
+
if line.get_marker() == "o": # 산점도는 건너뛰기
|
|
955
993
|
continue
|
|
956
994
|
line.set_linewidth(linewidth)
|
|
957
995
|
|
|
958
996
|
g.fig.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width) # type: ignore
|
|
959
997
|
|
|
960
998
|
if title:
|
|
961
|
-
g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight=
|
|
999
|
+
g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight="bold")
|
|
962
1000
|
|
|
963
1001
|
plt.tight_layout()
|
|
964
1002
|
|
|
965
1003
|
if save_path is not None:
|
|
966
|
-
plt.savefig(save_path, dpi=dpi*2, bbox_inches=
|
|
1004
|
+
plt.savefig(save_path, dpi=dpi * 2, bbox_inches="tight")
|
|
967
1005
|
|
|
968
1006
|
plt.show()
|
|
969
1007
|
plt.close()
|
|
@@ -1012,7 +1050,7 @@ def pairplot(
|
|
|
1012
1050
|
if xnames is None:
|
|
1013
1051
|
# 모든 연속형(숫자형) 컬럼 선택 (명목형/카테고리 제외)
|
|
1014
1052
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|
1015
|
-
target_cols = [col for col in numeric_cols if df[col].dtype.name !=
|
|
1053
|
+
target_cols = [col for col in numeric_cols if df[col].dtype.name != "category"]
|
|
1016
1054
|
elif isinstance(xnames, str):
|
|
1017
1055
|
# 문자열: 해당 컬럼만
|
|
1018
1056
|
target_cols = [xnames]
|
|
@@ -1022,7 +1060,7 @@ def pairplot(
|
|
|
1022
1060
|
else:
|
|
1023
1061
|
# 기본값으로 연속형 컬럼
|
|
1024
1062
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|
1025
|
-
target_cols = [col for col in numeric_cols if df[col].dtype.name !=
|
|
1063
|
+
target_cols = [col for col in numeric_cols if df[col].dtype.name != "category"]
|
|
1026
1064
|
|
|
1027
1065
|
# hue 컬럼이 있으면 target_cols에 포함시키기 (pairplot 자체에서 필요)
|
|
1028
1066
|
if hue is not None and hue not in target_cols:
|
|
@@ -1050,9 +1088,11 @@ def pairplot(
|
|
|
1050
1088
|
g.fig.set_dpi(dpi)
|
|
1051
1089
|
|
|
1052
1090
|
if title:
|
|
1053
|
-
g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight=
|
|
1091
|
+
g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight="bold")
|
|
1054
1092
|
|
|
1055
|
-
g.map_lower(
|
|
1093
|
+
g.map_lower(
|
|
1094
|
+
func=sb.kdeplot, fill=True, alpha=config.fill_alpha, linewidth=linewidth
|
|
1095
|
+
)
|
|
1056
1096
|
g.map_upper(func=sb.scatterplot, linewidth=linewidth)
|
|
1057
1097
|
|
|
1058
1098
|
# KDE 대각선에도 linewidth 적용
|
|
@@ -1063,7 +1103,7 @@ def pairplot(
|
|
|
1063
1103
|
plt.tight_layout()
|
|
1064
1104
|
|
|
1065
1105
|
if save_path is not None:
|
|
1066
|
-
plt.savefig(save_path, dpi=dpi*2, bbox_inches=
|
|
1106
|
+
plt.savefig(save_path, dpi=dpi * 2, bbox_inches="tight")
|
|
1067
1107
|
|
|
1068
1108
|
plt.show()
|
|
1069
1109
|
plt.close()
|
|
@@ -1117,7 +1157,7 @@ def countplot(
|
|
|
1117
1157
|
sort = sorted(list(df[xname].value_counts().index))
|
|
1118
1158
|
|
|
1119
1159
|
if ax is None:
|
|
1120
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
1160
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
1121
1161
|
outparams = True
|
|
1122
1162
|
|
|
1123
1163
|
# hue가 있을 때만 palette 사용, 없으면 color 사용
|
|
@@ -1140,7 +1180,7 @@ def countplot(
|
|
|
1140
1180
|
|
|
1141
1181
|
sb.countplot(**countplot_kwargs)
|
|
1142
1182
|
|
|
1143
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
1183
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
1144
1184
|
|
|
1145
1185
|
|
|
1146
1186
|
# ===================================================================
|
|
@@ -1185,7 +1225,7 @@ def barplot(
|
|
|
1185
1225
|
outparams = False
|
|
1186
1226
|
|
|
1187
1227
|
if ax is None:
|
|
1188
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
1228
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
1189
1229
|
outparams = True
|
|
1190
1230
|
|
|
1191
1231
|
# hue가 있을 때만 palette 사용, 없으면 color 사용
|
|
@@ -1206,7 +1246,7 @@ def barplot(
|
|
|
1206
1246
|
barplot_kwargs.update(params)
|
|
1207
1247
|
|
|
1208
1248
|
sb.barplot(**barplot_kwargs)
|
|
1209
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
1249
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
1210
1250
|
|
|
1211
1251
|
|
|
1212
1252
|
# ===================================================================
|
|
@@ -1251,7 +1291,7 @@ def boxenplot(
|
|
|
1251
1291
|
outparams = False
|
|
1252
1292
|
|
|
1253
1293
|
if ax is None:
|
|
1254
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
1294
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
1255
1295
|
outparams = True
|
|
1256
1296
|
|
|
1257
1297
|
# palette은 hue가 있을 때만 사용
|
|
@@ -1270,7 +1310,7 @@ def boxenplot(
|
|
|
1270
1310
|
boxenplot_kwargs.update(params)
|
|
1271
1311
|
|
|
1272
1312
|
sb.boxenplot(**boxenplot_kwargs)
|
|
1273
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
1313
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
1274
1314
|
|
|
1275
1315
|
|
|
1276
1316
|
# ===================================================================
|
|
@@ -1315,7 +1355,7 @@ def violinplot(
|
|
|
1315
1355
|
outparams = False
|
|
1316
1356
|
|
|
1317
1357
|
if ax is None:
|
|
1318
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
1358
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
1319
1359
|
outparams = True
|
|
1320
1360
|
|
|
1321
1361
|
# palette은 hue가 있을 때만 사용
|
|
@@ -1333,7 +1373,7 @@ def violinplot(
|
|
|
1333
1373
|
|
|
1334
1374
|
violinplot_kwargs.update(params)
|
|
1335
1375
|
sb.violinplot(**violinplot_kwargs)
|
|
1336
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
1376
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
1337
1377
|
|
|
1338
1378
|
|
|
1339
1379
|
# ===================================================================
|
|
@@ -1378,7 +1418,7 @@ def pointplot(
|
|
|
1378
1418
|
outparams = False
|
|
1379
1419
|
|
|
1380
1420
|
if ax is None:
|
|
1381
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
1421
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
1382
1422
|
outparams = True
|
|
1383
1423
|
|
|
1384
1424
|
# hue가 있을 때만 palette 사용, 없으면 color 사용
|
|
@@ -1398,7 +1438,7 @@ def pointplot(
|
|
|
1398
1438
|
|
|
1399
1439
|
pointplot_kwargs.update(params)
|
|
1400
1440
|
sb.pointplot(**pointplot_kwargs)
|
|
1401
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
1441
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
1402
1442
|
|
|
1403
1443
|
|
|
1404
1444
|
# ===================================================================
|
|
@@ -1456,7 +1496,7 @@ def jointplot(
|
|
|
1456
1496
|
g.fig.set_dpi(dpi)
|
|
1457
1497
|
|
|
1458
1498
|
if title:
|
|
1459
|
-
g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight=
|
|
1499
|
+
g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight="bold")
|
|
1460
1500
|
|
|
1461
1501
|
# 중앙 및 주변 플롯에 grid 추가
|
|
1462
1502
|
g.ax_joint.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
|
|
@@ -1466,7 +1506,7 @@ def jointplot(
|
|
|
1466
1506
|
plt.tight_layout()
|
|
1467
1507
|
|
|
1468
1508
|
if save_path is not None:
|
|
1469
|
-
plt.savefig(save_path, dpi=dpi*2, bbox_inches=
|
|
1509
|
+
plt.savefig(save_path, dpi=dpi * 2, bbox_inches="tight")
|
|
1470
1510
|
|
|
1471
1511
|
plt.show()
|
|
1472
1512
|
plt.close()
|
|
@@ -1509,7 +1549,7 @@ def heatmap(
|
|
|
1509
1549
|
|
|
1510
1550
|
if width == None or height == None:
|
|
1511
1551
|
width = (config.font_size * config.dpi / 72) * 4.5 * len(data.columns)
|
|
1512
|
-
height = width * 0.8
|
|
1552
|
+
height = width * 0.8 # type: ignore
|
|
1513
1553
|
|
|
1514
1554
|
if ax is None:
|
|
1515
1555
|
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
@@ -1522,7 +1562,7 @@ def heatmap(
|
|
|
1522
1562
|
"fmt": ".2f",
|
|
1523
1563
|
"ax": ax,
|
|
1524
1564
|
"linewidths": linewidth,
|
|
1525
|
-
"annot_kws": {"size": 10}
|
|
1565
|
+
"annot_kws": {"size": 10},
|
|
1526
1566
|
}
|
|
1527
1567
|
|
|
1528
1568
|
heatmatp_kwargs.update(params)
|
|
@@ -1533,7 +1573,6 @@ def heatmap(
|
|
|
1533
1573
|
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
1534
1574
|
|
|
1535
1575
|
|
|
1536
|
-
|
|
1537
1576
|
# ===================================================================
|
|
1538
1577
|
# KDE와 신뢰구간을 나타낸 그래프를 그린다
|
|
1539
1578
|
# ===================================================================
|
|
@@ -1619,19 +1658,30 @@ def kde_confidence_interval(
|
|
|
1619
1658
|
cmin, cmax = t.interval(clevel, dof, loc=sample_mean, scale=sample_std_error)
|
|
1620
1659
|
|
|
1621
1660
|
# 현재 컬럼에 대한 커널밀도추정
|
|
1622
|
-
sb.kdeplot(data=column, linewidth=linewidth, ax=current_ax, fill=fill, alpha=config.fill_alpha)
|
|
1661
|
+
sb.kdeplot(data=column, linewidth=linewidth, ax=current_ax, fill=fill, alpha=config.fill_alpha) # type: ignore
|
|
1623
1662
|
|
|
1624
1663
|
# 그래프 축의 범위
|
|
1625
1664
|
xmin, xmax, ymin, ymax = current_ax.get_position().bounds
|
|
1626
1665
|
ymin_val, ymax_val = 0, current_ax.get_ylim()[1]
|
|
1627
1666
|
|
|
1628
1667
|
# 신뢰구간 그리기
|
|
1629
|
-
current_ax.plot(
|
|
1630
|
-
|
|
1631
|
-
|
|
1668
|
+
current_ax.plot(
|
|
1669
|
+
[cmin, cmin], [ymin_val, ymax_val], linestyle=":", linewidth=linewidth * 0.5
|
|
1670
|
+
)
|
|
1671
|
+
current_ax.plot(
|
|
1672
|
+
[cmax, cmax], [ymin_val, ymax_val], linestyle=":", linewidth=linewidth * 0.5
|
|
1673
|
+
)
|
|
1674
|
+
current_ax.fill_between(
|
|
1675
|
+
[cmin, cmax], y1=ymin_val, y2=ymax_val, alpha=config.fill_alpha
|
|
1676
|
+
)
|
|
1632
1677
|
|
|
1633
1678
|
# 평균 그리기
|
|
1634
|
-
current_ax.plot(
|
|
1679
|
+
current_ax.plot(
|
|
1680
|
+
[sample_mean, sample_mean],
|
|
1681
|
+
[0, ymax_val],
|
|
1682
|
+
linestyle="--",
|
|
1683
|
+
linewidth=linewidth,
|
|
1684
|
+
)
|
|
1635
1685
|
|
|
1636
1686
|
current_ax.text(
|
|
1637
1687
|
x=(cmax - cmin) / 2 + cmin,
|
|
@@ -1644,8 +1694,7 @@ def kde_confidence_interval(
|
|
|
1644
1694
|
|
|
1645
1695
|
current_ax.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
|
|
1646
1696
|
|
|
1647
|
-
finalize_plot(axes[0] if isinstance(axes, list) and len(axes) > 0 else ax, callback, outparams, save_path, True, title)
|
|
1648
|
-
|
|
1697
|
+
finalize_plot(axes[0] if isinstance(axes, list) and len(axes) > 0 else ax, callback, outparams, save_path, True, title) # type: ignore
|
|
1649
1698
|
|
|
1650
1699
|
|
|
1651
1700
|
# ===================================================================
|
|
@@ -1711,32 +1760,29 @@ def ols_residplot(
|
|
|
1711
1760
|
sb.scatterplot(x=y_pred, y=resid, ax=ax, s=20, edgecolor="white", **params)
|
|
1712
1761
|
|
|
1713
1762
|
# 기준선 (잔차 = 0)
|
|
1714
|
-
ax.axhline(0, color="gray", linestyle="--", linewidth=linewidth*0.7)
|
|
1763
|
+
ax.axhline(0, color="gray", linestyle="--", linewidth=linewidth * 0.7) # type: ignore
|
|
1715
1764
|
|
|
1716
1765
|
# LOWESS 스무딩 (선택적)
|
|
1717
1766
|
if lowess:
|
|
1718
1767
|
lowess_result = sm_lowess(resid, y_pred, frac=0.6667)
|
|
1719
|
-
ax.plot(
|
|
1720
|
-
|
|
1768
|
+
ax.plot( # type: ignore
|
|
1769
|
+
lowess_result[:, 0],
|
|
1770
|
+
lowess_result[:, 1], # type: ignore
|
|
1771
|
+
color="red",
|
|
1772
|
+
linewidth=linewidth,
|
|
1773
|
+
label="LOWESS",
|
|
1774
|
+
) # type: ignore
|
|
1721
1775
|
|
|
1722
1776
|
ax.set_xlabel("Fitted values") # type: ignore
|
|
1723
|
-
ax.set_ylabel("Residuals")
|
|
1777
|
+
ax.set_ylabel("Residuals") # type: ignore
|
|
1724
1778
|
|
|
1725
1779
|
if mse:
|
|
1726
1780
|
mse_val = mean_squared_error(y, y_pred)
|
|
1727
1781
|
mse_sq = np.sqrt(mse_val)
|
|
1728
1782
|
|
|
1729
1783
|
r1 = resid[(resid > -mse_sq) & (resid < mse_sq)].size / resid.size * 100
|
|
1730
|
-
r2 = (
|
|
1731
|
-
|
|
1732
|
-
/ resid.size
|
|
1733
|
-
* 100
|
|
1734
|
-
)
|
|
1735
|
-
r3 = (
|
|
1736
|
-
resid[(resid > -3 * mse_sq) & (resid < 3 * mse_sq)].size
|
|
1737
|
-
/ resid.size
|
|
1738
|
-
* 100
|
|
1739
|
-
)
|
|
1784
|
+
r2 = resid[(resid > -2 * mse_sq) & (resid < 2 * mse_sq)].size / resid.size * 100
|
|
1785
|
+
r3 = resid[(resid > -3 * mse_sq) & (resid < 3 * mse_sq)].size / resid.size * 100
|
|
1740
1786
|
|
|
1741
1787
|
mse_r = [r1, r2, r3]
|
|
1742
1788
|
|
|
@@ -1747,26 +1793,26 @@ def ols_residplot(
|
|
|
1747
1793
|
alphas = [0.15, 0.10, 0.05] # 안쪽이 더 진하게
|
|
1748
1794
|
|
|
1749
1795
|
# 3σ 영역 (가장 바깥쪽, 가장 연함)
|
|
1750
|
-
ax.axhspan(-3 * mse_sq, 3 * mse_sq, facecolor=colors[2], alpha=alphas[2], zorder=0)
|
|
1796
|
+
ax.axhspan(-3 * mse_sq, 3 * mse_sq, facecolor=colors[2], alpha=alphas[2], zorder=0) # type: ignore
|
|
1751
1797
|
# 2σ 영역 (중간)
|
|
1752
|
-
ax.axhspan(-2 * mse_sq, 2 * mse_sq, facecolor=colors[1], alpha=alphas[1], zorder=1)
|
|
1798
|
+
ax.axhspan(-2 * mse_sq, 2 * mse_sq, facecolor=colors[1], alpha=alphas[1], zorder=1) # type: ignore
|
|
1753
1799
|
# 1σ 영역 (가장 안쪽, 가장 진함)
|
|
1754
|
-
ax.axhspan(-mse_sq, mse_sq, facecolor=colors[0], alpha=alphas[0], zorder=2)
|
|
1800
|
+
ax.axhspan(-mse_sq, mse_sq, facecolor=colors[0], alpha=alphas[0], zorder=2) # type: ignore
|
|
1755
1801
|
|
|
1756
1802
|
# 경계선 그리기
|
|
1757
1803
|
for i, c in enumerate(["red", "green", "blue"]):
|
|
1758
|
-
ax.axhline(mse_sq * (i + 1), color=c, linestyle="--", linewidth=linewidth/2)
|
|
1759
|
-
ax.axhline(mse_sq * (-(i + 1)), color=c, linestyle="--", linewidth=linewidth/2)
|
|
1804
|
+
ax.axhline(mse_sq * (i + 1), color=c, linestyle="--", linewidth=linewidth / 2) # type: ignore
|
|
1805
|
+
ax.axhline(mse_sq * (-(i + 1)), color=c, linestyle="--", linewidth=linewidth / 2) # type: ignore
|
|
1760
1806
|
|
|
1761
1807
|
target = [68, 95, 99.7]
|
|
1762
1808
|
for i, c in enumerate(["red", "green", "blue"]):
|
|
1763
|
-
ax.text(
|
|
1809
|
+
ax.text( # type: ignore
|
|
1764
1810
|
s=f"{i+1} sqrt(MSE) = {mse_r[i]:.2f}% ({mse_r[i] - target[i]:.2f}%)",
|
|
1765
1811
|
x=xmax + 0.05,
|
|
1766
1812
|
y=(i + 1) * mse_sq,
|
|
1767
1813
|
color=c,
|
|
1768
1814
|
)
|
|
1769
|
-
ax.text(
|
|
1815
|
+
ax.text( # type: ignore
|
|
1770
1816
|
s=f"-{i+1} sqrt(MSE) = {mse_r[i]:.2f}% ({mse_r[i] - target[i]:.2f}%)",
|
|
1771
1817
|
x=xmax + 0.05,
|
|
1772
1818
|
y=-(i + 1) * mse_sq,
|
|
@@ -1782,7 +1828,7 @@ def ols_residplot(
|
|
|
1782
1828
|
def ols_qqplot(
|
|
1783
1829
|
fit,
|
|
1784
1830
|
title: str | None = None,
|
|
1785
|
-
line: str =
|
|
1831
|
+
line: str = "s",
|
|
1786
1832
|
width: int = config.width,
|
|
1787
1833
|
height: int = config.height,
|
|
1788
1834
|
linewidth: float = config.line_width,
|
|
@@ -1834,32 +1880,32 @@ def ols_qqplot(
|
|
|
1834
1880
|
outparams = False
|
|
1835
1881
|
|
|
1836
1882
|
if ax is None:
|
|
1837
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
1883
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
1838
1884
|
outparams = True
|
|
1839
1885
|
|
|
1840
1886
|
# fit 객체에서 잔차(residuals) 추출
|
|
1841
1887
|
residuals = fit.resid
|
|
1842
1888
|
|
|
1843
1889
|
# markersize 기본값 설정 (기존 크기의 2/3)
|
|
1844
|
-
if
|
|
1845
|
-
params[
|
|
1890
|
+
if "markersize" not in params:
|
|
1891
|
+
params["markersize"] = 2
|
|
1846
1892
|
|
|
1847
1893
|
# statsmodels의 qqplot 사용 (더 전문적이고 최적화된 구현)
|
|
1848
1894
|
# line 옵션으로 다양한 참조선 지원
|
|
1849
1895
|
sm_qqplot(residuals, line=line, ax=ax, **params)
|
|
1850
1896
|
|
|
1851
1897
|
# 점의 스타일 개선: 연한 내부, 진한 테두리
|
|
1852
|
-
for collection in ax.collections:
|
|
1898
|
+
for collection in ax.collections: # type: ignore
|
|
1853
1899
|
# PathCollection (scatter plot의 점들)
|
|
1854
|
-
collection.set_facecolor(
|
|
1855
|
-
collection.set_edgecolor(
|
|
1900
|
+
collection.set_facecolor("#4A90E2") # 연한 파란색 내부
|
|
1901
|
+
collection.set_edgecolor("#1E3A8A") # 진한 파란색 테두리
|
|
1856
1902
|
collection.set_linewidth(0.8) # 테두리 굵기
|
|
1857
1903
|
collection.set_alpha(0.7) # 약간의 투명도
|
|
1858
1904
|
|
|
1859
1905
|
# 선 굵기 조정
|
|
1860
|
-
for line in ax.get_lines():
|
|
1861
|
-
if line.get_linestyle() ==
|
|
1862
|
-
line.set_linewidth(linewidth)
|
|
1906
|
+
for line in ax.get_lines(): # type: ignore
|
|
1907
|
+
if line.get_linestyle() == "--" or line.get_color() == "r": # type: ignore
|
|
1908
|
+
line.set_linewidth(linewidth) # type: ignore
|
|
1863
1909
|
|
|
1864
1910
|
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
1865
1911
|
|
|
@@ -1904,7 +1950,7 @@ def distribution_by_class(
|
|
|
1904
1950
|
None
|
|
1905
1951
|
"""
|
|
1906
1952
|
if xnames is None:
|
|
1907
|
-
xnames = data.columns
|
|
1953
|
+
xnames = data.columns # type: ignore
|
|
1908
1954
|
|
|
1909
1955
|
for i, v in enumerate(xnames): # type: ignore
|
|
1910
1956
|
# 종속변수이거나 숫자형이 아닌 경우는 제외
|
|
@@ -1930,14 +1976,14 @@ def distribution_by_class(
|
|
|
1930
1976
|
linewidth=linewidth,
|
|
1931
1977
|
dpi=dpi,
|
|
1932
1978
|
callback=callback,
|
|
1933
|
-
save_path=save_path
|
|
1979
|
+
save_path=save_path,
|
|
1934
1980
|
)
|
|
1935
1981
|
elif type == "hist":
|
|
1936
1982
|
histplot(
|
|
1937
1983
|
df=data,
|
|
1938
1984
|
xname=v,
|
|
1939
1985
|
hue=hue,
|
|
1940
|
-
bins=bins,
|
|
1986
|
+
bins=bins, # type: ignore
|
|
1941
1987
|
kde=False,
|
|
1942
1988
|
palette=palette,
|
|
1943
1989
|
width=width,
|
|
@@ -1945,14 +1991,14 @@ def distribution_by_class(
|
|
|
1945
1991
|
linewidth=linewidth,
|
|
1946
1992
|
dpi=dpi,
|
|
1947
1993
|
callback=callback,
|
|
1948
|
-
save_path=save_path
|
|
1994
|
+
save_path=save_path,
|
|
1949
1995
|
)
|
|
1950
1996
|
elif type == "histkde":
|
|
1951
1997
|
histplot(
|
|
1952
1998
|
df=data,
|
|
1953
1999
|
xname=v,
|
|
1954
2000
|
hue=hue,
|
|
1955
|
-
bins=bins,
|
|
2001
|
+
bins=bins, # type: ignore
|
|
1956
2002
|
kde=True,
|
|
1957
2003
|
palette=palette,
|
|
1958
2004
|
width=width,
|
|
@@ -1960,7 +2006,7 @@ def distribution_by_class(
|
|
|
1960
2006
|
linewidth=linewidth,
|
|
1961
2007
|
dpi=dpi,
|
|
1962
2008
|
callback=callback,
|
|
1963
|
-
save_path=save_path
|
|
2009
|
+
save_path=save_path,
|
|
1964
2010
|
)
|
|
1965
2011
|
|
|
1966
2012
|
|
|
@@ -2027,7 +2073,7 @@ def scatter_by_class(
|
|
|
2027
2073
|
group = processed
|
|
2028
2074
|
|
|
2029
2075
|
for v in group:
|
|
2030
|
-
scatterplot(data=data, xname=v[0], yname=v[1], outline=outline, hue=hue, palette=palette, width=width, height=height, linewidth=linewidth, dpi=dpi, callback=callback, save_path=save_path)
|
|
2076
|
+
scatterplot(data=data, xname=v[0], yname=v[1], outline=outline, hue=hue, palette=palette, width=width, height=height, linewidth=linewidth, dpi=dpi, callback=callback, save_path=save_path) # type: ignore
|
|
2031
2077
|
|
|
2032
2078
|
|
|
2033
2079
|
# ===================================================================
|
|
@@ -2072,7 +2118,9 @@ def categorical_target_distribution(
|
|
|
2072
2118
|
|
|
2073
2119
|
# 명목형 컬럼 후보: object, category, bool
|
|
2074
2120
|
if hue is None:
|
|
2075
|
-
cat_cols = data.select_dtypes(
|
|
2121
|
+
cat_cols = data.select_dtypes(
|
|
2122
|
+
include=["object", "category", "bool", "boolean"]
|
|
2123
|
+
).columns
|
|
2076
2124
|
target_cols = [c for c in cat_cols if c != yname]
|
|
2077
2125
|
elif isinstance(hue, str):
|
|
2078
2126
|
target_cols = [hue]
|
|
@@ -2102,7 +2150,16 @@ def categorical_target_distribution(
|
|
|
2102
2150
|
plot_kwargs.update({"x": col, "y": yname, "palette": palette})
|
|
2103
2151
|
sb.violinplot(**plot_kwargs, linewidth=linewidth)
|
|
2104
2152
|
elif kind == "kde":
|
|
2105
|
-
plot_kwargs.update(
|
|
2153
|
+
plot_kwargs.update(
|
|
2154
|
+
{
|
|
2155
|
+
"x": yname,
|
|
2156
|
+
"hue": col,
|
|
2157
|
+
"palette": palette,
|
|
2158
|
+
"fill": kde_fill,
|
|
2159
|
+
"common_norm": False,
|
|
2160
|
+
"linewidth": linewidth,
|
|
2161
|
+
}
|
|
2162
|
+
)
|
|
2106
2163
|
sb.kdeplot(**plot_kwargs)
|
|
2107
2164
|
else: # box
|
|
2108
2165
|
plot_kwargs.update({"x": col, "y": yname, "hue": col, "palette": palette})
|
|
@@ -2156,7 +2213,7 @@ def roc_curve_plot(
|
|
|
2156
2213
|
"""
|
|
2157
2214
|
outparams = False
|
|
2158
2215
|
if ax is None:
|
|
2159
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
2216
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
2160
2217
|
outparams = True
|
|
2161
2218
|
|
|
2162
2219
|
# 실제값(y_true) 결정
|
|
@@ -2177,15 +2234,15 @@ def roc_curve_plot(
|
|
|
2177
2234
|
roc_auc = auc(fpr, tpr)
|
|
2178
2235
|
|
|
2179
2236
|
# ROC 곡선 그리기
|
|
2180
|
-
ax.plot(fpr, tpr, color=
|
|
2181
|
-
ax.plot([0, 1], [0, 1], color=
|
|
2182
|
-
|
|
2183
|
-
ax.set_xlim([0.0, 1.0])
|
|
2184
|
-
ax.set_ylim([0.0, 1.05])
|
|
2185
|
-
ax.set_xlabel(
|
|
2186
|
-
ax.set_ylabel(
|
|
2187
|
-
ax.set_title(
|
|
2188
|
-
ax.legend(loc="lower right", fontsize=7)
|
|
2237
|
+
ax.plot(fpr, tpr, color="darkorange", lw=linewidth, label=f"ROC curve (AUC = {roc_auc:.4f})") # type: ignore
|
|
2238
|
+
ax.plot([0, 1], [0, 1], color="navy", lw=linewidth, linestyle="--", label="Random Classifier") # type: ignore
|
|
2239
|
+
|
|
2240
|
+
ax.set_xlim([0.0, 1.0]) # type: ignore
|
|
2241
|
+
ax.set_ylim([0.0, 1.05]) # type: ignore
|
|
2242
|
+
ax.set_xlabel("위양성율 (False Positive Rate)", fontsize=8) # type: ignore
|
|
2243
|
+
ax.set_ylabel("재현율 (True Positive Rate)", fontsize=8) # type: ignore
|
|
2244
|
+
ax.set_title("ROC 곡선", fontsize=10, fontweight="bold") # type: ignore
|
|
2245
|
+
ax.legend(loc="lower right", fontsize=7) # type: ignore
|
|
2189
2246
|
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
2190
2247
|
|
|
2191
2248
|
|
|
@@ -2232,13 +2289,18 @@ def confusion_matrix_plot(
|
|
|
2232
2289
|
cm = confusion_matrix(y_true, y_pred)
|
|
2233
2290
|
|
|
2234
2291
|
# 혼동행렬 시각화
|
|
2235
|
-
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[
|
|
2292
|
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["음성", "양성"])
|
|
2236
2293
|
# 가독성을 위해 텍스트 크기/굵기 조정
|
|
2237
|
-
disp.plot(
|
|
2294
|
+
disp.plot(
|
|
2295
|
+
ax=ax,
|
|
2296
|
+
cmap="Blues",
|
|
2297
|
+
values_format="d",
|
|
2298
|
+
text_kw={"fontsize": 16, "weight": "bold"},
|
|
2299
|
+
)
|
|
2238
2300
|
|
|
2239
|
-
ax.set_title(f
|
|
2301
|
+
ax.set_title(f"혼동행렬 (임계값: {threshold})", fontsize=8, fontweight="bold") # type: ignore
|
|
2240
2302
|
|
|
2241
|
-
finalize_plot(ax, callback, outparams, save_path, False, title)
|
|
2303
|
+
finalize_plot(ax, callback, outparams, save_path, False, title) # type: ignore
|
|
2242
2304
|
|
|
2243
2305
|
|
|
2244
2306
|
# ===================================================================
|
|
@@ -2323,7 +2385,7 @@ def radarplot(
|
|
|
2323
2385
|
# Axes 생성 (polar projection)
|
|
2324
2386
|
if ax is None:
|
|
2325
2387
|
fig = plt.figure(figsize=(width / 100, height / 100), dpi=dpi)
|
|
2326
|
-
ax = fig.add_subplot(111, projection=
|
|
2388
|
+
ax = fig.add_subplot(111, projection="polar")
|
|
2327
2389
|
outparams = True
|
|
2328
2390
|
|
|
2329
2391
|
# 각도 계산
|
|
@@ -2345,8 +2407,15 @@ def radarplot(
|
|
|
2345
2407
|
color = colors[idx]
|
|
2346
2408
|
|
|
2347
2409
|
# 선 그리기
|
|
2348
|
-
ax.plot(
|
|
2349
|
-
|
|
2410
|
+
ax.plot(
|
|
2411
|
+
angles,
|
|
2412
|
+
values,
|
|
2413
|
+
"o-",
|
|
2414
|
+
linewidth=linewidth,
|
|
2415
|
+
label=str(label_name),
|
|
2416
|
+
color=color,
|
|
2417
|
+
**params,
|
|
2418
|
+
)
|
|
2350
2419
|
|
|
2351
2420
|
# 영역 채우기
|
|
2352
2421
|
if fill:
|
|
@@ -2362,15 +2431,15 @@ def radarplot(
|
|
|
2362
2431
|
|
|
2363
2432
|
# 범례
|
|
2364
2433
|
if len(labels) <= 10: # 너무 많으면 범례 생략
|
|
2365
|
-
ax.legend(loc=
|
|
2434
|
+
ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))
|
|
2366
2435
|
|
|
2367
2436
|
# 제목
|
|
2368
2437
|
if hue is not None:
|
|
2369
|
-
ax.set_title(f
|
|
2438
|
+
ax.set_title(f"Radar Chart by {hue}", pad=20)
|
|
2370
2439
|
else:
|
|
2371
|
-
ax.set_title(
|
|
2440
|
+
ax.set_title("Radar Chart", pad=20)
|
|
2372
2441
|
|
|
2373
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
2442
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
2374
2443
|
|
|
2375
2444
|
|
|
2376
2445
|
# ===================================================================
|
|
@@ -2421,7 +2490,9 @@ def distribution_plot(
|
|
|
2421
2490
|
|
|
2422
2491
|
if hue is None:
|
|
2423
2492
|
# 1행 2열 서브플롯 생성
|
|
2424
|
-
fig, axes = get_default_ax(
|
|
2493
|
+
fig, axes = get_default_ax(
|
|
2494
|
+
width, height, rows=1, cols=2, dpi=dpi, title=title
|
|
2495
|
+
)
|
|
2425
2496
|
|
|
2426
2497
|
kde_confidence_interval(
|
|
2427
2498
|
data=data,
|
|
@@ -2432,17 +2503,10 @@ def distribution_plot(
|
|
|
2432
2503
|
)
|
|
2433
2504
|
|
|
2434
2505
|
if kind == "hist":
|
|
2435
|
-
histplot(
|
|
2436
|
-
df=data,
|
|
2437
|
-
xname=c,
|
|
2438
|
-
linewidth=linewidth,
|
|
2439
|
-
ax=axes[1]
|
|
2440
|
-
)
|
|
2506
|
+
histplot(df=data, xname=c, linewidth=linewidth, ax=axes[1])
|
|
2441
2507
|
else:
|
|
2442
2508
|
boxplot(
|
|
2443
|
-
df=data[column],
|
|
2444
|
-
linewidth=linewidth,
|
|
2445
|
-
ax=axes[1]
|
|
2509
|
+
df=data[column], linewidth=linewidth, ax=axes[1] # type: ignore
|
|
2446
2510
|
)
|
|
2447
2511
|
|
|
2448
2512
|
fig.suptitle(title, fontsize=14, y=1.02)
|
|
@@ -2453,7 +2517,9 @@ def distribution_plot(
|
|
|
2453
2517
|
categories = list(Series(data[hue].dropna().unique()).sort_values())
|
|
2454
2518
|
n_cat = len(categories) if categories else 1
|
|
2455
2519
|
|
|
2456
|
-
fig, axes = get_default_ax(
|
|
2520
|
+
fig, axes = get_default_ax(
|
|
2521
|
+
width, height, rows=n_cat, cols=2, dpi=dpi, title=title
|
|
2522
|
+
)
|
|
2457
2523
|
axes_2d = np.atleast_2d(axes)
|
|
2458
2524
|
|
|
2459
2525
|
for idx, cat in enumerate(categories):
|
|
@@ -2478,9 +2544,7 @@ def distribution_plot(
|
|
|
2478
2544
|
)
|
|
2479
2545
|
else:
|
|
2480
2546
|
boxplot(
|
|
2481
|
-
df=subset[c], # type: ignore
|
|
2482
|
-
linewidth=linewidth,
|
|
2483
|
-
ax=right_ax
|
|
2547
|
+
df=subset[c], linewidth=linewidth, ax=right_ax # type: ignore
|
|
2484
2548
|
)
|
|
2485
2549
|
|
|
2486
2550
|
fig.suptitle(f"{title} by {hue}", fontsize=14, y=1.02)
|
|
@@ -2488,24 +2552,24 @@ def distribution_plot(
|
|
|
2488
2552
|
plt.tight_layout()
|
|
2489
2553
|
|
|
2490
2554
|
if save_path:
|
|
2491
|
-
plt.savefig(save_path, bbox_inches=
|
|
2555
|
+
plt.savefig(save_path, bbox_inches="tight", dpi=dpi)
|
|
2492
2556
|
plt.close()
|
|
2493
2557
|
else:
|
|
2494
2558
|
plt.show()
|
|
2495
2559
|
|
|
2496
2560
|
|
|
2497
2561
|
def silhouette_plot(
|
|
2498
|
-
|
|
2499
|
-
|
|
2500
|
-
|
|
2501
|
-
|
|
2502
|
-
|
|
2503
|
-
|
|
2504
|
-
|
|
2505
|
-
|
|
2506
|
-
|
|
2507
|
-
|
|
2508
|
-
|
|
2562
|
+
estimator: KMeans,
|
|
2563
|
+
data: DataFrame,
|
|
2564
|
+
title: str | None = None,
|
|
2565
|
+
width: int = config.width,
|
|
2566
|
+
height: int = config.height,
|
|
2567
|
+
linewidth: float = config.line_width,
|
|
2568
|
+
dpi: int = config.dpi,
|
|
2569
|
+
save_path: str | None = None,
|
|
2570
|
+
callback: Callable | None = None,
|
|
2571
|
+
ax: Axes | None = None,
|
|
2572
|
+
) -> None:
|
|
2509
2573
|
"""
|
|
2510
2574
|
군집분석 결과의 실루엣 플롯을 시각화함.
|
|
2511
2575
|
|
|
@@ -2532,7 +2596,7 @@ def silhouette_plot(
|
|
|
2532
2596
|
|
|
2533
2597
|
outparams = False
|
|
2534
2598
|
if ax is None:
|
|
2535
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
2599
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
2536
2600
|
outparams = True
|
|
2537
2601
|
|
|
2538
2602
|
sil_avg = silhouette_score(X=data, labels=estimator.labels_)
|
|
@@ -2541,14 +2605,14 @@ def silhouette_plot(
|
|
|
2541
2605
|
y_lower = 10
|
|
2542
2606
|
|
|
2543
2607
|
# 클러스터링 갯수별로 fill_betweenx( )형태의 막대 그래프 표현.
|
|
2544
|
-
for i in range(estimator.n_clusters):
|
|
2545
|
-
ith_cluster_sil_values = sil_values[estimator.labels_ == i]
|
|
2546
|
-
ith_cluster_sil_values.sort()
|
|
2608
|
+
for i in range(estimator.n_clusters): # type: ignore
|
|
2609
|
+
ith_cluster_sil_values = sil_values[estimator.labels_ == i] # type: ignore
|
|
2610
|
+
ith_cluster_sil_values.sort() # type: ignore
|
|
2547
2611
|
|
|
2548
|
-
size_cluster_i = ith_cluster_sil_values.shape[0]
|
|
2612
|
+
size_cluster_i = ith_cluster_sil_values.shape[0] # type: ignore
|
|
2549
2613
|
y_upper = y_lower + size_cluster_i
|
|
2550
2614
|
|
|
2551
|
-
ax.fill_betweenx(
|
|
2615
|
+
ax.fill_betweenx( # type: ignore
|
|
2552
2616
|
np.arange(y_lower, y_upper),
|
|
2553
2617
|
0,
|
|
2554
2618
|
ith_cluster_sil_values,
|
|
@@ -2564,23 +2628,24 @@ def silhouette_plot(
|
|
|
2564
2628
|
ax.set_xlim([-0.1, 1]) # type: ignore
|
|
2565
2629
|
ax.set_ylim([0, len(data) + (estimator.n_clusters + 1) * 10]) # type: ignore
|
|
2566
2630
|
ax.set_yticks([]) # type: ignore
|
|
2567
|
-
ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
|
|
2631
|
+
ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1]) # type: ignore
|
|
2568
2632
|
|
|
2569
2633
|
if title is None:
|
|
2570
|
-
title = "Number of Cluster : " + str(estimator.n_clusters) + ", Silhouette Score :" + str(round(sil_avg, 3))
|
|
2634
|
+
title = "Number of Cluster : " + str(estimator.n_clusters) + ", Silhouette Score :" + str(round(sil_avg, 3)) # type: ignore
|
|
2571
2635
|
|
|
2572
|
-
finalize_plot(ax, callback, outparams, save_path, True, title)
|
|
2636
|
+
finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
|
|
2573
2637
|
|
|
2574
2638
|
|
|
2575
2639
|
def cluster_plot(
|
|
2576
|
-
estimator: KMeans,
|
|
2577
|
-
data: DataFrame,
|
|
2640
|
+
estimator: KMeans | None = None,
|
|
2641
|
+
data: DataFrame | None = None,
|
|
2578
2642
|
xname: str | None = None,
|
|
2579
2643
|
yname: str | None = None,
|
|
2580
2644
|
hue: str | None = None,
|
|
2645
|
+
vector: str | None = None,
|
|
2581
2646
|
title: str | None = None,
|
|
2582
2647
|
palette: str | None = None,
|
|
2583
|
-
outline: bool =
|
|
2648
|
+
outline: bool = True,
|
|
2584
2649
|
width: int = config.width,
|
|
2585
2650
|
height: int = config.height,
|
|
2586
2651
|
linewidth: float = config.line_width,
|
|
@@ -2588,7 +2653,6 @@ def cluster_plot(
|
|
|
2588
2653
|
save_path: str | None = None,
|
|
2589
2654
|
ax: Axes | None = None,
|
|
2590
2655
|
) -> None:
|
|
2591
|
-
|
|
2592
2656
|
"""
|
|
2593
2657
|
2차원 공간에서 군집분석 결과를 산점도로 시각화함.
|
|
2594
2658
|
|
|
@@ -2598,6 +2662,7 @@ def cluster_plot(
|
|
|
2598
2662
|
xname (str, optional): x축에 사용할 컬럼명. None이면 첫 번째 컬럼 사용.
|
|
2599
2663
|
yname (str, optional): y축에 사용할 컬럼명. None이면 두 번째 컬럼 사용.
|
|
2600
2664
|
hue (str, optional): 군집 구분에 사용할 컬럼명. None이면 'cluster' 자동 생성.
|
|
2665
|
+
vector (str, optional): 벡터 종류를 의미하는 컬럼명. None이면 사용 안함.
|
|
2601
2666
|
title (str, optional): 플롯 제목. None이면 기본값 사용.
|
|
2602
2667
|
palette (str, optional): 색상 팔레트.
|
|
2603
2668
|
outline (bool, optional): 외곽선 표시 여부.
|
|
@@ -2622,49 +2687,53 @@ def cluster_plot(
|
|
|
2622
2687
|
"""
|
|
2623
2688
|
outparams = False
|
|
2624
2689
|
if ax is None:
|
|
2625
|
-
fig, ax = get_default_ax(width, height, 1, 1, dpi)
|
|
2690
|
+
fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
|
|
2626
2691
|
outparams = True
|
|
2627
2692
|
|
|
2628
|
-
df = data.copy()
|
|
2693
|
+
df = data.copy() if data is not None else None # type: ignore
|
|
2629
2694
|
|
|
2630
2695
|
if not hue:
|
|
2631
|
-
df[
|
|
2632
|
-
hue =
|
|
2696
|
+
df["cluster"] = estimator.labels_ # type: ignore
|
|
2697
|
+
hue = "cluster"
|
|
2633
2698
|
|
|
2634
2699
|
if xname is None:
|
|
2635
|
-
xname = df.columns[0]
|
|
2700
|
+
xname = df.columns[0] # type: ignore
|
|
2636
2701
|
|
|
2637
2702
|
if yname is None:
|
|
2638
|
-
yname = df.columns[1]
|
|
2703
|
+
yname = df.columns[1] # type: ignore
|
|
2639
2704
|
|
|
2640
|
-
xindex = df.columns.get_loc(xname)
|
|
2641
|
-
yindex = df.columns.get_loc(yname)
|
|
2705
|
+
xindex = df.columns.get_loc(xname) # type: ignore
|
|
2706
|
+
yindex = df.columns.get_loc(yname) # type: ignore
|
|
2642
2707
|
|
|
2643
2708
|
def callback(ax: Axes) -> None:
|
|
2644
|
-
# 클러스터 중심점 표시
|
|
2645
|
-
centers = estimator.cluster_centers_ # type: ignore
|
|
2646
|
-
ax.scatter( # type: ignore
|
|
2647
|
-
centers[:, xindex],
|
|
2648
|
-
centers[:, yindex],
|
|
2649
|
-
marker="o",
|
|
2650
|
-
color="white",
|
|
2651
|
-
alpha=1,
|
|
2652
|
-
s=200,
|
|
2653
|
-
edgecolor="r",
|
|
2654
|
-
linewidth=linewidth
|
|
2655
|
-
)
|
|
2656
|
-
|
|
2657
|
-
for i, c in enumerate(centers):
|
|
2658
|
-
ax.scatter(c[xindex], c[yindex], marker="$%d$" % i, alpha=1, s=50, edgecolor="k")
|
|
2659
|
-
|
|
2660
2709
|
ax.set_xlabel("Feature space for the " + xname)
|
|
2661
2710
|
ax.set_ylabel("Feature space for the " + yname)
|
|
2662
2711
|
|
|
2712
|
+
if hasattr(estimator, "cluster_centers_"):
|
|
2713
|
+
# 클러스터 중심점 표시
|
|
2714
|
+
centers = estimator.cluster_centers_ # type: ignore
|
|
2715
|
+
ax.scatter( # type: ignore
|
|
2716
|
+
centers[:, xindex],
|
|
2717
|
+
centers[:, yindex],
|
|
2718
|
+
marker="o",
|
|
2719
|
+
color="white",
|
|
2720
|
+
alpha=1,
|
|
2721
|
+
s=200,
|
|
2722
|
+
edgecolor="r",
|
|
2723
|
+
linewidth=linewidth,
|
|
2724
|
+
)
|
|
2725
|
+
|
|
2726
|
+
for i, c in enumerate(centers):
|
|
2727
|
+
ax.scatter(
|
|
2728
|
+
c[xindex], c[yindex], marker="$%d$" % i, alpha=1, s=50, edgecolor="k"
|
|
2729
|
+
)
|
|
2730
|
+
|
|
2663
2731
|
scatterplot(
|
|
2664
|
-
df=df,
|
|
2732
|
+
df=df, # type: ignore
|
|
2665
2733
|
xname=xname,
|
|
2666
2734
|
yname=yname,
|
|
2667
2735
|
hue=hue,
|
|
2736
|
+
vector=vector,
|
|
2668
2737
|
title="The visualization of the clustered data." if title is None else title,
|
|
2669
2738
|
outline=outline,
|
|
2670
2739
|
palette=palette,
|
|
@@ -2674,16 +2743,40 @@ def cluster_plot(
|
|
|
2674
2743
|
dpi=dpi,
|
|
2675
2744
|
save_path=save_path,
|
|
2676
2745
|
callback=callback,
|
|
2677
|
-
ax=ax
|
|
2746
|
+
ax=ax,
|
|
2678
2747
|
)
|
|
2679
2748
|
|
|
2680
|
-
|
|
2749
|
+
|
|
2750
|
+
def visualize_silhouette(
|
|
2751
|
+
estimator: KMeans,
|
|
2752
|
+
data: DataFrame,
|
|
2753
|
+
xname: str | None = None,
|
|
2754
|
+
yname: str | None = None,
|
|
2755
|
+
title: str | None = None,
|
|
2756
|
+
palette: str | None = None,
|
|
2757
|
+
outline: bool = False,
|
|
2758
|
+
width: int = config.width,
|
|
2759
|
+
height: int = config.height,
|
|
2760
|
+
linewidth: float = config.line_width,
|
|
2761
|
+
dpi: int = config.dpi,
|
|
2762
|
+
save_path: str | None = None,
|
|
2763
|
+
) -> None:
|
|
2681
2764
|
"""
|
|
2682
2765
|
군집분석 결과의 실루엣 플롯과 군집 산점도를 한 화면에 함께 시각화함.
|
|
2683
2766
|
|
|
2684
2767
|
Args:
|
|
2685
2768
|
estimator (KMeans): 학습된 KMeans 군집 모델 객체.
|
|
2686
2769
|
data (DataFrame): 군집분석에 사용된 입력 데이터 (n_samples, n_features).
|
|
2770
|
+
xname (str, optional): 산점도 x축에 사용할 컬럼명. None이면 첫 번째 컬럼 사용.
|
|
2771
|
+
yname (str, optional): 산점도 y축에 사용할 컬럼명. None이면 두 번째 컬럼 사용.
|
|
2772
|
+
title (str, optional): 플롯 제목. None이면 기본값 사용.
|
|
2773
|
+
palette (str, optional): 색상 팔레트.
|
|
2774
|
+
outline (bool, optional): 산점도 외곽선 표시 여부.
|
|
2775
|
+
width (int, optional): 플롯 가로 크기 (inch 단위).
|
|
2776
|
+
height (int, optional): 플롯 세로 크기 (inch 단위).
|
|
2777
|
+
linewidth (float, optional): 기준선 등 선 두께.
|
|
2778
|
+
dpi (int, optional): 플롯 해상도(DPI).
|
|
2779
|
+
save_path (str, optional): 저장 경로 지정 시 파일로 저장.
|
|
2687
2780
|
|
|
2688
2781
|
Returns:
|
|
2689
2782
|
None
|
|
@@ -2692,18 +2785,29 @@ def visualize_silhouette(estimator: KMeans, data: DataFrame) -> None:
|
|
|
2692
2785
|
- 실루엣 플롯(왼쪽)과 2차원 군집 산점도(오른쪽)를 동시에 확인 가능
|
|
2693
2786
|
- 군집 품질과 분포를 한눈에 비교·분석할 때 유용
|
|
2694
2787
|
"""
|
|
2695
|
-
fig, ax = get_default_ax(rows=1, cols=2)
|
|
2788
|
+
fig, ax = get_default_ax(rows=1, cols=2, width=width, height=height, dpi=dpi, title=title)
|
|
2696
2789
|
|
|
2697
2790
|
silhouette_plot(
|
|
2698
2791
|
estimator=estimator,
|
|
2699
2792
|
data=data,
|
|
2700
2793
|
ax=ax[0],
|
|
2794
|
+
linewidth=linewidth,
|
|
2795
|
+
width=width,
|
|
2796
|
+
height=height,
|
|
2797
|
+
dpi=dpi
|
|
2701
2798
|
)
|
|
2702
2799
|
|
|
2703
2800
|
cluster_plot(
|
|
2704
2801
|
estimator=estimator,
|
|
2705
2802
|
data=data,
|
|
2803
|
+
xname=xname,
|
|
2804
|
+
yname=yname,
|
|
2706
2805
|
ax=ax[1],
|
|
2806
|
+
outline=outline,
|
|
2807
|
+
palette=palette,
|
|
2808
|
+
width=width,
|
|
2809
|
+
height=height,
|
|
2810
|
+
dpi=dpi
|
|
2707
2811
|
)
|
|
2708
2812
|
|
|
2709
|
-
finalize_plot(ax)
|
|
2813
|
+
finalize_plot(ax)
|