oafuncs 0.0.98.19__py3-none-any.whl → 0.0.98.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,57 @@
1
1
  import os
2
+ import warnings
2
3
 
3
4
  import netCDF4 as nc
4
5
  import numpy as np
5
6
  import xarray as xr
6
- import warnings
7
7
 
8
8
  warnings.filterwarnings("ignore", category=RuntimeWarning)
9
9
 
10
10
 
11
+
12
+ def _nan_to_fillvalue(ncfile,set_fill_value):
13
+ """
14
+ 将 NetCDF 文件中所有变量的 NaN 和掩码值替换为其 _FillValue 属性(若无则自动添加 _FillValue=-32767 并替换)。
15
+ 同时处理掩码数组中的无效值。
16
+ 仅对数值型变量(浮点型、整型)生效。
17
+ """
18
+ with nc.Dataset(ncfile, "r+") as ds:
19
+ for var_name in ds.variables:
20
+ var = ds.variables[var_name]
21
+ # 只处理数值类型变量 (f:浮点型, i:有符号整型, u:无符号整型)
22
+ if var.dtype.kind not in ["f", "i", "u"]:
23
+ continue
24
+
25
+ # 读取数据
26
+ arr = var[:]
27
+
28
+ # 确定填充值
29
+ if "_FillValue" in var.ncattrs():
30
+ fill_value = var.getncattr("_FillValue")
31
+ elif hasattr(var, "missing_value"):
32
+ fill_value = var.getncattr("missing_value")
33
+ else:
34
+ fill_value = set_fill_value
35
+ try:
36
+ var.setncattr("_FillValue", fill_value)
37
+ except Exception:
38
+ # 某些变量可能不允许动态添加 _FillValue
39
+ continue
40
+
41
+ # 处理掩码数组
42
+ if hasattr(arr, "mask"):
43
+ # 如果是掩码数组,将掩码位置的值设为 fill_value
44
+ if np.any(arr.mask):
45
+ arr = np.where(arr.mask, fill_value, arr.data if hasattr(arr, "data") else arr)
46
+
47
+ # 处理剩余 NaN 和无穷值
48
+ if arr.dtype.kind in ["f", "i", "u"] and np.any(~np.isfinite(arr)):
49
+ arr = np.nan_to_num(arr, nan=fill_value, posinf=fill_value, neginf=fill_value)
50
+
51
+ # 写回变量
52
+ var[:] = arr
53
+
54
+
11
55
  def _numpy_to_nc_type(numpy_type):
12
56
  """将 NumPy 数据类型映射到 NetCDF 数据类型"""
13
57
  numpy_to_nc = {
@@ -26,55 +70,93 @@ def _numpy_to_nc_type(numpy_type):
26
70
  return numpy_to_nc.get(numpy_type_str, "f4")
27
71
 
28
72
 
29
- def _calculate_scale_and_offset(data, n=16):
73
+ def _calculate_scale_and_offset(data, dtype="int32"):
30
74
  """
31
- 计算数值型数据的 scale_factoradd_offset
32
- 将数据映射到 [0, 2**n - 1] 的范围。
33
-
34
- 要求 data 为数值型的 NumPy 数组,不允许全 NaN 值。
75
+ 只对有效数据(非NaN、非填充值、非自定义缺失值)计算scale_factoradd_offset
76
+ 使用 int32 类型,n=32
35
77
  """
36
78
  if not isinstance(data, np.ndarray):
37
79
  raise ValueError("Input data must be a NumPy array.")
80
+
81
+ if dtype == "int32":
82
+ n = 32
83
+ fill_value = np.iinfo(np.int32).min # -2147483648
84
+ elif dtype == "int16":
85
+ n = 16
86
+ fill_value = np.iinfo(np.int16).min # -32768
87
+ else:
88
+ raise ValueError("Unsupported dtype. Supported types are 'int16' and 'int32'.")
38
89
 
39
- data_min = np.nanmin(data)
40
- data_max = np.nanmax(data)
90
+ # 有效掩码:非NaN、非inf、非fill_value
91
+ valid_mask = np.isfinite(data) & (data != fill_value)
92
+ if hasattr(data, "mask") and np.ma.is_masked(data):
93
+ valid_mask &= ~data.mask
41
94
 
42
- if np.isnan(data_min) or np.isnan(data_max):
43
- raise ValueError("Input data contains NaN values.")
95
+ if np.any(valid_mask):
96
+ data_min = np.min(data[valid_mask])-1
97
+ data_max = np.max(data[valid_mask])+1
98
+ else:
99
+ data_min, data_max = 0, 1
44
100
 
101
+ # 防止scale为0,且保证scale/offset不会影响缺省值
45
102
  if data_max == data_min:
46
103
  scale_factor = 1.0
47
104
  add_offset = data_min
48
105
  else:
49
- scale_factor = (data_max - data_min) / (2**n - 1)
50
- add_offset = data_min + 2 ** (n - 1) * scale_factor
106
+ scale_factor = (data_max - data_min) / (2**n - 2)
107
+ add_offset = (data_max + data_min) / 2.0
51
108
  return scale_factor, add_offset
52
109
 
53
110
 
54
- def _data_to_scale_offset(data, scale, offset):
111
+ def _data_to_scale_offset(data, scale, offset, dtype='int32'):
55
112
  """
56
- 将数据转换为 scale_factor 和 add_offset 的形式。
57
- 此处同时替换 NaN、正无穷和负无穷为填充值 -32767,
58
- 以确保转换后的数据可安全转为 int16。
113
+ 只对有效数据做缩放,NaN/inf/填充值直接赋为fill_value。
114
+ 掩码区域的值会被保留并进行缩放,除非掩码本身标记为无效。
115
+ 使用 int32 类型
59
116
  """
60
117
  if not isinstance(data, np.ndarray):
61
118
  raise ValueError("Input data must be a NumPy array.")
119
+
120
+ if dtype == "int32":
121
+ # n = 32
122
+ np_dtype = np.int32
123
+ fill_value = np.iinfo(np.int32).min # -2147483648
124
+ clip_min = np.iinfo(np.int32).min + 1 # -2147483647
125
+ clip_max = np.iinfo(np.int32).max # 2147483647
126
+ elif dtype == "int16":
127
+ # n = 16
128
+ np_dtype = np.int16
129
+ fill_value = np.iinfo(np.int16).min # -32768
130
+ clip_min = np.iinfo(np.int16).min + 1 # -32767
131
+ clip_max = np.iinfo(np.int16).max # 32767
132
+ else:
133
+ raise ValueError("Unsupported dtype. Supported types are 'int16' and 'int32'.")
134
+
135
+ # 创建掩码,只排除 NaN/inf 和显式的填充值
136
+ valid_mask = np.isfinite(data)
137
+ valid_mask &= data != fill_value
138
+
139
+ # 如果数据有掩码属性,还需考虑掩码
140
+ if hasattr(data, "mask") and np.ma.is_masked(data):
141
+ # 只有掩码标记的区域视为无效
142
+ valid_mask &= ~data.mask
62
143
 
63
- # 先计算转换后的数据
64
- result = np.around((data - offset) / scale)
65
- # 替换 NaN, 正负无穷(posinf, neginf)为 -32767
66
- result = np.nan_to_num(result, nan=-32767, posinf=-32767, neginf=-32767)
67
- result = np.clip(result, -32767, 32767) # 限制范围在 int16 的有效范围内
68
- result = np.where(np.isfinite(result), result, -32767) # 替换无效值为 -32767
69
- new_data = result.astype(np.int16)
70
- return new_data
144
+ result = data.copy()
145
+ if np.any(valid_mask):
146
+ # 反向映射时能还原原始值
147
+ scaled = (data[valid_mask] - offset) / scale
148
+ scaled = np.round(scaled).astype(np_dtype)
149
+ # clip到int32范围,保留最大范围供转换
150
+ scaled = np.clip(scaled, clip_min, clip_max) # 不使用 -2147483648,保留做 _FillValue
151
+ result[valid_mask] = scaled
152
+ return result
71
153
 
72
154
 
73
- def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_switch=True, compile_switch=True):
155
+ def save_to_nc(file, data, varname=None, coords=None, mode="w", convert_dtype='int32',scale_offset_switch=True, compile_switch=True, preserve_mask_values=True):
74
156
  """
75
157
  保存数据到 NetCDF 文件,支持 xarray 对象(DataArray 或 Dataset)和 numpy 数组。
76
158
 
77
- 仅对数据变量中数值型数据进行压缩转换(利用 scale_factor/add_offset 转换后转为 int16),
159
+ 仅对数据变量中数值型数据进行压缩转换(利用 scale_factor/add_offset 转换后转为 int32),
78
160
  非数值型数据以及所有坐标变量将禁用任何压缩,直接保存原始数据。
79
161
 
80
162
  参数:
@@ -83,62 +165,134 @@ def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_swi
83
165
  - varname: 变量名(仅适用于传入 numpy 数组或 DataArray 时)
84
166
  - coords: 坐标字典(numpy 数组分支时使用),所有坐标变量均不压缩
85
167
  - mode: "w"(覆盖)或 "a"(追加)
168
+ - convert_dtype: 转换为的数值类型("int16" 或 "int32"),默认为 "int32"
86
169
  - scale_offset_switch: 是否对数值型数据变量进行压缩转换
87
170
  - compile_switch: 是否启用 NetCDF4 的 zlib 压缩(仅针对数值型数据有效)
171
+ - missing_value: 自定义缺失值,将被替换为 fill_value
172
+ - preserve_mask_values: 是否保留掩码区域的原始值(True)或将其替换为缺省值(False)
88
173
  """
174
+ if convert_dtype not in ["int16", "int32"]:
175
+ convert_dtype = "int32"
176
+ nc_dtype = _numpy_to_nc_type(convert_dtype)
177
+ # fill_value = np.iinfo(np.convert_dtype).min # -2147483648 或 -32768
178
+ # fill_value = np.iinfo(eval('np.' + convert_dtype)).min # -2147483648 或 -32768
179
+ np_dtype = getattr(np, convert_dtype) # 更安全的类型获取方式
180
+ fill_value = np.iinfo(np_dtype).min
181
+ # ----------------------------------------------------------------------------
89
182
  # 处理 xarray 对象(DataArray 或 Dataset)的情况
90
183
  if isinstance(data, (xr.DataArray, xr.Dataset)):
91
- encoding = {} # 用于保存数据变量的编码信息
184
+ encoding = {}
92
185
 
93
186
  if isinstance(data, xr.DataArray):
94
187
  if data.name is None:
95
188
  data = data.rename("data")
96
189
  varname = data.name if varname is None else varname
97
- # 判断数据是否为数值型
98
- if np.issubdtype(data.values.dtype, np.number) and scale_offset_switch:
99
- scale, offset = _calculate_scale_and_offset(data.values)
100
- new_values = _data_to_scale_offset(data.values, scale, offset)
101
- # 生成新 DataArray,保留原坐标和属性,同时写入转换参数到属性中
190
+ arr = np.array(data.values)
191
+ try:
192
+ data_missing_val = data.attrs.get("missing_value")
193
+ except AttributeError:
194
+ data_missing_val = data.attrs.get("_FillValue", None)
195
+ # 只对有效数据计算scale/offset
196
+ valid_mask = np.ones(arr.shape, dtype=bool) # 默认所有值都有效
197
+ if arr.dtype.kind in ["f", "i", "u"]: # 仅对数值数据应用isfinite
198
+ valid_mask = np.isfinite(arr)
199
+ if data_missing_val is not None:
200
+ valid_mask &= arr != data_missing_val
201
+ if hasattr(arr, "mask"):
202
+ valid_mask &= ~getattr(arr, "mask", False)
203
+ if np.issubdtype(arr.dtype, np.number) and scale_offset_switch:
204
+ arr_valid = arr[valid_mask]
205
+ scale, offset = _calculate_scale_and_offset(arr_valid, convert_dtype)
206
+ # 写入前处理无效值(只在这里做!)
207
+ arr_to_save = arr.copy()
208
+ # 处理自定义缺失值
209
+ if data_missing_val is not None:
210
+ arr_to_save[arr == data_missing_val] = fill_value
211
+ # 处理 NaN/inf
212
+ arr_to_save[~np.isfinite(arr_to_save)] = fill_value
213
+ new_values = _data_to_scale_offset(arr_to_save, scale, offset)
102
214
  new_da = data.copy(data=new_values)
215
+ # 移除 _FillValue 和 missing_value 属性
216
+ for k in ["_FillValue", "missing_value"]:
217
+ if k in new_da.attrs:
218
+ del new_da.attrs[k]
103
219
  new_da.attrs["scale_factor"] = float(scale)
104
220
  new_da.attrs["add_offset"] = float(offset)
105
221
  encoding[varname] = {
106
222
  "zlib": compile_switch,
107
223
  "complevel": 4,
108
- "dtype": "int16",
109
- "_FillValue": -32767,
224
+ "dtype": nc_dtype,
225
+ # "_FillValue": -2147483648,
110
226
  }
111
227
  new_da.to_dataset(name=varname).to_netcdf(file, mode=mode, encoding=encoding)
112
228
  else:
229
+ for k in ["_FillValue", "missing_value"]:
230
+ if k in data.attrs:
231
+ del data.attrs[k]
113
232
  data.to_dataset(name=varname).to_netcdf(file, mode=mode)
233
+ _nan_to_fillvalue(file, fill_value)
114
234
  return
115
235
 
116
- else:
117
- # 处理 Dataset 的情况,仅处理 data_vars 数据变量,坐标变量保持原样
236
+ else: # Dataset 情况
118
237
  new_vars = {}
119
238
  encoding = {}
120
239
  for var in data.data_vars:
121
240
  da = data[var]
122
- if np.issubdtype(np.asarray(da.values).dtype, np.number) and scale_offset_switch:
123
- scale, offset = _calculate_scale_and_offset(da.values)
124
- new_values = _data_to_scale_offset(da.values, scale, offset)
125
- new_da = xr.DataArray(new_values, dims=da.dims, coords=da.coords, attrs=da.attrs)
241
+ arr = np.array(da.values)
242
+ try:
243
+ data_missing_val = da.attrs.get("missing_value")
244
+ except AttributeError:
245
+ data_missing_val = da.attrs.get("_FillValue", None)
246
+ valid_mask = np.ones(arr.shape, dtype=bool) # 默认所有值都有效
247
+ if arr.dtype.kind in ["f", "i", "u"]: # 仅对数值数据应用isfinite
248
+ valid_mask = np.isfinite(arr)
249
+ if data_missing_val is not None:
250
+ valid_mask &= arr != data_missing_val
251
+ if hasattr(arr, "mask"):
252
+ valid_mask &= ~getattr(arr, "mask", False)
253
+
254
+ # 创建属性的副本以避免修改原始数据集
255
+ attrs = da.attrs.copy()
256
+ for k in ["_FillValue", "missing_value"]:
257
+ if k in attrs:
258
+ del attrs[k]
259
+
260
+ if np.issubdtype(arr.dtype, np.number) and scale_offset_switch:
261
+ # 处理边缘情况:检查是否有有效数据
262
+ if not np.any(valid_mask):
263
+ # 如果没有有效数据,创建一个简单的拷贝,不做转换
264
+ new_vars[var] = xr.DataArray(arr, dims=da.dims, coords=da.coords, attrs=attrs)
265
+ continue
266
+
267
+ arr_valid = arr[valid_mask]
268
+ scale, offset = _calculate_scale_and_offset(arr_valid, convert_dtype)
269
+ arr_to_save = arr.copy()
270
+
271
+ # 使用与DataArray相同的逻辑,使用_data_to_scale_offset处理数据
272
+ # 处理自定义缺失值
273
+ if data_missing_val is not None:
274
+ arr_to_save[arr == data_missing_val] = fill_value
275
+ # 处理 NaN/inf
276
+ arr_to_save[~np.isfinite(arr_to_save)] = fill_value
277
+ new_values = _data_to_scale_offset(arr_to_save, scale, offset)
278
+ new_da = xr.DataArray(new_values, dims=da.dims, coords=da.coords, attrs=attrs)
126
279
  new_da.attrs["scale_factor"] = float(scale)
127
280
  new_da.attrs["add_offset"] = float(offset)
281
+ # 不设置_FillValue属性,改为使用missing_value
282
+ # new_da.attrs["missing_value"] = -2147483648
128
283
  new_vars[var] = new_da
129
284
  encoding[var] = {
130
285
  "zlib": compile_switch,
131
286
  "complevel": 4,
132
- "dtype": "int16",
133
- "_FillValue": -32767,
287
+ "dtype": nc_dtype,
134
288
  }
135
289
  else:
136
- new_vars[var] = da
137
- new_ds = xr.Dataset(new_vars, coords=data.coords)
138
- if encoding:
139
- new_ds.to_netcdf(file, mode=mode, encoding=encoding)
140
- else:
141
- new_ds.to_netcdf(file, mode=mode)
290
+ new_vars[var] = xr.DataArray(arr, dims=da.dims, coords=da.coords, attrs=attrs)
291
+
292
+ # 确保坐标变量被正确复制
293
+ new_ds = xr.Dataset(new_vars, coords=data.coords.copy())
294
+ new_ds.to_netcdf(file, mode=mode, encoding=encoding if encoding else None)
295
+ _nan_to_fillvalue(file, fill_value)
142
296
  return
143
297
 
144
298
  # 处理纯 numpy 数组情况
@@ -148,9 +302,16 @@ def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_swi
148
302
  mode = "w"
149
303
  data = np.asarray(data)
150
304
  is_numeric = np.issubdtype(data.dtype, np.number)
305
+
306
+ if hasattr(data, "mask") and np.ma.is_masked(data):
307
+ # 处理掩码数组,获取缺失值
308
+ data = data.data
309
+ missing_value = getattr(data, "missing_value", None)
310
+ else:
311
+ missing_value = None
312
+
151
313
  try:
152
314
  with nc.Dataset(file, mode, format="NETCDF4") as ncfile:
153
- # 坐标变量直接写入,不做压缩
154
315
  if coords is not None:
155
316
  for dim, values in coords.items():
156
317
  if dim not in ncfile.dimensions:
@@ -160,44 +321,147 @@ def save_to_nc(file, data, varname=None, coords=None, mode="w", scale_offset_swi
160
321
 
161
322
  dims = list(coords.keys()) if coords else []
162
323
  if is_numeric and scale_offset_switch:
163
- scale, offset = _calculate_scale_and_offset(data)
164
- new_data = _data_to_scale_offset(data, scale, offset)
165
- var = ncfile.createVariable(varname, "i2", dims, fill_value=-32767, zlib=compile_switch)
324
+ arr = np.array(data)
325
+
326
+ # 构建有效掩码,但不排除掩码区域的数值(如果 preserve_mask_values True)
327
+ valid_mask = np.isfinite(arr) # 排除 NaN 和无限值
328
+ if missing_value is not None:
329
+ valid_mask &= arr != missing_value # 排除明确的缺失值
330
+
331
+ # 如果不保留掩码区域的值,则将掩码区域视为无效
332
+ if not preserve_mask_values and hasattr(arr, "mask"):
333
+ valid_mask &= ~arr.mask
334
+
335
+ arr_to_save = arr.copy()
336
+
337
+ # 确保有有效数据
338
+ if not np.any(valid_mask):
339
+ # 如果没有有效数据,不进行压缩,直接保存原始数据类型
340
+ dtype = _numpy_to_nc_type(data.dtype)
341
+ var = ncfile.createVariable(varname, dtype, dims, zlib=False)
342
+ # 确保没有 NaN
343
+ clean_data = np.nan_to_num(data, nan=missing_value if missing_value is not None else fill_value)
344
+ var[:] = clean_data
345
+ return
346
+
347
+ # 计算 scale 和 offset 仅使用有效区域数据
348
+ arr_valid = arr_to_save[valid_mask]
349
+ scale, offset = _calculate_scale_and_offset(arr_valid, convert_dtype)
350
+
351
+ # 执行压缩转换
352
+ new_data = _data_to_scale_offset(arr_to_save, scale, offset)
353
+
354
+ # 创建变量并设置属性
355
+ var = ncfile.createVariable(varname, nc_dtype, dims, zlib=compile_switch)
166
356
  var.scale_factor = scale
167
357
  var.add_offset = offset
168
- # Ensure no invalid values in new_data before assignment
358
+ var._FillValue = fill_value # 明确设置填充值
169
359
  var[:] = new_data
170
360
  else:
171
- # 非数值型数据,禁止压缩
172
361
  dtype = _numpy_to_nc_type(data.dtype)
173
362
  var = ncfile.createVariable(varname, dtype, dims, zlib=False)
174
- var[:] = data
363
+ # 确保不写入 NaN
364
+ if np.issubdtype(data.dtype, np.floating) and np.any(~np.isfinite(data)):
365
+ fill_val = missing_value if missing_value is not None else fill_value
366
+ var._FillValue = fill_val
367
+ clean_data = np.nan_to_num(data, nan=fill_val)
368
+ var[:] = clean_data
369
+ else:
370
+ var[:] = data
371
+ # 最后确保所有 NaN 值被处理
372
+ _nan_to_fillvalue(file, fill_value)
175
373
  except Exception as e:
176
374
  raise RuntimeError(f"netCDF4 保存失败: {str(e)}") from e
177
375
 
178
376
 
377
+ def _compress_netcdf(src_path, dst_path=None, tolerance=1e-10, preserve_mask_values=True):
378
+ """
379
+ 压缩 NetCDF 文件,使用 scale_factor/add_offset 压缩数据。
380
+ 若 dst_path 省略,则自动生成新文件名,写出后删除原文件并将新文件改回原名。
381
+ 压缩后验证数据是否失真。
382
+
383
+ 参数:
384
+ - src_path: 原始 NetCDF 文件路径
385
+ - dst_path: 压缩后的文件路径(可选)
386
+ - tolerance: 数据验证的允许误差范围(默认 1e-10)
387
+ - preserve_mask_values: 是否保留掩码区域的原始值(True)或将其替换为缺省值(False)
388
+ """
389
+ # 判断是否要替换原文件
390
+ delete_orig = dst_path is None
391
+ if delete_orig:
392
+ dst_path = src_path.replace(".nc", "_compress.nc")
393
+ # 打开原始文件并保存压缩文件
394
+ ds = xr.open_dataset(src_path)
395
+ save_to_nc(dst_path, ds, convert_dtype='int32',scale_offset_switch=True, compile_switch=True, preserve_mask_values=preserve_mask_values)
396
+ ds.close()
397
+
398
+ # 验证压缩后的数据是否失真
399
+ original_ds = xr.open_dataset(src_path)
400
+ compressed_ds = xr.open_dataset(dst_path)
401
+ # 更详细地验证数据
402
+ for var in original_ds.data_vars:
403
+ original_data = original_ds[var].values
404
+ compressed_data = compressed_ds[var].values
405
+ # 跳过非数值类型变量
406
+ if not np.issubdtype(original_data.dtype, np.number):
407
+ continue
408
+ # 获取掩码(如果存在)
409
+ original_mask = None
410
+ if hasattr(original_data, "mask") and np.ma.is_masked(original_data): # 修正:确保是有效的掩码数组
411
+ original_mask = original_data.mask.copy()
412
+ # 检查有效数据是否在允许误差范围内
413
+ valid_mask = np.isfinite(original_data)
414
+ if original_mask is not None:
415
+ valid_mask &= ~original_mask
416
+ if np.any(valid_mask):
417
+ if np.issubdtype(original_data.dtype, np.floating):
418
+ diff = np.abs(original_data[valid_mask] - compressed_data[valid_mask])
419
+ max_diff = np.max(diff)
420
+ if max_diff > tolerance:
421
+ print(f"警告: 变量 {var} 的压缩误差 {max_diff} 超出容许范围 {tolerance}")
422
+ if max_diff > tolerance * 10: # 严重偏差时抛出错误
423
+ raise ValueError(f"变量 {var} 的数据在压缩后严重失真 (max_diff={max_diff})")
424
+ elif np.issubdtype(original_data.dtype, np.integer):
425
+ # 整数类型应该完全相等
426
+ if not np.array_equal(original_data[valid_mask], compressed_data[valid_mask]):
427
+ raise ValueError(f"变量 {var} 的整数数据在压缩后不一致")
428
+ # 如果需要保留掩码区域值,检查掩码区域的值
429
+ if preserve_mask_values and original_mask is not None and np.any(original_mask):
430
+ # 确保掩码区域的原始值被正确保留
431
+ # 修正:掩码数组可能存在数据类型不匹配问题,添加安全检查
432
+ try:
433
+ mask_diff = np.abs(original_data[original_mask] - compressed_data[original_mask])
434
+ if np.any(mask_diff > tolerance):
435
+ print(f"警告: 变量 {var} 的掩码区域数据在压缩后发生变化")
436
+ except Exception as e:
437
+ print(f"警告: 变量 {var} 的掩码区域数据比较失败: {str(e)}")
438
+ original_ds.close()
439
+ compressed_ds.close()
440
+
441
+ # 替换原文件
442
+ if delete_orig:
443
+ os.remove(src_path)
444
+ os.rename(dst_path, src_path)
445
+
446
+
179
447
  # 测试用例
180
448
  if __name__ == "__main__":
181
- # --------------------------------
182
- # dataset
183
- file = r"F:\roms_rst.nc"
449
+ # 示例文件路径,需根据实际情况修改
450
+ file = "dataset_test.nc"
184
451
  ds = xr.open_dataset(file)
185
- outfile = r"F:\roms_rst_test.nc"
452
+ outfile = "dataset_test_compressed.nc"
186
453
  save_to_nc(outfile, ds)
187
454
  ds.close()
188
- # --------------------------------
455
+
189
456
  # dataarray
190
457
  data = np.random.rand(4, 3, 2)
191
458
  coords = {"x": np.arange(4), "y": np.arange(3), "z": np.arange(2)}
192
459
  varname = "test_var"
193
460
  data = xr.DataArray(data, dims=("x", "y", "z"), coords=coords, name=varname)
194
- outfile = r"F:\test_dataarray.nc"
461
+ outfile = "test_dataarray.nc"
195
462
  save_to_nc(outfile, data)
196
- # --------------------------------
197
- # numpy array
198
- data = np.random.rand(4, 3, 2)
199
- coords = {"x": np.arange(4), "y": np.arange(3), "z": np.arange(2)}
200
- varname = "test_var"
201
- outfile = r"F:\test_numpy.nc"
202
- save_to_nc(outfile, data, varname=varname, coords=coords)
203
- # --------------------------------
463
+
464
+ # numpy array with custom missing value
465
+ coords = {"dim0": np.arange(5)}
466
+ data = np.array([1, 2, -999, 4, np.nan])
467
+ save_to_nc("test_numpy_missing.nc", data, varname="data", coords=coords, missing_value=-999)
@@ -6,6 +6,7 @@ import time
6
6
  from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
7
7
  from typing import Any, Callable, Dict, List, Optional, Tuple
8
8
 
9
+
9
10
  import psutil
10
11
 
11
12
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -14,6 +15,7 @@ __all__ = ["ParallelExecutor"]
14
15
 
15
16
 
16
17
  class ParallelExecutor:
18
+
17
19
  def __init__(
18
20
  self,
19
21
  max_workers: Optional[int] = None,
@@ -120,6 +122,7 @@ class ParallelExecutor:
120
122
  self._executor = None
121
123
 
122
124
  def _execute_batch(self, func: Callable, params: List[Tuple], chunk_size: int) -> List[Any]:
125
+ from oafuncs.oa_tool import pbar
123
126
  if not params:
124
127
  return []
125
128
 
@@ -129,7 +132,7 @@ class ParallelExecutor:
129
132
  results = [None] * len(params)
130
133
  with self._get_executor() as executor:
131
134
  futures = {executor.submit(func, *args): idx for idx, args in enumerate(params)}
132
- for future in as_completed(futures):
135
+ for future in pbar(as_completed(futures), "Parallel Tasks", total=len(futures)):
133
136
  idx = futures[future]
134
137
  try:
135
138
  results[idx] = future.result(timeout=self.timeout_per_task)
oafuncs/oa_cmap.py CHANGED
@@ -246,12 +246,12 @@ def get(colormap_name: Optional[str] = None, show_available: bool = False) -> Op
246
246
  return None
247
247
 
248
248
  if colormap_name in my_cmap_dict:
249
- print(f"Using custom colormap: {colormap_name}")
249
+ # print(f"Using custom colormap: {colormap_name}")
250
250
  return create(my_cmap_dict[colormap_name])
251
251
  else:
252
252
  try:
253
253
  cmap = mpl.colormaps.get_cmap(colormap_name)
254
- print(f"Using matplotlib colormap: {colormap_name}")
254
+ # print(f"Using matplotlib colormap: {colormap_name}")
255
255
  return cmap
256
256
  except ValueError:
257
257
  print(f"Warning: Unknown cmap name: {colormap_name}")
oafuncs/oa_data.py CHANGED
@@ -23,7 +23,7 @@ from rich import print
23
23
  from scipy.interpolate import interp1d
24
24
 
25
25
 
26
- __all__ = ["interp_along_dim", "interp_2d", "ensure_list", "mask_shapefile"]
26
+ __all__ = ["interp_along_dim", "interp_2d", "interp_2d_geo", "ensure_list", "mask_shapefile"]
27
27
 
28
28
 
29
29
  def ensure_list(input_value: Any) -> List[str]:
@@ -120,7 +120,7 @@ def interp_2d(
120
120
  source_x_coordinates: Union[np.ndarray, List[float]],
121
121
  source_y_coordinates: Union[np.ndarray, List[float]],
122
122
  source_data: np.ndarray,
123
- interpolation_method: str = "linear",
123
+ interpolation_method: str = "cubic",
124
124
  ) -> np.ndarray:
125
125
  """
126
126
  Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
@@ -132,7 +132,7 @@ def interp_2d(
132
132
  source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's y-coordinates.
133
133
  source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
134
134
  >>> must be [y, x] or [*, y, x] or [*, *, y, x]
135
- interpolation_method (str, optional): Interpolation method. Defaults to "linear".
135
+ interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
136
136
  >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
137
137
  use_parallel (bool, optional): Enable parallel processing. Defaults to True.
138
138
 
@@ -162,7 +162,47 @@ def interp_2d(
162
162
  interpolation_method=interpolation_method,
163
163
  )
164
164
 
165
+ def interp_2d_geo(target_x_coordinates: Union[np.ndarray, List[float]], target_y_coordinates: Union[np.ndarray, List[float]], source_x_coordinates: Union[np.ndarray, List[float]], source_y_coordinates: Union[np.ndarray, List[float]], source_data: np.ndarray, interpolation_method: str = "cubic") -> np.ndarray:
166
+ """
167
+ Perform 2D interpolation on the last two dimensions of a multi-dimensional array (spherical coordinates).
168
+ 使用球面坐标系进行插值,适用于全球尺度的地理数据,能正确处理经度跨越日期线的情况。
169
+
170
+ Args:
171
+ target_x_coordinates (Union[np.ndarray, List[float]]): Target grid's longitude (-180 to 180 or 0 to 360).
172
+ target_y_coordinates (Union[np.ndarray, List[float]]): Target grid's latitude (-90 to 90).
173
+ source_x_coordinates (Union[np.ndarray, List[float]]): Original grid's longitude (-180 to 180 or 0 to 360).
174
+ source_y_coordinates (Union[np.ndarray, List[float]]): Original grid's latitude (-90 to 90).
175
+ source_data (np.ndarray): Multi-dimensional array with the last two dimensions as spatial.
176
+ interpolation_method (str, optional): Interpolation method. Defaults to "cubic".
177
+ >>> optional: 'linear', 'nearest', 'cubic', 'quintic', etc.
178
+
179
+ Returns:
180
+ np.ndarray: Interpolated data array.
181
+
182
+ Raises:
183
+ ValueError: If input shapes are invalid.
165
184
 
185
+ Examples:
186
+ >>> # 创建一个全球网格示例
187
+ >>> target_lon = np.arange(-180, 181, 1) # 1度分辨率目标网格
188
+ >>> target_lat = np.arange(-90, 91, 1)
189
+ >>> source_lon = np.arange(-180, 181, 5) # 5度分辨率源网格
190
+ >>> source_lat = np.arange(-90, 91, 5)
191
+ >>> # 创建一个简单的数据场 (例如温度场)
192
+ >>> source_data = np.cos(np.deg2rad(source_lat.reshape(-1, 1))) * np.cos(np.deg2rad(source_lon))
193
+ >>> # 插值到高分辨率网格
194
+ >>> result = interp_2d_geo(target_lon, target_lat, source_lon, source_lat, source_data)
195
+ >>> print(result.shape) # Expected output: (181, 361)
196
+ """
197
+ from ._script.data_interp_geo import interp_2d_func_geo
198
+ interp_2d_func_geo(
199
+ target_x_coordinates=target_x_coordinates,
200
+ target_y_coordinates=target_y_coordinates,
201
+ source_x_coordinates=source_x_coordinates,
202
+ source_y_coordinates=source_y_coordinates,
203
+ source_data=source_data,
204
+ interpolation_method=interpolation_method,
205
+ )
166
206
 
167
207
  def mask_shapefile(
168
208
  data_array: np.ndarray,