oafuncs 0.0.98.19__py3-none-any.whl → 0.0.98.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,18 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2025-04-25 16:22:52
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2025-04-26 19:21:31
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\_script\\data_interp.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.12
14
+ """
15
+
1
16
  from typing import List, Union
2
17
 
3
18
  import numpy as np
@@ -8,10 +23,10 @@ from oafuncs.oa_tool import PEx
8
23
 
9
24
  def _interp_single_worker(*args):
10
25
  """
11
- 用于PEx并行的单slice插值worker,参数为(t, z, source_data, origin_points, target_points, interpolation_method, target_shape)
26
+ 用于PEx并行的单slice插值worker
27
+ 参数: data_slice, origin_points, target_points, interpolation_method, target_shape
12
28
  """
13
29
  data_slice, origin_points, target_points, interpolation_method, target_shape = args
14
-
15
30
  # 过滤掉包含 NaN 的点
16
31
  valid_mask = ~np.isnan(data_slice.ravel())
17
32
  valid_data = data_slice.ravel()[valid_mask]
@@ -21,20 +36,16 @@ def _interp_single_worker(*args):
21
36
  return np.full(target_shape, np.nanmean(data_slice))
22
37
 
23
38
  # 使用有效数据进行插值
24
- result = griddata(valid_points, valid_data, target_points, method=interpolation_method)
25
- result = result.reshape(target_shape)
39
+ result = griddata(valid_points, valid_data, target_points, method=interpolation_method).reshape(target_shape)
40
+ # 对仍为 NaN 的点用最近邻填充
41
+ if np.isnan(result).any():
42
+ nn = griddata(valid_points, valid_data, target_points, method="nearest").reshape(target_shape)
43
+ result[np.isnan(result)] = nn[np.isnan(result)]
26
44
 
27
45
  return result
28
46
 
29
47
 
30
- def interp_2d_func(
31
- target_x_coordinates: Union[np.ndarray, List[float]],
32
- target_y_coordinates: Union[np.ndarray, List[float]],
33
- source_x_coordinates: Union[np.ndarray, List[float]],
34
- source_y_coordinates: Union[np.ndarray, List[float]],
35
- source_data: np.ndarray,
36
- interpolation_method: str = "cubic",
37
- ) -> np.ndarray:
48
+ def interp_2d_func(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
38
49
  """
39
50
  Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
40
51
 
@@ -46,7 +57,6 @@ def interp_2d_func(
46
57
  source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
47
58
  interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
48
59
  >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
49
- use_parallel (bool, optional): Enable parallel processing. Defaults to True.
50
60
 
51
61
  Returns:
52
62
  np.ndarray: Interpolated data array.
@@ -60,7 +70,7 @@ def interp_2d_func(
60
70
  >>> source_x_coordinates = np.array([7, 8, 9])
61
71
  >>> source_y_coordinates = np.array([10, 11, 12])
62
72
  >>> source_data = np.random.rand(3, 3)
63
- >>> result = interp_2d(target_x_coordinates, target_y_coordinates, source_x_coordinates, source_y_coordinates, source_data)
73
+ >>> result = interp_2d_func(target_x_coordinates, target_y_coordinates, source_x_coordinates, source_y_coordinates, source_data)
64
74
  >>> print(result.shape) # Expected output: (3, 3)
65
75
  """
66
76
  if len(target_y_coordinates.shape) == 1:
@@ -80,7 +90,7 @@ def interp_2d_func(
80
90
  raise ValueError(f"[red]Source data must have at least 2 dimensions, but got {data_dims}.[/red]")
81
91
  elif data_dims > 4:
82
92
  # Or handle cases with more than 4 dimensions if necessary
83
- raise ValueError(f"[red]Source data has {data_dims} dimensions, but this function currently supports only up to 4.[/red]")
93
+ raise ValueError(f"Source data has {data_dims} dimensions, but this function currently supports only up to 4.")
84
94
 
85
95
  # Reshape to 4D by adding leading dimensions of size 1 if needed
86
96
  num_dims_to_add = 4 - data_dims
@@ -1,98 +1,167 @@
1
1
  from typing import List, Union
2
2
 
3
3
  import numpy as np
4
- from scipy.interpolate import RectBivariateSpline
4
+ from scipy.interpolate import NearestNDInterpolator, griddata
5
5
 
6
6
  from oafuncs.oa_tool import PEx
7
- from oafuncs.oa_data import data_clip
8
- from oafuncs._script.data_interp import _fill_nan_nearest
9
7
 
10
8
 
11
- def _interp_single_worker(*args):
9
+ def _normalize_lon(lon, ref_lon):
12
10
  """
13
- 单slice插值worker,参数为(data_slice, sx, sy, tx, ty, interpolation_method, data_min, data_max)
11
+ 将经度数组 lon 归一化到与 ref_lon 相同的区间([-180,180] [0,360])
12
+ 并在经度分界(如180/-180, 0/360)附近自动拓宽,避免插值断裂。
14
13
  """
15
- # 兼容PEx调用方式:args为tuple或list
16
- if len(args) == 1 and isinstance(args[0], (tuple, list)):
17
- args = args[0]
18
- data_slice, sx, sy, tx, ty, interpolation_method, data_min, data_max = args
19
- # 处理nan
20
- if np.isnan(data_slice).any():
21
- mask = np.isnan(data_slice)
22
- if mask.any():
23
- data_slice = _fill_nan_nearest(data_slice)
24
- x1d = np.unique(sx[0, :])
25
- y1d = np.unique(sy[:, 0])
26
- if sx.shape != (len(y1d), len(x1d)) or sy.shape != (len(y1d), len(x1d)):
27
- from scipy.interpolate import griddata
28
-
29
- grid_points = np.column_stack((sx.ravel(), sy.ravel()))
30
- grid_values = data_slice.ravel()
31
- data_slice = griddata(grid_points, grid_values, (x1d[None, :], y1d[:, None]), method="linear")
32
- if interpolation_method == "linear":
33
- kx = ky = 1
14
+ lon = np.asarray(lon)
15
+ ref_lon = np.asarray(ref_lon)
16
+ if np.nanmax(ref_lon) > 180:
17
+ lon = np.where(lon < 0, lon + 360, lon)
34
18
  else:
35
- kx = ky = 3
36
- interp_func = RectBivariateSpline(y1d, x1d, data_slice, kx=kx, ky=ky)
37
- out = interp_func(ty[:, 0], tx[0, :])
38
- # 优化裁剪逻辑:超出范围的点设为nan,再用fill_nan_nearest填充
39
- arr = np.asarray(out)
40
- arr = data_clip(arr,data_min,data_max)
41
- return arr
42
-
43
-
44
- def interp_2d_geo(
45
- target_x_coordinates: Union[np.ndarray, List[float]],
46
- target_y_coordinates: Union[np.ndarray, List[float]],
47
- source_x_coordinates: Union[np.ndarray, List[float]],
48
- source_y_coordinates: Union[np.ndarray, List[float]],
49
- source_data: np.ndarray,
50
- interpolation_method: str = "cubic",
51
- ) -> np.ndarray:
19
+ lon = np.where(lon > 180, lon - 360, lon)
20
+ return lon
21
+
22
+
23
+ def _expand_lonlat_for_dateline(points, values):
24
+ """
25
+ 对经度分界(如180/-180, 0/360)附近的数据进行拓宽,避免插值断裂。
26
+ points: (N,2) [lon,lat]
27
+ values: (N,)
28
+ 返回拓宽后的 points, values
29
+ """
30
+ lon = points[:, 0]
31
+ lat = points[:, 1]
32
+ expanded_points = [points]
33
+ expanded_values = [values]
34
+ if (np.nanmax(lon) > 170) and (np.nanmin(lon) < -170):
35
+ expanded_points.append(np.column_stack((lon + 360, lat)))
36
+ expanded_points.append(np.column_stack((lon - 360, lat)))
37
+ expanded_values.append(values)
38
+ expanded_values.append(values)
39
+ if (np.nanmax(lon) > 350) and (np.nanmin(lon) < 10):
40
+ expanded_points.append(np.column_stack((lon - 360, lat)))
41
+ expanded_points.append(np.column_stack((lon + 360, lat)))
42
+ expanded_values.append(values)
43
+ expanded_values.append(values)
44
+ points_new = np.vstack(expanded_points)
45
+ values_new = np.concatenate(expanded_values)
46
+ return points_new, values_new
47
+
48
+
49
+ def _interp_single_worker(*args):
52
50
  """
53
- 更平滑的二维插值,采用RectBivariateSpline实现bicubic效果,接口与interp_2d兼容。
54
- 支持输入2D/3D/4D数据,最后两维为空间。
55
- interpolation_method: "cubic"(默认,bicubic),"linear"(双线性)
56
- 插值后自动裁剪并用最近邻填充超限和NaN,范围取原始数据的nanmin/nanmax
51
+ 用于PEx并行的单slice插值worker。
52
+ 参数: data_slice, origin_points, target_points, interpolation_method, target_shape
53
+ 球面插值:经纬度转球面坐标后插值
57
54
  """
58
- # 保证输入为ndarray
59
- tx = np.asarray(target_x_coordinates)
60
- ty = np.asarray(target_y_coordinates)
61
- sx = np.asarray(source_x_coordinates)
62
- sy = np.asarray(source_y_coordinates)
63
- data = np.asarray(source_data)
64
-
65
- if ty.ndim == 1:
66
- tx, ty = np.meshgrid(tx, ty)
67
- if sy.ndim == 1:
68
- sx, sy = np.meshgrid(sx, sy)
69
-
70
- if sx.shape != data.shape[-2:] or sy.shape != data.shape[-2:]:
71
- raise ValueError("Shape of source_data does not match shape of source_x_coordinates or source_y_coordinates.")
72
-
73
- data_dims = data.ndim
55
+ data_slice, origin_points, target_points, interpolation_method, target_shape = args
56
+
57
+ # 经纬度归一化
58
+ origin_points = origin_points.copy()
59
+ target_points = target_points.copy()
60
+ origin_points[:, 0] = _normalize_lon(origin_points[:, 0], target_points[:, 0])
61
+ target_points[:, 0] = _normalize_lon(target_points[:, 0], origin_points[:, 0])
62
+
63
+ def lonlat2xyz(lon, lat):
64
+ lon_rad = np.deg2rad(lon)
65
+ lat_rad = np.deg2rad(lat)
66
+ x = np.cos(lat_rad) * np.cos(lon_rad)
67
+ y = np.cos(lat_rad) * np.sin(lon_rad)
68
+ z = np.sin(lat_rad)
69
+ return np.stack([x, y, z], axis=-1)
70
+
71
+ # 过滤掉包含 NaN 的点
72
+ valid_mask = ~np.isnan(data_slice.ravel())
73
+ valid_data = data_slice.ravel()[valid_mask]
74
+ valid_points = origin_points[valid_mask]
75
+
76
+ if len(valid_data) < 10:
77
+ return np.full(target_shape, np.nanmean(data_slice))
78
+
79
+ # 拓宽经度分界,避免如179/-181插值断裂
80
+ valid_points_exp, valid_data_exp = _expand_lonlat_for_dateline(valid_points, valid_data)
81
+
82
+ valid_xyz = lonlat2xyz(valid_points_exp[:, 0], valid_points_exp[:, 1])
83
+ target_xyz = lonlat2xyz(target_points[:, 0], target_points[:, 1])
84
+
85
+ # 使用 griddata 的 cubic 插值以获得更好平滑效果
86
+ result = griddata(valid_xyz, valid_data_exp, target_xyz, method=interpolation_method).reshape(target_shape)
87
+
88
+ # 用最近邻处理残余 NaN
89
+ if np.isnan(result).any():
90
+ nn_interp = NearestNDInterpolator(valid_xyz, valid_data_exp)
91
+ nn = nn_interp(target_xyz).reshape(target_shape)
92
+ result[np.isnan(result)] = nn[np.isnan(result)]
93
+
94
+ return result
95
+
96
+
97
+ def interp_2d_func_geo(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
98
+ """
99
+ Perform 2D interpolation on the last two dimensions of a multi-dimensional array (spherical coordinates).
100
+ 使用球面坐标系进行插值,适用于全球尺度的地理数据,能正确处理经度跨越日期线的情况。
101
+
102
+ Args:
103
+ target_x_coordinates (Union[np.ndarray, List[float]]): Target grid's longitude (-180 to 180 or 0 to 360).
104
+ target_y_coordinates (Union[np.ndarray, List[float]]): Target grid's latitude (-90 to 90).
105
+ source_x_coordinates (Union[np.ndarray, List[float]]): Original grid's longitude (-180 to 180 or 0 to 360).
106
+ source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's latitude (-90 to 90).
107
+ source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
108
+ interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
109
+ >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
110
+
111
+ Returns:
112
+ np.ndarray: Interpolated data array.
113
+
114
+ Raises:
115
+ ValueError: If input shapes are invalid.
116
+
117
+ Examples:
118
+ >>> # 创建一个全球网格示例
119
+ >>> target_lon = np.arange(-180, 181, 1) # 1度分辨率目标网格
120
+ >>> target_lat = np.arange(-90, 91, 1)
121
+ >>> source_lon = np.arange(-180, 181, 5) # 5度分辨率源网格
122
+ >>> source_lat = np.arange(-90, 91, 5)
123
+ >>> # 创建一个简单的数据场 (例如温度场)
124
+ >>> source_data = np.cos(np.deg2rad(source_lat.reshape(-1, 1))) * np.cos(np.deg2rad(source_lon))
125
+ >>> # 插值到高分辨率网格
126
+ >>> result = interp_2d_geo(target_lon, target_lat, source_lon, source_lat, source_data)
127
+ >>> print(result.shape) # Expected output: (181, 361)
128
+ """
129
+ # 验证输入数据范围
130
+ if np.nanmin(target_y_coordinates) < -90 or np.nanmax(target_y_coordinates) > 90:
131
+ raise ValueError("[red]Target latitude must be in range [-90, 90].[/red]")
132
+ if np.nanmin(source_y_coordinates) < -90 or np.nanmax(source_y_coordinates) > 90:
133
+ raise ValueError("[red]Source latitude must be in range [-90, 90].[/red]")
134
+
135
+ if len(target_y_coordinates.shape) == 1:
136
+ target_x_coordinates, target_y_coordinates = np.meshgrid(target_x_coordinates, target_y_coordinates)
137
+ if len(source_y_coordinates.shape) == 1:
138
+ source_x_coordinates, source_y_coordinates = np.meshgrid(source_x_coordinates, source_y_coordinates)
139
+
140
+ if source_x_coordinates.shape != source_data.shape[-2:] or source_y_coordinates.shape != source_data.shape[-2:]:
141
+ raise ValueError("[red]Shape of source_data does not match shape of source_x_coordinates or source_y_coordinates.[/red]")
142
+
143
+ target_points = np.column_stack((np.array(target_x_coordinates).ravel(), np.array(target_y_coordinates).ravel()))
144
+ origin_points = np.column_stack((np.array(source_x_coordinates).ravel(), np.array(source_y_coordinates).ravel()))
145
+
146
+ data_dims = len(source_data.shape)
74
147
  if data_dims < 2:
75
- raise ValueError("Source data must have at least 2 dimensions.")
148
+ raise ValueError(f"[red]Source data must have at least 2 dimensions, but got {data_dims}.[/red]")
76
149
  elif data_dims > 4:
77
- raise ValueError("Source data has more than 4 dimensions, not supported.")
150
+ raise ValueError(f"Source data has {data_dims} dimensions, but this function currently supports only up to 4.")
78
151
 
79
152
  num_dims_to_add = 4 - data_dims
80
- new_shape = (1,) * num_dims_to_add + data.shape
81
- data4d = data.reshape(new_shape)
82
- t, z, ny, nx = data4d.shape
153
+ new_shape = (1,) * num_dims_to_add + source_data.shape
154
+ new_src_data = source_data.reshape(new_shape)
83
155
 
84
- data_min, data_max = np.nanmin(data), np.nanmax(data)
85
- target_shape = ty.shape
156
+ t, z, y, x = new_src_data.shape
86
157
 
87
- # 并行参数准备
88
158
  params = []
89
- for ti in range(t):
90
- for zi in range(z):
91
- params.append((data4d[ti, zi], sx, sy, tx, ty, interpolation_method, data_min, data_max))
159
+ target_shape = target_y_coordinates.shape
160
+ for t_index in range(t):
161
+ for z_index in range(z):
162
+ params.append((new_src_data[t_index, z_index], origin_points, target_points, interpolation_method, target_shape))
92
163
 
93
164
  with PEx() as excutor:
94
165
  result = excutor.run(_interp_single_worker, params)
95
166
 
96
- result = np.array(result).reshape(t, z, *target_shape)
97
- result = np.squeeze(result)
98
- return result
167
+ return np.squeeze(np.array(result).reshape(t, z, *target_shape))
@@ -1,26 +1,10 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- """
4
- Author: Liu Kun && 16031215@qq.com
5
- Date: 2025-03-30 11:16:29
6
- LastEditors: Liu Kun && 16031215@qq.com
7
- LastEditTime: 2025-04-25 14:23:10
8
- FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\_script\\netcdf_merge.py
9
- Description
10
- EditPlatform: vscode
11
- ComputerInfo: XPS 15 9510
12
- SystemInfo: Windows 11
13
- Python Version: 3.12
14
- """
15
-
1
+ import logging
16
2
  import os
17
3
  from typing import List, Optional, Union
18
-
19
4
  import numpy as np
20
5
  import xarray as xr
21
6
 
22
7
  from oafuncs import pbar
23
- import logging
24
8
 
25
9
 
26
10
  def merge_nc(file_list: Union[str, List[str]], var_name: Optional[Union[str, List[str]]] = None, dim_name: Optional[str] = None, target_filename: Optional[str] = None) -> None:
@@ -75,26 +59,41 @@ def merge_nc(file_list: Union[str, List[str]], var_name: Optional[Union[str, Lis
75
59
  if dim_name in data_var.dims:
76
60
  merged_data.setdefault(var, []).append(data_var)
77
61
  elif var not in merged_data:
78
- # 判断类型,时间类型用NaT填充
79
- if np.issubdtype(data_var.dtype, np.datetime64):
80
- merged_data[var] = data_var.fillna(np.datetime64("NaT"))
81
- else:
82
- merged_data[var] = data_var.fillna(0)
62
+ # 区分更细致的类型,不仅是时间型和非时间型
63
+ if data_var.dtype.kind == "M": # datetime64类型
64
+ merged_data[var] = data_var # 时间类型在save_to_nc处理
65
+ elif data_var.dtype.kind in ["f", "i", "u"]: # 数值类型
66
+ # 对数值型变量用-32767填充NaN,而不是0
67
+ merged_data[var] = data_var.fillna(-32767)
68
+ else: # 字符串或其他类型
69
+ merged_data[var] = data_var # 非数值类型保持原样
83
70
 
84
71
  for var in pbar(merged_data, description="Merging variables"):
85
72
  if isinstance(merged_data[var], list):
86
- # 判断类型,时间类型用NaT填充
87
- if np.issubdtype(merged_data[var][0].dtype, np.datetime64):
88
- merged_data[var] = xr.concat(merged_data[var], dim=dim_name).fillna(np.datetime64("NaT"))
89
- else:
90
- merged_data[var] = xr.concat(merged_data[var], dim=dim_name).fillna(0)
73
+ # 区分更细致的类型处理
74
+ if merged_data[var][0].dtype.kind == "M": # datetime64类型
75
+ merged_data[var] = xr.concat(merged_data[var], dim=dim_name)
76
+ elif merged_data[var][0].dtype.kind in ["f", "i", "u"]: # 数值类型
77
+ # 使用-32767填充NaN,而不是0
78
+ merged_data[var] = xr.concat(merged_data[var], dim=dim_name).fillna(-32767)
79
+ else: # 字符串或其他类型
80
+ merged_data[var] = xr.concat(merged_data[var], dim=dim_name)
81
+
82
+ # 在构建最终数据集前,再次检查确保所有数值型变量没有NaN
83
+ merged_ds = xr.Dataset(merged_data)
84
+ for var_name in merged_ds.data_vars:
85
+ var = merged_ds[var_name]
86
+ if var.dtype.kind in ["f", "i", "u"]:
87
+ if np.isnan(var.values).any():
88
+ logging.warning(f"变量 {var_name} 在合并后仍包含NaN值,将替换为-32767")
89
+ merged_ds[var_name] = var.fillna(-32767)
91
90
 
92
91
  if os.path.exists(target_filename):
93
92
  # print("Warning: The target file already exists. Removing it ...")
94
93
  logging.warning("The target file already exists. Removing it ...")
95
94
  os.remove(target_filename)
96
95
 
97
- save_to_nc(target_filename, xr.Dataset(merged_data))
96
+ save_to_nc(target_filename, merged_ds)
98
97
 
99
98
 
100
99
  # Example usage
@@ -1,13 +1,55 @@
1
1
  import os
2
+ import warnings
2
3
 
3
4
  import netCDF4 as nc
4
5
  import numpy as np
5
6
  import xarray as xr
6
- import warnings
7
7
 
8
8
  warnings.filterwarnings("ignore", category=RuntimeWarning)
9
9
 
10
10
 
11
+ def _nan_to_fillvalue(ncfile):
12
+ """
13
+ 将 NetCDF 文件中所有变量的 NaN 和掩码值替换为其 _FillValue 属性(若无则自动添加 _FillValue=-32767 并替换)。
14
+ 同时处理掩码数组中的无效值。
15
+ 仅对数值型变量(浮点型、整型)生效。
16
+ """
17
+ with nc.Dataset(ncfile, "r+") as ds:
18
+ for var_name in ds.variables:
19
+ var = ds.variables[var_name]
20
+ # 只处理数值类型变量 (f:浮点型, i:有符号整型, u:无符号整型)
21
+ if var.dtype.kind not in ["f", "i", "u"]:
22
+ continue
23
+
24
+ # 读取数据
25
+ arr = var[:]
26
+
27
+ # 确定填充值
28
+ if "_FillValue" in var.ncattrs():
29
+ fill_value = var.getncattr("_FillValue")
30
+ else:
31
+ fill_value = -32767
32
+ try:
33
+ var.setncattr("_FillValue", fill_value)
34
+ except Exception:
35
+ # 某些变量可能不允许动态添加 _FillValue
36
+ continue
37
+
38
+ # 处理掩码数组
39
+ if hasattr(arr, "mask"):
40
+ # 如果是掩码数组,将掩码位置的值设为 fill_value
41
+ if np.any(arr.mask):
42
+ # 转换为普通数组,掩码位置填入 fill_value
43
+ arr = np.ma.filled(arr, fill_value=fill_value)
44
+
45
+ # 处理剩余NaN值
46
+ if np.any(np.isnan(arr)):
47
+ arr = np.nan_to_num(arr, nan=fill_value, posinf=fill_value, neginf=fill_value)
48
+
49
+ # 写回变量
50
+ var[:] = arr
51
+
52
+
11
53
  def _numpy_to_nc_type(numpy_type):
12
54
  """将 NumPy 数据类型映射到 NetCDF 数据类型"""
13
55
  numpy_to_nc = {
@@ -30,17 +72,22 @@ def _calculate_scale_and_offset(data, n=16):
30
72
  """
31
73
  计算数值型数据的 scale_factor 与 add_offset,
32
74
  将数据映射到 [0, 2**n - 1] 的范围。
33
-
34
- 要求 data 为数值型的 NumPy 数组,不允许全 NaN 值。
75
+ 同时处理数据中的 NaN 值。
35
76
  """
36
77
  if not isinstance(data, np.ndarray):
37
78
  raise ValueError("Input data must be a NumPy array.")
38
79
 
39
- data_min = np.nanmin(data)
40
- data_max = np.nanmax(data)
80
+ # 先处理数据中的 NaN 和无穷值
81
+ clean_data = np.nan_to_num(data, nan=-32767, posinf=np.finfo(float).max, neginf=np.finfo(float).min)
41
82
 
42
- if np.isnan(data_min) or np.isnan(data_max):
43
- raise ValueError("Input data contains NaN values.")
83
+ # 计算有效值的最小最大值(排除填充值)
84
+ valid_mask = clean_data != -32767
85
+ if np.any(valid_mask):
86
+ data_min = np.min(clean_data[valid_mask])
87
+ data_max = np.max(clean_data[valid_mask])
88
+ else:
89
+ # 全都是填充值的情况,使用默认范围
90
+ data_min, data_max = 0, 1
44
91
 
45
92
  if data_max == data_min:
46
93
  scale_factor = 1.0
@@ -111,6 +158,7 @@ def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_swi
111
158
  new_da.to_dataset(name=varname).to_netcdf(file, mode=mode, encoding=encoding)
112
159
  else:
113
160
  data.to_dataset(name=varname).to_netcdf(file, mode=mode)
161
+ _nan_to_fillvalue(file) # 替换 NaN 为 _FillValue
114
162
  return
115
163
 
116
164
  else:
@@ -139,6 +187,7 @@ def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_swi
139
187
  new_ds.to_netcdf(file, mode=mode, encoding=encoding)
140
188
  else:
141
189
  new_ds.to_netcdf(file, mode=mode)
190
+ _nan_to_fillvalue(file) # 替换 NaN 为 _FillValue
142
191
  return
143
192
 
144
193
  # 处理纯 numpy 数组情况
@@ -172,6 +221,7 @@ def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_swi
172
221
  dtype = _numpy_to_nc_type(data.dtype)
173
222
  var = ncfile.createVariable(varname, dtype, dims, zlib=False)
174
223
  var[:] = data
224
+ _nan_to_fillvalue(file) # 替换 NaN 为 _FillValue
175
225
  except Exception as e:
176
226
  raise RuntimeError(f"netCDF4 保存失败: {str(e)}") from e
177
227
 
@@ -6,6 +6,7 @@ import time
6
6
  from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
7
7
  from typing import Any, Callable, Dict, List, Optional, Tuple
8
8
 
9
+
9
10
  import psutil
10
11
 
11
12
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -14,6 +15,7 @@ __all__ = ["ParallelExecutor"]
14
15
 
15
16
 
16
17
  class ParallelExecutor:
18
+
17
19
  def __init__(
18
20
  self,
19
21
  max_workers: Optional[int] = None,
@@ -120,6 +122,7 @@ class ParallelExecutor:
120
122
  self._executor = None
121
123
 
122
124
  def _execute_batch(self, func: Callable, params: List[Tuple], chunk_size: int) -> List[Any]:
125
+ from oafuncs.oa_tool import pbar
123
126
  if not params:
124
127
  return []
125
128
 
@@ -129,7 +132,7 @@ class ParallelExecutor:
129
132
  results = [None] * len(params)
130
133
  with self._get_executor() as executor:
131
134
  futures = {executor.submit(func, *args): idx for idx, args in enumerate(params)}
132
- for future in as_completed(futures):
135
+ for future in pbar(as_completed(futures), "Parallel Tasks", total=len(futures)):
133
136
  idx = futures[future]
134
137
  try:
135
138
  results[idx] = future.result(timeout=self.timeout_per_task)
oafuncs/oa_data.py CHANGED
@@ -23,7 +23,7 @@ from rich import print
23
23
  from scipy.interpolate import interp1d
24
24
 
25
25
 
26
- __all__ = ["interp_along_dim", "interp_2d", "ensure_list", "mask_shapefile"]
26
+ __all__ = ["interp_along_dim", "interp_2d", "interp_2d_geo", "ensure_list", "mask_shapefile"]
27
27
 
28
28
 
29
29
  def ensure_list(input_value: Any) -> List[str]:
@@ -120,7 +120,7 @@ def interp_2d(
120
120
  source_x_coordinates: Union[np.ndarray, List[float]],
121
121
  source_y_coordinates: Union[np.ndarray, List[float]],
122
122
  source_data: np.ndarray,
123
- interpolation_method: str = "linear",
123
+ interpolation_method: str = "cubic",
124
124
  ) -> np.ndarray:
125
125
  """
126
126
  Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
@@ -132,7 +132,7 @@ def interp_2d(
132
132
  source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's y-coordinates.
133
133
  source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
134
134
  >>> must be [y, x] or [*, y, x] or [*, *, y, x]
135
- interpolation_method (str, optional): Interpolation method. Defaults to "linear".
135
+ interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
136
136
  >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
137
137
  use_parallel (bool, optional): Enable parallel processing. Defaults to True.
138
138
 
@@ -162,7 +162,47 @@ def interp_2d(
162
162
  interpolation_method=interpolation_method,
163
163
  )
164
164
 
165
+ def interp_2d_geo(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
166
+ """
167
+ Perform 2D interpolation on the last two dimensions of a multi-dimensional array (spherical coordinates).
168
+ 使用球面坐标系进行插值,适用于全球尺度的地理数据,能正确处理经度跨越日期线的情况。
169
+
170
+ Args:
171
+ target_x_coordinates (Union[np.ndarray, List[float]]): Target grid's longitude (-180 to 180 or 0 to 360).
172
+ target_y_coordinates (Union[np.ndarray, List[float]]): Target grid's latitude (-90 to 90).
173
+ source_x_coordinates (Union[np.ndarray, List[float]]): Original grid's longitude (-180 to 180 or 0 to 360).
174
+ source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's latitude (-90 to 90).
175
+ source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
176
+ interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
177
+ >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
178
+
179
+ Returns:
180
+ np.ndarray: Interpolated data array.
181
+
182
+ Raises:
183
+ ValueError: If input shapes are invalid.
165
184
 
185
+ Examples:
186
+ >>> # 创建一个全球网格示例
187
+ >>> target_lon = np.arange(-180, 181, 1) # 1度分辨率目标网格
188
+ >>> target_lat = np.arange(-90, 91, 1)
189
+ >>> source_lon = np.arange(-180, 181, 5) # 5度分辨率源网格
190
+ >>> source_lat = np.arange(-90, 91, 5)
191
+ >>> # 创建一个简单的数据场 (例如温度场)
192
+ >>> source_data = np.cos(np.deg2rad(source_lat.reshape(-1, 1))) * np.cos(np.deg2rad(source_lon))
193
+ >>> # 插值到高分辨率网格
194
+ >>> result = interp_2d_geo(target_lon, target_lat, source_lon, source_lat, source_data)
195
+ >>> print(result.shape) # Expected output: (181, 361)
196
+ """
197
+ from ._script.data_interp_geo import interp_2d_func_geo
198
+ interp_2d_func_geo(
199
+ target_x_coordinates=target_x_coordinates,
200
+ target_y_coordinates=target_y_coordinates,
201
+ source_x_coordinates=source_x_coordinates,
202
+ source_y_coordinates=source_y_coordinates,
203
+ source_data=source_data,
204
+ interpolation_method=interpolation_method,
205
+ )
166
206
 
167
207
  def mask_shapefile(
168
208
  data_array: np.ndarray,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: oafuncs
3
- Version: 0.0.98.19
3
+ Version: 0.0.98.20
4
4
  Summary: Oceanic and Atmospheric Functions
5
5
  Home-page: https://github.com/Industry-Pays/OAFuncs
6
6
  Author: Kun Liu
@@ -1,6 +1,6 @@
1
1
  oafuncs/__init__.py,sha256=T_-VtnWWllV3Q91twT5Yt2sUapeA051QbPNnBxmg9nw,1456
2
2
  oafuncs/oa_cmap.py,sha256=DimWT4Bg7uE5Lx8hSw1REp7whpsR2pFRStAwk1cowEM,11494
3
- oafuncs/oa_data.py,sha256=y11xxaVNZ6_eveVjSG4PisRXYpKr_FFsBBh0mj_ss2g,8436
3
+ oafuncs/oa_data.py,sha256=Aat9ktxxRGevaqQya3IJWfXeoEs-FCXGUcNE2pKnzfU,10931
4
4
  oafuncs/oa_date.py,sha256=WhM6cyD4G3IeghjLTHhAMtlvJbA7kwQG2sHnxdTgyso,6303
5
5
  oafuncs/oa_draw.py,sha256=Wj2QBgyIPpV_dxaDrH10jqj_puK9ZM9rd-si-3VrsrE,17631
6
6
  oafuncs/oa_file.py,sha256=j9gXJgPOJsliu4IOUc4bc-luW4yBvQyNCEmMyDVjUwQ,16404
@@ -11,13 +11,13 @@ oafuncs/oa_tool.py,sha256=rpPkLqWhqMmqlCc5wjL8qMTg3gThCkSrYJckbX_0iJc,8631
11
11
  oafuncs/_data/hycom.png,sha256=MadKs6Gyj5n9-TOu7L4atQfTXtF9dvN9w-tdU9IfygI,10945710
12
12
  oafuncs/_data/oafuncs.png,sha256=o3VD7wm-kwDea5E98JqxXl04_78cBX7VcdUt7uQXGiU,3679898
13
13
  oafuncs/_script/cprogressbar.py,sha256=UIgGcLFs-6IgWlITuBLaQqrpt4OAK3Mst5RlCiNfZdQ,15772
14
- oafuncs/_script/data_interp.py,sha256=KJ-p-UN3Op1MmtCoN4KdjFVHFE3GNHrTD3vBjzaYSjQ,4688
15
- oafuncs/_script/data_interp_geo.py,sha256=X89KxLYhpltWi0Sf96gIhBL3r1M5aExd_JCmgBmmvUc,3742
14
+ oafuncs/_script/data_interp.py,sha256=EiZbt6n5BEaRKcng88UgX7TFPhKE6TLVZniS01awXjg,5146
15
+ oafuncs/_script/data_interp_geo.py,sha256=ZRFb3fKRiYQViZNHd19eW20C9i38BsiIU8w0fG5mbqM,7789
16
16
  oafuncs/_script/email.py,sha256=lL4HGKrr524-g0xLlgs-4u7x4-u7DtgNoD9AL8XJKj4,3058
17
- oafuncs/_script/netcdf_merge.py,sha256=9hCyxfeUHnBzs50_0v0jzVfxpMxTX4dNTo0pmsp_T6g,4226
17
+ oafuncs/_script/netcdf_merge.py,sha256=zasqrFpB2GHAjZ1LkrWtI7kpHu7uCjdIGf4C0lEraYA,4816
18
18
  oafuncs/_script/netcdf_modify.py,sha256=sGRUYNhfGgf9JV70rnBzw3bzuTRSXzBTL_RMDnDPeLQ,4552
19
- oafuncs/_script/netcdf_write.py,sha256=iO1Qv9bp6RLiw1D8Nrv7tX_8X-diUZaX3Nxhk6pJ5Nw,8556
20
- oafuncs/_script/parallel.py,sha256=T9Aie-e4LcbKlFTLZ0l4lhEN3SBVa84jRcrAsIm8s0I,8767
19
+ oafuncs/_script/netcdf_write.py,sha256=gEqeagAcu0_xa6cIuu7a5d9uXMEaGqnci02KNc5znEY,10672
20
+ oafuncs/_script/parallel.py,sha256=07-BJVHxXJNlrOrhrSGt7qCZiKWq6dBvNDBA1AANYnI,8861
21
21
  oafuncs/_script/parallel_test.py,sha256=0GBqZOX7IaCOKF2t1y8N8YYu53GJ33OkfsWgpvZNqM4,372
22
22
  oafuncs/_script/plot_dataset.py,sha256=zkSEnO_-biyagorwWXPoihts_cwuvripzEt-l9bHJ2E,13989
23
23
  oafuncs/_script/replace_file_content.py,sha256=eCFZjnZcwyRvy6b4mmIfBna-kylSZTyJRfgXd6DdCjk,5982
@@ -39,8 +39,8 @@ oafuncs/oa_sign/__init__.py,sha256=QKqTFrJDFK40C5uvk48GlRRbGFzO40rgkYwu6dYxatM,5
39
39
  oafuncs/oa_sign/meteorological.py,sha256=8091SHo2L8kl4dCFmmSH5NGVHDku5i5lSiLEG5DLnOQ,6489
40
40
  oafuncs/oa_sign/ocean.py,sha256=xrW-rWD7xBWsB5PuCyEwQ1Q_RDKq2KCLz-LOONHgldU,5932
41
41
  oafuncs/oa_sign/scientific.py,sha256=a4JxOBgm9vzNZKpJ_GQIQf7cokkraV5nh23HGbmTYKw,5064
42
- oafuncs-0.0.98.19.dist-info/licenses/LICENSE.txt,sha256=rMtLpVg8sKiSlwClfR9w_Dd_5WubTQgoOzE2PDFxzs4,1074
43
- oafuncs-0.0.98.19.dist-info/METADATA,sha256=PJ_1BA6QOeA2QHSb61jwOFA2pDD2H4QT--m3tf_f44I,4273
44
- oafuncs-0.0.98.19.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
45
- oafuncs-0.0.98.19.dist-info/top_level.txt,sha256=bgC35QkXbN4EmPHEveg_xGIZ5i9NNPYWqtJqaKqTPsQ,8
46
- oafuncs-0.0.98.19.dist-info/RECORD,,
42
+ oafuncs-0.0.98.20.dist-info/licenses/LICENSE.txt,sha256=rMtLpVg8sKiSlwClfR9w_Dd_5WubTQgoOzE2PDFxzs4,1074
43
+ oafuncs-0.0.98.20.dist-info/METADATA,sha256=QgOUX1DWWv2bLiGauzfcJsKtuGQfyJkinMJoCO7FRdI,4273
44
+ oafuncs-0.0.98.20.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
45
+ oafuncs-0.0.98.20.dist-info/top_level.txt,sha256=bgC35QkXbN4EmPHEveg_xGIZ5i9NNPYWqtJqaKqTPsQ,8
46
+ oafuncs-0.0.98.20.dist-info/RECORD,,