hossam 0.4.12__py3-none-any.whl → 0.4.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hossam/hs_plot.py CHANGED
@@ -7,7 +7,7 @@ from typing import Callable
7
7
  import numpy as np
8
8
  import seaborn as sb
9
9
  import matplotlib.pyplot as plt
10
- from matplotlib.pyplot import Axes # type: ignore
10
+ from matplotlib.pyplot import Axes # type: ignore
11
11
  from pandas import Series, DataFrame
12
12
  from math import sqrt
13
13
  from pandas import DataFrame
@@ -31,7 +31,7 @@ from sklearn.metrics import (
31
31
  auc,
32
32
  confusion_matrix,
33
33
  silhouette_score,
34
- silhouette_samples
34
+ silhouette_samples,
35
35
  )
36
36
 
37
37
  # ===================================================================
@@ -45,13 +45,24 @@ config = SimpleNamespace(
45
45
  line_width=1,
46
46
  grid_alpha=0.3,
47
47
  grid_width=0.5,
48
- fill_alpha=0.3
48
+ fill_alpha=0.3,
49
49
  )
50
50
 
51
+
51
52
  # ===================================================================
52
53
  # 기본 크기가 설정된 Figure와 Axes를 생성한다
53
54
  # ===================================================================
54
- def get_default_ax(width: int = config.width, height: int = config.height, rows: int = 1, cols: int = 1, dpi: int = config.dpi, flatten: bool = False, ws: int | None = None, hs: int | None = None, title: str | None = None):
55
+ def get_default_ax(
56
+ width: int = config.width,
57
+ height: int = config.height,
58
+ rows: int = 1,
59
+ cols: int = 1,
60
+ dpi: int = config.dpi,
61
+ flatten: bool = False,
62
+ ws: int | None = None,
63
+ hs: int | None = None,
64
+ title: str | None = None,
65
+ ):
55
66
  """기본 크기의 Figure와 Axes를 생성한다.
56
67
 
57
68
  Args:
@@ -78,7 +89,7 @@ def get_default_ax(width: int = config.width, height: int = config.height, rows:
78
89
  fig.subplots_adjust(wspace=ws, hspace=hs)
79
90
 
80
91
  if title and is_array:
81
- fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight='bold')
92
+ fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight="bold")
82
93
 
83
94
  if flatten == True:
84
95
  # 단일 Axes인 경우 리스트로 변환
@@ -97,7 +108,7 @@ def get_default_ax(width: int = config.width, height: int = config.height, rows:
97
108
  for spine in a.spines.values():
98
109
  spine.set_linewidth(config.frame_width)
99
110
  else:
100
- for spine in ax.spines.values(): # type: ignore
111
+ for spine in ax.spines.values(): # type: ignore
101
112
  spine.set_linewidth(config.frame_width)
102
113
 
103
114
  return fig, ax
@@ -106,7 +117,17 @@ def get_default_ax(width: int = config.width, height: int = config.height, rows:
106
117
  # ===================================================================
107
118
  # 기본 크기가 설정된 Figure와 Axes를 생성한다
108
119
  # ===================================================================
109
- def create_figure(width: int = config.width, height: int = config.height, rows: int = 1, cols: int = 1, dpi: int = config.dpi, flatten: bool = False, ws: int | None = None, hs: int | None = None, title: str | None = None):
120
+ def create_figure(
121
+ width: int = config.width,
122
+ height: int = config.height,
123
+ rows: int = 1,
124
+ cols: int = 1,
125
+ dpi: int = config.dpi,
126
+ flatten: bool = False,
127
+ ws: int | None = None,
128
+ hs: int | None = None,
129
+ title: str | None = None,
130
+ ):
110
131
  """기본 크기의 Figure와 Axes를 생성한다. get_default_ax의 래퍼 함수.
111
132
 
112
133
  Args:
@@ -130,7 +151,14 @@ def create_figure(width: int = config.width, height: int = config.height, rows:
130
151
  # ===================================================================
131
152
  # 그래프의 그리드, 레이아웃을 정리하고 필요 시 저장 또는 표시한다
132
153
  # ===================================================================
133
- def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None, outparams: bool = False, save_path: str | None = None, grid: bool = True, title: str | None = None) -> None:
154
+ def finalize_plot(
155
+ ax: Axes | np.ndarray | list,
156
+ callback: Callable | None = None,
157
+ outparams: bool = False,
158
+ save_path: str | None = None,
159
+ grid: bool = True,
160
+ title: str | None = None,
161
+ ) -> None:
134
162
  """공통 후처리를 수행한다: 콜백 실행, 레이아웃 정리, 필요 시 표시/종료.
135
163
 
136
164
  Args:
@@ -149,7 +177,7 @@ def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None
149
177
  # callback 실행
150
178
  if callback:
151
179
  if is_array:
152
- for a in (ax.flat if isinstance(ax, np.ndarray) else ax):
180
+ for a in ax.flat if isinstance(ax, np.ndarray) else ax:
153
181
  callback(a)
154
182
  else:
155
183
  callback(ax)
@@ -157,7 +185,7 @@ def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None
157
185
  # grid 설정
158
186
  if grid:
159
187
  if is_array:
160
- for a in (ax.flat if isinstance(ax, np.ndarray) else ax):
188
+ for a in ax.flat if isinstance(ax, np.ndarray) else ax:
161
189
  a.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
162
190
  else:
163
191
  ax.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
@@ -165,10 +193,10 @@ def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None
165
193
  plt.tight_layout()
166
194
 
167
195
  if title and not is_array:
168
- ax.set_title(title, fontsize=config.font_size * 1.3, pad=7, fontweight='bold')
196
+ ax.set_title(title, fontsize=config.font_size * 1.3, pad=7, fontweight="bold")
169
197
 
170
198
  if save_path is not None:
171
- plt.savefig(save_path, dpi=config.dpi * 2, bbox_inches='tight')
199
+ plt.savefig(save_path, dpi=config.dpi * 2, bbox_inches="tight")
172
200
 
173
201
  if outparams:
174
202
  plt.show()
@@ -178,7 +206,14 @@ def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None
178
206
  # ===================================================================
179
207
  # 그래프의 그리드, 레이아웃을 정리하고 필요 시 저장 또는 표시한다
180
208
  # ===================================================================
181
- def show_figure(ax: Axes | np.ndarray, callback: Callable | None = None, outparams: bool = False, save_path: str | None = None, grid: bool = True, title: str | None = None) -> None:
209
+ def show_figure(
210
+ ax: Axes | np.ndarray,
211
+ callback: Callable | None = None,
212
+ outparams: bool = False,
213
+ save_path: str | None = None,
214
+ grid: bool = True,
215
+ title: str | None = None,
216
+ ) -> None:
182
217
  """공통 후처리를 수행한다: 콜백 실행, 레이아웃 정리, 필요 시 표시/종료.
183
218
  finalize_plot의 래퍼 함수.
184
219
 
@@ -241,7 +276,7 @@ def lineplot(
241
276
  outparams = False
242
277
 
243
278
  if ax is None:
244
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
279
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
245
280
  outparams = True
246
281
 
247
282
  # hue가 있을 때만 palette 사용, 없으면 color 사용
@@ -262,7 +297,7 @@ def lineplot(
262
297
  lineplot_kwargs.update(params)
263
298
 
264
299
  sb.lineplot(**lineplot_kwargs, linewidth=linewidth)
265
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
300
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
266
301
 
267
302
 
268
303
  # ===================================================================
@@ -316,7 +351,7 @@ def boxplot(
316
351
  outparams = False
317
352
 
318
353
  if ax is None:
319
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
354
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
320
355
  outparams = True
321
356
 
322
357
  if xname is not None and yname is not None:
@@ -344,13 +379,17 @@ def boxplot(
344
379
  if stat_pairs is None:
345
380
  stat_pairs = [df[xname].dropna().unique().tolist()]
346
381
 
347
- annotator = Annotator(ax, data=df, x=xname, y=yname, pairs=stat_pairs, orient=orient)
348
- annotator.configure(test=stat_test, text_format=stat_text_format, loc=stat_loc)
382
+ annotator = Annotator(
383
+ ax, data=df, x=xname, y=yname, pairs=stat_pairs, orient=orient
384
+ )
385
+ annotator.configure(
386
+ test=stat_test, text_format=stat_text_format, loc=stat_loc
387
+ )
349
388
  annotator.apply_and_annotate()
350
389
  else:
351
- sb.boxplot(data=df, orient=orient, ax=ax, linewidth=linewidth, **params) # type: ignore
390
+ sb.boxplot(data=df, orient=orient, ax=ax, linewidth=linewidth, **params) # type: ignore
352
391
 
353
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
392
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
354
393
 
355
394
 
356
395
  # ===================================================================
@@ -372,7 +411,7 @@ def pvalue1_anotation(
372
411
  save_path: str | None = None,
373
412
  callback: Callable | None = None,
374
413
  ax: Axes | None = None,
375
- **params
414
+ **params,
376
415
  ) -> None:
377
416
  """
378
417
  boxplot의 wrapper 함수로, 상자그림에 p-value 주석을 추가한다.
@@ -395,7 +434,7 @@ def pvalue1_anotation(
395
434
  save_path=save_path,
396
435
  callback=callback,
397
436
  ax=ax,
398
- **params
437
+ **params,
399
438
  )
400
439
 
401
440
 
@@ -452,7 +491,9 @@ def kdeplot(
452
491
  # 사분위수 분할 전용 처리 (1D KDE만 지원)
453
492
  if quartile_split:
454
493
  if yname is not None:
455
- raise ValueError("quartile_split은 1차원 KDE(xname)에서만 사용할 수 있습니다.")
494
+ raise ValueError(
495
+ "quartile_split은 1차원 KDE(xname)에서만 사용할 수 있습니다."
496
+ )
456
497
 
457
498
  series = df[xname].dropna()
458
499
  if series.empty:
@@ -499,7 +540,7 @@ def kdeplot(
499
540
  return
500
541
 
501
542
  if ax is None:
502
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
543
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
503
544
  outparams = True
504
545
 
505
546
  # 기본 kwargs 설정
@@ -529,7 +570,7 @@ def kdeplot(
529
570
 
530
571
  sb.kdeplot(**kdeplot_kwargs)
531
572
 
532
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
573
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
533
574
 
534
575
 
535
576
  # ===================================================================
@@ -576,7 +617,7 @@ def histplot(
576
617
  outparams = False
577
618
 
578
619
  if ax is None:
579
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
620
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
580
621
  outparams = True
581
622
 
582
623
  if bins:
@@ -604,7 +645,7 @@ def histplot(
604
645
  "hue": hue,
605
646
  "kde": kde,
606
647
  "linewidth": linewidth,
607
- "ax": ax
648
+ "ax": ax,
608
649
  }
609
650
 
610
651
  if hue is not None and palette is not None:
@@ -615,7 +656,7 @@ def histplot(
615
656
  histplot_kwargs.update(params)
616
657
  sb.histplot(**histplot_kwargs)
617
658
 
618
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
659
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
619
660
 
620
661
 
621
662
  # ===================================================================
@@ -658,7 +699,7 @@ def stackplot(
658
699
  outparams = False
659
700
 
660
701
  if ax is None:
661
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
702
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
662
703
  outparams = True
663
704
 
664
705
  df2 = df[[xname, hue]].copy()
@@ -685,11 +726,11 @@ def stackplot(
685
726
  sb.histplot(**stackplot_kwargs)
686
727
 
687
728
  # 그래프의 x축 항목 수 만큼 반복
688
- for p in ax.patches: # type: ignore
729
+ for p in ax.patches: # type: ignore
689
730
  # 각 막대의 위치, 넓이, 높이
690
- left, bottom, width, height = p.get_bbox().bounds # type: ignore
731
+ left, bottom, width, height = p.get_bbox().bounds # type: ignore
691
732
  # 막대의 중앙에 글자 표시하기
692
- ax.annotate( # type: ignore
733
+ ax.annotate( # type: ignore
693
734
  "%0.1f%%" % (height * 100),
694
735
  xy=(left + width / 2, bottom + height / 2),
695
736
  ha="center",
@@ -698,10 +739,10 @@ def stackplot(
698
739
 
699
740
  if str(df[xname].dtype) in ["int", "int32", "int64", "float", "float32", "float64"]:
700
741
  xticks = list(df[xname].unique())
701
- ax.set_xticks(xticks) # type: ignore
742
+ ax.set_xticks(xticks) # type: ignore
702
743
  ax.set_xticklabels(xticks) # type: ignore
703
744
 
704
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
745
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
705
746
 
706
747
 
707
748
  # ===================================================================
@@ -750,10 +791,9 @@ def scatterplot(
750
791
  outparams = False
751
792
 
752
793
  if ax is None:
753
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
794
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
754
795
  outparams = True
755
796
 
756
-
757
797
  if outline and hue is not None:
758
798
  # 군집별 값의 종류별로 반복 수행
759
799
  for c in df[hue].unique():
@@ -770,8 +810,11 @@ def scatterplot(
770
810
  # 마지막 좌표 이후에 첫 번째 좌표를 연결
771
811
  points = np.append(hull.vertices, hull.vertices[0])
772
812
 
773
- ax.plot( # type: ignore
774
- df_c.iloc[points, 0], df_c.iloc[points, 1], linewidth=linewidth, linestyle=":"
813
+ ax.plot( # type: ignore
814
+ df_c.iloc[points, 0],
815
+ df_c.iloc[points, 1],
816
+ linewidth=linewidth,
817
+ linestyle=":",
775
818
  )
776
819
  ax.fill(df_c.iloc[points, 0], df_c.iloc[points, 1], alpha=0.1) # type: ignore
777
820
  except:
@@ -798,27 +841,26 @@ def scatterplot(
798
841
  sb.scatterplot(data=df, **scatterplot_kwargs)
799
842
  else:
800
843
  # 핵심벡터
801
- scatterplot_kwargs['edgecolor'] = '#ffffff'
844
+ scatterplot_kwargs["edgecolor"] = "#ffffff"
802
845
  sb.scatterplot(data=df[df[vector] == "core"], **scatterplot_kwargs)
803
846
 
804
847
  # 외곽백터
805
- scatterplot_kwargs['edgecolor'] = '#000000'
806
- scatterplot_kwargs['s'] = 25
807
- scatterplot_kwargs['marker'] = '^'
808
- scatterplot_kwargs['linewidth'] = 0.8
848
+ scatterplot_kwargs["edgecolor"] = "#000000"
849
+ scatterplot_kwargs["s"] = 25
850
+ scatterplot_kwargs["marker"] = "^"
851
+ scatterplot_kwargs["linewidth"] = 0.8
809
852
  sb.scatterplot(data=df[df[vector] == "border"], **scatterplot_kwargs)
810
853
 
811
854
  # 노이즈벡터
812
- scatterplot_kwargs['edgecolor'] = None
813
- scatterplot_kwargs['s'] = 25
814
- scatterplot_kwargs['marker'] = 'x'
815
- scatterplot_kwargs['linewidth'] = 2
816
- scatterplot_kwargs['color'] = '#ff0000'
817
- scatterplot_kwargs['hue'] = None
855
+ scatterplot_kwargs["edgecolor"] = None
856
+ scatterplot_kwargs["s"] = 25
857
+ scatterplot_kwargs["marker"] = "x"
858
+ scatterplot_kwargs["linewidth"] = 2
859
+ scatterplot_kwargs["color"] = "#ff0000"
860
+ scatterplot_kwargs["hue"] = None
818
861
  sb.scatterplot(data=df[df[vector] == "noise"], **scatterplot_kwargs)
819
862
 
820
-
821
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
863
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
822
864
 
823
865
 
824
866
  # ===================================================================
@@ -861,7 +903,7 @@ def regplot(
861
903
  outparams = False
862
904
 
863
905
  if ax is None:
864
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
906
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
865
907
  outparams = True
866
908
 
867
909
  # regplot은 hue를 지원하지 않으므로 palette를 color로 변환
@@ -877,13 +919,9 @@ def regplot(
877
919
  "s": 20,
878
920
  "linewidths": 0.5,
879
921
  "edgecolor": "w",
880
- "color": scatter_color
881
- },
882
- "line_kws": {
883
- "color": "red",
884
- "linestyle": "--",
885
- "linewidth": linewidth
922
+ "color": scatter_color,
886
923
  },
924
+ "line_kws": {"color": "red", "linestyle": "--", "linewidth": linewidth},
887
925
  "ax": ax,
888
926
  }
889
927
 
@@ -891,7 +929,7 @@ def regplot(
891
929
 
892
930
  sb.regplot(**regplot_kwargs)
893
931
 
894
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
932
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
895
933
 
896
934
 
897
935
  # ===================================================================
@@ -951,19 +989,19 @@ def lmplot(
951
989
  # 회귀선에 linewidth 적용
952
990
  for ax in g.axes.flat:
953
991
  for line in ax.get_lines():
954
- if line.get_marker() == 'o': # 산점도는 건너뛰기
992
+ if line.get_marker() == "o": # 산점도는 건너뛰기
955
993
  continue
956
994
  line.set_linewidth(linewidth)
957
995
 
958
996
  g.fig.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width) # type: ignore
959
997
 
960
998
  if title:
961
- g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight='bold')
999
+ g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight="bold")
962
1000
 
963
1001
  plt.tight_layout()
964
1002
 
965
1003
  if save_path is not None:
966
- plt.savefig(save_path, dpi=dpi*2, bbox_inches='tight')
1004
+ plt.savefig(save_path, dpi=dpi * 2, bbox_inches="tight")
967
1005
 
968
1006
  plt.show()
969
1007
  plt.close()
@@ -1012,7 +1050,7 @@ def pairplot(
1012
1050
  if xnames is None:
1013
1051
  # 모든 연속형(숫자형) 컬럼 선택 (명목형/카테고리 제외)
1014
1052
  numeric_cols = df.select_dtypes(include=[np.number]).columns
1015
- target_cols = [col for col in numeric_cols if df[col].dtype.name != 'category']
1053
+ target_cols = [col for col in numeric_cols if df[col].dtype.name != "category"]
1016
1054
  elif isinstance(xnames, str):
1017
1055
  # 문자열: 해당 컬럼만
1018
1056
  target_cols = [xnames]
@@ -1022,7 +1060,7 @@ def pairplot(
1022
1060
  else:
1023
1061
  # 기본값으로 연속형 컬럼
1024
1062
  numeric_cols = df.select_dtypes(include=[np.number]).columns
1025
- target_cols = [col for col in numeric_cols if df[col].dtype.name != 'category']
1063
+ target_cols = [col for col in numeric_cols if df[col].dtype.name != "category"]
1026
1064
 
1027
1065
  # hue 컬럼이 있으면 target_cols에 포함시키기 (pairplot 자체에서 필요)
1028
1066
  if hue is not None and hue not in target_cols:
@@ -1050,9 +1088,11 @@ def pairplot(
1050
1088
  g.fig.set_dpi(dpi)
1051
1089
 
1052
1090
  if title:
1053
- g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight='bold')
1091
+ g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight="bold")
1054
1092
 
1055
- g.map_lower(func=sb.kdeplot, fill=True, alpha=config.fill_alpha, linewidth=linewidth)
1093
+ g.map_lower(
1094
+ func=sb.kdeplot, fill=True, alpha=config.fill_alpha, linewidth=linewidth
1095
+ )
1056
1096
  g.map_upper(func=sb.scatterplot, linewidth=linewidth)
1057
1097
 
1058
1098
  # KDE 대각선에도 linewidth 적용
@@ -1063,7 +1103,7 @@ def pairplot(
1063
1103
  plt.tight_layout()
1064
1104
 
1065
1105
  if save_path is not None:
1066
- plt.savefig(save_path, dpi=dpi*2, bbox_inches='tight')
1106
+ plt.savefig(save_path, dpi=dpi * 2, bbox_inches="tight")
1067
1107
 
1068
1108
  plt.show()
1069
1109
  plt.close()
@@ -1117,7 +1157,7 @@ def countplot(
1117
1157
  sort = sorted(list(df[xname].value_counts().index))
1118
1158
 
1119
1159
  if ax is None:
1120
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1160
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1121
1161
  outparams = True
1122
1162
 
1123
1163
  # hue가 있을 때만 palette 사용, 없으면 color 사용
@@ -1140,7 +1180,7 @@ def countplot(
1140
1180
 
1141
1181
  sb.countplot(**countplot_kwargs)
1142
1182
 
1143
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1183
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1144
1184
 
1145
1185
 
1146
1186
  # ===================================================================
@@ -1185,7 +1225,7 @@ def barplot(
1185
1225
  outparams = False
1186
1226
 
1187
1227
  if ax is None:
1188
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1228
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1189
1229
  outparams = True
1190
1230
 
1191
1231
  # hue가 있을 때만 palette 사용, 없으면 color 사용
@@ -1206,7 +1246,7 @@ def barplot(
1206
1246
  barplot_kwargs.update(params)
1207
1247
 
1208
1248
  sb.barplot(**barplot_kwargs)
1209
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1249
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1210
1250
 
1211
1251
 
1212
1252
  # ===================================================================
@@ -1251,7 +1291,7 @@ def boxenplot(
1251
1291
  outparams = False
1252
1292
 
1253
1293
  if ax is None:
1254
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1294
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1255
1295
  outparams = True
1256
1296
 
1257
1297
  # palette은 hue가 있을 때만 사용
@@ -1270,7 +1310,7 @@ def boxenplot(
1270
1310
  boxenplot_kwargs.update(params)
1271
1311
 
1272
1312
  sb.boxenplot(**boxenplot_kwargs)
1273
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1313
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1274
1314
 
1275
1315
 
1276
1316
  # ===================================================================
@@ -1315,7 +1355,7 @@ def violinplot(
1315
1355
  outparams = False
1316
1356
 
1317
1357
  if ax is None:
1318
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1358
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1319
1359
  outparams = True
1320
1360
 
1321
1361
  # palette은 hue가 있을 때만 사용
@@ -1333,7 +1373,7 @@ def violinplot(
1333
1373
 
1334
1374
  violinplot_kwargs.update(params)
1335
1375
  sb.violinplot(**violinplot_kwargs)
1336
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1376
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1337
1377
 
1338
1378
 
1339
1379
  # ===================================================================
@@ -1378,7 +1418,7 @@ def pointplot(
1378
1418
  outparams = False
1379
1419
 
1380
1420
  if ax is None:
1381
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1421
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1382
1422
  outparams = True
1383
1423
 
1384
1424
  # hue가 있을 때만 palette 사용, 없으면 color 사용
@@ -1398,7 +1438,7 @@ def pointplot(
1398
1438
 
1399
1439
  pointplot_kwargs.update(params)
1400
1440
  sb.pointplot(**pointplot_kwargs)
1401
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1441
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1402
1442
 
1403
1443
 
1404
1444
  # ===================================================================
@@ -1456,7 +1496,7 @@ def jointplot(
1456
1496
  g.fig.set_dpi(dpi)
1457
1497
 
1458
1498
  if title:
1459
- g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight='bold')
1499
+ g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight="bold")
1460
1500
 
1461
1501
  # 중앙 및 주변 플롯에 grid 추가
1462
1502
  g.ax_joint.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
@@ -1466,7 +1506,7 @@ def jointplot(
1466
1506
  plt.tight_layout()
1467
1507
 
1468
1508
  if save_path is not None:
1469
- plt.savefig(save_path, dpi=dpi*2, bbox_inches='tight')
1509
+ plt.savefig(save_path, dpi=dpi * 2, bbox_inches="tight")
1470
1510
 
1471
1511
  plt.show()
1472
1512
  plt.close()
@@ -1509,7 +1549,7 @@ def heatmap(
1509
1549
 
1510
1550
  if width == None or height == None:
1511
1551
  width = (config.font_size * config.dpi / 72) * 4.5 * len(data.columns)
1512
- height = width * 0.8 # type: ignore
1552
+ height = width * 0.8 # type: ignore
1513
1553
 
1514
1554
  if ax is None:
1515
1555
  fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
@@ -1522,7 +1562,7 @@ def heatmap(
1522
1562
  "fmt": ".2f",
1523
1563
  "ax": ax,
1524
1564
  "linewidths": linewidth,
1525
- "annot_kws": {"size": 10}
1565
+ "annot_kws": {"size": 10},
1526
1566
  }
1527
1567
 
1528
1568
  heatmatp_kwargs.update(params)
@@ -1533,7 +1573,6 @@ def heatmap(
1533
1573
  finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1534
1574
 
1535
1575
 
1536
-
1537
1576
  # ===================================================================
1538
1577
  # KDE와 신뢰구간을 나타낸 그래프를 그린다
1539
1578
  # ===================================================================
@@ -1619,19 +1658,30 @@ def kde_confidence_interval(
1619
1658
  cmin, cmax = t.interval(clevel, dof, loc=sample_mean, scale=sample_std_error)
1620
1659
 
1621
1660
  # 현재 컬럼에 대한 커널밀도추정
1622
- sb.kdeplot(data=column, linewidth=linewidth, ax=current_ax, fill=fill, alpha=config.fill_alpha) # type: ignore
1661
+ sb.kdeplot(data=column, linewidth=linewidth, ax=current_ax, fill=fill, alpha=config.fill_alpha) # type: ignore
1623
1662
 
1624
1663
  # 그래프 축의 범위
1625
1664
  xmin, xmax, ymin, ymax = current_ax.get_position().bounds
1626
1665
  ymin_val, ymax_val = 0, current_ax.get_ylim()[1]
1627
1666
 
1628
1667
  # 신뢰구간 그리기
1629
- current_ax.plot([cmin, cmin], [ymin_val, ymax_val], linestyle=":", linewidth=linewidth*0.5)
1630
- current_ax.plot([cmax, cmax], [ymin_val, ymax_val], linestyle=":", linewidth=linewidth*0.5)
1631
- current_ax.fill_between([cmin, cmax], y1=ymin_val, y2=ymax_val, alpha=config.fill_alpha)
1668
+ current_ax.plot(
1669
+ [cmin, cmin], [ymin_val, ymax_val], linestyle=":", linewidth=linewidth * 0.5
1670
+ )
1671
+ current_ax.plot(
1672
+ [cmax, cmax], [ymin_val, ymax_val], linestyle=":", linewidth=linewidth * 0.5
1673
+ )
1674
+ current_ax.fill_between(
1675
+ [cmin, cmax], y1=ymin_val, y2=ymax_val, alpha=config.fill_alpha
1676
+ )
1632
1677
 
1633
1678
  # 평균 그리기
1634
- current_ax.plot([sample_mean, sample_mean], [0, ymax_val], linestyle="--", linewidth=linewidth)
1679
+ current_ax.plot(
1680
+ [sample_mean, sample_mean],
1681
+ [0, ymax_val],
1682
+ linestyle="--",
1683
+ linewidth=linewidth,
1684
+ )
1635
1685
 
1636
1686
  current_ax.text(
1637
1687
  x=(cmax - cmin) / 2 + cmin,
@@ -1644,8 +1694,7 @@ def kde_confidence_interval(
1644
1694
 
1645
1695
  current_ax.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
1646
1696
 
1647
- finalize_plot(axes[0] if isinstance(axes, list) and len(axes) > 0 else ax, callback, outparams, save_path, True, title) # type: ignore
1648
-
1697
+ finalize_plot(axes[0] if isinstance(axes, list) and len(axes) > 0 else ax, callback, outparams, save_path, True, title) # type: ignore
1649
1698
 
1650
1699
 
1651
1700
  # ===================================================================
@@ -1711,32 +1760,29 @@ def ols_residplot(
1711
1760
  sb.scatterplot(x=y_pred, y=resid, ax=ax, s=20, edgecolor="white", **params)
1712
1761
 
1713
1762
  # 기준선 (잔차 = 0)
1714
- ax.axhline(0, color="gray", linestyle="--", linewidth=linewidth*0.7) # type: ignore
1763
+ ax.axhline(0, color="gray", linestyle="--", linewidth=linewidth * 0.7) # type: ignore
1715
1764
 
1716
1765
  # LOWESS 스무딩 (선택적)
1717
1766
  if lowess:
1718
1767
  lowess_result = sm_lowess(resid, y_pred, frac=0.6667)
1719
- ax.plot(lowess_result[:, 0], lowess_result[:, 1], # type: ignore
1720
- color="red", linewidth=linewidth, label="LOWESS") # type: ignore
1768
+ ax.plot( # type: ignore
1769
+ lowess_result[:, 0],
1770
+ lowess_result[:, 1], # type: ignore
1771
+ color="red",
1772
+ linewidth=linewidth,
1773
+ label="LOWESS",
1774
+ ) # type: ignore
1721
1775
 
1722
1776
  ax.set_xlabel("Fitted values") # type: ignore
1723
- ax.set_ylabel("Residuals") # type: ignore
1777
+ ax.set_ylabel("Residuals") # type: ignore
1724
1778
 
1725
1779
  if mse:
1726
1780
  mse_val = mean_squared_error(y, y_pred)
1727
1781
  mse_sq = np.sqrt(mse_val)
1728
1782
 
1729
1783
  r1 = resid[(resid > -mse_sq) & (resid < mse_sq)].size / resid.size * 100
1730
- r2 = (
1731
- resid[(resid > -2 * mse_sq) & (resid < 2 * mse_sq)].size
1732
- / resid.size
1733
- * 100
1734
- )
1735
- r3 = (
1736
- resid[(resid > -3 * mse_sq) & (resid < 3 * mse_sq)].size
1737
- / resid.size
1738
- * 100
1739
- )
1784
+ r2 = resid[(resid > -2 * mse_sq) & (resid < 2 * mse_sq)].size / resid.size * 100
1785
+ r3 = resid[(resid > -3 * mse_sq) & (resid < 3 * mse_sq)].size / resid.size * 100
1740
1786
 
1741
1787
  mse_r = [r1, r2, r3]
1742
1788
 
@@ -1747,26 +1793,26 @@ def ols_residplot(
1747
1793
  alphas = [0.15, 0.10, 0.05] # 안쪽이 더 진하게
1748
1794
 
1749
1795
  # 3σ 영역 (가장 바깥쪽, 가장 연함)
1750
- ax.axhspan(-3 * mse_sq, 3 * mse_sq, facecolor=colors[2], alpha=alphas[2], zorder=0) # type: ignore
1796
+ ax.axhspan(-3 * mse_sq, 3 * mse_sq, facecolor=colors[2], alpha=alphas[2], zorder=0) # type: ignore
1751
1797
  # 2σ 영역 (중간)
1752
- ax.axhspan(-2 * mse_sq, 2 * mse_sq, facecolor=colors[1], alpha=alphas[1], zorder=1) # type: ignore
1798
+ ax.axhspan(-2 * mse_sq, 2 * mse_sq, facecolor=colors[1], alpha=alphas[1], zorder=1) # type: ignore
1753
1799
  # 1σ 영역 (가장 안쪽, 가장 진함)
1754
- ax.axhspan(-mse_sq, mse_sq, facecolor=colors[0], alpha=alphas[0], zorder=2) # type: ignore
1800
+ ax.axhspan(-mse_sq, mse_sq, facecolor=colors[0], alpha=alphas[0], zorder=2) # type: ignore
1755
1801
 
1756
1802
  # 경계선 그리기
1757
1803
  for i, c in enumerate(["red", "green", "blue"]):
1758
- ax.axhline(mse_sq * (i + 1), color=c, linestyle="--", linewidth=linewidth/2) # type: ignore
1759
- ax.axhline(mse_sq * (-(i + 1)), color=c, linestyle="--", linewidth=linewidth/2) # type: ignore
1804
+ ax.axhline(mse_sq * (i + 1), color=c, linestyle="--", linewidth=linewidth / 2) # type: ignore
1805
+ ax.axhline(mse_sq * (-(i + 1)), color=c, linestyle="--", linewidth=linewidth / 2) # type: ignore
1760
1806
 
1761
1807
  target = [68, 95, 99.7]
1762
1808
  for i, c in enumerate(["red", "green", "blue"]):
1763
- ax.text( # type: ignore
1809
+ ax.text( # type: ignore
1764
1810
  s=f"{i+1} sqrt(MSE) = {mse_r[i]:.2f}% ({mse_r[i] - target[i]:.2f}%)",
1765
1811
  x=xmax + 0.05,
1766
1812
  y=(i + 1) * mse_sq,
1767
1813
  color=c,
1768
1814
  )
1769
- ax.text( # type: ignore
1815
+ ax.text( # type: ignore
1770
1816
  s=f"-{i+1} sqrt(MSE) = {mse_r[i]:.2f}% ({mse_r[i] - target[i]:.2f}%)",
1771
1817
  x=xmax + 0.05,
1772
1818
  y=-(i + 1) * mse_sq,
@@ -1782,7 +1828,7 @@ def ols_residplot(
1782
1828
  def ols_qqplot(
1783
1829
  fit,
1784
1830
  title: str | None = None,
1785
- line: str = 's',
1831
+ line: str = "s",
1786
1832
  width: int = config.width,
1787
1833
  height: int = config.height,
1788
1834
  linewidth: float = config.line_width,
@@ -1834,32 +1880,32 @@ def ols_qqplot(
1834
1880
  outparams = False
1835
1881
 
1836
1882
  if ax is None:
1837
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1883
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1838
1884
  outparams = True
1839
1885
 
1840
1886
  # fit 객체에서 잔차(residuals) 추출
1841
1887
  residuals = fit.resid
1842
1888
 
1843
1889
  # markersize 기본값 설정 (기존 크기의 2/3)
1844
- if 'markersize' not in params:
1845
- params['markersize'] = 2
1890
+ if "markersize" not in params:
1891
+ params["markersize"] = 2
1846
1892
 
1847
1893
  # statsmodels의 qqplot 사용 (더 전문적이고 최적화된 구현)
1848
1894
  # line 옵션으로 다양한 참조선 지원
1849
1895
  sm_qqplot(residuals, line=line, ax=ax, **params)
1850
1896
 
1851
1897
  # 점의 스타일 개선: 연한 내부, 진한 테두리
1852
- for collection in ax.collections: # type: ignore
1898
+ for collection in ax.collections: # type: ignore
1853
1899
  # PathCollection (scatter plot의 점들)
1854
- collection.set_facecolor('#4A90E2') # 연한 파란색 내부
1855
- collection.set_edgecolor('#1E3A8A') # 진한 파란색 테두리
1900
+ collection.set_facecolor("#4A90E2") # 연한 파란색 내부
1901
+ collection.set_edgecolor("#1E3A8A") # 진한 파란색 테두리
1856
1902
  collection.set_linewidth(0.8) # 테두리 굵기
1857
1903
  collection.set_alpha(0.7) # 약간의 투명도
1858
1904
 
1859
1905
  # 선 굵기 조정
1860
- for line in ax.get_lines(): # type: ignore
1861
- if line.get_linestyle() == '--' or line.get_color() == 'r': # type: ignore
1862
- line.set_linewidth(linewidth) # type: ignore
1906
+ for line in ax.get_lines(): # type: ignore
1907
+ if line.get_linestyle() == "--" or line.get_color() == "r": # type: ignore
1908
+ line.set_linewidth(linewidth) # type: ignore
1863
1909
 
1864
1910
  finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1865
1911
 
@@ -1904,7 +1950,7 @@ def distribution_by_class(
1904
1950
  None
1905
1951
  """
1906
1952
  if xnames is None:
1907
- xnames = data.columns # type: ignore
1953
+ xnames = data.columns # type: ignore
1908
1954
 
1909
1955
  for i, v in enumerate(xnames): # type: ignore
1910
1956
  # 종속변수이거나 숫자형이 아닌 경우는 제외
@@ -1930,14 +1976,14 @@ def distribution_by_class(
1930
1976
  linewidth=linewidth,
1931
1977
  dpi=dpi,
1932
1978
  callback=callback,
1933
- save_path=save_path
1979
+ save_path=save_path,
1934
1980
  )
1935
1981
  elif type == "hist":
1936
1982
  histplot(
1937
1983
  df=data,
1938
1984
  xname=v,
1939
1985
  hue=hue,
1940
- bins=bins, # type: ignore
1986
+ bins=bins, # type: ignore
1941
1987
  kde=False,
1942
1988
  palette=palette,
1943
1989
  width=width,
@@ -1945,14 +1991,14 @@ def distribution_by_class(
1945
1991
  linewidth=linewidth,
1946
1992
  dpi=dpi,
1947
1993
  callback=callback,
1948
- save_path=save_path
1994
+ save_path=save_path,
1949
1995
  )
1950
1996
  elif type == "histkde":
1951
1997
  histplot(
1952
1998
  df=data,
1953
1999
  xname=v,
1954
2000
  hue=hue,
1955
- bins=bins, # type: ignore
2001
+ bins=bins, # type: ignore
1956
2002
  kde=True,
1957
2003
  palette=palette,
1958
2004
  width=width,
@@ -1960,7 +2006,7 @@ def distribution_by_class(
1960
2006
  linewidth=linewidth,
1961
2007
  dpi=dpi,
1962
2008
  callback=callback,
1963
- save_path=save_path
2009
+ save_path=save_path,
1964
2010
  )
1965
2011
 
1966
2012
 
@@ -2027,7 +2073,7 @@ def scatter_by_class(
2027
2073
  group = processed
2028
2074
 
2029
2075
  for v in group:
2030
- scatterplot(data=data, xname=v[0], yname=v[1], outline=outline, hue=hue, palette=palette, width=width, height=height, linewidth=linewidth, dpi=dpi, callback=callback, save_path=save_path) # type: ignore
2076
+ scatterplot(data=data, xname=v[0], yname=v[1], outline=outline, hue=hue, palette=palette, width=width, height=height, linewidth=linewidth, dpi=dpi, callback=callback, save_path=save_path) # type: ignore
2031
2077
 
2032
2078
 
2033
2079
  # ===================================================================
@@ -2072,7 +2118,9 @@ def categorical_target_distribution(
2072
2118
 
2073
2119
  # 명목형 컬럼 후보: object, category, bool
2074
2120
  if hue is None:
2075
- cat_cols = data.select_dtypes(include=["object", "category", "bool", "boolean"]).columns
2121
+ cat_cols = data.select_dtypes(
2122
+ include=["object", "category", "bool", "boolean"]
2123
+ ).columns
2076
2124
  target_cols = [c for c in cat_cols if c != yname]
2077
2125
  elif isinstance(hue, str):
2078
2126
  target_cols = [hue]
@@ -2102,7 +2150,16 @@ def categorical_target_distribution(
2102
2150
  plot_kwargs.update({"x": col, "y": yname, "palette": palette})
2103
2151
  sb.violinplot(**plot_kwargs, linewidth=linewidth)
2104
2152
  elif kind == "kde":
2105
- plot_kwargs.update({"x": yname, "hue": col, "palette": palette, "fill": kde_fill, "common_norm": False, "linewidth": linewidth})
2153
+ plot_kwargs.update(
2154
+ {
2155
+ "x": yname,
2156
+ "hue": col,
2157
+ "palette": palette,
2158
+ "fill": kde_fill,
2159
+ "common_norm": False,
2160
+ "linewidth": linewidth,
2161
+ }
2162
+ )
2106
2163
  sb.kdeplot(**plot_kwargs)
2107
2164
  else: # box
2108
2165
  plot_kwargs.update({"x": col, "y": yname, "hue": col, "palette": palette})
@@ -2156,7 +2213,7 @@ def roc_curve_plot(
2156
2213
  """
2157
2214
  outparams = False
2158
2215
  if ax is None:
2159
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
2216
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
2160
2217
  outparams = True
2161
2218
 
2162
2219
  # 실제값(y_true) 결정
@@ -2177,15 +2234,15 @@ def roc_curve_plot(
2177
2234
  roc_auc = auc(fpr, tpr)
2178
2235
 
2179
2236
  # ROC 곡선 그리기
2180
- ax.plot(fpr, tpr, color='darkorange', lw=linewidth, label=f'ROC curve (AUC = {roc_auc:.4f})') # type: ignore
2181
- ax.plot([0, 1], [0, 1], color='navy', lw=linewidth, linestyle='--', label='Random Classifier') # type: ignore
2182
-
2183
- ax.set_xlim([0.0, 1.0]) # type: ignore
2184
- ax.set_ylim([0.0, 1.05]) # type: ignore
2185
- ax.set_xlabel('위양성율 (False Positive Rate)', fontsize=8) # type: ignore
2186
- ax.set_ylabel('재현율 (True Positive Rate)', fontsize=8) # type: ignore
2187
- ax.set_title('ROC 곡선', fontsize=10, fontweight='bold') # type: ignore
2188
- ax.legend(loc="lower right", fontsize=7) # type: ignore
2237
+ ax.plot(fpr, tpr, color="darkorange", lw=linewidth, label=f"ROC curve (AUC = {roc_auc:.4f})") # type: ignore
2238
+ ax.plot([0, 1], [0, 1], color="navy", lw=linewidth, linestyle="--", label="Random Classifier") # type: ignore
2239
+
2240
+ ax.set_xlim([0.0, 1.0]) # type: ignore
2241
+ ax.set_ylim([0.0, 1.05]) # type: ignore
2242
+ ax.set_xlabel("위양성율 (False Positive Rate)", fontsize=8) # type: ignore
2243
+ ax.set_ylabel("재현율 (True Positive Rate)", fontsize=8) # type: ignore
2244
+ ax.set_title("ROC 곡선", fontsize=10, fontweight="bold") # type: ignore
2245
+ ax.legend(loc="lower right", fontsize=7) # type: ignore
2189
2246
  finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
2190
2247
 
2191
2248
 
@@ -2232,13 +2289,18 @@ def confusion_matrix_plot(
2232
2289
  cm = confusion_matrix(y_true, y_pred)
2233
2290
 
2234
2291
  # 혼동행렬 시각화
2235
- disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['음성', '양성'])
2292
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["음성", "양성"])
2236
2293
  # 가독성을 위해 텍스트 크기/굵기 조정
2237
- disp.plot(ax=ax, cmap='Blues', values_format='d', text_kw={"fontsize": 16, "weight": "bold"})
2294
+ disp.plot(
2295
+ ax=ax,
2296
+ cmap="Blues",
2297
+ values_format="d",
2298
+ text_kw={"fontsize": 16, "weight": "bold"},
2299
+ )
2238
2300
 
2239
- ax.set_title(f'혼동행렬 (임계값: {threshold})', fontsize=8, fontweight='bold') # type: ignore
2301
+ ax.set_title(f"혼동행렬 (임계값: {threshold})", fontsize=8, fontweight="bold") # type: ignore
2240
2302
 
2241
- finalize_plot(ax, callback, outparams, save_path, False, title) # type: ignore
2303
+ finalize_plot(ax, callback, outparams, save_path, False, title) # type: ignore
2242
2304
 
2243
2305
 
2244
2306
  # ===================================================================
@@ -2323,7 +2385,7 @@ def radarplot(
2323
2385
  # Axes 생성 (polar projection)
2324
2386
  if ax is None:
2325
2387
  fig = plt.figure(figsize=(width / 100, height / 100), dpi=dpi)
2326
- ax = fig.add_subplot(111, projection='polar')
2388
+ ax = fig.add_subplot(111, projection="polar")
2327
2389
  outparams = True
2328
2390
 
2329
2391
  # 각도 계산
@@ -2345,8 +2407,15 @@ def radarplot(
2345
2407
  color = colors[idx]
2346
2408
 
2347
2409
  # 선 그리기
2348
- ax.plot(angles, values, 'o-', linewidth=linewidth,
2349
- label=str(label_name), color=color, **params)
2410
+ ax.plot(
2411
+ angles,
2412
+ values,
2413
+ "o-",
2414
+ linewidth=linewidth,
2415
+ label=str(label_name),
2416
+ color=color,
2417
+ **params,
2418
+ )
2350
2419
 
2351
2420
  # 영역 채우기
2352
2421
  if fill:
@@ -2362,15 +2431,15 @@ def radarplot(
2362
2431
 
2363
2432
  # 범례
2364
2433
  if len(labels) <= 10: # 너무 많으면 범례 생략
2365
- ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
2434
+ ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))
2366
2435
 
2367
2436
  # 제목
2368
2437
  if hue is not None:
2369
- ax.set_title(f'Radar Chart by {hue}', pad=20)
2438
+ ax.set_title(f"Radar Chart by {hue}", pad=20)
2370
2439
  else:
2371
- ax.set_title('Radar Chart', pad=20)
2440
+ ax.set_title("Radar Chart", pad=20)
2372
2441
 
2373
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
2442
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
2374
2443
 
2375
2444
 
2376
2445
  # ===================================================================
@@ -2421,7 +2490,9 @@ def distribution_plot(
2421
2490
 
2422
2491
  if hue is None:
2423
2492
  # 1행 2열 서브플롯 생성
2424
- fig, axes = get_default_ax(width, height, rows=1, cols=2, dpi=dpi, title=title)
2493
+ fig, axes = get_default_ax(
2494
+ width, height, rows=1, cols=2, dpi=dpi, title=title
2495
+ )
2425
2496
 
2426
2497
  kde_confidence_interval(
2427
2498
  data=data,
@@ -2432,17 +2503,10 @@ def distribution_plot(
2432
2503
  )
2433
2504
 
2434
2505
  if kind == "hist":
2435
- histplot(
2436
- df=data,
2437
- xname=c,
2438
- linewidth=linewidth,
2439
- ax=axes[1]
2440
- )
2506
+ histplot(df=data, xname=c, linewidth=linewidth, ax=axes[1])
2441
2507
  else:
2442
2508
  boxplot(
2443
- df=data[column], # type: ignore
2444
- linewidth=linewidth,
2445
- ax=axes[1]
2509
+ df=data[column], linewidth=linewidth, ax=axes[1] # type: ignore
2446
2510
  )
2447
2511
 
2448
2512
  fig.suptitle(title, fontsize=14, y=1.02)
@@ -2453,7 +2517,9 @@ def distribution_plot(
2453
2517
  categories = list(Series(data[hue].dropna().unique()).sort_values())
2454
2518
  n_cat = len(categories) if categories else 1
2455
2519
 
2456
- fig, axes = get_default_ax(width, height, rows=n_cat, cols=2, dpi=dpi, title=title)
2520
+ fig, axes = get_default_ax(
2521
+ width, height, rows=n_cat, cols=2, dpi=dpi, title=title
2522
+ )
2457
2523
  axes_2d = np.atleast_2d(axes)
2458
2524
 
2459
2525
  for idx, cat in enumerate(categories):
@@ -2478,9 +2544,7 @@ def distribution_plot(
2478
2544
  )
2479
2545
  else:
2480
2546
  boxplot(
2481
- df=subset[c], # type: ignore
2482
- linewidth=linewidth,
2483
- ax=right_ax
2547
+ df=subset[c], linewidth=linewidth, ax=right_ax # type: ignore
2484
2548
  )
2485
2549
 
2486
2550
  fig.suptitle(f"{title} by {hue}", fontsize=14, y=1.02)
@@ -2488,24 +2552,24 @@ def distribution_plot(
2488
2552
  plt.tight_layout()
2489
2553
 
2490
2554
  if save_path:
2491
- plt.savefig(save_path, bbox_inches='tight', dpi=dpi)
2555
+ plt.savefig(save_path, bbox_inches="tight", dpi=dpi)
2492
2556
  plt.close()
2493
2557
  else:
2494
2558
  plt.show()
2495
2559
 
2496
2560
 
2497
2561
  def silhouette_plot(
2498
- estimator: KMeans,
2499
- data: DataFrame,
2500
- title: str | None = None,
2501
- width: int = config.width,
2502
- height: int = config.height,
2503
- linewidth: float = config.line_width,
2504
- dpi: int = config.dpi,
2505
- save_path: str | None = None,
2506
- callback: Callable | None = None,
2507
- ax: Axes | None = None,
2508
- ) -> None:
2562
+ estimator: KMeans,
2563
+ data: DataFrame,
2564
+ title: str | None = None,
2565
+ width: int = config.width,
2566
+ height: int = config.height,
2567
+ linewidth: float = config.line_width,
2568
+ dpi: int = config.dpi,
2569
+ save_path: str | None = None,
2570
+ callback: Callable | None = None,
2571
+ ax: Axes | None = None,
2572
+ ) -> None:
2509
2573
  """
2510
2574
  군집분석 결과의 실루엣 플롯을 시각화함.
2511
2575
 
@@ -2532,7 +2596,7 @@ def silhouette_plot(
2532
2596
 
2533
2597
  outparams = False
2534
2598
  if ax is None:
2535
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
2599
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
2536
2600
  outparams = True
2537
2601
 
2538
2602
  sil_avg = silhouette_score(X=data, labels=estimator.labels_)
@@ -2541,14 +2605,14 @@ def silhouette_plot(
2541
2605
  y_lower = 10
2542
2606
 
2543
2607
  # 클러스터링 갯수별로 fill_betweenx( )형태의 막대 그래프 표현.
2544
- for i in range(estimator.n_clusters): # type: ignore
2545
- ith_cluster_sil_values = sil_values[estimator.labels_ == i] # type: ignore
2546
- ith_cluster_sil_values.sort() # type: ignore
2608
+ for i in range(estimator.n_clusters): # type: ignore
2609
+ ith_cluster_sil_values = sil_values[estimator.labels_ == i] # type: ignore
2610
+ ith_cluster_sil_values.sort() # type: ignore
2547
2611
 
2548
- size_cluster_i = ith_cluster_sil_values.shape[0] # type: ignore
2612
+ size_cluster_i = ith_cluster_sil_values.shape[0] # type: ignore
2549
2613
  y_upper = y_lower + size_cluster_i
2550
2614
 
2551
- ax.fill_betweenx( # type: ignore
2615
+ ax.fill_betweenx( # type: ignore
2552
2616
  np.arange(y_lower, y_upper),
2553
2617
  0,
2554
2618
  ith_cluster_sil_values,
@@ -2564,23 +2628,24 @@ def silhouette_plot(
2564
2628
  ax.set_xlim([-0.1, 1]) # type: ignore
2565
2629
  ax.set_ylim([0, len(data) + (estimator.n_clusters + 1) * 10]) # type: ignore
2566
2630
  ax.set_yticks([]) # type: ignore
2567
- ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1]) # type: ignore
2631
+ ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1]) # type: ignore
2568
2632
 
2569
2633
  if title is None:
2570
- title = "Number of Cluster : " + str(estimator.n_clusters) + ", Silhouette Score :" + str(round(sil_avg, 3)) # type: ignore
2634
+ title = "Number of Cluster : " + str(estimator.n_clusters) + ", Silhouette Score :" + str(round(sil_avg, 3)) # type: ignore
2571
2635
 
2572
- finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
2636
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
2573
2637
 
2574
2638
 
2575
2639
  def cluster_plot(
2576
- estimator: KMeans,
2577
- data: DataFrame,
2640
+ estimator: KMeans | None = None,
2641
+ data: DataFrame | None = None,
2578
2642
  xname: str | None = None,
2579
2643
  yname: str | None = None,
2580
2644
  hue: str | None = None,
2645
+ vector: str | None = None,
2581
2646
  title: str | None = None,
2582
2647
  palette: str | None = None,
2583
- outline: bool = False,
2648
+ outline: bool = True,
2584
2649
  width: int = config.width,
2585
2650
  height: int = config.height,
2586
2651
  linewidth: float = config.line_width,
@@ -2588,7 +2653,6 @@ def cluster_plot(
2588
2653
  save_path: str | None = None,
2589
2654
  ax: Axes | None = None,
2590
2655
  ) -> None:
2591
-
2592
2656
  """
2593
2657
  2차원 공간에서 군집분석 결과를 산점도로 시각화함.
2594
2658
 
@@ -2598,6 +2662,7 @@ def cluster_plot(
2598
2662
  xname (str, optional): x축에 사용할 컬럼명. None이면 첫 번째 컬럼 사용.
2599
2663
  yname (str, optional): y축에 사용할 컬럼명. None이면 두 번째 컬럼 사용.
2600
2664
  hue (str, optional): 군집 구분에 사용할 컬럼명. None이면 'cluster' 자동 생성.
2665
+ vector (str, optional): 벡터 종류를 의미하는 컬럼명. None이면 사용 안함.
2601
2666
  title (str, optional): 플롯 제목. None이면 기본값 사용.
2602
2667
  palette (str, optional): 색상 팔레트.
2603
2668
  outline (bool, optional): 외곽선 표시 여부.
@@ -2622,49 +2687,53 @@ def cluster_plot(
2622
2687
  """
2623
2688
  outparams = False
2624
2689
  if ax is None:
2625
- fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
2690
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
2626
2691
  outparams = True
2627
2692
 
2628
- df = data.copy()
2693
+ df = data.copy() if data is not None else None # type: ignore
2629
2694
 
2630
2695
  if not hue:
2631
- df['cluster'] = estimator.labels_ # type: ignore
2632
- hue = 'cluster'
2696
+ df["cluster"] = estimator.labels_ # type: ignore
2697
+ hue = "cluster"
2633
2698
 
2634
2699
  if xname is None:
2635
- xname = df.columns[0] # type: ignore
2700
+ xname = df.columns[0] # type: ignore
2636
2701
 
2637
2702
  if yname is None:
2638
- yname = df.columns[1] # type: ignore
2703
+ yname = df.columns[1] # type: ignore
2639
2704
 
2640
- xindex = df.columns.get_loc(xname) # type: ignore
2641
- yindex = df.columns.get_loc(yname) # type: ignore
2705
+ xindex = df.columns.get_loc(xname) # type: ignore
2706
+ yindex = df.columns.get_loc(yname) # type: ignore
2642
2707
 
2643
2708
  def callback(ax: Axes) -> None:
2644
- # 클러스터 중심점 표시
2645
- centers = estimator.cluster_centers_ # type: ignore
2646
- ax.scatter( # type: ignore
2647
- centers[:, xindex],
2648
- centers[:, yindex],
2649
- marker="o",
2650
- color="white",
2651
- alpha=1,
2652
- s=200,
2653
- edgecolor="r",
2654
- linewidth=linewidth
2655
- )
2656
-
2657
- for i, c in enumerate(centers):
2658
- ax.scatter(c[xindex], c[yindex], marker="$%d$" % i, alpha=1, s=50, edgecolor="k")
2659
-
2660
2709
  ax.set_xlabel("Feature space for the " + xname)
2661
2710
  ax.set_ylabel("Feature space for the " + yname)
2662
2711
 
2712
+ if hasattr(estimator, "cluster_centers_"):
2713
+ # 클러스터 중심점 표시
2714
+ centers = estimator.cluster_centers_ # type: ignore
2715
+ ax.scatter( # type: ignore
2716
+ centers[:, xindex],
2717
+ centers[:, yindex],
2718
+ marker="o",
2719
+ color="white",
2720
+ alpha=1,
2721
+ s=200,
2722
+ edgecolor="r",
2723
+ linewidth=linewidth,
2724
+ )
2725
+
2726
+ for i, c in enumerate(centers):
2727
+ ax.scatter(
2728
+ c[xindex], c[yindex], marker="$%d$" % i, alpha=1, s=50, edgecolor="k"
2729
+ )
2730
+
2663
2731
  scatterplot(
2664
- df=df,
2732
+ df=df, # type: ignore
2665
2733
  xname=xname,
2666
2734
  yname=yname,
2667
2735
  hue=hue,
2736
+ vector=vector,
2668
2737
  title="The visualization of the clustered data." if title is None else title,
2669
2738
  outline=outline,
2670
2739
  palette=palette,
@@ -2674,16 +2743,40 @@ def cluster_plot(
2674
2743
  dpi=dpi,
2675
2744
  save_path=save_path,
2676
2745
  callback=callback,
2677
- ax=ax
2746
+ ax=ax,
2678
2747
  )
2679
2748
 
2680
- def visualize_silhouette(estimator: KMeans, data: DataFrame) -> None:
2749
+
2750
+ def visualize_silhouette(
2751
+ estimator: KMeans,
2752
+ data: DataFrame,
2753
+ xname: str | None = None,
2754
+ yname: str | None = None,
2755
+ title: str | None = None,
2756
+ palette: str | None = None,
2757
+ outline: bool = False,
2758
+ width: int = config.width,
2759
+ height: int = config.height,
2760
+ linewidth: float = config.line_width,
2761
+ dpi: int = config.dpi,
2762
+ save_path: str | None = None,
2763
+ ) -> None:
2681
2764
  """
2682
2765
  군집분석 결과의 실루엣 플롯과 군집 산점도를 한 화면에 함께 시각화함.
2683
2766
 
2684
2767
  Args:
2685
2768
  estimator (KMeans): 학습된 KMeans 군집 모델 객체.
2686
2769
  data (DataFrame): 군집분석에 사용된 입력 데이터 (n_samples, n_features).
2770
+ xname (str, optional): 산점도 x축에 사용할 컬럼명. None이면 첫 번째 컬럼 사용.
2771
+ yname (str, optional): 산점도 y축에 사용할 컬럼명. None이면 두 번째 컬럼 사용.
2772
+ title (str, optional): 플롯 제목. None이면 기본값 사용.
2773
+ palette (str, optional): 색상 팔레트.
2774
+ outline (bool, optional): 산점도 외곽선 표시 여부.
2775
+ width (int, optional): 플롯 가로 크기 (inch 단위).
2776
+ height (int, optional): 플롯 세로 크기 (inch 단위).
2777
+ linewidth (float, optional): 기준선 등 선 두께.
2778
+ dpi (int, optional): 플롯 해상도(DPI).
2779
+ save_path (str, optional): 저장 경로 지정 시 파일로 저장.
2687
2780
 
2688
2781
  Returns:
2689
2782
  None
@@ -2692,18 +2785,29 @@ def visualize_silhouette(estimator: KMeans, data: DataFrame) -> None:
2692
2785
  - 실루엣 플롯(왼쪽)과 2차원 군집 산점도(오른쪽)를 동시에 확인 가능
2693
2786
  - 군집 품질과 분포를 한눈에 비교·분석할 때 유용
2694
2787
  """
2695
- fig, ax = get_default_ax(rows=1, cols=2)
2788
+ fig, ax = get_default_ax(rows=1, cols=2, width=width, height=height, dpi=dpi, title=title)
2696
2789
 
2697
2790
  silhouette_plot(
2698
2791
  estimator=estimator,
2699
2792
  data=data,
2700
2793
  ax=ax[0],
2794
+ linewidth=linewidth,
2795
+ width=width,
2796
+ height=height,
2797
+ dpi=dpi
2701
2798
  )
2702
2799
 
2703
2800
  cluster_plot(
2704
2801
  estimator=estimator,
2705
2802
  data=data,
2803
+ xname=xname,
2804
+ yname=yname,
2706
2805
  ax=ax[1],
2806
+ outline=outline,
2807
+ palette=palette,
2808
+ width=width,
2809
+ height=height,
2810
+ dpi=dpi
2707
2811
  )
2708
2812
 
2709
- finalize_plot(ax)
2813
+ finalize_plot(ax)