oafuncs 0.0.97.13__py3-none-any.whl → 0.0.97.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,354 +1,107 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- """
4
- Author: Liu Kun && 16031215@qq.com
5
- Date: 2025-03-30 11:16:29
6
- LastEditors: Liu Kun && 16031215@qq.com
7
- LastEditTime: 2025-03-30 11:16:31
8
- FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\_script\\netcdf_merge.py
9
- Description:
10
- EditPlatform: vscode
11
- ComputerInfo: XPS 15 9510
12
- SystemInfo: Windows 11
13
- Python Version: 3.12
14
- """
15
-
16
- import logging
17
1
  import os
18
- from typing import Dict, List, Union
19
-
20
- import numpy as np
21
- import xarray as xr
2
+ from typing import List, Optional, Union
22
3
  from dask.diagnostics import ProgressBar
4
+ import xarray as xr
5
+ from oafuncs import pbar
23
6
 
24
- # Configure logging
25
- logging.basicConfig(level=logging.INFO)
26
- logger = logging.getLogger(__name__)
27
-
28
-
29
- def merge(file_list: Union[str, List[str]], var_name: Union[str, List[str], None] = None, dim_name: str = "time", target_filename: str = "merged.nc", chunk_config: Dict = {"time": 1000}, compression: Union[bool, Dict] = True, sanity_check: bool = True, overwrite: bool = True, parallel: bool = True) -> None:
7
+ def merge_nc(file_list: Union[str, List[str]], var_name: Optional[Union[str, List[str]]] = None, dim_name: Optional[str] = None, target_filename: Optional[str] = None) -> None:
30
8
  """
31
- Ultimate NetCDF merge function
9
+ Description:
10
+ Merge variables from multiple NetCDF files along a specified dimension and write to a new file.
11
+ If var_name is a string, it is considered a single variable; if it is a list and has only one element, it is also a single variable;
12
+ If the list has more than one element, it is a multi-variable; if var_name is None, all variables are merged.
32
13
 
33
14
  Parameters:
34
- file_list: List of file paths or single file path
35
- var_name: Variables to merge (single variable name/list of variables/None means all)
36
- dim_name: Dimension to merge along, default is 'time'
37
- target_filename: Output file path
38
- chunk_config: Dask chunking configuration, e.g. {"time": 1000}
39
- compression: Compression configuration (True enables default compression, or custom encoding dictionary)
40
- sanity_check: Whether to perform data integrity validation
41
- overwrite: Whether to overwrite existing files
42
- parallel: Whether to enable parallel processing
15
+ file_list: List of NetCDF file paths or a single file path as a string
16
+ var_name: Name of the variable to be extracted or a list of variable names, default is None, which means all variables are extracted
17
+ dim_name: Dimension name used for merging
18
+ target_filename: Target file name after merging
43
19
 
44
20
  Example:
45
- merge(["data1.nc", "data2.nc"],
46
- var_name=["temp", "salt"],
47
- target_filename="result.nc",
48
- chunk_config={"time": 500})
21
+ merge(file_list, var_name='u', dim_name='time', target_filename='merged.nc')
22
+ merge(file_list, var_name=['u', 'v'], dim_name='time', target_filename='merged.nc')
23
+ merge(file_list, var_name=None, dim_name='time', target_filename='merged.nc')
49
24
  """
50
- # ------------------------ Parameter preprocessing ------------------------#
51
- file_list = _validate_and_preprocess_inputs(file_list, target_filename, overwrite)
52
- all_vars, var_names = _determine_variables(file_list, var_name)
53
- static_vars = _identify_static_vars(file_list[0], var_names, dim_name)
54
-
55
- # Estimate required memory for processing
56
- _estimate_memory_usage(file_list, var_names, chunk_config)
57
-
58
- # ------------------------ Data validation phase ------------------------#
59
- if sanity_check:
60
- _perform_sanity_checks(file_list, var_names, dim_name, static_vars)
61
-
62
- # ------------------------ Core merging logic ------------------------#
63
- with xr.set_options(keep_attrs=True): # Preserve metadata attributes
64
- # Merge dynamic variables
65
- merged_ds = xr.open_mfdataset(
66
- file_list,
67
- combine="nested",
68
- concat_dim=dim_name,
69
- data_vars=[var for var in var_names if var not in static_vars],
70
- chunks=chunk_config,
71
- parallel=parallel,
72
- preprocess=lambda ds: ds[var_names], # Only load target variables
73
- )
74
-
75
- # Process static variables
76
- if static_vars:
77
- with xr.open_dataset(file_list[0], chunks=chunk_config) as ref_ds:
78
- merged_ds = merged_ds.assign({var: ref_ds[var] for var in static_vars})
79
-
80
- # ------------------------ Time dimension processing ------------------------#
81
- if dim_name == "time":
82
- merged_ds = _process_time_dimension(merged_ds)
83
-
84
- # ------------------------ File output ------------------------#
85
- encoding = _generate_encoding_config(merged_ds, compression)
86
- _write_to_netcdf(merged_ds, target_filename, encoding)
87
-
88
-
89
- # ------------------------ Helper functions ------------------------#
90
- def _validate_and_preprocess_inputs(file_list: Union[str, List[str]], target_filename: str, overwrite: bool) -> List[str]:
91
- """Input parameter validation and preprocessing"""
92
- if not file_list:
93
- raise ValueError("File list cannot be empty")
25
+
26
+ if target_filename is None:
27
+ target_filename = "merged.nc"
94
28
 
95
- file_list = [file_list] if isinstance(file_list, str) else file_list
96
- for f in file_list:
97
- if not os.path.exists(f):
98
- raise FileNotFoundError(f"Input file does not exist: {f}")
29
+ # 确保目标路径存在
30
+ target_dir = os.path.dirname(target_filename)
31
+ if target_dir and not os.path.exists(target_dir):
32
+ os.makedirs(target_dir)
99
33
 
100
- target_dir = os.path.dirname(os.path.abspath(target_filename))
101
- os.makedirs(target_dir, exist_ok=True)
102
-
103
- if os.path.exists(target_filename):
104
- if overwrite:
105
- logger.warning(f"Overwriting existing file: {target_filename}")
106
- os.remove(target_filename)
107
- else:
108
- raise FileExistsError(f"Target file already exists: {target_filename}")
109
-
110
- return file_list
111
-
112
-
113
- def _determine_variables(file_list: List[str], var_name: Union[str, List[str], None]) -> tuple:
114
- """Determine the list of variables to process"""
115
- with xr.open_dataset(file_list[0]) as ds:
116
- all_vars = list(ds.data_vars.keys())
34
+ if isinstance(file_list, str):
35
+ file_list = [file_list]
117
36
 
37
+ # 初始化变量名列表
118
38
  if var_name is None:
119
- return all_vars, all_vars
39
+ with xr.open_dataset(file_list[0]) as ds:
40
+ var_names = list(ds.variables.keys())
120
41
  elif isinstance(var_name, str):
121
- if var_name not in all_vars:
122
- raise ValueError(f"Invalid variable name: {var_name}")
123
- return all_vars, [var_name]
42
+ var_names = [var_name]
124
43
  elif isinstance(var_name, list):
125
- if not var_name: # Handle empty list case
126
- logger.warning("Empty variable list provided, will use all variables")
127
- return all_vars, all_vars
128
- invalid_vars = set(var_name) - set(all_vars)
129
- if invalid_vars:
130
- raise ValueError(f"Invalid variable names: {invalid_vars}")
131
- return all_vars, var_name
44
+ var_names = var_name
132
45
  else:
133
- raise TypeError("var_name parameter must be of type str/list/None")
134
-
135
-
136
- def _identify_static_vars(sample_file: str, var_names: List[str], dim_name: str) -> List[str]:
137
- """Identify static variables"""
138
- with xr.open_dataset(sample_file) as ds:
139
- return [var for var in var_names if dim_name not in ds[var].dims]
46
+ raise ValueError("var_name must be a string, a list of strings, or None")
140
47
 
48
+ # 初始化合并数据字典
49
+ merged_data = {}
141
50
 
142
- def _perform_sanity_checks(file_list: List[str], var_names: List[str], dim_name: str, static_vars: List[str]) -> None:
143
- """Perform data integrity validation"""
144
- logger.info("Performing data integrity validation...")
145
-
146
- # Check consistency of static variables
147
- with xr.open_dataset(file_list[0]) as ref_ds:
148
- for var in static_vars:
149
- ref = ref_ds[var]
150
- for f in file_list[1:]:
151
- with xr.open_dataset(f) as ds:
152
- if not ref.equals(ds[var]):
153
- raise ValueError(f"Static variable {var} inconsistent\nReference file: {file_list[0]}\nProblem file: {f}")
154
-
155
- # Check dimensions of dynamic variables
156
- dim_sizes = {}
157
- for f in file_list:
158
- with xr.open_dataset(f) as ds:
51
+ for i, file in pbar(enumerate(file_list),description="Reading files", color="green",total=len(file_list)):
52
+ with xr.open_dataset(file) as ds:
159
53
  for var in var_names:
160
- if var not in static_vars:
161
- dims = ds[var].dims
162
- if dim_name not in dims:
163
- raise ValueError(f"Variable {var} in file {f} missing merge dimension {dim_name}")
164
- dim_sizes.setdefault(var, []).append(ds[var].sizes[dim_name])
165
-
166
- # Check dimension continuity
167
- for var, sizes in dim_sizes.items():
168
- if len(set(sizes[1:])) > 1:
169
- raise ValueError(f"Variable {var} has inconsistent {dim_name} dimension lengths: {sizes}")
170
-
171
-
172
- def _process_time_dimension(ds: xr.Dataset) -> xr.Dataset:
173
- """Special processing for time dimension"""
174
- if "time" not in ds.dims:
175
- return ds
176
-
177
- # Sort and deduplicate
178
- ds = ds.sortby("time")
179
- # Find indices of unique timestamps
180
- _, index = np.unique(ds["time"], return_index=True)
181
- # No need to sort indices again as we want to keep original time order
182
- return ds.isel(time=index)
183
-
184
-
185
- def _generate_encoding_config(ds: xr.Dataset, compression: Union[bool, Dict]) -> Dict:
186
- """Generate compression encoding configuration"""
187
- if not compression:
188
- return {}
189
-
190
- # Default compression settings base
191
- def _get_default_encoding(var):
192
- return {"zlib": True, "complevel": 3, "dtype": "float32" if ds[var].dtype == "float64" else ds[var].dtype}
193
-
194
- # Handle custom compression configuration
195
- encoding = {}
196
- if isinstance(compression, dict):
197
- for var in ds.data_vars:
198
- encoding[var] = _get_default_encoding(var)
199
- encoding[var].update(compression.get(var, {})) # Use dict.update() to merge dictionaries
200
- else:
201
- for var in ds.data_vars:
202
- encoding[var] = _get_default_encoding(var)
203
-
204
- return encoding
205
-
206
- def _calculate_file_size(filepath: str) -> str:
207
- """Calculate file size with adaptive unit conversion"""
208
- if os.path.exists(filepath):
209
- size_in_bytes = os.path.getsize(filepath)
210
- if size_in_bytes < 1e3:
211
- return f"{size_in_bytes:.2f} B"
212
- elif size_in_bytes < 1e6:
213
- return f"{size_in_bytes / 1e3:.2f} KB"
214
- elif size_in_bytes < 1e9:
215
- return f"{size_in_bytes / 1e6:.2f} MB"
216
- else:
217
- return f"{size_in_bytes / 1e9:.2f} GB"
218
- else:
219
- raise FileNotFoundError(f"File not found: {filepath}")
220
-
221
- def _write_to_netcdf(ds: xr.Dataset, filename: str, encoding: Dict) -> None:
222
- """Improved safe writing to NetCDF file"""
223
- logger.info("Starting file write...")
224
- unlimited_dims = [dim for dim in ds.dims if ds[dim].encoding.get("unlimited", False)]
225
-
226
- delayed = ds.to_netcdf(filename, encoding=encoding, compute=False, unlimited_dims=unlimited_dims)
227
-
228
- try:
54
+ data_var = ds[var]
55
+ if dim_name in data_var.dims:
56
+ merged_data.setdefault(var, []).append(data_var)
57
+ elif var not in merged_data:
58
+ merged_data[var] = data_var.fillna(0) # 用0填充NaN值
59
+
60
+ for var in pbar(merged_data, description="Merging variables", color="#9b45d1"):
61
+ if isinstance(merged_data[var], list):
62
+ merged_data[var] = xr.concat(merged_data[var], dim=dim_name).fillna(0)
63
+ # print(f"Variable '{var}' merged: min={merged_data[var].min().values:.3f}, max={merged_data[var].max().values:.3f}, mean={merged_data[var].mean().values:.3f}")
64
+
65
+ # 修改写入数据部分,支持压缩并设置基数和比例因子
66
+ # print("\nWriting data to file ...")
67
+ if os.path.exists(target_filename):
68
+ print("Warning: The target file already exists. Removing it ...")
69
+ os.remove(target_filename)
70
+
71
+ with xr.Dataset(merged_data) as merged_dataset:
72
+ encoding = {}
73
+ for var in merged_dataset.data_vars:
74
+ data = merged_dataset[var].values
75
+ # print(f"Variable '{var}' ready for writing: min={data.min():.3f}, max={data.max():.3f}, mean={data.mean():.3f}")
76
+ if data.dtype.kind in {"i", "u", "f"}: # 仅对数值型数据进行压缩
77
+ data_range = data.max() - data.min()
78
+ if data_range > 0: # 避免范围过小导致的精度问题
79
+ scale_factor = data_range / (2**16 - 1)
80
+ add_offset = data.min()
81
+ encoding[var] = {
82
+ "zlib": True,
83
+ "complevel": 4,
84
+ "dtype": "int16",
85
+ "scale_factor": scale_factor,
86
+ "add_offset": add_offset,
87
+ "_FillValue": -32767,
88
+ }
89
+ else:
90
+ encoding[var] = {"zlib": True, "complevel": 4} # 范围过小时禁用缩放
91
+ else:
92
+ encoding[var] = {"zlib": True, "complevel": 4} # 非数值型数据不使用缩放
93
+
94
+ # 确保写入时不会因编码问题导致数据丢失
95
+ # merged_dataset.to_netcdf(target_filename, encoding=encoding)
96
+ delayed_write = merged_dataset.to_netcdf(target_filename, encoding=encoding, compute=False)
229
97
  with ProgressBar():
230
- delayed.compute()
231
-
232
- logger.info(f"Merge completed → {filename}")
233
- # logger.info(f"File size: {os.path.getsize(filename) / 1e9:.2f}GB")
234
- logger.info(f"File size: {_calculate_file_size(filename)}")
235
- except MemoryError as e:
236
- _handle_write_error(filename, "Insufficient memory to complete file write. Try adjusting chunk_config parameter to reduce memory usage", e)
237
- except Exception as e:
238
- _handle_write_error(filename, f"Failed to write file: {str(e)}", e)
239
-
240
-
241
- def _handle_write_error(filename: str, message: str, exception: Exception) -> None:
242
- """Unified handling of file write exceptions"""
243
- logger.error(message)
244
- if os.path.exists(filename):
245
- os.remove(filename)
246
- raise exception
98
+ delayed_write.compute()
247
99
 
248
-
249
- def _estimate_memory_usage(file_list: List[str], var_names: List[str], chunk_config: Dict) -> None:
250
- """Improved memory usage estimation"""
251
- try:
252
- total_size = 0
253
- sample_file = file_list[0]
254
- with xr.open_dataset(sample_file) as ds:
255
- for var in var_names:
256
- if var in ds:
257
- # Consider variable dimension sizes
258
- var_size = np.prod([ds[var].sizes[dim] for dim in ds[var].dims]) * ds[var].dtype.itemsize
259
- total_size += var_size * len(file_list)
260
-
261
- # Estimate memory usage during Dask processing (typically 2-3x original data)
262
- estimated_memory = total_size * 3
263
-
264
- if estimated_memory > 8e9:
265
- logger.warning(f"Estimated memory usage may be high (approx. {estimated_memory / 1e9:.1f}GB). If memory issues occur, adjust chunk_config parameter: {chunk_config}")
266
- except Exception as e:
267
- logger.debug(f"Memory estimation failed: {str(e)}")
100
+ print(f'\nFile "{target_filename}" has been successfully created.')
268
101
 
269
102
 
103
+ # Example usage
270
104
  if __name__ == "__main__":
271
- # 示例文件列表(请替换为实际文件路径)
272
- sample_files = ["data/file1.nc", "data/file2.nc", "data/file3.nc"]
273
-
274
- # 示例1: 基础用法 - 合并全部变量
275
- print("\n" + "=" * 40)
276
- print("示例1: 合并所有变量(默认配置)")
277
- merge(file_list=sample_files, target_filename="merged_all_vars.nc")
278
-
279
- # 示例2: 合并指定变量
280
- print("\n" + "=" * 40)
281
- print("示例2: 合并指定变量(温度、盐度)")
282
- merge(
283
- file_list=sample_files,
284
- var_name=["temperature", "salinity"],
285
- target_filename="merged_selected_vars.nc",
286
- chunk_config={"time": 500}, # 更保守的内存分配
287
- )
288
-
289
- # 示例3: 自定义压缩配置
290
- print("\n" + "=" * 40)
291
- print("示例3: 自定义压缩参数")
292
- merge(file_list=sample_files, var_name="chlorophyll", compression={"chlorophyll": {"zlib": True, "complevel": 5, "dtype": "float32"}}, target_filename="merged_compressed.nc")
293
-
294
- # 示例4: 处理大型数据集
295
- print("\n" + "=" * 40)
296
- print("示例4: 大文件分块策略")
297
- merge(file_list=sample_files, chunk_config={"time": 2000, "lat": 100, "lon": 100}, target_filename="merged_large_dataset.nc", parallel=True)
298
-
299
- # 示例5: 时间维度特殊处理
300
- print("\n" + "=" * 40)
301
- print("示例5: 时间维度排序去重")
302
- merge(
303
- file_list=sample_files,
304
- dim_name="time",
305
- target_filename="merged_time_processed.nc",
306
- sanity_check=True, # 强制数据校验
307
- )
308
-
309
- # 示例6: 覆盖已存在文件
310
- print("\n" + "=" * 40)
311
- print("示例6: 强制覆盖现有文件")
312
- try:
313
- merge(
314
- file_list=sample_files,
315
- target_filename="merged_all_vars.nc", # 与示例1相同文件名
316
- overwrite=True, # 显式启用覆盖
317
- )
318
- except FileExistsError as e:
319
- print(f"捕获预期外异常: {str(e)}")
320
-
321
- # 示例7: 禁用并行处理
322
- print("\n" + "=" * 40)
323
- print("示例7: 单线程模式运行")
324
- merge(file_list=sample_files, target_filename="merged_single_thread.nc", parallel=False)
325
-
326
- # 示例8: 处理特殊维度
327
- print("\n" + "=" * 40)
328
- print("示例8: 按深度维度合并")
329
- merge(file_list=sample_files, dim_name="depth", var_name=["density", "oxygen"], target_filename="merged_by_depth.nc")
330
-
331
- # 示例9: 混合变量类型处理
332
- print("\n" + "=" * 40)
333
- print("示例9: 混合静态/动态变量")
334
- merge(
335
- file_list=sample_files,
336
- var_name=["bathymetry", "temperature"], # bathymetry为静态变量
337
- target_filename="merged_mixed_vars.nc",
338
- sanity_check=True, # 验证静态变量一致性
339
- )
340
-
341
- # 示例10: 完整配置演示
342
- print("\n" + "=" * 40)
343
- print("示例10: 全参数配置演示")
344
- merge(
345
- file_list=sample_files,
346
- var_name=None, # 所有变量
347
- dim_name="time",
348
- target_filename="merged_full_config.nc",
349
- chunk_config={"time": 1000, "lat": 500, "lon": 500},
350
- compression={"temperature": {"complevel": 4}, "salinity": {"zlib": False}},
351
- sanity_check=True,
352
- overwrite=True,
353
- parallel=True,
354
- )
105
+ files_to_merge = ["file1.nc", "file2.nc", "file3.nc"]
106
+ output_path = "merged_output.nc"
107
+ merge_nc(files_to_merge, var_name=None, dim_name="time", target_filename=output_path)
@@ -2,10 +2,10 @@
2
2
  # coding=utf-8
3
3
  """
4
4
  Author: Liu Kun && 16031215@qq.com
5
- Date: 2025-01-11 19:47:08
5
+ Date: 2025-04-04 20:19:23
6
6
  LastEditors: Liu Kun && 16031215@qq.com
7
- LastEditTime: 2025-03-18 19:21:36
8
- FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_tool\\parallel.py
7
+ LastEditTime: 2025-04-04 20:19:23
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\_script\\parallel.py
9
9
  Description:
10
10
  EditPlatform: vscode
11
11
  ComputerInfo: XPS 15 9510
@@ -13,6 +13,8 @@ SystemInfo: Windows 11
13
13
  Python Version: 3.12
14
14
  """
15
15
 
16
+
17
+
16
18
  import contextlib
17
19
  import logging
18
20
  import multiprocessing as mp
oafuncs/oa_data.py CHANGED
@@ -17,14 +17,13 @@ import itertools
17
17
  import multiprocessing as mp
18
18
  from concurrent.futures import ThreadPoolExecutor
19
19
 
20
+
20
21
  import numpy as np
21
22
  import salem
22
23
  import xarray as xr
23
- from scipy.interpolate import griddata
24
- from scipy.interpolate import interp1d
25
- from typing import Iterable
24
+ from scipy.interpolate import griddata, interp1d
26
25
 
27
- __all__ = ["interp_along_dim", "interp_2d", "ensure_list", "mask_shapefile", "pbar"]
26
+ __all__ = ["interp_along_dim", "interp_2d", "ensure_list", "mask_shapefile"]
28
27
 
29
28
 
30
29
  def ensure_list(input_data):
@@ -255,26 +254,6 @@ def mask_shapefile(data: np.ndarray, lons: np.ndarray, lats: np.ndarray, shapefi
255
254
  return None
256
255
 
257
256
 
258
- def pbar(iterable: Iterable, prefix: str = "", color: str = "cyan", cmap: str = None, **kwargs) -> Iterable:
259
- """
260
- 快速创建进度条的封装函数
261
- :param iterable: 可迭代对象
262
- :param prefix: 进度条前缀
263
- :param color: 基础颜色
264
- :param cmap: 渐变色名称
265
- :param kwargs: 其他ColorProgressBar支持的参数
266
-
267
- example:
268
- from oafuncs.oa_data import pbar
269
- from time import sleep
270
- for i in pbar(range(100), prefix="Processing", color="green", cmap="viridis"):
271
- sleep(0.1)
272
- """
273
- from ._script.cprogressbar import ColorProgressBar # 从progressbar.py导入类
274
-
275
- return ColorProgressBar(iterable=iterable, prefix=prefix, color=color, cmap=cmap, **kwargs)
276
-
277
-
278
257
  if __name__ == "__main__":
279
258
  pass
280
259
  """ import time
oafuncs/oa_date.py CHANGED
@@ -4,7 +4,7 @@
4
4
  Author: Liu Kun && 16031215@qq.com
5
5
  Date: 2025-03-27 16:56:57
6
6
  LastEditors: Liu Kun && 16031215@qq.com
7
- LastEditTime: 2025-03-27 16:56:57
7
+ LastEditTime: 2025-04-04 12:58:15
8
8
  FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_date.py
9
9
  Description:
10
10
  EditPlatform: vscode
@@ -13,10 +13,12 @@ SystemInfo: Windows 11
13
13
  Python Version: 3.12
14
14
  """
15
15
 
16
+
17
+
16
18
  import calendar
17
19
  import datetime
18
20
 
19
- __all__ = ["get_days_in_month", "generate_hour_list", "adjust_time"]
21
+ __all__ = ["get_days_in_month", "generate_hour_list", "adjust_time", "timeit"]
20
22
 
21
23
 
22
24
  def get_days_in_month(year, month):
@@ -88,3 +90,23 @@ def adjust_time(initial_time, amount, time_unit="hours", output_format=None):
88
90
  elif time_unit == "days":
89
91
  default_format = "%Y%m%d"
90
92
  return time_obj.strftime(default_format)
93
+
94
+ class timeit:
95
+ """
96
+ A decorator to measure the execution time of a function.
97
+
98
+ Usage:
99
+ @timeit
100
+ def my_function():
101
+ # Function code here
102
+ """
103
+ def __init__(self, func):
104
+ self.func = func
105
+
106
+ def __call__(self, *args, **kwargs):
107
+ start_time = datetime.datetime.now()
108
+ result = self.func(*args, **kwargs)
109
+ end_time = datetime.datetime.now()
110
+ elapsed_time = (end_time - start_time).total_seconds()
111
+ print(f"Function '{self.func.__name__}' executed in {elapsed_time:.2f} seconds.")
112
+ return result
oafuncs/oa_nc.py CHANGED
@@ -14,12 +14,12 @@ Python Version: 3.11
14
14
  """
15
15
 
16
16
  import os
17
+ from typing import List, Optional, Union
17
18
 
18
19
  import netCDF4 as nc
19
20
  import numpy as np
20
21
  import xarray as xr
21
22
  from rich import print
22
- from typing import Dict, List, Union
23
23
 
24
24
  __all__ = ["get_var", "extract", "save", "merge", "modify", "rename", "check", "convert_longitude", "isel", "draw"]
25
25
 
@@ -136,7 +136,18 @@ def save(file, data, varname=None, coords=None, mode="w", scale_offset_switch=Tr
136
136
  with nc.Dataset(file, mode, format="NETCDF4") as ncfile:
137
137
  # 如果 data 是 DataArray 并且没有提供 varname 和 coords
138
138
  if varname is None and coords is None and isinstance(data, xr.DataArray):
139
- data.to_netcdf(file, mode=mode)
139
+ encoding = {}
140
+ for var in data.data_vars:
141
+ scale_factor, add_offset = _calculate_scale_and_offset(data[var].values)
142
+ encoding[var] = {
143
+ "zlib": True,
144
+ "complevel": 4,
145
+ "dtype": "int16",
146
+ "scale_factor": scale_factor,
147
+ "add_offset": add_offset,
148
+ "_FillValue": -32767,
149
+ }
150
+ data.to_netcdf(file, mode=mode, encoding=encoding)
140
151
  return
141
152
 
142
153
  # 添加坐标
@@ -182,29 +193,10 @@ def save(file, data, varname=None, coords=None, mode="w", scale_offset_switch=Tr
182
193
  var.setncattr(key, value)
183
194
 
184
195
 
185
- def merge(file_list: Union[str, List[str]], var_name: Union[str, List[str], None] = None, dim_name: str = "time", target_filename: str = "merged.nc", chunk_config: Dict = {"time": 1000}, compression: Union[bool, Dict] = True, sanity_check: bool = True, overwrite: bool = True, parallel: bool = True) -> None:
186
- """
187
- NetCDF合并函数
188
-
189
- Parameters:
190
- file_list: 文件路径列表或单个文件路径
191
- var_name: 需要合并的变量(单个变量名/变量列表/None表示全部)
192
- dim_name: 合并维度,默认为'time'
193
- target_filename: 输出文件路径
194
- chunk_config: Dask分块配置,如{"time": 1000}
195
- compression: 压缩配置(True启用默认压缩,或自定义编码字典)
196
- sanity_check: 是否执行数据完整性校验
197
- overwrite: 是否覆盖已存在文件
198
- parallel: 是否启用并行处理
196
+ def merge(file_list: Union[str, List[str]], var_name: Optional[Union[str, List[str]]] = None, dim_name: Optional[str] = None, target_filename: Optional[str] = None) -> None:
197
+ from ._script.netcdf_merge import merge_nc
199
198
 
200
- Example:
201
- merge(["data1.nc", "data2.nc"],
202
- var_name=["temp", "salt"],
203
- target_filename="result.nc",
204
- chunk_config={"time": 500})
205
- """
206
- from ._script.netcdf_merge import merge as nc_merge
207
- nc_merge(file_list, var_name, dim_name, target_filename, chunk_config, compression, sanity_check, overwrite, parallel)
199
+ merge_nc(file_list, var_name, dim_name, target_filename)
208
200
 
209
201
 
210
202
  def _modify_var(nc_file_path, variable_name, new_value):
@@ -230,8 +222,7 @@ def _modify_var(nc_file_path, variable_name, new_value):
230
222
  variable = dataset.variables[variable_name]
231
223
  # Check if the shape of the new value matches the variable's shape
232
224
  if variable.shape != new_value.shape:
233
- raise ValueError(f"Shape mismatch: Variable '{variable_name}' has shape {variable.shape}, "
234
- f"but new value has shape {new_value.shape}.")
225
+ raise ValueError(f"Shape mismatch: Variable '{variable_name}' has shape {variable.shape}, but new value has shape {new_value.shape}.")
235
226
  # Modify the value of the variable
236
227
  variable[:] = new_value
237
228
  print(f"Successfully modified variable '{variable_name}' in '{nc_file_path}'.")
@@ -264,8 +255,7 @@ def _modify_attr(nc_file_path, variable_name, attribute_name, attribute_value):
264
255
  variable.setncattr(attribute_name, attribute_value)
265
256
  print(f"Successfully modified attribute '{attribute_name}' of variable '{variable_name}' in '{nc_file_path}'.")
266
257
  except Exception as e:
267
- print(f"[red]Error:[/red] Failed to modify attribute '{attribute_name}' of variable '{variable_name}' "
268
- f"in file '{nc_file_path}'. [bold]Details:[/bold] {e}")
258
+ print(f"[red]Error:[/red] Failed to modify attribute '{attribute_name}' of variable '{variable_name}' in file '{nc_file_path}'. [bold]Details:[/bold] {e}")
269
259
 
270
260
 
271
261
  def modify(nc_file, var_name, attr_name=None, new_value=None):
@@ -435,7 +425,7 @@ def isel(ncfile, dim_name, slice_list):
435
425
  return ds_new
436
426
 
437
427
 
438
- def draw(output_dir=None, dataset=None, ncfile=None, xyzt_dims=("longitude", "latitude", "level", "time"), plot_type="contourf",fixed_colorscale=False):
428
+ def draw(output_dir=None, dataset=None, ncfile=None, xyzt_dims=("longitude", "latitude", "level", "time"), plot_type="contourf", fixed_colorscale=False):
439
429
  """
440
430
  Description:
441
431
  Draw the data in the netCDF file
@@ -454,6 +444,7 @@ def draw(output_dir=None, dataset=None, ncfile=None, xyzt_dims=("longitude", "la
454
444
  draw(ncfile, output_dir, x_dim="longitude", y_dim="latitude", z_dim="level", t_dim="time", fixed_colorscale=False)
455
445
  """
456
446
  from ._script.plot_dataset import func_plot_dataset
447
+
457
448
  if output_dir is None:
458
449
  output_dir = str(os.getcwd())
459
450
  if isinstance(xyzt_dims, (list, tuple)):