oafuncs 0.0.97.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oafuncs/__init__.py +54 -0
 - oafuncs/_script/__init__.py +27 -0
 - oafuncs/_script/plot_dataset.py +299 -0
 - oafuncs/data_store/OAFuncs.png +0 -0
 - oafuncs/data_store/hycom_3hourly.png +0 -0
 - oafuncs/oa_cmap.py +215 -0
 - oafuncs/oa_data.py +293 -0
 - oafuncs/oa_down/User_Agent-list.txt +6697 -0
 - oafuncs/oa_down/__init__.py +22 -0
 - oafuncs/oa_down/hycom_3hourly.py +1309 -0
 - oafuncs/oa_down/hycom_3hourly_20250129.py +1307 -0
 - oafuncs/oa_down/idm.py +50 -0
 - oafuncs/oa_down/literature.py +288 -0
 - oafuncs/oa_down/test_ua.py +151 -0
 - oafuncs/oa_down/user_agent.py +31 -0
 - oafuncs/oa_draw.py +326 -0
 - oafuncs/oa_file.py +413 -0
 - oafuncs/oa_help.py +144 -0
 - oafuncs/oa_model/__init__.py +19 -0
 - oafuncs/oa_model/roms/__init__.py +20 -0
 - oafuncs/oa_model/roms/test.py +19 -0
 - oafuncs/oa_model/wrf/__init__.py +18 -0
 - oafuncs/oa_model/wrf/little_r.py +186 -0
 - oafuncs/oa_nc.py +523 -0
 - oafuncs/oa_python.py +108 -0
 - oafuncs/oa_sign/__init__.py +21 -0
 - oafuncs/oa_sign/meteorological.py +168 -0
 - oafuncs/oa_sign/ocean.py +158 -0
 - oafuncs/oa_sign/scientific.py +139 -0
 - oafuncs/oa_tool/__init__.py +19 -0
 - oafuncs/oa_tool/email.py +114 -0
 - oafuncs/oa_tool/parallel.py +90 -0
 - oafuncs/oa_tool/time.py +22 -0
 - oafuncs-0.0.97.1.dist-info/LICENSE.txt +19 -0
 - oafuncs-0.0.97.1.dist-info/METADATA +106 -0
 - oafuncs-0.0.97.1.dist-info/RECORD +38 -0
 - oafuncs-0.0.97.1.dist-info/WHEEL +5 -0
 - oafuncs-0.0.97.1.dist-info/top_level.txt +1 -0
 
    
        oafuncs/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,54 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            #!/usr/bin/env python
         
     | 
| 
      
 2 
     | 
    
         
            +
            # coding=utf-8
         
     | 
| 
      
 3 
     | 
    
         
            +
            """
         
     | 
| 
      
 4 
     | 
    
         
            +
            Author: Liu Kun && 16031215@qq.com
         
     | 
| 
      
 5 
     | 
    
         
            +
            Date: 2024-09-17 16:09:20
         
     | 
| 
      
 6 
     | 
    
         
            +
            LastEditors: Liu Kun && 16031215@qq.com
         
     | 
| 
      
 7 
     | 
    
         
            +
            LastEditTime: 2025-03-09 16:28:01
         
     | 
| 
      
 8 
     | 
    
         
            +
            FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\__init__.py
         
     | 
| 
      
 9 
     | 
    
         
            +
            Description:
         
     | 
| 
      
 10 
     | 
    
         
            +
            EditPlatform: vscode
         
     | 
| 
      
 11 
     | 
    
         
            +
            ComputerInfo: XPS 15 9510
         
     | 
| 
      
 12 
     | 
    
         
            +
            SystemInfo: Windows 11
         
     | 
| 
      
 13 
     | 
    
         
            +
            Python Version: 3.12
         
     | 
| 
      
 14 
     | 
    
         
            +
            """
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            # 会导致OAFuncs直接导入所有函数,不符合模块化设计
         
     | 
| 
      
 18 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_cmap import *
         
     | 
| 
      
 19 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_data import *
         
     | 
| 
      
 20 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_draw import *
         
     | 
| 
      
 21 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_file import *
         
     | 
| 
      
 22 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_help import *
         
     | 
| 
      
 23 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_nc import *
         
     | 
| 
      
 24 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_python import *
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
            # ------------------- 2024-12-13 12:31:06 -------------------
         
     | 
| 
      
 27 
     | 
    
         
            +
            # path: My_Funcs/OAFuncs/oafuncs/
         
     | 
| 
      
 28 
     | 
    
         
            +
            from .oa_cmap import *
         
     | 
| 
      
 29 
     | 
    
         
            +
            from .oa_data import *
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
            # ------------------- 2024-12-13 12:31:06 -------------------
         
     | 
| 
      
 32 
     | 
    
         
            +
            # path: My_Funcs/OAFuncs/oafuncs/oa_down/
         
     | 
| 
      
 33 
     | 
    
         
            +
            from .oa_down import *
         
     | 
| 
      
 34 
     | 
    
         
            +
            from .oa_draw import *
         
     | 
| 
      
 35 
     | 
    
         
            +
            from .oa_file import *
         
     | 
| 
      
 36 
     | 
    
         
            +
            from .oa_help import *
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
            # ------------------- 2024-12-13 12:31:06 -------------------
         
     | 
| 
      
 39 
     | 
    
         
            +
            # path: My_Funcs/OAFuncs/oafuncs/oa_model/
         
     | 
| 
      
 40 
     | 
    
         
            +
            from .oa_model import *
         
     | 
| 
      
 41 
     | 
    
         
            +
            from .oa_nc import *
         
     | 
| 
      
 42 
     | 
    
         
            +
            from .oa_python import *
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
            # ------------------- 2024-12-13 12:31:06 -------------------
         
     | 
| 
      
 45 
     | 
    
         
            +
            # path: My_Funcs/OAFuncs/oafuncs/oa_sign/
         
     | 
| 
      
 46 
     | 
    
         
            +
            from .oa_sign import *
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
            # ------------------- 2024-12-13 12:31:06 -------------------
         
     | 
| 
      
 49 
     | 
    
         
            +
            # path: My_Funcs/OAFuncs/oafuncs/oa_tool/
         
     | 
| 
      
 50 
     | 
    
         
            +
            from .oa_tool import *
         
     | 
| 
      
 51 
     | 
    
         
            +
            # ------------------- 2025-03-09 16:28:01 -------------------
         
     | 
| 
      
 52 
     | 
    
         
            +
            # path: My_Funcs/OAFuncs/oafuncs/_script/
         
     | 
| 
      
 53 
     | 
    
         
            +
            from ._script import *
         
     | 
| 
      
 54 
     | 
    
         
            +
            # ------------------- 2025-03-16 15:56:01 -------------------
         
     | 
| 
         @@ -0,0 +1,27 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            #!/usr/bin/env python
         
     | 
| 
      
 2 
     | 
    
         
            +
            # coding=utf-8
         
     | 
| 
      
 3 
     | 
    
         
            +
            """
         
     | 
| 
      
 4 
     | 
    
         
            +
            Author: Liu Kun && 16031215@qq.com
         
     | 
| 
      
 5 
     | 
    
         
            +
            Date: 2025-03-13 15:26:15
         
     | 
| 
      
 6 
     | 
    
         
            +
            LastEditors: Liu Kun && 16031215@qq.com
         
     | 
| 
      
 7 
     | 
    
         
            +
            LastEditTime: 2025-03-13 15:26:18
         
     | 
| 
      
 8 
     | 
    
         
            +
            FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_script\\__init__.py
         
     | 
| 
      
 9 
     | 
    
         
            +
            Description:
         
     | 
| 
      
 10 
     | 
    
         
            +
            EditPlatform: vscode
         
     | 
| 
      
 11 
     | 
    
         
            +
            ComputerInfo: XPS 15 9510
         
     | 
| 
      
 12 
     | 
    
         
            +
            SystemInfo: Windows 11
         
     | 
| 
      
 13 
     | 
    
         
            +
            Python Version: 3.12
         
     | 
| 
      
 14 
     | 
    
         
            +
            """
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            # 会导致OAFuncs直接导入所有函数,不符合模块化设计
         
     | 
| 
      
 19 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_cmap import *
         
     | 
| 
      
 20 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_data import *
         
     | 
| 
      
 21 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_draw import *
         
     | 
| 
      
 22 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_file import *
         
     | 
| 
      
 23 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_help import *
         
     | 
| 
      
 24 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_nc import *
         
     | 
| 
      
 25 
     | 
    
         
            +
            # from oafuncs.oa_s.oa_python import *
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
            from .plot_dataset import func_plot_dataset
         
     | 
| 
         @@ -0,0 +1,299 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import os
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import Optional, Tuple
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            import matplotlib as mpl
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            mpl.use("Agg")  # Use non-interactive backend
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
            import cftime
         
     | 
| 
      
 9 
     | 
    
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 
      
 10 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 11 
     | 
    
         
            +
            from rich import print
         
     | 
| 
      
 12 
     | 
    
         
            +
            import cartopy.crs as ccrs
         
     | 
| 
      
 13 
     | 
    
         
            +
            import xarray as xr
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            import oafuncs
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            def plot_1d(data: xr.DataArray, output_path: str, x_dim: str, y_dim: str, z_dim: str, t_dim: str) -> None:
         
     | 
| 
      
 19 
     | 
    
         
            +
                """Plot 1D data."""
         
     | 
| 
      
 20 
     | 
    
         
            +
                plt.figure(figsize=(10, 6))
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                # Handle time dimension
         
     | 
| 
      
 23 
     | 
    
         
            +
                if t_dim in data.dims and isinstance(data[t_dim].values[0], cftime.datetime):
         
     | 
| 
      
 24 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 25 
     | 
    
         
            +
                        data[t_dim] = data.indexes[t_dim].to_datetimeindex()
         
     | 
| 
      
 26 
     | 
    
         
            +
                    except (AttributeError, ValueError, TypeError) as e:
         
     | 
| 
      
 27 
     | 
    
         
            +
                        print(f"Warning: Could not convert {t_dim} to datetime index: {e}")
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                # Determine X axis data
         
     | 
| 
      
 30 
     | 
    
         
            +
                x, x_label = determine_x_axis(data, x_dim, y_dim, z_dim, t_dim)
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                y = data.values
         
     | 
| 
      
 33 
     | 
    
         
            +
                plt.plot(x, y, linewidth=2)
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
                # Add chart info
         
     | 
| 
      
 36 
     | 
    
         
            +
                long_name = getattr(data, "long_name", "No long_name")
         
     | 
| 
      
 37 
     | 
    
         
            +
                units = getattr(data, "units", "")
         
     | 
| 
      
 38 
     | 
    
         
            +
                plt.title(f"{data.name} | {long_name}", fontsize=12)
         
     | 
| 
      
 39 
     | 
    
         
            +
                plt.xlabel(x_label)
         
     | 
| 
      
 40 
     | 
    
         
            +
                plt.ylabel(f"{data.name} ({units})" if units else data.name)
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                plt.grid(True, linestyle="--", alpha=0.7)
         
     | 
| 
      
 43 
     | 
    
         
            +
                plt.tight_layout()
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
                # Save image
         
     | 
| 
      
 46 
     | 
    
         
            +
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
         
     | 
| 
      
 47 
     | 
    
         
            +
                plt.savefig(output_path, bbox_inches="tight", dpi=600)
         
     | 
| 
      
 48 
     | 
    
         
            +
                plt.clf()
         
     | 
| 
      
 49 
     | 
    
         
            +
                plt.close()
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
            def determine_x_axis(data: xr.DataArray, x_dim: str, y_dim: str, z_dim: str, t_dim: str) -> Tuple[np.ndarray, str]:
         
     | 
| 
      
 53 
     | 
    
         
            +
                """Determine the X axis data and label."""
         
     | 
| 
      
 54 
     | 
    
         
            +
                if x_dim in data.dims:
         
     | 
| 
      
 55 
     | 
    
         
            +
                    return data[x_dim].values, x_dim
         
     | 
| 
      
 56 
     | 
    
         
            +
                elif y_dim in data.dims:
         
     | 
| 
      
 57 
     | 
    
         
            +
                    return data[y_dim].values, y_dim
         
     | 
| 
      
 58 
     | 
    
         
            +
                elif z_dim in data.dims:
         
     | 
| 
      
 59 
     | 
    
         
            +
                    return data[z_dim].values, z_dim
         
     | 
| 
      
 60 
     | 
    
         
            +
                elif t_dim in data.dims:
         
     | 
| 
      
 61 
     | 
    
         
            +
                    return data[t_dim].values, t_dim
         
     | 
| 
      
 62 
     | 
    
         
            +
                else:
         
     | 
| 
      
 63 
     | 
    
         
            +
                    return np.arange(len(data)), "Index"
         
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
            def plot_2d(data: xr.DataArray, output_path: str, data_range: Optional[Tuple[float, float]], x_dim: str, y_dim: str, t_dim: str, plot_type: str) -> bool:
         
     | 
| 
      
 67 
     | 
    
         
            +
                """Plot 2D data."""
         
     | 
| 
      
 68 
     | 
    
         
            +
                if x_dim in data.dims and y_dim in data.dims and x_dim.lower() in ["lon", "longitude"] and y_dim.lower() in ["lat", "latitude"]:
         
     | 
| 
      
 69 
     | 
    
         
            +
                    lon_range = data[x_dim].values
         
     | 
| 
      
 70 
     | 
    
         
            +
                    lat_range = data[y_dim].values
         
     | 
| 
      
 71 
     | 
    
         
            +
                    lon_lat_ratio = np.abs(np.max(lon_range) - np.min(lon_range)) / (np.max(lat_range) - np.min(lat_range))
         
     | 
| 
      
 72 
     | 
    
         
            +
                    figsize = (10, 10 / lon_lat_ratio)
         
     | 
| 
      
 73 
     | 
    
         
            +
                    fig, ax = plt.subplots(figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()})
         
     | 
| 
      
 74 
     | 
    
         
            +
                    oafuncs.oa_draw.add_cartopy(ax, lon_range, lat_range)
         
     | 
| 
      
 75 
     | 
    
         
            +
                else:
         
     | 
| 
      
 76 
     | 
    
         
            +
                    fig, ax = plt.subplots(figsize=(10, 8))
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                # Handle time dimension
         
     | 
| 
      
 79 
     | 
    
         
            +
                if t_dim in data.dims and isinstance(data[t_dim].values[0], cftime.datetime):
         
     | 
| 
      
 80 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 81 
     | 
    
         
            +
                        data[t_dim] = data.indexes[t_dim].to_datetimeindex()
         
     | 
| 
      
 82 
     | 
    
         
            +
                    except (AttributeError, ValueError, TypeError) as e:
         
     | 
| 
      
 83 
     | 
    
         
            +
                        print(f"Warning: Could not convert {t_dim} to datetime index: {e}")
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                # Check for valid data
         
     | 
| 
      
 86 
     | 
    
         
            +
                if np.all(np.isnan(data.values)) or data.size == 0:
         
     | 
| 
      
 87 
     | 
    
         
            +
                    print(f"Skipping {data.name}: All values are NaN or empty")
         
     | 
| 
      
 88 
     | 
    
         
            +
                    plt.close()
         
     | 
| 
      
 89 
     | 
    
         
            +
                    return False
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
                data_range = calculate_data_range(data, data_range)
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                if data_range is None:
         
     | 
| 
      
 94 
     | 
    
         
            +
                    print(f"Skipping {data.name} due to all NaN values")
         
     | 
| 
      
 95 
     | 
    
         
            +
                    plt.close()
         
     | 
| 
      
 96 
     | 
    
         
            +
                    return False
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                # Select appropriate colormap and levels
         
     | 
| 
      
 99 
     | 
    
         
            +
                cmap, norm, levels = select_colormap_and_levels(data_range, plot_type)
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                mappable = None
         
     | 
| 
      
 102 
     | 
    
         
            +
                try:
         
     | 
| 
      
 103 
     | 
    
         
            +
                    if plot_type == "contourf":
         
     | 
| 
      
 104 
     | 
    
         
            +
                        if np.ptp(data.values) < 1e-10 and not np.all(np.isnan(data.values)):
         
     | 
| 
      
 105 
     | 
    
         
            +
                            print(f"Warning: {data.name} has very little variation. Using imshow instead.")
         
     | 
| 
      
 106 
     | 
    
         
            +
                            mappable = ax.imshow(data.values, cmap=cmap, aspect="auto", interpolation="none")
         
     | 
| 
      
 107 
     | 
    
         
            +
                            colorbar = plt.colorbar(mappable, ax=ax)
         
     | 
| 
      
 108 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 109 
     | 
    
         
            +
                            mappable = ax.contourf(data[x_dim], data[y_dim], data.values, levels=levels, cmap=cmap, norm=norm)
         
     | 
| 
      
 110 
     | 
    
         
            +
                            colorbar = plt.colorbar(mappable, ax=ax)
         
     | 
| 
      
 111 
     | 
    
         
            +
                    elif plot_type == "contour":
         
     | 
| 
      
 112 
     | 
    
         
            +
                        if np.ptp(data.values) < 1e-10 and not np.all(np.isnan(data.values)):
         
     | 
| 
      
 113 
     | 
    
         
            +
                            print(f"Warning: {data.name} has very little variation. Using imshow instead.")
         
     | 
| 
      
 114 
     | 
    
         
            +
                            mappable = ax.imshow(data.values, cmap=cmap, aspect="auto", interpolation="none")
         
     | 
| 
      
 115 
     | 
    
         
            +
                            colorbar = plt.colorbar(mappable, ax=ax)
         
     | 
| 
      
 116 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 117 
     | 
    
         
            +
                            mappable = ax.contour(data[x_dim], data[y_dim], data.values, levels=levels, cmap=cmap, norm=norm)
         
     | 
| 
      
 118 
     | 
    
         
            +
                            ax.clabel(mappable, inline=True, fontsize=8, fmt="%1.1f")
         
     | 
| 
      
 119 
     | 
    
         
            +
                            colorbar = plt.colorbar(mappable, ax=ax)
         
     | 
| 
      
 120 
     | 
    
         
            +
                except (ValueError, TypeError) as e:
         
     | 
| 
      
 121 
     | 
    
         
            +
                    print(f"Warning: Could not plot with specified parameters: {e}. Trying simplified parameters.")
         
     | 
| 
      
 122 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 123 
     | 
    
         
            +
                        mappable = data.plot(ax=ax, cmap=cmap, add_colorbar=False)
         
     | 
| 
      
 124 
     | 
    
         
            +
                        colorbar = plt.colorbar(mappable, ax=ax)
         
     | 
| 
      
 125 
     | 
    
         
            +
                    except Exception as e2:
         
     | 
| 
      
 126 
     | 
    
         
            +
                        print(f"Error plotting {data.name}: {e2}")
         
     | 
| 
      
 127 
     | 
    
         
            +
                        plt.figure(figsize=(10, 8))
         
     | 
| 
      
 128 
     | 
    
         
            +
                        mappable = ax.imshow(data.values, cmap="viridis", aspect="auto")
         
     | 
| 
      
 129 
     | 
    
         
            +
                        colorbar = plt.colorbar(mappable, ax=ax, label=getattr(data, "units", ""))
         
     | 
| 
      
 130 
     | 
    
         
            +
                        plt.title(f"{data.name} | {getattr(data, 'long_name', 'No long_name')} (basic plot)", fontsize=12)
         
     | 
| 
      
 131 
     | 
    
         
            +
                        plt.tight_layout()
         
     | 
| 
      
 132 
     | 
    
         
            +
                        os.makedirs(os.path.dirname(output_path), exist_ok=True)
         
     | 
| 
      
 133 
     | 
    
         
            +
                        plt.savefig(output_path, bbox_inches="tight", dpi=600)
         
     | 
| 
      
 134 
     | 
    
         
            +
                        plt.close()
         
     | 
| 
      
 135 
     | 
    
         
            +
                        return True
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
                plt.title(f"{data.name} | {getattr(data, 'long_name', 'No long_name')}", fontsize=12)
         
     | 
| 
      
 138 
     | 
    
         
            +
                units = getattr(data, "units", "")
         
     | 
| 
      
 139 
     | 
    
         
            +
                if units and colorbar:
         
     | 
| 
      
 140 
     | 
    
         
            +
                    colorbar.set_label(units)
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
                plt.tight_layout()
         
     | 
| 
      
 143 
     | 
    
         
            +
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
         
     | 
| 
      
 144 
     | 
    
         
            +
                plt.savefig(output_path, bbox_inches="tight", dpi=600)
         
     | 
| 
      
 145 
     | 
    
         
            +
                plt.close()
         
     | 
| 
      
 146 
     | 
    
         
            +
                return True
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
             
     | 
| 
      
 149 
     | 
    
         
            +
            def calculate_data_range(data: xr.DataArray, data_range: Optional[Tuple[float, float]]) -> Optional[Tuple[float, float]]:
         
     | 
| 
      
 150 
     | 
    
         
            +
                """Calculate the data range, ignoring extreme outliers."""
         
     | 
| 
      
 151 
     | 
    
         
            +
                if data_range is None:
         
     | 
| 
      
 152 
     | 
    
         
            +
                    flat_data = data.values.flatten()
         
     | 
| 
      
 153 
     | 
    
         
            +
                    if flat_data.size == 0:
         
     | 
| 
      
 154 
     | 
    
         
            +
                        return None
         
     | 
| 
      
 155 
     | 
    
         
            +
                    valid_data = flat_data[~np.isnan(flat_data)]
         
     | 
| 
      
 156 
     | 
    
         
            +
                    if len(valid_data) == 0:
         
     | 
| 
      
 157 
     | 
    
         
            +
                        return None
         
     | 
| 
      
 158 
     | 
    
         
            +
                    low, high = np.percentile(valid_data, [0.5, 99.5])
         
     | 
| 
      
 159 
     | 
    
         
            +
                    filtered_data = valid_data[(valid_data >= low) & (valid_data <= high)]
         
     | 
| 
      
 160 
     | 
    
         
            +
                    if len(filtered_data) > 0:
         
     | 
| 
      
 161 
     | 
    
         
            +
                        data_range = (np.min(filtered_data), np.max(filtered_data))
         
     | 
| 
      
 162 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 163 
     | 
    
         
            +
                        data_range = (np.nanmin(valid_data), np.nanmax(valid_data))
         
     | 
| 
      
 164 
     | 
    
         
            +
                    if abs(data_range[1] - data_range[0]) < 1e-10:
         
     | 
| 
      
 165 
     | 
    
         
            +
                        mean = (data_range[0] + data_range[1]) / 2
         
     | 
| 
      
 166 
     | 
    
         
            +
                        data_range = (mean - 1e-10 if mean != 0 else -1e-10, mean + 1e-10 if mean != 0 else 1e-10)
         
     | 
| 
      
 167 
     | 
    
         
            +
                return data_range
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
            def select_colormap_and_levels(data_range: Tuple[float, float], plot_type: str) -> Tuple[mpl.colors.Colormap, mpl.colors.Normalize, np.ndarray]:
         
     | 
| 
      
 171 
     | 
    
         
            +
                """Select colormap and levels based on data range."""
         
     | 
| 
      
 172 
     | 
    
         
            +
                if plot_type == "contour":
         
     | 
| 
      
 173 
     | 
    
         
            +
                    # For contour plots, use fewer levels
         
     | 
| 
      
 174 
     | 
    
         
            +
                    num_levels = 10
         
     | 
| 
      
 175 
     | 
    
         
            +
                else:
         
     | 
| 
      
 176 
     | 
    
         
            +
                    # For filled contour plots, use more levels
         
     | 
| 
      
 177 
     | 
    
         
            +
                    num_levels = 128
         
     | 
| 
      
 178 
     | 
    
         
            +
             
     | 
| 
      
 179 
     | 
    
         
            +
                if data_range[0] * data_range[1] < 0:
         
     | 
| 
      
 180 
     | 
    
         
            +
                    cmap = oafuncs.oa_cmap.get("diverging_1")
         
     | 
| 
      
 181 
     | 
    
         
            +
                    bdy = max(abs(data_range[0]), abs(data_range[1]))
         
     | 
| 
      
 182 
     | 
    
         
            +
                    norm = mpl.colors.TwoSlopeNorm(vmin=-bdy, vcenter=0, vmax=bdy)
         
     | 
| 
      
 183 
     | 
    
         
            +
                    levels = np.linspace(-bdy, bdy, num_levels)
         
     | 
| 
      
 184 
     | 
    
         
            +
                else:
         
     | 
| 
      
 185 
     | 
    
         
            +
                    cmap = oafuncs.oa_cmap.get("cool_1") if data_range[0] < 0 else oafuncs.oa_cmap.get("warm_1")
         
     | 
| 
      
 186 
     | 
    
         
            +
                    norm = mpl.colors.Normalize(vmin=data_range[0], vmax=data_range[1])
         
     | 
| 
      
 187 
     | 
    
         
            +
                    levels = np.linspace(data_range[0], data_range[1], num_levels)
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                if np.any(np.diff(levels) <= 0):
         
     | 
| 
      
 190 
     | 
    
         
            +
                    levels = np.linspace(data_range[0], data_range[1], 10)
         
     | 
| 
      
 191 
     | 
    
         
            +
                return cmap, norm, levels
         
     | 
| 
      
 192 
     | 
    
         
            +
             
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
            def process_variable(var: str, data: xr.DataArray, dims: int, dims_name: Tuple[str, ...], output_dir: str, x_dim: str, y_dim: str, z_dim: str, t_dim: str, fixed_colorscale: bool, plot_type: str) -> None:
         
     | 
| 
      
 195 
     | 
    
         
            +
                """Process a single variable."""
         
     | 
| 
      
 196 
     | 
    
         
            +
                valid_dims = {x_dim, y_dim, z_dim, t_dim}
         
     | 
| 
      
 197 
     | 
    
         
            +
                if not set(dims_name).issubset(valid_dims):
         
     | 
| 
      
 198 
     | 
    
         
            +
                    print(f"Skipping {var} due to unsupported dimensions: {dims_name}")
         
     | 
| 
      
 199 
     | 
    
         
            +
                    return
         
     | 
| 
      
 200 
     | 
    
         
            +
             
     | 
| 
      
 201 
     | 
    
         
            +
                # Process 1D data
         
     | 
| 
      
 202 
     | 
    
         
            +
                if dims == 1:
         
     | 
| 
      
 203 
     | 
    
         
            +
                    if np.issubdtype(data.dtype, np.character):
         
     | 
| 
      
 204 
     | 
    
         
            +
                        print(f"Skipping {var} due to character data type")
         
     | 
| 
      
 205 
     | 
    
         
            +
                        return
         
     | 
| 
      
 206 
     | 
    
         
            +
                    plot_1d(data, os.path.join(output_dir, f"{var}.png"), x_dim, y_dim, z_dim, t_dim)
         
     | 
| 
      
 207 
     | 
    
         
            +
                    print(f"{var}.png")
         
     | 
| 
      
 208 
     | 
    
         
            +
                    return
         
     | 
| 
      
 209 
     | 
    
         
            +
             
     | 
| 
      
 210 
     | 
    
         
            +
                # Compute global data range for fixed colorscale
         
     | 
| 
      
 211 
     | 
    
         
            +
                global_data_range = None
         
     | 
| 
      
 212 
     | 
    
         
            +
                if dims >= 2 and fixed_colorscale:
         
     | 
| 
      
 213 
     | 
    
         
            +
                    global_data_range = calculate_data_range(data, None)
         
     | 
| 
      
 214 
     | 
    
         
            +
                    if global_data_range is None:
         
     | 
| 
      
 215 
     | 
    
         
            +
                        print(f"Skipping {var} due to no valid data")
         
     | 
| 
      
 216 
     | 
    
         
            +
                        return
         
     | 
| 
      
 217 
     | 
    
         
            +
                    print(f"Fixed colorscale range: {global_data_range}")
         
     | 
| 
      
 218 
     | 
    
         
            +
             
     | 
| 
      
 219 
     | 
    
         
            +
                # Process 2D data
         
     | 
| 
      
 220 
     | 
    
         
            +
                if dims == 2:
         
     | 
| 
      
 221 
     | 
    
         
            +
                    success = plot_2d(data, os.path.join(output_dir, f"{var}.png"), global_data_range, x_dim, y_dim, t_dim, plot_type)
         
     | 
| 
      
 222 
     | 
    
         
            +
                    if success:
         
     | 
| 
      
 223 
     | 
    
         
            +
                        print(f"{var}.png")
         
     | 
| 
      
 224 
     | 
    
         
            +
             
     | 
| 
      
 225 
     | 
    
         
            +
                # Process 3D data
         
     | 
| 
      
 226 
     | 
    
         
            +
                if dims == 3:
         
     | 
| 
      
 227 
     | 
    
         
            +
                    for i in range(data.shape[0]):
         
     | 
| 
      
 228 
     | 
    
         
            +
                        for attempt in range(10):
         
     | 
| 
      
 229 
     | 
    
         
            +
                            try:
         
     | 
| 
      
 230 
     | 
    
         
            +
                                if data[i].values.size == 0:
         
     | 
| 
      
 231 
     | 
    
         
            +
                                    print(f"Skipped {var}_{dims_name[0]}-{i} (empty data)")
         
     | 
| 
      
 232 
     | 
    
         
            +
                                    break
         
     | 
| 
      
 233 
     | 
    
         
            +
                                success = plot_2d(data[i], os.path.join(output_dir, f"{var}_{dims_name[0]}-{i}.png"), global_data_range, x_dim, y_dim, t_dim, plot_type)
         
     | 
| 
      
 234 
     | 
    
         
            +
                                if success:
         
     | 
| 
      
 235 
     | 
    
         
            +
                                    print(f"{var}_{dims_name[0]}-{i}.png")
         
     | 
| 
      
 236 
     | 
    
         
            +
                                else:
         
     | 
| 
      
 237 
     | 
    
         
            +
                                    print(f"Skipped {var}_{dims_name[0]}-{i} (invalid data)")
         
     | 
| 
      
 238 
     | 
    
         
            +
                                break
         
     | 
| 
      
 239 
     | 
    
         
            +
                            except Exception as e:
         
     | 
| 
      
 240 
     | 
    
         
            +
                                if attempt < 9:
         
     | 
| 
      
 241 
     | 
    
         
            +
                                    print(f"Retrying {var}_{dims_name[0]}-{i} (attempt {attempt + 1})")
         
     | 
| 
      
 242 
     | 
    
         
            +
                                else:
         
     | 
| 
      
 243 
     | 
    
         
            +
                                    print(f"Error processing {var}_{dims_name[0]}-{i}: {e}")
         
     | 
| 
      
 244 
     | 
    
         
            +
             
     | 
| 
      
 245 
     | 
    
         
            +
                # Process 4D data
         
     | 
| 
      
 246 
     | 
    
         
            +
                if dims == 4:
         
     | 
| 
      
 247 
     | 
    
         
            +
                    for i in range(data.shape[0]):
         
     | 
| 
      
 248 
     | 
    
         
            +
                        for j in range(data.shape[1]):
         
     | 
| 
      
 249 
     | 
    
         
            +
                            for attempt in range(3):
         
     | 
| 
      
 250 
     | 
    
         
            +
                                try:
         
     | 
| 
      
 251 
     | 
    
         
            +
                                    if data[i, j].values.size == 0:
         
     | 
| 
      
 252 
     | 
    
         
            +
                                        print(f"Skipped {var}_{dims_name[0]}-{i}_{dims_name[1]}-{j} (empty data)")
         
     | 
| 
      
 253 
     | 
    
         
            +
                                        break
         
     | 
| 
      
 254 
     | 
    
         
            +
                                    success = plot_2d(data[i, j], os.path.join(output_dir, f"{var}_{dims_name[0]}-{i}_{dims_name[1]}-{j}.png"), global_data_range, x_dim, y_dim, t_dim, plot_type)
         
     | 
| 
      
 255 
     | 
    
         
            +
                                    if success:
         
     | 
| 
      
 256 
     | 
    
         
            +
                                        print(f"{var}_{dims_name[0]}-{i}_{dims_name[1]}-{j}.png")
         
     | 
| 
      
 257 
     | 
    
         
            +
                                    else:
         
     | 
| 
      
 258 
     | 
    
         
            +
                                        print(f"Skipped {var}_{dims_name[0]}-{i}_{dims_name[1]}-{j} (invalid data)")
         
     | 
| 
      
 259 
     | 
    
         
            +
                                    break
         
     | 
| 
      
 260 
     | 
    
         
            +
                                except Exception as e:
         
     | 
| 
      
 261 
     | 
    
         
            +
                                    if attempt < 2:
         
     | 
| 
      
 262 
     | 
    
         
            +
                                        print(f"Retrying {var}_{dims_name[0]}-{i}_{dims_name[1]}-{j} (attempt {attempt + 1})")
         
     | 
| 
      
 263 
     | 
    
         
            +
                                    else:
         
     | 
| 
      
 264 
     | 
    
         
            +
                                        print(f"Error processing {var}_{dims_name[0]}-{i}_{dims_name[1]}-{j}: {e}")
         
     | 
| 
      
 265 
     | 
    
         
            +
             
     | 
| 
      
 266 
     | 
    
         
            +
             
     | 
| 
      
 267 
     | 
    
         
            +
            def func_plot_dataset(ds_in: xr.Dataset, output_dir: str, xyzt_dims: Tuple[str, str, str, str] = ("longitude", "latitude", "level", "time"), plot_type: str = "contourf", fixed_colorscale: bool = False) -> None:
         
     | 
| 
      
 268 
     | 
    
         
            +
                """Plot variables from a NetCDF file and save the plots to the specified directory."""
         
     | 
| 
      
 269 
     | 
    
         
            +
                os.makedirs(output_dir, exist_ok=True)
         
     | 
| 
      
 270 
     | 
    
         
            +
                x_dim, y_dim, z_dim, t_dim = xyzt_dims
         
     | 
| 
      
 271 
     | 
    
         
            +
             
     | 
| 
      
 272 
     | 
    
         
            +
                # Main processing function
         
     | 
| 
      
 273 
     | 
    
         
            +
                try:
         
     | 
| 
      
 274 
     | 
    
         
            +
                    ds = ds_in
         
     | 
| 
      
 275 
     | 
    
         
            +
                    varlist = list(ds.data_vars)
         
     | 
| 
      
 276 
     | 
    
         
            +
                    print(f"Found {len(varlist)} variables in dataset")
         
     | 
| 
      
 277 
     | 
    
         
            +
             
     | 
| 
      
 278 
     | 
    
         
            +
                    for var in varlist:
         
     | 
| 
      
 279 
     | 
    
         
            +
                        print("=" * 120)
         
     | 
| 
      
 280 
     | 
    
         
            +
                        print(f"Processing: {var}")
         
     | 
| 
      
 281 
     | 
    
         
            +
                        data = ds[var]
         
     | 
| 
      
 282 
     | 
    
         
            +
                        dims = len(data.shape)
         
     | 
| 
      
 283 
     | 
    
         
            +
                        dims_name = data.dims
         
     | 
| 
      
 284 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 285 
     | 
    
         
            +
                            process_variable(var, data, dims, dims_name, output_dir, x_dim, y_dim, z_dim, t_dim, fixed_colorscale, plot_type)
         
     | 
| 
      
 286 
     | 
    
         
            +
                        except Exception as e:
         
     | 
| 
      
 287 
     | 
    
         
            +
                            print(f"Error processing variable {var}: {e}")
         
     | 
| 
      
 288 
     | 
    
         
            +
             
     | 
| 
      
 289 
     | 
    
         
            +
                except Exception as e:
         
     | 
| 
      
 290 
     | 
    
         
            +
                    print(f"Error processing dataset: {e}")
         
     | 
| 
      
 291 
     | 
    
         
            +
                finally:
         
     | 
| 
      
 292 
     | 
    
         
            +
                    if "ds" in locals():
         
     | 
| 
      
 293 
     | 
    
         
            +
                        ds.close()
         
     | 
| 
      
 294 
     | 
    
         
            +
                        print("Dataset closed")
         
     | 
| 
      
 295 
     | 
    
         
            +
             
     | 
| 
      
 296 
     | 
    
         
            +
             
     | 
| 
      
 297 
     | 
    
         
            +
            if __name__ == "__main__":
         
     | 
| 
      
 298 
     | 
    
         
            +
                pass
         
     | 
| 
      
 299 
     | 
    
         
            +
                # func_plot_dataset(ds, output_dir, xyzt_dims=("longitude", "latitude", "level", "time"), plot_type="contourf", fixed_colorscale=False)
         
     | 
| 
         Binary file 
     | 
| 
         Binary file 
     | 
    
        oafuncs/oa_cmap.py
    ADDED
    
    | 
         @@ -0,0 +1,215 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            #!/usr/bin/env python
         
     | 
| 
      
 2 
     | 
    
         
            +
            # coding=utf-8
         
     | 
| 
      
 3 
     | 
    
         
            +
            """
         
     | 
| 
      
 4 
     | 
    
         
            +
            Author: Liu Kun && 16031215@qq.com
         
     | 
| 
      
 5 
     | 
    
         
            +
            Date: 2024-09-17 16:55:11
         
     | 
| 
      
 6 
     | 
    
         
            +
            LastEditors: Liu Kun && 16031215@qq.com
         
     | 
| 
      
 7 
     | 
    
         
            +
            LastEditTime: 2024-11-21 13:14:24
         
     | 
| 
      
 8 
     | 
    
         
            +
            FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_cmap.py
         
     | 
| 
      
 9 
     | 
    
         
            +
            Description:
         
     | 
| 
      
 10 
     | 
    
         
            +
            EditPlatform: vscode
         
     | 
| 
      
 11 
     | 
    
         
            +
            ComputerInfo: XPS 15 9510
         
     | 
| 
      
 12 
     | 
    
         
            +
            SystemInfo: Windows 11
         
     | 
| 
      
 13 
     | 
    
         
            +
            Python Version: 3.11
         
     | 
| 
      
 14 
     | 
    
         
            +
            """
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            import matplotlib as mpl
         
     | 
| 
      
 17 
     | 
    
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 
      
 18 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 19 
     | 
    
         
            +
            from rich import print
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            __all__ = ["show", "to_color", "create", "create_rgbtxt", "get"]
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
            # ** 将cmap用填色图可视化(官网摘抄函数)
         
     | 
| 
      
 24 
     | 
    
         
            +
            def show(colormaps):
         
     | 
| 
      
 25 
     | 
    
         
            +
                """
         
     | 
| 
      
 26 
     | 
    
         
            +
                Description:
         
     | 
| 
      
 27 
     | 
    
         
            +
                    Helper function to plot data with associated colormap.
         
     | 
| 
      
 28 
     | 
    
         
            +
                Parameters:
         
     | 
| 
      
 29 
     | 
    
         
            +
                    colormaps : list of colormaps, or a single colormap; can be a string or a colormap object.
         
     | 
| 
      
 30 
     | 
    
         
            +
                Example:
         
     | 
| 
      
 31 
     | 
    
         
            +
                    cmap = ListedColormap(["darkorange", "gold", "lawngreen", "lightseagreen"])
         
     | 
| 
      
 32 
     | 
    
         
            +
                    show([cmap]); show("viridis"); show(["viridis", "cividis"])
         
     | 
| 
      
 33 
     | 
    
         
            +
                """
         
     | 
| 
      
 34 
     | 
    
         
            +
                if isinstance(colormaps, str) or isinstance(colormaps, mpl.colors.Colormap):
         
     | 
| 
      
 35 
     | 
    
         
            +
                    colormaps = [colormaps]
         
     | 
| 
      
 36 
     | 
    
         
            +
                np.random.seed(19680801)
         
     | 
| 
      
 37 
     | 
    
         
            +
                data = np.random.randn(30, 30)
         
     | 
| 
      
 38 
     | 
    
         
            +
                n = len(colormaps)
         
     | 
| 
      
 39 
     | 
    
         
            +
                fig, axs = plt.subplots(1, n, figsize=(n * 2 + 2, 3), constrained_layout=True, squeeze=False)
         
     | 
| 
      
 40 
     | 
    
         
            +
                for [ax, cmap] in zip(axs.flat, colormaps):
         
     | 
| 
      
 41 
     | 
    
         
            +
                    psm = ax.pcolormesh(data, cmap=cmap, rasterized=True, vmin=-4, vmax=4)
         
     | 
| 
      
 42 
     | 
    
         
            +
                    fig.colorbar(psm, ax=ax)
         
     | 
| 
      
 43 
     | 
    
         
            +
                plt.show()
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
            # ** 将cmap转为list,即多个颜色的列表
         
     | 
| 
      
 47 
     | 
    
         
            +
            def to_color(cmap, n=256):
         
     | 
| 
      
 48 
     | 
    
         
            +
                """
         
     | 
| 
      
 49 
     | 
    
         
            +
                Description:
         
     | 
| 
      
 50 
     | 
    
         
            +
                    Convert a colormap to a list of colors
         
     | 
| 
      
 51 
     | 
    
         
            +
                Parameters:
         
     | 
| 
      
 52 
     | 
    
         
            +
                    cmap : str; the name of the colormap
         
     | 
| 
      
 53 
     | 
    
         
            +
                    n    : int, optional; the number of colors
         
     | 
| 
      
 54 
     | 
    
         
            +
                Return:
         
     | 
| 
      
 55 
     | 
    
         
            +
                    out_colors : list of colors
         
     | 
| 
      
 56 
     | 
    
         
            +
                Example:
         
     | 
| 
      
 57 
     | 
    
         
            +
                    out_colors = to_color('viridis', 256)
         
     | 
| 
      
 58 
     | 
    
         
            +
                """
         
     | 
| 
      
 59 
     | 
    
         
            +
                c_map = mpl.colormaps.get_cmap(cmap)
         
     | 
| 
      
 60 
     | 
    
         
            +
                out_colors = [c_map(i) for i in np.linspace(0, 1, n)]
         
     | 
| 
      
 61 
     | 
    
         
            +
                return out_colors
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
            # ** 自制cmap,多色,可带位置
         
     | 
| 
      
 65 
     | 
    
         
            +
            def create(colors: list, nodes=None, under=None, over=None):  # 利用颜色快速配色
         
     | 
| 
      
 66 
     | 
    
         
            +
                """
         
     | 
| 
      
 67 
     | 
    
         
            +
                Description:
         
     | 
| 
      
 68 
     | 
    
         
            +
                    Create a custom colormap
         
     | 
| 
      
 69 
     | 
    
         
            +
                Parameters:
         
     | 
| 
      
 70 
     | 
    
         
            +
                    colors : list of colors
         
     | 
| 
      
 71 
     | 
    
         
            +
                    nodes  : list of positions
         
     | 
| 
      
 72 
     | 
    
         
            +
                    under  : color
         
     | 
| 
      
 73 
     | 
    
         
            +
                    over   : color
         
     | 
| 
      
 74 
     | 
    
         
            +
                Return:
         
     | 
| 
      
 75 
     | 
    
         
            +
                    cmap : colormap
         
     | 
| 
      
 76 
     | 
    
         
            +
                Example:
         
     | 
| 
      
 77 
     | 
    
         
            +
                    cmap = create(['#C2B7F3','#B3BBF2','#B0CBF1','#ACDCF0','#A8EEED'])
         
     | 
| 
      
 78 
     | 
    
         
            +
                    cmap = create(['aliceblue','skyblue','deepskyblue'],[0.0,0.5,1.0])
         
     | 
| 
      
 79 
     | 
    
         
            +
                """
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                if nodes is None:  # 采取自动分配比例
         
     | 
| 
      
 82 
     | 
    
         
            +
                    cmap_color = mpl.colors.LinearSegmentedColormap.from_list("mycmap", colors)
         
     | 
| 
      
 83 
     | 
    
         
            +
                else:  # 按照提供比例分配
         
     | 
| 
      
 84 
     | 
    
         
            +
                    cmap_color = mpl.colors.LinearSegmentedColormap.from_list("mycmap", list(zip(nodes, colors)))
         
     | 
| 
      
 85 
     | 
    
         
            +
                if under is not None:
         
     | 
| 
      
 86 
     | 
    
         
            +
                    cmap_color.set_under(under)
         
     | 
| 
      
 87 
     | 
    
         
            +
                if over is not None:
         
     | 
| 
      
 88 
     | 
    
         
            +
                    cmap_color.set_over(over)
         
     | 
| 
      
 89 
     | 
    
         
            +
                return cmap_color
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
            # ** 根据RGB的txt文档制作色卡(利用Grads调色盘)
         
     | 
| 
      
 93 
     | 
    
         
            +
            def create_rgbtxt(rgbtxt_file,split_mark=','):  # 根据RGB的txt文档制作色卡/根据rgb值制作
         
     | 
| 
      
 94 
     | 
    
         
            +
                """
         
     | 
| 
      
 95 
     | 
    
         
            +
                Description
         
     | 
| 
      
 96 
     | 
    
         
            +
                -----------
         
     | 
| 
      
 97 
     | 
    
         
            +
                Make a color card according to the RGB txt document, each line in the txt file is an RGB value, separated by commas, such as: 251,251,253
         
     | 
| 
      
 98 
     | 
    
         
            +
                
         
     | 
| 
      
 99 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 100 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 101 
     | 
    
         
            +
                rgbtxt_file : str, the path of txt file
         
     | 
| 
      
 102 
     | 
    
         
            +
                split_mark  : str, optional, default is ','; the split mark of rgb value
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
                Returns
         
     | 
| 
      
 105 
     | 
    
         
            +
                -------
         
     | 
| 
      
 106 
     | 
    
         
            +
                cmap : colormap
         
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
                Example
         
     | 
| 
      
 109 
     | 
    
         
            +
                -------
         
     | 
| 
      
 110 
     | 
    
         
            +
                cmap=create_rgbtxt(path,split_mark=',')
         
     | 
| 
      
 111 
     | 
    
         
            +
                
         
     | 
| 
      
 112 
     | 
    
         
            +
                txt example
         
     | 
| 
      
 113 
     | 
    
         
            +
                -----------
         
     | 
| 
      
 114 
     | 
    
         
            +
                251,251,253
         
     | 
| 
      
 115 
     | 
    
         
            +
                225,125,25
         
     | 
| 
      
 116 
     | 
    
         
            +
                250,205,255
         
     | 
| 
      
 117 
     | 
    
         
            +
                """
         
     | 
| 
      
 118 
     | 
    
         
            +
                with open(rgbtxt_file) as fid:
         
     | 
| 
      
 119 
     | 
    
         
            +
                    data = fid.readlines()
         
     | 
| 
      
 120 
     | 
    
         
            +
                n = len(data)
         
     | 
| 
      
 121 
     | 
    
         
            +
                rgb = np.zeros((n, 3))
         
     | 
| 
      
 122 
     | 
    
         
            +
                for i in np.arange(n):
         
     | 
| 
      
 123 
     | 
    
         
            +
                    rgb[i][0] = data[i].split(split_mark)[0]
         
     | 
| 
      
 124 
     | 
    
         
            +
                    rgb[i][1] = data[i].split(split_mark)[1]
         
     | 
| 
      
 125 
     | 
    
         
            +
                    rgb[i][2] = data[i].split(split_mark)[2]
         
     | 
| 
      
 126 
     | 
    
         
            +
                max_rgb = np.max(rgb)
         
     | 
| 
      
 127 
     | 
    
         
            +
                if max_rgb > 2:  # if the value is greater than 2, it is normalized to 0-1
         
     | 
| 
      
 128 
     | 
    
         
            +
                    rgb = rgb / 255.0
         
     | 
| 
      
 129 
     | 
    
         
            +
                my_cmap = mpl.colors.ListedColormap(rgb, name="my_color")
         
     | 
| 
      
 130 
     | 
    
         
            +
                return my_cmap
         
     | 
| 
      
 131 
     | 
    
         
            +
             
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
            # ** 选择cmap
         
     | 
| 
      
 134 
     | 
    
         
            +
            def get(cmap_name=None, query=False):
         
     | 
| 
      
 135 
     | 
    
         
            +
                """
         
     | 
| 
      
 136 
     | 
    
         
            +
                Description:
         
     | 
| 
      
 137 
     | 
    
         
            +
                    Choosing a colormap from the list of available colormaps or a custom colormap
         
     | 
| 
      
 138 
     | 
    
         
            +
                Parameters:
         
     | 
| 
      
 139 
     | 
    
         
            +
                    cmap_name : str, optional; the name of the colormap
         
     | 
| 
      
 140 
     | 
    
         
            +
                    query     : bool, optional; whether to query the available colormap names
         
     | 
| 
      
 141 
     | 
    
         
            +
                Return:
         
     | 
| 
      
 142 
     | 
    
         
            +
                    cmap : colormap
         
     | 
| 
      
 143 
     | 
    
         
            +
                Example:
         
     | 
| 
      
 144 
     | 
    
         
            +
                    cmap = get('viridis')
         
     | 
| 
      
 145 
     | 
    
         
            +
                    cmap = get('diverging_1')
         
     | 
| 
      
 146 
     | 
    
         
            +
                    cmap = get('cool_1')
         
     | 
| 
      
 147 
     | 
    
         
            +
                    cmap = get('warm_1')
         
     | 
| 
      
 148 
     | 
    
         
            +
                    cmap = get('colorful_1')
         
     | 
| 
      
 149 
     | 
    
         
            +
                """
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
                my_cmap_dict = {
         
     | 
| 
      
 152 
     | 
    
         
            +
                    "diverging_1": create(["#4e00b3", "#0000FF", "#00c0ff", "#a1d3ff", "#DCDCDC", "#FFD39B", "#FF8247", "#FF0000", "#FF5F9E"]),
         
     | 
| 
      
 153 
     | 
    
         
            +
                    "cool_1": create(["#4e00b3", "#0000FF", "#00c0ff", "#a1d3ff", "#DCDCDC"]),
         
     | 
| 
      
 154 
     | 
    
         
            +
                    "warm_1": create(["#DCDCDC", "#FFD39B", "#FF8247", "#FF0000", "#FF5F9E"]),
         
     | 
| 
      
 155 
     | 
    
         
            +
                    # "land_1": create_custom(["#3E6436", "#678A59", "#91A176", "#B8A87D", "#D9CBB2"], under="#A6CEE3", over="#FFFFFF"),
         
     | 
| 
      
 156 
     | 
    
         
            +
                    # "ocean_1": create_custom(["#126697", "#2D88B3", "#4EA1C9", "#78B9D8", "#A6CEE3"], under="#8470FF", over="#3E6436"), 
         
     | 
| 
      
 157 
     | 
    
         
            +
                    # "ocean_land_1": create_custom(
         
     | 
| 
      
 158 
     | 
    
         
            +
                    #     [
         
     | 
| 
      
 159 
     | 
    
         
            +
                    #         "#126697",  # 深蓝(深海)
         
     | 
| 
      
 160 
     | 
    
         
            +
                    #         "#2D88B3",  # 蓝
         
     | 
| 
      
 161 
     | 
    
         
            +
                    #         "#4EA1C9",  # 蓝绿
         
     | 
| 
      
 162 
     | 
    
         
            +
                    #         "#78B9D8",  # 浅蓝(浅海)
         
     | 
| 
      
 163 
     | 
    
         
            +
                    #         "#A6CEE3",  # 浅蓝(近岸)
         
     | 
| 
      
 164 
     | 
    
         
            +
                    #         "#AAAAAA",  # 灰色(0值,海平面)
         
     | 
| 
      
 165 
     | 
    
         
            +
                    #         "#D9CBB2",  # 沙质土壤色(陆地开始)
         
     | 
| 
      
 166 
     | 
    
         
            +
                    #         "#B8A87D",  # 浅棕
         
     | 
| 
      
 167 
     | 
    
         
            +
                    #         "#91A176",  # 浅绿
         
     | 
| 
      
 168 
     | 
    
         
            +
                    #         "#678A59",  # 中绿
         
     | 
| 
      
 169 
     | 
    
         
            +
                    #         "#3E6436",  # 深绿(高山)
         
     | 
| 
      
 170 
     | 
    
         
            +
                    #     ]
         
     | 
| 
      
 171 
     | 
    
         
            +
                    # ),
         
     | 
| 
      
 172 
     | 
    
         
            +
                    "colorful_1": create(["#6d00db", "#9800cb", "#F2003C", "#ff4500", "#ff7f00", "#FE28A2", "#FFC0CB", "#DDA0DD", "#40E0D0", "#1a66f2", "#00f7fb", "#8fff88", "#E3FF00"]),
         
     | 
| 
      
 173 
     | 
    
         
            +
                }
         
     | 
| 
      
 174 
     | 
    
         
            +
                if query:
         
     | 
| 
      
 175 
     | 
    
         
            +
                    print("Available cmap names:")
         
     | 
| 
      
 176 
     | 
    
         
            +
                    print('-' * 20)
         
     | 
| 
      
 177 
     | 
    
         
            +
                    print('Defined by myself:')
         
     | 
| 
      
 178 
     | 
    
         
            +
                    for key, _ in my_cmap_dict.items():
         
     | 
| 
      
 179 
     | 
    
         
            +
                        print(key)
         
     | 
| 
      
 180 
     | 
    
         
            +
                    print('-' * 20)
         
     | 
| 
      
 181 
     | 
    
         
            +
                    print('Matplotlib built-in:')
         
     | 
| 
      
 182 
     | 
    
         
            +
                    print(mpl.colormaps())
         
     | 
| 
      
 183 
     | 
    
         
            +
                    print("-" * 20)
         
     | 
| 
      
 184 
     | 
    
         
            +
                
         
     | 
| 
      
 185 
     | 
    
         
            +
                if cmap_name is None:
         
     | 
| 
      
 186 
     | 
    
         
            +
                    return
         
     | 
| 
      
 187 
     | 
    
         
            +
             
     | 
| 
      
 188 
     | 
    
         
            +
                if cmap_name in my_cmap_dict:
         
     | 
| 
      
 189 
     | 
    
         
            +
                    return my_cmap_dict[cmap_name]
         
     | 
| 
      
 190 
     | 
    
         
            +
                else:
         
     | 
| 
      
 191 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 192 
     | 
    
         
            +
                        return mpl.colormaps.get_cmap(cmap_name)
         
     | 
| 
      
 193 
     | 
    
         
            +
                    except ValueError:
         
     | 
| 
      
 194 
     | 
    
         
            +
                        # raise ValueError(f"Unknown cmap name: {cmap_name}")
         
     | 
| 
      
 195 
     | 
    
         
            +
                        print(f"Unknown cmap name: {cmap_name}\nNow return 'rainbow' as default.")
         
     | 
| 
      
 196 
     | 
    
         
            +
                        return mpl.colormaps.get_cmap("rainbow")
         
     | 
| 
      
 197 
     | 
    
         
            +
             
     | 
| 
      
 198 
     | 
    
         
            +
             
     | 
| 
      
 199 
     | 
    
         
            +
            if __name__ == "__main__":
         
     | 
| 
      
 200 
     | 
    
         
            +
                # ** 测试自制cmap
         
     | 
| 
      
 201 
     | 
    
         
            +
                colors = ["#C2B7F3", "#B3BBF2", "#B0CBF1", "#ACDCF0", "#A8EEED"]
         
     | 
| 
      
 202 
     | 
    
         
            +
                nodes = [0.0, 0.2, 0.4, 0.6, 1.0]
         
     | 
| 
      
 203 
     | 
    
         
            +
                c_map = create(colors, nodes)
         
     | 
| 
      
 204 
     | 
    
         
            +
                show([c_map])
         
     | 
| 
      
 205 
     | 
    
         
            +
             
     | 
| 
      
 206 
     | 
    
         
            +
                # ** 测试自制diverging型cmap
         
     | 
| 
      
 207 
     | 
    
         
            +
                diverging_cmap = create(["#4e00b3", "#0000FF", "#00c0ff", "#a1d3ff", "#DCDCDC", "#FFD39B", "#FF8247", "#FF0000", "#FF5F9E"])
         
     | 
| 
      
 208 
     | 
    
         
            +
                show([diverging_cmap])
         
     | 
| 
      
 209 
     | 
    
         
            +
             
     | 
| 
      
 210 
     | 
    
         
            +
                # ** 测试根据RGB的txt文档制作色卡
         
     | 
| 
      
 211 
     | 
    
         
            +
                file_path = "E:/python/colorbar/test.txt"
         
     | 
| 
      
 212 
     | 
    
         
            +
                cmap_rgb = create_rgbtxt(file_path)
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
                # ** 测试将cmap转为list
         
     | 
| 
      
 215 
     | 
    
         
            +
                out_colors = to_color("viridis", 256)
         
     |