oafuncs 0.0.91__tar.gz → 0.0.93__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {oafuncs-0.0.91/oafuncs.egg-info → oafuncs-0.0.93}/PKG-INFO +12 -2
- oafuncs-0.0.93/oafuncs/oa_data.py +153 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/hycom_3hourly.py +246 -148
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/idm.py +1 -1
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/literature.py +11 -10
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_file.py +6 -2
- {oafuncs-0.0.91 → oafuncs-0.0.93/oafuncs.egg-info}/PKG-INFO +12 -2
- {oafuncs-0.0.91 → oafuncs-0.0.93}/setup.py +1 -1
- oafuncs-0.0.91/oafuncs/oa_data.py +0 -278
- {oafuncs-0.0.91 → oafuncs-0.0.93}/LICENSE.txt +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/MANIFEST.in +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/README.md +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/__init__.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/data_store/OAFuncs.png +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_cmap.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/User_Agent-list.txt +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/__init__.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/test_ua.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_down/user_agent.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_draw.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_help.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_nc.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_python.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_sign/__init__.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_sign/meteorological.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_sign/ocean.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_sign/scientific.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_tool/__init__.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_tool/email.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs/oa_tool/parallel.py +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs.egg-info/SOURCES.txt +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs.egg-info/dependency_links.txt +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs.egg-info/requires.txt +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/oafuncs.egg-info/top_level.txt +0 -0
- {oafuncs-0.0.91 → oafuncs-0.0.93}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: oafuncs
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.93
|
4
4
|
Summary: Oceanic and Atmospheric Functions
|
5
5
|
Home-page: https://github.com/Industry-Pays/OAFuncs
|
6
6
|
Author: Kun Liu
|
@@ -30,6 +30,16 @@ Requires-Dist: matplotlib
|
|
30
30
|
Requires-Dist: Cartopy
|
31
31
|
Requires-Dist: netCDF4
|
32
32
|
Requires-Dist: xlrd
|
33
|
+
Dynamic: author
|
34
|
+
Dynamic: author-email
|
35
|
+
Dynamic: classifier
|
36
|
+
Dynamic: description
|
37
|
+
Dynamic: description-content-type
|
38
|
+
Dynamic: home-page
|
39
|
+
Dynamic: license
|
40
|
+
Dynamic: requires-dist
|
41
|
+
Dynamic: requires-python
|
42
|
+
Dynamic: summary
|
33
43
|
|
34
44
|
|
35
45
|
# oafuncs
|
@@ -0,0 +1,153 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# coding=utf-8
|
3
|
+
"""
|
4
|
+
Author: Liu Kun && 16031215@qq.com
|
5
|
+
Date: 2024-09-17 17:12:47
|
6
|
+
LastEditors: Liu Kun && 16031215@qq.com
|
7
|
+
LastEditTime: 2024-12-13 19:11:08
|
8
|
+
FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_data.py
|
9
|
+
Description:
|
10
|
+
EditPlatform: vscode
|
11
|
+
ComputerInfo: XPS 15 9510
|
12
|
+
SystemInfo: Windows 11
|
13
|
+
Python Version: 3.11
|
14
|
+
"""
|
15
|
+
|
16
|
+
import itertools
|
17
|
+
import multiprocessing as mp
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
from scipy.interpolate import griddata
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = ["interp_2d", "ensure_list"]
|
25
|
+
|
26
|
+
|
27
|
+
def ensure_list(input_data):
|
28
|
+
"""
|
29
|
+
Ensures that the input is converted into a list.
|
30
|
+
|
31
|
+
If the input is already a list, it returns it directly.
|
32
|
+
If the input is a string, it wraps it in a list and returns.
|
33
|
+
For other types of input, it converts them to a string and then wraps in a list.
|
34
|
+
|
35
|
+
:param input_data: The input which can be a list, a string, or any other type.
|
36
|
+
:return: A list containing the input or the string representation of the input.
|
37
|
+
"""
|
38
|
+
if isinstance(input_data, list):
|
39
|
+
return input_data
|
40
|
+
elif isinstance(input_data, str):
|
41
|
+
return [input_data]
|
42
|
+
else:
|
43
|
+
# For non-list and non-string inputs, convert to string and wrap in a list
|
44
|
+
return [str(input_data)]
|
45
|
+
|
46
|
+
|
47
|
+
|
48
|
+
def interp_2d(target_x, target_y, origin_x, origin_y, data, method="linear", parallel=True):
|
49
|
+
"""
|
50
|
+
Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
|
51
|
+
|
52
|
+
Parameters:
|
53
|
+
- target_x (array-like): 1D array of target grid's x-coordinates.
|
54
|
+
- target_y (array-like): 1D array of target grid's y-coordinates.
|
55
|
+
- origin_x (array-like): 1D array of original grid's x-coordinates.
|
56
|
+
- origin_y (array-like): 1D array of original grid's y-coordinates.
|
57
|
+
- data (numpy.ndarray): Multi-dimensional array where the last two dimensions correspond to the original grid.
|
58
|
+
- method (str, optional): Interpolation method, default is 'linear'. Other options include 'nearest', 'cubic', etc.
|
59
|
+
- parallel (bool, optional): Flag to enable parallel processing. Default is True.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
- interpolated_data (numpy.ndarray): Interpolated data with the same leading dimensions as the input data, but with the last two dimensions corresponding to the target grid.
|
63
|
+
|
64
|
+
Raises:
|
65
|
+
- ValueError: If the shape of the data does not match the shape of the origin_x or origin_y grids.
|
66
|
+
|
67
|
+
Usage:
|
68
|
+
- Interpolate a 2D array:
|
69
|
+
result = interp_2d(target_x, target_y, origin_x, origin_y, data_2d)
|
70
|
+
- Interpolate a 3D array (where the last two dimensions are spatial):
|
71
|
+
result = interp_2d(target_x, target_y, origin_x, origin_y, data_3d)
|
72
|
+
- Interpolate a 4D array (where the last two dimensions are spatial):
|
73
|
+
result = interp_2d(target_x, target_y, origin_x, origin_y, data_4d)
|
74
|
+
"""
|
75
|
+
|
76
|
+
def interp_single(data_slice, target_points, origin_points, method):
|
77
|
+
return griddata(origin_points, data_slice.ravel(), target_points, method=method).reshape(target_y.shape)
|
78
|
+
|
79
|
+
# 确保目标网格和初始网格都是二维的
|
80
|
+
if len(target_y.shape) == 1:
|
81
|
+
target_x, target_y = np.meshgrid(target_x, target_y)
|
82
|
+
if len(origin_y.shape) == 1:
|
83
|
+
origin_x, origin_y = np.meshgrid(origin_x, origin_y)
|
84
|
+
|
85
|
+
# 根据经纬度网格判断输入数据的形状是否匹配
|
86
|
+
if origin_x.shape != data.shape[-2:] or origin_y.shape != data.shape[-2:]:
|
87
|
+
raise ValueError("Shape of data does not match shape of origin_x or origin_y.")
|
88
|
+
|
89
|
+
# 创建网格和展平数据
|
90
|
+
target_points = np.column_stack((target_y.ravel(), target_x.ravel()))
|
91
|
+
origin_points = np.column_stack((origin_y.ravel(), origin_x.ravel()))
|
92
|
+
|
93
|
+
# 根据是否并行选择不同的执行方式
|
94
|
+
if parallel:
|
95
|
+
with ThreadPoolExecutor(max_workers=mp.cpu_count() - 2) as executor:
|
96
|
+
if len(data.shape) == 2:
|
97
|
+
interpolated_data = list(executor.map(interp_single, [data], [target_points], [origin_points], [method]))
|
98
|
+
elif len(data.shape) == 3:
|
99
|
+
interpolated_data = list(executor.map(interp_single, [data[i] for i in range(data.shape[0])], [target_points] * data.shape[0], [origin_points] * data.shape[0], [method] * data.shape[0]))
|
100
|
+
elif len(data.shape) == 4:
|
101
|
+
index_combinations = list(itertools.product(range(data.shape[0]), range(data.shape[1])))
|
102
|
+
interpolated_data = list(executor.map(interp_single, [data[i, j] for i, j in index_combinations], [target_points] * len(index_combinations), [origin_points] * len(index_combinations), [method] * len(index_combinations)))
|
103
|
+
interpolated_data = np.array(interpolated_data).reshape(data.shape[0], data.shape[1], *target_y.shape)
|
104
|
+
else:
|
105
|
+
if len(data.shape) == 2:
|
106
|
+
interpolated_data = interp_single(data, target_points, origin_points, method)
|
107
|
+
elif len(data.shape) == 3:
|
108
|
+
interpolated_data = np.stack([interp_single(data[i], target_points, origin_points, method) for i in range(data.shape[0])])
|
109
|
+
elif len(data.shape) == 4:
|
110
|
+
interpolated_data = np.stack([np.stack([interp_single(data[i, j], target_points, origin_points, method) for j in range(data.shape[1])]) for i in range(data.shape[0])])
|
111
|
+
|
112
|
+
return np.array(interpolated_data)
|
113
|
+
|
114
|
+
|
115
|
+
|
116
|
+
if __name__ == "__main__":
|
117
|
+
|
118
|
+
pass
|
119
|
+
""" import time
|
120
|
+
|
121
|
+
import matplotlib.pyplot as plt
|
122
|
+
|
123
|
+
# 测试数据
|
124
|
+
origin_x = np.linspace(0, 10, 11)
|
125
|
+
origin_y = np.linspace(0, 10, 11)
|
126
|
+
target_x = np.linspace(0, 10, 101)
|
127
|
+
target_y = np.linspace(0, 10, 101)
|
128
|
+
data = np.random.rand(11, 11)
|
129
|
+
|
130
|
+
# 高维插值
|
131
|
+
origin_x = np.linspace(0, 10, 11)
|
132
|
+
origin_y = np.linspace(0, 10, 11)
|
133
|
+
target_x = np.linspace(0, 10, 101)
|
134
|
+
target_y = np.linspace(0, 10, 101)
|
135
|
+
data = np.random.rand(10, 10, 11, 11)
|
136
|
+
|
137
|
+
start = time.time()
|
138
|
+
interpolated_data = interp_2d(target_x, target_y, origin_x, origin_y, data, parallel=False)
|
139
|
+
print(f"Interpolation time: {time.time()-start:.2f}s")
|
140
|
+
|
141
|
+
print(interpolated_data.shape)
|
142
|
+
|
143
|
+
# 高维插值多线程
|
144
|
+
start = time.time()
|
145
|
+
interpolated_data = interp_2d(target_x, target_y, origin_x, origin_y, data)
|
146
|
+
print(f"Interpolation time: {time.time()-start:.2f}s")
|
147
|
+
|
148
|
+
print(interpolated_data.shape)
|
149
|
+
print(interpolated_data[0, 0, :, :].shape)
|
150
|
+
plt.figure()
|
151
|
+
plt.contourf(target_x, target_y, interpolated_data[0, 0, :, :])
|
152
|
+
plt.colorbar()
|
153
|
+
plt.show() """
|
@@ -26,13 +26,17 @@ from threading import Lock
|
|
26
26
|
import matplotlib.pyplot as plt
|
27
27
|
import numpy as np
|
28
28
|
import pandas as pd
|
29
|
+
import xarray as xr
|
29
30
|
import requests
|
30
31
|
from rich import print
|
31
32
|
from rich.progress import Progress
|
33
|
+
import netCDF4 as nc
|
32
34
|
|
33
35
|
from oafuncs.oa_down.user_agent import get_ua
|
34
36
|
from oafuncs.oa_file import file_size, mean_size
|
35
37
|
from oafuncs.oa_nc import check as check_nc
|
38
|
+
from oafuncs.oa_nc import modify as modify_nc
|
39
|
+
from oafuncs.oa_down.idm import downloader as idm_downloader
|
36
40
|
|
37
41
|
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Engine '.*' loading failed:.*")
|
38
42
|
|
@@ -571,20 +575,26 @@ def _check_existing_file(file_full_path, avg_size):
|
|
571
575
|
if abs(delta_size_ratio) > 0.025:
|
572
576
|
if check_nc(file_full_path):
|
573
577
|
# print(f"File size is abnormal but can be opened normally, file size: {fsize:.2f} KB")
|
574
|
-
|
578
|
+
if not _check_ftime(file_full_path, if_print=True):
|
579
|
+
return False
|
580
|
+
else:
|
581
|
+
return True
|
575
582
|
else:
|
576
583
|
print(f"File size is abnormal and cannot be opened, {file_full_path}: {fsize:.2f} KB")
|
577
584
|
return False
|
578
585
|
else:
|
579
|
-
|
586
|
+
if not _check_ftime(file_full_path, if_print=True):
|
587
|
+
return False
|
588
|
+
else:
|
589
|
+
return True
|
580
590
|
else:
|
581
591
|
return False
|
582
592
|
|
583
593
|
|
584
594
|
def _get_mean_size30(store_path, same_file):
|
585
595
|
if same_file not in fsize_dict.keys():
|
586
|
-
|
587
|
-
|
596
|
+
# print(f'Same file name: {same_file}')
|
597
|
+
fsize_dict[same_file] = {"size": 0, "count": 0}
|
588
598
|
|
589
599
|
if fsize_dict[same_file]["count"] < 30 or fsize_dict[same_file]["size"] == 0:
|
590
600
|
# 更新30次文件最小值,后续认为可以代表所有文件,不再更新占用时间
|
@@ -599,7 +609,7 @@ def _get_mean_size30(store_path, same_file):
|
|
599
609
|
|
600
610
|
def _get_mean_size_move(same_file, current_file):
|
601
611
|
# 获取锁
|
602
|
-
with fsize_dict_lock:
|
612
|
+
with fsize_dict_lock: # 全局锁,确保同一时间只能有一个线程访问
|
603
613
|
# 初始化字典中的值,如果文件不在字典中
|
604
614
|
if same_file not in fsize_dict.keys():
|
605
615
|
fsize_dict[same_file] = {"size_list": [], "mean_size": 1.0}
|
@@ -633,6 +643,61 @@ def _get_mean_size_move(same_file, current_file):
|
|
633
643
|
return fsize_dict[same_file]["mean_size"]
|
634
644
|
|
635
645
|
|
646
|
+
def _check_ftime(nc_file, tname="time", if_print=False):
|
647
|
+
if not os.path.exists(nc_file):
|
648
|
+
return False
|
649
|
+
nc_file = str(nc_file)
|
650
|
+
try:
|
651
|
+
ds = xr.open_dataset(nc_file)
|
652
|
+
real_time = ds[tname].values[0]
|
653
|
+
ds.close()
|
654
|
+
real_time = str(real_time)[:13]
|
655
|
+
real_time = real_time.replace("-", "").replace("T", "")
|
656
|
+
# -----------------------------------------------------
|
657
|
+
f_time = re.findall(r"\d{10}", nc_file)[0]
|
658
|
+
if real_time == f_time:
|
659
|
+
return True
|
660
|
+
else:
|
661
|
+
if if_print:
|
662
|
+
print(f"[bold #daff5c]File time error, file/real time: [bold blue]{f_time}/{real_time}")
|
663
|
+
return False
|
664
|
+
except Exception as e:
|
665
|
+
if if_print:
|
666
|
+
print(f"[bold #daff5c]File time check failed, {nc_file}: {e}")
|
667
|
+
return False
|
668
|
+
|
669
|
+
|
670
|
+
def _correct_time(nc_file):
|
671
|
+
# 打开NC文件
|
672
|
+
dataset = nc.Dataset(nc_file)
|
673
|
+
|
674
|
+
# 读取时间单位
|
675
|
+
time_units = dataset.variables["time"].units
|
676
|
+
|
677
|
+
# 关闭文件
|
678
|
+
dataset.close()
|
679
|
+
|
680
|
+
# 解析时间单位字符串以获取时间原点
|
681
|
+
origin_str = time_units.split("since")[1].strip()
|
682
|
+
origin_datetime = datetime.datetime.strptime(origin_str, "%Y-%m-%d %H:%M:%S")
|
683
|
+
|
684
|
+
# 从文件名中提取日期字符串
|
685
|
+
given_date_str = re.findall(r"\d{10}", str(nc_file))[0]
|
686
|
+
|
687
|
+
# 将提取的日期字符串转换为datetime对象
|
688
|
+
given_datetime = datetime.datetime.strptime(given_date_str, "%Y%m%d%H")
|
689
|
+
|
690
|
+
# 计算给定日期与时间原点之间的差值(以小时为单位)
|
691
|
+
time_difference = (given_datetime - origin_datetime).total_seconds()
|
692
|
+
if "hours" in time_units:
|
693
|
+
time_difference /= 3600
|
694
|
+
elif "days" in time_units:
|
695
|
+
time_difference /= 3600 * 24
|
696
|
+
|
697
|
+
# 修改NC文件中的时间变量
|
698
|
+
modify_nc(nc_file, "time", None, time_difference)
|
699
|
+
|
700
|
+
|
636
701
|
def _download_file(target_url, store_path, file_name, check=False):
|
637
702
|
# Check if the file exists
|
638
703
|
fname = Path(store_path) / file_name
|
@@ -640,108 +705,122 @@ def _download_file(target_url, store_path, file_name, check=False):
|
|
640
705
|
file_name_split = file_name_split[:-1]
|
641
706
|
# same_file = f"{file_name_split[0]}_{file_name_split[1]}*nc"
|
642
707
|
same_file = "_".join(file_name_split) + "*nc"
|
643
|
-
|
708
|
+
|
644
709
|
if check:
|
645
|
-
if same_file not in fsize_dict.keys():
|
646
|
-
check_nc(fname,if_delete=True)
|
710
|
+
if same_file not in fsize_dict.keys(): # 对第一个文件单独进行检查,因为没有大小可以对比
|
711
|
+
check_nc(fname, if_delete=True)
|
647
712
|
|
648
713
|
# set_min_size = _get_mean_size30(store_path, same_file) # 原方案,只30次取平均值;若遇变化,无法判断
|
649
714
|
get_mean_size = _get_mean_size_move(same_file, fname)
|
650
|
-
|
715
|
+
|
651
716
|
if _check_existing_file(fname, get_mean_size):
|
652
717
|
count_dict["skip"] += 1
|
653
718
|
return
|
654
719
|
_clear_existing_file(fname)
|
655
720
|
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
721
|
+
if not use_idm:
|
722
|
+
# -----------------------------------------------
|
723
|
+
print(f"[bold #f0f6d0]Requesting {file_name} ...")
|
724
|
+
# 创建会话
|
725
|
+
s = requests.Session()
|
726
|
+
download_success = False
|
727
|
+
request_times = 0
|
662
728
|
|
663
|
-
|
664
|
-
|
665
|
-
|
729
|
+
def calculate_wait_time(time_str, target_url):
|
730
|
+
# 定义正则表达式,匹配YYYYMMDDHH格式的时间
|
731
|
+
time_pattern = r"\d{10}"
|
666
732
|
|
667
|
-
|
668
|
-
|
669
|
-
|
733
|
+
# 定义两个字符串
|
734
|
+
# str1 = 'HYCOM_water_u_2018010100-2018010112.nc'
|
735
|
+
# str2 = 'HYCOM_water_u_2018010100.nc'
|
670
736
|
|
671
|
-
|
672
|
-
|
737
|
+
# 使用正则表达式查找时间
|
738
|
+
times_in_str = re.findall(time_pattern, time_str)
|
673
739
|
|
674
|
-
|
675
|
-
|
740
|
+
# 计算每个字符串中的时间数量
|
741
|
+
num_times_str = len(times_in_str)
|
676
742
|
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
743
|
+
if num_times_str > 1:
|
744
|
+
delta_t = datetime.datetime.strptime(times_in_str[1], "%Y%m%d%H") - datetime.datetime.strptime(times_in_str[0], "%Y%m%d%H")
|
745
|
+
delta_t = delta_t.total_seconds() / 3600
|
746
|
+
delta_t = delta_t / 3 + 1
|
747
|
+
else:
|
748
|
+
delta_t = 1
|
749
|
+
# 单个要素最多等待5分钟,不宜太短,太短可能请求失败;也不宜太长,太长可能会浪费时间
|
750
|
+
num_var = int(target_url.count("var="))
|
751
|
+
if num_var <= 0:
|
752
|
+
num_var = 1
|
753
|
+
return int(delta_t * 5 * 60 * num_var)
|
754
|
+
|
755
|
+
max_timeout = calculate_wait_time(file_name, target_url)
|
756
|
+
print(f"[bold #912dbc]Max timeout: {max_timeout} seconds")
|
757
|
+
|
758
|
+
# print(f'Download_start_time: {datetime.datetime.now()}')
|
759
|
+
download_time_s = datetime.datetime.now()
|
760
|
+
order_list = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"]
|
761
|
+
while not download_success:
|
762
|
+
if request_times >= 10:
|
763
|
+
# print(f'下载失败,已重试 {request_times} 次\n可先跳过,后续再试')
|
764
|
+
print(f"[bold #ffe5c0]Download failed after {request_times} times\nYou can skip it and try again later")
|
765
|
+
count_dict["fail"] += 1
|
766
|
+
break
|
767
|
+
if request_times > 0:
|
768
|
+
# print(f'\r正在重试第 {request_times} 次', end="")
|
769
|
+
print(f"[bold #ffe5c0]Retrying the {order_list[request_times - 1]} time...")
|
770
|
+
# 尝试下载文件
|
771
|
+
try:
|
772
|
+
headers = {"User-Agent": get_ua()}
|
773
|
+
""" response = s.get(target_url, headers=headers, timeout=random.randint(5, max_timeout))
|
774
|
+
response.raise_for_status() # 如果请求返回的不是200,将抛出HTTPError异常
|
775
|
+
|
776
|
+
# 保存文件
|
777
|
+
with open(filename, 'wb') as f:
|
778
|
+
f.write(response.content) """
|
779
|
+
|
780
|
+
response = s.get(target_url, headers=headers, stream=True, timeout=random.randint(5, max_timeout)) # 启用流式传输
|
781
|
+
response.raise_for_status() # 如果请求返回的不是200,将抛出HTTPError异常
|
782
|
+
# 保存文件
|
783
|
+
with open(fname, "wb") as f:
|
784
|
+
print(f"[bold #96cbd7]Downloading {file_name} ...")
|
785
|
+
for chunk in response.iter_content(chunk_size=1024):
|
786
|
+
if chunk:
|
787
|
+
f.write(chunk)
|
788
|
+
|
789
|
+
f.close()
|
790
|
+
|
791
|
+
if not _check_ftime(fname, if_print=True):
|
792
|
+
if match_time:
|
793
|
+
_correct_time(fname)
|
794
|
+
else:
|
795
|
+
_clear_existing_file(fname)
|
796
|
+
# print(f"[bold #ffe5c0]File time error, {fname}")
|
797
|
+
count_dict["no_data"] += 1
|
798
|
+
break
|
799
|
+
|
800
|
+
# print(f'\r文件 {fname} 下载成功', end="")
|
801
|
+
if os.path.exists(fname):
|
802
|
+
download_success = True
|
803
|
+
download_time_e = datetime.datetime.now()
|
804
|
+
download_delta = download_time_e - download_time_s
|
805
|
+
print(f"[#3dfc40]File [bold #dfff73]{fname} [#3dfc40]has been downloaded successfully, Time: [#39cbdd]{download_delta}")
|
806
|
+
count_dict["success"] += 1
|
807
|
+
# print(f'Download_end_time: {datetime.datetime.now()}')
|
808
|
+
|
809
|
+
except requests.exceptions.HTTPError as errh:
|
810
|
+
print(f"Http Error: {errh}")
|
811
|
+
except requests.exceptions.ConnectionError as errc:
|
812
|
+
print(f"Error Connecting: {errc}")
|
813
|
+
except requests.exceptions.Timeout as errt:
|
814
|
+
print(f"Timeout Error: {errt}")
|
815
|
+
except requests.exceptions.RequestException as err:
|
816
|
+
print(f"OOps: Something Else: {err}")
|
817
|
+
|
818
|
+
time.sleep(3)
|
819
|
+
request_times += 1
|
820
|
+
else:
|
821
|
+
idm_downloader(target_url, store_path, file_name, given_idm_engine)
|
822
|
+
idm_download_list.append(fname)
|
823
|
+
print(f"[bold #3dfc40]File [bold #dfff73]{fname} [#3dfc40]has been submit to IDM for downloading")
|
745
824
|
|
746
825
|
|
747
826
|
def _check_hour_is_valid(ymdh_str):
|
@@ -849,11 +928,11 @@ def _prepare_url_to_download(var, lon_min=0, lon_max=359.92, lat_min=-80, lat_ma
|
|
849
928
|
var = current_group[0]
|
850
929
|
submit_url = _get_submit_url_var(var, depth, level_num, lon_min, lon_max, lat_min, lat_max, dataset_name, version_name, download_time, download_time_end)
|
851
930
|
file_name = f"HYCOM_{variable_info[var]['var_name']}_{download_time}.nc"
|
852
|
-
old_str = f
|
853
|
-
new_str = f
|
931
|
+
old_str = f"var={variable_info[var]['var_name']}"
|
932
|
+
new_str = f"var={variable_info[var]['var_name']}"
|
854
933
|
if len(current_group) > 1:
|
855
934
|
for v in current_group[1:]:
|
856
|
-
new_str = f
|
935
|
+
new_str = f"{new_str}&var={variable_info[v]['var_name']}"
|
857
936
|
submit_url = submit_url.replace(old_str, new_str)
|
858
937
|
# file_name = f'HYCOM_{'-'.join([variable_info[v]["var_name"] for v in current_group])}_{download_time}.nc'
|
859
938
|
file_name = f"HYCOM_{key}_{download_time}.nc"
|
@@ -949,7 +1028,7 @@ def _download_hourly_func(var, time_s, time_e, lon_min=0, lon_max=359.92, lat_mi
|
|
949
1028
|
# 串行方式
|
950
1029
|
for i, time_str in enumerate(time_list):
|
951
1030
|
_prepare_url_to_download(var, lon_min, lon_max, lat_min, lat_max, time_str, None, depth, level, store_path, dataset_name, version_name, check)
|
952
|
-
progress.update(task, advance=1, description=f"[cyan]Downloading... {i+1}/{len(time_list)}")
|
1031
|
+
progress.update(task, advance=1, description=f"[cyan]Downloading... {i + 1}/{len(time_list)}")
|
953
1032
|
else:
|
954
1033
|
# 并行方式
|
955
1034
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
@@ -967,7 +1046,7 @@ def _download_hourly_func(var, time_s, time_e, lon_min=0, lon_max=359.92, lat_mi
|
|
967
1046
|
time_str_end_index = int(min(len(time_list) - 1, int(i * ftimes + ftimes - 1)))
|
968
1047
|
time_str_end = time_list[time_str_end_index]
|
969
1048
|
_prepare_url_to_download(var, lon_min, lon_max, lat_min, lat_max, time_str, time_str_end, depth, level, store_path, dataset_name, version_name, check)
|
970
|
-
progress.update(task, advance=1, description=f"[cyan]Downloading... {i+1}/{total_num}")
|
1049
|
+
progress.update(task, advance=1, description=f"[cyan]Downloading... {i + 1}/{total_num}")
|
971
1050
|
else:
|
972
1051
|
# 并行方式
|
973
1052
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
@@ -977,10 +1056,10 @@ def _download_hourly_func(var, time_s, time_e, lon_min=0, lon_max=359.92, lat_mi
|
|
977
1056
|
for feature in as_completed(futures):
|
978
1057
|
_done_callback(feature, progress, task, len(time_list), counter_lock)
|
979
1058
|
else:
|
980
|
-
print("Please ensure the time_s is no more than time_e")
|
1059
|
+
print("[bold red]Please ensure the time_s is no more than time_e")
|
981
1060
|
|
982
1061
|
|
983
|
-
def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, lat_max=90, depth=None, level=None, store_path=None, dataset_name=None, version_name=None, num_workers=None, check=False, ftimes=1):
|
1062
|
+
def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, lat_max=90, depth=None, level=None, store_path=None, dataset_name=None, version_name=None, num_workers=None, check=False, ftimes=1, idm_engine=None, fill_time=False):
|
984
1063
|
"""
|
985
1064
|
Description:
|
986
1065
|
Download the data of single time or a series of time
|
@@ -1001,6 +1080,8 @@ def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, l
|
|
1001
1080
|
num_workers: int, the number of workers, default is None, if not set, the number of workers will be 1; suggest not to set the number of workers too large
|
1002
1081
|
check: bool, whether to check the existing file, default is False, if set to True, the existing file will be checked and not downloaded again; else, the existing file will be covered
|
1003
1082
|
ftimes: int, the number of time in one file, default is 1, if set to 1, the data of single time will be downloaded; the maximum is 8, if set to 8, the data of 8 times will be downloaded in one file
|
1083
|
+
idm_engine: str, the IDM engine, default is None, if set, the IDM will be used to download the data; example: "D:\\Programs\\Internet Download Manager\\IDMan.exe"
|
1084
|
+
fill_time: bool, whether to match the time, default is False, if set to True, the time in the file name will be corrected according to the time in the file; else, the data will be skip if the time is not correct. Because the real time of some data that has been downloaded does not match the time in the file name, eg. the required time is 2024110100, but the time in the file name is 2024110103, so the data will be skip if the fill_time is False. Note: it is not the right time data, so it is not recommended to set fill_time to True
|
1004
1085
|
|
1005
1086
|
Returns:
|
1006
1087
|
None
|
@@ -1048,7 +1129,7 @@ def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, l
|
|
1048
1129
|
os.makedirs(str(store_path), exist_ok=True)
|
1049
1130
|
|
1050
1131
|
if num_workers is not None:
|
1051
|
-
num_workers = max(min(num_workers, 10), 1)
|
1132
|
+
num_workers = max(min(num_workers, 10), 1) # 暂时不限制最大值,再检查的时候可以多开一些线程
|
1052
1133
|
# num_workers = int(max(num_workers, 1))
|
1053
1134
|
time_s = str(time_s)
|
1054
1135
|
if len(time_s) == 8:
|
@@ -1068,12 +1149,52 @@ def download(var, time_s, time_e=None, lon_min=0, lon_max=359.92, lat_min=-80, l
|
|
1068
1149
|
|
1069
1150
|
global fsize_dict
|
1070
1151
|
fsize_dict = {}
|
1071
|
-
|
1152
|
+
|
1072
1153
|
global fsize_dict_lock
|
1073
1154
|
fsize_dict_lock = Lock()
|
1074
1155
|
|
1156
|
+
global use_idm, given_idm_engine, idm_download_list
|
1157
|
+
if idm_engine is not None:
|
1158
|
+
use_idm = True
|
1159
|
+
num_workers = 1
|
1160
|
+
given_idm_engine = idm_engine
|
1161
|
+
idm_download_list = []
|
1162
|
+
else:
|
1163
|
+
use_idm = False
|
1164
|
+
|
1165
|
+
global match_time
|
1166
|
+
if fill_time:
|
1167
|
+
match_time = True
|
1168
|
+
else:
|
1169
|
+
match_time = False
|
1170
|
+
|
1075
1171
|
_download_hourly_func(var, time_s, time_e, lon_min, lon_max, lat_min, lat_max, depth, level, store_path, dataset_name, version_name, num_workers, check, ftimes)
|
1076
1172
|
|
1173
|
+
if idm_download_list:
|
1174
|
+
for f in idm_download_list:
|
1175
|
+
wait_success = 0
|
1176
|
+
success = False
|
1177
|
+
while not success:
|
1178
|
+
if check_nc(f):
|
1179
|
+
if match_time:
|
1180
|
+
_correct_time(f)
|
1181
|
+
count_dict["success"] += 1
|
1182
|
+
else:
|
1183
|
+
if not _check_ftime(f):
|
1184
|
+
_clear_existing_file(f)
|
1185
|
+
count_dict["no_data"] += 1
|
1186
|
+
count_dict["no_data_list"].append(str(f).split("_")[-1].split(".")[0])
|
1187
|
+
else:
|
1188
|
+
count_dict["success"] += 1
|
1189
|
+
success = True
|
1190
|
+
else:
|
1191
|
+
wait_success += 1
|
1192
|
+
time.sleep(3)
|
1193
|
+
if wait_success >= 20:
|
1194
|
+
success = True
|
1195
|
+
# print(f'{f} download failed')
|
1196
|
+
count_dict["fail"] += 1
|
1197
|
+
|
1077
1198
|
count_dict["total"] = count_dict["success"] + count_dict["fail"] + count_dict["skip"] + count_dict["no_data"]
|
1078
1199
|
|
1079
1200
|
print("[bold #ecdbfe]-" * 160)
|
@@ -1140,10 +1261,6 @@ def how_to_use():
|
|
1140
1261
|
|
1141
1262
|
|
1142
1263
|
if __name__ == "__main__":
|
1143
|
-
time_s, time_e = "2024101012", "2024101018"
|
1144
|
-
merge_name = f"{time_s}_{time_e}" # 合并后的文件名
|
1145
|
-
root_path = r"G:\Data\HYCOM\3hourly_test"
|
1146
|
-
location_dict = {"west": 105, "east": 130, "south": 15, "north": 45}
|
1147
1264
|
download_dict = {
|
1148
1265
|
"water_u": {"simple_name": "u", "download": 1},
|
1149
1266
|
"water_v": {"simple_name": "v", "download": 1},
|
@@ -1158,50 +1275,31 @@ if __name__ == "__main__":
|
|
1158
1275
|
|
1159
1276
|
var_list = [var_name for var_name in download_dict.keys() if download_dict[var_name]["download"]]
|
1160
1277
|
|
1161
|
-
|
1162
|
-
# if you wanna download all depth or level, set both False
|
1163
|
-
depth = None # or 0-5000 meters
|
1164
|
-
level = None # or 1-40 levels
|
1165
|
-
num_workers = 3
|
1166
|
-
|
1167
|
-
check = True
|
1168
|
-
ftimes = 1
|
1169
|
-
|
1170
|
-
download_switch, single_var = True, False
|
1171
|
-
combine_switch = False
|
1172
|
-
copy_switch, copy_dir = False, r"G:\Data\HYCOM\3hourly"
|
1278
|
+
single_var = False
|
1173
1279
|
|
1174
1280
|
# draw_time_range(pic_save_folder=r'I:\Delete')
|
1175
1281
|
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1282
|
+
options = {
|
1283
|
+
"var": var_list,
|
1284
|
+
"time_s": "2018010100",
|
1285
|
+
"time_e": "2020123121",
|
1286
|
+
"store_path": r"F:\Data\HYCOM\3hourly",
|
1287
|
+
"lon_min": 105,
|
1288
|
+
"lon_max": 130,
|
1289
|
+
"lat_min": 15,
|
1290
|
+
"lat_max": 45,
|
1291
|
+
"num_workers": 3,
|
1292
|
+
"check": True,
|
1293
|
+
"depth": None, # or 0-5000 meters
|
1294
|
+
"level": None, # or 1-40 levels
|
1295
|
+
"ftimes": 1,
|
1296
|
+
"idm_engine": r"D:\Programs\Internet Download Manager\IDMan.exe", # 查漏补缺不建议开启
|
1297
|
+
"fill_time": False
|
1298
|
+
}
|
1182
1299
|
|
1183
|
-
|
1184
|
-
time_list = get_time_list(time_s, time_e, 3, 'hour')
|
1300
|
+
if single_var:
|
1185
1301
|
for var_name in var_list:
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
merge_path_name = Path(root_path)/f'HYCOM_{var_name}_{merge_name}.nc'
|
1191
|
-
else:
|
1192
|
-
# 如果混合,需要看情况获取文件列表
|
1193
|
-
fname = ''
|
1194
|
-
if var_name in ['water_u', 'water_v', 'water_u_bottom', 'water_v_bottom']:
|
1195
|
-
fname = 'uv3z'
|
1196
|
-
elif var_name in ['water_temp', 'salinity', 'water_temp_bottom', 'salinity_bottom']:
|
1197
|
-
fname = 'ts3z'
|
1198
|
-
elif var_name in ['surf_el']:
|
1199
|
-
fname = 'surf_el'
|
1200
|
-
for time_str in time_list:
|
1201
|
-
file_list.append(Path(root_path)/f'HYCOM_{fname}_{time_str}.nc')
|
1202
|
-
merge_path_name = Path(root_path)/f'HYCOM_{fname}_{merge_name}.nc'
|
1203
|
-
if combine_switch:
|
1204
|
-
# 这里的var_name必须是官方变量名,不能再是简写了
|
1205
|
-
merge(file_list, var_name, 'time', merge_path_name)
|
1206
|
-
if copy_switch:
|
1207
|
-
copy_file(merge_path_name, copy_dir) """
|
1302
|
+
options["var"] = var_name
|
1303
|
+
download(**options)
|
1304
|
+
else:
|
1305
|
+
download(**options)
|
@@ -38,7 +38,7 @@ def downloader(task_url, folder_path, file_name, idm_engine=r"D:\Programs\Intern
|
|
38
38
|
Return:
|
39
39
|
None
|
40
40
|
Example:
|
41
|
-
downloader("https://www.test.com/data.nc",
|
41
|
+
downloader("https://www.test.com/data.nc", "E:\\Data", "test.nc", "D:\\Programs\\Internet Download Manager\\IDMan.exe")
|
42
42
|
"""
|
43
43
|
os.makedirs(folder_path, exist_ok=True)
|
44
44
|
# 将任务添加至队列
|
@@ -23,6 +23,8 @@ import requests
|
|
23
23
|
from rich import print
|
24
24
|
from rich.progress import track
|
25
25
|
from oafuncs.oa_down.user_agent import get_ua
|
26
|
+
from oafuncs.oa_file import remove
|
27
|
+
from oafuncs.oa_data import ensure_list
|
26
28
|
|
27
29
|
__all__ = ["download5doi"]
|
28
30
|
|
@@ -221,7 +223,7 @@ class _Downloader:
|
|
221
223
|
print("Try another URL...")
|
222
224
|
|
223
225
|
|
224
|
-
def
|
226
|
+
def _read_excel(file, col_name=r"DOI"):
|
225
227
|
df = pd.read_excel(file)
|
226
228
|
df_list = df[col_name].tolist()
|
227
229
|
# 去掉nan
|
@@ -229,7 +231,7 @@ def read_excel(file, col_name=r"DOI"):
|
|
229
231
|
return df_list
|
230
232
|
|
231
233
|
|
232
|
-
def
|
234
|
+
def _read_txt(file):
|
233
235
|
with open(file, "r") as f:
|
234
236
|
lines = f.readlines()
|
235
237
|
# 去掉换行符以及空行
|
@@ -267,13 +269,13 @@ def download5doi(store_path=None, doi_list=None, txt_file=None, excel_file=None,
|
|
267
269
|
store_path.mkdir(parents=True, exist_ok=True)
|
268
270
|
store_path = str(store_path)
|
269
271
|
|
270
|
-
|
271
|
-
|
272
|
-
doi_list = [doi_list]
|
272
|
+
if doi_list:
|
273
|
+
doi_list = ensure_list(doi_list)
|
273
274
|
if txt_file:
|
274
|
-
doi_list =
|
275
|
+
doi_list = _read_txt(txt_file)
|
275
276
|
if excel_file:
|
276
|
-
doi_list =
|
277
|
+
doi_list = _read_excel(excel_file, col_name)
|
278
|
+
remove(Path(store_path) / "wrong_record.txt")
|
277
279
|
print(f"Downloading {len(doi_list)} PDF files...")
|
278
280
|
for doi in track(doi_list, description="Downloading..."):
|
279
281
|
download = _Downloader(doi, store_path)
|
@@ -281,7 +283,6 @@ def download5doi(store_path=None, doi_list=None, txt_file=None, excel_file=None,
|
|
281
283
|
|
282
284
|
|
283
285
|
if __name__ == "__main__":
|
284
|
-
store_path = r"F:\AAA-Delete\DOI_Reference\pdf"
|
285
|
-
excel_file = r"F:\AAA-Delete\DOI_Reference\savedrecs.xls"
|
286
|
-
# download5doi(store_path, doi_list='10.1007/s00382-022-06260-x')
|
286
|
+
store_path = r"F:\AAA-Delete\DOI_Reference\5\pdf"
|
287
|
+
excel_file = r"F:\AAA-Delete\DOI_Reference\5\savedrecs.xls"
|
287
288
|
download5doi(store_path, excel_file=excel_file)
|
@@ -226,8 +226,12 @@ def make_dir(directory):
|
|
226
226
|
make_dir(r"E:\Data\2024\09\17\var1")
|
227
227
|
"""
|
228
228
|
directory = str(directory)
|
229
|
-
os.
|
230
|
-
|
229
|
+
if os.path.exists(directory):
|
230
|
+
print(f"Directory already exists: {directory}")
|
231
|
+
return
|
232
|
+
else:
|
233
|
+
os.makedirs(directory, exist_ok=True)
|
234
|
+
print(f"Created directory: {directory}")
|
231
235
|
|
232
236
|
|
233
237
|
# ** 清空文件夹
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: oafuncs
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.93
|
4
4
|
Summary: Oceanic and Atmospheric Functions
|
5
5
|
Home-page: https://github.com/Industry-Pays/OAFuncs
|
6
6
|
Author: Kun Liu
|
@@ -30,6 +30,16 @@ Requires-Dist: matplotlib
|
|
30
30
|
Requires-Dist: Cartopy
|
31
31
|
Requires-Dist: netCDF4
|
32
32
|
Requires-Dist: xlrd
|
33
|
+
Dynamic: author
|
34
|
+
Dynamic: author-email
|
35
|
+
Dynamic: classifier
|
36
|
+
Dynamic: description
|
37
|
+
Dynamic: description-content-type
|
38
|
+
Dynamic: home-page
|
39
|
+
Dynamic: license
|
40
|
+
Dynamic: requires-dist
|
41
|
+
Dynamic: requires-python
|
42
|
+
Dynamic: summary
|
33
43
|
|
34
44
|
|
35
45
|
# oafuncs
|
@@ -18,7 +18,7 @@ URL = 'https://github.com/Industry-Pays/OAFuncs'
|
|
18
18
|
EMAIL = 'liukun0312@stu.ouc.edu.cn'
|
19
19
|
AUTHOR = 'Kun Liu'
|
20
20
|
REQUIRES_PYTHON = '>=3.9.0' # 2025/01/05
|
21
|
-
VERSION = '0.0.
|
21
|
+
VERSION = '0.0.93'
|
22
22
|
|
23
23
|
# What packages are required for this module to be executed?
|
24
24
|
REQUIRED = [
|
@@ -1,278 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python
|
2
|
-
# coding=utf-8
|
3
|
-
"""
|
4
|
-
Author: Liu Kun && 16031215@qq.com
|
5
|
-
Date: 2024-09-17 17:12:47
|
6
|
-
LastEditors: Liu Kun && 16031215@qq.com
|
7
|
-
LastEditTime: 2024-12-13 19:11:08
|
8
|
-
FilePath: \\Python\\My_Funcs\\OAFuncs\\oafuncs\\oa_data.py
|
9
|
-
Description:
|
10
|
-
EditPlatform: vscode
|
11
|
-
ComputerInfo: XPS 15 9510
|
12
|
-
SystemInfo: Windows 11
|
13
|
-
Python Version: 3.11
|
14
|
-
"""
|
15
|
-
|
16
|
-
import itertools
|
17
|
-
import multiprocessing as mp
|
18
|
-
from concurrent.futures import ThreadPoolExecutor
|
19
|
-
|
20
|
-
import numpy as np
|
21
|
-
from rich import print
|
22
|
-
from scipy.interpolate import griddata
|
23
|
-
|
24
|
-
|
25
|
-
__all__ = ["interp_2d"]
|
26
|
-
|
27
|
-
|
28
|
-
def interp_2d(target_x, target_y, origin_x, origin_y, data, method="linear", parallel=True):
|
29
|
-
"""
|
30
|
-
Perform 2D interpolation on the last two dimensions of a multi-dimensional array.
|
31
|
-
|
32
|
-
Parameters:
|
33
|
-
- target_x (array-like): 1D array of target grid's x-coordinates.
|
34
|
-
- target_y (array-like): 1D array of target grid's y-coordinates.
|
35
|
-
- origin_x (array-like): 1D array of original grid's x-coordinates.
|
36
|
-
- origin_y (array-like): 1D array of original grid's y-coordinates.
|
37
|
-
- data (numpy.ndarray): Multi-dimensional array where the last two dimensions correspond to the original grid.
|
38
|
-
- method (str, optional): Interpolation method, default is 'linear'. Other options include 'nearest', 'cubic', etc.
|
39
|
-
- parallel (bool, optional): Flag to enable parallel processing. Default is True.
|
40
|
-
|
41
|
-
Returns:
|
42
|
-
- interpolated_data (numpy.ndarray): Interpolated data with the same leading dimensions as the input data, but with the last two dimensions corresponding to the target grid.
|
43
|
-
|
44
|
-
Raises:
|
45
|
-
- ValueError: If the shape of the data does not match the shape of the origin_x or origin_y grids.
|
46
|
-
|
47
|
-
Usage:
|
48
|
-
- Interpolate a 2D array:
|
49
|
-
result = interp_2d(target_x, target_y, origin_x, origin_y, data_2d)
|
50
|
-
- Interpolate a 3D array (where the last two dimensions are spatial):
|
51
|
-
result = interp_2d(target_x, target_y, origin_x, origin_y, data_3d)
|
52
|
-
- Interpolate a 4D array (where the last two dimensions are spatial):
|
53
|
-
result = interp_2d(target_x, target_y, origin_x, origin_y, data_4d)
|
54
|
-
"""
|
55
|
-
|
56
|
-
def interp_single(data_slice, target_points, origin_points, method):
|
57
|
-
return griddata(origin_points, data_slice.ravel(), target_points, method=method).reshape(target_y.shape)
|
58
|
-
|
59
|
-
# 确保目标网格和初始网格都是二维的
|
60
|
-
if len(target_y.shape) == 1:
|
61
|
-
target_x, target_y = np.meshgrid(target_x, target_y)
|
62
|
-
if len(origin_y.shape) == 1:
|
63
|
-
origin_x, origin_y = np.meshgrid(origin_x, origin_y)
|
64
|
-
|
65
|
-
# 根据经纬度网格判断输入数据的形状是否匹配
|
66
|
-
if origin_x.shape != data.shape[-2:] or origin_y.shape != data.shape[-2:]:
|
67
|
-
raise ValueError("Shape of data does not match shape of origin_x or origin_y.")
|
68
|
-
|
69
|
-
# 创建网格和展平数据
|
70
|
-
target_points = np.column_stack((target_y.ravel(), target_x.ravel()))
|
71
|
-
origin_points = np.column_stack((origin_y.ravel(), origin_x.ravel()))
|
72
|
-
|
73
|
-
# 根据是否并行选择不同的执行方式
|
74
|
-
if parallel:
|
75
|
-
with ThreadPoolExecutor(max_workers=mp.cpu_count() - 2) as executor:
|
76
|
-
if len(data.shape) == 2:
|
77
|
-
interpolated_data = list(executor.map(interp_single, [data], [target_points], [origin_points], [method]))
|
78
|
-
elif len(data.shape) == 3:
|
79
|
-
interpolated_data = list(executor.map(interp_single, [data[i] for i in range(data.shape[0])], [target_points] * data.shape[0], [origin_points] * data.shape[0], [method] * data.shape[0]))
|
80
|
-
elif len(data.shape) == 4:
|
81
|
-
index_combinations = list(itertools.product(range(data.shape[0]), range(data.shape[1])))
|
82
|
-
interpolated_data = list(executor.map(interp_single, [data[i, j] for i, j in index_combinations], [target_points] * len(index_combinations), [origin_points] * len(index_combinations), [method] * len(index_combinations)))
|
83
|
-
interpolated_data = np.array(interpolated_data).reshape(data.shape[0], data.shape[1], *target_y.shape)
|
84
|
-
else:
|
85
|
-
if len(data.shape) == 2:
|
86
|
-
interpolated_data = interp_single(data, target_points, origin_points, method)
|
87
|
-
elif len(data.shape) == 3:
|
88
|
-
interpolated_data = np.stack([interp_single(data[i], target_points, origin_points, method) for i in range(data.shape[0])])
|
89
|
-
elif len(data.shape) == 4:
|
90
|
-
interpolated_data = np.stack([np.stack([interp_single(data[i, j], target_points, origin_points, method) for j in range(data.shape[1])]) for i in range(data.shape[0])])
|
91
|
-
|
92
|
-
return np.array(interpolated_data)
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
# ---------------------------------------------------------------------------------- not used below ----------------------------------------------------------------------------------
|
99
|
-
# ** 高维插值函数,插值最后两个维度
|
100
|
-
def interp_2d_20241213(target_x, target_y, origin_x, origin_y, data, method="linear"):
|
101
|
-
"""
|
102
|
-
高维插值函数,默认插值最后两个维度,传输数据前请确保数据的维度正确
|
103
|
-
参数:
|
104
|
-
target_y (array-like): 目标经度网格 1D 或 2D
|
105
|
-
target_x (array-like): 目标纬度网格 1D 或 2D
|
106
|
-
origin_y (array-like): 初始经度网格 1D 或 2D
|
107
|
-
origin_x (array-like): 初始纬度网格 1D 或 2D
|
108
|
-
data (array-like): 数据 (*, lat, lon) 2D, 3D, 4D
|
109
|
-
method (str, optional): 插值方法,可选 'linear', 'nearest', 'cubic' 等,默认为 'linear'
|
110
|
-
返回:
|
111
|
-
array-like: 插值结果
|
112
|
-
"""
|
113
|
-
|
114
|
-
# 确保目标网格和初始网格都是二维的
|
115
|
-
if len(target_y.shape) == 1:
|
116
|
-
target_x, target_y = np.meshgrid(target_x, target_y)
|
117
|
-
if len(origin_y.shape) == 1:
|
118
|
-
origin_x, origin_y = np.meshgrid(origin_x, origin_y)
|
119
|
-
|
120
|
-
dims = data.shape
|
121
|
-
len_dims = len(dims)
|
122
|
-
# print(dims[-2:])
|
123
|
-
# 根据经纬度网格判断输入数据的形状是否匹配
|
124
|
-
|
125
|
-
if origin_x.shape != dims[-2:] or origin_y.shape != dims[-2:]:
|
126
|
-
print(origin_x.shape, dims[-2:])
|
127
|
-
raise ValueError("Shape of data does not match shape of origin_x or origin_y.")
|
128
|
-
|
129
|
-
# 将目标网格展平成一维数组
|
130
|
-
target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
|
131
|
-
|
132
|
-
# 将初始网格展平成一维数组
|
133
|
-
origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
|
134
|
-
|
135
|
-
# 进行插值
|
136
|
-
if len_dims == 2:
|
137
|
-
interpolated_data = griddata(origin_points, np.ravel(data), target_points, method=method)
|
138
|
-
interpolated_data = np.reshape(interpolated_data, target_y.shape)
|
139
|
-
elif len_dims == 3:
|
140
|
-
interpolated_data = []
|
141
|
-
for i in range(dims[0]):
|
142
|
-
dt = griddata(origin_points, np.ravel(data[i, :, :]), target_points, method=method)
|
143
|
-
interpolated_data.append(np.reshape(dt, target_y.shape))
|
144
|
-
print(f"Interpolating {i + 1}/{dims[0]}...")
|
145
|
-
interpolated_data = np.array(interpolated_data)
|
146
|
-
elif len_dims == 4:
|
147
|
-
interpolated_data = []
|
148
|
-
for i in range(dims[0]):
|
149
|
-
interpolated_data.append([])
|
150
|
-
for j in range(dims[1]):
|
151
|
-
dt = griddata(origin_points, np.ravel(data[i, j, :, :]), target_points, method=method)
|
152
|
-
interpolated_data[i].append(np.reshape(dt, target_y.shape))
|
153
|
-
print(f"\rInterpolating {i * dims[1] + j + 1}/{dims[0] * dims[1]}...", end="")
|
154
|
-
print("\n")
|
155
|
-
interpolated_data = np.array(interpolated_data)
|
156
|
-
|
157
|
-
return interpolated_data
|
158
|
-
|
159
|
-
|
160
|
-
# ** 高维插值函数,插值最后两个维度,使用多线程进行插值
|
161
|
-
# 在本地电脑上可以提速三倍左右,超算上暂时无法加速
|
162
|
-
def interp_2d_parallel_20241213(target_x, target_y, origin_x, origin_y, data, method="linear"):
|
163
|
-
"""
|
164
|
-
param {*} target_x 目标经度网格 1D 或 2D
|
165
|
-
param {*} target_y 目标纬度网格 1D 或 2D
|
166
|
-
param {*} origin_x 初始经度网格 1D 或 2D
|
167
|
-
param {*} origin_y 初始纬度网格 1D 或 2D
|
168
|
-
param {*} data 数据 (*, lat, lon) 2D, 3D, 4D
|
169
|
-
param {*} method 插值方法,可选 'linear', 'nearest', 'cubic' 等,默认为 'linear'
|
170
|
-
return {*} 插值结果
|
171
|
-
description : 高维插值函数,默认插值最后两个维度,传输数据前请确保数据的维度正确
|
172
|
-
example : interpolated_data = interp_2d_parallel(target_x, target_y, origin_x, origin_y, data, method='linear')
|
173
|
-
"""
|
174
|
-
|
175
|
-
def interp_single2d(target_y, target_x, origin_y, origin_x, data, method="linear"):
|
176
|
-
target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
|
177
|
-
origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
|
178
|
-
|
179
|
-
dt = griddata(origin_points, np.ravel(data[:, :]), target_points, method=method)
|
180
|
-
return np.reshape(dt, target_y.shape)
|
181
|
-
|
182
|
-
def interp_single3d(i, target_y, target_x, origin_y, origin_x, data, method="linear"):
|
183
|
-
target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
|
184
|
-
origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
|
185
|
-
|
186
|
-
dt = griddata(origin_points, np.ravel(data[i, :, :]), target_points, method=method)
|
187
|
-
return np.reshape(dt, target_y.shape)
|
188
|
-
|
189
|
-
def interp_single4d(i, j, target_y, target_x, origin_y, origin_x, data, method="linear"):
|
190
|
-
target_points = np.column_stack((np.ravel(target_y), np.ravel(target_x)))
|
191
|
-
origin_points = np.column_stack((np.ravel(origin_y), np.ravel(origin_x)))
|
192
|
-
|
193
|
-
dt = griddata(origin_points, np.ravel(data[i, j, :, :]), target_points, method=method)
|
194
|
-
return np.reshape(dt, target_y.shape)
|
195
|
-
|
196
|
-
if len(target_y.shape) == 1:
|
197
|
-
target_x, target_y = np.meshgrid(target_x, target_y)
|
198
|
-
if len(origin_y.shape) == 1:
|
199
|
-
origin_x, origin_y = np.meshgrid(origin_x, origin_y)
|
200
|
-
|
201
|
-
dims = data.shape
|
202
|
-
len_dims = len(dims)
|
203
|
-
|
204
|
-
if origin_x.shape != dims[-2:] or origin_y.shape != dims[-2:]:
|
205
|
-
raise ValueError("数据形状与 origin_x 或 origin_y 的形状不匹配.")
|
206
|
-
|
207
|
-
interpolated_data = []
|
208
|
-
|
209
|
-
# 使用多线程进行插值
|
210
|
-
with ThreadPoolExecutor(max_workers=mp.cpu_count() - 2) as executor:
|
211
|
-
print(f"Using {mp.cpu_count() - 2} threads...")
|
212
|
-
if len_dims == 2:
|
213
|
-
interpolated_data = list(executor.map(interp_single2d, [target_y], [target_x], [origin_y], [origin_x], [data], [method]))
|
214
|
-
elif len_dims == 3:
|
215
|
-
interpolated_data = list(executor.map(interp_single3d, [i for i in range(dims[0])], [target_y] * dims[0], [target_x] * dims[0], [origin_y] * dims[0], [origin_x] * dims[0], [data] * dims[0], [method] * dims[0]))
|
216
|
-
elif len_dims == 4:
|
217
|
-
interpolated_data = list(
|
218
|
-
executor.map(
|
219
|
-
interp_single4d,
|
220
|
-
[i for i in range(dims[0]) for j in range(dims[1])],
|
221
|
-
[j for i in range(dims[0]) for j in range(dims[1])],
|
222
|
-
[target_y] * dims[0] * dims[1],
|
223
|
-
[target_x] * dims[0] * dims[1],
|
224
|
-
[origin_y] * dims[0] * dims[1],
|
225
|
-
[origin_x] * dims[0] * dims[1],
|
226
|
-
[data] * dims[0] * dims[1],
|
227
|
-
[method] * dims[0] * dims[1],
|
228
|
-
)
|
229
|
-
)
|
230
|
-
interpolated_data = np.array(interpolated_data).reshape(dims[0], dims[1], target_y.shape[0], target_x.shape[1])
|
231
|
-
|
232
|
-
interpolated_data = np.array(interpolated_data)
|
233
|
-
|
234
|
-
return interpolated_data
|
235
|
-
|
236
|
-
|
237
|
-
def _test_sum(a, b):
|
238
|
-
return a + b
|
239
|
-
|
240
|
-
|
241
|
-
if __name__ == "__main__":
|
242
|
-
|
243
|
-
pass
|
244
|
-
""" import time
|
245
|
-
|
246
|
-
import matplotlib.pyplot as plt
|
247
|
-
|
248
|
-
# 测试数据
|
249
|
-
origin_x = np.linspace(0, 10, 11)
|
250
|
-
origin_y = np.linspace(0, 10, 11)
|
251
|
-
target_x = np.linspace(0, 10, 101)
|
252
|
-
target_y = np.linspace(0, 10, 101)
|
253
|
-
data = np.random.rand(11, 11)
|
254
|
-
|
255
|
-
# 高维插值
|
256
|
-
origin_x = np.linspace(0, 10, 11)
|
257
|
-
origin_y = np.linspace(0, 10, 11)
|
258
|
-
target_x = np.linspace(0, 10, 101)
|
259
|
-
target_y = np.linspace(0, 10, 101)
|
260
|
-
data = np.random.rand(10, 10, 11, 11)
|
261
|
-
|
262
|
-
start = time.time()
|
263
|
-
interpolated_data = interp_2d(target_x, target_y, origin_x, origin_y, data, parallel=False)
|
264
|
-
print(f"Interpolation time: {time.time()-start:.2f}s")
|
265
|
-
|
266
|
-
print(interpolated_data.shape)
|
267
|
-
|
268
|
-
# 高维插值多线程
|
269
|
-
start = time.time()
|
270
|
-
interpolated_data = interp_2d(target_x, target_y, origin_x, origin_y, data)
|
271
|
-
print(f"Interpolation time: {time.time()-start:.2f}s")
|
272
|
-
|
273
|
-
print(interpolated_data.shape)
|
274
|
-
print(interpolated_data[0, 0, :, :].shape)
|
275
|
-
plt.figure()
|
276
|
-
plt.contourf(target_x, target_y, interpolated_data[0, 0, :, :])
|
277
|
-
plt.colorbar()
|
278
|
-
plt.show() """
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|