oafuncs 0.0.79__py2.py3-none-any.whl → 0.0.81__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. oafuncs/__init__.py +25 -7
  2. oafuncs/oa_cmap.py +31 -52
  3. oafuncs/oa_down/hycom_3hourly.py +68 -25
  4. oafuncs/oa_down/test_ua.py +151 -0
  5. oafuncs/oa_nc.py +120 -10
  6. oafuncs/oa_s/__init__.py +23 -0
  7. oafuncs/oa_s/oa_cmap.py +163 -0
  8. oafuncs/oa_s/oa_data.py +187 -0
  9. oafuncs/oa_s/oa_draw.py +451 -0
  10. oafuncs/oa_s/oa_file.py +332 -0
  11. oafuncs/oa_s/oa_help.py +39 -0
  12. oafuncs/oa_s/oa_nc.py +410 -0
  13. oafuncs/oa_s/oa_python.py +107 -0
  14. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/__init__.py" +26 -0
  15. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_cmap.py" +163 -0
  16. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_data.py" +187 -0
  17. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_down/__init__.py" +20 -0
  18. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_down/hycom_3hourly.py" +1176 -0
  19. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_down/literature.py" +332 -0
  20. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_down/test_ua.py" +151 -0
  21. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_draw.py" +451 -0
  22. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_file.py" +332 -0
  23. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_help.py" +39 -0
  24. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_nc.py" +410 -0
  25. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_python.py" +107 -0
  26. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_sign/__init__.py" +21 -0
  27. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_sign/meteorological.py" +168 -0
  28. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_sign/ocean.py" +158 -0
  29. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_sign/scientific.py" +139 -0
  30. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_tool/__init__.py" +18 -0
  31. oafuncs - /321/205/320/231/320/277/321/206/320/254/320/274/oa_tool/email.py" +114 -0
  32. {oafuncs-0.0.79.dist-info → oafuncs-0.0.81.dist-info}/METADATA +1 -2
  33. oafuncs-0.0.81.dist-info/RECORD +51 -0
  34. oafuncs-0.0.79.dist-info/RECORD +0 -24
  35. {oafuncs-0.0.79.dist-info → oafuncs-0.0.81.dist-info}/LICENSE.txt +0 -0
  36. {oafuncs-0.0.79.dist-info → oafuncs-0.0.81.dist-info}/WHEEL +0 -0
  37. {oafuncs-0.0.79.dist-info → oafuncs-0.0.81.dist-info}/top_level.txt +0 -0
oafuncs/oa_nc.py CHANGED
@@ -19,7 +19,7 @@ import netCDF4 as nc
19
19
  import numpy as np
20
20
  import xarray as xr
21
21
 
22
- __all__ = ["get_var", "extract5nc", "write2nc", "merge5nc", "modify_var_value", "modify_var_attr", "rename_var_or_dim", "check_ncfile"]
22
+ __all__ = ["get_var", "extract5nc", "write2nc", "merge5nc", "modify_var_value", "modify_var_attr", "rename_var_or_dim", "check_ncfile", "longitude_change", "nc_isel"]
23
23
 
24
24
 
25
25
  def get_var(file, *vars):
@@ -38,7 +38,7 @@ def get_var(file, *vars):
38
38
  return datas
39
39
 
40
40
 
41
- def extract5nc(file, varname):
41
+ def extract5nc(file, varname, only_value=True):
42
42
  """
43
43
  描述:
44
44
  1、提取nc文件中的变量
@@ -47,16 +47,22 @@ def extract5nc(file, varname):
47
47
  参数:
48
48
  file: 文件路径
49
49
  varname: 变量名
50
+ only_value: 变量和维度是否只保留数值
50
51
  example: data, dimdict = extract5nc(file_ecm, 'h')
51
52
  """
52
53
  ds = xr.open_dataset(file)
53
54
  vardata = ds[varname]
55
+ ds.close()
54
56
  dims = vardata.dims
55
57
  dimdict = {}
56
58
  for dim in dims:
57
- dimdict[dim] = ds[dim].values
58
- ds.close()
59
- return np.array(vardata), dimdict
59
+ if only_value:
60
+ dimdict[dim] = vardata[dim].values
61
+ else:
62
+ dimdict[dim] = ds[dim]
63
+ if only_value:
64
+ vardata = np.array(vardata)
65
+ return vardata, dimdict
60
66
 
61
67
 
62
68
  def _numpy_to_nc_type(numpy_type):
@@ -76,15 +82,27 @@ def _numpy_to_nc_type(numpy_type):
76
82
  return numpy_to_nc.get(str(numpy_type), "f4") # 默认使用 'float32'
77
83
 
78
84
 
79
- def write2nc(file, data, varname, coords, mode):
85
+ def _calculate_scale_and_offset(data, n=16):
86
+ data_min, data_max = np.nanmin(data), np.nanmax(data)
87
+ scale_factor = (data_max - data_min) / (2 ** n - 1)
88
+ add_offset = data_min + 2 ** (n - 1) * scale_factor
89
+ # S = Q * scale_factor + add_offset
90
+ return scale_factor, add_offset
91
+
92
+
93
+ def write2nc(file, data, varname=None, coords=None, mode='w', scale_offset_switch=True, compile_switch=True):
80
94
  """
81
95
  description: 写入数据到nc文件
96
+
82
97
  参数:
83
98
  file: 文件路径
84
99
  data: 数据
85
100
  varname: 变量名
86
101
  coords: 坐标,字典,键为维度名称,值为坐标数据
87
102
  mode: 写入模式,'w'为写入,'a'为追加
103
+ scale_offset_switch: 是否使用scale_factor和add_offset,默认为True
104
+ compile_switch: 是否使用压缩参数,默认为True
105
+
88
106
  example: write2nc(r'test.nc', data, 'data', {'time': np.linspace(0, 120, 100), 'lev': np.linspace(0, 120, 50)}, 'a')
89
107
  """
90
108
  # 判断mode是写入还是追加
@@ -96,6 +114,21 @@ def write2nc(file, data, varname, coords, mode):
96
114
  if not os.path.exists(file):
97
115
  print("Warning: File doesn't exist. Creating a new file.")
98
116
  mode = "w"
117
+
118
+ complete = False
119
+ if varname is None and coords is None:
120
+ try:
121
+ data.to_netcdf(file)
122
+ complete = True
123
+ # 不能在这里return
124
+ except AttributeError:
125
+ raise ValueError("If varname and coords are None, data must be a DataArray.")
126
+
127
+ if complete:
128
+ return
129
+
130
+ kwargs = {'zlib': True, 'complevel': 4} # 压缩参数
131
+ # kwargs = {"compression": 'zlib', "complevel": 4} # 压缩参数
99
132
 
100
133
  # 打开 NetCDF 文件
101
134
  with nc.Dataset(file, mode, format="NETCDF4") as ncfile:
@@ -116,8 +149,17 @@ def write2nc(file, data, varname, coords, mode):
116
149
  if add_coords:
117
150
  # 创建新坐标
118
151
  ncfile.createDimension(dim, len(coord_data))
119
- ncfile.createVariable(dim, _numpy_to_nc_type(coord_data.dtype), (dim,))
152
+ if compile_switch:
153
+ ncfile.createVariable(dim, _numpy_to_nc_type(coord_data.dtype), (dim,), **kwargs)
154
+ else:
155
+ ncfile.createVariable(dim, _numpy_to_nc_type(coord_data.dtype), (dim,))
120
156
  ncfile.variables[dim][:] = np.array(coord_data)
157
+
158
+ if isinstance(coord_data, xr.DataArray):
159
+ current_var = ncfile.variables[dim]
160
+ if coord_data.attrs:
161
+ for key, value in coord_data.attrs.items():
162
+ current_var.setncattr(key, value)
121
163
 
122
164
  # 判断变量是否存在,若存在,则删除原变量
123
165
  add_var = True
@@ -127,22 +169,48 @@ def write2nc(file, data, varname, coords, mode):
127
169
  raise ValueError("Shape of data does not match the variable shape.")
128
170
  else:
129
171
  # 写入数据
130
- ncfile.variables[varname][:] = data
172
+ ncfile.variables[varname][:] = np.array(data)
131
173
  add_var = False
132
174
  print(f"Warning: Variable '{varname}' already exists. Replacing it.")
133
175
 
134
176
  if add_var:
135
177
  # 创建变量及其维度
136
178
  dim_names = tuple(coords.keys()) # 使用coords传入的维度名称
137
- ncfile.createVariable(varname, _numpy_to_nc_type(data.dtype), dim_names)
179
+ if scale_offset_switch:
180
+ scale_factor, add_offset = _calculate_scale_and_offset(np.array(data))
181
+ _FillValue = -32767
182
+ missing_value = -32767
183
+ dtype = 'i2' # short类型
184
+ else:
185
+ dtype = _numpy_to_nc_type(data.dtype)
186
+
187
+ if compile_switch:
188
+ ncfile.createVariable(varname, dtype, dim_names, **kwargs)
189
+ else:
190
+ ncfile.createVariable(varname, dtype, dim_names)
191
+
192
+ if scale_offset_switch: # 需要在写入数据之前设置scale_factor和add_offset
193
+ ncfile.variables[varname].setncattr('scale_factor', scale_factor)
194
+ ncfile.variables[varname].setncattr('add_offset', add_offset)
195
+ ncfile.variables[varname].setncattr('_FillValue', _FillValue)
196
+ ncfile.variables[varname].setncattr('missing_value', missing_value)
197
+
138
198
  # ncfile.createVariable('data', 'f4', ('time','lev'))
139
199
 
140
200
  # 写入数据
141
- ncfile.variables[varname][:] = data
201
+ ncfile.variables[varname][:] = np.array(data)
142
202
 
143
203
  # 判断维度是否匹配
144
204
  if len(data.shape) != len(coords):
145
205
  raise ValueError("Number of dimensions does not match the data shape.")
206
+ # 判断data是否带有属性信息,如果有,写入属性信息
207
+ if isinstance(data, xr.DataArray):
208
+ current_var = ncfile.variables[varname]
209
+ if data.attrs:
210
+ for key, value in data.attrs.items():
211
+ if key in ["scale_factor", "add_offset", "_FillValue", "missing_value"] and scale_offset_switch:
212
+ continue
213
+ current_var.setncattr(key, value)
146
214
 
147
215
 
148
216
  def merge5nc(file_list, var_name=None, dim_name=None, target_filename=None):
@@ -330,6 +398,48 @@ def check_ncfile(ncfile, if_delete=False):
330
398
  return False
331
399
 
332
400
 
401
+ def longitude_change(ds, lon_name="longitude", to_which="180"):
402
+ """
403
+ 将经度转换为 -180 到 180 之间
404
+
405
+ 参数:
406
+ lon (numpy.ndarray): 经度数组
407
+
408
+ 返回值:
409
+ numpy.ndarray: 转换后的经度数组
410
+ """
411
+ # return (lon + 180) % 360 - 180
412
+ # ds = ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
413
+ if to_which == "180":
414
+ # ds = ds.assign_coords(**{lon_name: (((ds[lon_name] + 180) % 360) - 180)}).sortby(lon_name)
415
+ ds = ds.assign_coords(**{lon_name: (ds[lon_name] + 180) % 360 - 180}).sortby(lon_name)
416
+ elif to_which == "360":
417
+ # -180 to 180 to 0 to 360
418
+ ds = ds.assign_coords(**{lon_name: (ds[lon_name] + 360) % 360}).sortby(lon_name)
419
+ return ds
420
+
421
+
422
+ def nc_isel(ncfile, dim_name, slice_list):
423
+ """
424
+ Description: Choose the data by the index of the dimension
425
+
426
+ Parameters:
427
+ ncfile: str, the path of the netCDF file
428
+ dim_name: str, the name of the dimension
429
+ slice_list: list, the index of the dimension
430
+
431
+ slice_list example: slice_list = [[y*12+m for m in range(11,14)] for y in range(84)]
432
+ or
433
+ slice_list = [y * 12 + m for y in range(84) for m in range(11, 14)]
434
+ """
435
+ ds = xr.open_dataset(ncfile)
436
+ slice_list = np.array(slice_list).flatten()
437
+ slice_list = [int(i) for i in slice_list]
438
+ ds_new = ds.isel(**{dim_name: slice_list})
439
+ ds.close()
440
+ return ds_new
441
+
442
+
333
443
  if __name__ == "__main__":
334
444
  data = np.random.rand(100, 50)
335
445
  write2nc(r"test.nc", data, "data", {"time": np.linspace(0, 120, 100), "lev": np.linspace(0, 120, 50)}, "a")
@@ -0,0 +1,23 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2024-09-17 16:09:20
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2024-12-13 10:47:40
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_s\\__init__.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.11
14
+ """
15
+
16
+ # 会导致OAFuncs直接导入所有函数,不符合模块化设计
17
+ from .oa_cmap import *
18
+ from .oa_data import *
19
+ from .oa_draw import *
20
+ from .oa_file import *
21
+ from .oa_help import *
22
+ from .oa_nc import *
23
+ from .oa_python import *
@@ -0,0 +1,163 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2024-09-17 16:55:11
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2024-11-21 13:14:24
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_cmap.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.11
14
+ """
15
+
16
+ import matplotlib as mpl
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+
20
+ __all__ = ["show", "cmap2colors", "create_cmap", "create_cmap_rgbtxt", "choose_cmap"]
21
+
22
+ # ** 将cmap用填色图可视化(官网摘抄函数)
23
+ def show(colormaps: list):
24
+ """
25
+ Helper function to plot data with associated colormap.
26
+ example:
27
+ cmap = ListedColormap(["darkorange", "gold", "lawngreen", "lightseagreen"])
28
+ show([cmap])
29
+ """
30
+ np.random.seed(19680801)
31
+ data = np.random.randn(30, 30)
32
+ n = len(colormaps)
33
+ fig, axs = plt.subplots(1, n, figsize=(n * 2 + 2, 3), constrained_layout=True, squeeze=False)
34
+ for [ax, cmap] in zip(axs.flat, colormaps):
35
+ psm = ax.pcolormesh(data, cmap=cmap, rasterized=True, vmin=-4, vmax=4)
36
+ fig.colorbar(psm, ax=ax)
37
+ plt.show()
38
+
39
+
40
+ # ** 将cmap转为list,即多个颜色的列表
41
+ def cmap2colors(cmap, n=256):
42
+ """
43
+ cmap : cmap名称
44
+ n : 提取颜色数量
45
+ return : 提取的颜色列表
46
+ example : out_colors = cmap2colors('viridis', 256)
47
+ """
48
+ c_map = mpl.colormaps.get_cmap(cmap)
49
+ out_colors = [c_map(i) for i in np.linspace(0, 1, n)]
50
+ return out_colors
51
+
52
+
53
+ # ** 自制cmap,多色,可带位置
54
+ def create_cmap(colors: list, nodes=None, under=None, over=None): # 利用颜色快速配色
55
+ """
56
+ func : 自制cmap,自动确定颜色位置(等比例)
57
+ description : colors可以是颜色名称,也可以是十六进制颜色代码
58
+ param {*} colors 颜色
59
+ param {*} nodes 颜色位置,默认不提供,等间距
60
+ return {*} cmap
61
+ example : cmap = create_cmap(['#C2B7F3','#B3BBF2','#B0CBF1','#ACDCF0','#A8EEED'])
62
+ cmap = create_cmap(['aliceblue','skyblue','deepskyblue'],[0.0,0.5,1.0])
63
+ """
64
+ if nodes is None: # 采取自动分配比例
65
+ cmap_color = mpl.colors.LinearSegmentedColormap.from_list("mycmap", colors)
66
+ else: # 按照提供比例分配
67
+ cmap_color = mpl.colors.LinearSegmentedColormap.from_list("mycmap", list(zip(nodes, colors)))
68
+ if under is not None:
69
+ cmap_color.set_under(under)
70
+ if over is not None:
71
+ cmap_color.set_over(over)
72
+ return cmap_color
73
+
74
+
75
+ # ** 根据RGB的txt文档制作色卡(利用Grads调色盘)
76
+ def create_cmap_rgbtxt(rgbtxt_file,split_mark=','): # 根据RGB的txt文档制作色卡/根据rgb值制作
77
+ """
78
+ func : 根据RGB的txt文档制作色卡
79
+ description : rgbtxt_file='E:/python/colorbar/test.txt'
80
+ param {*} rgbtxt_file txt文件路径
81
+ return {*} camp
82
+ example : cmap=create_cmap_rgbtxt(path,split_mark=',') #
83
+
84
+ txt example : 251,251,253
85
+ 225,125,25
86
+ 250,205,255
87
+ """
88
+ with open(rgbtxt_file) as fid:
89
+ data = fid.readlines()
90
+ n = len(data)
91
+ rgb = np.zeros((n, 3))
92
+ for i in np.arange(n):
93
+ rgb[i][0] = data[i].split(split_mark)[0]
94
+ rgb[i][1] = data[i].split(split_mark)[1]
95
+ rgb[i][2] = data[i].split(split_mark)[2]
96
+ max_rgb = np.max(rgb)
97
+ if max_rgb > 2: # 如果rgb值大于2,则认为是0-255的值,需要归一化
98
+ rgb = rgb / 255.0
99
+ icmap = mpl.colors.ListedColormap(rgb, name="my_color")
100
+ return icmap
101
+
102
+
103
+ def choose_cmap(cmap_name=None, query=False):
104
+ """
105
+ description: Choosing a colormap from the list of available colormaps or a custom colormap
106
+ param {*} cmap_name:
107
+ param {*} query:
108
+ return {*}
109
+ """
110
+
111
+ my_cmap_dict = {
112
+ "diverging_1": create_cmap(["#4e00b3", "#0000FF", "#00c0ff", "#a1d3ff", "#DCDCDC", "#FFD39B", "#FF8247", "#FF0000", "#FF5F9E"]),
113
+ "cold_1": create_cmap(["#4e00b3", "#0000FF", "#00c0ff", "#a1d3ff", "#DCDCDC"]),
114
+ "warm_1": create_cmap(["#DCDCDC", "#FFD39B", "#FF8247", "#FF0000", "#FF5F9E"]),
115
+ # "land_1": create_custom(["#3E6436", "#678A59", "#91A176", "#B8A87D", "#D9CBB2"], under="#A6CEE3", over="#FFFFFF"), # 陆地颜色从深绿到浅棕,表示从植被到沙地的递减
116
+ # "ocean_1": create_custom(["#126697", "#2D88B3", "#4EA1C9", "#78B9D8", "#A6CEE3"], under="#8470FF", over="#3E6436"), # 海洋颜色从深蓝到浅蓝,表示从深海到浅海的递减
117
+ # "ocean_land_1": create_custom(
118
+ # [
119
+ # "#126697", # 深蓝(深海)
120
+ # "#2D88B3", # 蓝
121
+ # "#4EA1C9", # 蓝绿
122
+ # "#78B9D8", # 浅蓝(浅海)
123
+ # "#A6CEE3", # 浅蓝(近岸)
124
+ # "#AAAAAA", # 灰色(0值,海平面)
125
+ # "#D9CBB2", # 沙质土壤色(陆地开始)
126
+ # "#B8A87D", # 浅棕
127
+ # "#91A176", # 浅绿
128
+ # "#678A59", # 中绿
129
+ # "#3E6436", # 深绿(高山)
130
+ # ]
131
+ # ),
132
+ "colorful_1": create_cmap(["#6d00db", "#9800cb", "#F2003C", "#ff4500", "#ff7f00", "#FE28A2", "#FFC0CB", "#DDA0DD", "#40E0D0", "#1a66f2", "#00f7fb", "#8fff88", "#E3FF00"]),
133
+ }
134
+ if query:
135
+ for key, _ in my_cmap_dict.items():
136
+ print(key)
137
+
138
+ if cmap_name in my_cmap_dict:
139
+ return my_cmap_dict[cmap_name]
140
+ else:
141
+ try:
142
+ return mpl.colormaps.get_cmap(cmap_name)
143
+ except ValueError:
144
+ raise ValueError(f"Unknown cmap name: {cmap_name}")
145
+
146
+
147
+ if __name__ == "__main__":
148
+ # ** 测试自制cmap
149
+ colors = ["#C2B7F3", "#B3BBF2", "#B0CBF1", "#ACDCF0", "#A8EEED"]
150
+ nodes = [0.0, 0.2, 0.4, 0.6, 1.0]
151
+ c_map = create_cmap(colors, nodes)
152
+ show([c_map])
153
+
154
+ # ** 测试自制diverging型cmap
155
+ diverging_cmap = create_cmap(["#4e00b3", "#0000FF", "#00c0ff", "#a1d3ff", "#DCDCDC", "#FFD39B", "#FF8247", "#FF0000", "#FF5F9E"])
156
+ show([diverging_cmap])
157
+
158
+ # ** 测试根据RGB的txt文档制作色卡
159
+ file_path = "E:/python/colorbar/test.txt"
160
+ cmap_rgb = create_cmap_rgbtxt(file_path)
161
+
162
+ # ** 测试将cmap转为list
163
+ out_colors = cmap2colors("viridis", 256)
@@ -0,0 +1,187 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ '''
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2024-09-17 17:12:47
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2024-11-21 13:13:20
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_data.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.11
14
+ '''
15
+
16
+
17
+ import multiprocessing as mp
18
+ from concurrent.futures import ThreadPoolExecutor
19
+
20
+ import numpy as np
21
+ from scipy.interpolate import griddata
22
+
23
+ __all__ = ['interp_2d', 'interp_2d_parallel']
24
+
25
+ # ** 高维插值函数,插值最后两个维度
26
+
27
+
28
+ def interp_2d(target_x, target_y, origin_x, origin_y, data, method='linear'):
29
+ """
30
+ 高维插值函数,默认插值最后两个维度,传输数据前请确保数据的维度正确
31
+ 参数:
32
+ target_y (array-like): 目标经度网格 1D 或 2D
33
+ target_x (array-like): 目标纬度网格 1D 或 2D
34
+ origin_y (array-like): 初始经度网格 1D 或 2D
35
+ origin_x (array-like): 初始纬度网格 1D 或 2D
36
+ data (array-like): 数据 (*, lat, lon) 2D, 3D, 4D
37
+ method (str, optional): 插值方法,可选 'linear', 'nearest', 'cubic' 等,默认为 'linear'
38
+ 返回:
39
+ array-like: 插值结果
40
+ """
41
+
42
+ # 确保目标网格和初始网格都是二维的
43
+ if len(target_y.shape) == 1:
44
+ target_x, target_y = np.meshgrid(target_x, target_y)
45
+ if len(origin_y.shape) == 1:
46
+ origin_x, origin_y = np.meshgrid(origin_x, origin_y)
47
+
48
+ dims = data.shape
49
+ len_dims = len(dims)
50
+ # print(dims[-2:])
51
+ # 根据经纬度网格判断输入数据的形状是否匹配
52
+
53
+ if origin_x.shape != dims[-2:] or origin_y.shape != dims[-2:]:
54
+ print(origin_x.shape, dims[-2:])
55
+ raise ValueError('Shape of data does not match shape of origin_x or origin_y.')
56
+
57
+ # 将目标网格展平成一维数组
58
+ target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
59
+
60
+ # 将初始网格展平成一维数组
61
+ origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
62
+
63
+ # 进行插值
64
+ if len_dims == 2:
65
+ interpolated_data = griddata(origin_points, np.ravel(data), target_points, method=method)
66
+ interpolated_data = np.reshape(interpolated_data, target_y.shape)
67
+ elif len_dims == 3:
68
+ interpolated_data = []
69
+ for i in range(dims[0]):
70
+ dt = griddata(origin_points, np.ravel(data[i, :, :]), target_points, method=method)
71
+ interpolated_data.append(np.reshape(dt, target_y.shape))
72
+ print(f'Interpolating {i+1}/{dims[0]}...')
73
+ interpolated_data = np.array(interpolated_data)
74
+ elif len_dims == 4:
75
+ interpolated_data = []
76
+ for i in range(dims[0]):
77
+ interpolated_data.append([])
78
+ for j in range(dims[1]):
79
+ dt = griddata(origin_points, np.ravel(data[i, j, :, :]), target_points, method=method)
80
+ interpolated_data[i].append(np.reshape(dt, target_y.shape))
81
+ print(f'\rInterpolating {i*dims[1]+j+1}/{dims[0]*dims[1]}...', end='')
82
+ print('\n')
83
+ interpolated_data = np.array(interpolated_data)
84
+
85
+ return interpolated_data
86
+
87
+
88
+ # ** 高维插值函数,插值最后两个维度,使用多线程进行插值
89
+ # 在本地电脑上可以提速三倍左右,超算上暂时无法加速
90
+ def interp_2d_parallel(target_x, target_y, origin_x, origin_y, data, method='linear'):
91
+ '''
92
+ param {*} target_x 目标经度网格 1D 或 2D
93
+ param {*} target_y 目标纬度网格 1D 或 2D
94
+ param {*} origin_x 初始经度网格 1D 或 2D
95
+ param {*} origin_y 初始纬度网格 1D 或 2D
96
+ param {*} data 数据 (*, lat, lon) 2D, 3D, 4D
97
+ param {*} method 插值方法,可选 'linear', 'nearest', 'cubic' 等,默认为 'linear'
98
+ return {*} 插值结果
99
+ description : 高维插值函数,默认插值最后两个维度,传输数据前请确保数据的维度正确
100
+ example : interpolated_data = interp_2d_parallel(target_x, target_y, origin_x, origin_y, data, method='linear')
101
+ '''
102
+ def interp_single2d(target_y, target_x, origin_y, origin_x, data, method='linear'):
103
+ target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
104
+ origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
105
+
106
+ dt = griddata(origin_points, np.ravel(data[:, :]), target_points, method=method)
107
+ return np.reshape(dt, target_y.shape)
108
+
109
+ def interp_single3d(i, target_y, target_x, origin_y, origin_x, data, method='linear'):
110
+ target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
111
+ origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
112
+
113
+ dt = griddata(origin_points, np.ravel(data[i, :, :]), target_points, method=method)
114
+ return np.reshape(dt, target_y.shape)
115
+
116
+ def interp_single4d(i, j, target_y, target_x, origin_y, origin_x, data, method='linear'):
117
+ target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
118
+ origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
119
+
120
+ dt = griddata(origin_points, np.ravel(data[i, j, :, :]), target_points, method=method)
121
+ return np.reshape(dt, target_y.shape)
122
+
123
+ if len(target_y.shape) == 1:
124
+ target_x, target_y = np.meshgrid(target_x, target_y)
125
+ if len(origin_y.shape) == 1:
126
+ origin_x, origin_y = np.meshgrid(origin_x, origin_y)
127
+
128
+ dims = data.shape
129
+ len_dims = len(dims)
130
+
131
+ if origin_x.shape != dims[-2:] or origin_y.shape != dims[-2:]:
132
+ raise ValueError('数据形状与 origin_x 或 origin_y 的形状不匹配.')
133
+
134
+ interpolated_data = []
135
+
136
+ # 使用多线程进行插值
137
+ with ThreadPoolExecutor(max_workers=mp.cpu_count()-2) as executor:
138
+ print(f'Using {mp.cpu_count()-2} threads...')
139
+ if len_dims == 2:
140
+ interpolated_data = list(executor.map(interp_single2d, [target_y], [target_x], [origin_y], [origin_x], [data], [method]))
141
+ elif len_dims == 3:
142
+ interpolated_data = list(executor.map(interp_single3d, [i for i in range(dims[0])], [target_y]*dims[0], [target_x]*dims[0], [origin_y]*dims[0], [origin_x]*dims[0], [data]*dims[0], [method]*dims[0]))
143
+ elif len_dims == 4:
144
+ interpolated_data = list(executor.map(interp_single4d, [i for i in range(dims[0]) for j in range(dims[1])], [j for i in range(dims[0]) for j in range(dims[1])], [target_y]*dims[0]*dims[1], [target_x]*dims[0]*dims[1], [origin_y]*dims[0]*dims[1], [origin_x]*dims[0]*dims[1], [data]*dims[0]*dims[1], [method]*dims[0]*dims[1]))
145
+ interpolated_data = np.array(interpolated_data).reshape(dims[0], dims[1], target_y.shape[0], target_x.shape[1])
146
+
147
+ interpolated_data = np.array(interpolated_data)
148
+
149
+ return interpolated_data
150
+
151
+
152
+ if __name__ == '__main__':
153
+ import time
154
+
155
+ import matplotlib.pyplot as plt
156
+
157
+ # 测试数据
158
+ origin_x = np.linspace(0, 10, 11)
159
+ origin_y = np.linspace(0, 10, 11)
160
+ target_x = np.linspace(0, 10, 101)
161
+ target_y = np.linspace(0, 10, 101)
162
+ data = np.random.rand(11, 11)
163
+
164
+ # 高维插值
165
+ origin_x = np.linspace(0, 10, 11)
166
+ origin_y = np.linspace(0, 10, 11)
167
+ target_x = np.linspace(0, 10, 101)
168
+ target_y = np.linspace(0, 10, 101)
169
+ data = np.random.rand(10, 10, 11, 11)
170
+
171
+ start = time.time()
172
+ interpolated_data = interp_2d(target_x, target_y, origin_x, origin_y, data)
173
+ print(f'Interpolation time: {time.time()-start:.2f}s')
174
+
175
+ print(interpolated_data.shape)
176
+
177
+ # 高维插值多线程
178
+ start = time.time()
179
+ interpolated_data = interp_2d_parallel(target_x, target_y, origin_x, origin_y, data)
180
+ print(f'Interpolation time: {time.time()-start:.2f}s')
181
+
182
+ print(interpolated_data.shape)
183
+ print(interpolated_data[0, 0, :, :].shape)
184
+ plt.figure()
185
+ plt.contourf(target_x, target_y, interpolated_data[0, 0, :, :])
186
+ plt.colorbar()
187
+ plt.show()