oafuncs 0.0.98.20__tar.gz → 0.0.98.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {oafuncs-0.0.98.20/oafuncs.egg-info → oafuncs-0.0.98.22}/PKG-INFO +1 -1
  2. oafuncs-0.0.98.22/oafuncs/_script/netcdf_merge.py +132 -0
  3. oafuncs-0.0.98.22/oafuncs/_script/netcdf_write.py +467 -0
  4. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_cmap.py +18 -9
  5. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_down/hycom_3hourly.py +4 -20
  6. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_draw.py +2 -2
  7. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_nc.py +73 -6
  8. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_tool.py +1 -1
  9. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22/oafuncs.egg-info}/PKG-INFO +1 -1
  10. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/setup.py +1 -1
  11. oafuncs-0.0.98.20/oafuncs/_script/netcdf_merge.py +0 -103
  12. oafuncs-0.0.98.20/oafuncs/_script/netcdf_write.py +0 -253
  13. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/LICENSE.txt +0 -0
  14. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/MANIFEST.in +0 -0
  15. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/README.md +0 -0
  16. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/__init__.py +0 -0
  17. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_data/hycom.png +0 -0
  18. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_data/oafuncs.png +0 -0
  19. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_script/cprogressbar.py +0 -0
  20. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_script/data_interp.py +0 -0
  21. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_script/data_interp_geo.py +0 -0
  22. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_script/email.py +0 -0
  23. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_script/netcdf_modify.py +0 -0
  24. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_script/parallel.py +0 -0
  25. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_script/parallel_test.py +0 -0
  26. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_script/plot_dataset.py +0 -0
  27. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/_script/replace_file_content.py +0 -0
  28. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_data.py +0 -0
  29. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_date.py +0 -0
  30. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_down/User_Agent-list.txt +0 -0
  31. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_down/__init__.py +0 -0
  32. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_down/hycom_3hourly_proxy.py +0 -0
  33. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_down/idm.py +0 -0
  34. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_down/literature.py +0 -0
  35. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_down/read_proxy.py +0 -0
  36. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_down/test_ua.py +0 -0
  37. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_down/user_agent.py +0 -0
  38. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_file.py +0 -0
  39. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_help.py +0 -0
  40. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_model/__init__.py +0 -0
  41. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_model/roms/__init__.py +0 -0
  42. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_model/roms/test.py +0 -0
  43. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_model/wrf/__init__.py +0 -0
  44. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_model/wrf/little_r.py +0 -0
  45. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_python.py +0 -0
  46. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_sign/__init__.py +0 -0
  47. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_sign/meteorological.py +0 -0
  48. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_sign/ocean.py +0 -0
  49. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs/oa_sign/scientific.py +0 -0
  50. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs.egg-info/SOURCES.txt +0 -0
  51. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs.egg-info/dependency_links.txt +0 -0
  52. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs.egg-info/requires.txt +0 -0
  53. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/oafuncs.egg-info/top_level.txt +0 -0
  54. {oafuncs-0.0.98.20 → oafuncs-0.0.98.22}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: oafuncs
3
- Version: 0.0.98.20
3
+ Version: 0.0.98.22
4
4
  Summary: Oceanic and Atmospheric Functions
5
5
  Home-page: https://github.com/Industry-Pays/OAFuncs
6
6
  Author: Kun Liu
@@ -0,0 +1,132 @@
1
+ import logging
2
+ import os
3
+ from typing import List, Optional, Union
4
+
5
+ import xarray as xr
6
+
7
+ from oafuncs import pbar
8
+
9
+
10
+ def merge_nc(file_list: Union[str, List[str]], var_name: Optional[Union[str, List[str]]] = None, dim_name: Optional[str] = None, target_filename: Optional[str] = None) -> None:
11
+ """
12
+ Description:
13
+ Merge variables from multiple NetCDF files along a specified dimension and write to a new file.
14
+ If var_name is a string, it is considered a single variable; if it is a list and has only one element, it is also a single variable;
15
+ If the list has more than one element, it is a multi-variable; if var_name is None, all variables are merged.
16
+
17
+ Parameters:
18
+ file_list: List of NetCDF file paths or a single file path as a string
19
+ var_name: Name of the variable to be extracted or a list of variable names, default is None, which means all variables are extracted
20
+ dim_name: Dimension name used for merging
21
+ target_filename: Target file name after merging
22
+
23
+ Example:
24
+ merge(file_list, var_name='u', dim_name='time', target_filename='merged.nc')
25
+ merge(file_list, var_name=['u', 'v'], dim_name='time', target_filename='merged.nc')
26
+ merge(file_list, var_name=None, dim_name='time', target_filename='merged.nc')
27
+ """
28
+
29
+ if target_filename is None:
30
+ target_filename = "merged.nc"
31
+
32
+ # 确保目标路径存在
33
+ target_dir = os.path.dirname(target_filename)
34
+ if target_dir and not os.path.exists(target_dir):
35
+ os.makedirs(target_dir)
36
+
37
+ if isinstance(file_list, str):
38
+ file_list = [file_list]
39
+
40
+ # 初始化变量名列表
41
+ if var_name is None:
42
+ with xr.open_dataset(file_list[0]) as ds:
43
+ var_names = list(ds.variables.keys())
44
+ elif isinstance(var_name, str):
45
+ var_names = [var_name]
46
+ elif isinstance(var_name, list):
47
+ var_names = var_name
48
+ else:
49
+ raise ValueError("var_name must be a string, a list of strings, or None")
50
+
51
+ # 初始化合并数据字典
52
+ merged_data = {}
53
+
54
+ for i, file in pbar(enumerate(file_list), "Reading files", total=len(file_list)):
55
+ with xr.open_dataset(file) as ds:
56
+ for var in var_names:
57
+ data_var = ds[var]
58
+ if dim_name in data_var.dims:
59
+ merged_data.setdefault(var, []).append(data_var)
60
+ elif var not in merged_data:
61
+ # 只负责合并,不做NaN填充,统一交由 netcdf_write.py 处理
62
+ merged_data[var] = data_var
63
+
64
+ # 记录变量的填充值和缺失值信息,确保不会丢失
65
+ fill_values = {}
66
+ missing_values = {}
67
+ for var_name, var_data in merged_data.items():
68
+ if isinstance(var_data, list) and var_data:
69
+ # 如果是要合并的变量,检查第一个元素的属性
70
+ attrs = var_data[0].attrs
71
+ if "_FillValue" in attrs:
72
+ fill_values[var_name] = attrs["_FillValue"]
73
+ if "missing_value" in attrs:
74
+ missing_values[var_name] = attrs["missing_value"]
75
+ else:
76
+ # 如果是单个变量,直接检查属性
77
+ attrs = var_data.attrs if hasattr(var_data, "attrs") else {}
78
+ if "_FillValue" in attrs:
79
+ fill_values[var_name] = attrs["_FillValue"]
80
+ if "missing_value" in attrs:
81
+ missing_values[var_name] = attrs["missing_value"]
82
+
83
+ for var in pbar(merged_data, "Merging variables"):
84
+ if isinstance(merged_data[var], list):
85
+ # 使用 coords='minimal' 替代默认值,并移除可能冲突的 compat='override'
86
+ merged_data[var] = xr.concat(merged_data[var], dim=dim_name, coords="minimal")
87
+ # 恢复原始填充值和缺失值属性
88
+ if var in fill_values:
89
+ merged_data[var].attrs["_FillValue"] = fill_values[var]
90
+ if var in missing_values:
91
+ merged_data[var].attrs["missing_value"] = missing_values[var]
92
+
93
+ # 合并后构建 Dataset,此时 merged_data 只包含数据变量,不包含坐标变量
94
+ merged_ds = xr.Dataset(merged_data)
95
+
96
+ # 自动补充坐标变量(如 time、lat、lon 等),以第一个文件为准
97
+ with xr.open_dataset(file_list[0]) as ds0:
98
+ for coord in ds0.coords:
99
+ # 保证坐标变量不会被覆盖,且数据类型和属性保持一致
100
+ if coord not in merged_ds.coords:
101
+ merged_ds = merged_ds.assign_coords({coord: ds0[coord]})
102
+
103
+ """ # 修改合并维度验证逻辑,更合理地检查所有文件维度的兼容性
104
+ if dim_name in merged_ds.coords and len(file_list) > 1:
105
+ logging.info(f"检查合并维度 {dim_name} 的有效性...")
106
+
107
+ # 收集所有文件的该维度值
108
+ all_dim_values = []
109
+ for file in file_list:
110
+ with xr.open_dataset(file) as ds:
111
+ if dim_name in ds.coords:
112
+ all_dim_values.append(ds[dim_name].values)
113
+
114
+ # 只有当有两个或更多不同值集合时才警告
115
+ unique_values_count = len({tuple(vals.tolist()) if hasattr(vals, "tolist") else tuple(vals) for vals in all_dim_values})
116
+ if unique_values_count > 1:
117
+ logging.warning(f"检测到 {unique_values_count} 种不同的 {dim_name} 坐标值集合,合并可能导致数据重新排列")
118
+ else:
119
+ logging.info(f"所有文件的 {dim_name} 坐标值完全一致,合并将保持原始顺序") """
120
+
121
+ if os.path.exists(target_filename):
122
+ logging.warning("The target file already exists. Removing it ...")
123
+ os.remove(target_filename)
124
+
125
+ merged_ds.to_netcdf(target_filename, mode="w")
126
+
127
+
128
+ # Example usage
129
+ if __name__ == "__main__":
130
+ files_to_merge = ["file1.nc", "file2.nc", "file3.nc"]
131
+ output_path = "merged_output.nc"
132
+ merge_nc(files_to_merge, var_name=None, dim_name="time", target_filename=output_path)
@@ -0,0 +1,467 @@
1
+ import os
2
+ import warnings
3
+
4
+ import netCDF4 as nc
5
+ import numpy as np
6
+ import xarray as xr
7
+
8
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
9
+
10
+
11
+
12
+ def _nan_to_fillvalue(ncfile,set_fill_value):
13
+ """
14
+ 将 NetCDF 文件中所有变量的 NaN 和掩码值替换为其 _FillValue 属性(若无则自动添加 _FillValue=-32767 并替换)。
15
+ 同时处理掩码数组中的无效值。
16
+ 仅对数值型变量(浮点型、整型)生效。
17
+ """
18
+ with nc.Dataset(ncfile, "r+") as ds:
19
+ for var_name in ds.variables:
20
+ var = ds.variables[var_name]
21
+ # 只处理数值类型变量 (f:浮点型, i:有符号整型, u:无符号整型)
22
+ if var.dtype.kind not in ["f", "i", "u"]:
23
+ continue
24
+
25
+ # 读取数据
26
+ arr = var[:]
27
+
28
+ # 确定填充值
29
+ if "_FillValue" in var.ncattrs():
30
+ fill_value = var.getncattr("_FillValue")
31
+ elif hasattr(var, "missing_value"):
32
+ fill_value = var.getncattr("missing_value")
33
+ else:
34
+ fill_value = set_fill_value
35
+ try:
36
+ var.setncattr("_FillValue", fill_value)
37
+ except Exception:
38
+ # 某些变量可能不允许动态添加 _FillValue
39
+ continue
40
+
41
+ # 处理掩码数组
42
+ if hasattr(arr, "mask"):
43
+ # 如果是掩码数组,将掩码位置的值设为 fill_value
44
+ if np.any(arr.mask):
45
+ arr = np.where(arr.mask, fill_value, arr.data if hasattr(arr, "data") else arr)
46
+
47
+ # 处理剩余 NaN 和无穷值
48
+ if arr.dtype.kind in ["f", "i", "u"] and np.any(~np.isfinite(arr)):
49
+ arr = np.nan_to_num(arr, nan=fill_value, posinf=fill_value, neginf=fill_value)
50
+
51
+ # 写回变量
52
+ var[:] = arr
53
+
54
+
55
+ def _numpy_to_nc_type(numpy_type):
56
+ """将 NumPy 数据类型映射到 NetCDF 数据类型"""
57
+ numpy_to_nc = {
58
+ "float32": "f4",
59
+ "float64": "f8",
60
+ "int8": "i1",
61
+ "int16": "i2",
62
+ "int32": "i4",
63
+ "int64": "i8",
64
+ "uint8": "u1",
65
+ "uint16": "u2",
66
+ "uint32": "u4",
67
+ "uint64": "u8",
68
+ }
69
+ numpy_type_str = str(numpy_type) if not isinstance(numpy_type, str) else numpy_type
70
+ return numpy_to_nc.get(numpy_type_str, "f4")
71
+
72
+
73
+ def _calculate_scale_and_offset(data, dtype="int32"):
74
+ """
75
+ 只对有效数据(非NaN、非填充值、非自定义缺失值)计算scale_factor和add_offset。
76
+ 使用 int32 类型,n=32
77
+ """
78
+ if not isinstance(data, np.ndarray):
79
+ raise ValueError("Input data must be a NumPy array.")
80
+
81
+ if dtype == "int32":
82
+ n = 32
83
+ fill_value = np.iinfo(np.int32).min # -2147483648
84
+ elif dtype == "int16":
85
+ n = 16
86
+ fill_value = np.iinfo(np.int16).min # -32768
87
+ else:
88
+ raise ValueError("Unsupported dtype. Supported types are 'int16' and 'int32'.")
89
+
90
+ # 有效掩码:非NaN、非inf、非fill_value
91
+ valid_mask = np.isfinite(data) & (data != fill_value)
92
+ if hasattr(data, "mask") and np.ma.is_masked(data):
93
+ valid_mask &= ~data.mask
94
+
95
+ if np.any(valid_mask):
96
+ data_min = np.min(data[valid_mask])-1
97
+ data_max = np.max(data[valid_mask])+1
98
+ else:
99
+ data_min, data_max = 0, 1
100
+
101
+ # 防止scale为0,且保证scale/offset不会影响缺省值
102
+ if data_max == data_min:
103
+ scale_factor = 1.0
104
+ add_offset = data_min
105
+ else:
106
+ scale_factor = (data_max - data_min) / (2**n - 2)
107
+ add_offset = (data_max + data_min) / 2.0
108
+ return scale_factor, add_offset
109
+
110
+
111
+ def _data_to_scale_offset(data, scale, offset, dtype='int32'):
112
+ """
113
+ 只对有效数据做缩放,NaN/inf/填充值直接赋为fill_value。
114
+ 掩码区域的值会被保留并进行缩放,除非掩码本身标记为无效。
115
+ 使用 int32 类型
116
+ """
117
+ if not isinstance(data, np.ndarray):
118
+ raise ValueError("Input data must be a NumPy array.")
119
+
120
+ if dtype == "int32":
121
+ # n = 32
122
+ np_dtype = np.int32
123
+ fill_value = np.iinfo(np.int32).min # -2147483648
124
+ clip_min = np.iinfo(np.int32).min + 1 # -2147483647
125
+ clip_max = np.iinfo(np.int32).max # 2147483647
126
+ elif dtype == "int16":
127
+ # n = 16
128
+ np_dtype = np.int16
129
+ fill_value = np.iinfo(np.int16).min # -32768
130
+ clip_min = np.iinfo(np.int16).min + 1 # -32767
131
+ clip_max = np.iinfo(np.int16).max # 32767
132
+ else:
133
+ raise ValueError("Unsupported dtype. Supported types are 'int16' and 'int32'.")
134
+
135
+ # 创建掩码,只排除 NaN/inf 和显式的填充值
136
+ valid_mask = np.isfinite(data)
137
+ valid_mask &= data != fill_value
138
+
139
+ # 如果数据有掩码属性,还需考虑掩码
140
+ if hasattr(data, "mask") and np.ma.is_masked(data):
141
+ # 只有掩码标记的区域视为无效
142
+ valid_mask &= ~data.mask
143
+
144
+ result = data.copy()
145
+ if np.any(valid_mask):
146
+ # 反向映射时能还原原始值
147
+ scaled = (data[valid_mask] - offset) / scale
148
+ scaled = np.round(scaled).astype(np_dtype)
149
+ # clip到int32范围,保留最大范围供转换
150
+ scaled = np.clip(scaled, clip_min, clip_max) # 不使用 -2147483648,保留做 _FillValue
151
+ result[valid_mask] = scaled
152
+ return result
153
+
154
+
155
+ def save_to_nc(file, data, varname=None, coords=None, mode="w", convert_dtype='int32',scale_offset_switch=True, compile_switch=True, preserve_mask_values=True):
156
+ """
157
+ 保存数据到 NetCDF 文件,支持 xarray 对象(DataArray 或 Dataset)和 numpy 数组。
158
+
159
+ 仅对数据变量中数值型数据进行压缩转换(利用 scale_factor/add_offset 转换后转为 int32),
160
+ 非数值型数据以及所有坐标变量将禁用任何压缩,直接保存原始数据。
161
+
162
+ 参数:
163
+ - file: 保存文件的路径
164
+ - data: xarray.DataArray、xarray.Dataset 或 numpy 数组
165
+ - varname: 变量名(仅适用于传入 numpy 数组或 DataArray 时)
166
+ - coords: 坐标字典(numpy 数组分支时使用),所有坐标变量均不压缩
167
+ - mode: "w"(覆盖)或 "a"(追加)
168
+ - convert_dtype: 转换为的数值类型("int16" 或 "int32"),默认为 "int32"
169
+ - scale_offset_switch: 是否对数值型数据变量进行压缩转换
170
+ - compile_switch: 是否启用 NetCDF4 的 zlib 压缩(仅针对数值型数据有效)
171
+ - missing_value: 自定义缺失值,将被替换为 fill_value
172
+ - preserve_mask_values: 是否保留掩码区域的原始值(True)或将其替换为缺省值(False)
173
+ """
174
+ if convert_dtype not in ["int16", "int32"]:
175
+ convert_dtype = "int32"
176
+ nc_dtype = _numpy_to_nc_type(convert_dtype)
177
+ # fill_value = np.iinfo(np.convert_dtype).min # -2147483648 或 -32768
178
+ # fill_value = np.iinfo(eval('np.' + convert_dtype)).min # -2147483648 或 -32768
179
+ np_dtype = getattr(np, convert_dtype) # 更安全的类型获取方式
180
+ fill_value = np.iinfo(np_dtype).min
181
+ # ----------------------------------------------------------------------------
182
+ # 处理 xarray 对象(DataArray 或 Dataset)的情况
183
+ if isinstance(data, (xr.DataArray, xr.Dataset)):
184
+ encoding = {}
185
+
186
+ if isinstance(data, xr.DataArray):
187
+ if data.name is None:
188
+ data = data.rename("data")
189
+ varname = data.name if varname is None else varname
190
+ arr = np.array(data.values)
191
+ try:
192
+ data_missing_val = data.attrs.get("missing_value")
193
+ except AttributeError:
194
+ data_missing_val = data.attrs.get("_FillValue", None)
195
+ # 只对有效数据计算scale/offset
196
+ valid_mask = np.ones(arr.shape, dtype=bool) # 默认所有值都有效
197
+ if arr.dtype.kind in ["f", "i", "u"]: # 仅对数值数据应用isfinite
198
+ valid_mask = np.isfinite(arr)
199
+ if data_missing_val is not None:
200
+ valid_mask &= arr != data_missing_val
201
+ if hasattr(arr, "mask"):
202
+ valid_mask &= ~getattr(arr, "mask", False)
203
+ if np.issubdtype(arr.dtype, np.number) and scale_offset_switch:
204
+ arr_valid = arr[valid_mask]
205
+ scale, offset = _calculate_scale_and_offset(arr_valid, convert_dtype)
206
+ # 写入前处理无效值(只在这里做!)
207
+ arr_to_save = arr.copy()
208
+ # 处理自定义缺失值
209
+ if data_missing_val is not None:
210
+ arr_to_save[arr == data_missing_val] = fill_value
211
+ # 处理 NaN/inf
212
+ arr_to_save[~np.isfinite(arr_to_save)] = fill_value
213
+ new_values = _data_to_scale_offset(arr_to_save, scale, offset)
214
+ new_da = data.copy(data=new_values)
215
+ # 移除 _FillValue 和 missing_value 属性
216
+ for k in ["_FillValue", "missing_value"]:
217
+ if k in new_da.attrs:
218
+ del new_da.attrs[k]
219
+ new_da.attrs["scale_factor"] = float(scale)
220
+ new_da.attrs["add_offset"] = float(offset)
221
+ encoding[varname] = {
222
+ "zlib": compile_switch,
223
+ "complevel": 4,
224
+ "dtype": nc_dtype,
225
+ # "_FillValue": -2147483648,
226
+ }
227
+ new_da.to_dataset(name=varname).to_netcdf(file, mode=mode, encoding=encoding)
228
+ else:
229
+ for k in ["_FillValue", "missing_value"]:
230
+ if k in data.attrs:
231
+ del data.attrs[k]
232
+ data.to_dataset(name=varname).to_netcdf(file, mode=mode)
233
+ _nan_to_fillvalue(file, fill_value)
234
+ return
235
+
236
+ else: # Dataset 情况
237
+ new_vars = {}
238
+ encoding = {}
239
+ for var in data.data_vars:
240
+ da = data[var]
241
+ arr = np.array(da.values)
242
+ try:
243
+ data_missing_val = da.attrs.get("missing_value")
244
+ except AttributeError:
245
+ data_missing_val = da.attrs.get("_FillValue", None)
246
+ valid_mask = np.ones(arr.shape, dtype=bool) # 默认所有值都有效
247
+ if arr.dtype.kind in ["f", "i", "u"]: # 仅对数值数据应用isfinite
248
+ valid_mask = np.isfinite(arr)
249
+ if data_missing_val is not None:
250
+ valid_mask &= arr != data_missing_val
251
+ if hasattr(arr, "mask"):
252
+ valid_mask &= ~getattr(arr, "mask", False)
253
+
254
+ # 创建属性的副本以避免修改原始数据集
255
+ attrs = da.attrs.copy()
256
+ for k in ["_FillValue", "missing_value"]:
257
+ if k in attrs:
258
+ del attrs[k]
259
+
260
+ if np.issubdtype(arr.dtype, np.number) and scale_offset_switch:
261
+ # 处理边缘情况:检查是否有有效数据
262
+ if not np.any(valid_mask):
263
+ # 如果没有有效数据,创建一个简单的拷贝,不做转换
264
+ new_vars[var] = xr.DataArray(arr, dims=da.dims, coords=da.coords, attrs=attrs)
265
+ continue
266
+
267
+ arr_valid = arr[valid_mask]
268
+ scale, offset = _calculate_scale_and_offset(arr_valid, convert_dtype)
269
+ arr_to_save = arr.copy()
270
+
271
+ # 使用与DataArray相同的逻辑,使用_data_to_scale_offset处理数据
272
+ # 处理自定义缺失值
273
+ if data_missing_val is not None:
274
+ arr_to_save[arr == data_missing_val] = fill_value
275
+ # 处理 NaN/inf
276
+ arr_to_save[~np.isfinite(arr_to_save)] = fill_value
277
+ new_values = _data_to_scale_offset(arr_to_save, scale, offset)
278
+ new_da = xr.DataArray(new_values, dims=da.dims, coords=da.coords, attrs=attrs)
279
+ new_da.attrs["scale_factor"] = float(scale)
280
+ new_da.attrs["add_offset"] = float(offset)
281
+ # 不设置_FillValue属性,改为使用missing_value
282
+ # new_da.attrs["missing_value"] = -2147483648
283
+ new_vars[var] = new_da
284
+ encoding[var] = {
285
+ "zlib": compile_switch,
286
+ "complevel": 4,
287
+ "dtype": nc_dtype,
288
+ }
289
+ else:
290
+ new_vars[var] = xr.DataArray(arr, dims=da.dims, coords=da.coords, attrs=attrs)
291
+
292
+ # 确保坐标变量被正确复制
293
+ new_ds = xr.Dataset(new_vars, coords=data.coords.copy())
294
+ new_ds.to_netcdf(file, mode=mode, encoding=encoding if encoding else None)
295
+ _nan_to_fillvalue(file, fill_value)
296
+ return
297
+
298
+ # 处理纯 numpy 数组情况
299
+ if mode == "w" and os.path.exists(file):
300
+ os.remove(file)
301
+ elif mode == "a" and not os.path.exists(file):
302
+ mode = "w"
303
+ data = np.asarray(data)
304
+ is_numeric = np.issubdtype(data.dtype, np.number)
305
+
306
+ if hasattr(data, "mask") and np.ma.is_masked(data):
307
+ # 处理掩码数组,获取缺失值
308
+ data = data.data
309
+ missing_value = getattr(data, "missing_value", None)
310
+ else:
311
+ missing_value = None
312
+
313
+ try:
314
+ with nc.Dataset(file, mode, format="NETCDF4") as ncfile:
315
+ if coords is not None:
316
+ for dim, values in coords.items():
317
+ if dim not in ncfile.dimensions:
318
+ ncfile.createDimension(dim, len(values))
319
+ var_obj = ncfile.createVariable(dim, _numpy_to_nc_type(np.asarray(values).dtype), (dim,))
320
+ var_obj[:] = values
321
+
322
+ dims = list(coords.keys()) if coords else []
323
+ if is_numeric and scale_offset_switch:
324
+ arr = np.array(data)
325
+
326
+ # 构建有效掩码,但不排除掩码区域的数值(如果 preserve_mask_values 为 True)
327
+ valid_mask = np.isfinite(arr) # 排除 NaN 和无限值
328
+ if missing_value is not None:
329
+ valid_mask &= arr != missing_value # 排除明确的缺失值
330
+
331
+ # 如果不保留掩码区域的值,则将掩码区域视为无效
332
+ if not preserve_mask_values and hasattr(arr, "mask"):
333
+ valid_mask &= ~arr.mask
334
+
335
+ arr_to_save = arr.copy()
336
+
337
+ # 确保有有效数据
338
+ if not np.any(valid_mask):
339
+ # 如果没有有效数据,不进行压缩,直接保存原始数据类型
340
+ dtype = _numpy_to_nc_type(data.dtype)
341
+ var = ncfile.createVariable(varname, dtype, dims, zlib=False)
342
+ # 确保没有 NaN
343
+ clean_data = np.nan_to_num(data, nan=missing_value if missing_value is not None else fill_value)
344
+ var[:] = clean_data
345
+ return
346
+
347
+ # 计算 scale 和 offset 仅使用有效区域数据
348
+ arr_valid = arr_to_save[valid_mask]
349
+ scale, offset = _calculate_scale_and_offset(arr_valid, convert_dtype)
350
+
351
+ # 执行压缩转换
352
+ new_data = _data_to_scale_offset(arr_to_save, scale, offset)
353
+
354
+ # 创建变量并设置属性
355
+ var = ncfile.createVariable(varname, nc_dtype, dims, zlib=compile_switch)
356
+ var.scale_factor = scale
357
+ var.add_offset = offset
358
+ var._FillValue = fill_value # 明确设置填充值
359
+ var[:] = new_data
360
+ else:
361
+ dtype = _numpy_to_nc_type(data.dtype)
362
+ var = ncfile.createVariable(varname, dtype, dims, zlib=False)
363
+ # 确保不写入 NaN
364
+ if np.issubdtype(data.dtype, np.floating) and np.any(~np.isfinite(data)):
365
+ fill_val = missing_value if missing_value is not None else fill_value
366
+ var._FillValue = fill_val
367
+ clean_data = np.nan_to_num(data, nan=fill_val)
368
+ var[:] = clean_data
369
+ else:
370
+ var[:] = data
371
+ # 最后确保所有 NaN 值被处理
372
+ _nan_to_fillvalue(file, fill_value)
373
+ except Exception as e:
374
+ raise RuntimeError(f"netCDF4 保存失败: {str(e)}") from e
375
+
376
+
377
+ def _compress_netcdf(src_path, dst_path=None, tolerance=1e-10, preserve_mask_values=True):
378
+ """
379
+ 压缩 NetCDF 文件,使用 scale_factor/add_offset 压缩数据。
380
+ 若 dst_path 省略,则自动生成新文件名,写出后删除原文件并将新文件改回原名。
381
+ 压缩后验证数据是否失真。
382
+
383
+ 参数:
384
+ - src_path: 原始 NetCDF 文件路径
385
+ - dst_path: 压缩后的文件路径(可选)
386
+ - tolerance: 数据验证的允许误差范围(默认 1e-10)
387
+ - preserve_mask_values: 是否保留掩码区域的原始值(True)或将其替换为缺省值(False)
388
+ """
389
+ # 判断是否要替换原文件
390
+ delete_orig = dst_path is None
391
+ if delete_orig:
392
+ dst_path = src_path.replace(".nc", "_compress.nc")
393
+ # 打开原始文件并保存压缩文件
394
+ ds = xr.open_dataset(src_path)
395
+ save_to_nc(dst_path, ds, convert_dtype='int32',scale_offset_switch=True, compile_switch=True, preserve_mask_values=preserve_mask_values)
396
+ ds.close()
397
+
398
+ # 验证压缩后的数据是否失真
399
+ original_ds = xr.open_dataset(src_path)
400
+ compressed_ds = xr.open_dataset(dst_path)
401
+ # 更详细地验证数据
402
+ for var in original_ds.data_vars:
403
+ original_data = original_ds[var].values
404
+ compressed_data = compressed_ds[var].values
405
+ # 跳过非数值类型变量
406
+ if not np.issubdtype(original_data.dtype, np.number):
407
+ continue
408
+ # 获取掩码(如果存在)
409
+ original_mask = None
410
+ if hasattr(original_data, "mask") and np.ma.is_masked(original_data): # 修正:确保是有效的掩码数组
411
+ original_mask = original_data.mask.copy()
412
+ # 检查有效数据是否在允许误差范围内
413
+ valid_mask = np.isfinite(original_data)
414
+ if original_mask is not None:
415
+ valid_mask &= ~original_mask
416
+ if np.any(valid_mask):
417
+ if np.issubdtype(original_data.dtype, np.floating):
418
+ diff = np.abs(original_data[valid_mask] - compressed_data[valid_mask])
419
+ max_diff = np.max(diff)
420
+ if max_diff > tolerance:
421
+ print(f"警告: 变量 {var} 的压缩误差 {max_diff} 超出容许范围 {tolerance}")
422
+ if max_diff > tolerance * 10: # 严重偏差时抛出错误
423
+ raise ValueError(f"变量 {var} 的数据在压缩后严重失真 (max_diff={max_diff})")
424
+ elif np.issubdtype(original_data.dtype, np.integer):
425
+ # 整数类型应该完全相等
426
+ if not np.array_equal(original_data[valid_mask], compressed_data[valid_mask]):
427
+ raise ValueError(f"变量 {var} 的整数数据在压缩后不一致")
428
+ # 如果需要保留掩码区域值,检查掩码区域的值
429
+ if preserve_mask_values and original_mask is not None and np.any(original_mask):
430
+ # 确保掩码区域的原始值被正确保留
431
+ # 修正:掩码数组可能存在数据类型不匹配问题,添加安全检查
432
+ try:
433
+ mask_diff = np.abs(original_data[original_mask] - compressed_data[original_mask])
434
+ if np.any(mask_diff > tolerance):
435
+ print(f"警告: 变量 {var} 的掩码区域数据在压缩后发生变化")
436
+ except Exception as e:
437
+ print(f"警告: 变量 {var} 的掩码区域数据比较失败: {str(e)}")
438
+ original_ds.close()
439
+ compressed_ds.close()
440
+
441
+ # 替换原文件
442
+ if delete_orig:
443
+ os.remove(src_path)
444
+ os.rename(dst_path, src_path)
445
+
446
+
447
+ # 测试用例
448
+ if __name__ == "__main__":
449
+ # 示例文件路径,需根据实际情况修改
450
+ file = "dataset_test.nc"
451
+ ds = xr.open_dataset(file)
452
+ outfile = "dataset_test_compressed.nc"
453
+ save_to_nc(outfile, ds)
454
+ ds.close()
455
+
456
+ # dataarray
457
+ data = np.random.rand(4, 3, 2)
458
+ coords = {"x": np.arange(4), "y": np.arange(3), "z": np.arange(2)}
459
+ varname = "test_var"
460
+ data = xr.DataArray(data, dims=("x", "y", "z"), coords=coords, name=varname)
461
+ outfile = "test_dataarray.nc"
462
+ save_to_nc(outfile, data)
463
+
464
+ # numpy array with custom missing value
465
+ coords = {"dim0": np.arange(5)}
466
+ data = np.array([1, 2, -999, 4, np.nan])
467
+ save_to_nc("test_numpy_missing.nc", data, varname="data", coords=coords, missing_value=-999)
@@ -8,7 +8,9 @@ __all__ = ["show", "to_color", "create", "get"]
8
8
 
9
9
 
10
10
  # ** 将cmap用填色图可视化(官网摘抄函数)
11
- def show(colormaps: Union[str, mpl.colors.Colormap, List[Union[str, mpl.colors.Colormap]]]) -> None:
11
+ def show(
12
+ colormaps: Union[str, mpl.colors.Colormap, List[Union[str, mpl.colors.Colormap]]],
13
+ ) -> None:
12
14
  """Helper function to plot data with associated colormap.
13
15
 
14
16
  This function creates a visualization of one or more colormaps by applying them
@@ -97,7 +99,14 @@ def to_color(colormap_name: str, num_colors: int = 256) -> List[tuple]:
97
99
 
98
100
 
99
101
  # ** 自制cmap,多色,可带位置
100
- def create(color_list: Optional[List[Union[str, tuple]]] = None, rgb_file: Optional[str] = None, color_positions: Optional[List[float]] = None, below_range_color: Optional[Union[str, tuple]] = None, above_range_color: Optional[Union[str, tuple]] = None, value_delimiter: str = ",") -> mpl.colors.Colormap:
102
+ def create(
103
+ color_list: Optional[List[Union[str, tuple]]] = None,
104
+ rgb_file: Optional[str] = None,
105
+ color_positions: Optional[List[float]] = None,
106
+ below_range_color: Optional[Union[str, tuple]] = None,
107
+ above_range_color: Optional[Union[str, tuple]] = None,
108
+ value_delimiter: str = ",",
109
+ ) -> mpl.colors.Colormap:
101
110
  """Create a custom colormap from a list of colors or an RGB txt document.
102
111
 
103
112
  Args:
@@ -144,7 +153,7 @@ def create(color_list: Optional[List[Union[str, tuple]]] = None, rgb_file: Optio
144
153
 
145
154
  if rgb_file:
146
155
  try:
147
- print(f"Reading RGB data from {rgb_file}...")
156
+ # print(f"Reading RGB data from {rgb_file}...")
148
157
 
149
158
  with open(rgb_file) as fid:
150
159
  data = [line.strip() for line in fid if line.strip() and not line.strip().startswith("#")]
@@ -178,7 +187,7 @@ def create(color_list: Optional[List[Union[str, tuple]]] = None, rgb_file: Optio
178
187
  if max_rgb > 2:
179
188
  rgb = rgb / 255.0
180
189
  cmap_color = mpl.colors.ListedColormap(rgb, name="my_color")
181
- print(f"Successfully created colormap from {rgb_file}")
190
+ # print(f"Successfully created colormap from {rgb_file}")
182
191
  except FileNotFoundError:
183
192
  error_msg = f"RGB file not found: {rgb_file}"
184
193
  print(error_msg)
@@ -189,15 +198,15 @@ def create(color_list: Optional[List[Union[str, tuple]]] = None, rgb_file: Optio
189
198
  cmap_color = mpl.colors.LinearSegmentedColormap.from_list("mycmap", color_list)
190
199
  else:
191
200
  cmap_color = mpl.colors.LinearSegmentedColormap.from_list("mycmap", list(zip(color_positions, color_list)))
192
- print(f"Successfully created colormap from {len(color_list)} colors")
201
+ # print(f"Successfully created colormap from {len(color_list)} colors")
193
202
 
194
203
  # Set below/above range colors if provided
195
204
  if below_range_color is not None:
196
205
  cmap_color.set_under(below_range_color)
197
- print(f"Set below-range color to {below_range_color}")
206
+ # print(f"Set below-range color to {below_range_color}")
198
207
  if above_range_color is not None:
199
208
  cmap_color.set_over(above_range_color)
200
- print(f"Set above-range color to {above_range_color}")
209
+ # print(f"Set above-range color to {above_range_color}")
201
210
 
202
211
  return cmap_color
203
212
 
@@ -246,12 +255,12 @@ def get(colormap_name: Optional[str] = None, show_available: bool = False) -> Op
246
255
  return None
247
256
 
248
257
  if colormap_name in my_cmap_dict:
249
- print(f"Using custom colormap: {colormap_name}")
258
+ # print(f"Using custom colormap: {colormap_name}")
250
259
  return create(my_cmap_dict[colormap_name])
251
260
  else:
252
261
  try:
253
262
  cmap = mpl.colormaps.get_cmap(colormap_name)
254
- print(f"Using matplotlib colormap: {colormap_name}")
263
+ # print(f"Using matplotlib colormap: {colormap_name}")
255
264
  return cmap
256
265
  except ValueError:
257
266
  print(f"Warning: Unknown cmap name: {colormap_name}")