oafuncs 0.0.97.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
oafuncs/__init__.py ADDED
@@ -0,0 +1,54 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2024-09-17 16:09:20
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2025-03-09 16:28:01
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\__init__.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.12
14
+ """
15
+
16
+
17
+ # 会导致OAFuncs直接导入所有函数,不符合模块化设计
18
+ # from oafuncs.oa_s.oa_cmap import *
19
+ # from oafuncs.oa_s.oa_data import *
20
+ # from oafuncs.oa_s.oa_draw import *
21
+ # from oafuncs.oa_s.oa_file import *
22
+ # from oafuncs.oa_s.oa_help import *
23
+ # from oafuncs.oa_s.oa_nc import *
24
+ # from oafuncs.oa_s.oa_python import *
25
+
26
+ # ------------------- 2024-12-13 12:31:06 -------------------
27
+ # path: My_Funcs/OAFuncs/oafuncs/
28
+ from .oa_cmap import *
29
+ from .oa_data import *
30
+
31
+ # ------------------- 2024-12-13 12:31:06 -------------------
32
+ # path: My_Funcs/OAFuncs/oafuncs/oa_down/
33
+ from .oa_down import *
34
+ from .oa_draw import *
35
+ from .oa_file import *
36
+ from .oa_help import *
37
+
38
+ # ------------------- 2024-12-13 12:31:06 -------------------
39
+ # path: My_Funcs/OAFuncs/oafuncs/oa_model/
40
+ from .oa_model import *
41
+ from .oa_nc import *
42
+ from .oa_python import *
43
+
44
+ # ------------------- 2024-12-13 12:31:06 -------------------
45
+ # path: My_Funcs/OAFuncs/oafuncs/oa_sign/
46
+ from .oa_sign import *
47
+
48
+ # ------------------- 2024-12-13 12:31:06 -------------------
49
+ # path: My_Funcs/OAFuncs/oafuncs/oa_tool/
50
+ from .oa_tool import *
51
+ # ------------------- 2025-03-09 16:28:01 -------------------
52
+ # path: My_Funcs/OAFuncs/oafuncs/_script/
53
+ from ._script import *
54
+ # ------------------- 2025-03-16 15:56:01 -------------------
@@ -0,0 +1,27 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2025-03-13 15:26:15
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2025-03-13 15:26:18
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_script\\__init__.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.12
14
+ """
15
+
16
+
17
+
18
+ # 会导致OAFuncs直接导入所有函数,不符合模块化设计
19
+ # from oafuncs.oa_s.oa_cmap import *
20
+ # from oafuncs.oa_s.oa_data import *
21
+ # from oafuncs.oa_s.oa_draw import *
22
+ # from oafuncs.oa_s.oa_file import *
23
+ # from oafuncs.oa_s.oa_help import *
24
+ # from oafuncs.oa_s.oa_nc import *
25
+ # from oafuncs.oa_s.oa_python import *
26
+
27
+ from .plot_dataset import func_plot_dataset
@@ -0,0 +1,299 @@
1
+ import os
2
+ from typing import Optional, Tuple
3
+
4
+ import matplotlib as mpl
5
+
6
+ mpl.use("Agg") # Use non-interactive backend
7
+
8
+ import cftime
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ from rich import print
12
+ import cartopy.crs as ccrs
13
+ import xarray as xr
14
+
15
+ import oafuncs
16
+
17
+
18
+ def plot_1d(data: xr.DataArray, output_path: str, x_dim: str, y_dim: str, z_dim: str, t_dim: str) -> None:
19
+ """Plot 1D data."""
20
+ plt.figure(figsize=(10, 6))
21
+
22
+ # Handle time dimension
23
+ if t_dim in data.dims and isinstance(data[t_dim].values[0], cftime.datetime):
24
+ try:
25
+ data[t_dim] = data.indexes[t_dim].to_datetimeindex()
26
+ except (AttributeError, ValueError, TypeError) as e:
27
+ print(f"Warning: Could not convert {t_dim} to datetime index: {e}")
28
+
29
+ # Determine X axis data
30
+ x, x_label = determine_x_axis(data, x_dim, y_dim, z_dim, t_dim)
31
+
32
+ y = data.values
33
+ plt.plot(x, y, linewidth=2)
34
+
35
+ # Add chart info
36
+ long_name = getattr(data, "long_name", "No long_name")
37
+ units = getattr(data, "units", "")
38
+ plt.title(f"{data.name} | {long_name}", fontsize=12)
39
+ plt.xlabel(x_label)
40
+ plt.ylabel(f"{data.name} ({units})" if units else data.name)
41
+
42
+ plt.grid(True, linestyle="--", alpha=0.7)
43
+ plt.tight_layout()
44
+
45
+ # Save image
46
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
47
+ plt.savefig(output_path, bbox_inches="tight", dpi=600)
48
+ plt.clf()
49
+ plt.close()
50
+
51
+
52
+ def determine_x_axis(data: xr.DataArray, x_dim: str, y_dim: str, z_dim: str, t_dim: str) -> Tuple[np.ndarray, str]:
53
+ """Determine the X axis data and label."""
54
+ if x_dim in data.dims:
55
+ return data[x_dim].values, x_dim
56
+ elif y_dim in data.dims:
57
+ return data[y_dim].values, y_dim
58
+ elif z_dim in data.dims:
59
+ return data[z_dim].values, z_dim
60
+ elif t_dim in data.dims:
61
+ return data[t_dim].values, t_dim
62
+ else:
63
+ return np.arange(len(data)), "Index"
64
+
65
+
66
+ def plot_2d(data: xr.DataArray, output_path: str, data_range: Optional[Tuple[float, float]], x_dim: str, y_dim: str, t_dim: str, plot_type: str) -> bool:
67
+ """Plot 2D data."""
68
+ if x_dim in data.dims and y_dim in data.dims and x_dim.lower() in ["lon", "longitude"] and y_dim.lower() in ["lat", "latitude"]:
69
+ lon_range = data[x_dim].values
70
+ lat_range = data[y_dim].values
71
+ lon_lat_ratio = np.abs(np.max(lon_range) - np.min(lon_range)) / (np.max(lat_range) - np.min(lat_range))
72
+ figsize = (10, 10 / lon_lat_ratio)
73
+ fig, ax = plt.subplots(figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()})
74
+ oafuncs.oa_draw.add_cartopy(ax, lon_range, lat_range)
75
+ else:
76
+ fig, ax = plt.subplots(figsize=(10, 8))
77
+
78
+ # Handle time dimension
79
+ if t_dim in data.dims and isinstance(data[t_dim].values[0], cftime.datetime):
80
+ try:
81
+ data[t_dim] = data.indexes[t_dim].to_datetimeindex()
82
+ except (AttributeError, ValueError, TypeError) as e:
83
+ print(f"Warning: Could not convert {t_dim} to datetime index: {e}")
84
+
85
+ # Check for valid data
86
+ if np.all(np.isnan(data.values)) or data.size == 0:
87
+ print(f"Skipping {data.name}: All values are NaN or empty")
88
+ plt.close()
89
+ return False
90
+
91
+ data_range = calculate_data_range(data, data_range)
92
+
93
+ if data_range is None:
94
+ print(f"Skipping {data.name} due to all NaN values")
95
+ plt.close()
96
+ return False
97
+
98
+ # Select appropriate colormap and levels
99
+ cmap, norm, levels = select_colormap_and_levels(data_range, plot_type)
100
+
101
+ mappable = None
102
+ try:
103
+ if plot_type == "contourf":
104
+ if np.ptp(data.values) < 1e-10 and not np.all(np.isnan(data.values)):
105
+ print(f"Warning: {data.name} has very little variation. Using imshow instead.")
106
+ mappable = ax.imshow(data.values, cmap=cmap, aspect="auto", interpolation="none")
107
+ colorbar = plt.colorbar(mappable, ax=ax)
108
+ else:
109
+ mappable = ax.contourf(data[x_dim], data[y_dim], data.values, levels=levels, cmap=cmap, norm=norm)
110
+ colorbar = plt.colorbar(mappable, ax=ax)
111
+ elif plot_type == "contour":
112
+ if np.ptp(data.values) < 1e-10 and not np.all(np.isnan(data.values)):
113
+ print(f"Warning: {data.name} has very little variation. Using imshow instead.")
114
+ mappable = ax.imshow(data.values, cmap=cmap, aspect="auto", interpolation="none")
115
+ colorbar = plt.colorbar(mappable, ax=ax)
116
+ else:
117
+ mappable = ax.contour(data[x_dim], data[y_dim], data.values, levels=levels, cmap=cmap, norm=norm)
118
+ ax.clabel(mappable, inline=True, fontsize=8, fmt="%1.1f")
119
+ colorbar = plt.colorbar(mappable, ax=ax)
120
+ except (ValueError, TypeError) as e:
121
+ print(f"Warning: Could not plot with specified parameters: {e}. Trying simplified parameters.")
122
+ try:
123
+ mappable = data.plot(ax=ax, cmap=cmap, add_colorbar=False)
124
+ colorbar = plt.colorbar(mappable, ax=ax)
125
+ except Exception as e2:
126
+ print(f"Error plotting {data.name}: {e2}")
127
+ plt.figure(figsize=(10, 8))
128
+ mappable = ax.imshow(data.values, cmap="viridis", aspect="auto")
129
+ colorbar = plt.colorbar(mappable, ax=ax, label=getattr(data, "units", ""))
130
+ plt.title(f"{data.name} | {getattr(data, 'long_name', 'No long_name')} (basic plot)", fontsize=12)
131
+ plt.tight_layout()
132
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
133
+ plt.savefig(output_path, bbox_inches="tight", dpi=600)
134
+ plt.close()
135
+ return True
136
+
137
+ plt.title(f"{data.name} | {getattr(data, 'long_name', 'No long_name')}", fontsize=12)
138
+ units = getattr(data, "units", "")
139
+ if units and colorbar:
140
+ colorbar.set_label(units)
141
+
142
+ plt.tight_layout()
143
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
144
+ plt.savefig(output_path, bbox_inches="tight", dpi=600)
145
+ plt.close()
146
+ return True
147
+
148
+
149
+ def calculate_data_range(data: xr.DataArray, data_range: Optional[Tuple[float, float]]) -> Optional[Tuple[float, float]]:
150
+ """Calculate the data range, ignoring extreme outliers."""
151
+ if data_range is None:
152
+ flat_data = data.values.flatten()
153
+ if flat_data.size == 0:
154
+ return None
155
+ valid_data = flat_data[~np.isnan(flat_data)]
156
+ if len(valid_data) == 0:
157
+ return None
158
+ low, high = np.percentile(valid_data, [0.5, 99.5])
159
+ filtered_data = valid_data[(valid_data >= low) & (valid_data <= high)]
160
+ if len(filtered_data) > 0:
161
+ data_range = (np.min(filtered_data), np.max(filtered_data))
162
+ else:
163
+ data_range = (np.nanmin(valid_data), np.nanmax(valid_data))
164
+ if abs(data_range[1] - data_range[0]) < 1e-10:
165
+ mean = (data_range[0] + data_range[1]) / 2
166
+ data_range = (mean - 1e-10 if mean != 0 else -1e-10, mean + 1e-10 if mean != 0 else 1e-10)
167
+ return data_range
168
+
169
+
170
+ def select_colormap_and_levels(data_range: Tuple[float, float], plot_type: str) -> Tuple[mpl.colors.Colormap, mpl.colors.Normalize, np.ndarray]:
171
+ """Select colormap and levels based on data range."""
172
+ if plot_type == "contour":
173
+ # For contour plots, use fewer levels
174
+ num_levels = 10
175
+ else:
176
+ # For filled contour plots, use more levels
177
+ num_levels = 128
178
+
179
+ if data_range[0] * data_range[1] < 0:
180
+ cmap = oafuncs.oa_cmap.get("diverging_1")
181
+ bdy = max(abs(data_range[0]), abs(data_range[1]))
182
+ norm = mpl.colors.TwoSlopeNorm(vmin=-bdy, vcenter=0, vmax=bdy)
183
+ levels = np.linspace(-bdy, bdy, num_levels)
184
+ else:
185
+ cmap = oafuncs.oa_cmap.get("cool_1") if data_range[0] < 0 else oafuncs.oa_cmap.get("warm_1")
186
+ norm = mpl.colors.Normalize(vmin=data_range[0], vmax=data_range[1])
187
+ levels = np.linspace(data_range[0], data_range[1], num_levels)
188
+
189
+ if np.any(np.diff(levels) <= 0):
190
+ levels = np.linspace(data_range[0], data_range[1], 10)
191
+ return cmap, norm, levels
192
+
193
+
194
+ def process_variable(var: str, data: xr.DataArray, dims: int, dims_name: Tuple[str, ...], output_dir: str, x_dim: str, y_dim: str, z_dim: str, t_dim: str, fixed_colorscale: bool, plot_type: str) -> None:
195
+ """Process a single variable."""
196
+ valid_dims = {x_dim, y_dim, z_dim, t_dim}
197
+ if not set(dims_name).issubset(valid_dims):
198
+ print(f"Skipping {var} due to unsupported dimensions: {dims_name}")
199
+ return
200
+
201
+ # Process 1D data
202
+ if dims == 1:
203
+ if np.issubdtype(data.dtype, np.character):
204
+ print(f"Skipping {var} due to character data type")
205
+ return
206
+ plot_1d(data, os.path.join(output_dir, f"{var}.png"), x_dim, y_dim, z_dim, t_dim)
207
+ print(f"{var}.png")
208
+ return
209
+
210
+ # Compute global data range for fixed colorscale
211
+ global_data_range = None
212
+ if dims >= 2 and fixed_colorscale:
213
+ global_data_range = calculate_data_range(data, None)
214
+ if global_data_range is None:
215
+ print(f"Skipping {var} due to no valid data")
216
+ return
217
+ print(f"Fixed colorscale range: {global_data_range}")
218
+
219
+ # Process 2D data
220
+ if dims == 2:
221
+ success = plot_2d(data, os.path.join(output_dir, f"{var}.png"), global_data_range, x_dim, y_dim, t_dim, plot_type)
222
+ if success:
223
+ print(f"{var}.png")
224
+
225
+ # Process 3D data
226
+ if dims == 3:
227
+ for i in range(data.shape[0]):
228
+ for attempt in range(10):
229
+ try:
230
+ if data[i].values.size == 0:
231
+ print(f"Skipped {var}_{dims_name[0]}-{i} (empty data)")
232
+ break
233
+ success = plot_2d(data[i], os.path.join(output_dir, f"{var}_{dims_name[0]}-{i}.png"), global_data_range, x_dim, y_dim, t_dim, plot_type)
234
+ if success:
235
+ print(f"{var}_{dims_name[0]}-{i}.png")
236
+ else:
237
+ print(f"Skipped {var}_{dims_name[0]}-{i} (invalid data)")
238
+ break
239
+ except Exception as e:
240
+ if attempt < 9:
241
+ print(f"Retrying {var}_{dims_name[0]}-{i} (attempt {attempt + 1})")
242
+ else:
243
+ print(f"Error processing {var}_{dims_name[0]}-{i}: {e}")
244
+
245
+ # Process 4D data
246
+ if dims == 4:
247
+ for i in range(data.shape[0]):
248
+ for j in range(data.shape[1]):
249
+ for attempt in range(3):
250
+ try:
251
+ if data[i, j].values.size == 0:
252
+ print(f"Skipped {var}_{dims_name[0]}-{i}_{dims_name[1]}-{j} (empty data)")
253
+ break
254
+ success = plot_2d(data[i, j], os.path.join(output_dir, f"{var}_{dims_name[0]}-{i}_{dims_name[1]}-{j}.png"), global_data_range, x_dim, y_dim, t_dim, plot_type)
255
+ if success:
256
+ print(f"{var}_{dims_name[0]}-{i}_{dims_name[1]}-{j}.png")
257
+ else:
258
+ print(f"Skipped {var}_{dims_name[0]}-{i}_{dims_name[1]}-{j} (invalid data)")
259
+ break
260
+ except Exception as e:
261
+ if attempt < 2:
262
+ print(f"Retrying {var}_{dims_name[0]}-{i}_{dims_name[1]}-{j} (attempt {attempt + 1})")
263
+ else:
264
+ print(f"Error processing {var}_{dims_name[0]}-{i}_{dims_name[1]}-{j}: {e}")
265
+
266
+
267
+ def func_plot_dataset(ds_in: xr.Dataset, output_dir: str, xyzt_dims: Tuple[str, str, str, str] = ("longitude", "latitude", "level", "time"), plot_type: str = "contourf", fixed_colorscale: bool = False) -> None:
268
+ """Plot variables from a NetCDF file and save the plots to the specified directory."""
269
+ os.makedirs(output_dir, exist_ok=True)
270
+ x_dim, y_dim, z_dim, t_dim = xyzt_dims
271
+
272
+ # Main processing function
273
+ try:
274
+ ds = ds_in
275
+ varlist = list(ds.data_vars)
276
+ print(f"Found {len(varlist)} variables in dataset")
277
+
278
+ for var in varlist:
279
+ print("=" * 120)
280
+ print(f"Processing: {var}")
281
+ data = ds[var]
282
+ dims = len(data.shape)
283
+ dims_name = data.dims
284
+ try:
285
+ process_variable(var, data, dims, dims_name, output_dir, x_dim, y_dim, z_dim, t_dim, fixed_colorscale, plot_type)
286
+ except Exception as e:
287
+ print(f"Error processing variable {var}: {e}")
288
+
289
+ except Exception as e:
290
+ print(f"Error processing dataset: {e}")
291
+ finally:
292
+ if "ds" in locals():
293
+ ds.close()
294
+ print("Dataset closed")
295
+
296
+
297
+ if __name__ == "__main__":
298
+ pass
299
+ # func_plot_dataset(ds, output_dir, xyzt_dims=("longitude", "latitude", "level", "time"), plot_type="contourf", fixed_colorscale=False)
Binary file
Binary file
oafuncs/oa_cmap.py ADDED
@@ -0,0 +1,215 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2024-09-17 16:55:11
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2024-11-21 13:14:24
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_cmap.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.11
14
+ """
15
+
16
+ import matplotlib as mpl
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ from rich import print
20
+
21
+ __all__ = ["show", "to_color", "create", "create_rgbtxt", "get"]
22
+
23
+ # ** 将cmap用填色图可视化(官网摘抄函数)
24
+ def show(colormaps):
25
+ """
26
+ Description:
27
+ Helper function to plot data with associated colormap.
28
+ Parameters:
29
+ colormaps : list of colormaps, or a single colormap; can be a string or a colormap object.
30
+ Example:
31
+ cmap = ListedColormap(["darkorange", "gold", "lawngreen", "lightseagreen"])
32
+ show([cmap]); show("viridis"); show(["viridis", "cividis"])
33
+ """
34
+ if isinstance(colormaps, str) or isinstance(colormaps, mpl.colors.Colormap):
35
+ colormaps = [colormaps]
36
+ np.random.seed(19680801)
37
+ data = np.random.randn(30, 30)
38
+ n = len(colormaps)
39
+ fig, axs = plt.subplots(1, n, figsize=(n * 2 + 2, 3), constrained_layout=True, squeeze=False)
40
+ for [ax, cmap] in zip(axs.flat, colormaps):
41
+ psm = ax.pcolormesh(data, cmap=cmap, rasterized=True, vmin=-4, vmax=4)
42
+ fig.colorbar(psm, ax=ax)
43
+ plt.show()
44
+
45
+
46
+ # ** 将cmap转为list,即多个颜色的列表
47
+ def to_color(cmap, n=256):
48
+ """
49
+ Description:
50
+ Convert a colormap to a list of colors
51
+ Parameters:
52
+ cmap : str; the name of the colormap
53
+ n : int, optional; the number of colors
54
+ Return:
55
+ out_colors : list of colors
56
+ Example:
57
+ out_colors = to_color('viridis', 256)
58
+ """
59
+ c_map = mpl.colormaps.get_cmap(cmap)
60
+ out_colors = [c_map(i) for i in np.linspace(0, 1, n)]
61
+ return out_colors
62
+
63
+
64
+ # ** 自制cmap,多色,可带位置
65
+ def create(colors: list, nodes=None, under=None, over=None): # 利用颜色快速配色
66
+ """
67
+ Description:
68
+ Create a custom colormap
69
+ Parameters:
70
+ colors : list of colors
71
+ nodes : list of positions
72
+ under : color
73
+ over : color
74
+ Return:
75
+ cmap : colormap
76
+ Example:
77
+ cmap = create(['#C2B7F3','#B3BBF2','#B0CBF1','#ACDCF0','#A8EEED'])
78
+ cmap = create(['aliceblue','skyblue','deepskyblue'],[0.0,0.5,1.0])
79
+ """
80
+
81
+ if nodes is None: # 采取自动分配比例
82
+ cmap_color = mpl.colors.LinearSegmentedColormap.from_list("mycmap", colors)
83
+ else: # 按照提供比例分配
84
+ cmap_color = mpl.colors.LinearSegmentedColormap.from_list("mycmap", list(zip(nodes, colors)))
85
+ if under is not None:
86
+ cmap_color.set_under(under)
87
+ if over is not None:
88
+ cmap_color.set_over(over)
89
+ return cmap_color
90
+
91
+
92
+ # ** 根据RGB的txt文档制作色卡(利用Grads调色盘)
93
+ def create_rgbtxt(rgbtxt_file,split_mark=','): # 根据RGB的txt文档制作色卡/根据rgb值制作
94
+ """
95
+ Description
96
+ -----------
97
+ Make a color card according to the RGB txt document, each line in the txt file is an RGB value, separated by commas, such as: 251,251,253
98
+
99
+ Parameters
100
+ ----------
101
+ rgbtxt_file : str, the path of txt file
102
+ split_mark : str, optional, default is ','; the split mark of rgb value
103
+
104
+ Returns
105
+ -------
106
+ cmap : colormap
107
+
108
+ Example
109
+ -------
110
+ cmap=create_rgbtxt(path,split_mark=',')
111
+
112
+ txt example
113
+ -----------
114
+ 251,251,253
115
+ 225,125,25
116
+ 250,205,255
117
+ """
118
+ with open(rgbtxt_file) as fid:
119
+ data = fid.readlines()
120
+ n = len(data)
121
+ rgb = np.zeros((n, 3))
122
+ for i in np.arange(n):
123
+ rgb[i][0] = data[i].split(split_mark)[0]
124
+ rgb[i][1] = data[i].split(split_mark)[1]
125
+ rgb[i][2] = data[i].split(split_mark)[2]
126
+ max_rgb = np.max(rgb)
127
+ if max_rgb > 2: # if the value is greater than 2, it is normalized to 0-1
128
+ rgb = rgb / 255.0
129
+ my_cmap = mpl.colors.ListedColormap(rgb, name="my_color")
130
+ return my_cmap
131
+
132
+
133
+ # ** 选择cmap
134
+ def get(cmap_name=None, query=False):
135
+ """
136
+ Description:
137
+ Choosing a colormap from the list of available colormaps or a custom colormap
138
+ Parameters:
139
+ cmap_name : str, optional; the name of the colormap
140
+ query : bool, optional; whether to query the available colormap names
141
+ Return:
142
+ cmap : colormap
143
+ Example:
144
+ cmap = get('viridis')
145
+ cmap = get('diverging_1')
146
+ cmap = get('cool_1')
147
+ cmap = get('warm_1')
148
+ cmap = get('colorful_1')
149
+ """
150
+
151
+ my_cmap_dict = {
152
+ "diverging_1": create(["#4e00b3", "#0000FF", "#00c0ff", "#a1d3ff", "#DCDCDC", "#FFD39B", "#FF8247", "#FF0000", "#FF5F9E"]),
153
+ "cool_1": create(["#4e00b3", "#0000FF", "#00c0ff", "#a1d3ff", "#DCDCDC"]),
154
+ "warm_1": create(["#DCDCDC", "#FFD39B", "#FF8247", "#FF0000", "#FF5F9E"]),
155
+ # "land_1": create_custom(["#3E6436", "#678A59", "#91A176", "#B8A87D", "#D9CBB2"], under="#A6CEE3", over="#FFFFFF"),
156
+ # "ocean_1": create_custom(["#126697", "#2D88B3", "#4EA1C9", "#78B9D8", "#A6CEE3"], under="#8470FF", over="#3E6436"),
157
+ # "ocean_land_1": create_custom(
158
+ # [
159
+ # "#126697", # 深蓝(深海)
160
+ # "#2D88B3", # 蓝
161
+ # "#4EA1C9", # 蓝绿
162
+ # "#78B9D8", # 浅蓝(浅海)
163
+ # "#A6CEE3", # 浅蓝(近岸)
164
+ # "#AAAAAA", # 灰色(0值,海平面)
165
+ # "#D9CBB2", # 沙质土壤色(陆地开始)
166
+ # "#B8A87D", # 浅棕
167
+ # "#91A176", # 浅绿
168
+ # "#678A59", # 中绿
169
+ # "#3E6436", # 深绿(高山)
170
+ # ]
171
+ # ),
172
+ "colorful_1": create(["#6d00db", "#9800cb", "#F2003C", "#ff4500", "#ff7f00", "#FE28A2", "#FFC0CB", "#DDA0DD", "#40E0D0", "#1a66f2", "#00f7fb", "#8fff88", "#E3FF00"]),
173
+ }
174
+ if query:
175
+ print("Available cmap names:")
176
+ print('-' * 20)
177
+ print('Defined by myself:')
178
+ for key, _ in my_cmap_dict.items():
179
+ print(key)
180
+ print('-' * 20)
181
+ print('Matplotlib built-in:')
182
+ print(mpl.colormaps())
183
+ print("-" * 20)
184
+
185
+ if cmap_name is None:
186
+ return
187
+
188
+ if cmap_name in my_cmap_dict:
189
+ return my_cmap_dict[cmap_name]
190
+ else:
191
+ try:
192
+ return mpl.colormaps.get_cmap(cmap_name)
193
+ except ValueError:
194
+ # raise ValueError(f"Unknown cmap name: {cmap_name}")
195
+ print(f"Unknown cmap name: {cmap_name}\nNow return 'rainbow' as default.")
196
+ return mpl.colormaps.get_cmap("rainbow")
197
+
198
+
199
+ if __name__ == "__main__":
200
+ # ** 测试自制cmap
201
+ colors = ["#C2B7F3", "#B3BBF2", "#B0CBF1", "#ACDCF0", "#A8EEED"]
202
+ nodes = [0.0, 0.2, 0.4, 0.6, 1.0]
203
+ c_map = create(colors, nodes)
204
+ show([c_map])
205
+
206
+ # ** 测试自制diverging型cmap
207
+ diverging_cmap = create(["#4e00b3", "#0000FF", "#00c0ff", "#a1d3ff", "#DCDCDC", "#FFD39B", "#FF8247", "#FF0000", "#FF5F9E"])
208
+ show([diverging_cmap])
209
+
210
+ # ** 测试根据RGB的txt文档制作色卡
211
+ file_path = "E:/python/colorbar/test.txt"
212
+ cmap_rgb = create_rgbtxt(file_path)
213
+
214
+ # ** 测试将cmap转为list
215
+ out_colors = to_color("viridis", 256)