hossam 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hossam/hs_plot.py CHANGED
@@ -1,13 +1,14 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  from __future__ import annotations
3
3
  from types import SimpleNamespace
4
+ from typing import Callable
4
5
 
5
6
  # ===================================================================
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  import seaborn as sb
9
10
  import matplotlib.pyplot as plt
10
- from matplotlib.pyplot import Axes
11
+ from matplotlib.pyplot import Axes # type: ignore
11
12
  from math import sqrt
12
13
  from pandas import DataFrame
13
14
 
@@ -15,7 +16,7 @@ from pandas import DataFrame
15
16
  from scipy.stats import t
16
17
  from scipy.spatial import ConvexHull
17
18
  from statsmodels.graphics.gofplots import qqplot as sm_qqplot
18
- from statsmodels.nonparametric.smoothers_lowess import lowess
19
+ from statsmodels.nonparametric.smoothers_lowess import lowess as sm_lowess
19
20
 
20
21
  # ===================================================================
21
22
  from statannotations.Annotator import Annotator
@@ -30,9 +31,6 @@ from sklearn.metrics import (
30
31
  )
31
32
 
32
33
  # ===================================================================
33
- if pd.__version__ > "2.0.0":
34
- pd.DataFrame.iteritems = pd.DataFrame.items
35
-
36
34
  config = SimpleNamespace(
37
35
  dpi=200,
38
36
  width=600,
@@ -49,7 +47,7 @@ config = SimpleNamespace(
49
47
  # ===================================================================
50
48
  # 기본 크기가 설정된 Figure와 Axes를 생성한다
51
49
  # ===================================================================
52
- def get_default_ax(width: int = config.width, height: int = config.height, rows: int = 1, cols: int = 1, dpi: int = config.dpi, flatten: bool = False, ws: int | None = None, hs: int | None = None, title: str = None):
50
+ def get_default_ax(width: int = config.width, height: int = config.height, rows: int = 1, cols: int = 1, dpi: int = config.dpi, flatten: bool = False, ws: int | None = None, hs: int | None = None, title: str | None = None):
53
51
  """기본 크기의 Figure와 Axes를 생성한다.
54
52
 
55
53
  Args:
@@ -75,7 +73,7 @@ def get_default_ax(width: int = config.width, height: int = config.height, rows:
75
73
  if is_array and (ws != None and hs != None):
76
74
  fig.subplots_adjust(wspace=ws, hspace=hs)
77
75
 
78
- if title and not is_array:
76
+ if title and is_array:
79
77
  fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight='bold')
80
78
 
81
79
  if flatten == True:
@@ -95,7 +93,7 @@ def get_default_ax(width: int = config.width, height: int = config.height, rows:
95
93
  for spine in a.spines.values():
96
94
  spine.set_linewidth(config.frame_width)
97
95
  else:
98
- for spine in ax.spines.values():
96
+ for spine in ax.spines.values(): # type: ignore
99
97
  spine.set_linewidth(config.frame_width)
100
98
 
101
99
  return fig, ax
@@ -104,7 +102,7 @@ def get_default_ax(width: int = config.width, height: int = config.height, rows:
104
102
  # ===================================================================
105
103
  # 기본 크기가 설정된 Figure와 Axes를 생성한다
106
104
  # ===================================================================
107
- def create_figure(width: int = config.width, height: int = config.height, rows: int = 1, cols: int = 1, dpi: int = config.dpi, flatten: bool = False, ws: int | None = None, hs: int | None = None, title: str = None):
105
+ def create_figure(width: int = config.width, height: int = config.height, rows: int = 1, cols: int = 1, dpi: int = config.dpi, flatten: bool = False, ws: int | None = None, hs: int | None = None, title: str | None = None):
108
106
  """기본 크기의 Figure와 Axes를 생성한다. get_default_ax의 래퍼 함수.
109
107
 
110
108
  Args:
@@ -128,7 +126,7 @@ def create_figure(width: int = config.width, height: int = config.height, rows:
128
126
  # ===================================================================
129
127
  # 그래프의 그리드, 레이아웃을 정리하고 필요 시 저장 또는 표시한다
130
128
  # ===================================================================
131
- def finalize_plot(ax: Axes | np.ndarray, callback: any = None, outparams: bool = False, save_path: str = None, grid: bool = True, title: str = None) -> None:
129
+ def finalize_plot(ax: Axes | np.ndarray | list, callback: Callable | None = None, outparams: bool = False, save_path: str | None = None, grid: bool = True, title: str | None = None) -> None:
132
130
  """공통 후처리를 수행한다: 콜백 실행, 레이아웃 정리, 필요 시 표시/종료.
133
131
 
134
132
  Args:
@@ -176,7 +174,7 @@ def finalize_plot(ax: Axes | np.ndarray, callback: any = None, outparams: bool =
176
174
  # ===================================================================
177
175
  # 그래프의 그리드, 레이아웃을 정리하고 필요 시 저장 또는 표시한다
178
176
  # ===================================================================
179
- def show_figure(ax: Axes | np.ndarray, callback: any = None, outparams: bool = False, save_path: str = None, grid: bool = True, title: str = None) -> None:
177
+ def show_figure(ax: Axes | np.ndarray, callback: Callable | None = None, outparams: bool = False, save_path: str | None = None, grid: bool = True, title: str | None = None) -> None:
180
178
  """공통 후처리를 수행한다: 콜백 실행, 레이아웃 정리, 필요 시 표시/종료.
181
179
  finalize_plot의 래퍼 함수.
182
180
 
@@ -199,19 +197,19 @@ def show_figure(ax: Axes | np.ndarray, callback: any = None, outparams: bool = F
199
197
  # ===================================================================
200
198
  def lineplot(
201
199
  df: DataFrame,
202
- xname: str = None,
203
- yname: str = None,
204
- hue: str = None,
200
+ xname: str | None = None,
201
+ yname: str | None = None,
202
+ hue: str | None = None,
205
203
  title: str | None = None,
206
- marker: str = None,
207
- palette: str = None,
204
+ marker: str | None = None,
205
+ palette: str | None = None,
208
206
  width: int = config.width,
209
207
  height: int = config.height,
210
208
  linewidth: float = config.line_width,
211
209
  dpi: int = config.dpi,
212
- save_path: str = None,
213
- callback: any = None,
214
- ax: Axes = None,
210
+ save_path: str | None = None,
211
+ callback: Callable | None = None,
212
+ ax: Axes | None = None,
215
213
  **params,
216
214
  ) -> None:
217
215
  """선 그래프를 그린다.
@@ -239,7 +237,7 @@ def lineplot(
239
237
  outparams = False
240
238
 
241
239
  if ax is None:
242
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
240
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
243
241
  outparams = True
244
242
 
245
243
  # hue가 있을 때만 palette 사용, 없으면 color 사용
@@ -260,7 +258,7 @@ def lineplot(
260
258
  lineplot_kwargs.update(params)
261
259
 
262
260
  sb.lineplot(**lineplot_kwargs, linewidth=linewidth)
263
- finalize_plot(ax, callback, outparams, save_path, True, title)
261
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
264
262
 
265
263
 
266
264
  # ===================================================================
@@ -268,18 +266,22 @@ def lineplot(
268
266
  # ===================================================================
269
267
  def boxplot(
270
268
  df: DataFrame,
271
- xname: str = None,
272
- yname: str = None,
269
+ xname: str | None = None,
270
+ yname: str | None = None,
273
271
  title: str | None = None,
274
272
  orient: str = "v",
275
- palette: str = None,
273
+ stat_test: str | None = None,
274
+ stat_pairs: list[tuple] | None = None,
275
+ stat_text_format: str = "star",
276
+ stat_loc: str = "inside",
277
+ palette: str | None = None,
276
278
  width: int = config.width,
277
279
  height: int = config.height,
278
280
  linewidth: float = config.line_width,
279
281
  dpi: int = config.dpi,
280
- save_path: str = None,
281
- callback: any = None,
282
- ax: Axes = None,
282
+ save_path: str | None = None,
283
+ callback: Callable | None = None,
284
+ ax: Axes | None = None,
283
285
  **params,
284
286
  ) -> None:
285
287
  """상자그림(boxplot)을 그린다.
@@ -290,6 +292,10 @@ def boxplot(
290
292
  yname (str|None): y축 값 컬럼명.
291
293
  title (str|None): 그래프 제목.
292
294
  orient (str): 'v' 또는 'h' 방향.
295
+ stat_test (str|None): 통계 검정 방법. None이면 검정 안함. xname과 yname이 모두 지정되어야 함.
296
+ stat_pairs (list[tuple]|None): 통계 검정할 그룹 쌍 목록.
297
+ stat_text_format (str): 통계 결과 표시 형식.
298
+ stat_loc (str): 통계 결과 위치.
293
299
  palette (str|None): 팔레트 이름.
294
300
  width (int): 캔버스 가로 픽셀.
295
301
  height (int): 캔버스 세로 픽셀.
@@ -306,7 +312,7 @@ def boxplot(
306
312
  outparams = False
307
313
 
308
314
  if ax is None:
309
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
315
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
310
316
  outparams = True
311
317
 
312
318
  if xname is not None and yname is not None:
@@ -328,10 +334,65 @@ def boxplot(
328
334
 
329
335
  boxplot_kwargs.update(params)
330
336
  sb.boxplot(**boxplot_kwargs, linewidth=linewidth)
337
+
338
+ # 통계 검정 추가
339
+ if stat_test is not None:
340
+ if stat_pairs is None:
341
+ stat_pairs = [df[xname].dropna().unique().tolist()]
342
+
343
+ annotator = Annotator(ax, data=df, x=xname, y=yname, pairs=stat_pairs, orient=orient)
344
+ annotator.configure(test=stat_test, text_format=stat_text_format, loc=stat_loc)
345
+ annotator.apply_and_annotate()
331
346
  else:
332
- sb.boxplot(data=df, orient=orient, ax=ax, linewidth=linewidth, **params)
347
+ sb.boxplot(data=df, orient=orient, ax=ax, linewidth=linewidth, **params) # type: ignore
333
348
 
334
- finalize_plot(ax, callback, outparams, save_path, True, title)
349
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
350
+
351
+
352
+ # ===================================================================
353
+ # 상자그림에 p-value 주석을 추가한다
354
+ # ===================================================================
355
+ def pvalue1_anotation(
356
+ data: DataFrame,
357
+ target: str,
358
+ hue: str,
359
+ title: str | None = None,
360
+ pairs: list | None = None,
361
+ test: str = "t-test_ind",
362
+ text_format: str = "star",
363
+ loc: str = "outside",
364
+ width: int = config.width,
365
+ height: int = config.height,
366
+ linewidth: float = config.line_width,
367
+ dpi: int = config.dpi,
368
+ save_path: str | None = None,
369
+ callback: Callable | None = None,
370
+ ax: Axes | None = None,
371
+ **params
372
+ ) -> None:
373
+ """
374
+ boxplot의 wrapper 함수로, 상자그림에 p-value 주석을 추가한다.
375
+ """
376
+ boxplot(
377
+ data,
378
+ xname=hue,
379
+ yname=target,
380
+ title=title,
381
+ orient="v",
382
+ stat_test=test,
383
+ stat_pairs=pairs,
384
+ stat_text_format=text_format,
385
+ stat_loc=loc,
386
+ palette=None,
387
+ width=width,
388
+ height=height,
389
+ linewidth=linewidth,
390
+ dpi=dpi,
391
+ save_path=save_path,
392
+ callback=callback,
393
+ ax=ax,
394
+ **params
395
+ )
335
396
 
336
397
 
337
398
  # ===================================================================
@@ -339,11 +400,11 @@ def boxplot(
339
400
  # ===================================================================
340
401
  def kdeplot(
341
402
  df: DataFrame,
342
- xname: str = None,
343
- yname: str = None,
344
- hue: str = None,
403
+ xname: str | None = None,
404
+ yname: str | None = None,
405
+ hue: str | None = None,
345
406
  title: str | None = None,
346
- palette: str = None,
407
+ palette: str | None = None,
347
408
  fill: bool = False,
348
409
  fill_alpha: float = config.fill_alpha,
349
410
  linewidth: float = config.line_width,
@@ -351,9 +412,9 @@ def kdeplot(
351
412
  width: int = config.width,
352
413
  height: int = config.height,
353
414
  dpi: int = config.dpi,
354
- save_path: str = None,
355
- callback: any = None,
356
- ax: Axes = None,
415
+ save_path: str | None = None,
416
+ callback: Callable | None = None,
417
+ ax: Axes | None = None,
357
418
  **params,
358
419
  ) -> None:
359
420
  """커널 밀도 추정(KDE) 그래프를 그린다.
@@ -434,7 +495,7 @@ def kdeplot(
434
495
  return
435
496
 
436
497
  if ax is None:
437
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
498
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
438
499
  outparams = True
439
500
 
440
501
  # 기본 kwargs 설정
@@ -464,7 +525,7 @@ def kdeplot(
464
525
 
465
526
  sb.kdeplot(**kdeplot_kwargs)
466
527
 
467
- finalize_plot(ax, callback, outparams, save_path, True, title)
528
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
468
529
 
469
530
 
470
531
  # ===================================================================
@@ -477,14 +538,14 @@ def histplot(
477
538
  title: str | None = None,
478
539
  bins: int | None = None,
479
540
  kde: bool = True,
480
- palette: str = None,
541
+ palette: str | None = None,
481
542
  width: int = config.width,
482
543
  height: int = config.height,
483
544
  linewidth: float = config.line_width,
484
545
  dpi: int = config.dpi,
485
- save_path: str = None,
486
- callback: any = None,
487
- ax: Axes = None,
546
+ save_path: str | None = None,
547
+ callback: Callable | None = None,
548
+ ax: Axes | None = None,
488
549
  **params,
489
550
  ) -> None:
490
551
  """히스토그램을 그리고 필요 시 KDE를 함께 표시한다.
@@ -511,7 +572,7 @@ def histplot(
511
572
  outparams = False
512
573
 
513
574
  if ax is None:
514
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
575
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
515
576
  outparams = True
516
577
 
517
578
  if bins:
@@ -550,7 +611,7 @@ def histplot(
550
611
  histplot_kwargs.update(params)
551
612
  sb.histplot(**histplot_kwargs)
552
613
 
553
- finalize_plot(ax, callback, outparams, save_path, True, title)
614
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
554
615
 
555
616
 
556
617
  # ===================================================================
@@ -561,14 +622,14 @@ def stackplot(
561
622
  xname: str,
562
623
  hue: str,
563
624
  title: str | None = None,
564
- palette: str = None,
625
+ palette: str | None = None,
565
626
  width: int = config.width,
566
627
  height: int = config.height,
567
628
  linewidth: float = 0.25,
568
629
  dpi: int = config.dpi,
569
- save_path: str = None,
570
- callback: any = None,
571
- ax: Axes = None,
630
+ save_path: str | None = None,
631
+ callback: Callable | None = None,
632
+ ax: Axes | None = None,
572
633
  **params,
573
634
  ) -> None:
574
635
  """클래스 비율을 100% 누적 막대로 표현한다.
@@ -593,7 +654,7 @@ def stackplot(
593
654
  outparams = False
594
655
 
595
656
  if ax is None:
596
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
657
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
597
658
  outparams = True
598
659
 
599
660
  df2 = df[[xname, hue]].copy()
@@ -620,11 +681,11 @@ def stackplot(
620
681
  sb.histplot(**stackplot_kwargs)
621
682
 
622
683
  # 그래프의 x축 항목 수 만큼 반복
623
- for p in ax.patches:
684
+ for p in ax.patches: # type: ignore
624
685
  # 각 막대의 위치, 넓이, 높이
625
- left, bottom, width, height = p.get_bbox().bounds
686
+ left, bottom, width, height = p.get_bbox().bounds # type: ignore
626
687
  # 막대의 중앙에 글자 표시하기
627
- ax.annotate(
688
+ ax.annotate( # type: ignore
628
689
  "%0.1f%%" % (height * 100),
629
690
  xy=(left + width / 2, bottom + height / 2),
630
691
  ha="center",
@@ -633,10 +694,10 @@ def stackplot(
633
694
 
634
695
  if str(df[xname].dtype) in ["int", "int32", "int64", "float", "float32", "float64"]:
635
696
  xticks = list(df[xname].unique())
636
- ax.set_xticks(xticks)
637
- ax.set_xticklabels(xticks)
697
+ ax.set_xticks(xticks) # type: ignore
698
+ ax.set_xticklabels(xticks) # type: ignore
638
699
 
639
- finalize_plot(ax, callback, outparams, save_path, True, title)
700
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
640
701
 
641
702
 
642
703
  # ===================================================================
@@ -648,14 +709,14 @@ def scatterplot(
648
709
  yname: str,
649
710
  hue=None,
650
711
  title: str | None = None,
651
- palette: str = None,
712
+ palette: str | None = None,
652
713
  width: int = config.width,
653
714
  height: int = config.height,
654
715
  linewidth: float = config.line_width,
655
716
  dpi: int = config.dpi,
656
- save_path: str = None,
657
- callback: any = None,
658
- ax: Axes = None,
717
+ save_path: str | None = None,
718
+ callback: Callable | None = None,
719
+ ax: Axes | None = None,
659
720
  **params,
660
721
  ) -> None:
661
722
  """산점도를 그린다.
@@ -681,7 +742,7 @@ def scatterplot(
681
742
  outparams = False
682
743
 
683
744
  if ax is None:
684
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
745
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
685
746
  outparams = True
686
747
 
687
748
  # hue가 있을 때만 palette 사용, 없으면 color 사용
@@ -703,7 +764,7 @@ def scatterplot(
703
764
 
704
765
  sb.scatterplot(**scatterplot_kwargs)
705
766
 
706
- finalize_plot(ax, callback, outparams, save_path, True, title)
767
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
707
768
 
708
769
 
709
770
  # ===================================================================
@@ -714,14 +775,14 @@ def regplot(
714
775
  xname: str,
715
776
  yname: str,
716
777
  title: str | None = None,
717
- palette: str = None,
778
+ palette: str | None = None,
718
779
  width: int = config.width,
719
780
  height: int = config.height,
720
781
  linewidth: float = config.line_width,
721
782
  dpi: int = config.dpi,
722
- save_path: str = None,
723
- callback: any = None,
724
- ax: Axes = None,
783
+ save_path: str | None = None,
784
+ callback: Callable | None = None,
785
+ ax: Axes | None = None,
725
786
  **params,
726
787
  ) -> None:
727
788
  """단순 회귀선이 포함된 산점도를 그린다.
@@ -746,7 +807,7 @@ def regplot(
746
807
  outparams = False
747
808
 
748
809
  if ax is None:
749
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
810
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
750
811
  outparams = True
751
812
 
752
813
  # regplot은 hue를 지원하지 않으므로 palette를 color로 변환
@@ -758,7 +819,12 @@ def regplot(
758
819
  "data": df,
759
820
  "x": xname,
760
821
  "y": yname,
761
- "scatter_kws": {"color": scatter_color} if scatter_color else {},
822
+ "scatter_kws": {
823
+ "s": 20,
824
+ "linewidths": 0.5,
825
+ "edgecolor": "w",
826
+ "color": scatter_color
827
+ },
762
828
  "line_kws": {
763
829
  "color": "red",
764
830
  "linestyle": "--",
@@ -771,7 +837,7 @@ def regplot(
771
837
 
772
838
  sb.regplot(**regplot_kwargs)
773
839
 
774
- finalize_plot(ax, callback, outparams, save_path, True, title)
840
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
775
841
 
776
842
 
777
843
  # ===================================================================
@@ -783,12 +849,12 @@ def lmplot(
783
849
  yname: str,
784
850
  hue=None,
785
851
  title: str | None = None,
786
- palette: str = None,
852
+ palette: str | None = None,
787
853
  width: int = config.width,
788
854
  height: int = config.height,
789
855
  linewidth: float = config.line_width,
790
856
  dpi: int = config.dpi,
791
- save_path: str = None,
857
+ save_path: str | None = None,
792
858
  **params,
793
859
  ) -> None:
794
860
  """seaborn lmplot으로 선형 모델 시각화를 수행한다.
@@ -835,7 +901,7 @@ def lmplot(
835
901
  continue
836
902
  line.set_linewidth(linewidth)
837
903
 
838
- g.fig.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
904
+ g.fig.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width) # type: ignore
839
905
 
840
906
  if title:
841
907
  g.fig.suptitle(title, fontsize=config.font_size * 1.5, fontweight='bold')
@@ -858,12 +924,12 @@ def pairplot(
858
924
  title: str | None = None,
859
925
  diag_kind: str = "kde",
860
926
  hue=None,
861
- palette: str = None,
927
+ palette: str | None = None,
862
928
  width: int = config.height,
863
929
  height: int = config.height,
864
930
  linewidth: float = config.line_width,
865
931
  dpi: int = config.dpi,
866
- save_path: str = None,
932
+ save_path: str | None = None,
867
933
  **params,
868
934
  ) -> None:
869
935
  """연속형 변수의 숫자형 컬럼 쌍에 대한 관계를 그린다.
@@ -936,7 +1002,7 @@ def pairplot(
936
1002
  g.map_upper(func=sb.scatterplot, linewidth=linewidth)
937
1003
 
938
1004
  # KDE 대각선에도 linewidth 적용
939
- for ax in g.axes.diag:
1005
+ for ax in g.axes.diag: # type: ignore
940
1006
  for line in ax.get_lines():
941
1007
  line.set_linewidth(linewidth)
942
1008
 
@@ -957,15 +1023,15 @@ def countplot(
957
1023
  xname: str,
958
1024
  hue=None,
959
1025
  title: str | None = None,
960
- palette: str = None,
1026
+ palette: str | None = None,
961
1027
  order: int = 1,
962
1028
  width: int = config.width,
963
1029
  height: int = config.height,
964
1030
  linewidth: float = config.line_width,
965
1031
  dpi: int = config.dpi,
966
- save_path: str = None,
967
- callback: any = None,
968
- ax: Axes = None,
1032
+ save_path: str | None = None,
1033
+ callback: Callable | None = None,
1034
+ ax: Axes | None = None,
969
1035
  **params,
970
1036
  ) -> None:
971
1037
  """범주 빈도 막대그래프를 그린다.
@@ -997,7 +1063,7 @@ def countplot(
997
1063
  sort = sorted(list(df[xname].value_counts().index))
998
1064
 
999
1065
  if ax is None:
1000
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
1066
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1001
1067
  outparams = True
1002
1068
 
1003
1069
  # hue가 있을 때만 palette 사용, 없으면 color 사용
@@ -1020,7 +1086,7 @@ def countplot(
1020
1086
 
1021
1087
  sb.countplot(**countplot_kwargs)
1022
1088
 
1023
- finalize_plot(ax, callback, outparams, save_path, True, title)
1089
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1024
1090
 
1025
1091
 
1026
1092
  # ===================================================================
@@ -1032,14 +1098,14 @@ def barplot(
1032
1098
  yname: str,
1033
1099
  hue=None,
1034
1100
  title: str | None = None,
1035
- palette: str = None,
1101
+ palette: str | None = None,
1036
1102
  width: int = config.width,
1037
1103
  height: int = config.height,
1038
1104
  linewidth: float = config.line_width,
1039
1105
  dpi: int = config.dpi,
1040
- save_path: str = None,
1041
- callback: any = None,
1042
- ax: Axes = None,
1106
+ save_path: str | None = None,
1107
+ callback: Callable | None = None,
1108
+ ax: Axes | None = None,
1043
1109
  **params,
1044
1110
  ) -> None:
1045
1111
  """막대그래프를 그린다.
@@ -1065,7 +1131,7 @@ def barplot(
1065
1131
  outparams = False
1066
1132
 
1067
1133
  if ax is None:
1068
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
1134
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1069
1135
  outparams = True
1070
1136
 
1071
1137
  # hue가 있을 때만 palette 사용, 없으면 color 사용
@@ -1086,11 +1152,11 @@ def barplot(
1086
1152
  barplot_kwargs.update(params)
1087
1153
 
1088
1154
  sb.barplot(**barplot_kwargs)
1089
- finalize_plot(ax, callback, outparams, save_path, True, title)
1155
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1090
1156
 
1091
1157
 
1092
1158
  # ===================================================================
1093
- # 바이올린 플롯을 그린다
1159
+ # boxen 플롯을 그린다
1094
1160
  # ===================================================================
1095
1161
  def boxenplot(
1096
1162
  df: DataFrame,
@@ -1098,14 +1164,14 @@ def boxenplot(
1098
1164
  yname: str,
1099
1165
  hue=None,
1100
1166
  title: str | None = None,
1101
- palette: str = None,
1167
+ palette: str | None = None,
1102
1168
  width: int = config.width,
1103
1169
  height: int = config.height,
1104
1170
  linewidth: float = config.line_width,
1105
1171
  dpi: int = config.dpi,
1106
- save_path: str = None,
1107
- callback: any = None,
1108
- ax: Axes = None,
1172
+ save_path: str | None = None,
1173
+ callback: Callable | None = None,
1174
+ ax: Axes | None = None,
1109
1175
  **params,
1110
1176
  ) -> None:
1111
1177
  """박스앤 위스커 확장(boxen) 플롯을 그린다.
@@ -1131,7 +1197,7 @@ def boxenplot(
1131
1197
  outparams = False
1132
1198
 
1133
1199
  if ax is None:
1134
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
1200
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1135
1201
  outparams = True
1136
1202
 
1137
1203
  # palette은 hue가 있을 때만 사용
@@ -1150,7 +1216,7 @@ def boxenplot(
1150
1216
  boxenplot_kwargs.update(params)
1151
1217
 
1152
1218
  sb.boxenplot(**boxenplot_kwargs)
1153
- finalize_plot(ax, callback, outparams, save_path, True, title)
1219
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1154
1220
 
1155
1221
 
1156
1222
  # ===================================================================
@@ -1162,14 +1228,14 @@ def violinplot(
1162
1228
  yname: str,
1163
1229
  hue=None,
1164
1230
  title: str | None = None,
1165
- palette: str = None,
1231
+ palette: str | None = None,
1166
1232
  width: int = config.width,
1167
1233
  height: int = config.height,
1168
1234
  linewidth: float = config.line_width,
1169
1235
  dpi: int = config.dpi,
1170
- save_path: str = None,
1171
- callback: any = None,
1172
- ax: Axes = None,
1236
+ save_path: str | None = None,
1237
+ callback: Callable | None = None,
1238
+ ax: Axes | None = None,
1173
1239
  **params,
1174
1240
  ) -> None:
1175
1241
  """바이올린 플롯을 그린다.
@@ -1195,7 +1261,7 @@ def violinplot(
1195
1261
  outparams = False
1196
1262
 
1197
1263
  if ax is None:
1198
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
1264
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1199
1265
  outparams = True
1200
1266
 
1201
1267
  # palette은 hue가 있을 때만 사용
@@ -1213,7 +1279,7 @@ def violinplot(
1213
1279
 
1214
1280
  violinplot_kwargs.update(params)
1215
1281
  sb.violinplot(**violinplot_kwargs)
1216
- finalize_plot(ax, callback, outparams, save_path, True, title)
1282
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1217
1283
 
1218
1284
 
1219
1285
  # ===================================================================
@@ -1225,14 +1291,14 @@ def pointplot(
1225
1291
  yname: str,
1226
1292
  hue=None,
1227
1293
  title: str | None = None,
1228
- palette: str = None,
1294
+ palette: str | None = None,
1229
1295
  width: int = config.width,
1230
1296
  height: int = config.height,
1231
1297
  linewidth: float = config.line_width,
1232
1298
  dpi: int = config.dpi,
1233
- save_path: str = None,
1234
- callback: any = None,
1235
- ax: Axes = None,
1299
+ save_path: str | None = None,
1300
+ callback: Callable | None = None,
1301
+ ax: Axes | None = None,
1236
1302
  **params,
1237
1303
  ) -> None:
1238
1304
  """포인트 플롯을 그린다.
@@ -1258,7 +1324,7 @@ def pointplot(
1258
1324
  outparams = False
1259
1325
 
1260
1326
  if ax is None:
1261
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
1327
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1262
1328
  outparams = True
1263
1329
 
1264
1330
  # hue가 있을 때만 palette 사용, 없으면 color 사용
@@ -1278,7 +1344,7 @@ def pointplot(
1278
1344
 
1279
1345
  pointplot_kwargs.update(params)
1280
1346
  sb.pointplot(**pointplot_kwargs)
1281
- finalize_plot(ax, callback, outparams, save_path, True, title)
1347
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1282
1348
 
1283
1349
 
1284
1350
  # ===================================================================
@@ -1290,12 +1356,12 @@ def jointplot(
1290
1356
  yname: str,
1291
1357
  hue=None,
1292
1358
  title: str | None = None,
1293
- palette: str = None,
1359
+ palette: str | None = None,
1294
1360
  width: int = config.width,
1295
1361
  height: int = config.height,
1296
1362
  linewidth: float = config.line_width,
1297
1363
  dpi: int = config.dpi,
1298
- save_path: str = None,
1364
+ save_path: str | None = None,
1299
1365
  **params,
1300
1366
  ) -> None:
1301
1367
  """공동 분포(joint) 플롯을 그린다.
@@ -1358,14 +1424,14 @@ def jointplot(
1358
1424
  def heatmap(
1359
1425
  data: DataFrame,
1360
1426
  title: str | None = None,
1361
- palette: str = None,
1427
+ palette: str | None = None,
1362
1428
  width: int | None = None,
1363
1429
  height: int | None = None,
1364
1430
  linewidth: float = 0.25,
1365
1431
  dpi: int = config.dpi,
1366
- save_path: str = None,
1367
- callback: any = None,
1368
- ax: Axes = None,
1432
+ save_path: str | None = None,
1433
+ callback: Callable | None = None,
1434
+ ax: Axes | None = None,
1369
1435
  **params,
1370
1436
  ) -> None:
1371
1437
  """히트맵을 그린다(값 주석 포함).
@@ -1389,10 +1455,10 @@ def heatmap(
1389
1455
 
1390
1456
  if width == None or height == None:
1391
1457
  width = (config.font_size * config.dpi / 72) * 4.5 * len(data.columns)
1392
- height = width * 0.8
1458
+ height = width * 0.8 # type: ignore
1393
1459
 
1394
1460
  if ax is None:
1395
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
1461
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1396
1462
  outparams = True
1397
1463
 
1398
1464
  heatmatp_kwargs = {
@@ -1410,26 +1476,26 @@ def heatmap(
1410
1476
  # heatmap은 hue를 지원하지 않으므로 cmap에 palette 사용
1411
1477
  sb.heatmap(**heatmatp_kwargs)
1412
1478
 
1413
- finalize_plot(ax, callback, outparams, save_path, True, title)
1479
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1414
1480
 
1415
1481
 
1416
1482
  # ===================================================================
1417
- # 클러스터별 볼록 ꯐ막(convex hull)을 그린다
1483
+ # 클러스터별 볼록 경계막(convex hull)을 그린다
1418
1484
  # ===================================================================
1419
1485
  def convex_hull(
1420
1486
  data: DataFrame,
1421
1487
  xname: str,
1422
1488
  yname: str,
1423
- hue: str,
1489
+ hue: str | None = None,
1424
1490
  title: str | None = None,
1425
- palette: str = None,
1491
+ palette: str | None = None,
1426
1492
  width: int = config.width,
1427
1493
  height: int = config.height,
1428
1494
  linewidth: float = config.line_width,
1429
1495
  dpi: int = config.dpi,
1430
- save_path: str = None,
1431
- callback: any = None,
1432
- ax: Axes = None,
1496
+ save_path: str | None = None,
1497
+ callback: Callable | None = None,
1498
+ ax: Axes | None = None,
1433
1499
  **params,
1434
1500
  ):
1435
1501
  """클러스터별 볼록 껍질(convex hull)과 산점도를 그린다.
@@ -1455,7 +1521,7 @@ def convex_hull(
1455
1521
  outparams = False
1456
1522
 
1457
1523
  if ax is None:
1458
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
1524
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1459
1525
  outparams = True
1460
1526
 
1461
1527
  # 군집별 값의 종류별로 반복 수행
@@ -1473,10 +1539,10 @@ def convex_hull(
1473
1539
  # 마지막 좌표 이후에 첫 번째 좌표를 연결
1474
1540
  points = np.append(hull.vertices, hull.vertices[0])
1475
1541
 
1476
- ax.plot(
1542
+ ax.plot( # type: ignore
1477
1543
  df_c.iloc[points, 0], df_c.iloc[points, 1], linewidth=linewidth, linestyle=":"
1478
1544
  )
1479
- ax.fill(df_c.iloc[points, 0], df_c.iloc[points, 1], alpha=0.1)
1545
+ ax.fill(df_c.iloc[points, 0], df_c.iloc[points, 1], alpha=0.1) # type: ignore
1480
1546
  except:
1481
1547
  pass
1482
1548
 
@@ -1484,7 +1550,7 @@ def convex_hull(
1484
1550
  sb.scatterplot(
1485
1551
  data=data, x=xname, y=yname, hue=hue, palette=palette, ax=ax, **params
1486
1552
  )
1487
- finalize_plot(ax, callback, outparams, save_path, True, title)
1553
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1488
1554
 
1489
1555
 
1490
1556
  # ===================================================================
@@ -1500,9 +1566,9 @@ def kde_confidence_interval(
1500
1566
  linewidth: float = config.line_width,
1501
1567
  fill: bool = False,
1502
1568
  dpi: int = config.dpi,
1503
- save_path: str = None,
1504
- callback: any = None,
1505
- ax: Axes = None,
1569
+ save_path: str | None = None,
1570
+ callback: Callable | None = None,
1571
+ ax: Axes | None = None,
1506
1572
  ) -> None:
1507
1573
  """각 숫자 컬럼에 대해 KDE와 t-분포 기반 신뢰구간을 그린다.
1508
1574
 
@@ -1572,7 +1638,7 @@ def kde_confidence_interval(
1572
1638
  cmin, cmax = t.interval(clevel, dof, loc=sample_mean, scale=sample_std_error)
1573
1639
 
1574
1640
  # 현재 컬럼에 대한 커널밀도추정
1575
- sb.kdeplot(data=column, linewidth=linewidth, ax=current_ax, fill=fill, alpha=config.fill_alpha)
1641
+ sb.kdeplot(data=column, linewidth=linewidth, ax=current_ax, fill=fill, alpha=config.fill_alpha) # type: ignore
1576
1642
 
1577
1643
  # 그래프 축의 범위
1578
1644
  xmin, xmax, ymin, ymax = current_ax.get_position().bounds
@@ -1597,89 +1663,7 @@ def kde_confidence_interval(
1597
1663
 
1598
1664
  current_ax.grid(True, alpha=config.grid_alpha, linewidth=config.grid_width)
1599
1665
 
1600
- finalize_plot(axes[0] if isinstance(axes, list) and len(axes) > 0 else ax, callback, outparams, save_path, True, title)
1601
-
1602
-
1603
- # ===================================================================
1604
- # 상자그림에 p-value 주석을 추가한다
1605
- # ===================================================================
1606
- def pvalue1_anotation(
1607
- data: DataFrame,
1608
- target: str,
1609
- hue: str,
1610
- title: str | None = None,
1611
- pairs: list = None,
1612
- test: str = "t-test_ind",
1613
- text_format: str = "star",
1614
- loc: str = "outside",
1615
- width: int = config.width,
1616
- height: int = config.height,
1617
- linewidth: float = config.line_width,
1618
- dpi: int = config.dpi,
1619
- save_path: str = None,
1620
- callback: any = None,
1621
- ax: Axes = None,
1622
- **params
1623
- ) -> None:
1624
- """statannotations를 이용해 상자그림에 p-value 주석을 추가한다.
1625
-
1626
- Args:
1627
- data (DataFrame): 시각화할 데이터.
1628
- target (str): 값 컬럼명.
1629
- hue (str): 그룹 컬럼명.
1630
- title (str|None): 그래프 제목.
1631
- pairs (list|None): 비교할 (group_a, group_b) 튜플 목록. None이면 hue 컬럼의 모든 고유값 조합을 자동 생성.
1632
- test (str): 적용할 통계 검정 이름.
1633
- text_format (str): 주석 형식('star' 등).
1634
- loc (str): 주석 위치.
1635
- width (int): 캔버스 가로 픽셀.
1636
- height (int): 캔버스 세로 픽셀.
1637
- linewidth (float): 선 굵기.
1638
- dpi (int): 그림 크기 및 해상도.
1639
- callback (Callable|None): Axes 후처리 콜백.
1640
- ax (Axes|None): 외부에서 전달한 Axes.
1641
- **params: seaborn boxplot 추가 인자.
1642
-
1643
- Returns:
1644
- None
1645
- """
1646
- # pairs가 None이면 hue 컬럼의 고유값으로 모든 조합 생성
1647
- if pairs is None:
1648
- from itertools import combinations
1649
- unique_values = sorted(data[hue].unique())
1650
- pairs = list(combinations(unique_values, 2))
1651
-
1652
- outparams = False
1653
-
1654
- if ax is None:
1655
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
1656
- outparams = True
1657
-
1658
- # params에서 palette 추출 (있으면)
1659
- palette_value = params.pop("palette", None)
1660
-
1661
- # boxplot kwargs 구성
1662
- boxplot_kwargs = {
1663
- "data": data,
1664
- "x": hue,
1665
- "y": target,
1666
- "linewidth": linewidth,
1667
- "ax": ax,
1668
- }
1669
-
1670
- # palette가 있으면 추가 (hue는 x에 이미 할당됨)
1671
- if palette_value is not None:
1672
- boxplot_kwargs["palette"] = palette_value
1673
-
1674
- boxplot_kwargs.update(params)
1675
-
1676
- sb.boxplot(**boxplot_kwargs)
1677
- annotator = Annotator(ax, data=data, x=hue, y=target, pairs=pairs)
1678
- annotator.configure(test=test, text_format=text_format, loc=loc)
1679
- annotator.apply_and_annotate()
1680
-
1681
- sb.despine()
1682
- finalize_plot(ax, callback, outparams, save_path, True, title)
1666
+ finalize_plot(axes[0] if isinstance(axes, list) and len(axes) > 0 else ax, callback, outparams, save_path, True, title) # type: ignore
1683
1667
 
1684
1668
 
1685
1669
 
@@ -1695,9 +1679,9 @@ def ols_residplot(
1695
1679
  height: int = config.height,
1696
1680
  linewidth: float = config.line_width,
1697
1681
  dpi: int = config.dpi,
1698
- save_path: str = None,
1699
- callback: any = None,
1700
- ax: Axes = None,
1682
+ save_path: str | None = None,
1683
+ callback: Callable | None = None,
1684
+ ax: Axes | None = None,
1701
1685
  **params,
1702
1686
  ) -> None:
1703
1687
  """잔차도를 그린다(선택적으로 MSE 범위와 LOWESS 포함).
@@ -1739,24 +1723,23 @@ def ols_residplot(
1739
1723
  y = y_pred + resid # 실제값 = 적합값 + 잔차
1740
1724
 
1741
1725
  if ax is None:
1742
- fig, ax = get_default_ax(width + 150 if mse else width, height, 1, 1, dpi)
1726
+ fig, ax = get_default_ax(width + 150 if mse else width, height, 1, 1, dpi) # type: ignore
1743
1727
  outparams = True
1744
1728
 
1745
- # 산점도 직접 그리기 (seaborn.residplot보다 훨씬 빠름)
1746
- ax.scatter(y_pred, resid, edgecolor="white", alpha=0.7, **params)
1729
+ # 산점도 seaborn으로 그리기
1730
+ sb.scatterplot(x=y_pred, y=resid, ax=ax, s=20, edgecolor="white", **params)
1747
1731
 
1748
1732
  # 기준선 (잔차 = 0)
1749
- ax.axhline(0, color="gray", linestyle="--", linewidth=linewidth)
1733
+ ax.axhline(0, color="gray", linestyle="--", linewidth=linewidth*0.7) # type: ignore
1750
1734
 
1751
1735
  # LOWESS 스무딩 (선택적)
1752
1736
  if lowess:
1753
- from statsmodels.nonparametric.smoothers_lowess import lowess as sm_lowess
1754
1737
  lowess_result = sm_lowess(resid, y_pred, frac=0.6667)
1755
- ax.plot(lowess_result[:, 0], lowess_result[:, 1],
1756
- color="red", linewidth=linewidth, label="LOWESS")
1738
+ ax.plot(lowess_result[:, 0], lowess_result[:, 1], # type: ignore
1739
+ color="red", linewidth=linewidth, label="LOWESS") # type: ignore
1757
1740
 
1758
- ax.set_xlabel("Fitted values")
1759
- ax.set_ylabel("Residuals")
1741
+ ax.set_xlabel("Fitted values") # type: ignore
1742
+ ax.set_ylabel("Residuals") # type: ignore
1760
1743
 
1761
1744
  if mse:
1762
1745
  mse_val = mean_squared_error(y, y_pred)
@@ -1776,40 +1759,40 @@ def ols_residplot(
1776
1759
 
1777
1760
  mse_r = [r1, r2, r3]
1778
1761
 
1779
- xmin, xmax = ax.get_xlim()
1762
+ xmin, xmax = ax.get_xlim() # type: ignore
1780
1763
 
1781
1764
  # 구간별 반투명 색상 채우기 (안쪽부터 바깥쪽으로, 진한 색에서 연한 색으로)
1782
1765
  colors = ["red", "green", "blue"]
1783
1766
  alphas = [0.15, 0.10, 0.05] # 안쪽이 더 진하게
1784
1767
 
1785
1768
  # 3σ 영역 (가장 바깥쪽, 가장 연함)
1786
- ax.axhspan(-3 * mse_sq, 3 * mse_sq, facecolor=colors[2], alpha=alphas[2], zorder=0)
1769
+ ax.axhspan(-3 * mse_sq, 3 * mse_sq, facecolor=colors[2], alpha=alphas[2], zorder=0) # type: ignore
1787
1770
  # 2σ 영역 (중간)
1788
- ax.axhspan(-2 * mse_sq, 2 * mse_sq, facecolor=colors[1], alpha=alphas[1], zorder=1)
1771
+ ax.axhspan(-2 * mse_sq, 2 * mse_sq, facecolor=colors[1], alpha=alphas[1], zorder=1) # type: ignore
1789
1772
  # 1σ 영역 (가장 안쪽, 가장 진함)
1790
- ax.axhspan(-mse_sq, mse_sq, facecolor=colors[0], alpha=alphas[0], zorder=2)
1773
+ ax.axhspan(-mse_sq, mse_sq, facecolor=colors[0], alpha=alphas[0], zorder=2) # type: ignore
1791
1774
 
1792
1775
  # 경계선 그리기
1793
1776
  for i, c in enumerate(["red", "green", "blue"]):
1794
- ax.axhline(mse_sq * (i + 1), color=c, linestyle="--", linewidth=linewidth/2)
1795
- ax.axhline(mse_sq * (-(i + 1)), color=c, linestyle="--", linewidth=linewidth/2)
1777
+ ax.axhline(mse_sq * (i + 1), color=c, linestyle="--", linewidth=linewidth/2) # type: ignore
1778
+ ax.axhline(mse_sq * (-(i + 1)), color=c, linestyle="--", linewidth=linewidth/2) # type: ignore
1796
1779
 
1797
1780
  target = [68, 95, 99.7]
1798
1781
  for i, c in enumerate(["red", "green", "blue"]):
1799
- ax.text(
1782
+ ax.text( # type: ignore
1800
1783
  s=f"{i+1} sqrt(MSE) = {mse_r[i]:.2f}% ({mse_r[i] - target[i]:.2f}%)",
1801
- x=xmax + 0.2,
1784
+ x=xmax + 0.05,
1802
1785
  y=(i + 1) * mse_sq,
1803
1786
  color=c,
1804
1787
  )
1805
- ax.text(
1788
+ ax.text( # type: ignore
1806
1789
  s=f"-{i+1} sqrt(MSE) = {mse_r[i]:.2f}% ({mse_r[i] - target[i]:.2f}%)",
1807
- x=xmax + 0.2,
1790
+ x=xmax + 0.05,
1808
1791
  y=-(i + 1) * mse_sq,
1809
1792
  color=c,
1810
1793
  )
1811
1794
 
1812
- finalize_plot(ax, callback, outparams, save_path, True, title)
1795
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1813
1796
 
1814
1797
 
1815
1798
  # ===================================================================
@@ -1823,9 +1806,9 @@ def ols_qqplot(
1823
1806
  height: int = config.height,
1824
1807
  linewidth: float = config.line_width,
1825
1808
  dpi: int = config.dpi,
1826
- save_path: str = None,
1827
- callback: any = None,
1828
- ax: Axes = None,
1809
+ save_path: str | None = None,
1810
+ callback: Callable | None = None,
1811
+ ax: Axes | None = None,
1829
1812
  **params,
1830
1813
  ) -> None:
1831
1814
  """표준화된 잔차의 정규성 확인을 위한 QQ 플롯을 그린다.
@@ -1870,7 +1853,7 @@ def ols_qqplot(
1870
1853
  outparams = False
1871
1854
 
1872
1855
  if ax is None:
1873
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
1856
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
1874
1857
  outparams = True
1875
1858
 
1876
1859
  # fit 객체에서 잔차(residuals) 추출
@@ -1885,7 +1868,7 @@ def ols_qqplot(
1885
1868
  sm_qqplot(residuals, line=line, ax=ax, **params)
1886
1869
 
1887
1870
  # 점의 스타일 개선: 연한 내부, 진한 테두리
1888
- for collection in ax.collections:
1871
+ for collection in ax.collections: # type: ignore
1889
1872
  # PathCollection (scatter plot의 점들)
1890
1873
  collection.set_facecolor('#4A90E2') # 연한 파란색 내부
1891
1874
  collection.set_edgecolor('#1E3A8A') # 진한 파란색 테두리
@@ -1893,11 +1876,11 @@ def ols_qqplot(
1893
1876
  collection.set_alpha(0.7) # 약간의 투명도
1894
1877
 
1895
1878
  # 선 굵기 조정
1896
- for line in ax.get_lines():
1897
- if line.get_linestyle() == '--' or line.get_color() == 'r':
1898
- line.set_linewidth(linewidth)
1879
+ for line in ax.get_lines(): # type: ignore
1880
+ if line.get_linestyle() == '--' or line.get_color() == 'r': # type: ignore
1881
+ line.set_linewidth(linewidth) # type: ignore
1899
1882
 
1900
- finalize_plot(ax, callback, outparams, save_path, True, title)
1883
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
1901
1884
 
1902
1885
 
1903
1886
  # ===================================================================
@@ -1906,18 +1889,18 @@ def ols_qqplot(
1906
1889
  def distribution_by_class(
1907
1890
  data: DataFrame,
1908
1891
  title: str | None = None,
1909
- xnames: list = None,
1910
- hue: str = None,
1892
+ xnames: list | None = None,
1893
+ hue: str | None = None,
1911
1894
  type: str = "kde",
1912
- bins: any = 5,
1913
- palette: str = None,
1895
+ bins: list[int] | int = 5,
1896
+ palette: str | None = None,
1914
1897
  fill: bool = False,
1915
1898
  width: int = config.width,
1916
1899
  height: int = config.height,
1917
1900
  linewidth: float = config.line_width,
1918
1901
  dpi: int = config.dpi,
1919
- save_path: str = None,
1920
- callback: any = None,
1902
+ save_path: str | None = None,
1903
+ callback: Callable | None = None,
1921
1904
  ) -> None:
1922
1905
  """클래스별로 각 숫자형 특징의 분포를 KDE 또는 히스토그램으로 그린다.
1923
1906
 
@@ -1940,9 +1923,9 @@ def distribution_by_class(
1940
1923
  None
1941
1924
  """
1942
1925
  if xnames is None:
1943
- xnames = data.columns
1926
+ xnames = data.columns # type: ignore
1944
1927
 
1945
- for i, v in enumerate(xnames):
1928
+ for i, v in enumerate(xnames): # type: ignore
1946
1929
  # 종속변수이거나 숫자형이 아닌 경우는 제외
1947
1930
  if v == hue or data[v].dtype not in [
1948
1931
  "int",
@@ -1973,7 +1956,7 @@ def distribution_by_class(
1973
1956
  df=data,
1974
1957
  xname=v,
1975
1958
  hue=hue,
1976
- bins=bins,
1959
+ bins=bins, # type: ignore
1977
1960
  kde=False,
1978
1961
  palette=palette,
1979
1962
  width=width,
@@ -1988,7 +1971,7 @@ def distribution_by_class(
1988
1971
  df=data,
1989
1972
  xname=v,
1990
1973
  hue=hue,
1991
- bins=bins,
1974
+ bins=bins, # type: ignore
1992
1975
  kde=True,
1993
1976
  palette=palette,
1994
1977
  width=width,
@@ -2015,8 +1998,8 @@ def scatter_by_class(
2015
1998
  height: int = config.height,
2016
1999
  linewidth: float = config.line_width,
2017
2000
  dpi: int = config.dpi,
2018
- save_path: str = None,
2019
- callback: any = None,
2001
+ save_path: str | None = None,
2002
+ callback: Callable | None = None,
2020
2003
  ) -> None:
2021
2004
  """종속변수(y)와 각 연속형 독립변수(x) 간 산점도/볼록껍질을 그린다.
2022
2005
 
@@ -2071,7 +2054,7 @@ def scatter_by_class(
2071
2054
  for v in group:
2072
2055
  scatterplot(data=data, xname=v[0], yname=v[1], hue=hue, palette=palette,
2073
2056
  width=width, height=height, linewidth=linewidth, dpi=dpi, callback=callback,
2074
- save_path=save_path)
2057
+ save_path=save_path) # type: ignore
2075
2058
 
2076
2059
 
2077
2060
  # ===================================================================
@@ -2090,8 +2073,8 @@ def categorical_target_distribution(
2090
2073
  linewidth: float = config.line_width,
2091
2074
  dpi: int = config.dpi,
2092
2075
  cols: int = 2,
2093
- save_path: str = None,
2094
- callback: any = None,
2076
+ save_path: str | None = None,
2077
+ callback: Callable | None = None,
2095
2078
  ) -> None:
2096
2079
  """명목형 변수별로 종속변수 분포 차이를 시각화한다.
2097
2080
 
@@ -2149,7 +2132,7 @@ def categorical_target_distribution(
2149
2132
  plot_kwargs.update({"x": yname, "hue": col, "palette": palette, "fill": kde_fill, "common_norm": False, "linewidth": linewidth})
2150
2133
  sb.kdeplot(**plot_kwargs)
2151
2134
  else: # box
2152
- plot_kwargs.update({"x": col, "y": yname, "palette": palette})
2135
+ plot_kwargs.update({"x": col, "y": yname, "hue": col, "palette": palette})
2153
2136
  sb.boxplot(**plot_kwargs, linewidth=linewidth)
2154
2137
 
2155
2138
  ax.set_title(f"{col} vs {yname}")
@@ -2166,16 +2149,16 @@ def categorical_target_distribution(
2166
2149
  # ===================================================================
2167
2150
  def roc_curve_plot(
2168
2151
  fit,
2169
- y: np.ndarray | pd.Series = None,
2170
- X: pd.DataFrame | np.ndarray = None,
2152
+ y: np.ndarray | pd.Series | None = None,
2153
+ X: pd.DataFrame | np.ndarray | None = None,
2171
2154
  title: str | None = None,
2172
2155
  width: int = config.height,
2173
2156
  height: int = config.height,
2174
2157
  linewidth: float = config.line_width,
2175
2158
  dpi: int = config.dpi,
2176
- save_path: str = None,
2177
- callback: any = None,
2178
- ax: Axes = None,
2159
+ save_path: str | None = None,
2160
+ callback: Callable | None = None,
2161
+ ax: Axes | None = None,
2179
2162
  ) -> None:
2180
2163
  """로지스틱 회귀 적합 결과의 ROC 곡선을 시각화한다.
2181
2164
 
@@ -2200,7 +2183,7 @@ def roc_curve_plot(
2200
2183
  """
2201
2184
  outparams = False
2202
2185
  if ax is None:
2203
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
2186
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
2204
2187
  outparams = True
2205
2188
 
2206
2189
  # 실제값(y_true) 결정
@@ -2221,16 +2204,16 @@ def roc_curve_plot(
2221
2204
  roc_auc = auc(fpr, tpr)
2222
2205
 
2223
2206
  # ROC 곡선 그리기
2224
- ax.plot(fpr, tpr, color='darkorange', lw=linewidth, label=f'ROC curve (AUC = {roc_auc:.4f})')
2225
- ax.plot([0, 1], [0, 1], color='navy', lw=linewidth, linestyle='--', label='Random Classifier')
2207
+ ax.plot(fpr, tpr, color='darkorange', lw=linewidth, label=f'ROC curve (AUC = {roc_auc:.4f})') # type: ignore
2208
+ ax.plot([0, 1], [0, 1], color='navy', lw=linewidth, linestyle='--', label='Random Classifier') # type: ignore
2226
2209
 
2227
- ax.set_xlim([0.0, 1.0])
2228
- ax.set_ylim([0.0, 1.05])
2229
- ax.set_xlabel('위양성율 (False Positive Rate)', fontsize=8)
2230
- ax.set_ylabel('재현율 (True Positive Rate)', fontsize=8)
2231
- ax.set_title('ROC 곡선', fontsize=10, fontweight='bold')
2232
- ax.legend(loc="lower right", fontsize=7)
2233
- finalize_plot(ax, callback, outparams, save_path, True, title)
2210
+ ax.set_xlim([0.0, 1.0]) # type: ignore
2211
+ ax.set_ylim([0.0, 1.05]) # type: ignore
2212
+ ax.set_xlabel('위양성율 (False Positive Rate)', fontsize=8) # type: ignore
2213
+ ax.set_ylabel('재현율 (True Positive Rate)', fontsize=8) # type: ignore
2214
+ ax.set_title('ROC 곡선', fontsize=10, fontweight='bold') # type: ignore
2215
+ ax.legend(loc="lower right", fontsize=7) # type: ignore
2216
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
2234
2217
 
2235
2218
 
2236
2219
  # ===================================================================
@@ -2243,9 +2226,9 @@ def confusion_matrix_plot(
2243
2226
  width: int = config.width,
2244
2227
  height: int = config.height,
2245
2228
  dpi: int = config.dpi,
2246
- save_path: str = None,
2247
- callback: any = None,
2248
- ax: Axes = None,
2229
+ save_path: str | None = None,
2230
+ callback: Callable | None = None,
2231
+ ax: Axes | None = None,
2249
2232
  ) -> None:
2250
2233
  """로지스틱 회귀 적합 결과의 혼동행렬을 시각화한다.
2251
2234
 
@@ -2264,7 +2247,7 @@ def confusion_matrix_plot(
2264
2247
  """
2265
2248
  outparams = False
2266
2249
  if ax is None:
2267
- fig, ax = get_default_ax(width, height, 1, 1, dpi)
2250
+ fig, ax = get_default_ax(width, height, 1, 1, dpi) # type: ignore
2268
2251
  outparams = True
2269
2252
 
2270
2253
  # 학습 데이터 기반 실제값/예측 확률 결정
@@ -2280,9 +2263,9 @@ def confusion_matrix_plot(
2280
2263
  # 가독성을 위해 텍스트 크기/굵기 조정
2281
2264
  disp.plot(ax=ax, cmap='Blues', values_format='d', text_kw={"fontsize": 16, "weight": "bold"})
2282
2265
 
2283
- ax.set_title(f'혼동행렬 (임계값: {threshold})', fontsize=8, fontweight='bold')
2266
+ ax.set_title(f'혼동행렬 (임계값: {threshold})', fontsize=8, fontweight='bold') # type: ignore
2284
2267
 
2285
- finalize_plot(ax, callback, outparams, save_path, False, title)
2268
+ finalize_plot(ax, callback, outparams, save_path, False, title) # type: ignore
2286
2269
 
2287
2270
 
2288
2271
  # ===================================================================
@@ -2290,20 +2273,20 @@ def confusion_matrix_plot(
2290
2273
  # ===================================================================
2291
2274
  def radarplot(
2292
2275
  df: DataFrame,
2293
- columns: list = None,
2294
- hue: str = None,
2276
+ columns: list | None = None,
2277
+ hue: str | None = None,
2295
2278
  title: str | None = None,
2296
2279
  normalize: bool = True,
2297
2280
  fill: bool = True,
2298
2281
  fill_alpha: float = 0.25,
2299
- palette: str = None,
2282
+ palette: str | None = None,
2300
2283
  width: int = config.width,
2301
2284
  height: int = config.height,
2302
2285
  linewidth: float = config.line_width,
2303
2286
  dpi: int = config.dpi,
2304
- save_path: str = None,
2305
- callback: any = None,
2306
- ax: Axes = None,
2287
+ save_path: str | None = None,
2288
+ callback: Callable | None = None,
2289
+ ax: Axes | None = None,
2307
2290
  **params,
2308
2291
  ) -> None:
2309
2292
  """레이더 차트(방사형 차트)를 그린다.
@@ -2414,7 +2397,7 @@ def radarplot(
2414
2397
  else:
2415
2398
  ax.set_title('Radar Chart', pad=20)
2416
2399
 
2417
- finalize_plot(ax, callback, outparams, save_path, True, title)
2400
+ finalize_plot(ax, callback, outparams, save_path, True, title) # type: ignore
2418
2401
 
2419
2402
 
2420
2403
  # ===================================================================
@@ -2422,8 +2405,7 @@ def radarplot(
2422
2405
  # ===================================================================
2423
2406
  def distribution_plot(
2424
2407
  data: DataFrame,
2425
- column: str,
2426
- title: str | None = None,
2408
+ column: str | list[str],
2427
2409
  clevel: float = 0.95,
2428
2410
  orient: str = "h",
2429
2411
  hue: str | None = None,
@@ -2432,8 +2414,8 @@ def distribution_plot(
2432
2414
  height: int = config.height,
2433
2415
  linewidth: float = config.line_width,
2434
2416
  dpi: int = config.dpi,
2435
- save_path: str = None,
2436
- callback: any = None,
2417
+ save_path: str | None = None,
2418
+ callback: Callable | None = None,
2437
2419
  ) -> None:
2438
2420
  """연속형 데이터의 분포를 KDE와 Boxplot으로 시각화한다.
2439
2421
 
@@ -2444,7 +2426,6 @@ def distribution_plot(
2444
2426
  Args:
2445
2427
  data (DataFrame): 시각화할 데이터.
2446
2428
  column (str): 분석할 컬럼명.
2447
- title (str|None): 그래프 제목.
2448
2429
  clevel (float): KDE 신뢰수준 (0~1). 기본값 0.95.
2449
2430
  orient (str): Boxplot 방향 ('v' 또는 'h'). 기본값 'h'.
2450
2431
  hue (str|None): 명목형 컬럼명. 지정하면 각 범주별로 행을 늘려 KDE와 boxplot을 그림.
@@ -2459,76 +2440,82 @@ def distribution_plot(
2459
2440
  Returns:
2460
2441
  None
2461
2442
  """
2462
- if hue is None:
2463
- # 1행 2열 서브플롯 생성
2464
- fig, axes = get_default_ax(width, height, rows=1, cols=2, dpi=dpi)
2465
-
2466
- kde_confidence_interval(
2467
- data=data,
2468
- xnames=column,
2469
- clevel=clevel,
2470
- linewidth=linewidth,
2471
- ax=axes[0],
2472
- )
2473
-
2474
- if kind == "hist":
2475
- histplot(
2476
- df=data,
2477
- xname=column,
2478
- linewidth=linewidth,
2479
- ax=axes[1]
2480
- )
2481
- else:
2482
- boxplot(
2483
- df=data[column],
2484
- linewidth=linewidth,
2485
- ax=axes[1]
2486
- )
2487
-
2488
- fig.suptitle(f"Distribution of {column}", fontsize=14, y=1.02)
2489
- else:
2490
- if hue not in data.columns:
2491
- raise ValueError(f"hue column '{hue}' not found in DataFrame")
2492
-
2493
- categories = list(pd.Series(data[hue].dropna().unique()).sort_values())
2494
- n_cat = len(categories) if categories else 1
2443
+ if isinstance(column, str):
2444
+ column = [column]
2495
2445
 
2496
- fig, axes = get_default_ax(width, height, rows=n_cat, cols=2, dpi=dpi)
2497
- axes_2d = np.atleast_2d(axes)
2446
+ for c in column:
2447
+ title = f"Distribution Plot of {c}"
2498
2448
 
2499
- for idx, cat in enumerate(categories):
2500
- subset = data[data[hue] == cat]
2501
- left_ax, right_ax = axes_2d[idx, 0], axes_2d[idx, 1]
2449
+ if hue is None:
2450
+ # 1행 2열 서브플롯 생성
2451
+ fig, axes = get_default_ax(width, height, rows=1, cols=2, dpi=dpi, title=title)
2502
2452
 
2503
2453
  kde_confidence_interval(
2504
- data=subset,
2505
- xnames=column,
2454
+ data=data,
2455
+ xnames=c,
2506
2456
  clevel=clevel,
2507
2457
  linewidth=linewidth,
2508
- ax=left_ax,
2458
+ ax=axes[0],
2509
2459
  )
2510
- left_ax.set_title(f"{hue} = {cat}")
2511
2460
 
2512
2461
  if kind == "hist":
2513
2462
  histplot(
2514
- df=subset,
2515
- xname=column,
2463
+ df=data,
2464
+ xname=c,
2516
2465
  linewidth=linewidth,
2517
- ax=right_ax,
2466
+ ax=axes[1]
2518
2467
  )
2519
2468
  else:
2520
2469
  boxplot(
2521
- df=subset[column],
2470
+ df=data[column], # type: ignore
2522
2471
  linewidth=linewidth,
2523
- ax=right_ax
2472
+ ax=axes[1]
2524
2473
  )
2525
2474
 
2526
- fig.suptitle(f"Distribution of {column} by {hue}", fontsize=14, y=1.02)
2475
+ fig.suptitle(title, fontsize=14, y=1.02)
2476
+ else:
2477
+ if hue not in data.columns:
2478
+ raise ValueError(f"hue column '{hue}' not found in DataFrame")
2527
2479
 
2528
- plt.tight_layout()
2480
+ categories = list(pd.Series(data[hue].dropna().unique()).sort_values())
2481
+ n_cat = len(categories) if categories else 1
2529
2482
 
2530
- if save_path:
2531
- plt.savefig(save_path, bbox_inches='tight', dpi=dpi)
2532
- plt.close()
2533
- else:
2534
- plt.show()
2483
+ fig, axes = get_default_ax(width, height, rows=n_cat, cols=2, dpi=dpi, title=title)
2484
+ axes_2d = np.atleast_2d(axes)
2485
+
2486
+ for idx, cat in enumerate(categories):
2487
+ subset = data[data[hue] == cat]
2488
+ left_ax, right_ax = axes_2d[idx, 0], axes_2d[idx, 1]
2489
+
2490
+ kde_confidence_interval(
2491
+ data=subset,
2492
+ xnames=c,
2493
+ clevel=clevel,
2494
+ linewidth=linewidth,
2495
+ ax=left_ax,
2496
+ )
2497
+ left_ax.set_title(f"{hue} = {cat}")
2498
+
2499
+ if kind == "hist":
2500
+ histplot(
2501
+ df=subset,
2502
+ xname=c,
2503
+ linewidth=linewidth,
2504
+ ax=right_ax,
2505
+ )
2506
+ else:
2507
+ boxplot(
2508
+ df=subset[c], # type: ignore
2509
+ linewidth=linewidth,
2510
+ ax=right_ax
2511
+ )
2512
+
2513
+ fig.suptitle(f"{title} by {hue}", fontsize=14, y=1.02)
2514
+
2515
+ plt.tight_layout()
2516
+
2517
+ if save_path:
2518
+ plt.savefig(save_path, bbox_inches='tight', dpi=dpi)
2519
+ plt.close()
2520
+ else:
2521
+ plt.show()