oafuncs 0.0.98.19__py3-none-any.whl → 0.0.98.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,18 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2025-04-25 16:22:52
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2025-04-26 19:21:31
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\_script\\data_interp.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.12
14
+ """
15
+
1
16
  from typing import List, Union
2
17
 
3
18
  import numpy as np
@@ -8,10 +23,10 @@ from oafuncs.oa_tool import PEx
8
23
 
9
24
  def _interp_single_worker(*args):
10
25
  """
11
- 用于PEx并行的单slice插值worker,参数为(t, z, source_data, origin_points, target_points, interpolation_method, target_shape)
26
+ 用于PEx并行的单slice插值worker
27
+ 参数: data_slice, origin_points, target_points, interpolation_method, target_shape
12
28
  """
13
29
  data_slice, origin_points, target_points, interpolation_method, target_shape = args
14
-
15
30
  # 过滤掉包含 NaN 的点
16
31
  valid_mask = ~np.isnan(data_slice.ravel())
17
32
  valid_data = data_slice.ravel()[valid_mask]
@@ -21,20 +36,16 @@ def _interp_single_worker(*args):
21
36
  return np.full(target_shape, np.nanmean(data_slice))
22
37
 
23
38
  # 使用有效数据进行插值
24
- result = griddata(valid_points, valid_data, target_points, method=interpolation_method)
25
- result = result.reshape(target_shape)
39
+ result = griddata(valid_points, valid_data, target_points, method=interpolation_method).reshape(target_shape)
40
+ # 对仍为 NaN 的点用最近邻填充
41
+ if np.isnan(result).any():
42
+ nn = griddata(valid_points, valid_data, target_points, method="nearest").reshape(target_shape)
43
+ result[np.isnan(result)] = nn[np.isnan(result)]
26
44
 
27
45
  return result
28
46
 
29
47
 
30
- def interp_2d_func(
31
- target_x_coordinates: Union[np.ndarray, List[float]],
32
- target_y_coordinates: Union[np.ndarray, List[float]],
33
- source_x_coordinates: Union[np.ndarray, List[float]],
34
- source_y_coordinates: Union[np.ndarray, List[float]],
35
- source_data: np.ndarray,
36
- interpolation_method: str = "cubic",
37
- ) -> np.ndarray:
48
+ def interp_2d_func(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
38
49
  """
39
50
  Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
40
51
 
@@ -46,7 +57,6 @@ def interp_2d_func(
46
57
  source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
47
58
  interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
48
59
  >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
49
- use_parallel (bool, optional): Enable parallel processing. Defaults to True.
50
60
 
51
61
  Returns:
52
62
  np.ndarray: Interpolated data array.
@@ -60,7 +70,7 @@ def interp_2d_func(
60
70
  >>> source_x_coordinates = np.array([7, 8, 9])
61
71
  >>> source_y_coordinates = np.array([10, 11, 12])
62
72
  >>> source_data = np.random.rand(3, 3)
63
- >>> result = interp_2d(target_x_coordinates, target_y_coordinates, source_x_coordinates, source_y_coordinates, source_data)
73
+ >>> result = interp_2d_func(target_x_coordinates, target_y_coordinates, source_x_coordinates, source_y_coordinates, source_data)
64
74
  >>> print(result.shape) # Expected output: (3, 3)
65
75
  """
66
76
  if len(target_y_coordinates.shape) == 1:
@@ -80,7 +90,7 @@ def interp_2d_func(
80
90
  raise ValueError(f"[red]Source data must have at least 2 dimensions, but got {data_dims}.[/red]")
81
91
  elif data_dims > 4:
82
92
  # Or handle cases with more than 4 dimensions if necessary
83
- raise ValueError(f"[red]Source data has {data_dims} dimensions, but this function currently supports only up to 4.[/red]")
93
+ raise ValueError(f"Source data has {data_dims} dimensions, but this function currently supports only up to 4.")
84
94
 
85
95
  # Reshape to 4D by adding leading dimensions of size 1 if needed
86
96
  num_dims_to_add = 4 - data_dims
@@ -1,98 +1,167 @@
1
1
  from typing import List, Union
2
2
 
3
3
  import numpy as np
4
- from scipy.interpolate import RectBivariateSpline
4
+ from scipy.interpolate import NearestNDInterpolator, griddata
5
5
 
6
6
  from oafuncs.oa_tool import PEx
7
- from oafuncs.oa_data import data_clip
8
- from oafuncs._script.data_interp import _fill_nan_nearest
9
7
 
10
8
 
11
- def _interp_single_worker(*args):
9
+ def _normalize_lon(lon, ref_lon):
12
10
  """
13
- 单slice插值worker,参数为(data_slice, sx, sy, tx, ty, interpolation_method, data_min, data_max)
11
+ 将经度数组 lon 归一化到与 ref_lon 相同的区间([-180,180] [0,360])
12
+ 并在经度分界(如180/-180, 0/360)附近自动拓宽,避免插值断裂。
14
13
  """
15
- # 兼容PEx调用方式:args为tuple或list
16
- if len(args) == 1 and isinstance(args[0], (tuple, list)):
17
- args = args[0]
18
- data_slice, sx, sy, tx, ty, interpolation_method, data_min, data_max = args
19
- # 处理nan
20
- if np.isnan(data_slice).any():
21
- mask = np.isnan(data_slice)
22
- if mask.any():
23
- data_slice = _fill_nan_nearest(data_slice)
24
- x1d = np.unique(sx[0, :])
25
- y1d = np.unique(sy[:, 0])
26
- if sx.shape != (len(y1d), len(x1d)) or sy.shape != (len(y1d), len(x1d)):
27
- from scipy.interpolate import griddata
28
-
29
- grid_points = np.column_stack((sx.ravel(), sy.ravel()))
30
- grid_values = data_slice.ravel()
31
- data_slice = griddata(grid_points, grid_values, (x1d[None, :], y1d[:, None]), method="linear")
32
- if interpolation_method == "linear":
33
- kx = ky = 1
14
+ lon = np.asarray(lon)
15
+ ref_lon = np.asarray(ref_lon)
16
+ if np.nanmax(ref_lon) > 180:
17
+ lon = np.where(lon < 0, lon + 360, lon)
34
18
  else:
35
- kx = ky = 3
36
- interp_func = RectBivariateSpline(y1d, x1d, data_slice, kx=kx, ky=ky)
37
- out = interp_func(ty[:, 0], tx[0, :])
38
- # 优化裁剪逻辑:超出范围的点设为nan,再用fill_nan_nearest填充
39
- arr = np.asarray(out)
40
- arr = data_clip(arr,data_min,data_max)
41
- return arr
42
-
43
-
44
- def interp_2d_geo(
45
- target_x_coordinates: Union[np.ndarray, List[float]],
46
- target_y_coordinates: Union[np.ndarray, List[float]],
47
- source_x_coordinates: Union[np.ndarray, List[float]],
48
- source_y_coordinates: Union[np.ndarray, List[float]],
49
- source_data: np.ndarray,
50
- interpolation_method: str = "cubic",
51
- ) -> np.ndarray:
19
+ lon = np.where(lon > 180, lon - 360, lon)
20
+ return lon
21
+
22
+
23
+ def _expand_lonlat_for_dateline(points, values):
24
+ """
25
+ 对经度分界(如180/-180, 0/360)附近的数据进行拓宽,避免插值断裂。
26
+ points: (N,2) [lon,lat]
27
+ values: (N,)
28
+ 返回拓宽后的 points, values
29
+ """
30
+ lon = points[:, 0]
31
+ lat = points[:, 1]
32
+ expanded_points = [points]
33
+ expanded_values = [values]
34
+ if (np.nanmax(lon) > 170) and (np.nanmin(lon) < -170):
35
+ expanded_points.append(np.column_stack((lon + 360, lat)))
36
+ expanded_points.append(np.column_stack((lon - 360, lat)))
37
+ expanded_values.append(values)
38
+ expanded_values.append(values)
39
+ if (np.nanmax(lon) > 350) and (np.nanmin(lon) < 10):
40
+ expanded_points.append(np.column_stack((lon - 360, lat)))
41
+ expanded_points.append(np.column_stack((lon + 360, lat)))
42
+ expanded_values.append(values)
43
+ expanded_values.append(values)
44
+ points_new = np.vstack(expanded_points)
45
+ values_new = np.concatenate(expanded_values)
46
+ return points_new, values_new
47
+
48
+
49
+ def _interp_single_worker(*args):
52
50
  """
53
- 更平滑的二维插值,采用RectBivariateSpline实现bicubic效果,接口与interp_2d兼容。
54
- 支持输入2D/3D/4D数据,最后两维为空间。
55
- interpolation_method: "cubic"(默认,bicubic),"linear"(双线性)
56
- 插值后自动裁剪并用最近邻填充超限和NaN,范围取原始数据的nanmin/nanmax
51
+ 用于PEx并行的单slice插值worker。
52
+ 参数: data_slice, origin_points, target_points, interpolation_method, target_shape
53
+ 球面插值:经纬度转球面坐标后插值
57
54
  """
58
- # 保证输入为ndarray
59
- tx = np.asarray(target_x_coordinates)
60
- ty = np.asarray(target_y_coordinates)
61
- sx = np.asarray(source_x_coordinates)
62
- sy = np.asarray(source_y_coordinates)
63
- data = np.asarray(source_data)
64
-
65
- if ty.ndim == 1:
66
- tx, ty = np.meshgrid(tx, ty)
67
- if sy.ndim == 1:
68
- sx, sy = np.meshgrid(sx, sy)
69
-
70
- if sx.shape != data.shape[-2:] or sy.shape != data.shape[-2:]:
71
- raise ValueError("Shape of source_data does not match shape of source_x_coordinates or source_y_coordinates.")
72
-
73
- data_dims = data.ndim
55
+ data_slice, origin_points, target_points, interpolation_method, target_shape = args
56
+
57
+ # 经纬度归一化
58
+ origin_points = origin_points.copy()
59
+ target_points = target_points.copy()
60
+ origin_points[:, 0] = _normalize_lon(origin_points[:, 0], target_points[:, 0])
61
+ target_points[:, 0] = _normalize_lon(target_points[:, 0], origin_points[:, 0])
62
+
63
+ def lonlat2xyz(lon, lat):
64
+ lon_rad = np.deg2rad(lon)
65
+ lat_rad = np.deg2rad(lat)
66
+ x = np.cos(lat_rad) * np.cos(lon_rad)
67
+ y = np.cos(lat_rad) * np.sin(lon_rad)
68
+ z = np.sin(lat_rad)
69
+ return np.stack([x, y, z], axis=-1)
70
+
71
+ # 过滤掉包含 NaN 的点
72
+ valid_mask = ~np.isnan(data_slice.ravel())
73
+ valid_data = data_slice.ravel()[valid_mask]
74
+ valid_points = origin_points[valid_mask]
75
+
76
+ if len(valid_data) < 10:
77
+ return np.full(target_shape, np.nanmean(data_slice))
78
+
79
+ # 拓宽经度分界,避免如179/-181插值断裂
80
+ valid_points_exp, valid_data_exp = _expand_lonlat_for_dateline(valid_points, valid_data)
81
+
82
+ valid_xyz = lonlat2xyz(valid_points_exp[:, 0], valid_points_exp[:, 1])
83
+ target_xyz = lonlat2xyz(target_points[:, 0], target_points[:, 1])
84
+
85
+ # 使用 griddata 的 cubic 插值以获得更好平滑效果
86
+ result = griddata(valid_xyz, valid_data_exp, target_xyz, method=interpolation_method).reshape(target_shape)
87
+
88
+ # 用最近邻处理残余 NaN
89
+ if np.isnan(result).any():
90
+ nn_interp = NearestNDInterpolator(valid_xyz, valid_data_exp)
91
+ nn = nn_interp(target_xyz).reshape(target_shape)
92
+ result[np.isnan(result)] = nn[np.isnan(result)]
93
+
94
+ return result
95
+
96
+
97
+ def interp_2d_func_geo(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
98
+ """
99
+ Perform 2D interpolation on the last two dimensions of a multi-dimensional array (spherical coordinates).
100
+ 使用球面坐标系进行插值,适用于全球尺度的地理数据,能正确处理经度跨越日期线的情况。
101
+
102
+ Args:
103
+ target_x_coordinates (Union[np.ndarray, List[float]]): Target grid's longitude (-180 to 180 or 0 to 360).
104
+ target_y_coordinates (Union[np.ndarray, List[float]]): Target grid's latitude (-90 to 90).
105
+ source_x_coordinates (Union[np.ndarray, List[float]]): Original grid's longitude (-180 to 180 or 0 to 360).
106
+ source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's latitude (-90 to 90).
107
+ source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
108
+ interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
109
+ >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
110
+
111
+ Returns:
112
+ np.ndarray: Interpolated data array.
113
+
114
+ Raises:
115
+ ValueError: If input shapes are invalid.
116
+
117
+ Examples:
118
+ >>> # 创建一个全球网格示例
119
+ >>> target_lon = np.arange(-180, 181, 1) # 1度分辨率目标网格
120
+ >>> target_lat = np.arange(-90, 91, 1)
121
+ >>> source_lon = np.arange(-180, 181, 5) # 5度分辨率源网格
122
+ >>> source_lat = np.arange(-90, 91, 5)
123
+ >>> # 创建一个简单的数据场 (例如温度场)
124
+ >>> source_data = np.cos(np.deg2rad(source_lat.reshape(-1, 1))) * np.cos(np.deg2rad(source_lon))
125
+ >>> # 插值到高分辨率网格
126
+ >>> result = interp_2d_geo(target_lon, target_lat, source_lon, source_lat, source_data)
127
+ >>> print(result.shape) # Expected output: (181, 361)
128
+ """
129
+ # 验证输入数据范围
130
+ if np.nanmin(target_y_coordinates) < -90 or np.nanmax(target_y_coordinates) > 90:
131
+ raise ValueError("[red]Target latitude must be in range [-90, 90].[/red]")
132
+ if np.nanmin(source_y_coordinates) < -90 or np.nanmax(source_y_coordinates) > 90:
133
+ raise ValueError("[red]Source latitude must be in range [-90, 90].[/red]")
134
+
135
+ if len(target_y_coordinates.shape) == 1:
136
+ target_x_coordinates, target_y_coordinates = np.meshgrid(target_x_coordinates, target_y_coordinates)
137
+ if len(source_y_coordinates.shape) == 1:
138
+ source_x_coordinates, source_y_coordinates = np.meshgrid(source_x_coordinates, source_y_coordinates)
139
+
140
+ if source_x_coordinates.shape != source_data.shape[-2:] or source_y_coordinates.shape != source_data.shape[-2:]:
141
+ raise ValueError("[red]Shape of source_data does not match shape of source_x_coordinates or source_y_coordinates.[/red]")
142
+
143
+ target_points = np.column_stack((np.array(target_x_coordinates).ravel(), np.array(target_y_coordinates).ravel()))
144
+ origin_points = np.column_stack((np.array(source_x_coordinates).ravel(), np.array(source_y_coordinates).ravel()))
145
+
146
+ data_dims = len(source_data.shape)
74
147
  if data_dims < 2:
75
- raise ValueError("Source data must have at least 2 dimensions.")
148
+ raise ValueError(f"[red]Source data must have at least 2 dimensions, but got {data_dims}.[/red]")
76
149
  elif data_dims > 4:
77
- raise ValueError("Source data has more than 4 dimensions, not supported.")
150
+ raise ValueError(f"Source data has {data_dims} dimensions, but this function currently supports only up to 4.")
78
151
 
79
152
  num_dims_to_add = 4 - data_dims
80
- new_shape = (1,) * num_dims_to_add + data.shape
81
- data4d = data.reshape(new_shape)
82
- t, z, ny, nx = data4d.shape
153
+ new_shape = (1,) * num_dims_to_add + source_data.shape
154
+ new_src_data = source_data.reshape(new_shape)
83
155
 
84
- data_min, data_max = np.nanmin(data), np.nanmax(data)
85
- target_shape = ty.shape
156
+ t, z, y, x = new_src_data.shape
86
157
 
87
- # 并行参数准备
88
158
  params = []
89
- for ti in range(t):
90
- for zi in range(z):
91
- params.append((data4d[ti, zi], sx, sy, tx, ty, interpolation_method, data_min, data_max))
159
+ target_shape = target_y_coordinates.shape
160
+ for t_index in range(t):
161
+ for z_index in range(z):
162
+ params.append((new_src_data[t_index, z_index], origin_points, target_points, interpolation_method, target_shape))
92
163
 
93
164
  with PEx() as excutor:
94
165
  result = excutor.run(_interp_single_worker, params)
95
166
 
96
- result = np.array(result).reshape(t, z, *target_shape)
97
- result = np.squeeze(result)
98
- return result
167
+ return np.squeeze(np.array(result).reshape(t, z, *target_shape))
@@ -1,26 +1,10 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- """
4
- Author: Liu Kun && 16031215@qq.com
5
- Date: 2025-03-30 11:16:29
6
- LastEditors: Liu Kun && 16031215@qq.com
7
- LastEditTime: 2025-04-25 14:23:10
8
- FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\_script\\netcdf_merge.py
9
- Description
10
- EditPlatform: vscode
11
- ComputerInfo: XPS 15 9510
12
- SystemInfo: Windows 11
13
- Python Version: 3.12
14
- """
15
-
1
+ import logging
16
2
  import os
17
3
  from typing import List, Optional, Union
18
4
 
19
- import numpy as np
20
5
  import xarray as xr
21
6
 
22
7
  from oafuncs import pbar
23
- import logging
24
8
 
25
9
 
26
10
  def merge_nc(file_list: Union[str, List[str]], var_name: Optional[Union[str, List[str]]] = None, dim_name: Optional[str] = None, target_filename: Optional[str] = None) -> None:
@@ -41,7 +25,6 @@ def merge_nc(file_list: Union[str, List[str]], var_name: Optional[Union[str, Lis
41
25
  merge(file_list, var_name=['u', 'v'], dim_name='time', target_filename='merged.nc')
42
26
  merge(file_list, var_name=None, dim_name='time', target_filename='merged.nc')
43
27
  """
44
- from oafuncs._script.netcdf_write import save_to_nc
45
28
 
46
29
  if target_filename is None:
47
30
  target_filename = "merged.nc"
@@ -68,33 +51,68 @@ def merge_nc(file_list: Union[str, List[str]], var_name: Optional[Union[str, Lis
68
51
  # 初始化合并数据字典
69
52
  merged_data = {}
70
53
 
71
- for i, file in pbar(enumerate(file_list), description="Reading files", total=len(file_list)):
54
+ for i, file in pbar(enumerate(file_list), "Reading files", total=len(file_list)):
72
55
  with xr.open_dataset(file) as ds:
73
56
  for var in var_names:
74
57
  data_var = ds[var]
75
58
  if dim_name in data_var.dims:
76
59
  merged_data.setdefault(var, []).append(data_var)
77
60
  elif var not in merged_data:
78
- # 判断类型,时间类型用NaT填充
79
- if np.issubdtype(data_var.dtype, np.datetime64):
80
- merged_data[var] = data_var.fillna(np.datetime64("NaT"))
81
- else:
82
- merged_data[var] = data_var.fillna(0)
83
-
84
- for var in pbar(merged_data, description="Merging variables"):
61
+ # 只负责合并,不做NaN填充,统一交由 netcdf_write.py 处理
62
+ merged_data[var] = data_var
63
+
64
+ # 记录变量的填充值和缺失值信息,确保不会丢失
65
+ fill_values = {}
66
+ missing_values = {}
67
+ for var_name, var_data in merged_data.items():
68
+ if isinstance(var_data, list) and var_data:
69
+ # 如果是要合并的变量,检查第一个元素的属性
70
+ attrs = var_data[0].attrs
71
+ if "_FillValue" in attrs:
72
+ fill_values[var_name] = attrs["_FillValue"]
73
+ if "missing_value" in attrs:
74
+ missing_values[var_name] = attrs["missing_value"]
75
+ else:
76
+ # 如果是单个变量,直接检查属性
77
+ attrs = var_data.attrs if hasattr(var_data, "attrs") else {}
78
+ if "_FillValue" in attrs:
79
+ fill_values[var_name] = attrs["_FillValue"]
80
+ if "missing_value" in attrs:
81
+ missing_values[var_name] = attrs["missing_value"]
82
+
83
+ for var in pbar(merged_data, "Merging variables"):
85
84
  if isinstance(merged_data[var], list):
86
- # 判断类型,时间类型用NaT填充
87
- if np.issubdtype(merged_data[var][0].dtype, np.datetime64):
88
- merged_data[var] = xr.concat(merged_data[var], dim=dim_name).fillna(np.datetime64("NaT"))
89
- else:
90
- merged_data[var] = xr.concat(merged_data[var], dim=dim_name).fillna(0)
85
+ # 使用compat='override'确保合并时属性不会冲突
86
+ merged_data[var] = xr.concat(merged_data[var], dim=dim_name, compat="override")
87
+ # 恢复原始填充值和缺失值属性
88
+ if var in fill_values:
89
+ merged_data[var].attrs["_FillValue"] = fill_values[var]
90
+ if var in missing_values:
91
+ merged_data[var].attrs["missing_value"] = missing_values[var]
92
+
93
+ # 合并后构建 Dataset,此时 merged_data 只包含数据变量,不包含坐标变量
94
+ merged_ds = xr.Dataset(merged_data)
95
+
96
+ # 自动补充坐标变量(如 time、lat、lon 等),以第一个文件为准
97
+ with xr.open_dataset(file_list[0]) as ds0:
98
+ for coord in ds0.coords:
99
+ # 保证坐标变量不会被覆盖,且数据类型和属性保持一致
100
+ if coord not in merged_ds.coords:
101
+ merged_ds = merged_ds.assign_coords({coord: ds0[coord]})
102
+
103
+ # 如果合并维度是坐标,检查所有文件的该坐标是否一致
104
+ if dim_name in merged_ds.coords and len(file_list) > 1:
105
+ logging.info(f"验证合并维度 {dim_name} 的一致性...")
106
+ for file in file_list[1:]:
107
+ with xr.open_dataset(file) as ds:
108
+ if dim_name in ds.coords and not ds[dim_name].equals(merged_ds[dim_name]):
109
+ logging.warning(f"文件 {file} 的 {dim_name} 坐标与合并后的数据不一致,可能导致数据失真")
91
110
 
92
111
  if os.path.exists(target_filename):
93
- # print("Warning: The target file already exists. Removing it ...")
94
112
  logging.warning("The target file already exists. Removing it ...")
95
113
  os.remove(target_filename)
96
114
 
97
- save_to_nc(target_filename, xr.Dataset(merged_data))
115
+ merged_ds.to_netcdf(target_filename,mode='w')
98
116
 
99
117
 
100
118
  # Example usage