oafuncs 0.0.98.22__py3-none-any.whl → 0.0.98.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,167 +1,172 @@
1
+ import importlib.util
1
2
  from typing import List, Union
2
3
 
3
4
  import numpy as np
4
- from scipy.interpolate import NearestNDInterpolator, griddata
5
5
 
6
6
  from oafuncs.oa_tool import PEx
7
7
 
8
+ # 检查pyinterp是否可用
9
+ pyinterp_available = importlib.util.find_spec("pyinterp") is not None
8
10
 
9
- def _normalize_lon(lon, ref_lon):
10
- """
11
- 将经度数组 lon 归一化到与 ref_lon 相同的区间([-180,180] 或 [0,360])
12
- 并在经度分界(如180/-180, 0/360)附近自动拓宽,避免插值断裂。
13
- """
14
- lon = np.asarray(lon)
15
- ref_lon = np.asarray(ref_lon)
16
- if np.nanmax(ref_lon) > 180:
17
- lon = np.where(lon < 0, lon + 360, lon)
18
- else:
19
- lon = np.where(lon > 180, lon - 360, lon)
20
- return lon
11
+ if pyinterp_available:
12
+ import pyinterp
13
+ import pyinterp.backends.xarray as pyxr
14
+ import xarray as xr
21
15
 
22
16
 
23
- def _expand_lonlat_for_dateline(points, values):
17
+ def _fill_nan_with_nearest(data: np.ndarray, source_lons: np.ndarray, source_lats: np.ndarray) -> np.ndarray:
24
18
  """
25
- 对经度分界(如180/-180, 0/360)附近的数据进行拓宽,避免插值断裂。
26
- points: (N,2) [lon,lat]
27
- values: (N,)
28
- 返回拓宽后的 points, values
19
+ 使用最近邻方法填充NaN值,适合地理数据。
29
20
  """
30
- lon = points[:, 0]
31
- lat = points[:, 1]
32
- expanded_points = [points]
33
- expanded_values = [values]
34
- if (np.nanmax(lon) > 170) and (np.nanmin(lon) < -170):
35
- expanded_points.append(np.column_stack((lon + 360, lat)))
36
- expanded_points.append(np.column_stack((lon - 360, lat)))
37
- expanded_values.append(values)
38
- expanded_values.append(values)
39
- if (np.nanmax(lon) > 350) and (np.nanmin(lon) < 10):
40
- expanded_points.append(np.column_stack((lon - 360, lat)))
41
- expanded_points.append(np.column_stack((lon + 360, lat)))
42
- expanded_values.append(values)
43
- expanded_values.append(values)
44
- points_new = np.vstack(expanded_points)
45
- values_new = np.concatenate(expanded_values)
46
- return points_new, values_new
21
+ if not np.isnan(data).any():
22
+ return data
47
23
 
24
+ # 创建掩码,区分有效值和NaN值
25
+ mask = ~np.isnan(data)
26
+ if not np.any(mask):
27
+ return data # 全是NaN,无法填充
48
28
 
49
- def _interp_single_worker(*args):
50
- """
51
- 用于PEx并行的单slice插值worker。
52
- 参数: data_slice, origin_points, target_points, interpolation_method, target_shape
53
- 球面插值:经纬度转球面坐标后插值
54
- """
55
- data_slice, origin_points, target_points, interpolation_method, target_shape = args
29
+ # 使用pyinterp的RTree进行最近邻插值填充NaN
30
+ try:
31
+ if not pyinterp_available:
32
+ raise ImportError("pyinterp not available")
33
+
34
+ # 获取有效数据点的位置和值
35
+ valid_points = np.column_stack((source_lons[mask].ravel(), source_lats[mask].ravel()))
36
+ valid_values = data[mask].ravel()
56
37
 
57
- # 经纬度归一化
58
- origin_points = origin_points.copy()
59
- target_points = target_points.copy()
60
- origin_points[:, 0] = _normalize_lon(origin_points[:, 0], target_points[:, 0])
61
- target_points[:, 0] = _normalize_lon(target_points[:, 0], origin_points[:, 0])
38
+ # 创建RTree
39
+ tree = pyinterp.RTree()
40
+ tree.insert(valid_points.astype(np.float64), valid_values.astype(np.float64))
62
41
 
63
- def lonlat2xyz(lon, lat):
64
- lon_rad = np.deg2rad(lon)
65
- lat_rad = np.deg2rad(lat)
66
- x = np.cos(lat_rad) * np.cos(lon_rad)
67
- y = np.cos(lat_rad) * np.sin(lon_rad)
68
- z = np.sin(lat_rad)
69
- return np.stack([x, y, z], axis=-1)
42
+ # 获取所有点的坐标
43
+ all_points = np.column_stack((source_lons.ravel(), source_lats.ravel()))
70
44
 
71
- # 过滤掉包含 NaN 的点
72
- valid_mask = ~np.isnan(data_slice.ravel())
73
- valid_data = data_slice.ravel()[valid_mask]
74
- valid_points = origin_points[valid_mask]
45
+ # 最近邻插值
46
+ filled_values = tree.query(all_points[:, 0], all_points[:, 1], k=1)
75
47
 
76
- if len(valid_data) < 10:
77
- return np.full(target_shape, np.nanmean(data_slice))
48
+ return filled_values.reshape(data.shape)
78
49
 
79
- # 拓宽经度分界,避免如179/-181插值断裂
80
- valid_points_exp, valid_data_exp = _expand_lonlat_for_dateline(valid_points, valid_data)
50
+ except Exception:
51
+ # 备选方案:使用scipy的最近邻
52
+ from scipy.interpolate import NearestNDInterpolator
81
53
 
82
- valid_xyz = lonlat2xyz(valid_points_exp[:, 0], valid_points_exp[:, 1])
83
- target_xyz = lonlat2xyz(target_points[:, 0], target_points[:, 1])
54
+ points = np.column_stack((source_lons[mask].ravel(), source_lats[mask].ravel()))
55
+ values = data[mask].ravel()
84
56
 
85
- # 使用 griddata 的 cubic 插值以获得更好平滑效果
86
- result = griddata(valid_xyz, valid_data_exp, target_xyz, method=interpolation_method).reshape(target_shape)
57
+ if len(values) > 0:
58
+ interp = NearestNDInterpolator(points, values)
59
+ return interp(source_lons.ravel(), source_lats.ravel()).reshape(data.shape)
60
+ else:
61
+ return data # 无有效值可用于填充
87
62
 
88
- # 用最近邻处理残余 NaN
89
- if np.isnan(result).any():
90
- nn_interp = NearestNDInterpolator(valid_xyz, valid_data_exp)
91
- nn = nn_interp(target_xyz).reshape(target_shape)
92
- result[np.isnan(result)] = nn[np.isnan(result)]
63
+
64
+ def _interp_single_worker(*args):
65
+ """
66
+ 单slice插值worker,只使用pyinterp的bicubic方法,失败直接报错。
67
+ 参数: data_slice, source_lons, source_lats, target_lons, target_lats
68
+ """
69
+ if not pyinterp_available:
70
+ raise ImportError("pyinterp package is required for geographic interpolation")
71
+
72
+ data_slice, source_lons, source_lats, target_lons, target_lats = args
73
+
74
+ # 预处理:填充NaN值以确保数据完整
75
+ if np.isnan(data_slice).any():
76
+ data_filled = _fill_nan_with_nearest(data_slice, source_lons, source_lats)
77
+ else:
78
+ data_filled = data_slice
79
+
80
+ # 创建xarray DataArray
81
+ da = xr.DataArray(
82
+ data_filled,
83
+ coords={"latitude": source_lats, "longitude": source_lons},
84
+ dims=("latitude", "longitude"),
85
+ )
86
+
87
+ # 创建Grid2D对象
88
+ grid = pyxr.Grid2D(da)
89
+
90
+ # 使用bicubic方法插值
91
+ result = grid.bicubic(coords={"longitude": target_lons.ravel(), "latitude": target_lats.ravel()}, bounds_error=False, num_threads=1).reshape(target_lons.shape)
93
92
 
94
93
  return result
95
94
 
96
95
 
97
- def interp_2d_func_geo(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
96
+ def interp_2d_func_geo(
97
+ target_x_coordinates: Union[np.ndarray, List[float]],
98
+ target_y_coordinates: Union[np.ndarray, List[float]],
99
+ source_x_coordinates: Union[np.ndarray, List[float]],
100
+ source_y_coordinates: Union[np.ndarray, List[float]],
101
+ source_data: np.ndarray,
102
+ ) -> np.ndarray:
98
103
  """
99
- Perform 2D interpolation on the last two dimensions of a multi-dimensional array (spherical coordinates).
100
- 使用球面坐标系进行插值,适用于全球尺度的地理数据,能正确处理经度跨越日期线的情况。
104
+ 使用pyinterp进行地理插值,只使用bicubic方法。
101
105
 
102
106
  Args:
103
- target_x_coordinates (Union[np.ndarray, List[float]]): Target grid's longitude (-180 to 180 or 0 to 360).
104
- target_y_coordinates (Union[np.ndarray, List[float]]): Target grid's latitude (-90 to 90).
105
- source_x_coordinates (Union[np.ndarray, List[float]]): Original grid's longitude (-180 to 180 or 0 to 360).
106
- source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's latitude (-90 to 90).
107
- source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
108
- interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
109
- >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
107
+ target_x_coordinates: 目标点经度 (-180 to 180 0 to 360)
108
+ target_y_coordinates: 目标点纬度 (-90 to 90)
109
+ source_x_coordinates: 源数据经度 (-180 to 180 0 to 360)
110
+ source_y_coordinates: 源数据纬度 (-90 to 90)
111
+ source_data: 多维数组,最后两个维度为空间维度
110
112
 
111
113
  Returns:
112
- np.ndarray: Interpolated data array.
113
-
114
- Raises:
115
- ValueError: If input shapes are invalid.
116
-
117
- Examples:
118
- >>> # 创建一个全球网格示例
119
- >>> target_lon = np.arange(-180, 181, 1) # 1度分辨率目标网格
120
- >>> target_lat = np.arange(-90, 91, 1)
121
- >>> source_lon = np.arange(-180, 181, 5) # 5度分辨率源网格
122
- >>> source_lat = np.arange(-90, 91, 5)
123
- >>> # 创建一个简单的数据场 (例如温度场)
124
- >>> source_data = np.cos(np.deg2rad(source_lat.reshape(-1, 1))) * np.cos(np.deg2rad(source_lon))
125
- >>> # 插值到高分辨率网格
126
- >>> result = interp_2d_geo(target_lon, target_lat, source_lon, source_lat, source_data)
127
- >>> print(result.shape) # Expected output: (181, 361)
114
+ np.ndarray: 插值后的数据数组
128
115
  """
129
- # 验证输入数据范围
116
+ if not pyinterp_available:
117
+ raise ImportError("pyinterp package is required for geographic interpolation")
118
+
119
+ # 验证纬度范围
130
120
  if np.nanmin(target_y_coordinates) < -90 or np.nanmax(target_y_coordinates) > 90:
131
- raise ValueError("[red]Target latitude must be in range [-90, 90].[/red]")
121
+ raise ValueError("Target latitude must be in range [-90, 90].")
132
122
  if np.nanmin(source_y_coordinates) < -90 or np.nanmax(source_y_coordinates) > 90:
133
- raise ValueError("[red]Source latitude must be in range [-90, 90].[/red]")
123
+ raise ValueError("Source latitude must be in range [-90, 90].")
134
124
 
135
- if len(target_y_coordinates.shape) == 1:
136
- target_x_coordinates, target_y_coordinates = np.meshgrid(target_x_coordinates, target_y_coordinates)
137
- if len(source_y_coordinates.shape) == 1:
125
+ # 确保使用numpy数组
126
+ source_x_coordinates = np.array(source_x_coordinates)
127
+ source_y_coordinates = np.array(source_y_coordinates)
128
+ target_x_coordinates = np.array(target_x_coordinates)
129
+ target_y_coordinates = np.array(target_y_coordinates)
130
+
131
+ # 创建网格坐标(如果是一维的)
132
+ if source_x_coordinates.ndim == 1:
138
133
  source_x_coordinates, source_y_coordinates = np.meshgrid(source_x_coordinates, source_y_coordinates)
134
+ if target_x_coordinates.ndim == 1:
135
+ target_x_coordinates, target_y_coordinates = np.meshgrid(target_x_coordinates, target_y_coordinates)
139
136
 
137
+ # 验证源数据形状
140
138
  if source_x_coordinates.shape != source_data.shape[-2:] or source_y_coordinates.shape != source_data.shape[-2:]:
141
- raise ValueError("[red]Shape of source_data does not match shape of source_x_coordinates or source_y_coordinates.[/red]")
142
-
143
- target_points = np.column_stack((np.array(target_x_coordinates).ravel(), np.array(target_y_coordinates).ravel()))
144
- origin_points = np.column_stack((np.array(source_x_coordinates).ravel(), np.array(source_y_coordinates).ravel()))
139
+ raise ValueError("Shape of source_data does not match shape of source_x_coordinates or source_y_coordinates.")
145
140
 
146
- data_dims = len(source_data.shape)
141
+ # 处理多维数据
142
+ data_dims = source_data.ndim
147
143
  if data_dims < 2:
148
- raise ValueError(f"[red]Source data must have at least 2 dimensions, but got {data_dims}.[/red]")
144
+ raise ValueError(f"Source data must have at least 2 dimensions, but got {data_dims}.")
149
145
  elif data_dims > 4:
150
- raise ValueError(f"Source data has {data_dims} dimensions, but this function currently supports only up to 4.")
146
+ raise ValueError(f"Source data has {data_dims} dimensions, but this function currently supports up to 4.")
151
147
 
148
+ # 扩展到4D
152
149
  num_dims_to_add = 4 - data_dims
153
- new_shape = (1,) * num_dims_to_add + source_data.shape
154
- new_src_data = source_data.reshape(new_shape)
155
-
156
- t, z, y, x = new_src_data.shape
150
+ source_data = source_data.reshape((1,) * num_dims_to_add + source_data.shape)
151
+ t, z, y, x = source_data.shape
157
152
 
153
+ # 准备并行处理参数
158
154
  params = []
159
- target_shape = target_y_coordinates.shape
160
155
  for t_index in range(t):
161
156
  for z_index in range(z):
162
- params.append((new_src_data[t_index, z_index], origin_points, target_points, interpolation_method, target_shape))
163
-
164
- with PEx() as excutor:
165
- result = excutor.run(_interp_single_worker, params)
166
-
167
- return np.squeeze(np.array(result).reshape(t, z, *target_shape))
157
+ params.append(
158
+ (
159
+ source_data[t_index, z_index],
160
+ source_x_coordinates[0, :], # 假设经度在每行都相同
161
+ source_y_coordinates[:, 0], # 假设纬度在每列都相同
162
+ target_x_coordinates,
163
+ target_y_coordinates,
164
+ )
165
+ )
166
+
167
+ # 并行执行插值
168
+ with PEx() as executor:
169
+ results = executor.run(_interp_single_worker, params)
170
+
171
+ # 还原到原始维度
172
+ return np.squeeze(np.array(results).reshape((t, z) + target_x_coordinates.shape))
@@ -263,10 +263,45 @@ def process_variable(var: str, data: xr.DataArray, dims: int, dims_name: Tuple[s
263
263
  print(f"Error processing {var}_{dims_name[0]}-{i}_{dims_name[1]}-{j}: {e}")
264
264
 
265
265
 
266
- def func_plot_dataset(ds_in: Union[xr.Dataset, xr.DataArray], output_dir: str, xyzt_dims: Tuple[str, str, str, str] = ("longitude", "latitude", "level", "time"), plot_type: str = "contourf", fixed_colorscale: bool = False) -> None:
266
+ def get_xyzt_names(ds_in, xyzt_dims):
267
+ dims_dict = {
268
+ "x": ["longitude", "lon", "x", "lon_rho", "lon_u", "lon_v", "xi_rho", "xi_u", "xi_v",
269
+ "xc", "x_rho", "xlon", "nlon", "east_west", "i", "xh", "xq", "nav_lon"],
270
+ "y": ["latitude", "lat", "y", "lat_rho", "lat_u", "lat_v", "eta_rho", "eta_u", "eta_v",
271
+ "yc", "y_rho", "ylat", "nlat", "north_south", "j", "yh", "yq", "nav_lat"],
272
+ "z": ["level", "lev", "z", "depth", "height", "pressure", "s_rho", "s_w",
273
+ "altitude", "plev", "isobaric", "vertical", "k", "sigma", "hybrid", "theta",
274
+ "pres", "sigma_level", "z_rho", "z_w", "layers", "deptht", "nav_lev"],
275
+ "t": ["time", "t", "ocean_time", "bry_time", 'frc_time',
276
+ "time_counter", "Time", "Julian_day", "forecast_time", "clim_time", "model_time"],
277
+ }
278
+ if xyzt_dims is not None:
279
+ x_dim, y_dim, z_dim, t_dim = xyzt_dims
280
+ return x_dim, y_dim, z_dim, t_dim
281
+ data_dim_names = ds_in.dims
282
+ x_dim, y_dim, z_dim, t_dim = None, None, None, None
283
+ for dim in dims_dict['x']:
284
+ if dim in data_dim_names:
285
+ x_dim = dim
286
+ break
287
+ for dim in dims_dict['y']:
288
+ if dim in data_dim_names:
289
+ y_dim = dim
290
+ break
291
+ for dim in dims_dict['z']:
292
+ if dim in data_dim_names:
293
+ z_dim = dim
294
+ break
295
+ for dim in dims_dict['t']:
296
+ if dim in data_dim_names:
297
+ t_dim = dim
298
+ break
299
+ return x_dim, y_dim, z_dim, t_dim
300
+
301
+
302
+ def func_plot_dataset(ds_in: Union[xr.Dataset, xr.DataArray], output_dir: str, xyzt_dims: Tuple[str, str, str, str] = None, plot_type: str = "contourf", fixed_colorscale: bool = False) -> None:
267
303
  """Plot variables from a NetCDF file and save the plots to the specified directory."""
268
304
  os.makedirs(output_dir, exist_ok=True)
269
- x_dim, y_dim, z_dim, t_dim = xyzt_dims
270
305
 
271
306
  # Main processing function
272
307
  try:
@@ -277,9 +312,11 @@ def func_plot_dataset(ds_in: Union[xr.Dataset, xr.DataArray], output_dir: str, x
277
312
  var = ds_in.name if ds_in.name is not None else "unnamed_variable"
278
313
  print("=" * 120)
279
314
  print(f"Processing: {var}")
315
+
280
316
  try:
281
317
  dims = len(ds_in.shape)
282
318
  dims_name = ds_in.dims
319
+ x_dim, y_dim, z_dim, t_dim = get_xyzt_names(ds_in, xyzt_dims)
283
320
  process_variable(var, ds_in, dims, dims_name, output_dir, x_dim, y_dim, z_dim, t_dim, fixed_colorscale, plot_type)
284
321
  except Exception as e:
285
322
  print(f"Error processing variable {var}: {e}")
@@ -295,6 +332,7 @@ def func_plot_dataset(ds_in: Union[xr.Dataset, xr.DataArray], output_dir: str, x
295
332
  data = ds[var]
296
333
  dims = len(data.shape)
297
334
  dims_name = data.dims
335
+ x_dim, y_dim, z_dim, t_dim = get_xyzt_names(data, xyzt_dims)
298
336
  try:
299
337
  process_variable(var, data, dims, dims_name, output_dir, x_dim, y_dim, z_dim, t_dim, fixed_colorscale, plot_type)
300
338
  except Exception as e:
oafuncs/oa_data.py CHANGED
@@ -13,7 +13,6 @@ SystemInfo: Windows 11
13
13
  Python Version: 3.11
14
14
  """
15
15
 
16
-
17
16
  from typing import Any, List, Union
18
17
 
19
18
  import numpy as np
@@ -22,7 +21,6 @@ import xarray as xr
22
21
  from rich import print
23
22
  from scipy.interpolate import interp1d
24
23
 
25
-
26
24
  __all__ = ["interp_along_dim", "interp_2d", "interp_2d_geo", "ensure_list", "mask_shapefile"]
27
25
 
28
26
 
@@ -152,7 +150,7 @@ def interp_2d(
152
150
  >>> print(result.shape) # Expected output: (3, 3)
153
151
  """
154
152
  from ._script.data_interp import interp_2d_func
155
-
153
+
156
154
  return interp_2d_func(
157
155
  target_x_coordinates=target_x_coordinates,
158
156
  target_y_coordinates=target_y_coordinates,
@@ -162,47 +160,73 @@ def interp_2d(
162
160
  interpolation_method=interpolation_method,
163
161
  )
164
162
 
165
- def interp_2d_geo(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
163
+
164
+ def interp_2d_geo(
165
+ target_x_coordinates: Union[np.ndarray, List[float]],
166
+ target_y_coordinates: Union[np.ndarray, List[float]],
167
+ source_x_coordinates: Union[np.ndarray, List[float]],
168
+ source_y_coordinates: Union[np.ndarray, List[float]],
169
+ source_data: np.ndarray,
170
+ ) -> np.ndarray:
166
171
  """
167
- Perform 2D interpolation on the last two dimensions of a multi-dimensional array (spherical coordinates).
168
- 使用球面坐标系进行插值,适用于全球尺度的地理数据,能正确处理经度跨越日期线的情况。
172
+ 使用pyinterp进行地理插值,适用于全球尺度的地理数据与区域数据。
173
+
174
+ 特点:
175
+ - 正确处理经度跨越日期线的情况
176
+ - 自动选择最佳插值策略
177
+ - 处理规则网格和非规则数据
178
+ - 支持多维数据并行处理
169
179
 
170
180
  Args:
171
- target_x_coordinates (Union[np.ndarray, List[float]]): Target grid's longitude (-180 to 180 or 0 to 360).
172
- target_y_coordinates (Union[np.ndarray, List[float]]): Target grid's latitude (-90 to 90).
173
- source_x_coordinates (Union[np.ndarray, List[float]]): Original grid's longitude (-180 to 180 or 0 to 360).
174
- source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's latitude (-90 to 90).
175
- source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
176
- interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
177
- >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
181
+ target_x_coordinates: 目标点经度 (-180 to 180 0 to 360)
182
+ target_y_coordinates: 目标点纬度 (-90 to 90)
183
+ source_x_coordinates: 源数据经度 (-180 to 180 0 to 360)
184
+ source_y_coordinates: 源数据纬度 (-90 to 90)
185
+ source_data: 多维数组,最后两个维度为空间维度
186
+ interpolation_method: 插值方法: 只会使用 'bicubic' 方法。
178
187
 
179
188
  Returns:
180
- np.ndarray: Interpolated data array.
181
-
182
- Raises:
183
- ValueError: If input shapes are invalid.
189
+ np.ndarray: 插值后的数据数组
184
190
 
185
191
  Examples:
186
- >>> # 创建一个全球网格示例
187
- >>> target_lon = np.arange(-180, 181, 1) # 1度分辨率目标网格
192
+ >>> # 全球数据插值示例
193
+ >>> target_lon = np.arange(-180, 181, 1)
188
194
  >>> target_lat = np.arange(-90, 91, 1)
189
- >>> source_lon = np.arange(-180, 181, 5) # 5度分辨率源网格
195
+ >>> source_lon = np.arange(-180, 181, 5)
190
196
  >>> source_lat = np.arange(-90, 91, 5)
191
- >>> # 创建一个简单的数据场 (例如温度场)
192
197
  >>> source_data = np.cos(np.deg2rad(source_lat.reshape(-1, 1))) * np.cos(np.deg2rad(source_lon))
193
- >>> # 插值到高分辨率网格
194
198
  >>> result = interp_2d_geo(target_lon, target_lat, source_lon, source_lat, source_data)
195
- >>> print(result.shape) # Expected output: (181, 361)
196
199
  """
197
- from ._script.data_interp_geo import interp_2d_func_geo
198
- interp_2d_func_geo(
199
- target_x_coordinates=target_x_coordinates,
200
- target_y_coordinates=target_y_coordinates,
201
- source_x_coordinates=source_x_coordinates,
202
- source_y_coordinates=source_y_coordinates,
203
- source_data=source_data,
204
- interpolation_method=interpolation_method,
205
- )
200
+ # 使用importlib检查pyinterp是否可用,避免直接import导致的警告
201
+ import importlib.util
202
+
203
+ pyinterp_available = importlib.util.find_spec("pyinterp") is not None
204
+
205
+ if pyinterp_available:
206
+ # 只在pyinterp可用时才导入相关模块
207
+ from ._script.data_interp_geo import interp_2d_func_geo
208
+
209
+ return interp_2d_func_geo(
210
+ target_x_coordinates=target_x_coordinates,
211
+ target_y_coordinates=target_y_coordinates,
212
+ source_x_coordinates=source_x_coordinates,
213
+ source_y_coordinates=source_y_coordinates,
214
+ source_data=source_data,
215
+ )
216
+ else:
217
+ print("[yellow]警告: pyinterp模块未安装,无法使用球面坐标插值。尝试使用平面插值作为备选方案。[/yellow]")
218
+ print("[yellow]推荐使用 pip install pyinterp 安装pyinterp以获得更准确的地理数据插值结果。[/yellow]")
219
+ try:
220
+ return interp_2d(
221
+ target_x_coordinates=target_x_coordinates,
222
+ target_y_coordinates=target_y_coordinates,
223
+ source_x_coordinates=source_x_coordinates,
224
+ source_y_coordinates=source_y_coordinates,
225
+ source_data=source_data,
226
+ )
227
+ except Exception as e:
228
+ raise ImportError(f"pyinterp不可用且备选插值方法也失败: {e}")
229
+
206
230
 
207
231
  def mask_shapefile(
208
232
  data_array: np.ndarray,
oafuncs/oa_nc.py CHANGED
@@ -246,7 +246,7 @@ def draw(
246
246
  output_directory: Optional[str] = None,
247
247
  dataset: Optional[xr.Dataset] = None,
248
248
  file_path: Optional[str] = None,
249
- dimensions: Union[List[str], Tuple[str, str, str, str]] = ("longitude", "latitude", "level", "time"),
249
+ dims_xyzt: Union[List[str], Tuple[str, str, str, str]] = None,
250
250
  plot_style: str = "contourf",
251
251
  use_fixed_colorscale: bool = False,
252
252
  ) -> None:
@@ -257,7 +257,7 @@ def draw(
257
257
  output_directory (Optional[str]): Path of the output directory.
258
258
  dataset (Optional[xr.Dataset]): Xarray dataset to plot.
259
259
  file_path (Optional[str]): Path to the NetCDF file.
260
- dimensions (Union[List[str], Tuple[str, str, str, str]]): Dimensions for plotting.
260
+ dims_xyzt (Union[List[str], Tuple[str, str, str, str]]): Dimensions for plotting. xyzt
261
261
  plot_style (str): Type of the plot, e.g., "contourf" or "contour". Default is "contourf".
262
262
  use_fixed_colorscale (bool): Whether to use a fixed colorscale. Default is False.
263
263
 
@@ -268,15 +268,15 @@ def draw(
268
268
 
269
269
  if output_directory is None:
270
270
  output_directory = os.getcwd()
271
- if not isinstance(dimensions, (list, tuple)):
271
+ if not isinstance(dims_xyzt, (list, tuple)):
272
272
  raise ValueError("dimensions must be a list or tuple")
273
273
 
274
274
  if dataset is not None:
275
- func_plot_dataset(dataset, output_directory, tuple(dimensions), plot_style, use_fixed_colorscale)
275
+ func_plot_dataset(dataset, output_directory, tuple(dims_xyzt), plot_style, use_fixed_colorscale)
276
276
  elif file_path is not None:
277
277
  if check(file_path):
278
278
  ds = xr.open_dataset(file_path)
279
- func_plot_dataset(ds, output_directory, tuple(dimensions), plot_style, use_fixed_colorscale)
279
+ func_plot_dataset(ds, output_directory, tuple(dims_xyzt), plot_style, use_fixed_colorscale)
280
280
  else:
281
281
  print(f"[red]Invalid file: {file_path}[/red]")
282
282
  else:
@@ -288,10 +288,11 @@ def compress(src_path, dst_path=None,convert_dtype='int16'):
288
288
  压缩 NetCDF 文件,使用 scale_factor/add_offset 压缩数据。
289
289
  若 dst_path 省略,则自动生成新文件名,写出后删除原文件并将新文件改回原名。
290
290
  """
291
+ src_path = str(src_path)
291
292
  # 判断是否要替换原文件
292
293
  delete_orig = dst_path is None
293
294
  if delete_orig:
294
- dst_path = src_path.replace(".nc", "_compress.nc")
295
+ dst_path = src_path.replace(".nc", "_compress_temp.nc")
295
296
 
296
297
  ds = xr.open_dataset(src_path)
297
298
  save(dst_path, ds, convert_dtype=convert_dtype, use_scale_offset=True, use_compression=True)
@@ -313,10 +314,11 @@ def unscale(src_path, dst_path=None, compression_level=4):
313
314
  dst_path: 目标文件路径,None则替换原文件
314
315
  compression_level: 压缩级别(1-9),数值越大压缩比越高,速度越慢
315
316
  """
317
+ src_path = str(src_path)
316
318
  # 判断是否要替换原文件
317
319
  delete_orig = dst_path is None
318
320
  if delete_orig:
319
- dst_path = src_path.replace(".nc", "_unpacked.nc")
321
+ dst_path = src_path.replace(".nc", "_unpacked_temp.nc")
320
322
 
321
323
  # 打开原始文件,获取文件大小
322
324
  orig_size = os.path.getsize(src_path) / (1024 * 1024) # MB
oafuncs/oa_tool.py CHANGED
@@ -135,7 +135,7 @@ def email(title: str = "Title", content: Optional[str] = None, send_to: str = "1
135
135
 
136
136
  def pbar(
137
137
  iterable: Iterable = range(100),
138
- description: str = "Working",
138
+ description: str = None,
139
139
  total: Optional[float] = None,
140
140
  completed: float = 0,
141
141
  color: Any = "None",
@@ -162,21 +162,9 @@ def pbar(
162
162
 
163
163
  Returns:
164
164
  Any: An instance of ColorProgressBar.
165
-
166
- Example:
167
- >>> for i in pbar(range(10), description="Processing"):
168
- ... time.sleep(0.1)
169
- >>> for i in pbar(range(10), description="Processing", color="green"):
170
- ... time.sleep(0.1)
171
- >>> for i in pbar(range(10), description="Processing", cmap=["red", "green"]):
172
- ... time.sleep(0.1)
173
- >>> for i in pbar(range(10), description="Processing", cmap="viridis"):
174
- ... time.sleep(0.1)
175
165
  """
176
166
  from ._script.cprogressbar import ColorProgressBar
177
167
  import random
178
-
179
- # number = random.randint(1, 999)
180
168
 
181
169
  def _generate_random_color_hex():
182
170
  """Generate a random color in hexadecimal format."""
@@ -188,11 +176,10 @@ def pbar(
188
176
  if color == 'None' and cmap is None:
189
177
  color = _generate_random_color_hex()
190
178
 
191
- style = f"bold {color if color != 'None' else 'green'}"
192
- # print(f"[{style}]~*^* {description} *^*~ -> {number:03d}[/{style}]")
193
- print(f"[{style}]~*^* {description} *^*~[/{style}]")
179
+ if description is not None:
180
+ style = f"bold {color if color != 'None' else 'green'}"
181
+ print(f"[{style}]~*^* {description} *^*~[/{style}]")
194
182
 
195
- # description=f'{number:03d}'
196
183
  description = ""
197
184
 
198
185
  return ColorProgressBar(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: oafuncs
3
- Version: 0.0.98.22
3
+ Version: 0.0.98.24
4
4
  Summary: Oceanic and Atmospheric Functions
5
5
  Home-page: https://github.com/Industry-Pays/OAFuncs
6
6
  Author: Kun Liu
@@ -1,25 +1,25 @@
1
1
  oafuncs/__init__.py,sha256=T_-VtnWWllV3Q91twT5Yt2sUapeA051QbPNnBxmg9nw,1456
2
2
  oafuncs/oa_cmap.py,sha256=pUFAGzbIg0WLxObBP2t_--ZIg00Dxdojx0y7OjTeqEo,11551
3
- oafuncs/oa_data.py,sha256=Aat9ktxxRGevaqQya3IJWfXeoEs-FCXGUcNE2pKnzfU,10931
3
+ oafuncs/oa_data.py,sha256=QiIDwAy0Gqvv-ulWFcMk0nND81GU3Cf_xgGJtJ7p2mc,11397
4
4
  oafuncs/oa_date.py,sha256=WhM6cyD4G3IeghjLTHhAMtlvJbA7kwQG2sHnxdTgyso,6303
5
5
  oafuncs/oa_draw.py,sha256=IaBGDx-EOxyMM2IuJ4zLZt6ruHHV5qFStPItmUOXoWk,17635
6
6
  oafuncs/oa_file.py,sha256=j9gXJgPOJsliu4IOUc4bc-luW4yBvQyNCEmMyDVjUwQ,16404
7
7
  oafuncs/oa_help.py,sha256=_4AZgRDq5Or0vauNvq5IDDHIBoBfdOQtzak-mG1wwAw,4537
8
- oafuncs/oa_nc.py,sha256=UUXnBg2cO5XiJ8w0jNqCZJg83FVKqxlEHxOJG5o08Z8,15201
8
+ oafuncs/oa_nc.py,sha256=pxTyR8f2mlu1Zkz4PJ5ImOyhrFD_mgytXHJjt9ohnUw,15233
9
9
  oafuncs/oa_python.py,sha256=NkopwkYFGSEuVljnTBvXCl6o2CeyRNBqRXSsUl3euEE,5192
10
- oafuncs/oa_tool.py,sha256=QBjJh3pf54yXVuOmu97rW6Tsr6uNMyZ5KqZbR4VQFTc,8628
10
+ oafuncs/oa_tool.py,sha256=Zuaoa92wll0YqXGRf0oF_c7wlATtl7bvjCuLt9VLXp0,8046
11
11
  oafuncs/_data/hycom.png,sha256=MadKs6Gyj5n9-TOu7L4atQfTXtF9dvN9w-tdU9IfygI,10945710
12
12
  oafuncs/_data/oafuncs.png,sha256=o3VD7wm-kwDea5E98JqxXl04_78cBX7VcdUt7uQXGiU,3679898
13
13
  oafuncs/_script/cprogressbar.py,sha256=UIgGcLFs-6IgWlITuBLaQqrpt4OAK3Mst5RlCiNfZdQ,15772
14
14
  oafuncs/_script/data_interp.py,sha256=EiZbt6n5BEaRKcng88UgX7TFPhKE6TLVZniS01awXjg,5146
15
- oafuncs/_script/data_interp_geo.py,sha256=ZRFb3fKRiYQViZNHd19eW20C9i38BsiIU8w0fG5mbqM,7789
15
+ oafuncs/_script/data_interp_geo.py,sha256=edddYkI2D0X8VIIrVUILz7cBXnosbmV8wZehp3w04Jw,6540
16
16
  oafuncs/_script/email.py,sha256=lL4HGKrr524-g0xLlgs-4u7x4-u7DtgNoD9AL8XJKj4,3058
17
17
  oafuncs/_script/netcdf_merge.py,sha256=tM9ePqLiEsE7eIsNM5XjEYeXwxjYOdNz5ejnEuI7xKw,6066
18
18
  oafuncs/_script/netcdf_modify.py,sha256=sGRUYNhfGgf9JV70rnBzw3bzuTRSXzBTL_RMDnDPeLQ,4552
19
19
  oafuncs/_script/netcdf_write.py,sha256=GvyUyUhzMonzSp3y4pT8ZAfbQrsh5J3dLnmINYJKhuE,21422
20
20
  oafuncs/_script/parallel.py,sha256=07-BJVHxXJNlrOrhrSGt7qCZiKWq6dBvNDBA1AANYnI,8861
21
21
  oafuncs/_script/parallel_test.py,sha256=0GBqZOX7IaCOKF2t1y8N8YYu53GJ33OkfsWgpvZNqM4,372
22
- oafuncs/_script/plot_dataset.py,sha256=zkSEnO_-biyagorwWXPoihts_cwuvripzEt-l9bHJ2E,13989
22
+ oafuncs/_script/plot_dataset.py,sha256=Hr4X0BHJ1qmf2YHT40Vu3nF8JS_4MlZ2MK6yeJCSHOg,15642
23
23
  oafuncs/_script/replace_file_content.py,sha256=eCFZjnZcwyRvy6b4mmIfBna-kylSZTyJRfgXd6DdCjk,5982
24
24
  oafuncs/oa_down/User_Agent-list.txt,sha256=pHaMlElMvZ8TG4vf4BqkZYKqe0JIGkr4kCN0lM1Y9FQ,514295
25
25
  oafuncs/oa_down/__init__.py,sha256=kRX5eTUCbAiz3zTaQM1501paOYS_3fizDN4Pa0mtNUA,585
@@ -39,8 +39,8 @@ oafuncs/oa_sign/__init__.py,sha256=QKqTFrJDFK40C5uvk48GlRRbGFzO40rgkYwu6dYxatM,5
39
39
  oafuncs/oa_sign/meteorological.py,sha256=8091SHo2L8kl4dCFmmSH5NGVHDku5i5lSiLEG5DLnOQ,6489
40
40
  oafuncs/oa_sign/ocean.py,sha256=xrW-rWD7xBWsB5PuCyEwQ1Q_RDKq2KCLz-LOONHgldU,5932
41
41
  oafuncs/oa_sign/scientific.py,sha256=a4JxOBgm9vzNZKpJ_GQIQf7cokkraV5nh23HGbmTYKw,5064
42
- oafuncs-0.0.98.22.dist-info/licenses/LICENSE.txt,sha256=rMtLpVg8sKiSlwClfR9w_Dd_5WubTQgoOzE2PDFxzs4,1074
43
- oafuncs-0.0.98.22.dist-info/METADATA,sha256=ctJ9aAoY3RztAP6gD2STCFB0ZZaCbXQV8SufCLMGkbM,4273
44
- oafuncs-0.0.98.22.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
45
- oafuncs-0.0.98.22.dist-info/top_level.txt,sha256=bgC35QkXbN4EmPHEveg_xGIZ5i9NNPYWqtJqaKqTPsQ,8
46
- oafuncs-0.0.98.22.dist-info/RECORD,,
42
+ oafuncs-0.0.98.24.dist-info/licenses/LICENSE.txt,sha256=rMtLpVg8sKiSlwClfR9w_Dd_5WubTQgoOzE2PDFxzs4,1074
43
+ oafuncs-0.0.98.24.dist-info/METADATA,sha256=ZeGzkxArxlWU9YOHhouWFTTNffZrZVY64OrqNxXKQTc,4273
44
+ oafuncs-0.0.98.24.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
45
+ oafuncs-0.0.98.24.dist-info/top_level.txt,sha256=bgC35QkXbN4EmPHEveg_xGIZ5i9NNPYWqtJqaKqTPsQ,8
46
+ oafuncs-0.0.98.24.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5