oafuncs 0.0.91__tar.gz → 0.0.93__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {oafuncs-0.0.91/oafuncs.egg-info → oafuncs-0.0.93}/PKG-INFO +12 -2
  2. oafuncs-0.0.93/oafuncs/oa_data.py +153 -0
  3. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/hycom_3hourly.py +246 -148
  4. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/idm.py +1 -1
  5. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/literature.py +11 -10
  6. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_file.py +6 -2
  7. {oafuncs-0.0.91 → oafuncs-0.0.93/oafuncs.egg-info}/PKG-INFO +12 -2
  8. {oafuncs-0.0.91 → oafuncs-0.0.93}/setup.py +1 -1
  9. oafuncs-0.0.91/oafuncs/oa_data.py +0 -278
  10. {oafuncs-0.0.91 → oafuncs-0.0.93}/LICENSE.txt +0 -0
  11. {oafuncs-0.0.91 → oafuncs-0.0.93}/MANIFEST.in +0 -0
  12. {oafuncs-0.0.91 → oafuncs-0.0.93}/README.md +0 -0
  13. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/__init__.py +0 -0
  14. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/data_store/OAFuncs.png +0 -0
  15. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_cmap.py +0 -0
  16. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/User_Agent-list.txt +0 -0
  17. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/__init__.py +0 -0
  18. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/test_ua.py +0 -0
  19. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/user_agent.py +0 -0
  20. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_draw.py +0 -0
  21. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_help.py +0 -0
  22. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_nc.py +0 -0
  23. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_python.py +0 -0
  24. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_sign/__init__.py +0 -0
  25. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_sign/meteorological.py +0 -0
  26. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_sign/ocean.py +0 -0
  27. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_sign/scientific.py +0 -0
  28. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_tool/__init__.py +0 -0
  29. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_tool/email.py +0 -0
  30. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_tool/parallel.py +0 -0
  31. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs.egg-info/SOURCES.txt +0 -0
  32. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs.egg-info/dependency_links.txt +0 -0
  33. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs.egg-info/requires.txt +0 -0
  34. {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs.egg-info/top_level.txt +0 -0
  35. {oafuncs-0.0.91 → oafuncs-0.0.93}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: oafuncs
3
- Version: 0.0.91
3
+ Version: 0.0.93
4
4
  Summary: Oceanic and Atmospheric Functions
5
5
  Home-page: https://github.com/Industry-Pays/OAFuncs
6
6
  Author: Kun Liu
@@ -30,6 +30,16 @@ Requires-Dist: matplotlib
30
30
  Requires-Dist: Cartopy
31
31
  Requires-Dist: netCDF4
32
32
  Requires-Dist: xlrd
33
+ Dynamic: author
34
+ Dynamic: author-email
35
+ Dynamic: classifier
36
+ Dynamic: description
37
+ Dynamic: description-content-type
38
+ Dynamic: home-page
39
+ Dynamic: license
40
+ Dynamic: requires-dist
41
+ Dynamic: requires-python
42
+ Dynamic: summary
33
43
 
34
44
 
35
45
  # oafuncs
@@ -0,0 +1,153 @@
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Author: Liu Kun && 16031215@qq.com
5
+ Date: 2024-09-17 17:12:47
6
+ LastEditors: Liu Kun && 16031215@qq.com
7
+ LastEditTime: 2024-12-13 19:11:08
8
+ FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_data.py
9
+ Description:
10
+ EditPlatform: vscode
11
+ ComputerInfo: XPS 15 9510
12
+ SystemInfo: Windows 11
13
+ Python Version: 3.11
14
+ """
15
+
16
+ import itertools
17
+ import multiprocessing as mp
18
+ from concurrent.futures import ThreadPoolExecutor
19
+
20
+ import numpy as np
21
+ from scipy.interpolate import griddata
22
+
23
+
24
+ __all__ = ["interp_2d", "ensure_list"]
25
+
26
+
27
+ def ensure_list(input_data):
28
+ """
29
+ Ensures that the input is converted into a list.
30
+
31
+ If the input is already a list, it returns it directly.
32
+ If the input is a string, it wraps it in a list and returns.
33
+ For other types of input, it converts them to a string and then wraps in a list.
34
+
35
+ :param input_data: The input which can be a list, a string, or any other type.
36
+ :return: A list containing the input or the string representation of the input.
37
+ """
38
+ if isinstance(input_data, list):
39
+ return input_data
40
+ elif isinstance(input_data, str):
41
+ return [input_data]
42
+ else:
43
+ # For non-list and non-string inputs, convert to string and wrap in a list
44
+ return [str(input_data)]
45
+
46
+
47
+
48
+ def interp_2d(target_x, target_y, origin_x, origin_y, data, method="linear", parallel=True):
49
+ """
50
+ Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
51
+
52
+ Parameters:
53
+ - target_x (array-like): 1D array of target grid's x-coordinates.
54
+ - target_y (array-like): 1D array of target grid's y-coordinates.
55
+ - origin_x (array-like): 1D array of original grid's x-coordinates.
56
+ - origin_y (array-like): 1D array of original grid's y-coordinates.
57
+ - data (numpy.ndarray): Multi-dimensional array where the last two dimensions correspond to the original grid.
58
+ - method (str, optional): Interpolation method, default is 'linear'. Other options include 'nearest', 'cubic', etc.
59
+ - parallel (bool, optional): Flag to enable parallel processing. Default is True.
60
+
61
+ Returns:
62
+ - interpolated_data (numpy.ndarray): Interpolated data with the same leading dimensions as the input data, but with the last two dimensions corresponding to the target grid.
63
+
64
+ Raises:
65
+ - ValueError: If the shape of the data does not match the shape of the origin_x or origin_y grids.
66
+
67
+ Usage:
68
+ - Interpolate a 2D array:
69
+ result = interp_2d(target_x, target_y, origin_x, origin_y, data_2d)
70
+ - Interpolate a 3D array (where the last two dimensions are spatial):
71
+ result = interp_2d(target_x, target_y, origin_x, origin_y, data_3d)
72
+ - Interpolate a 4D array (where the last two dimensions are spatial):
73
+ result = interp_2d(target_x, target_y, origin_x, origin_y, data_4d)
74
+ """
75
+
76
+ def interp_single(data_slice, target_points, origin_points, method):
77
+ return griddata(origin_points, data_slice.ravel(), target_points, method=method).reshape(target_y.shape)
78
+
79
+ # 确保目标网格和初始网格都是二维的
80
+ if len(target_y.shape) == 1:
81
+ target_x, target_y = np.meshgrid(target_x, target_y)
82
+ if len(origin_y.shape) == 1:
83
+ origin_x, origin_y = np.meshgrid(origin_x, origin_y)
84
+
85
+ # 根据经纬度网格判断输入数据的形状是否匹配
86
+ if origin_x.shape != data.shape[-2:] or origin_y.shape != data.shape[-2:]:
87
+ raise ValueError("Shape of data does not match shape of origin_x or origin_y.")
88
+
89
+ # 创建网格和展平数据
90
+ target_points = np.column_stack((target_y.ravel(), target_x.ravel()))
91
+ origin_points = np.column_stack((origin_y.ravel(), origin_x.ravel()))
92
+
93
+ # 根据是否并行选择不同的执行方式
94
+ if parallel:
95
+ with ThreadPoolExecutor(max_workers=mp.cpu_count() - 2) as executor:
96
+ if len(data.shape) == 2:
97
+ interpolated_data = list(executor.map(interp_single, [data], [target_points], [origin_points], [method]))
98
+ elif len(data.shape) == 3:
99
+ interpolated_data = list(executor.map(interp_single, [data[i] for i in range(data.shape[0])], [target_points] * data.shape[0], [origin_points] * data.shape[0], [method] * data.shape[0]))
100
+ elif len(data.shape) == 4:
101
+ index_combinations = list(itertools.product(range(data.shape[0]), range(data.shape[1])))
102
+ interpolated_data = list(executor.map(interp_single, [data[i, j] for i, j in index_combinations], [target_points] * len(index_combinations), [origin_points] * len(index_combinations), [method] * len(index_combinations)))
103
+ interpolated_data = np.array(interpolated_data).reshape(data.shape[0], data.shape[1], *target_y.shape)
104
+ else:
105
+ if len(data.shape) == 2:
106
+ interpolated_data = interp_single(data, target_points, origin_points, method)
107
+ elif len(data.shape) == 3:
108
+ interpolated_data = np.stack([interp_single(data[i], target_points, origin_points, method) for i in range(data.shape[0])])
109
+ elif len(data.shape) == 4:
110
+ interpolated_data = np.stack([np.stack([interp_single(data[i, j], target_points, origin_points, method) for j in range(data.shape[1])]) for i in range(data.shape[0])])
111
+
112
+ return np.array(interpolated_data)
113
+
114
+
115
+
116
+ if __name__ == "__main__":
117
+
118
+ pass
119
+ """ import time
120
+
121
+ import matplotlib.pyplot as plt
122
+
123
+ # 测试数据
124
+ origin_x = np.linspace(0, 10, 11)
125
+ origin_y = np.linspace(0, 10, 11)
126
+ target_x = np.linspace(0, 10, 101)
127
+ target_y = np.linspace(0, 10, 101)
128
+ data = np.random.rand(11, 11)
129
+
130
+ # 高维插值
131
+ origin_x = np.linspace(0, 10, 11)
132
+ origin_y = np.linspace(0, 10, 11)
133
+ target_x = np.linspace(0, 10, 101)
134
+ target_y = np.linspace(0, 10, 101)
135
+ data = np.random.rand(10, 10, 11, 11)
136
+
137
+ start = time.time()
138
+ interpolated_data = interp_2d(target_x, target_y, origin_x, origin_y, data, parallel=False)
139
+ print(f"Interpolation time: {time.time()-start:.2f}s")
140
+
141
+ print(interpolated_data.shape)
142
+
143
+ # 高维插值多线程
144
+ start = time.time()
145
+ interpolated_data = interp_2d(target_x, target_y, origin_x, origin_y, data)
146
+ print(f"Interpolation time: {time.time()-start:.2f}s")
147
+
148
+ print(interpolated_data.shape)
149
+ print(interpolated_data[0, 0, :, :].shape)
150
+ plt.figure()
151
+ plt.contourf(target_x, target_y, interpolated_data[0, 0, :, :])
152
+ plt.colorbar()
153
+ plt.show() """
@@ -26,13 +26,17 @@ from threading import Lock
26
26
  import matplotlib.pyplot as plt
27
27
  import numpy as np
28
28
  import pandas as pd
29
+ import xarray as xr
29
30
  import requests
30
31
  from rich import print
31
32
  from rich.progress import Progress
33
+ import netCDF4 as nc
32
34
 
33
35
  from oafuncs.oa_down.user_agent import get_ua
34
36
  from oafuncs.oa_file import file_size, mean_size
35
37
  from oafuncs.oa_nc import check as check_nc
38
+ from oafuncs.oa_nc import modify as modify_nc
39
+ from oafuncs.oa_down.idm import downloader as idm_downloader
36
40
 
37
41
  warnings.filterwarnings("ignore", category=RuntimeWarning, message="Engine '.*' loading failed:.*")
38
42
 
@@ -571,20 +575,26 @@ def _check_existing_file(file_full_path, avg_size):
571
575
  if abs(delta_size_ratio) > 0.025:
572
576
  if check_nc(file_full_path):
573
577
  # print(f"File size is abnormal but can be opened normally, file size: {fsize:.2f} KB")
574
- return True
578
+ if not _check_ftime(file_full_path, if_print=True):
579
+ return False
580
+ else:
581
+ return True
575
582
  else:
576
583
  print(f"File size is abnormal and cannot be opened, {file_full_path}: {fsize:.2f} KB")
577
584
  return False
578
585
  else:
579
- return True
586
+ if not _check_ftime(file_full_path, if_print=True):
587
+ return False
588
+ else:
589
+ return True
580
590
  else:
581
591
  return False
582
592
 
583
593
 
584
594
  def _get_mean_size30(store_path, same_file):
585
595
  if same_file not in fsize_dict.keys():
586
- # print(f'Same file name: {same_file}')
587
- fsize_dict[same_file] = {"size": 0, "count": 0}
596
+ # print(f'Same file name: {same_file}')
597
+ fsize_dict[same_file] = {"size": 0, "count": 0}
588
598
 
589
599
  if fsize_dict[same_file]["count"] < 30 or fsize_dict[same_file]["size"] == 0:
590
600
  # 更新30次文件最小值,后续认为可以代表所有文件,不再更新占用时间
@@ -599,7 +609,7 @@ def _get_mean_size30(store_path, same_file):
599
609
 
600
610
  def _get_mean_size_move(same_file, current_file):
601
611
  # 获取锁
602
- with fsize_dict_lock: # 全局锁,确保同一时间只能有一个线程访问
612
+ with fsize_dict_lock: # 全局锁,确保同一时间只能有一个线程访问
603
613
  # 初始化字典中的值,如果文件不在字典中
604
614
  if same_file not in fsize_dict.keys():
605
615
  fsize_dict[same_file] = {"size_list": [], "mean_size": 1.0}
@@ -633,6 +643,61 @@ def _get_mean_size_move(same_file, current_file):
633
643
  return fsize_dict[same_file]["mean_size"]
634
644
 
635
645
 
646
+ def _check_ftime(nc_file, tname="time", if_print=False):
647
+ if not os.path.exists(nc_file):
648
+ return False
649
+ nc_file = str(nc_file)
650
+ try:
651
+ ds = xr.open_dataset(nc_file)
652
+ real_time = ds[tname].values[0]
653
+ ds.close()
654
+ real_time = str(real_time)[:13]
655
+ real_time = real_time.replace("-", "").replace("T", "")
656
+ # -----------------------------------------------------
657
+ f_time = re.findall(r"\d{10}", nc_file)[0]
658
+ if real_time == f_time:
659
+ return True
660
+ else:
661
+ if if_print:
662
+ print(f"[bold #daff5c]File time error, file/real time: [bold blue]{f_time}/{real_time}")
663
+ return False
664
+ except Exception as e:
665
+ if if_print:
666
+ print(f"[bold #daff5c]File time check failed, {nc_file}: {e}")
667
+ return False
668
+
669
+
670
+ def _correct_time(nc_file):
671
+ # 打开NC文件
672
+ dataset = nc.Dataset(nc_file)
673
+
674
+ # 读取时间单位
675
+ time_units = dataset.variables["time"].units
676
+
677
+ # 关闭文件
678
+ dataset.close()
679
+
680
+ # 解析时间单位字符串以获取时间原点
681
+ origin_str = time_units.split("since")[1].strip()
682
+ origin_datetime = datetime.datetime.strptime(origin_str, "%Y-%m-%d %H:%M:%S")
683
+
684
+ # 从文件名中提取日期字符串
685
+ given_date_str = re.findall(r"\d{10}", str(nc_file))[0]
686
+
687
+ # 将提取的日期字符串转换为datetime对象
688
+ given_datetime = datetime.datetime.strptime(given_date_str, "%Y%m%d%H")
689
+
690
+ # 计算给定日期与时间原点之间的差值(以小时为单位)
691
+ time_difference = (given_datetime - origin_datetime).total_seconds()
692
+ if "hours" in time_units:
693
+ time_difference /= 3600
694
+ elif "days" in time_units:
695
+ time_difference /= 3600 * 24
696
+
697
+ # 修改NC文件中的时间变量
698
+ modify_nc(nc_file, "time", None, time_difference)
699
+
700
+
636
701
  def _download_file(target_url, store_path, file_name, check=False):
637
702
  # Check if the file exists
638
703
  fname = Path(store_path) / file_name
@@ -640,108 +705,122 @@ def _download_file(target_url, store_path, file_name, check=False):
640
705
  file_name_split = file_name_split[:-1]
641
706
  # same_file = f"{file_name_split[0]}_{file_name_split[1]}*nc"
642
707
  same_file = "_".join(file_name_split) + "*nc"
643
-
708
+
644
709
  if check:
645
- if same_file not in fsize_dict.keys(): # 对第一个文件单独进行检查,因为没有大小可以对比
646
- check_nc(fname,if_delete=True)
710
+ if same_file not in fsize_dict.keys(): # 对第一个文件单独进行检查,因为没有大小可以对比
711
+ check_nc(fname, if_delete=True)
647
712
 
648
713
  # set_min_size = _get_mean_size30(store_path, same_file) # 原方案,只30次取平均值;若遇变化,无法判断
649
714
  get_mean_size = _get_mean_size_move(same_file, fname)
650
-
715
+
651
716
  if _check_existing_file(fname, get_mean_size):
652
717
  count_dict["skip"] += 1
653
718
  return
654
719
  _clear_existing_file(fname)
655
720
 
656
- # -----------------------------------------------
657
- print(f"[bold #f0f6d0]Requesting {file_name}...")
658
- # 创建会话
659
- s = requests.Session()
660
- download_success = False
661
- request_times = 0
721
+ if not use_idm:
722
+ # -----------------------------------------------
723
+ print(f"[bold #f0f6d0]Requesting {file_name} ...")
724
+ # 创建会话
725
+ s = requests.Session()
726
+ download_success = False
727
+ request_times = 0
662
728
 
663
- def calculate_wait_time(time_str, target_url):
664
- # 定义正则表达式,匹配YYYYMMDDHH格式的时间
665
- time_pattern = r"\d{10}"
729
+ def calculate_wait_time(time_str, target_url):
730
+ # 定义正则表达式,匹配YYYYMMDDHH格式的时间
731
+ time_pattern = r"\d{10}"
666
732
 
667
- # 定义两个字符串
668
- # str1 = 'HYCOM_water_u_2018010100-2018010112.nc'
669
- # str2 = 'HYCOM_water_u_2018010100.nc'
733
+ # 定义两个字符串
734
+ # str1 = 'HYCOM_water_u_2018010100-2018010112.nc'
735
+ # str2 = 'HYCOM_water_u_2018010100.nc'
670
736
 
671
- # 使用正则表达式查找时间
672
- times_in_str = re.findall(time_pattern, time_str)
737
+ # 使用正则表达式查找时间
738
+ times_in_str = re.findall(time_pattern, time_str)
673
739
 
674
- # 计算每个字符串中的时间数量
675
- num_times_str = len(times_in_str)
740
+ # 计算每个字符串中的时间数量
741
+ num_times_str = len(times_in_str)
676
742
 
677
- if num_times_str > 1:
678
- delta_t = datetime.datetime.strptime(times_in_str[1], "%Y%m%d%H") - datetime.datetime.strptime(times_in_str[0], "%Y%m%d%H")
679
- delta_t = delta_t.total_seconds() / 3600
680
- delta_t = delta_t / 3 + 1
681
- else:
682
- delta_t = 1
683
- # 单个要素最多等待5分钟,不宜太短,太短可能请求失败;也不宜太长,太长可能会浪费时间
684
- num_var = int(target_url.count("var="))
685
- if num_var <= 0:
686
- num_var = 1
687
- return int(delta_t * 5 * 60 * num_var)
688
-
689
- max_timeout = calculate_wait_time(file_name, target_url)
690
- print(f"[bold #912dbc]Max timeout: {max_timeout} seconds")
691
-
692
- # print(f'Download_start_time: {datetime.datetime.now()}')
693
- download_time_s = datetime.datetime.now()
694
- order_list = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"]
695
- while not download_success:
696
- if request_times >= 10:
697
- # print(f'下载失败,已重试 {request_times} 次\n可先跳过,后续再试')
698
- print(f"[bold #ffe5c0]Download failed after {request_times} times\nYou can skip it and try again later")
699
- count_dict["fail"] += 1
700
- break
701
- if request_times > 0:
702
- # print(f'\r正在重试第 {request_times} 次', end="")
703
- print(f"[bold #ffe5c0]Retrying the {order_list[request_times-1]} time...")
704
- # 尝试下载文件
705
- try:
706
- headers = {"User-Agent": get_ua()}
707
- """ response = s.get(target_url, headers=headers, timeout=random.randint(5, max_timeout))
708
- response.raise_for_status() # 如果请求返回的不是200,将抛出HTTPError异常
709
-
710
- # 保存文件
711
- with open(filename, 'wb') as f:
712
- f.write(response.content) """
713
-
714
- response = s.get(target_url, headers=headers, stream=True, timeout=random.randint(5, max_timeout)) # 启用流式传输
715
- response.raise_for_status() # 如果请求返回的不是200,将抛出HTTPError异常
716
- # 保存文件
717
- with open(fname, "wb") as f:
718
- print(f"[bold #96cbd7]Downloading {file_name}...")
719
- for chunk in response.iter_content(chunk_size=1024):
720
- if chunk:
721
- f.write(chunk)
722
-
723
- f.close()
724
-
725
- # print(f'\r文件 {fname} 下载成功', end="")
726
- if os.path.exists(fname):
727
- download_success = True
728
- download_time_e = datetime.datetime.now()
729
- download_delta = download_time_e - download_time_s
730
- print(f"[#3dfc40]File [bold #dfff73]{fname} [#3dfc40]has been downloaded successfully, Time: [#39cbdd]{download_delta}")
731
- count_dict["success"] += 1
732
- # print(f'Download_end_time: {datetime.datetime.now()}')
733
-
734
- except requests.exceptions.HTTPError as errh:
735
- print(f"Http Error: {errh}")
736
- except requests.exceptions.ConnectionError as errc:
737
- print(f"Error Connecting: {errc}")
738
- except requests.exceptions.Timeout as errt:
739
- print(f"Timeout Error: {errt}")
740
- except requests.exceptions.RequestException as err:
741
- print(f"OOps: Something Else: {err}")
742
-
743
- time.sleep(3)
744
- request_times += 1
743
+ if num_times_str > 1:
744
+ delta_t = datetime.datetime.strptime(times_in_str[1], "%Y%m%d%H") - datetime.datetime.strptime(times_in_str[0], "%Y%m%d%H")
745
+ delta_t = delta_t.total_seconds() / 3600
746
+ delta_t = delta_t / 3 + 1
747
+ else:
748
+ delta_t = 1
749
+ # 单个要素最多等待5分钟,不宜太短,太短可能请求失败;也不宜太长,太长可能会浪费时间
750
+ num_var = int(target_url.count("var="))
751
+ if num_var <= 0:
752
+ num_var = 1
753
+ return int(delta_t * 5 * 60 * num_var)
754
+
755
+ max_timeout = calculate_wait_time(file_name, target_url)
756
+ print(f"[bold #912dbc]Max timeout: {max_timeout} seconds")
757
+
758
+ # print(f'Download_start_time: {datetime.datetime.now()}')
759
+ download_time_s = datetime.datetime.now()
760
+ order_list = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"]
761
+ while not download_success:
762
+ if request_times >= 10:
763
+ # print(f'下载失败,已重试 {request_times} 次\n可先跳过,后续再试')
764
+ print(f"[bold #ffe5c0]Download failed after {request_times} times\nYou can skip it and try again later")
765
+ count_dict["fail"] += 1
766
+ break
767
+ if request_times > 0:
768
+ # print(f'\r正在重试第 {request_times} 次', end="")
769
+ print(f"[bold #ffe5c0]Retrying the {order_list[request_times - 1]} time...")
770
+ # 尝试下载文件
771
+ try:
772
+ headers = {"User-Agent": get_ua()}
773
+ """ response = s.get(target_url, headers=headers, timeout=random.randint(5, max_timeout))
774
+ response.raise_for_status() # 如果请求返回的不是200,将抛出HTTPError异常
775
+
776
+ # 保存文件
777
+ with open(filename, 'wb') as f:
778
+ f.write(response.content) """
779
+
780
+ response = s.get(target_url, headers=headers, stream=True, timeout=random.randint(5, max_timeout)) # 启用流式传输
781
+ response.raise_for_status() # 如果请求返回的不是200,将抛出HTTPError异常
782
+ # 保存文件
783
+ with open(fname, "wb") as f:
784
+ print(f"[bold #96cbd7]Downloading {file_name} ...")
785
+ for chunk in response.iter_content(chunk_size=1024):
786
+ if chunk:
787
+ f.write(chunk)
788
+
789
+ f.close()
790
+
791
+ if not _check_ftime(fname, if_print=True):
792
+ if match_time:
793
+ _correct_time(fname)
794
+ else:
795
+ _clear_existing_file(fname)
796
+ # print(f"[bold #ffe5c0]File time error, {fname}")
797
+ count_dict["no_data"] += 1
798
+ break
799
+
800
+ # print(f'\r文件 {fname} 下载成功', end="")
801
+ if os.path.exists(fname):
802
+ download_success = True
803
+ download_time_e = datetime.datetime.now()
804
+ download_delta = download_time_e - download_time_s
805
+ print(f"[#3dfc40]File [bold #dfff73]{fname} [#3dfc40]has been downloaded successfully, Time: [#39cbdd]{download_delta}")
806
+ count_dict["success"] += 1
807
+ # print(f'Download_end_time: {datetime.datetime.now()}')
808
+
809
+ except requests.exceptions.HTTPError as errh:
810
+ print(f"Http Error: {errh}")
811
+ except requests.exceptions.ConnectionError as errc:
812
+ print(f"Error Connecting: {errc}")
813
+ except requests.exceptions.Timeout as errt:
814
+ print(f"Timeout Error: {errt}")
815
+ except requests.exceptions.RequestException as err:
816
+ print(f"OOps: Something Else: {err}")
817
+
818
+ time.sleep(3)
819
+ request_times += 1
820
+ else:
821
+ idm_downloader(target_url, store_path, file_name, given_idm_engine)
822
+ idm_download_list.append(fname)
823
+ print(f"[bold #3dfc40]File [bold #dfff73]{fname} [#3dfc40]has been submit to IDM for downloading")
745
824
 
746
825
 
747
826
  def _check_hour_is_valid(ymdh_str):
@@ -849,11 +928,11 @@ def _prepare_url_to_download(var, lon_min=0, lon_max=359.92, lat_min=-80, lat_ma
849
928
  var = current_group[0]
850
929
  submit_url = _get_submit_url_var(var, depth, level_num, lon_min, lon_max, lat_min, lat_max, dataset_name, version_name, download_time, download_time_end)
851
930
  file_name = f"HYCOM_{variable_info[var]['var_name']}_{download_time}.nc"
852
- old_str = f'var={variable_info[var]["var_name"]}'
853
- new_str = f'var={variable_info[var]["var_name"]}'
931
+ old_str = f"var={variable_info[var]['var_name']}"
932
+ new_str = f"var={variable_info[var]['var_name']}"
854
933
  if len(current_group) > 1:
855
934
  for v in current_group[1:]:
856
- new_str = f'{new_str}&var={variable_info[v]["var_name"]}'
935
+ new_str = f"{new_str}&var={variable_info[v]['var_name']}"
857
936
  submit_url = submit_url.replace(old_str, new_str)
858
937
  # file_name = f'HYCOM_{'-'.join([variable_info[v]["var_name"] for v in current_group])}_{download_time}.nc'
859
938
  file_name = f"HYCOM_{key}_{download_time}.nc"
@@ -949,7 +1028,7 @@ def _download_hourly_func(var, time_s, time_e, lon_min=0, lon_max=359.92, lat_mi
949
1028
  # 串行方式
950
1029
  for i, time_str in enumerate(time_list):
951
1030
  _prepare_url_to_download(var, lon_min, lon_max, lat_min, lat_max, time_str, None, depth, level, store_path, dataset_name, version_name, check)
952
- progress.update(task, advance=1, description=f"[cyan]Downloading... {i+1}/{len(time_list)}")
1031
+ progress.update(task, advance=1, description=f"[cyan]Downloading... {i + 1}/{len(time_list)}")
953
1032
  else:
954
1033
  # 并行方式
955
1034
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
@@ -967,7 +1046,7 @@ def _download_hourly_func(var, time_s, time_e, lon_min=0, lon_max=359.92, lat_mi
967
1046
  time_str_end_index = int(min(len(time_list) - 1, int(i * ftimes + ftimes - 1)))
968
1047
  time_str_end = time_list[time_str_end_index]
969
1048
  _prepare_url_to_download(var, lon_min, lon_max, lat_min, lat_max, time_str, time_str_end, depth, level, store_path, dataset_name, version_name, check)
970
- progress.update(task, advance=1, description=f"[cyan]Downloading... {i+1}/{total_num}")
1049
+ progress.update(task, advance=1, description=f"[cyan]Downloading... {i + 1}/{total_num}")
971
1050
  else:
972
1051
  # 并行方式
973
1052
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
@@ -977,10 +1056,10 @@ def _download_hourly_func(var, time_s, time_e, lon_min=0, lon_max=359.92, lat_mi
977
1056
  for feature in as_completed(futures):
978
1057
  _done_callback(feature, progress, task, len(time_list), counter_lock)
979
1058
  else:
980
- print("Please ensure the time_s is no more than time_e")
1059
+ print("[bold red]Please ensure the time_s is no more than time_e")
981
1060
 
982
1061
 
983
- def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, lat_max=90, depth=None, level=None, store_path=None, dataset_name=None, version_name=None, num_workers=None, check=False, ftimes=1):
1062
+ def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, lat_max=90, depth=None, level=None, store_path=None, dataset_name=None, version_name=None, num_workers=None, check=False, ftimes=1, idm_engine=None, fill_time=False):
984
1063
  """
985
1064
  Description:
986
1065
  Download the data of single time or a series of time
@@ -1001,6 +1080,8 @@ def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, l
1001
1080
  num_workers: int, the number of workers, default is None, if not set, the number of workers will be 1; suggest not to set the number of workers too large
1002
1081
  check: bool, whether to check the existing file, default is False, if set to True, the existing file will be checked and not downloaded again; else, the existing file will be covered
1003
1082
  ftimes: int, the number of time in one file, default is 1, if set to 1, the data of single time will be downloaded; the maximum is 8, if set to 8, the data of 8 times will be downloaded in one file
1083
+ idm_engine: str, the IDM engine, default is None, if set, the IDM will be used to download the data; example: "D:\\Programs\\Internet Download Manager\\IDMan.exe"
1084
+ fill_time: bool, whether to match the time, default is False, if set to True, the time in the file name will be corrected according to the time in the file; else, the data will be skip if the time is not correct. Because the real time of some data that has been downloaded does not match the time in the file name, eg. the required time is 2024110100, but the time in the file name is 2024110103, so the data will be skip if the fill_time is False. Note: it is not the right time data, so it is not recommended to set fill_time to True
1004
1085
 
1005
1086
  Returns:
1006
1087
  None
@@ -1048,7 +1129,7 @@ def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, l
1048
1129
  os.makedirs(str(store_path), exist_ok=True)
1049
1130
 
1050
1131
  if num_workers is not None:
1051
- num_workers = max(min(num_workers, 10), 1) # 暂时不限制最大值,再检查的时候可以多开一些线程
1132
+ num_workers = max(min(num_workers, 10), 1) # 暂时不限制最大值,再检查的时候可以多开一些线程
1052
1133
  # num_workers = int(max(num_workers, 1))
1053
1134
  time_s = str(time_s)
1054
1135
  if len(time_s) == 8:
@@ -1068,12 +1149,52 @@ def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, l
1068
1149
 
1069
1150
  global fsize_dict
1070
1151
  fsize_dict = {}
1071
-
1152
+
1072
1153
  global fsize_dict_lock
1073
1154
  fsize_dict_lock = Lock()
1074
1155
 
1156
+ global use_idm, given_idm_engine, idm_download_list
1157
+ if idm_engine is not None:
1158
+ use_idm = True
1159
+ num_workers = 1
1160
+ given_idm_engine = idm_engine
1161
+ idm_download_list = []
1162
+ else:
1163
+ use_idm = False
1164
+
1165
+ global match_time
1166
+ if fill_time:
1167
+ match_time = True
1168
+ else:
1169
+ match_time = False
1170
+
1075
1171
  _download_hourly_func(var, time_s, time_e, lon_min, lon_max, lat_min, lat_max, depth, level, store_path, dataset_name, version_name, num_workers, check, ftimes)
1076
1172
 
1173
+ if idm_download_list:
1174
+ for f in idm_download_list:
1175
+ wait_success = 0
1176
+ success = False
1177
+ while not success:
1178
+ if check_nc(f):
1179
+ if match_time:
1180
+ _correct_time(f)
1181
+ count_dict["success"] += 1
1182
+ else:
1183
+ if not _check_ftime(f):
1184
+ _clear_existing_file(f)
1185
+ count_dict["no_data"] += 1
1186
+ count_dict["no_data_list"].append(str(f).split("_")[-1].split(".")[0])
1187
+ else:
1188
+ count_dict["success"] += 1
1189
+ success = True
1190
+ else:
1191
+ wait_success += 1
1192
+ time.sleep(3)
1193
+ if wait_success >= 20:
1194
+ success = True
1195
+ # print(f'{f} download failed')
1196
+ count_dict["fail"] += 1
1197
+
1077
1198
  count_dict["total"] = count_dict["success"] + count_dict["fail"] + count_dict["skip"] + count_dict["no_data"]
1078
1199
 
1079
1200
  print("[bold #ecdbfe]-" * 160)
@@ -1140,10 +1261,6 @@ def how_to_use():
1140
1261
 
1141
1262
 
1142
1263
  if __name__ == "__main__":
1143
- time_s, time_e = "2024101012", "2024101018"
1144
- merge_name = f"{time_s}_{time_e}" # 合并后的文件名
1145
- root_path = r"G:\Data\HYCOM\3hourly_test"
1146
- location_dict = {"west": 105, "east": 130, "south": 15, "north": 45}
1147
1264
  download_dict = {
1148
1265
  "water_u": {"simple_name": "u", "download": 1},
1149
1266
  "water_v": {"simple_name": "v", "download": 1},
@@ -1158,50 +1275,31 @@ if __name__ == "__main__":
1158
1275
 
1159
1276
  var_list = [var_name for var_name in download_dict.keys() if download_dict[var_name]["download"]]
1160
1277
 
1161
- # set depth or level, only one can be True
1162
- # if you wanna download all depth or level, set both False
1163
- depth = None # or 0-5000 meters
1164
- level = None # or 1-40 levels
1165
- num_workers = 3
1166
-
1167
- check = True
1168
- ftimes = 1
1169
-
1170
- download_switch, single_var = True, False
1171
- combine_switch = False
1172
- copy_switch, copy_dir = False, r"G:\Data\HYCOM\3hourly"
1278
+ single_var = False
1173
1279
 
1174
1280
  # draw_time_range(pic_save_folder=r'I:\Delete')
1175
1281
 
1176
- if download_switch:
1177
- if single_var:
1178
- for var_name in var_list:
1179
- download(var=var_name, time_s=time_s, time_e=time_e, store_path=Path(root_path), lon_min=location_dict["west"], lon_max=location_dict["east"], lat_min=location_dict["south"], lat_max=location_dict["north"], num_workers=num_workers, check=check, depth=depth, level=level, ftimes=ftimes)
1180
- else:
1181
- download(var=var_list, time_s=time_s, time_e=time_e, store_path=Path(root_path), lon_min=location_dict["west"], lon_max=location_dict["east"], lat_min=location_dict["south"], lat_max=location_dict["north"], num_workers=num_workers, check=check, depth=depth, level=level, ftimes=ftimes)
1282
+ options = {
1283
+ "var": var_list,
1284
+ "time_s": "2018010100",
1285
+ "time_e": "2020123121",
1286
+ "store_path": r"F:\Data\HYCOM\3hourly",
1287
+ "lon_min": 105,
1288
+ "lon_max": 130,
1289
+ "lat_min": 15,
1290
+ "lat_max": 45,
1291
+ "num_workers": 3,
1292
+ "check": True,
1293
+ "depth": None, # or 0-5000 meters
1294
+ "level": None, # or 1-40 levels
1295
+ "ftimes": 1,
1296
+ "idm_engine": r"D:\Programs\Internet Download Manager\IDMan.exe", # 查漏补缺不建议开启
1297
+ "fill_time": False
1298
+ }
1182
1299
 
1183
- """ if combine_switch or copy_switch:
1184
- time_list = get_time_list(time_s, time_e, 3, 'hour')
1300
+ if single_var:
1185
1301
  for var_name in var_list:
1186
- file_list = []
1187
- if single_var:
1188
- for time_str in time_list:
1189
- file_list.append(Path(root_path)/f'HYCOM_{var_name}_{time_str}.nc')
1190
- merge_path_name = Path(root_path)/f'HYCOM_{var_name}_{merge_name}.nc'
1191
- else:
1192
- # 如果混合,需要看情况获取文件列表
1193
- fname = ''
1194
- if var_name in ['water_u', 'water_v', 'water_u_bottom', 'water_v_bottom']:
1195
- fname = 'uv3z'
1196
- elif var_name in ['water_temp', 'salinity', 'water_temp_bottom', 'salinity_bottom']:
1197
- fname = 'ts3z'
1198
- elif var_name in ['surf_el']:
1199
- fname = 'surf_el'
1200
- for time_str in time_list:
1201
- file_list.append(Path(root_path)/f'HYCOM_{fname}_{time_str}.nc')
1202
- merge_path_name = Path(root_path)/f'HYCOM_{fname}_{merge_name}.nc'
1203
- if combine_switch:
1204
- # 这里的var_name必须是官方变量名,不能再是简写了
1205
- merge(file_list, var_name, 'time', merge_path_name)
1206
- if copy_switch:
1207
- copy_file(merge_path_name, copy_dir) """
1302
+ options["var"] = var_name
1303
+ download(**options)
1304
+ else:
1305
+ download(**options)
@@ -38,7 +38,7 @@ def downloader(task_url, folder_path, file_name, idm_engine=r"D:\Programs\Intern
38
38
  Return:
39
39
  None
40
40
  Example:
41
- downloader("https://www.test.com/data.nc", r"E:\Data", "test.nc", r"D:\Programs\Internet Download Manager\IDMan.exe")
41
+ downloader("https://www.test.com/data.nc", "E:\\Data", "test.nc", "D:\\Programs\\Internet Download Manager\\IDMan.exe")
42
42
  """
43
43
  os.makedirs(folder_path, exist_ok=True)
44
44
  # 将任务添加至队列
@@ -23,6 +23,8 @@ import requests
23
23
  from rich import print
24
24
  from rich.progress import track
25
25
  from oafuncs.oa_down.user_agent import get_ua
26
+ from oafuncs.oa_file import remove
27
+ from oafuncs.oa_data import ensure_list
26
28
 
27
29
  __all__ = ["download5doi"]
28
30
 
@@ -221,7 +223,7 @@ class _Downloader:
221
223
  print("Try another URL...")
222
224
 
223
225
 
224
- def read_excel(file, col_name=r"DOI"):
226
+ def _read_excel(file, col_name=r"DOI"):
225
227
  df = pd.read_excel(file)
226
228
  df_list = df[col_name].tolist()
227
229
  # 去掉nan
@@ -229,7 +231,7 @@ def read_excel(file, col_name=r"DOI"):
229
231
  return df_list
230
232
 
231
233
 
232
- def read_txt(file):
234
+ def _read_txt(file):
233
235
  with open(file, "r") as f:
234
236
  lines = f.readlines()
235
237
  # 去掉换行符以及空行
@@ -267,13 +269,13 @@ def download5doi(store_path=None, doi_list=None, txt_file=None, excel_file=None,
267
269
  store_path.mkdir(parents=True, exist_ok=True)
268
270
  store_path = str(store_path)
269
271
 
270
- # 如果doi_list是str,转换为list
271
- if isinstance(doi_list, str) and doi_list:
272
- doi_list = [doi_list]
272
+ if doi_list:
273
+ doi_list = ensure_list(doi_list)
273
274
  if txt_file:
274
- doi_list = read_txt(txt_file)
275
+ doi_list = _read_txt(txt_file)
275
276
  if excel_file:
276
- doi_list = read_excel(excel_file, col_name)
277
+ doi_list = _read_excel(excel_file, col_name)
278
+ remove(Path(store_path) / "wrong_record.txt")
277
279
  print(f"Downloading {len(doi_list)} PDF files...")
278
280
  for doi in track(doi_list, description="Downloading..."):
279
281
  download = _Downloader(doi, store_path)
@@ -281,7 +283,6 @@ def download5doi(store_path=None, doi_list=None, txt_file=None, excel_file=None,
281
283
 
282
284
 
283
285
  if __name__ == "__main__":
284
- store_path = r"F:\AAA-Delete\DOI_Reference\pdf"
285
- excel_file = r"F:\AAA-Delete\DOI_Reference\savedrecs.xls"
286
- # download5doi(store_path, doi_list='10.1007/s00382-022-06260-x')
286
+ store_path = r"F:\AAA-Delete\DOI_Reference\5\pdf"
287
+ excel_file = r"F:\AAA-Delete\DOI_Reference\5\savedrecs.xls"
287
288
  download5doi(store_path, excel_file=excel_file)
@@ -226,8 +226,12 @@ def make_dir(directory):
226
226
  make_dir(r"E:\Data\2024\09\17\var1")
227
227
  """
228
228
  directory = str(directory)
229
- os.makedirs(directory, exist_ok=True)
230
- print(f"Created directory: {directory}")
229
+ if os.path.exists(directory):
230
+ print(f"Directory already exists: {directory}")
231
+ return
232
+ else:
233
+ os.makedirs(directory, exist_ok=True)
234
+ print(f"Created directory: {directory}")
231
235
 
232
236
 
233
237
  # ** 清空文件夹
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: oafuncs
3
- Version: 0.0.91
3
+ Version: 0.0.93
4
4
  Summary: Oceanic and Atmospheric Functions
5
5
  Home-page: https://github.com/Industry-Pays/OAFuncs
6
6
  Author: Kun Liu
@@ -30,6 +30,16 @@ Requires-Dist: matplotlib
30
30
  Requires-Dist: Cartopy
31
31
  Requires-Dist: netCDF4
32
32
  Requires-Dist: xlrd
33
+ Dynamic: author
34
+ Dynamic: author-email
35
+ Dynamic: classifier
36
+ Dynamic: description
37
+ Dynamic: description-content-type
38
+ Dynamic: home-page
39
+ Dynamic: license
40
+ Dynamic: requires-dist
41
+ Dynamic: requires-python
42
+ Dynamic: summary
33
43
 
34
44
 
35
45
  # oafuncs
@@ -18,7 +18,7 @@ URL = 'https://github.com/Industry-Pays/OAFuncs'
18
18
  EMAIL = 'liukun0312@stu.ouc.edu.cn'
19
19
  AUTHOR = 'Kun Liu'
20
20
  REQUIRES_PYTHON = '>=3.9.0' # 2025/01/05
21
- VERSION = '0.0.91'
21
+ VERSION = '0.0.93'
22
22
 
23
23
  # What packages are required for this module to be executed?
24
24
  REQUIRED = [
@@ -1,278 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- """
4
- Author: Liu Kun && 16031215@qq.com
5
- Date: 2024-09-17 17:12:47
6
- LastEditors: Liu Kun && 16031215@qq.com
7
- LastEditTime: 2024-12-13 19:11:08
8
- FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_data.py
9
- Description:
10
- EditPlatform: vscode
11
- ComputerInfo: XPS 15 9510
12
- SystemInfo: Windows 11
13
- Python Version: 3.11
14
- """
15
-
16
- import itertools
17
- import multiprocessing as mp
18
- from concurrent.futures import ThreadPoolExecutor
19
-
20
- import numpy as np
21
- from rich import print
22
- from scipy.interpolate import griddata
23
-
24
-
25
- __all__ = ["interp_2d"]
26
-
27
-
28
- def interp_2d(target_x, target_y, origin_x, origin_y, data, method="linear", parallel=True):
29
- """
30
- Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
31
-
32
- Parameters:
33
- - target_x (array-like): 1D array of target grid's x-coordinates.
34
- - target_y (array-like): 1D array of target grid's y-coordinates.
35
- - origin_x (array-like): 1D array of original grid's x-coordinates.
36
- - origin_y (array-like): 1D array of original grid's y-coordinates.
37
- - data (numpy.ndarray): Multi-dimensional array where the last two dimensions correspond to the original grid.
38
- - method (str, optional): Interpolation method, default is 'linear'. Other options include 'nearest', 'cubic', etc.
39
- - parallel (bool, optional): Flag to enable parallel processing. Default is True.
40
-
41
- Returns:
42
- - interpolated_data (numpy.ndarray): Interpolated data with the same leading dimensions as the input data, but with the last two dimensions corresponding to the target grid.
43
-
44
- Raises:
45
- - ValueError: If the shape of the data does not match the shape of the origin_x or origin_y grids.
46
-
47
- Usage:
48
- - Interpolate a 2D array:
49
- result = interp_2d(target_x, target_y, origin_x, origin_y, data_2d)
50
- - Interpolate a 3D array (where the last two dimensions are spatial):
51
- result = interp_2d(target_x, target_y, origin_x, origin_y, data_3d)
52
- - Interpolate a 4D array (where the last two dimensions are spatial):
53
- result = interp_2d(target_x, target_y, origin_x, origin_y, data_4d)
54
- """
55
-
56
- def interp_single(data_slice, target_points, origin_points, method):
57
- return griddata(origin_points, data_slice.ravel(), target_points, method=method).reshape(target_y.shape)
58
-
59
- # 确保目标网格和初始网格都是二维的
60
- if len(target_y.shape) == 1:
61
- target_x, target_y = np.meshgrid(target_x, target_y)
62
- if len(origin_y.shape) == 1:
63
- origin_x, origin_y = np.meshgrid(origin_x, origin_y)
64
-
65
- # 根据经纬度网格判断输入数据的形状是否匹配
66
- if origin_x.shape != data.shape[-2:] or origin_y.shape != data.shape[-2:]:
67
- raise ValueError("Shape of data does not match shape of origin_x or origin_y.")
68
-
69
- # 创建网格和展平数据
70
- target_points = np.column_stack((target_y.ravel(), target_x.ravel()))
71
- origin_points = np.column_stack((origin_y.ravel(), origin_x.ravel()))
72
-
73
- # 根据是否并行选择不同的执行方式
74
- if parallel:
75
- with ThreadPoolExecutor(max_workers=mp.cpu_count() - 2) as executor:
76
- if len(data.shape) == 2:
77
- interpolated_data = list(executor.map(interp_single, [data], [target_points], [origin_points], [method]))
78
- elif len(data.shape) == 3:
79
- interpolated_data = list(executor.map(interp_single, [data[i] for i in range(data.shape[0])], [target_points] * data.shape[0], [origin_points] * data.shape[0], [method] * data.shape[0]))
80
- elif len(data.shape) == 4:
81
- index_combinations = list(itertools.product(range(data.shape[0]), range(data.shape[1])))
82
- interpolated_data = list(executor.map(interp_single, [data[i, j] for i, j in index_combinations], [target_points] * len(index_combinations), [origin_points] * len(index_combinations), [method] * len(index_combinations)))
83
- interpolated_data = np.array(interpolated_data).reshape(data.shape[0], data.shape[1], *target_y.shape)
84
- else:
85
- if len(data.shape) == 2:
86
- interpolated_data = interp_single(data, target_points, origin_points, method)
87
- elif len(data.shape) == 3:
88
- interpolated_data = np.stack([interp_single(data[i], target_points, origin_points, method) for i in range(data.shape[0])])
89
- elif len(data.shape) == 4:
90
- interpolated_data = np.stack([np.stack([interp_single(data[i, j], target_points, origin_points, method) for j in range(data.shape[1])]) for i in range(data.shape[0])])
91
-
92
- return np.array(interpolated_data)
93
-
94
-
95
-
96
-
97
-
98
- # ---------------------------------------------------------------------------------- not used below ----------------------------------------------------------------------------------
99
- # ** 高维插值函数,插值最后两个维度
100
- def interp_2d_20241213(target_x, target_y, origin_x, origin_y, data, method="linear"):
101
- """
102
- 高维插值函数,默认插值最后两个维度,传输数据前请确保数据的维度正确
103
- 参数:
104
- target_y (array-like): 目标经度网格 1D 或 2D
105
- target_x (array-like): 目标纬度网格 1D 或 2D
106
- origin_y (array-like): 初始经度网格 1D 或 2D
107
- origin_x (array-like): 初始纬度网格 1D 或 2D
108
- data (array-like): 数据 (*, lat, lon) 2D, 3D, 4D
109
- method (str, optional): 插值方法,可选 'linear', 'nearest', 'cubic' 等,默认为 'linear'
110
- 返回:
111
- array-like: 插值结果
112
- """
113
-
114
- # 确保目标网格和初始网格都是二维的
115
- if len(target_y.shape) == 1:
116
- target_x, target_y = np.meshgrid(target_x, target_y)
117
- if len(origin_y.shape) == 1:
118
- origin_x, origin_y = np.meshgrid(origin_x, origin_y)
119
-
120
- dims = data.shape
121
- len_dims = len(dims)
122
- # print(dims[-2:])
123
- # 根据经纬度网格判断输入数据的形状是否匹配
124
-
125
- if origin_x.shape != dims[-2:] or origin_y.shape != dims[-2:]:
126
- print(origin_x.shape, dims[-2:])
127
- raise ValueError("Shape of data does not match shape of origin_x or origin_y.")
128
-
129
- # 将目标网格展平成一维数组
130
- target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
131
-
132
- # 将初始网格展平成一维数组
133
- origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
134
-
135
- # 进行插值
136
- if len_dims == 2:
137
- interpolated_data = griddata(origin_points, np.ravel(data), target_points, method=method)
138
- interpolated_data = np.reshape(interpolated_data, target_y.shape)
139
- elif len_dims == 3:
140
- interpolated_data = []
141
- for i in range(dims[0]):
142
- dt = griddata(origin_points, np.ravel(data[i, :, :]), target_points, method=method)
143
- interpolated_data.append(np.reshape(dt, target_y.shape))
144
- print(f"Interpolating {i + 1}/{dims[0]}...")
145
- interpolated_data = np.array(interpolated_data)
146
- elif len_dims == 4:
147
- interpolated_data = []
148
- for i in range(dims[0]):
149
- interpolated_data.append([])
150
- for j in range(dims[1]):
151
- dt = griddata(origin_points, np.ravel(data[i, j, :, :]), target_points, method=method)
152
- interpolated_data[i].append(np.reshape(dt, target_y.shape))
153
- print(f"\rInterpolating {i * dims[1] + j + 1}/{dims[0] * dims[1]}...", end="")
154
- print("\n")
155
- interpolated_data = np.array(interpolated_data)
156
-
157
- return interpolated_data
158
-
159
-
160
- # ** 高维插值函数,插值最后两个维度,使用多线程进行插值
161
- # 在本地电脑上可以提速三倍左右,超算上暂时无法加速
162
- def interp_2d_parallel_20241213(target_x, target_y, origin_x, origin_y, data, method="linear"):
163
- """
164
- param {*} target_x 目标经度网格 1D 或 2D
165
- param {*} target_y 目标纬度网格 1D 或 2D
166
- param {*} origin_x 初始经度网格 1D 或 2D
167
- param {*} origin_y 初始纬度网格 1D 或 2D
168
- param {*} data 数据 (*, lat, lon) 2D, 3D, 4D
169
- param {*} method 插值方法,可选 'linear', 'nearest', 'cubic' 等,默认为 'linear'
170
- return {*} 插值结果
171
- description : 高维插值函数,默认插值最后两个维度,传输数据前请确保数据的维度正确
172
- example : interpolated_data = interp_2d_parallel(target_x, target_y, origin_x, origin_y, data, method='linear')
173
- """
174
-
175
- def interp_single2d(target_y, target_x, origin_y, origin_x, data, method="linear"):
176
- target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
177
- origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
178
-
179
- dt = griddata(origin_points, np.ravel(data[:, :]), target_points, method=method)
180
- return np.reshape(dt, target_y.shape)
181
-
182
- def interp_single3d(i, target_y, target_x, origin_y, origin_x, data, method="linear"):
183
- target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
184
- origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
185
-
186
- dt = griddata(origin_points, np.ravel(data[i, :, :]), target_points, method=method)
187
- return np.reshape(dt, target_y.shape)
188
-
189
- def interp_single4d(i, j, target_y, target_x, origin_y, origin_x, data, method="linear"):
190
- target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
191
- origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
192
-
193
- dt = griddata(origin_points, np.ravel(data[i, j, :, :]), target_points, method=method)
194
- return np.reshape(dt, target_y.shape)
195
-
196
- if len(target_y.shape) == 1:
197
- target_x, target_y = np.meshgrid(target_x, target_y)
198
- if len(origin_y.shape) == 1:
199
- origin_x, origin_y = np.meshgrid(origin_x, origin_y)
200
-
201
- dims = data.shape
202
- len_dims = len(dims)
203
-
204
- if origin_x.shape != dims[-2:] or origin_y.shape != dims[-2:]:
205
- raise ValueError("数据形状与 origin_x 或 origin_y 的形状不匹配.")
206
-
207
- interpolated_data = []
208
-
209
- # 使用多线程进行插值
210
- with ThreadPoolExecutor(max_workers=mp.cpu_count() - 2) as executor:
211
- print(f"Using {mp.cpu_count() - 2} threads...")
212
- if len_dims == 2:
213
- interpolated_data = list(executor.map(interp_single2d, [target_y], [target_x], [origin_y], [origin_x], [data], [method]))
214
- elif len_dims == 3:
215
- interpolated_data = list(executor.map(interp_single3d, [i for i in range(dims[0])], [target_y] * dims[0], [target_x] * dims[0], [origin_y] * dims[0], [origin_x] * dims[0], [data] * dims[0], [method] * dims[0]))
216
- elif len_dims == 4:
217
- interpolated_data = list(
218
- executor.map(
219
- interp_single4d,
220
- [i for i in range(dims[0]) for j in range(dims[1])],
221
- [j for i in range(dims[0]) for j in range(dims[1])],
222
- [target_y] * dims[0] * dims[1],
223
- [target_x] * dims[0] * dims[1],
224
- [origin_y] * dims[0] * dims[1],
225
- [origin_x] * dims[0] * dims[1],
226
- [data] * dims[0] * dims[1],
227
- [method] * dims[0] * dims[1],
228
- )
229
- )
230
- interpolated_data = np.array(interpolated_data).reshape(dims[0], dims[1], target_y.shape[0], target_x.shape[1])
231
-
232
- interpolated_data = np.array(interpolated_data)
233
-
234
- return interpolated_data
235
-
236
-
237
- def _test_sum(a, b):
238
- return a + b
239
-
240
-
241
- if __name__ == "__main__":
242
-
243
- pass
244
- """ import time
245
-
246
- import matplotlib.pyplot as plt
247
-
248
- # 测试数据
249
- origin_x = np.linspace(0, 10, 11)
250
- origin_y = np.linspace(0, 10, 11)
251
- target_x = np.linspace(0, 10, 101)
252
- target_y = np.linspace(0, 10, 101)
253
- data = np.random.rand(11, 11)
254
-
255
- # 高维插值
256
- origin_x = np.linspace(0, 10, 11)
257
- origin_y = np.linspace(0, 10, 11)
258
- target_x = np.linspace(0, 10, 101)
259
- target_y = np.linspace(0, 10, 101)
260
- data = np.random.rand(10, 10, 11, 11)
261
-
262
- start = time.time()
263
- interpolated_data = interp_2d(target_x, target_y, origin_x, origin_y, data, parallel=False)
264
- print(f"Interpolation time: {time.time()-start:.2f}s")
265
-
266
- print(interpolated_data.shape)
267
-
268
- # 高维插值多线程
269
- start = time.time()
270
- interpolated_data = interp_2d(target_x, target_y, origin_x, origin_y, data)
271
- print(f"Interpolation time: {time.time()-start:.2f}s")
272
-
273
- print(interpolated_data.shape)
274
- print(interpolated_data[0, 0, :, :].shape)
275
- plt.figure()
276
- plt.contourf(target_x, target_y, interpolated_data[0, 0, :, :])
277
- plt.colorbar()
278
- plt.show() """
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes