jacksung-dev 0.0.4.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jacksung/__init__.py +1 -0
- jacksung/ai/GeoAttX.py +356 -0
- jacksung/ai/GeoNet/__init__.py +0 -0
- jacksung/ai/GeoNet/m_block.py +393 -0
- jacksung/ai/GeoNet/m_blockV2.py +442 -0
- jacksung/ai/GeoNet/m_network.py +107 -0
- jacksung/ai/GeoNet/m_networkV2.py +91 -0
- jacksung/ai/__init__.py +0 -0
- jacksung/ai/latex_tool.py +199 -0
- jacksung/ai/metrics.py +181 -0
- jacksung/ai/utils/__init__.py +0 -0
- jacksung/ai/utils/cmorph.py +42 -0
- jacksung/ai/utils/data_parallelV2.py +90 -0
- jacksung/ai/utils/fy.py +333 -0
- jacksung/ai/utils/goes.py +161 -0
- jacksung/ai/utils/gsmap.py +24 -0
- jacksung/ai/utils/imerg.py +159 -0
- jacksung/ai/utils/metsat.py +164 -0
- jacksung/ai/utils/norm_util.py +109 -0
- jacksung/ai/utils/util.py +300 -0
- jacksung/libs/times.ttf +0 -0
- jacksung/utils/__init__.py +1 -0
- jacksung/utils/base_db.py +72 -0
- jacksung/utils/cache.py +71 -0
- jacksung/utils/data_convert.py +273 -0
- jacksung/utils/exception.py +27 -0
- jacksung/utils/fastnumpy.py +115 -0
- jacksung/utils/figure.py +251 -0
- jacksung/utils/hash.py +26 -0
- jacksung/utils/image.py +221 -0
- jacksung/utils/log.py +86 -0
- jacksung/utils/login.py +149 -0
- jacksung/utils/mean_std.py +66 -0
- jacksung/utils/multi_task.py +129 -0
- jacksung/utils/number.py +6 -0
- jacksung/utils/nvidia.py +140 -0
- jacksung/utils/time.py +87 -0
- jacksung/utils/web.py +63 -0
- jacksung_dev-0.0.4.15.dist-info/LICENSE +201 -0
- jacksung_dev-0.0.4.15.dist-info/METADATA +228 -0
- jacksung_dev-0.0.4.15.dist-info/RECORD +44 -0
- jacksung_dev-0.0.4.15.dist-info/WHEEL +5 -0
- jacksung_dev-0.0.4.15.dist-info/entry_points.txt +3 -0
- jacksung_dev-0.0.4.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import traceback
|
|
2
|
+
|
|
3
|
+
from satpy import Scene
|
|
4
|
+
from pyresample import create_area_def
|
|
5
|
+
import numpy as np
|
|
6
|
+
import os
|
|
7
|
+
from datetime import timedelta
|
|
8
|
+
from jacksung.utils.data_convert import np2tif, Coordinate
|
|
9
|
+
from contextlib import contextmanager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@contextmanager
|
|
13
|
+
def satpy_scene_context(filenames, reader):
|
|
14
|
+
"""satpy Scene上下文管理器"""
|
|
15
|
+
scn = None
|
|
16
|
+
try:
|
|
17
|
+
scn = Scene(filenames=filenames, reader=reader)
|
|
18
|
+
yield scn
|
|
19
|
+
finally:
|
|
20
|
+
# 显式清理satpy资源
|
|
21
|
+
if scn is not None:
|
|
22
|
+
try:
|
|
23
|
+
# 尝试调用satpy的清理方法
|
|
24
|
+
if hasattr(scn, 'close'):
|
|
25
|
+
scn.close()
|
|
26
|
+
# 手动删除引用
|
|
27
|
+
del scn
|
|
28
|
+
except Exception as e:
|
|
29
|
+
print(f"清理satpy资源时警告: {e}")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _define_wgs84_area(resolution=0.05, area_extent=(-14.5, -60, 105.5, 60.0)):
|
|
33
|
+
"""定义WGS84坐标系目标区域(60°N-60°S,全经度),保持您原有的区域定义"""
|
|
34
|
+
|
|
35
|
+
wgs84_proj = "+proj=longlat +datum=WGS84 +ellps=WGS84 +no_defs" # 旧版projection参数(PROJ4字符串)
|
|
36
|
+
target_area = create_area_def(
|
|
37
|
+
area_id="wgs84_60n60s",
|
|
38
|
+
projection=wgs84_proj, # 旧版参数:projection(PROJ4字符串)
|
|
39
|
+
area_extent=area_extent, # 保持您原有的区域范围
|
|
40
|
+
resolution=resolution,
|
|
41
|
+
description=f"WGS84 Lat/Lon, 60N-60S, {resolution}deg resolution"
|
|
42
|
+
)
|
|
43
|
+
return target_area
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _extract_time_from_filename(filename):
|
|
47
|
+
"""从MSG文件名中提取时间戳(格式:20251126062741 → 2025-11-26_06-27-41)"""
|
|
48
|
+
try:
|
|
49
|
+
# MSG文件名格式示例:MSG2-SEVI-MSG15-0100-NA-20251126062741.272000000Z-NA.nat
|
|
50
|
+
time_part = filename.split("-")[5].split(".")[0] # 提取"20251126062741"
|
|
51
|
+
formatted_time = f"{time_part[:4]}-{time_part[4:6]}-{time_part[6:8]}_{time_part[8:10]}-{time_part[10:12]}-{time_part[12:14]}"
|
|
52
|
+
return formatted_time
|
|
53
|
+
except IndexError:
|
|
54
|
+
print("警告:无法从文件名提取时间戳,使用默认时间格式")
|
|
55
|
+
return "unknown_time"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _process_msg_seviri_to_numpy(nat_file_path, resolution=0.05, resampler="nearest", channels=("WV_062",), lock=None,
|
|
59
|
+
only_ld=False):
|
|
60
|
+
try:
|
|
61
|
+
if lock is not None:
|
|
62
|
+
lock.acquire()
|
|
63
|
+
# ========== 主要改动开始 ==========
|
|
64
|
+
# 使用上下文管理器确保Scene资源被释放(替换原来的Scene创建)
|
|
65
|
+
with satpy_scene_context(filenames=[nat_file_path], reader="seviri_l1b_native") as scn:
|
|
66
|
+
scn.load(channels)
|
|
67
|
+
ld = scn[channels[0]].attrs['orbital_parameters']['projection_longitude']
|
|
68
|
+
if only_ld:
|
|
69
|
+
return ld
|
|
70
|
+
area_extent = (ld - 60, -60, ld + 60, 60.0)
|
|
71
|
+
target_area = _define_wgs84_area(resolution=resolution, area_extent=area_extent)
|
|
72
|
+
scn_wgs84 = scn.resample(target_area, resampler=resampler)
|
|
73
|
+
time_str = _extract_time_from_filename(os.path.basename(nat_file_path))
|
|
74
|
+
result = {
|
|
75
|
+
'data': {},
|
|
76
|
+
'metadata': {},
|
|
77
|
+
'coordinates': {},
|
|
78
|
+
'global_attrs': {
|
|
79
|
+
'source_file': os.path.basename(nat_file_path),
|
|
80
|
+
'processing_time': time_str,
|
|
81
|
+
'resolution_degrees': resolution,
|
|
82
|
+
'resampling_method': resampler,
|
|
83
|
+
'area_extent': area_extent
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
# 提取每个通道的数据
|
|
87
|
+
for ch in channels:
|
|
88
|
+
try:
|
|
89
|
+
data_array = scn_wgs84[ch].values
|
|
90
|
+
result['data'][ch] = data_array
|
|
91
|
+
result['metadata'][ch] = dict(scn_wgs84[ch].attrs)
|
|
92
|
+
except Exception as e:
|
|
93
|
+
print(f"提取{ch}通道数据失败:{str(e)}")
|
|
94
|
+
# 提取坐标信息
|
|
95
|
+
if channels:
|
|
96
|
+
first_ch = channels[0]
|
|
97
|
+
try:
|
|
98
|
+
area = scn_wgs84[first_ch].attrs['area']
|
|
99
|
+
lons, lats = area.get_lonlats()
|
|
100
|
+
result['coordinates']['longitude'] = lons
|
|
101
|
+
result['coordinates']['latitude'] = lats
|
|
102
|
+
result['coordinates']['shape'] = lons.shape
|
|
103
|
+
except Exception as e:
|
|
104
|
+
print(f"提取坐标信息失败:{str(e)}")
|
|
105
|
+
return result
|
|
106
|
+
except Exception as e:
|
|
107
|
+
print(f"处理失败:{str(e)}")
|
|
108
|
+
traceback.print_exc()
|
|
109
|
+
return None
|
|
110
|
+
finally:
|
|
111
|
+
if lock is not None:
|
|
112
|
+
lock.release()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def getNPfromNAT(file_path, save_file=False, lock=None, return_coord=False, only_coord=False):
|
|
116
|
+
all_target_channels = ["WV_062", "WV_073", "IR_087", "IR_097", "IR_108", "IR_120", "IR_134"]
|
|
117
|
+
# all_target_channels = ["VIS006", "VIS008"]
|
|
118
|
+
if only_coord:
|
|
119
|
+
ld = _process_msg_seviri_to_numpy(nat_file_path=file_path, resolution=0.05, resampler="nearest",
|
|
120
|
+
channels=all_target_channels, lock=lock, only_ld=True)
|
|
121
|
+
return Coordinate(left=ld - 60, bottom=-60, right=ld + 60, top=60, x_res=0.05, y_res=0.05)
|
|
122
|
+
else:
|
|
123
|
+
result = _process_msg_seviri_to_numpy(nat_file_path=file_path, resolution=0.05, resampler="nearest",
|
|
124
|
+
channels=all_target_channels, lock=lock, only_ld=False)
|
|
125
|
+
np_data = None
|
|
126
|
+
coord = None
|
|
127
|
+
if result is not None:
|
|
128
|
+
for idx, channel in enumerate(all_target_channels):
|
|
129
|
+
if channel in result['data']:
|
|
130
|
+
chann_data = result['data'][channel].copy()
|
|
131
|
+
if np_data is None:
|
|
132
|
+
np_data = np.zeros((len(all_target_channels),) + chann_data.shape, dtype=chann_data.dtype)
|
|
133
|
+
area_extent = result['global_attrs']['area_extent']
|
|
134
|
+
coord = Coordinate(left=area_extent[0], bottom=area_extent[1], right=area_extent[2],
|
|
135
|
+
top=area_extent[3], x_res=0.05, y_res=0.05)
|
|
136
|
+
np_data[idx] = chann_data
|
|
137
|
+
else:
|
|
138
|
+
del result
|
|
139
|
+
raise Exception(f"文件{file_path},通道 {channel} 数据未找到")
|
|
140
|
+
else:
|
|
141
|
+
raise Exception(f"文件{file_path}处理失败,未获取数据")
|
|
142
|
+
if save_file:
|
|
143
|
+
np2tif(np_data, save_path='np2tif_dir', coord=coord, out_name='MetSat',
|
|
144
|
+
dtype='float32')
|
|
145
|
+
if return_coord:
|
|
146
|
+
return np_data, coord
|
|
147
|
+
else:
|
|
148
|
+
return np_data
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_seviri_file_path(data_path, data_date):
|
|
152
|
+
e_date = data_date + timedelta(minutes=14, seconds=59)
|
|
153
|
+
parent_dir = rf'{data_path}/{data_date.year}/{data_date.month}/{data_date.day}/'
|
|
154
|
+
start_date_str = data_date.strftime('%Y%m%d%H%M%S')
|
|
155
|
+
end_date_str = e_date.strftime('%Y%m%d%H%M%S')
|
|
156
|
+
for file in os.listdir(parent_dir):
|
|
157
|
+
if file.endswith('.nat') and start_date_str <= file.split('-')[5].split('.')[0] <= end_date_str:
|
|
158
|
+
return os.path.join(parent_dir, file)
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
if __name__ == '__main__':
|
|
163
|
+
np_data = getNPfromNAT("MSG4-SEVI-MSG15-0100-NA-20221230031243.610000000Z-NA.nat", save_file=False)
|
|
164
|
+
print()
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PredNormalization:
|
|
9
|
+
def __init__(self, data_path):
|
|
10
|
+
self.mean = np.load(os.path.join(data_path, 'mean_level.npy')).astype(np.float32)[2:]
|
|
11
|
+
self.std = np.load(os.path.join(data_path, 'std_level.npy')).astype(np.float32)[2:]
|
|
12
|
+
self.mean = torch.from_numpy(self.mean)
|
|
13
|
+
self.std = torch.from_numpy(self.std)
|
|
14
|
+
|
|
15
|
+
def norm(self, data):
|
|
16
|
+
data = rearrange(data, 'b c h w->b h w c')
|
|
17
|
+
data = (data - self.mean) / self.std
|
|
18
|
+
return rearrange(data, 'b h w c->b c h w')
|
|
19
|
+
|
|
20
|
+
def denorm(self, data):
|
|
21
|
+
data = rearrange(data, 'b c h w->b h w c')
|
|
22
|
+
data = data * self.std + self.mean
|
|
23
|
+
return rearrange(data, 'b h w c->b c h w')
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PrecNormalization:
|
|
27
|
+
def __init__(self, data_path):
|
|
28
|
+
self.mean_qpe = torch.from_numpy(
|
|
29
|
+
np.load(os.path.join(data_path, 'mean_level_qpe.npy')).astype(np.float32))
|
|
30
|
+
self.std_qpe = torch.from_numpy(
|
|
31
|
+
np.load(os.path.join(data_path, 'std_level_qpe.npy')).astype(np.float32))
|
|
32
|
+
self.mean_fy = torch.from_numpy(
|
|
33
|
+
np.load(os.path.join(data_path, 'mean_level_fy.npy')).astype(np.float32).mean(axis=0)[2:])
|
|
34
|
+
self.std_fy = torch.from_numpy(
|
|
35
|
+
np.load(os.path.join(data_path, 'std_level_fy.npy')).astype(np.float32).mean(axis=0)[2:])
|
|
36
|
+
|
|
37
|
+
def norm(self, data, norm_type='fy'):
|
|
38
|
+
if norm_type == 'fy':
|
|
39
|
+
data = rearrange(data, 'b t c h w->b h w t c')
|
|
40
|
+
data = (data - self.mean_fy) / self.std_fy
|
|
41
|
+
return rearrange(data, 'b h w t c->b t c h w')
|
|
42
|
+
elif norm_type == 'qpe':
|
|
43
|
+
data = rearrange(data, 'b c h w->b h w c')
|
|
44
|
+
data = (data - self.mean_qpe) / self.std_qpe
|
|
45
|
+
return rearrange(data, 'b h w c->b c h w')
|
|
46
|
+
|
|
47
|
+
def denorm(self, data, norm_type='fy'):
|
|
48
|
+
if norm_type == 'fy':
|
|
49
|
+
data = rearrange(data, 'b t c h w->b h w t c')
|
|
50
|
+
data = data * self.std_fy + self.mean_fy
|
|
51
|
+
return rearrange(data, 'b h w t c->b t c h w')
|
|
52
|
+
elif norm_type == 'qpe':
|
|
53
|
+
data = rearrange(data, 'b c h w->b h w c')
|
|
54
|
+
data = data * self.std_qpe + self.mean_qpe
|
|
55
|
+
return rearrange(data, 'b h w c->b c h w')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PremNormalization:
|
|
59
|
+
|
|
60
|
+
def __init__(self, prec_data_path):
|
|
61
|
+
self.mean = torch.from_numpy(
|
|
62
|
+
np.load(os.path.join(prec_data_path, 'imerg_mean.npy')).astype(np.float32))
|
|
63
|
+
self.std = torch.from_numpy(
|
|
64
|
+
np.load(os.path.join(prec_data_path, 'imerg_std.npy')).astype(np.float32))
|
|
65
|
+
# self.mean = torch.Tensor(
|
|
66
|
+
# [3.6614e+03, 3.3748e+03, 3.3829e+03, 2.8368e+03, 2.5664e+03, 2.4914e+03, 2.4259e+03, 0.5723, 0, 0])
|
|
67
|
+
# self.std = torch.Tensor([164.7376, 265.8857, 252.9820, 509.8994, 532.4901, 518.8191, 414.1427, 0.6308, 1, 1])
|
|
68
|
+
# print(self.mean, self.std)
|
|
69
|
+
|
|
70
|
+
def norm(self, data, fy_norm=True):
|
|
71
|
+
data = rearrange(data, 'b c h w->b h w c')
|
|
72
|
+
if fy_norm:
|
|
73
|
+
data = (data - self.mean[:7]) / self.std[:7]
|
|
74
|
+
else:
|
|
75
|
+
# data[:, 0][data[:, 0] == 0] = torch.inf
|
|
76
|
+
# data[:, 0] = 1 / torch.sqrt(data[:, 0].clone())
|
|
77
|
+
data = (data - self.mean[7:]) / self.std[7:]
|
|
78
|
+
return rearrange(data, 'b h w c->b c h w')
|
|
79
|
+
|
|
80
|
+
def denorm(self, data, fy_norm=True):
|
|
81
|
+
data = rearrange(data, 'b c h w->b h w c')
|
|
82
|
+
if fy_norm:
|
|
83
|
+
data = data * self.std[:7] + self.mean[:7]
|
|
84
|
+
else:
|
|
85
|
+
data = data * self.std[7:] + self.mean[7:]
|
|
86
|
+
# data[:, 0][data[:, 0] == 0] = torch.inf
|
|
87
|
+
# data[:, 0] = 1 / (data[:, 0].clone() ** 2)
|
|
88
|
+
return rearrange(data, 'b h w c->b c h w')
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class Normalization:
|
|
92
|
+
def __init__(self, mean_std_npy, idx=None):
|
|
93
|
+
mean_std = torch.from_numpy(mean_std_npy.astype(np.float32))
|
|
94
|
+
if idx:
|
|
95
|
+
self.mean = mean_std[0, idx[0]:idx[1]]
|
|
96
|
+
self.std = mean_std[1, idx[0]:idx[1]]
|
|
97
|
+
else:
|
|
98
|
+
self.mean = mean_std[0]
|
|
99
|
+
self.std = mean_std[1]
|
|
100
|
+
|
|
101
|
+
def norm(self, data):
|
|
102
|
+
data = rearrange(data, 'b c h w->b h w c')
|
|
103
|
+
data = (data - self.mean) / self.std
|
|
104
|
+
return rearrange(data, 'b h w c->b c h w')
|
|
105
|
+
|
|
106
|
+
def denorm(self, data):
|
|
107
|
+
data = rearrange(data, 'b c h w->b h w c')
|
|
108
|
+
data = data * self.std + self.mean
|
|
109
|
+
return rearrange(data, 'b h w c->b c h w')
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import cv2
|
|
4
|
+
import numpy as np
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import os
|
|
7
|
+
from jacksung.utils.image import crop_png, zoom_image, zoomAndDock
|
|
8
|
+
import rasterio
|
|
9
|
+
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
|
|
10
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
11
|
+
import cartopy.feature as cfeature
|
|
12
|
+
import cartopy.crs as ccrs
|
|
13
|
+
import yaml
|
|
14
|
+
import argparse
|
|
15
|
+
from datetime import datetime, timedelta
|
|
16
|
+
from matplotlib.ticker import MaxNLocator
|
|
17
|
+
import random
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_model(model, state_dict, strict=True):
|
|
21
|
+
own_state = model.state_dict()
|
|
22
|
+
for name, param in state_dict.items():
|
|
23
|
+
name = name[name.index('.') + 1:]
|
|
24
|
+
if name in own_state.keys():
|
|
25
|
+
if isinstance(param, nn.Parameter):
|
|
26
|
+
param = param.data
|
|
27
|
+
try:
|
|
28
|
+
own_state[name].copy_(param)
|
|
29
|
+
# own_state[name].requires_grad = False
|
|
30
|
+
except Exception as e:
|
|
31
|
+
err_log = f'While copying the parameter named {name}, ' \
|
|
32
|
+
f'whose dimensions in the model are {own_state[name].size()} and ' \
|
|
33
|
+
f'whose dimensions in the checkpoint are {param.size()}.'
|
|
34
|
+
if not strict:
|
|
35
|
+
print(err_log)
|
|
36
|
+
else:
|
|
37
|
+
raise Exception(err_log)
|
|
38
|
+
elif strict:
|
|
39
|
+
raise KeyError(f'unexpected key {name} in {own_state.keys()}')
|
|
40
|
+
else:
|
|
41
|
+
print(f'{name} not loaded by model')
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_stat_dict(metrics, extra_info=None):
|
|
45
|
+
if extra_info is None:
|
|
46
|
+
extra_info = dict()
|
|
47
|
+
stat_dict = {'epochs': 0, 'loss': [], 'metrics': {}}
|
|
48
|
+
for idx, metrics in enumerate(metrics):
|
|
49
|
+
name, default_value, op = metrics
|
|
50
|
+
stat_dict['metrics'][name] = {'value': [], 'best': {'value': default_value, 'epoch': 0, 'op': op}}
|
|
51
|
+
for key, value in extra_info.items():
|
|
52
|
+
stat_dict[key] = value
|
|
53
|
+
return stat_dict
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def data_to_device(datas, device, fp=32):
|
|
57
|
+
outs = []
|
|
58
|
+
for data in datas:
|
|
59
|
+
if fp == 16:
|
|
60
|
+
data = data.type(torch.HalfTensor)
|
|
61
|
+
if fp == 64:
|
|
62
|
+
data = data.type(torch.DoubleTensor)
|
|
63
|
+
if fp == 32:
|
|
64
|
+
data = data.type(torch.FloatTensor)
|
|
65
|
+
data = data.to(device)
|
|
66
|
+
outs.append(data)
|
|
67
|
+
return outs
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def draw_lines(stat_dict_path):
|
|
71
|
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
|
72
|
+
print('[TemporaryTag]Producing the LinePicture of Log...', end='[TemporaryTag]\n')
|
|
73
|
+
yaml_args = yaml.load(open(stat_dict_path), Loader=yaml.FullLoader)
|
|
74
|
+
# 创建图表
|
|
75
|
+
m_len = len(yaml_args['metrics'])
|
|
76
|
+
sub_loss_count = 0
|
|
77
|
+
for i in range(0, 5):
|
|
78
|
+
if f'loss{i}es' in yaml_args:
|
|
79
|
+
sub_loss_count += 1
|
|
80
|
+
else:
|
|
81
|
+
break
|
|
82
|
+
plt.figure(figsize=(8 * (m_len + sub_loss_count), 6)) # 设置图表的大小
|
|
83
|
+
x = np.array(range(1, yaml_args['epochs'] + 1))
|
|
84
|
+
for idx, d in enumerate(yaml_args['metrics'].items()):
|
|
85
|
+
m_name, m_value = d
|
|
86
|
+
y = np.array(m_value['value'])
|
|
87
|
+
# 生成数据
|
|
88
|
+
plt.subplot(1, m_len + sub_loss_count, idx + 1)
|
|
89
|
+
plt.plot(x, y)
|
|
90
|
+
plt.title(m_name)
|
|
91
|
+
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
|
|
92
|
+
for i in range(0, sub_loss_count):
|
|
93
|
+
y = np.array(yaml_args[f'loss{i}es'])
|
|
94
|
+
plt.subplot(1, m_len + sub_loss_count, m_len + i + 1)
|
|
95
|
+
scale = len(y) / yaml_args['epochs']
|
|
96
|
+
# 把X坐标的纯统计范围缩放到和其他图表一致(epoch)
|
|
97
|
+
x = np.array(range(1, len(y) + 1)) / scale
|
|
98
|
+
plt.plot(x, y)
|
|
99
|
+
plt.title(f'Loss{i}')
|
|
100
|
+
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
|
|
101
|
+
# 添加图例
|
|
102
|
+
# plt.legend()
|
|
103
|
+
plt.savefig(os.path.join(os.path.dirname(stat_dict_path), 'Metrics.jpg'))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def make_best_metric(stat_dict, metrics, epoch, save_model_param, server_log_param):
|
|
107
|
+
save_model_flag = False
|
|
108
|
+
experiment_model_path, model, optimizer, scheduler = save_model_param
|
|
109
|
+
log, epochs, cloudLogName = server_log_param
|
|
110
|
+
|
|
111
|
+
for name, m_value in metrics:
|
|
112
|
+
stat_dict['metrics'][name]['value'].append(m_value)
|
|
113
|
+
inf = float('inf')
|
|
114
|
+
if eval(str(m_value) + stat_dict['metrics'][name]['best']['op'] + str(
|
|
115
|
+
stat_dict['metrics'][name]['best']['value'])):
|
|
116
|
+
stat_dict['metrics'][name]['best']['value'] = m_value
|
|
117
|
+
stat_dict['metrics'][name]['best']['epoch'] = epoch
|
|
118
|
+
log.send_log('{}:{} epoch:{}/{}'.format(name, m_value, epoch, epochs), cloudLogName)
|
|
119
|
+
save_model_flag = True
|
|
120
|
+
|
|
121
|
+
if save_model_flag:
|
|
122
|
+
# sava best model
|
|
123
|
+
save_model(os.path.join(experiment_model_path, 'model_{}.pt'.format(epoch)), epoch,
|
|
124
|
+
model, optimizer, scheduler, stat_dict)
|
|
125
|
+
# '[Validation] nRMSE/RMSE: {:.4f}/{:.4f} (Best: {:.4f}/{:.4f}, Epoch: {}/{})\n'
|
|
126
|
+
test_log = f'[Val epoch:{epoch}] ' \
|
|
127
|
+
+ ' '.join([str(m[0]) + ':' + str(round(m[1], 4)) + '('
|
|
128
|
+
+ str(round(stat_dict['metrics'][m[0]]['best']['value'], 4)) + ')' for m in metrics]) \
|
|
129
|
+
+ ' (Best Epoch: ' \
|
|
130
|
+
+ '/'.join([str(stat_dict['metrics'][m[0]]['best']['epoch']) for m in metrics]) \
|
|
131
|
+
+ ')'
|
|
132
|
+
save_model(os.path.join(experiment_model_path, 'model_latest.pt'), epoch, model, optimizer, scheduler, stat_dict)
|
|
133
|
+
return test_log
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def save_model(_path, _epoch, _model, _optimizer, _scheduler, _stat_dict):
|
|
137
|
+
# torch.save(model.state_dict(), saved_model_path)
|
|
138
|
+
torch.save({
|
|
139
|
+
'epoch': _epoch,
|
|
140
|
+
'model_state_dict': _model.state_dict(),
|
|
141
|
+
'optimizer_state_dict': _optimizer.state_dict(),
|
|
142
|
+
'scheduler_state_dict': _scheduler.state_dict(),
|
|
143
|
+
'stat_dict': _stat_dict
|
|
144
|
+
}, _path)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def parse_config(config=None, set_gpu=True):
|
|
148
|
+
parser = argparse.ArgumentParser(description='config')
|
|
149
|
+
parser.add_argument('--config', type=str, default=None, help='pre-config file for training')
|
|
150
|
+
parser.add_argument('--prec_data_path', type=str, default=None, help='dataset path')
|
|
151
|
+
args = parser.parse_args()
|
|
152
|
+
if args.config:
|
|
153
|
+
opt = vars(args)
|
|
154
|
+
yaml_args = yaml.load(open(args.config), Loader=yaml.FullLoader)
|
|
155
|
+
opt.update(yaml_args)
|
|
156
|
+
else:
|
|
157
|
+
opt = vars(args)
|
|
158
|
+
yaml_args = yaml.load(open(config), Loader=yaml.FullLoader)
|
|
159
|
+
opt.update(yaml_args)
|
|
160
|
+
|
|
161
|
+
if set_gpu:
|
|
162
|
+
# set visible gpu
|
|
163
|
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
|
164
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(x) for x in args.gpu_ids])
|
|
165
|
+
|
|
166
|
+
# select active gpu devices
|
|
167
|
+
if args.gpu_ids is not None and torch.cuda.is_available():
|
|
168
|
+
print('use cuda & cudnn for acceleration!')
|
|
169
|
+
print('the gpu id is: {}'.format(args.gpu_ids))
|
|
170
|
+
device = torch.device('cuda')
|
|
171
|
+
# device = torch.device('cuda:' + str(args.gpu_ids[0]))
|
|
172
|
+
else:
|
|
173
|
+
print('use cpu for training!')
|
|
174
|
+
device = torch.device('cpu')
|
|
175
|
+
return device, args
|
|
176
|
+
else:
|
|
177
|
+
return args
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _get_color_normalization(data, colors):
|
|
181
|
+
max_value = colors[-1][0]
|
|
182
|
+
min_value = colors[0][0]
|
|
183
|
+
data[data < min_value] = min_value
|
|
184
|
+
data[data > max_value] = max_value
|
|
185
|
+
data = (data - min_value) / (max_value - min_value)
|
|
186
|
+
new_colors = []
|
|
187
|
+
for color in colors:
|
|
188
|
+
new_colors.append([(color[0] - min_value) / (max_value - min_value), color[1]])
|
|
189
|
+
return data, new_colors
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def data_augmentation(images):
|
|
193
|
+
# 旋转
|
|
194
|
+
rotate = random.random()
|
|
195
|
+
if 0 <= rotate < 0.25:
|
|
196
|
+
images = [torch.rot90(image, 1, [2, 3]) for image in images]
|
|
197
|
+
elif 0.25 <= rotate < 0.5:
|
|
198
|
+
images = [torch.rot90(image, 2, [2, 3]) for image in images]
|
|
199
|
+
elif 0.5 <= rotate < 0.75:
|
|
200
|
+
images = [torch.rot90(image, 3, [2, 3]) for image in images]
|
|
201
|
+
# 水平翻折
|
|
202
|
+
if random.random() > 0.5:
|
|
203
|
+
images = [torch.flip(image, [2]) for image in images]
|
|
204
|
+
# 垂直翻折
|
|
205
|
+
if random.random() > 0.5:
|
|
206
|
+
images = [torch.flip(image, [3]) for image in images]
|
|
207
|
+
return images
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def make_fig(file_name, root_path, out_folder=None, tz='UTC',
|
|
211
|
+
colors=((0, '#1E90FF'), (0.1, '#1874CD'), (0.2, '#3A5FCD'), (0.3, '#0000CD'), (1, '#9400D3')),
|
|
212
|
+
area=((100, 140, 10), (20, 60, 10)), font_size=20, corp=(0, 0, None, None),
|
|
213
|
+
zoom_rectangle=(310 * 5, 300 * 5, 50 * 5, 40 * 5), docker=(300, 730), dpi=500, filter=0.3, exposure=None):
|
|
214
|
+
# corp = [92, 31, 542, 456]
|
|
215
|
+
print('过期的make_figure方法!')
|
|
216
|
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
217
|
+
extents = [100, 140, 20, 60]
|
|
218
|
+
proj = ccrs.PlateCarree()
|
|
219
|
+
fig = plt.figure(dpi=dpi)
|
|
220
|
+
ax = fig.add_subplot(111, projection=proj)
|
|
221
|
+
ax.set_extent(extents, crs=proj)
|
|
222
|
+
# 读取TIFF数据
|
|
223
|
+
elevation = None
|
|
224
|
+
if type(file_name) == list:
|
|
225
|
+
for each_file in file_name:
|
|
226
|
+
file_path = os.path.join(root_path, each_file)
|
|
227
|
+
with rasterio.open(file_path) as dataset:
|
|
228
|
+
el_rd = dataset.read(1)
|
|
229
|
+
elevation[elevation < filter] = np.nan
|
|
230
|
+
if elevation is None:
|
|
231
|
+
elevation = el_rd
|
|
232
|
+
else:
|
|
233
|
+
elevation += el_rd
|
|
234
|
+
else:
|
|
235
|
+
file_path = os.path.join(root_path, file_name)
|
|
236
|
+
with rasterio.open(file_path) as dataset:
|
|
237
|
+
elevation = dataset.read(1)
|
|
238
|
+
elevation[elevation <= filter] = np.nan
|
|
239
|
+
elevation, colors = _get_color_normalization(elevation, colors)
|
|
240
|
+
cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)
|
|
241
|
+
# 添加各种特征
|
|
242
|
+
land = cfeature.NaturalEarthFeature('physical', 'land', '50m', edgecolor='black', facecolor='none',
|
|
243
|
+
linewidth=0.4)
|
|
244
|
+
ax.add_feature(land)
|
|
245
|
+
ax.imshow(elevation, origin='upper', extent=extents, transform=proj, cmap=cmap)
|
|
246
|
+
ax.add_feature(cfeature.OCEAN)
|
|
247
|
+
ax.add_feature(cfeature.LAND)
|
|
248
|
+
# ax.add_feature(cfeature.LAKES, edgecolor='black')
|
|
249
|
+
ax.add_feature(cfeature.RIVERS)
|
|
250
|
+
# ax.add_feature(cfeature.BORDERS)
|
|
251
|
+
# 添加网格线
|
|
252
|
+
# ax.gridlines(linestyle='--')
|
|
253
|
+
# 设置大刻度和小刻度
|
|
254
|
+
tick_proj = ccrs.PlateCarree()
|
|
255
|
+
ax.set_xticks(np.arange(area[0][0], area[0][1] + 1, area[0][2]), crs=tick_proj)
|
|
256
|
+
# ax.set_xticks(np.arange(-180, 180 + 30, 30), minor=True, crs=tick_proj)
|
|
257
|
+
ax.set_yticks(np.arange(area[1][0], area[1][1] + 1, area[1][2]), crs=tick_proj)
|
|
258
|
+
# ax.set_yticks(np.arange(-90, 90 + 15, 15), minor=True, crs=tick_proj)
|
|
259
|
+
# 利用Formatter格式化刻度标签
|
|
260
|
+
ax.xaxis.set_major_formatter(LongitudeFormatter())
|
|
261
|
+
ax.yaxis.set_major_formatter(LatitudeFormatter())
|
|
262
|
+
if out_folder is None:
|
|
263
|
+
if type(file_name) == list:
|
|
264
|
+
file_name = file_name[0]
|
|
265
|
+
file_dir = 'concate'
|
|
266
|
+
else:
|
|
267
|
+
file_dir = 'outs'
|
|
268
|
+
else:
|
|
269
|
+
file_dir = out_folder
|
|
270
|
+
os.makedirs(os.path.join(root_path, file_dir), exist_ok=True)
|
|
271
|
+
file_name = file_name.replace('.tif', '.png')
|
|
272
|
+
file_title = datetime.strptime(file_name.split('-')[0].replace('target_', ''), '%Y%m%d_%H%M%S')
|
|
273
|
+
file_name = file_title.strftime('%Y%m%d_%H%M%S.png')
|
|
274
|
+
exposure = exposure if exposure else (60 if file_dir == 'concate' else 15)
|
|
275
|
+
if tz == 'BJT':
|
|
276
|
+
exposure_end = (file_title + timedelta(hours=8) + timedelta(minutes=exposure)).strftime('%H:%M')
|
|
277
|
+
file_title = (file_title + timedelta(hours=8)).strftime('%Y-%m-%d %H:%M')
|
|
278
|
+
else:
|
|
279
|
+
exposure_end = (file_title + timedelta(minutes=exposure)).strftime('%H:%M')
|
|
280
|
+
file_title = file_title.strftime('%Y-%m-%d %H:%M')
|
|
281
|
+
ax.set_title(file_title + f'-' + exposure_end + f' ({tz})', fontsize=font_size)
|
|
282
|
+
plt.xticks(fontsize=font_size)
|
|
283
|
+
plt.yticks(fontsize=font_size)
|
|
284
|
+
# plt.title(fontsize=font_size)
|
|
285
|
+
file_save_path = os.path.join(root_path, file_dir, file_name)
|
|
286
|
+
plt.savefig(file_save_path)
|
|
287
|
+
if zoom_rectangle is not None:
|
|
288
|
+
read_png = cv2.imread(file_save_path)
|
|
289
|
+
read_png = zoomAndDock(read_png, zoom_rectangle, docker, scale_factor=5, border=14)
|
|
290
|
+
cv2.imwrite(file_save_path, read_png)
|
|
291
|
+
crop_png(file_save_path, left=corp[0], top=corp[1], right=corp[2], bottom=corp[3])
|
|
292
|
+
return file_save_path
|
|
293
|
+
# plt.show()
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def clipSatelliteNP(np_data, ld, area=((100, 140, 10), (20, 60, 10))):
|
|
297
|
+
lon_d = int((ld - (area[0][0] + area[0][1]) / 2) * 20)
|
|
298
|
+
lat_d = int(((area[1][0] + area[1][1]) / 2) * 20)
|
|
299
|
+
np_data = np_data[:, 800 - lat_d:1600 - lat_d, 800 - lon_d:1600 - lon_d]
|
|
300
|
+
return np_data
|
jacksung/libs/times.ttf
ADDED
|
Binary file
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import *
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import pymysql
|
|
4
|
+
from datetime import datetime, timedelta
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
import threading
|
|
8
|
+
import multiprocessing
|
|
9
|
+
import traceback, sys
|
|
10
|
+
|
|
11
|
+
import configparser
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def convert_str(s):
|
|
15
|
+
return "'" + str(s).replace("'", '"') + "'"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def convert_num(n):
|
|
19
|
+
return str(n)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseDB:
|
|
23
|
+
def __init__(self, ini_path='db.ini'):
|
|
24
|
+
self.conn = None
|
|
25
|
+
self.db = None
|
|
26
|
+
self.passwd = None
|
|
27
|
+
self.user = None
|
|
28
|
+
self.host = None
|
|
29
|
+
self.port = None
|
|
30
|
+
self.autocommit = True
|
|
31
|
+
self.ini_path = ini_path
|
|
32
|
+
pymysql.install_as_MySQLdb()
|
|
33
|
+
self.lock_t = threading.Lock()
|
|
34
|
+
|
|
35
|
+
def read_config(self):
|
|
36
|
+
config = configparser.ConfigParser()
|
|
37
|
+
config.read(self.ini_path)
|
|
38
|
+
# Create the connection object
|
|
39
|
+
self.host = config['database']['host']
|
|
40
|
+
self.passwd = config['database']['password']
|
|
41
|
+
self.db = config['database']['database']
|
|
42
|
+
self.user = config['database'].get('user', 'root')
|
|
43
|
+
self.port = config['database'].getint('port', 3306)
|
|
44
|
+
|
|
45
|
+
def reconnect(self):
|
|
46
|
+
if not self.conn:
|
|
47
|
+
self.read_config()
|
|
48
|
+
# 打开数据库连接
|
|
49
|
+
connection = pymysql.connect(host=self.host, user=self.user, passwd=self.passwd, db=self.db, port=self.port,
|
|
50
|
+
autocommit=self.autocommit)
|
|
51
|
+
tqdm.write('Reconnected!')
|
|
52
|
+
return connection
|
|
53
|
+
|
|
54
|
+
def execute(self, sql):
|
|
55
|
+
self.lock_t.acquire()
|
|
56
|
+
cursor = None
|
|
57
|
+
try:
|
|
58
|
+
cursor = self.conn.cursor()
|
|
59
|
+
result = cursor.execute(sql)
|
|
60
|
+
except Exception as e:
|
|
61
|
+
self.conn = self.reconnect()
|
|
62
|
+
cursor = self.conn.cursor()
|
|
63
|
+
try:
|
|
64
|
+
result = cursor.execute(sql)
|
|
65
|
+
except Exception as e:
|
|
66
|
+
print(f'err SQL: {sql}')
|
|
67
|
+
raise e
|
|
68
|
+
finally:
|
|
69
|
+
if cursor:
|
|
70
|
+
cursor.close()
|
|
71
|
+
self.lock_t.release()
|
|
72
|
+
return result, cursor
|