oafuncs 0.0.97.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
oafuncs/oa_draw.py ADDED
@@ -0,0 +1,326 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2024-09-17 17:26:11
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2024-11-21 13:10:47
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_draw.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.11
14
+ """
15
+
16
+
17
+ import warnings
18
+
19
+ import cartopy.crs as ccrs
20
+ import cartopy.feature as cfeature
21
+ import matplotlib as mpl
22
+ import matplotlib.pyplot as plt
23
+ import numpy as np
24
+ import xarray as xr
25
+ from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
26
+ from rich import print
27
+
28
+ __all__ = ["fig_minus", "gif", "add_cartopy", "add_gridlines", "MidpointNormalize", "add_lonlat_unit", "contour", "contourf", "quiver"]
29
+
30
+ warnings.filterwarnings("ignore")
31
+
32
+
33
+ def fig_minus(ax_x=None, ax_y=None, cbar=None, decimal=None, add_space=False):
34
+ """
35
+ Description: 将坐标轴刻度中的负号替换为减号
36
+
37
+ param {*} ax_x : x轴
38
+ param {*} ax_y : y轴
39
+ param {*} cbar : colorbar
40
+ param {*} decimal : 小数位数
41
+ param {*} add_space : 是否在非负数前面加空格
42
+
43
+ return {*} ax_x or ax_y or cbar
44
+ """
45
+ if ax_x is not None:
46
+ current_ticks = ax_x.get_xticks()
47
+ if ax_y is not None:
48
+ current_ticks = ax_y.get_yticks()
49
+ if cbar is not None:
50
+ current_ticks = cbar.get_ticks()
51
+ # 先判断是否需要加空格,如果要,先获取需要加的索引
52
+ if add_space:
53
+ index = 0
54
+ for _, tick in enumerate(current_ticks):
55
+ if tick >= 0:
56
+ index = _
57
+ break
58
+ if decimal is not None:
59
+ # my_ticks = [(round(float(iii), decimal)) for iii in my_ticks]
60
+ current_ticks = [f"{val:.{decimal}f}" if val != 0 else "0" for val in current_ticks]
61
+
62
+ out_ticks = [f"{val}".replace("-", "\u2212") for val in current_ticks]
63
+ if add_space:
64
+ # 在非负数前面加两个空格
65
+ out_ticks[index:] = [" " + m for m in out_ticks[index:]]
66
+
67
+ if ax_x is not None:
68
+ ax_x.set_xticklabels(out_ticks)
69
+ return ax_x
70
+ if ax_y is not None:
71
+ ax_y.set_yticklabels(out_ticks)
72
+ return ax_y
73
+ if cbar is not None:
74
+ cbar.set_ticklabels(out_ticks)
75
+ return cbar
76
+
77
+
78
+ # ** 将生成图片/已有图片制作成动图
79
+ def gif(image_list: list, gif_name: str, duration=0.2): # 制作动图,默认间隔0.2
80
+ """
81
+ Description
82
+ Make gif from images
83
+ Parameters
84
+ image_list : list, list of images
85
+ gif_name : str, name of gif
86
+ duration : float, duration of each frame
87
+ Returns
88
+ None
89
+ Example
90
+ gif(["1.png", "2.png"], "test.gif", duration=0.2)
91
+ """
92
+ import imageio.v2 as imageio
93
+
94
+ frames = []
95
+ for image_name in image_list:
96
+ frames.append(imageio.imread(image_name))
97
+ imageio.mimsave(gif_name, frames, format="GIF", duration=duration)
98
+ print("Gif制作完成!")
99
+ return
100
+
101
+
102
+ # ** 转化经/纬度刻度
103
+ def add_lonlat_unit(lon=None, lat=None, decimal=2):
104
+ """
105
+ param {*} lon : 经度列表
106
+ param {*} lat : 纬度列表
107
+ param {*} decimal : 小数位数
108
+ return {*} 转化后的经/纬度列表
109
+ example : add_lonlat_unit(lon=lon, lat=lat, decimal=2)
110
+ """
111
+
112
+ def _format_longitude(x_list):
113
+ out_list = []
114
+ for x in x_list:
115
+ if x > 180:
116
+ x -= 360
117
+ # degrees = int(abs(x))
118
+ degrees = round(abs(x), decimal)
119
+ direction = "E" if x >= 0 else "W"
120
+ out_list.append(f"{degrees:.{decimal}f}°{direction}" if x != 0 and x != 180 else f"{degrees}°")
121
+ return out_list if len(out_list) > 1 else out_list[0]
122
+
123
+ def _format_latitude(y_list):
124
+ out_list = []
125
+ for y in y_list:
126
+ if y > 90:
127
+ y -= 180
128
+ # degrees = int(abs(y))
129
+ degrees = round(abs(y), decimal)
130
+ direction = "N" if y >= 0 else "S"
131
+ out_list.append(f"{degrees:.{decimal}f}°{direction}" if y != 0 else f"{degrees}°")
132
+ return out_list if len(out_list) > 1 else out_list[0]
133
+
134
+ if lon and lat:
135
+ return _format_longitude(lon), _format_latitude(lat)
136
+ elif lon:
137
+ return _format_longitude(lon)
138
+ elif lat:
139
+ return _format_latitude(lat)
140
+
141
+
142
+ # ** 添加网格线
143
+ def add_gridlines(ax, projection=ccrs.PlateCarree(), color="k", alpha=0.5, linestyle="--", linewidth=0.5):
144
+ # add gridlines
145
+ gl = ax.gridlines(crs=projection, draw_labels=True, linewidth=linewidth, color=color, alpha=alpha, linestyle=linestyle)
146
+ gl.right_labels = False
147
+ gl.top_labels = False
148
+ gl.xformatter = LongitudeFormatter(zero_direction_label=False)
149
+ gl.yformatter = LatitudeFormatter()
150
+
151
+ return ax, gl
152
+
153
+
154
+ # ** 添加地图
155
+ def add_cartopy(ax, lon=None, lat=None, projection=ccrs.PlateCarree(), gridlines=True, landcolor="lightgrey", oceancolor="lightblue", cartopy_linewidth=0.5):
156
+ # add coastlines
157
+ ax.add_feature(cfeature.LAND, facecolor=landcolor)
158
+ ax.add_feature(cfeature.OCEAN, facecolor=oceancolor)
159
+ ax.add_feature(cfeature.COASTLINE, linewidth=cartopy_linewidth)
160
+ # ax.add_feature(cfeature.BORDERS, linewidth=cartopy_linewidth, linestyle=":")
161
+
162
+ # add gridlines
163
+ if gridlines:
164
+ ax, gl = add_gridlines(ax, projection)
165
+
166
+ # set longitude and latitude format
167
+ lon_formatter = LongitudeFormatter(zero_direction_label=False)
168
+ lat_formatter = LatitudeFormatter()
169
+ ax.xaxis.set_major_formatter(lon_formatter)
170
+ ax.yaxis.set_major_formatter(lat_formatter)
171
+
172
+ # set extent
173
+ if lon is not None and lat is not None:
174
+ lon_min, lon_max = lon.min(), lon.max()
175
+ lat_min, lat_max = lat.min(), lat.max()
176
+ ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=projection)
177
+
178
+
179
+ # ** 自定义归一化类,使得0值处为中心点
180
+ class MidpointNormalize(mpl.colors.Normalize):
181
+ """
182
+ Description: 自定义归一化类,使得0值处为中心点
183
+
184
+ param {*} mpl.colors.Normalize : 继承Normalize类
185
+ return {*}
186
+
187
+ Example:
188
+ nrom = MidpointNormalize(vmin=-2, vmax=1, vcenter=0)
189
+ """
190
+
191
+ def __init__(self, vmin=None, vmax=None, vcenter=None, clip=False):
192
+ self.vcenter = vcenter
193
+ super().__init__(vmin, vmax, clip)
194
+
195
+ def __call__(self, value, clip=None):
196
+ x, y = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1.0]
197
+ return np.ma.masked_array(np.interp(value, x, y, left=-np.inf, right=np.inf))
198
+
199
+ def inverse(self, value):
200
+ y, x = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1]
201
+ return np.interp(value, x, y, left=-np.inf, right=np.inf)
202
+
203
+
204
+ # -----------------------------------------------------------------------------------------------------------------------------------------------------------------
205
+
206
+ # ** 绘制填色图
207
+ def contourf(data,x=None,y=None,cmap='coolwarm',show=True,store=None,cartopy=False):
208
+ """
209
+ Description: 绘制填色图
210
+
211
+ param {*} data : 二维数据
212
+ param {*} x : x轴坐标
213
+ param {*} y : y轴坐标
214
+ param {*} cmap : 颜色映射
215
+ param {*} show : 是否显示
216
+ param {*} store : 是否保存
217
+ param {*} cartopy : 是否使用cartopy
218
+
219
+ return {*}
220
+ """
221
+ data = np.array(data)
222
+ if x is None or y is None:
223
+ x = np.arange(data.shape[1])
224
+ y = np.arange(data.shape[0])
225
+ if cartopy:
226
+ fig, ax = plt.subplots(subplot_kw={'projection': ccrs.PlateCarree()})
227
+ add_cartopy(ax, lon=x, lat=y)
228
+ ax.contourf(x, y, data, transform=ccrs.PlateCarree(), cmap=cmap)
229
+ else:
230
+ plt.contourf(x, y, data, cmap=cmap)
231
+ plt.colorbar()
232
+ plt.savefig(store, dpi=600, bbox_inches="tight") if store else plt.show()
233
+ plt.close()
234
+
235
+
236
+ # ** 绘制等值线图
237
+ def contour(data, x=None, y=None, cmap="coolwarm", show=True, store=None, cartopy=False):
238
+ """
239
+ Description: 绘制等值线图
240
+
241
+ param {*} data : 二维数据
242
+ param {*} x : x轴坐标
243
+ param {*} y : y轴坐标
244
+ param {*} cmap : 颜色映射
245
+ param {*} show : 是否显示
246
+ param {*} store : 是否保存
247
+ param {*} cartopy : 是否使用cartopy
248
+
249
+ return {*}
250
+ """
251
+ data = np.array(data)
252
+ if x is None or y is None:
253
+ x = np.arange(data.shape[1])
254
+ y = np.arange(data.shape[0])
255
+ if cartopy:
256
+ fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
257
+ add_cartopy(ax, lon=x, lat=y)
258
+ cr = ax.contour(x, y, data, transform=ccrs.PlateCarree(), cmap=cmap)
259
+ else:
260
+ cr = plt.contour(x, y, data, cmap=cmap)
261
+ plt.clabel(cr, inline=True, fontsize=10)
262
+ plt.savefig(store, dpi=600, bbox_inches="tight") if store else plt.show()
263
+ plt.close()
264
+
265
+
266
+ # ** 绘制矢量场
267
+ def quiver(u, v, lon, lat, picname=None, cmap="coolwarm", scale=0.25, width=0.002, x_space=5, y_space=5):
268
+ """
269
+ param {*} u : 二维数据
270
+ param {*} v : 二维数据
271
+ param {*} lon : 经度, 1D or 2D
272
+ param {*} lat : 纬度, 1D or 2D
273
+ param {*} picname : 图片保存的文件名(含路径)
274
+ param {*} cmap : 颜色映射,默认coolwarm
275
+ param {*} scale : 箭头的大小 / 缩小程度
276
+ param {*} width : 箭头的宽度
277
+ param {*} x_space : x轴间隔
278
+ param {*} y_space : y轴间隔
279
+ return {*} 无返回值
280
+ """
281
+ # 创建新的网格位置变量(lat_c, lon_c)
282
+ if len(lon.shape) == 1 and len(lat.shape) == 1:
283
+ lon_c, lat_c = np.meshgrid(lon, lat)
284
+ else:
285
+ lon_c, lat_c = lon, lat
286
+
287
+ # 设置箭头的比例、颜色、宽度等参数
288
+ # scale = 0.25 # 箭头的大小 / 缩小程度
289
+ # color = '#E5D1FA'
290
+ # width = 0.002 # 箭头的宽度
291
+ # x_space = 1
292
+ # y_space = 1
293
+
294
+ # 计算矢量的大小
295
+ S = xr.DataArray(np.hypot(np.array(u), np.array(v)))
296
+
297
+ mean_S = S.nanmean()
298
+
299
+ # 使用 plt.quiver 函数绘制矢量图
300
+ # 通过设置 quiver 函数的 pivot 参数来指定箭头的位置
301
+ quiver_plot = plt.quiver(
302
+ lon_c[::y_space, ::x_space],
303
+ lat_c[::y_space, ::x_space],
304
+ u[::y_space, ::x_space],
305
+ v[::y_space, ::x_space],
306
+ S[::y_space, ::x_space], # 矢量的大小,可以不要
307
+ pivot="middle",
308
+ scale=scale,
309
+ # color=color, # 矢量的颜色,单色
310
+ cmap=cmap, # 矢量的颜色,多色
311
+ width=width,
312
+ )
313
+ # plt.quiverkey(quiver_plot, X=0.90, Y=0.975, U=1, label='1 m/s', labelpos='E', fontproperties={'size': 10})
314
+ plt.quiverkey(quiver_plot, X=0.87, Y=0.975, U=mean_S, label=f"{mean_S:.2f} m/s", labelpos="E", fontproperties={"size": 10})
315
+ plt.colorbar(quiver_plot)
316
+ plt.xlabel("X")
317
+ plt.ylabel("Y")
318
+
319
+ plt.savefig(picname, bbox_inches="tight") if picname is not None else plt.show()
320
+ plt.clf()
321
+ plt.close()
322
+
323
+
324
+
325
+ if __name__ == "__main__":
326
+ pass