py2ls 0.1.9.0__py3-none-any.whl → 0.1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/data/.DS_Store +0 -0
 - py2ls/data/styles/style1.json +145 -0
 - py2ls/data/styles/style2.json +146 -0
 - py2ls/data/styles/style3.json +146 -0
 - py2ls/data/styles/style4.json +142 -0
 - py2ls/data/styles/style5.json +142 -0
 - py2ls/data/styles/style6.json +142 -0
 - py2ls/plot.py +117 -197
 - py2ls/stats.py +783 -390
 - {py2ls-0.1.9.0.dist-info → py2ls-0.1.9.2.dist-info}/METADATA +1 -1
 - {py2ls-0.1.9.0.dist-info → py2ls-0.1.9.2.dist-info}/RECORD +12 -6
 - {py2ls-0.1.9.0.dist-info → py2ls-0.1.9.2.dist-info}/WHEEL +0 -0
 
    
        py2ls/plot.py
    CHANGED
    
    | 
         @@ -3,12 +3,14 @@ import numpy as np 
     | 
|
| 
       3 
3 
     | 
    
         
             
            import pandas as pd
         
     | 
| 
       4 
4 
     | 
    
         
             
            from matplotlib.colors import to_rgba
         
     | 
| 
       5 
5 
     | 
    
         
             
            from scipy.stats import gaussian_kde
         
     | 
| 
       6 
     | 
    
         
            -
             
     | 
| 
      
 6 
     | 
    
         
            +
            import seaborn as sns
         
     | 
| 
       7 
7 
     | 
    
         
             
            import matplotlib
         
     | 
| 
       8 
8 
     | 
    
         
             
            import matplotlib.ticker as tck
         
     | 
| 
       9 
9 
     | 
    
         
             
            from cycler import cycler
         
     | 
| 
       10 
10 
     | 
    
         
             
            import logging
         
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
      
 11 
     | 
    
         
            +
            import os
         
     | 
| 
      
 12 
     | 
    
         
            +
            from .ips import fsave, fload, mkdir
         
     | 
| 
      
 13 
     | 
    
         
            +
            from .stats import *
         
     | 
| 
       12 
14 
     | 
    
         | 
| 
       13 
15 
     | 
    
         
             
            # Suppress INFO messages from fontTools
         
     | 
| 
       14 
16 
     | 
    
         
             
            logging.getLogger("fontTools").setLevel(logging.WARNING)
         
     | 
| 
         @@ -219,8 +221,11 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       219 
221 
     | 
    
         
             
                    # MeanLine or MedianLine only keep only one
         
     | 
| 
       220 
222 
     | 
    
         
             
                    if bx_opt["MeanLine"]:  # MeanLine has priority
         
     | 
| 
       221 
223 
     | 
    
         
             
                        bx_opt["MedianLine"] = False
         
     | 
| 
      
 224 
     | 
    
         
            +
                    # rm NaNs
         
     | 
| 
      
 225 
     | 
    
         
            +
                    cleaned_data = [data[~np.isnan(data[:, i]), i] for i in range(data.shape[1])]
         
     | 
| 
      
 226 
     | 
    
         
            +
             
     | 
| 
       222 
227 
     | 
    
         
             
                    bxp = ax.boxplot(
         
     | 
| 
       223 
     | 
    
         
            -
                         
     | 
| 
      
 228 
     | 
    
         
            +
                        cleaned_data,
         
     | 
| 
       224 
229 
     | 
    
         
             
                        positions=X_bx,
         
     | 
| 
       225 
230 
     | 
    
         
             
                        notch=bx_opt["Notch"],
         
     | 
| 
       226 
231 
     | 
    
         
             
                        patch_artist=True,
         
     | 
| 
         @@ -475,31 +480,69 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       475 
480 
     | 
    
         
             
                        data = df2array(data=data, x=x, y=y, hue=hue)
         
     | 
| 
       476 
481 
     | 
    
         
             
                        xticklabels = []
         
     | 
| 
       477 
482 
     | 
    
         
             
                        if hue is not None:
         
     | 
| 
      
 483 
     | 
    
         
            +
                            # for i in df[x].unique().tolist():
         
     | 
| 
      
 484 
     | 
    
         
            +
                            #     for j in df[hue].unique().tolist():
         
     | 
| 
      
 485 
     | 
    
         
            +
                            #         xticklabels.append(i + "-" + j)
         
     | 
| 
       478 
486 
     | 
    
         
             
                            for i in df[x].unique().tolist():
         
     | 
| 
       479 
     | 
    
         
            -
                                 
     | 
| 
       480 
     | 
    
         
            -
                                    xticklabels.append(i + "-" + j)
         
     | 
| 
      
 487 
     | 
    
         
            +
                                xticklabels.append(i)
         
     | 
| 
       481 
488 
     | 
    
         
             
                            x_len = len(df[x].unique().tolist())
         
     | 
| 
       482 
489 
     | 
    
         
             
                            hue_len = len(df[hue].unique().tolist())
         
     | 
| 
       483 
490 
     | 
    
         
             
                            xticks = generate_xticks_with_gap(x_len, hue_len)
         
     | 
| 
      
 491 
     | 
    
         
            +
                            xticks_x_loc = generate_xticks_x_labels(x_len, hue_len)
         
     | 
| 
       484 
492 
     | 
    
         
             
                            default_x_width = 0.85
         
     | 
| 
       485 
493 
     | 
    
         
             
                            legend_hue = df[hue].unique().tolist()
         
     | 
| 
       486 
494 
     | 
    
         
             
                            default_colors = get_color(hue_len)
         
     | 
| 
      
 495 
     | 
    
         
            +
             
     | 
| 
      
 496 
     | 
    
         
            +
                            # ! stats info
         
     | 
| 
      
 497 
     | 
    
         
            +
                            stats_param = kwargs.get("stats", False)
         
     | 
| 
      
 498 
     | 
    
         
            +
                            res = pd.DataFrame()  # Initialize an empty DataFrame to store results
         
     | 
| 
      
 499 
     | 
    
         
            +
                            for i in df[x].unique().tolist():
         
     | 
| 
      
 500 
     | 
    
         
            +
                                print(i)
         
     | 
| 
      
 501 
     | 
    
         
            +
                                if hue and stats_param:
         
     | 
| 
      
 502 
     | 
    
         
            +
                                    if isinstance(stats_param, dict):
         
     | 
| 
      
 503 
     | 
    
         
            +
                                        if "factor" in stats_param.keys():
         
     | 
| 
      
 504 
     | 
    
         
            +
                                            res_tmp = FuncMultiCmpt(data=df, dv=y, **stats_param)
         
     | 
| 
      
 505 
     | 
    
         
            +
                                        else:
         
     | 
| 
      
 506 
     | 
    
         
            +
                                            res_tmp = FuncMultiCmpt(
         
     | 
| 
      
 507 
     | 
    
         
            +
                                                data=df[df[x] == i], dv=y, factor=hue, **stats_param
         
     | 
| 
      
 508 
     | 
    
         
            +
                                            )
         
     | 
| 
      
 509 
     | 
    
         
            +
                                    elif bool(stats_param):
         
     | 
| 
      
 510 
     | 
    
         
            +
                                        res_tmp = FuncMultiCmpt(data=df, dv=y, factor=hue)
         
     | 
| 
      
 511 
     | 
    
         
            +
                                    else:
         
     | 
| 
      
 512 
     | 
    
         
            +
                                        res_tmp = "did not work properly"
         
     | 
| 
      
 513 
     | 
    
         
            +
                                    display_output(res_tmp)
         
     | 
| 
      
 514 
     | 
    
         
            +
                                    res_tmp = [{"x": i, **res_tmp}]
         
     | 
| 
      
 515 
     | 
    
         
            +
                                    res = pd.concat(
         
     | 
| 
      
 516 
     | 
    
         
            +
                                        [res, pd.DataFrame([res_tmp])], ignore_index=True
         
     | 
| 
      
 517 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 518 
     | 
    
         
            +
                            display_output(res)
         
     | 
| 
       487 
519 
     | 
    
         
             
                        else:
         
     | 
| 
      
 520 
     | 
    
         
            +
                            # ! stats info
         
     | 
| 
      
 521 
     | 
    
         
            +
                            stats_param = kwargs.get("stats", False)
         
     | 
| 
       488 
522 
     | 
    
         
             
                            for i in df[x].unique().tolist():
         
     | 
| 
       489 
523 
     | 
    
         
             
                                xticklabels.append(i)
         
     | 
| 
       490 
524 
     | 
    
         
             
                            xticks = np.arange(1, len(xticklabels) + 1).tolist()
         
     | 
| 
      
 525 
     | 
    
         
            +
                            xticks_x_loc = np.arange(1, len(xticklabels) + 1).tolist()
         
     | 
| 
       491 
526 
     | 
    
         
             
                            legend_hue = xticklabels
         
     | 
| 
       492 
527 
     | 
    
         
             
                            default_colors = get_color(len(xticklabels))
         
     | 
| 
       493 
528 
     | 
    
         
             
                            default_x_width = 0.5
         
     | 
| 
      
 529 
     | 
    
         
            +
                            res = None
         
     | 
| 
      
 530 
     | 
    
         
            +
                            if x and stats_param:
         
     | 
| 
      
 531 
     | 
    
         
            +
                                if isinstance(stats_param, dict):
         
     | 
| 
      
 532 
     | 
    
         
            +
                                    res = FuncMultiCmpt(data=df, dv=y, factor=x, **stats_param)
         
     | 
| 
      
 533 
     | 
    
         
            +
                                elif bool(stats_param):
         
     | 
| 
      
 534 
     | 
    
         
            +
                                    res = FuncMultiCmpt(data=df, dv=y, factor=x)
         
     | 
| 
      
 535 
     | 
    
         
            +
                                else:
         
     | 
| 
      
 536 
     | 
    
         
            +
                                    res = "did not work properly"
         
     | 
| 
      
 537 
     | 
    
         
            +
                            display_output(res)
         
     | 
| 
       494 
538 
     | 
    
         | 
| 
       495 
539 
     | 
    
         
             
                        # when the xticklabels are too long, rotate the labels a bit
         
     | 
| 
       496 
     | 
    
         
            -
             
     | 
| 
       497 
     | 
    
         
            -
                        xangle = 30 if max([len(i) for i in xticklabels]) > 5 else 0
         
     | 
| 
      
 540 
     | 
    
         
            +
                        xangle = 30 if max([len(i) for i in xticklabels]) > 50 else 0
         
     | 
| 
       498 
541 
     | 
    
         
             
                        if kw_figsets is not None:
         
     | 
| 
       499 
542 
     | 
    
         
             
                            kw_figsets = {
         
     | 
| 
       500 
543 
     | 
    
         
             
                                "ylabel": y,
         
     | 
| 
       501 
544 
     | 
    
         
             
                                # "xlabel": x,
         
     | 
| 
       502 
     | 
    
         
            -
                                "xticks": xticks,
         
     | 
| 
      
 545 
     | 
    
         
            +
                                "xticks": xticks_x_loc,  # xticks,
         
     | 
| 
       503 
546 
     | 
    
         
             
                                "xticklabels": xticklabels,
         
     | 
| 
       504 
547 
     | 
    
         
             
                                "xangle": xangle,
         
     | 
| 
       505 
548 
     | 
    
         
             
                                **kw_figsets,
         
     | 
| 
         @@ -508,7 +551,7 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       508 
551 
     | 
    
         
             
                            kw_figsets = {
         
     | 
| 
       509 
552 
     | 
    
         
             
                                "ylabel": y,
         
     | 
| 
       510 
553 
     | 
    
         
             
                                # "xlabel": x,
         
     | 
| 
       511 
     | 
    
         
            -
                                "xticks": xticks,
         
     | 
| 
      
 554 
     | 
    
         
            +
                                "xticks": xticks_x_loc,  # xticks,
         
     | 
| 
       512 
555 
     | 
    
         
             
                                "xticklabels": xticklabels,
         
     | 
| 
       513 
556 
     | 
    
         
             
                                "xangle": xangle,
         
     | 
| 
       514 
557 
     | 
    
         
             
                            }
         
     | 
| 
         @@ -521,6 +564,22 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       521 
564 
     | 
    
         | 
| 
       522 
565 
     | 
    
         
             
                    # full_order
         
     | 
| 
       523 
566 
     | 
    
         
             
                    opt = kwargs.get("opt", {})
         
     | 
| 
      
 567 
     | 
    
         
            +
             
     | 
| 
      
 568 
     | 
    
         
            +
                    # load style:
         
     | 
| 
      
 569 
     | 
    
         
            +
                    style_use = None
         
     | 
| 
      
 570 
     | 
    
         
            +
                    for k, v in kwargs.items():
         
     | 
| 
      
 571 
     | 
    
         
            +
                        if "style" in k and "exp" not in k:
         
     | 
| 
      
 572 
     | 
    
         
            +
                            style_use = v
         
     | 
| 
      
 573 
     | 
    
         
            +
                            break
         
     | 
| 
      
 574 
     | 
    
         
            +
                    if style_use:
         
     | 
| 
      
 575 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 576 
     | 
    
         
            +
                            dir_curr_script = os.path.dirname(os.path.abspath(__file__))
         
     | 
| 
      
 577 
     | 
    
         
            +
                            dir_style = dir_curr_script + "/data/styles/"
         
     | 
| 
      
 578 
     | 
    
         
            +
                            style_load = fload(dir_style + style_use + ".json")
         
     | 
| 
      
 579 
     | 
    
         
            +
                            style_load = remove_colors_in_dict(style_load)
         
     | 
| 
      
 580 
     | 
    
         
            +
                            opt.update(style_load)
         
     | 
| 
      
 581 
     | 
    
         
            +
                        except:
         
     | 
| 
      
 582 
     | 
    
         
            +
                            print(f"cannot find the style'{style_name}'")
         
     | 
| 
       524 
583 
     | 
    
         
             
                    ax = kwargs.get("ax", None)
         
     | 
| 
       525 
584 
     | 
    
         
             
                    if "ax" not in locals() or ax is None:
         
     | 
| 
       526 
585 
     | 
    
         
             
                        ax = plt.gca()
         
     | 
| 
         @@ -536,9 +595,9 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       536 
595 
     | 
    
         
             
                    opt["loc"].setdefault("xloc", xticks)
         
     | 
| 
       537 
596 
     | 
    
         | 
| 
       538 
597 
     | 
    
         
             
                    # export setting
         
     | 
| 
       539 
     | 
    
         
            -
                    opt.setdefault(" 
     | 
| 
       540 
     | 
    
         
            -
                    opt[" 
     | 
| 
       541 
     | 
    
         
            -
                    print(opt[" 
     | 
| 
      
 598 
     | 
    
         
            +
                    opt.setdefault("style", {})
         
     | 
| 
      
 599 
     | 
    
         
            +
                    opt["style"].setdefault("export", None)
         
     | 
| 
      
 600 
     | 
    
         
            +
                    print(opt["style"])
         
     | 
| 
       542 
601 
     | 
    
         | 
| 
       543 
602 
     | 
    
         
             
                    # opt.setdefault('layer', {})
         
     | 
| 
       544 
603 
     | 
    
         
             
                    opt.setdefault("layer", ["b", "bx", "e", "v", "s", "l"])
         
     | 
| 
         @@ -572,7 +631,7 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       572 
631 
     | 
    
         
             
                    opt["e"].setdefault("Visible", True)
         
     | 
| 
       573 
632 
     | 
    
         
             
                    opt["e"].setdefault("Orientation", "vertical")
         
     | 
| 
       574 
633 
     | 
    
         
             
                    opt["e"].setdefault("error", "sem")
         
     | 
| 
       575 
     | 
    
         
            -
                    opt["e"].setdefault("x_width",  
     | 
| 
      
 634 
     | 
    
         
            +
                    opt["e"].setdefault("x_width", default_x_width / 5)
         
     | 
| 
       576 
635 
     | 
    
         
             
                    opt["e"].setdefault("cap_dir", "b")
         
     | 
| 
       577 
636 
     | 
    
         | 
| 
       578 
637 
     | 
    
         
             
                    opt.setdefault("s", {})
         
     | 
| 
         @@ -581,7 +640,7 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       581 
640 
     | 
    
         
             
                    opt["s"].setdefault("FaceColor", "w")
         
     | 
| 
       582 
641 
     | 
    
         
             
                    opt["s"].setdefault("cmap", None)
         
     | 
| 
       583 
642 
     | 
    
         
             
                    opt["s"].setdefault("FaceAlpha", 1)
         
     | 
| 
       584 
     | 
    
         
            -
                    opt["s"].setdefault("x_width",  
     | 
| 
      
 643 
     | 
    
         
            +
                    opt["s"].setdefault("x_width", default_x_width / 5)
         
     | 
| 
       585 
644 
     | 
    
         
             
                    opt["s"].setdefault("Marker", "o")
         
     | 
| 
       586 
645 
     | 
    
         
             
                    opt["s"].setdefault("MarkerSize", 15)
         
     | 
| 
       587 
646 
     | 
    
         
             
                    opt["s"].setdefault("LineWidth", 0.8)
         
     | 
| 
         @@ -602,7 +661,7 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       602 
661 
     | 
    
         
             
                    opt["bx"].setdefault("FaceAlpha", 0.85)
         
     | 
| 
       603 
662 
     | 
    
         
             
                    opt["bx"].setdefault("EdgeAlpha", 1)
         
     | 
| 
       604 
663 
     | 
    
         
             
                    opt["bx"].setdefault("LineStyle", "-")
         
     | 
| 
       605 
     | 
    
         
            -
                    opt["bx"].setdefault("x_width",  
     | 
| 
      
 664 
     | 
    
         
            +
                    opt["bx"].setdefault("x_width", default_x_width / 5)
         
     | 
| 
       606 
665 
     | 
    
         
             
                    opt["bx"].setdefault("ShowBaseLine", "off")
         
     | 
| 
       607 
666 
     | 
    
         
             
                    opt["bx"].setdefault("Notch", False)
         
     | 
| 
       608 
667 
     | 
    
         
             
                    opt["bx"].setdefault("Outliers", "on")
         
     | 
| 
         @@ -611,28 +670,29 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       611 
670 
     | 
    
         
             
                    opt["bx"].setdefault("OutlierSize", 6)
         
     | 
| 
       612 
671 
     | 
    
         
             
                    # opt['bx'].setdefault('PlotStyle', 'traditional')
         
     | 
| 
       613 
672 
     | 
    
         
             
                    # opt['bx'].setdefault('FactorDirection', 'auto')
         
     | 
| 
       614 
     | 
    
         
            -
                    opt["bx"].setdefault(" 
     | 
| 
      
 673 
     | 
    
         
            +
                    opt["bx"].setdefault("LineWidth", 0.5)
         
     | 
| 
      
 674 
     | 
    
         
            +
                    opt["bx"].setdefault("Whisker", opt["bx"]["LineWidth"])
         
     | 
| 
       615 
675 
     | 
    
         
             
                    opt["bx"].setdefault("Orientation", "vertical")
         
     | 
| 
       616 
     | 
    
         
            -
                    opt["bx"].setdefault("BoxLineWidth",  
     | 
| 
      
 676 
     | 
    
         
            +
                    opt["bx"].setdefault("BoxLineWidth", opt["bx"]["LineWidth"])
         
     | 
| 
       617 
677 
     | 
    
         
             
                    opt["bx"].setdefault("FaceColor", "k")
         
     | 
| 
       618 
678 
     | 
    
         
             
                    opt["bx"].setdefault("WhiskerLineStyle", "-")
         
     | 
| 
       619 
679 
     | 
    
         
             
                    opt["bx"].setdefault("WhiskerLineColor", "k")
         
     | 
| 
       620 
     | 
    
         
            -
                    opt["bx"].setdefault("WhiskerLineWidth",  
     | 
| 
      
 680 
     | 
    
         
            +
                    opt["bx"].setdefault("WhiskerLineWidth", opt["bx"]["LineWidth"])
         
     | 
| 
       621 
681 
     | 
    
         
             
                    opt["bx"].setdefault("Caps", True)
         
     | 
| 
       622 
682 
     | 
    
         
             
                    opt["bx"].setdefault("CapLineColor", "k")
         
     | 
| 
       623 
     | 
    
         
            -
                    opt["bx"].setdefault("CapLineWidth",  
     | 
| 
      
 683 
     | 
    
         
            +
                    opt["bx"].setdefault("CapLineWidth", opt["bx"]["LineWidth"])
         
     | 
| 
       624 
684 
     | 
    
         
             
                    opt["bx"].setdefault("CapSize", 0.2)
         
     | 
| 
       625 
685 
     | 
    
         
             
                    opt["bx"].setdefault("MedianLine", True)
         
     | 
| 
       626 
686 
     | 
    
         
             
                    opt["bx"].setdefault("MedianLineStyle", "-")
         
     | 
| 
       627 
687 
     | 
    
         
             
                    opt["bx"].setdefault("MedianStyle", "line")
         
     | 
| 
       628 
688 
     | 
    
         
             
                    opt["bx"].setdefault("MedianLineColor", "k")
         
     | 
| 
       629 
     | 
    
         
            -
                    opt["bx"].setdefault("MedianLineWidth",  
     | 
| 
      
 689 
     | 
    
         
            +
                    opt["bx"].setdefault("MedianLineWidth", opt["bx"]["LineWidth"] * 4)
         
     | 
| 
       630 
690 
     | 
    
         
             
                    opt["bx"].setdefault("MedianLineTop", False)
         
     | 
| 
       631 
691 
     | 
    
         
             
                    opt["bx"].setdefault("MeanLine", False)
         
     | 
| 
       632 
692 
     | 
    
         
             
                    opt["bx"].setdefault("showmeans", opt["bx"]["MeanLine"])
         
     | 
| 
       633 
693 
     | 
    
         
             
                    opt["bx"].setdefault("MeanLineStyle", "-")
         
     | 
| 
       634 
694 
     | 
    
         
             
                    opt["bx"].setdefault("MeanLineColor", "w")
         
     | 
| 
       635 
     | 
    
         
            -
                    opt["bx"].setdefault("MeanLineWidth",  
     | 
| 
      
 695 
     | 
    
         
            +
                    opt["bx"].setdefault("MeanLineWidth", opt["bx"]["LineWidth"] * 4)
         
     | 
| 
       636 
696 
     | 
    
         | 
| 
       637 
697 
     | 
    
         
             
                    # Violin plot options
         
     | 
| 
       638 
698 
     | 
    
         
             
                    opt.setdefault("v", {})
         
     | 
| 
         @@ -676,7 +736,6 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       676 
736 
     | 
    
         
             
                        legend_which = "v"
         
     | 
| 
       677 
737 
     | 
    
         
             
                    else:
         
     | 
| 
       678 
738 
     | 
    
         
             
                        legend_which = None
         
     | 
| 
       679 
     | 
    
         
            -
             
     | 
| 
       680 
739 
     | 
    
         
             
                    for layer in layers:
         
     | 
| 
       681 
740 
     | 
    
         
             
                        if layer == "b" and opt["b"]["go"]:
         
     | 
| 
       682 
741 
     | 
    
         
             
                            if legend_which == "b":
         
     | 
| 
         @@ -711,9 +770,13 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       711 
770 
     | 
    
         
             
                    show_legend = kwargs.get("show_legend", True)
         
     | 
| 
       712 
771 
     | 
    
         
             
                    if show_legend:
         
     | 
| 
       713 
772 
     | 
    
         
             
                        ax.legend()
         
     | 
| 
       714 
     | 
    
         
            -
             
     | 
| 
       715 
     | 
    
         
            -
                     
     | 
| 
       716 
     | 
    
         
            -
             
     | 
| 
      
 773 
     | 
    
         
            +
             
     | 
| 
      
 774 
     | 
    
         
            +
                    style_export = kwargs.get("style_export", None)
         
     | 
| 
      
 775 
     | 
    
         
            +
                    if style_export and (style_export != style_use):
         
     | 
| 
      
 776 
     | 
    
         
            +
                        dir_curr_script = os.path.dirname(os.path.abspath(__file__))
         
     | 
| 
      
 777 
     | 
    
         
            +
                        dir_style = dir_curr_script + "/data/styles/"
         
     | 
| 
      
 778 
     | 
    
         
            +
                        fsave(dir_style + style_export + ".json", opt)
         
     | 
| 
      
 779 
     | 
    
         
            +
             
     | 
| 
       717 
780 
     | 
    
         
             
                    return ax, opt
         
     | 
| 
       718 
781 
     | 
    
         
             
                else:
         
     | 
| 
       719 
782 
     | 
    
         
             
                    col_names = data[col].unique().tolist()
         
     | 
| 
         @@ -729,10 +792,13 @@ def catplot(data, *args, **kwargs): 
     | 
|
| 
       729 
792 
     | 
    
         
             
                        # ax = axs[i][0] if len(col_names) > 1 else axs[0]
         
     | 
| 
       730 
793 
     | 
    
         
             
                        if i < len(col_names):
         
     | 
| 
       731 
794 
     | 
    
         
             
                            df_sub = data.loc[data[col] == col_names[i]]
         
     | 
| 
       732 
     | 
    
         
            -
                            catplot(ax=ax, data=df_sub, **kwargs)
         
     | 
| 
       733 
     | 
    
         
            -
                            ax.set_title(col_names[i])
         
     | 
| 
      
 795 
     | 
    
         
            +
                            _, opt = catplot(ax=ax, data=df_sub, **kwargs)
         
     | 
| 
      
 796 
     | 
    
         
            +
                            ax.set_title(f"{col}={col_names[i]}")
         
     | 
| 
      
 797 
     | 
    
         
            +
                            x_label = kwargs.get("x", None)
         
     | 
| 
      
 798 
     | 
    
         
            +
                            if x_label:
         
     | 
| 
      
 799 
     | 
    
         
            +
                                ax.set_xlabel(x_label)
         
     | 
| 
       734 
800 
     | 
    
         
             
                    print(f"Axis layout shape: {axs.shape}")
         
     | 
| 
       735 
     | 
    
         
            -
                    return axs
         
     | 
| 
      
 801 
     | 
    
         
            +
                    return axs, opt
         
     | 
| 
       736 
802 
     | 
    
         | 
| 
       737 
803 
     | 
    
         | 
| 
       738 
804 
     | 
    
         
             
            def get_cmap():
         
     | 
| 
         @@ -1510,175 +1576,6 @@ def add_colorbar(im, width=None, pad=None, **kwargs): 
     | 
|
| 
       1510 
1576 
     | 
    
         
             
                return fig.colorbar(im, cax=cax, **kwargs)  # draw cbar
         
     | 
| 
       1511 
1577 
     | 
    
         | 
| 
       1512 
1578 
     | 
    
         | 
| 
       1513 
     | 
    
         
            -
            # def padcat(*args, fill_value=np.nan, axis=1):
         
     | 
| 
       1514 
     | 
    
         
            -
            #     """
         
     | 
| 
       1515 
     | 
    
         
            -
            #     Concatenate vectors with padding.
         
     | 
| 
       1516 
     | 
    
         
            -
             
     | 
| 
       1517 
     | 
    
         
            -
            #     Parameters:
         
     | 
| 
       1518 
     | 
    
         
            -
            #     *args : variable number of list or 1D arrays
         
     | 
| 
       1519 
     | 
    
         
            -
            #         Input arrays to concatenate.
         
     | 
| 
       1520 
     | 
    
         
            -
            #     fill_value : scalar, optional
         
     | 
| 
       1521 
     | 
    
         
            -
            #         The value to use for padding the shorter lists (default is np.nan).
         
     | 
| 
       1522 
     | 
    
         
            -
            #     axis : int, optional
         
     | 
| 
       1523 
     | 
    
         
            -
            #         The axis along which to concatenate (0 for rows, 1 for columns, default is 0).
         
     | 
| 
       1524 
     | 
    
         
            -
             
     | 
| 
       1525 
     | 
    
         
            -
            #     Returns:
         
     | 
| 
       1526 
     | 
    
         
            -
            #     np.ndarray
         
     | 
| 
       1527 
     | 
    
         
            -
            #         A 2D array with the input arrays concatenated along the specified axis, padded with fill_value where necessary.
         
     | 
| 
       1528 
     | 
    
         
            -
            #     """
         
     | 
| 
       1529 
     | 
    
         
            -
            #     if axis == 0:
         
     | 
| 
       1530 
     | 
    
         
            -
            #         # Concatenate along rows
         
     | 
| 
       1531 
     | 
    
         
            -
            #         max_len = max(len(lst) for lst in args)
         
     | 
| 
       1532 
     | 
    
         
            -
            #         result = np.full((len(args), max_len), fill_value)
         
     | 
| 
       1533 
     | 
    
         
            -
            #         for i, lst in enumerate(args):
         
     | 
| 
       1534 
     | 
    
         
            -
            #             result[i, : len(lst)] = lst
         
     | 
| 
       1535 
     | 
    
         
            -
            #     elif axis == 1:
         
     | 
| 
       1536 
     | 
    
         
            -
            #         # Concatenate along columns
         
     | 
| 
       1537 
     | 
    
         
            -
            #         max_len = max(len(lst) for lst in args)
         
     | 
| 
       1538 
     | 
    
         
            -
            #         result = np.full((max_len, len(args)), fill_value)
         
     | 
| 
       1539 
     | 
    
         
            -
            #         for i, lst in enumerate(args):
         
     | 
| 
       1540 
     | 
    
         
            -
            #             result[: len(lst), i] = lst
         
     | 
| 
       1541 
     | 
    
         
            -
            #     else:
         
     | 
| 
       1542 
     | 
    
         
            -
            #         raise ValueError("axis must be 0 or 1")
         
     | 
| 
       1543 
     | 
    
         
            -
             
     | 
| 
       1544 
     | 
    
         
            -
            #     return result
         
     | 
| 
       1545 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       1546 
     | 
    
         
            -
             
     | 
| 
       1547 
     | 
    
         
            -
             
     | 
| 
       1548 
     | 
    
         
            -
            def padcat(*args, fill_value=np.nan, axis=1, order="row"):
         
     | 
| 
       1549 
     | 
    
         
            -
                """
         
     | 
| 
       1550 
     | 
    
         
            -
                Concatenate vectors with padding.
         
     | 
| 
       1551 
     | 
    
         
            -
             
     | 
| 
       1552 
     | 
    
         
            -
                Parameters:
         
     | 
| 
       1553 
     | 
    
         
            -
                *args : variable number of list or 1D arrays
         
     | 
| 
       1554 
     | 
    
         
            -
                    Input arrays to concatenate.
         
     | 
| 
       1555 
     | 
    
         
            -
                fill_value : scalar, optional
         
     | 
| 
       1556 
     | 
    
         
            -
                    The value to use for padding the shorter lists (default is np.nan).
         
     | 
| 
       1557 
     | 
    
         
            -
                axis : int, optional
         
     | 
| 
       1558 
     | 
    
         
            -
                    The axis along which to concatenate (0 for rows, 1 for columns, default is 1).
         
     | 
| 
       1559 
     | 
    
         
            -
                order : str, optional
         
     | 
| 
       1560 
     | 
    
         
            -
                    The order for flattening when required: "row" or "column" (default is "row").
         
     | 
| 
       1561 
     | 
    
         
            -
             
     | 
| 
       1562 
     | 
    
         
            -
                Returns:
         
     | 
| 
       1563 
     | 
    
         
            -
                np.ndarray
         
     | 
| 
       1564 
     | 
    
         
            -
                    A 2D array with the input arrays concatenated along the specified axis,
         
     | 
| 
       1565 
     | 
    
         
            -
                    padded with fill_value where necessary.
         
     | 
| 
       1566 
     | 
    
         
            -
                """
         
     | 
| 
       1567 
     | 
    
         
            -
                # Set the order for processing
         
     | 
| 
       1568 
     | 
    
         
            -
                if "ro" in order.lower():
         
     | 
| 
       1569 
     | 
    
         
            -
                    order = "C"  # row-major order
         
     | 
| 
       1570 
     | 
    
         
            -
                else:
         
     | 
| 
       1571 
     | 
    
         
            -
                    order = "F"  # column-major order
         
     | 
| 
       1572 
     | 
    
         
            -
             
     | 
| 
       1573 
     | 
    
         
            -
                # Process input arrays based on their dimensions
         
     | 
| 
       1574 
     | 
    
         
            -
                processed_arrays = []
         
     | 
| 
       1575 
     | 
    
         
            -
                for arg in args:
         
     | 
| 
       1576 
     | 
    
         
            -
                    arr = np.asarray(arg)
         
     | 
| 
       1577 
     | 
    
         
            -
                    if arr.ndim == 1:
         
     | 
| 
       1578 
     | 
    
         
            -
                        processed_arrays.append(arr)  # Keep 1D arrays as is
         
     | 
| 
       1579 
     | 
    
         
            -
                    elif arr.ndim == 2:
         
     | 
| 
       1580 
     | 
    
         
            -
                        if axis == 0:
         
     | 
| 
       1581 
     | 
    
         
            -
                            # If concatenating along rows, split 2D arrays into 1D arrays row-wise
         
     | 
| 
       1582 
     | 
    
         
            -
                            processed_arrays.extend(arr)
         
     | 
| 
       1583 
     | 
    
         
            -
                        elif axis == 1:
         
     | 
| 
       1584 
     | 
    
         
            -
                            # If concatenating along columns, split 2D arrays into 1D arrays column-wise
         
     | 
| 
       1585 
     | 
    
         
            -
                            processed_arrays.extend(arr.T)
         
     | 
| 
       1586 
     | 
    
         
            -
                        else:
         
     | 
| 
       1587 
     | 
    
         
            -
                            raise ValueError("axis must be 0 or 1")
         
     | 
| 
       1588 
     | 
    
         
            -
                    else:
         
     | 
| 
       1589 
     | 
    
         
            -
                        raise ValueError("Input arrays must be 1D or 2D")
         
     | 
| 
       1590 
     | 
    
         
            -
             
     | 
| 
       1591 
     | 
    
         
            -
                if axis == 0:
         
     | 
| 
       1592 
     | 
    
         
            -
                    # Concatenate along rows
         
     | 
| 
       1593 
     | 
    
         
            -
                    max_len = max(arr.size for arr in processed_arrays)
         
     | 
| 
       1594 
     | 
    
         
            -
                    result = np.full((len(processed_arrays), max_len), fill_value)
         
     | 
| 
       1595 
     | 
    
         
            -
                    for i, arr in enumerate(processed_arrays):
         
     | 
| 
       1596 
     | 
    
         
            -
                        result[i, : arr.size] = arr
         
     | 
| 
       1597 
     | 
    
         
            -
                elif axis == 1:
         
     | 
| 
       1598 
     | 
    
         
            -
                    # Concatenate along columns
         
     | 
| 
       1599 
     | 
    
         
            -
                    max_len = max(arr.size for arr in processed_arrays)
         
     | 
| 
       1600 
     | 
    
         
            -
                    result = np.full((max_len, len(processed_arrays)), fill_value)
         
     | 
| 
       1601 
     | 
    
         
            -
                    for i, arr in enumerate(processed_arrays):
         
     | 
| 
       1602 
     | 
    
         
            -
                        result[: arr.size, i] = arr
         
     | 
| 
       1603 
     | 
    
         
            -
                else:
         
     | 
| 
       1604 
     | 
    
         
            -
                    raise ValueError("axis must be 0 or 1")
         
     | 
| 
       1605 
     | 
    
         
            -
             
     | 
| 
       1606 
     | 
    
         
            -
                return result
         
     | 
| 
       1607 
     | 
    
         
            -
             
     | 
| 
       1608 
     | 
    
         
            -
             
     | 
| 
       1609 
     | 
    
         
            -
            # # Example usage:
         
     | 
| 
       1610 
     | 
    
         
            -
            # a = [1, np.nan]
         
     | 
| 
       1611 
     | 
    
         
            -
            # b = [1, 3, 4, np.nan, 2, np.nan]
         
     | 
| 
       1612 
     | 
    
         
            -
            # c = [1, 2, 3, 4, 5, 6, 7, 8, 10]
         
     | 
| 
       1613 
     | 
    
         
            -
            # d = padcat(a, b)
         
     | 
| 
       1614 
     | 
    
         
            -
            # result1 = padcat(d, c)
         
     | 
| 
       1615 
     | 
    
         
            -
            # result2 = padcat(a, b, c)
         
     | 
| 
       1616 
     | 
    
         
            -
            # print("Result of padcat(d, c):\n", result1)
         
     | 
| 
       1617 
     | 
    
         
            -
            # print("Result of padcat(a, b, c):\n", result2)
         
     | 
| 
       1618 
     | 
    
         
            -
             
     | 
| 
       1619 
     | 
    
         
            -
             
     | 
| 
       1620 
     | 
    
         
            -
            def sort_rows_move_nan(arr, sort=False):
         
     | 
| 
       1621 
     | 
    
         
            -
                # Handle edge cases where all values are NaN
         
     | 
| 
       1622 
     | 
    
         
            -
                if np.all(np.isnan(arr)):
         
     | 
| 
       1623 
     | 
    
         
            -
                    return arr  # Return unchanged if the entire array is NaN
         
     | 
| 
       1624 
     | 
    
         
            -
             
     | 
| 
       1625 
     | 
    
         
            -
                if sort:
         
     | 
| 
       1626 
     | 
    
         
            -
                    # Replace NaNs with a temporary large value for sorting
         
     | 
| 
       1627 
     | 
    
         
            -
                    temp_value = (
         
     | 
| 
       1628 
     | 
    
         
            -
                        np.nanmax(arr[np.isfinite(arr)]) + 1 if np.any(np.isfinite(arr)) else np.inf
         
     | 
| 
       1629 
     | 
    
         
            -
                    )
         
     | 
| 
       1630 
     | 
    
         
            -
                    arr_no_nan = np.where(np.isnan(arr), temp_value, arr)
         
     | 
| 
       1631 
     | 
    
         
            -
             
     | 
| 
       1632 
     | 
    
         
            -
                    # Sort each row
         
     | 
| 
       1633 
     | 
    
         
            -
                    sorted_arr = np.sort(arr_no_nan, axis=1)
         
     | 
| 
       1634 
     | 
    
         
            -
             
     | 
| 
       1635 
     | 
    
         
            -
                    # Move NaNs to the end
         
     | 
| 
       1636 
     | 
    
         
            -
                    result_arr = np.where(sorted_arr == temp_value, np.nan, sorted_arr)
         
     | 
| 
       1637 
     | 
    
         
            -
                else:
         
     | 
| 
       1638 
     | 
    
         
            -
                    result_rows = []
         
     | 
| 
       1639 
     | 
    
         
            -
                    for row in arr:
         
     | 
| 
       1640 
     | 
    
         
            -
                        # Separate non-NaN and NaN values
         
     | 
| 
       1641 
     | 
    
         
            -
                        non_nan_values = row[~np.isnan(row)]
         
     | 
| 
       1642 
     | 
    
         
            -
                        nan_count = np.isnan(row).sum()
         
     | 
| 
       1643 
     | 
    
         
            -
                        # Create a new row with non-NaN values followed by NaNs
         
     | 
| 
       1644 
     | 
    
         
            -
                        new_row = np.concatenate([non_nan_values, [np.nan] * nan_count])
         
     | 
| 
       1645 
     | 
    
         
            -
                        result_rows.append(new_row)
         
     | 
| 
       1646 
     | 
    
         
            -
                    # Convert the list of rows back into a 2D NumPy array
         
     | 
| 
       1647 
     | 
    
         
            -
                    result_arr = np.array(result_rows)
         
     | 
| 
       1648 
     | 
    
         
            -
             
     | 
| 
       1649 
     | 
    
         
            -
                # Remove rows/columns that contain only NaNs
         
     | 
| 
       1650 
     | 
    
         
            -
                clean_arr = result_arr[~np.isnan(result_arr).all(axis=1)]
         
     | 
| 
       1651 
     | 
    
         
            -
                clean_arr_ = clean_arr[:, ~np.isnan(clean_arr).all(axis=0)]
         
     | 
| 
       1652 
     | 
    
         
            -
             
     | 
| 
       1653 
     | 
    
         
            -
                return clean_arr_
         
     | 
| 
       1654 
     | 
    
         
            -
             
     | 
| 
       1655 
     | 
    
         
            -
             
     | 
| 
       1656 
     | 
    
         
            -
            def df2array(data: pd.DataFrame, x, y, hue=None, sort=False):
         
     | 
| 
       1657 
     | 
    
         
            -
                if hue is None:
         
     | 
| 
       1658 
     | 
    
         
            -
                    a = []
         
     | 
| 
       1659 
     | 
    
         
            -
                    if sort:
         
     | 
| 
       1660 
     | 
    
         
            -
                        np.sort(data[x].unique().tolist()).tolist()
         
     | 
| 
       1661 
     | 
    
         
            -
                    else:
         
     | 
| 
       1662 
     | 
    
         
            -
                        cat_x = data[x].unique().tolist()
         
     | 
| 
       1663 
     | 
    
         
            -
                    for i, x_ in enumerate(cat_x):
         
     | 
| 
       1664 
     | 
    
         
            -
                        new_ = data.loc[data[x] == x_, y].to_list()
         
     | 
| 
       1665 
     | 
    
         
            -
                        a = padcat(a, new_, axis=0)
         
     | 
| 
       1666 
     | 
    
         
            -
                    return sort_rows_move_nan(a).T
         
     | 
| 
       1667 
     | 
    
         
            -
                else:
         
     | 
| 
       1668 
     | 
    
         
            -
                    a = []
         
     | 
| 
       1669 
     | 
    
         
            -
                    if sort:
         
     | 
| 
       1670 
     | 
    
         
            -
                        cat_x = np.sort(data[x].unique().tolist()).tolist()
         
     | 
| 
       1671 
     | 
    
         
            -
                        cat_hue = np.sort(data[hue].unique().tolist()).tolist()
         
     | 
| 
       1672 
     | 
    
         
            -
                    else:
         
     | 
| 
       1673 
     | 
    
         
            -
                        cat_x = data[x].unique().tolist()
         
     | 
| 
       1674 
     | 
    
         
            -
                        cat_hue = data[hue].unique().tolist()
         
     | 
| 
       1675 
     | 
    
         
            -
                    for i, x_ in enumerate(cat_x):
         
     | 
| 
       1676 
     | 
    
         
            -
                        for j, hue_ in enumerate(cat_hue):
         
     | 
| 
       1677 
     | 
    
         
            -
                            new_ = data.loc[(data[x] == x_) & (data[hue] == hue_), y].to_list()
         
     | 
| 
       1678 
     | 
    
         
            -
                            a = padcat(a, new_, axis=0)
         
     | 
| 
       1679 
     | 
    
         
            -
                    return sort_rows_move_nan(a).T
         
     | 
| 
       1680 
     | 
    
         
            -
             
     | 
| 
       1681 
     | 
    
         
            -
             
     | 
| 
       1682 
1579 
     | 
    
         
             
            def generate_xticks_with_gap(x_len, hue_len):
         
     | 
| 
       1683 
1580 
     | 
    
         
             
                """
         
     | 
| 
       1684 
1581 
     | 
    
         
             
                Generate a concatenated array based on x_len and hue_len,
         
     | 
| 
         @@ -1700,3 +1597,26 @@ def generate_xticks_with_gap(x_len, hue_len): 
     | 
|
| 
       1700 
1597 
     | 
    
         
             
                positive_array = concatenated_array[concatenated_array > 0].tolist()
         
     | 
| 
       1701 
1598 
     | 
    
         | 
| 
       1702 
1599 
     | 
    
         
             
                return positive_array
         
     | 
| 
      
 1600 
     | 
    
         
            +
             
     | 
| 
      
 1601 
     | 
    
         
            +
             
     | 
| 
      
 1602 
     | 
    
         
            +
            def generate_xticks_x_labels(x_len, hue_len):
         
     | 
| 
      
 1603 
     | 
    
         
            +
                arrays = [
         
     | 
| 
      
 1604 
     | 
    
         
            +
                    np.arange(1, hue_len + 1) + hue_len * (x_len - i) + (x_len - i)
         
     | 
| 
      
 1605 
     | 
    
         
            +
                    for i in range(max(x_len, hue_len), 0, -1)  # i iterates from 3 to 1
         
     | 
| 
      
 1606 
     | 
    
         
            +
                ]
         
     | 
| 
      
 1607 
     | 
    
         
            +
                return [np.mean(i) for i in arrays if np.mean(i) > 0]
         
     | 
| 
      
 1608 
     | 
    
         
            +
             
     | 
| 
      
 1609 
     | 
    
         
            +
             
     | 
| 
      
 1610 
     | 
    
         
            +
            def remove_colors_in_dict(
         
     | 
| 
      
 1611 
     | 
    
         
            +
                data: dict, sections_to_remove_facecolor=["b", "e", "s", "bx", "v"]
         
     | 
| 
      
 1612 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1613 
     | 
    
         
            +
                # Remove "FaceColor" from specified sections
         
     | 
| 
      
 1614 
     | 
    
         
            +
                for section in sections_to_remove_facecolor:
         
     | 
| 
      
 1615 
     | 
    
         
            +
                    if section in data and ("FaceColor" in data[section]):
         
     | 
| 
      
 1616 
     | 
    
         
            +
                        del data[section]["FaceColor"]
         
     | 
| 
      
 1617 
     | 
    
         
            +
             
     | 
| 
      
 1618 
     | 
    
         
            +
                if "c" in data:
         
     | 
| 
      
 1619 
     | 
    
         
            +
                    del data["c"]
         
     | 
| 
      
 1620 
     | 
    
         
            +
                if "loc" in data:
         
     | 
| 
      
 1621 
     | 
    
         
            +
                    del data["loc"]
         
     | 
| 
      
 1622 
     | 
    
         
            +
                return data
         
     |