oafuncs 0.0.98.19__py3-none-any.whl → 0.0.98.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oafuncs/_script/data_interp.py +25 -15
- oafuncs/_script/data_interp_geo.py +144 -75
- oafuncs/_script/netcdf_merge.py +27 -28
- oafuncs/_script/netcdf_write.py +57 -7
- oafuncs/_script/parallel.py +4 -1
- oafuncs/oa_data.py +43 -3
- {oafuncs-0.0.98.19.dist-info → oafuncs-0.0.98.20.dist-info}/METADATA +1 -1
- {oafuncs-0.0.98.19.dist-info → oafuncs-0.0.98.20.dist-info}/RECORD +11 -11
- {oafuncs-0.0.98.19.dist-info → oafuncs-0.0.98.20.dist-info}/WHEEL +0 -0
- {oafuncs-0.0.98.19.dist-info → oafuncs-0.0.98.20.dist-info}/licenses/LICENSE.txt +0 -0
- {oafuncs-0.0.98.19.dist-info → oafuncs-0.0.98.20.dist-info}/top_level.txt +0 -0
oafuncs/_script/data_interp.py
CHANGED
@@ -1,3 +1,18 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# coding=utf-8
|
3
|
+
"""
|
4
|
+
Author: Liu Kun && 16031215@qq.com
|
5
|
+
Date: 2025-04-25 16:22:52
|
6
|
+
LastEditors: Liu Kun && 16031215@qq.com
|
7
|
+
LastEditTime: 2025-04-26 19:21:31
|
8
|
+
FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\_script\\data_interp.py
|
9
|
+
Description:
|
10
|
+
EditPlatform: vscode
|
11
|
+
ComputerInfo: XPS 15 9510
|
12
|
+
SystemInfo: Windows 11
|
13
|
+
Python Version: 3.12
|
14
|
+
"""
|
15
|
+
|
1
16
|
from typing import List, Union
|
2
17
|
|
3
18
|
import numpy as np
|
@@ -8,10 +23,10 @@ from oafuncs.oa_tool import PEx
|
|
8
23
|
|
9
24
|
def _interp_single_worker(*args):
|
10
25
|
"""
|
11
|
-
用于PEx并行的单slice插值worker
|
26
|
+
用于PEx并行的单slice插值worker。
|
27
|
+
参数: data_slice, origin_points, target_points, interpolation_method, target_shape
|
12
28
|
"""
|
13
29
|
data_slice, origin_points, target_points, interpolation_method, target_shape = args
|
14
|
-
|
15
30
|
# 过滤掉包含 NaN 的点
|
16
31
|
valid_mask = ~np.isnan(data_slice.ravel())
|
17
32
|
valid_data = data_slice.ravel()[valid_mask]
|
@@ -21,20 +36,16 @@ def _interp_single_worker(*args):
|
|
21
36
|
return np.full(target_shape, np.nanmean(data_slice))
|
22
37
|
|
23
38
|
# 使用有效数据进行插值
|
24
|
-
result = griddata(valid_points, valid_data, target_points, method=interpolation_method)
|
25
|
-
|
39
|
+
result = griddata(valid_points, valid_data, target_points, method=interpolation_method).reshape(target_shape)
|
40
|
+
# 对仍为 NaN 的点用最近邻填充
|
41
|
+
if np.isnan(result).any():
|
42
|
+
nn = griddata(valid_points, valid_data, target_points, method="nearest").reshape(target_shape)
|
43
|
+
result[np.isnan(result)] = nn[np.isnan(result)]
|
26
44
|
|
27
45
|
return result
|
28
46
|
|
29
47
|
|
30
|
-
def interp_2d_func(
|
31
|
-
target_x_coordinates: Union[np.ndarray, List[float]],
|
32
|
-
target_y_coordinates: Union[np.ndarray, List[float]],
|
33
|
-
source_x_coordinates: Union[np.ndarray, List[float]],
|
34
|
-
source_y_coordinates: Union[np.ndarray, List[float]],
|
35
|
-
source_data: np.ndarray,
|
36
|
-
interpolation_method: str = "cubic",
|
37
|
-
) -> np.ndarray:
|
48
|
+
def interp_2d_func(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
|
38
49
|
"""
|
39
50
|
Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
|
40
51
|
|
@@ -46,7 +57,6 @@ def interp_2d_func(
|
|
46
57
|
source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
|
47
58
|
interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
|
48
59
|
>>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
|
49
|
-
use_parallel (bool, optional): Enable parallel processing. Defaults to True.
|
50
60
|
|
51
61
|
Returns:
|
52
62
|
np.ndarray: Interpolated data array.
|
@@ -60,7 +70,7 @@ def interp_2d_func(
|
|
60
70
|
>>> source_x_coordinates = np.array([7, 8, 9])
|
61
71
|
>>> source_y_coordinates = np.array([10, 11, 12])
|
62
72
|
>>> source_data = np.random.rand(3, 3)
|
63
|
-
>>> result =
|
73
|
+
>>> result = interp_2d_func(target_x_coordinates, target_y_coordinates, source_x_coordinates, source_y_coordinates, source_data)
|
64
74
|
>>> print(result.shape) # Expected output: (3, 3)
|
65
75
|
"""
|
66
76
|
if len(target_y_coordinates.shape) == 1:
|
@@ -80,7 +90,7 @@ def interp_2d_func(
|
|
80
90
|
raise ValueError(f"[red]Source data must have at least 2 dimensions, but got {data_dims}.[/red]")
|
81
91
|
elif data_dims > 4:
|
82
92
|
# Or handle cases with more than 4 dimensions if necessary
|
83
|
-
raise ValueError(f"
|
93
|
+
raise ValueError(f"Source data has {data_dims} dimensions, but this function currently supports only up to 4.")
|
84
94
|
|
85
95
|
# Reshape to 4D by adding leading dimensions of size 1 if needed
|
86
96
|
num_dims_to_add = 4 - data_dims
|
@@ -1,98 +1,167 @@
|
|
1
1
|
from typing import List, Union
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
-
from scipy.interpolate import
|
4
|
+
from scipy.interpolate import NearestNDInterpolator, griddata
|
5
5
|
|
6
6
|
from oafuncs.oa_tool import PEx
|
7
|
-
from oafuncs.oa_data import data_clip
|
8
|
-
from oafuncs._script.data_interp import _fill_nan_nearest
|
9
7
|
|
10
8
|
|
11
|
-
def
|
9
|
+
def _normalize_lon(lon, ref_lon):
|
12
10
|
"""
|
13
|
-
|
11
|
+
将经度数组 lon 归一化到与 ref_lon 相同的区间([-180,180] 或 [0,360])
|
12
|
+
并在经度分界(如180/-180, 0/360)附近自动拓宽,避免插值断裂。
|
14
13
|
"""
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
# 处理nan
|
20
|
-
if np.isnan(data_slice).any():
|
21
|
-
mask = np.isnan(data_slice)
|
22
|
-
if mask.any():
|
23
|
-
data_slice = _fill_nan_nearest(data_slice)
|
24
|
-
x1d = np.unique(sx[0, :])
|
25
|
-
y1d = np.unique(sy[:, 0])
|
26
|
-
if sx.shape != (len(y1d), len(x1d)) or sy.shape != (len(y1d), len(x1d)):
|
27
|
-
from scipy.interpolate import griddata
|
28
|
-
|
29
|
-
grid_points = np.column_stack((sx.ravel(), sy.ravel()))
|
30
|
-
grid_values = data_slice.ravel()
|
31
|
-
data_slice = griddata(grid_points, grid_values, (x1d[None, :], y1d[:, None]), method="linear")
|
32
|
-
if interpolation_method == "linear":
|
33
|
-
kx = ky = 1
|
14
|
+
lon = np.asarray(lon)
|
15
|
+
ref_lon = np.asarray(ref_lon)
|
16
|
+
if np.nanmax(ref_lon) > 180:
|
17
|
+
lon = np.where(lon < 0, lon + 360, lon)
|
34
18
|
else:
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
19
|
+
lon = np.where(lon > 180, lon - 360, lon)
|
20
|
+
return lon
|
21
|
+
|
22
|
+
|
23
|
+
def _expand_lonlat_for_dateline(points, values):
|
24
|
+
"""
|
25
|
+
对经度分界(如180/-180, 0/360)附近的数据进行拓宽,避免插值断裂。
|
26
|
+
points: (N,2) [lon,lat]
|
27
|
+
values: (N,)
|
28
|
+
返回拓宽后的 points, values
|
29
|
+
"""
|
30
|
+
lon = points[:, 0]
|
31
|
+
lat = points[:, 1]
|
32
|
+
expanded_points = [points]
|
33
|
+
expanded_values = [values]
|
34
|
+
if (np.nanmax(lon) > 170) and (np.nanmin(lon) < -170):
|
35
|
+
expanded_points.append(np.column_stack((lon + 360, lat)))
|
36
|
+
expanded_points.append(np.column_stack((lon - 360, lat)))
|
37
|
+
expanded_values.append(values)
|
38
|
+
expanded_values.append(values)
|
39
|
+
if (np.nanmax(lon) > 350) and (np.nanmin(lon) < 10):
|
40
|
+
expanded_points.append(np.column_stack((lon - 360, lat)))
|
41
|
+
expanded_points.append(np.column_stack((lon + 360, lat)))
|
42
|
+
expanded_values.append(values)
|
43
|
+
expanded_values.append(values)
|
44
|
+
points_new = np.vstack(expanded_points)
|
45
|
+
values_new = np.concatenate(expanded_values)
|
46
|
+
return points_new, values_new
|
47
|
+
|
48
|
+
|
49
|
+
def _interp_single_worker(*args):
|
52
50
|
"""
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
插值后自动裁剪并用最近邻填充超限和NaN,范围取原始数据的nanmin/nanmax
|
51
|
+
用于PEx并行的单slice插值worker。
|
52
|
+
参数: data_slice, origin_points, target_points, interpolation_method, target_shape
|
53
|
+
球面插值:经纬度转球面坐标后插值
|
57
54
|
"""
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
55
|
+
data_slice, origin_points, target_points, interpolation_method, target_shape = args
|
56
|
+
|
57
|
+
# 经纬度归一化
|
58
|
+
origin_points = origin_points.copy()
|
59
|
+
target_points = target_points.copy()
|
60
|
+
origin_points[:, 0] = _normalize_lon(origin_points[:, 0], target_points[:, 0])
|
61
|
+
target_points[:, 0] = _normalize_lon(target_points[:, 0], origin_points[:, 0])
|
62
|
+
|
63
|
+
def lonlat2xyz(lon, lat):
|
64
|
+
lon_rad = np.deg2rad(lon)
|
65
|
+
lat_rad = np.deg2rad(lat)
|
66
|
+
x = np.cos(lat_rad) * np.cos(lon_rad)
|
67
|
+
y = np.cos(lat_rad) * np.sin(lon_rad)
|
68
|
+
z = np.sin(lat_rad)
|
69
|
+
return np.stack([x, y, z], axis=-1)
|
70
|
+
|
71
|
+
# 过滤掉包含 NaN 的点
|
72
|
+
valid_mask = ~np.isnan(data_slice.ravel())
|
73
|
+
valid_data = data_slice.ravel()[valid_mask]
|
74
|
+
valid_points = origin_points[valid_mask]
|
75
|
+
|
76
|
+
if len(valid_data) < 10:
|
77
|
+
return np.full(target_shape, np.nanmean(data_slice))
|
78
|
+
|
79
|
+
# 拓宽经度分界,避免如179/-181插值断裂
|
80
|
+
valid_points_exp, valid_data_exp = _expand_lonlat_for_dateline(valid_points, valid_data)
|
81
|
+
|
82
|
+
valid_xyz = lonlat2xyz(valid_points_exp[:, 0], valid_points_exp[:, 1])
|
83
|
+
target_xyz = lonlat2xyz(target_points[:, 0], target_points[:, 1])
|
84
|
+
|
85
|
+
# 使用 griddata 的 cubic 插值以获得更好平滑效果
|
86
|
+
result = griddata(valid_xyz, valid_data_exp, target_xyz, method=interpolation_method).reshape(target_shape)
|
87
|
+
|
88
|
+
# 用最近邻处理残余 NaN
|
89
|
+
if np.isnan(result).any():
|
90
|
+
nn_interp = NearestNDInterpolator(valid_xyz, valid_data_exp)
|
91
|
+
nn = nn_interp(target_xyz).reshape(target_shape)
|
92
|
+
result[np.isnan(result)] = nn[np.isnan(result)]
|
93
|
+
|
94
|
+
return result
|
95
|
+
|
96
|
+
|
97
|
+
def interp_2d_func_geo(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
|
98
|
+
"""
|
99
|
+
Perform 2D interpolation on the last two dimensions of a multi-dimensional array (spherical coordinates).
|
100
|
+
使用球面坐标系进行插值,适用于全球尺度的地理数据,能正确处理经度跨越日期线的情况。
|
101
|
+
|
102
|
+
Args:
|
103
|
+
target_x_coordinates (Union[np.ndarray, List[float]]): Target grid's longitude (-180 to 180 or 0 to 360).
|
104
|
+
target_y_coordinates (Union[np.ndarray, List[float]]): Target grid's latitude (-90 to 90).
|
105
|
+
source_x_coordinates (Union[np.ndarray, List[float]]): Original grid's longitude (-180 to 180 or 0 to 360).
|
106
|
+
source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's latitude (-90 to 90).
|
107
|
+
source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
|
108
|
+
interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
|
109
|
+
>>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
np.ndarray: Interpolated data array.
|
113
|
+
|
114
|
+
Raises:
|
115
|
+
ValueError: If input shapes are invalid.
|
116
|
+
|
117
|
+
Examples:
|
118
|
+
>>> # 创建一个全球网格示例
|
119
|
+
>>> target_lon = np.arange(-180, 181, 1) # 1度分辨率目标网格
|
120
|
+
>>> target_lat = np.arange(-90, 91, 1)
|
121
|
+
>>> source_lon = np.arange(-180, 181, 5) # 5度分辨率源网格
|
122
|
+
>>> source_lat = np.arange(-90, 91, 5)
|
123
|
+
>>> # 创建一个简单的数据场 (例如温度场)
|
124
|
+
>>> source_data = np.cos(np.deg2rad(source_lat.reshape(-1, 1))) * np.cos(np.deg2rad(source_lon))
|
125
|
+
>>> # 插值到高分辨率网格
|
126
|
+
>>> result = interp_2d_geo(target_lon, target_lat, source_lon, source_lat, source_data)
|
127
|
+
>>> print(result.shape) # Expected output: (181, 361)
|
128
|
+
"""
|
129
|
+
# 验证输入数据范围
|
130
|
+
if np.nanmin(target_y_coordinates) < -90 or np.nanmax(target_y_coordinates) > 90:
|
131
|
+
raise ValueError("[red]Target latitude must be in range [-90, 90].[/red]")
|
132
|
+
if np.nanmin(source_y_coordinates) < -90 or np.nanmax(source_y_coordinates) > 90:
|
133
|
+
raise ValueError("[red]Source latitude must be in range [-90, 90].[/red]")
|
134
|
+
|
135
|
+
if len(target_y_coordinates.shape) == 1:
|
136
|
+
target_x_coordinates, target_y_coordinates = np.meshgrid(target_x_coordinates, target_y_coordinates)
|
137
|
+
if len(source_y_coordinates.shape) == 1:
|
138
|
+
source_x_coordinates, source_y_coordinates = np.meshgrid(source_x_coordinates, source_y_coordinates)
|
139
|
+
|
140
|
+
if source_x_coordinates.shape != source_data.shape[-2:] or source_y_coordinates.shape != source_data.shape[-2:]:
|
141
|
+
raise ValueError("[red]Shape of source_data does not match shape of source_x_coordinates or source_y_coordinates.[/red]")
|
142
|
+
|
143
|
+
target_points = np.column_stack((np.array(target_x_coordinates).ravel(), np.array(target_y_coordinates).ravel()))
|
144
|
+
origin_points = np.column_stack((np.array(source_x_coordinates).ravel(), np.array(source_y_coordinates).ravel()))
|
145
|
+
|
146
|
+
data_dims = len(source_data.shape)
|
74
147
|
if data_dims < 2:
|
75
|
-
raise ValueError("Source data must have at least 2 dimensions.")
|
148
|
+
raise ValueError(f"[red]Source data must have at least 2 dimensions, but got {data_dims}.[/red]")
|
76
149
|
elif data_dims > 4:
|
77
|
-
raise ValueError("Source data has
|
150
|
+
raise ValueError(f"Source data has {data_dims} dimensions, but this function currently supports only up to 4.")
|
78
151
|
|
79
152
|
num_dims_to_add = 4 - data_dims
|
80
|
-
new_shape = (1,) * num_dims_to_add +
|
81
|
-
|
82
|
-
t, z, ny, nx = data4d.shape
|
153
|
+
new_shape = (1,) * num_dims_to_add + source_data.shape
|
154
|
+
new_src_data = source_data.reshape(new_shape)
|
83
155
|
|
84
|
-
|
85
|
-
target_shape = ty.shape
|
156
|
+
t, z, y, x = new_src_data.shape
|
86
157
|
|
87
|
-
# 并行参数准备
|
88
158
|
params = []
|
89
|
-
|
90
|
-
|
91
|
-
|
159
|
+
target_shape = target_y_coordinates.shape
|
160
|
+
for t_index in range(t):
|
161
|
+
for z_index in range(z):
|
162
|
+
params.append((new_src_data[t_index, z_index], origin_points, target_points, interpolation_method, target_shape))
|
92
163
|
|
93
164
|
with PEx() as excutor:
|
94
165
|
result = excutor.run(_interp_single_worker, params)
|
95
166
|
|
96
|
-
|
97
|
-
result = np.squeeze(result)
|
98
|
-
return result
|
167
|
+
return np.squeeze(np.array(result).reshape(t, z, *target_shape))
|
oafuncs/_script/netcdf_merge.py
CHANGED
@@ -1,26 +1,10 @@
|
|
1
|
-
|
2
|
-
# coding=utf-8
|
3
|
-
"""
|
4
|
-
Author: Liu Kun && 16031215@qq.com
|
5
|
-
Date: 2025-03-30 11:16:29
|
6
|
-
LastEditors: Liu Kun && 16031215@qq.com
|
7
|
-
LastEditTime: 2025-04-25 14:23:10
|
8
|
-
FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\_script\\netcdf_merge.py
|
9
|
-
Description
|
10
|
-
EditPlatform: vscode
|
11
|
-
ComputerInfo: XPS 15 9510
|
12
|
-
SystemInfo: Windows 11
|
13
|
-
Python Version: 3.12
|
14
|
-
"""
|
15
|
-
|
1
|
+
import logging
|
16
2
|
import os
|
17
3
|
from typing import List, Optional, Union
|
18
|
-
|
19
4
|
import numpy as np
|
20
5
|
import xarray as xr
|
21
6
|
|
22
7
|
from oafuncs import pbar
|
23
|
-
import logging
|
24
8
|
|
25
9
|
|
26
10
|
def merge_nc(file_list: Union[str, List[str]], var_name: Optional[Union[str, List[str]]] = None, dim_name: Optional[str] = None, target_filename: Optional[str] = None) -> None:
|
@@ -75,26 +59,41 @@ def merge_nc(file_list: Union[str, List[str]], var_name: Optional[Union[str, Lis
|
|
75
59
|
if dim_name in data_var.dims:
|
76
60
|
merged_data.setdefault(var, []).append(data_var)
|
77
61
|
elif var not in merged_data:
|
78
|
-
#
|
79
|
-
if
|
80
|
-
merged_data[var] = data_var
|
81
|
-
|
82
|
-
|
62
|
+
# 区分更细致的类型,不仅是时间型和非时间型
|
63
|
+
if data_var.dtype.kind == "M": # datetime64类型
|
64
|
+
merged_data[var] = data_var # 时间类型在save_to_nc处理
|
65
|
+
elif data_var.dtype.kind in ["f", "i", "u"]: # 数值类型
|
66
|
+
# 对数值型变量用-32767填充NaN,而不是0
|
67
|
+
merged_data[var] = data_var.fillna(-32767)
|
68
|
+
else: # 字符串或其他类型
|
69
|
+
merged_data[var] = data_var # 非数值类型保持原样
|
83
70
|
|
84
71
|
for var in pbar(merged_data, description="Merging variables"):
|
85
72
|
if isinstance(merged_data[var], list):
|
86
|
-
#
|
87
|
-
if
|
88
|
-
merged_data[var] = xr.concat(merged_data[var], dim=dim_name)
|
89
|
-
|
90
|
-
|
73
|
+
# 区分更细致的类型处理
|
74
|
+
if merged_data[var][0].dtype.kind == "M": # datetime64类型
|
75
|
+
merged_data[var] = xr.concat(merged_data[var], dim=dim_name)
|
76
|
+
elif merged_data[var][0].dtype.kind in ["f", "i", "u"]: # 数值类型
|
77
|
+
# 使用-32767填充NaN,而不是0
|
78
|
+
merged_data[var] = xr.concat(merged_data[var], dim=dim_name).fillna(-32767)
|
79
|
+
else: # 字符串或其他类型
|
80
|
+
merged_data[var] = xr.concat(merged_data[var], dim=dim_name)
|
81
|
+
|
82
|
+
# 在构建最终数据集前,再次检查确保所有数值型变量没有NaN
|
83
|
+
merged_ds = xr.Dataset(merged_data)
|
84
|
+
for var_name in merged_ds.data_vars:
|
85
|
+
var = merged_ds[var_name]
|
86
|
+
if var.dtype.kind in ["f", "i", "u"]:
|
87
|
+
if np.isnan(var.values).any():
|
88
|
+
logging.warning(f"变量 {var_name} 在合并后仍包含NaN值,将替换为-32767")
|
89
|
+
merged_ds[var_name] = var.fillna(-32767)
|
91
90
|
|
92
91
|
if os.path.exists(target_filename):
|
93
92
|
# print("Warning: The target file already exists. Removing it ...")
|
94
93
|
logging.warning("The target file already exists. Removing it ...")
|
95
94
|
os.remove(target_filename)
|
96
95
|
|
97
|
-
save_to_nc(target_filename,
|
96
|
+
save_to_nc(target_filename, merged_ds)
|
98
97
|
|
99
98
|
|
100
99
|
# Example usage
|
oafuncs/_script/netcdf_write.py
CHANGED
@@ -1,13 +1,55 @@
|
|
1
1
|
import os
|
2
|
+
import warnings
|
2
3
|
|
3
4
|
import netCDF4 as nc
|
4
5
|
import numpy as np
|
5
6
|
import xarray as xr
|
6
|
-
import warnings
|
7
7
|
|
8
8
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
9
9
|
|
10
10
|
|
11
|
+
def _nan_to_fillvalue(ncfile):
|
12
|
+
"""
|
13
|
+
将 NetCDF 文件中所有变量的 NaN 和掩码值替换为其 _FillValue 属性(若无则自动添加 _FillValue=-32767 并替换)。
|
14
|
+
同时处理掩码数组中的无效值。
|
15
|
+
仅对数值型变量(浮点型、整型)生效。
|
16
|
+
"""
|
17
|
+
with nc.Dataset(ncfile, "r+") as ds:
|
18
|
+
for var_name in ds.variables:
|
19
|
+
var = ds.variables[var_name]
|
20
|
+
# 只处理数值类型变量 (f:浮点型, i:有符号整型, u:无符号整型)
|
21
|
+
if var.dtype.kind not in ["f", "i", "u"]:
|
22
|
+
continue
|
23
|
+
|
24
|
+
# 读取数据
|
25
|
+
arr = var[:]
|
26
|
+
|
27
|
+
# 确定填充值
|
28
|
+
if "_FillValue" in var.ncattrs():
|
29
|
+
fill_value = var.getncattr("_FillValue")
|
30
|
+
else:
|
31
|
+
fill_value = -32767
|
32
|
+
try:
|
33
|
+
var.setncattr("_FillValue", fill_value)
|
34
|
+
except Exception:
|
35
|
+
# 某些变量可能不允许动态添加 _FillValue
|
36
|
+
continue
|
37
|
+
|
38
|
+
# 处理掩码数组
|
39
|
+
if hasattr(arr, "mask"):
|
40
|
+
# 如果是掩码数组,将掩码位置的值设为 fill_value
|
41
|
+
if np.any(arr.mask):
|
42
|
+
# 转换为普通数组,掩码位置填入 fill_value
|
43
|
+
arr = np.ma.filled(arr, fill_value=fill_value)
|
44
|
+
|
45
|
+
# 处理剩余NaN值
|
46
|
+
if np.any(np.isnan(arr)):
|
47
|
+
arr = np.nan_to_num(arr, nan=fill_value, posinf=fill_value, neginf=fill_value)
|
48
|
+
|
49
|
+
# 写回变量
|
50
|
+
var[:] = arr
|
51
|
+
|
52
|
+
|
11
53
|
def _numpy_to_nc_type(numpy_type):
|
12
54
|
"""将 NumPy 数据类型映射到 NetCDF 数据类型"""
|
13
55
|
numpy_to_nc = {
|
@@ -30,17 +72,22 @@ def _calculate_scale_and_offset(data, n=16):
|
|
30
72
|
"""
|
31
73
|
计算数值型数据的 scale_factor 与 add_offset,
|
32
74
|
将数据映射到 [0, 2**n - 1] 的范围。
|
33
|
-
|
34
|
-
要求 data 为数值型的 NumPy 数组,不允许全 NaN 值。
|
75
|
+
同时处理数据中的 NaN 值。
|
35
76
|
"""
|
36
77
|
if not isinstance(data, np.ndarray):
|
37
78
|
raise ValueError("Input data must be a NumPy array.")
|
38
79
|
|
39
|
-
|
40
|
-
|
80
|
+
# 先处理数据中的 NaN 和无穷值
|
81
|
+
clean_data = np.nan_to_num(data, nan=-32767, posinf=np.finfo(float).max, neginf=np.finfo(float).min)
|
41
82
|
|
42
|
-
|
43
|
-
|
83
|
+
# 计算有效值的最小最大值(排除填充值)
|
84
|
+
valid_mask = clean_data != -32767
|
85
|
+
if np.any(valid_mask):
|
86
|
+
data_min = np.min(clean_data[valid_mask])
|
87
|
+
data_max = np.max(clean_data[valid_mask])
|
88
|
+
else:
|
89
|
+
# 全都是填充值的情况,使用默认范围
|
90
|
+
data_min, data_max = 0, 1
|
44
91
|
|
45
92
|
if data_max == data_min:
|
46
93
|
scale_factor = 1.0
|
@@ -111,6 +158,7 @@ def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_swi
|
|
111
158
|
new_da.to_dataset(name=varname).to_netcdf(file, mode=mode, encoding=encoding)
|
112
159
|
else:
|
113
160
|
data.to_dataset(name=varname).to_netcdf(file, mode=mode)
|
161
|
+
_nan_to_fillvalue(file) # 替换 NaN 为 _FillValue
|
114
162
|
return
|
115
163
|
|
116
164
|
else:
|
@@ -139,6 +187,7 @@ def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_swi
|
|
139
187
|
new_ds.to_netcdf(file, mode=mode, encoding=encoding)
|
140
188
|
else:
|
141
189
|
new_ds.to_netcdf(file, mode=mode)
|
190
|
+
_nan_to_fillvalue(file) # 替换 NaN 为 _FillValue
|
142
191
|
return
|
143
192
|
|
144
193
|
# 处理纯 numpy 数组情况
|
@@ -172,6 +221,7 @@ def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_swi
|
|
172
221
|
dtype = _numpy_to_nc_type(data.dtype)
|
173
222
|
var = ncfile.createVariable(varname, dtype, dims, zlib=False)
|
174
223
|
var[:] = data
|
224
|
+
_nan_to_fillvalue(file) # 替换 NaN 为 _FillValue
|
175
225
|
except Exception as e:
|
176
226
|
raise RuntimeError(f"netCDF4 保存失败: {str(e)}") from e
|
177
227
|
|
oafuncs/_script/parallel.py
CHANGED
@@ -6,6 +6,7 @@ import time
|
|
6
6
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
7
7
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
8
8
|
|
9
|
+
|
9
10
|
import psutil
|
10
11
|
|
11
12
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
@@ -14,6 +15,7 @@ __all__ = ["ParallelExecutor"]
|
|
14
15
|
|
15
16
|
|
16
17
|
class ParallelExecutor:
|
18
|
+
|
17
19
|
def __init__(
|
18
20
|
self,
|
19
21
|
max_workers: Optional[int] = None,
|
@@ -120,6 +122,7 @@ class ParallelExecutor:
|
|
120
122
|
self._executor = None
|
121
123
|
|
122
124
|
def _execute_batch(self, func: Callable, params: List[Tuple], chunk_size: int) -> List[Any]:
|
125
|
+
from oafuncs.oa_tool import pbar
|
123
126
|
if not params:
|
124
127
|
return []
|
125
128
|
|
@@ -129,7 +132,7 @@ class ParallelExecutor:
|
|
129
132
|
results = [None] * len(params)
|
130
133
|
with self._get_executor() as executor:
|
131
134
|
futures = {executor.submit(func, *args): idx for idx, args in enumerate(params)}
|
132
|
-
for future in as_completed(futures):
|
135
|
+
for future in pbar(as_completed(futures), "Parallel Tasks", total=len(futures)):
|
133
136
|
idx = futures[future]
|
134
137
|
try:
|
135
138
|
results[idx] = future.result(timeout=self.timeout_per_task)
|
oafuncs/oa_data.py
CHANGED
@@ -23,7 +23,7 @@ from rich import print
|
|
23
23
|
from scipy.interpolate import interp1d
|
24
24
|
|
25
25
|
|
26
|
-
__all__ = ["interp_along_dim", "interp_2d", "ensure_list", "mask_shapefile"]
|
26
|
+
__all__ = ["interp_along_dim", "interp_2d", "interp_2d_geo", "ensure_list", "mask_shapefile"]
|
27
27
|
|
28
28
|
|
29
29
|
def ensure_list(input_value: Any) -> List[str]:
|
@@ -120,7 +120,7 @@ def interp_2d(
|
|
120
120
|
source_x_coordinates: Union[np.ndarray, List[float]],
|
121
121
|
source_y_coordinates: Union[np.ndarray, List[float]],
|
122
122
|
source_data: np.ndarray,
|
123
|
-
interpolation_method: str = "
|
123
|
+
interpolation_method: str = "cubic",
|
124
124
|
) -> np.ndarray:
|
125
125
|
"""
|
126
126
|
Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
|
@@ -132,7 +132,7 @@ def interp_2d(
|
|
132
132
|
source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's y-coordinates.
|
133
133
|
source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
|
134
134
|
>>> must be [y, x] or [*, y, x] or [*, *, y, x]
|
135
|
-
interpolation_method (str, optional): Interpolation method. Defaults to "
|
135
|
+
interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
|
136
136
|
>>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
|
137
137
|
use_parallel (bool, optional): Enable parallel processing. Defaults to True.
|
138
138
|
|
@@ -162,7 +162,47 @@ def interp_2d(
|
|
162
162
|
interpolation_method=interpolation_method,
|
163
163
|
)
|
164
164
|
|
165
|
+
def interp_2d_geo(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
|
166
|
+
"""
|
167
|
+
Perform 2D interpolation on the last two dimensions of a multi-dimensional array (spherical coordinates).
|
168
|
+
使用球面坐标系进行插值,适用于全球尺度的地理数据,能正确处理经度跨越日期线的情况。
|
169
|
+
|
170
|
+
Args:
|
171
|
+
target_x_coordinates (Union[np.ndarray, List[float]]): Target grid's longitude (-180 to 180 or 0 to 360).
|
172
|
+
target_y_coordinates (Union[np.ndarray, List[float]]): Target grid's latitude (-90 to 90).
|
173
|
+
source_x_coordinates (Union[np.ndarray, List[float]]): Original grid's longitude (-180 to 180 or 0 to 360).
|
174
|
+
source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's latitude (-90 to 90).
|
175
|
+
source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
|
176
|
+
interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
|
177
|
+
>>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
np.ndarray: Interpolated data array.
|
181
|
+
|
182
|
+
Raises:
|
183
|
+
ValueError: If input shapes are invalid.
|
165
184
|
|
185
|
+
Examples:
|
186
|
+
>>> # 创建一个全球网格示例
|
187
|
+
>>> target_lon = np.arange(-180, 181, 1) # 1度分辨率目标网格
|
188
|
+
>>> target_lat = np.arange(-90, 91, 1)
|
189
|
+
>>> source_lon = np.arange(-180, 181, 5) # 5度分辨率源网格
|
190
|
+
>>> source_lat = np.arange(-90, 91, 5)
|
191
|
+
>>> # 创建一个简单的数据场 (例如温度场)
|
192
|
+
>>> source_data = np.cos(np.deg2rad(source_lat.reshape(-1, 1))) * np.cos(np.deg2rad(source_lon))
|
193
|
+
>>> # 插值到高分辨率网格
|
194
|
+
>>> result = interp_2d_geo(target_lon, target_lat, source_lon, source_lat, source_data)
|
195
|
+
>>> print(result.shape) # Expected output: (181, 361)
|
196
|
+
"""
|
197
|
+
from ._script.data_interp_geo import interp_2d_func_geo
|
198
|
+
interp_2d_func_geo(
|
199
|
+
target_x_coordinates=target_x_coordinates,
|
200
|
+
target_y_coordinates=target_y_coordinates,
|
201
|
+
source_x_coordinates=source_x_coordinates,
|
202
|
+
source_y_coordinates=source_y_coordinates,
|
203
|
+
source_data=source_data,
|
204
|
+
interpolation_method=interpolation_method,
|
205
|
+
)
|
166
206
|
|
167
207
|
def mask_shapefile(
|
168
208
|
data_array: np.ndarray,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
oafuncs/__init__.py,sha256=T_-VtnWWllV3Q91twT5Yt2sUapeA051QbPNnBxmg9nw,1456
|
2
2
|
oafuncs/oa_cmap.py,sha256=DimWT4Bg7uE5Lx8hSw1REp7whpsR2pFRStAwk1cowEM,11494
|
3
|
-
oafuncs/oa_data.py,sha256=
|
3
|
+
oafuncs/oa_data.py,sha256=Aat9ktxxRGevaqQya3IJWfXeoEs-FCXGUcNE2pKnzfU,10931
|
4
4
|
oafuncs/oa_date.py,sha256=WhM6cyD4G3IeghjLTHhAMtlvJbA7kwQG2sHnxdTgyso,6303
|
5
5
|
oafuncs/oa_draw.py,sha256=Wj2QBgyIPpV_dxaDrH10jqj_puK9ZM9rd-si-3VrsrE,17631
|
6
6
|
oafuncs/oa_file.py,sha256=j9gXJgPOJsliu4IOUc4bc-luW4yBvQyNCEmMyDVjUwQ,16404
|
@@ -11,13 +11,13 @@ oafuncs/oa_tool.py,sha256=rpPkLqWhqMmqlCc5wjL8qMTg3gThCkSrYJckbX_0iJc,8631
|
|
11
11
|
oafuncs/_data/hycom.png,sha256=MadKs6Gyj5n9-TOu7L4atQfTXtF9dvN9w-tdU9IfygI,10945710
|
12
12
|
oafuncs/_data/oafuncs.png,sha256=o3VD7wm-kwDea5E98JqxXl04_78cBX7VcdUt7uQXGiU,3679898
|
13
13
|
oafuncs/_script/cprogressbar.py,sha256=UIgGcLFs-6IgWlITuBLaQqrpt4OAK3Mst5RlCiNfZdQ,15772
|
14
|
-
oafuncs/_script/data_interp.py,sha256=
|
15
|
-
oafuncs/_script/data_interp_geo.py,sha256=
|
14
|
+
oafuncs/_script/data_interp.py,sha256=EiZbt6n5BEaRKcng88UgX7TFPhKE6TLVZniS01awXjg,5146
|
15
|
+
oafuncs/_script/data_interp_geo.py,sha256=ZRFb3fKRiYQViZNHd19eW20C9i38BsiIU8w0fG5mbqM,7789
|
16
16
|
oafuncs/_script/email.py,sha256=lL4HGKrr524-g0xLlgs-4u7x4-u7DtgNoD9AL8XJKj4,3058
|
17
|
-
oafuncs/_script/netcdf_merge.py,sha256=
|
17
|
+
oafuncs/_script/netcdf_merge.py,sha256=zasqrFpB2GHAjZ1LkrWtI7kpHu7uCjdIGf4C0lEraYA,4816
|
18
18
|
oafuncs/_script/netcdf_modify.py,sha256=sGRUYNhfGgf9JV70rnBzw3bzuTRSXzBTL_RMDnDPeLQ,4552
|
19
|
-
oafuncs/_script/netcdf_write.py,sha256=
|
20
|
-
oafuncs/_script/parallel.py,sha256=
|
19
|
+
oafuncs/_script/netcdf_write.py,sha256=gEqeagAcu0_xa6cIuu7a5d9uXMEaGqnci02KNc5znEY,10672
|
20
|
+
oafuncs/_script/parallel.py,sha256=07-BJVHxXJNlrOrhrSGt7qCZiKWq6dBvNDBA1AANYnI,8861
|
21
21
|
oafuncs/_script/parallel_test.py,sha256=0GBqZOX7IaCOKF2t1y8N8YYu53GJ33OkfsWgpvZNqM4,372
|
22
22
|
oafuncs/_script/plot_dataset.py,sha256=zkSEnO_-biyagorwWXPoihts_cwuvripzEt-l9bHJ2E,13989
|
23
23
|
oafuncs/_script/replace_file_content.py,sha256=eCFZjnZcwyRvy6b4mmIfBna-kylSZTyJRfgXd6DdCjk,5982
|
@@ -39,8 +39,8 @@ oafuncs/oa_sign/__init__.py,sha256=QKqTFrJDFK40C5uvk48GlRRbGFzO40rgkYwu6dYxatM,5
|
|
39
39
|
oafuncs/oa_sign/meteorological.py,sha256=8091SHo2L8kl4dCFmmSH5NGVHDku5i5lSiLEG5DLnOQ,6489
|
40
40
|
oafuncs/oa_sign/ocean.py,sha256=xrW-rWD7xBWsB5PuCyEwQ1Q_RDKq2KCLz-LOONHgldU,5932
|
41
41
|
oafuncs/oa_sign/scientific.py,sha256=a4JxOBgm9vzNZKpJ_GQIQf7cokkraV5nh23HGbmTYKw,5064
|
42
|
-
oafuncs-0.0.98.
|
43
|
-
oafuncs-0.0.98.
|
44
|
-
oafuncs-0.0.98.
|
45
|
-
oafuncs-0.0.98.
|
46
|
-
oafuncs-0.0.98.
|
42
|
+
oafuncs-0.0.98.20.dist-info/licenses/LICENSE.txt,sha256=rMtLpVg8sKiSlwClfR9w_Dd_5WubTQgoOzE2PDFxzs4,1074
|
43
|
+
oafuncs-0.0.98.20.dist-info/METADATA,sha256=QgOUX1DWWv2bLiGauzfcJsKtuGQfyJkinMJoCO7FRdI,4273
|
44
|
+
oafuncs-0.0.98.20.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
45
|
+
oafuncs-0.0.98.20.dist-info/top_level.txt,sha256=bgC35QkXbN4EmPHEveg_xGIZ5i9NNPYWqtJqaKqTPsQ,8
|
46
|
+
oafuncs-0.0.98.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|