jacksung-dev 0.0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. jacksung/__init__.py +1 -0
  2. jacksung/ai/GeoAttX.py +356 -0
  3. jacksung/ai/GeoNet/__init__.py +0 -0
  4. jacksung/ai/GeoNet/m_block.py +393 -0
  5. jacksung/ai/GeoNet/m_blockV2.py +442 -0
  6. jacksung/ai/GeoNet/m_network.py +107 -0
  7. jacksung/ai/GeoNet/m_networkV2.py +91 -0
  8. jacksung/ai/__init__.py +0 -0
  9. jacksung/ai/latex_tool.py +199 -0
  10. jacksung/ai/metrics.py +181 -0
  11. jacksung/ai/utils/__init__.py +0 -0
  12. jacksung/ai/utils/cmorph.py +42 -0
  13. jacksung/ai/utils/data_parallelV2.py +90 -0
  14. jacksung/ai/utils/fy.py +333 -0
  15. jacksung/ai/utils/goes.py +161 -0
  16. jacksung/ai/utils/gsmap.py +24 -0
  17. jacksung/ai/utils/imerg.py +159 -0
  18. jacksung/ai/utils/metsat.py +164 -0
  19. jacksung/ai/utils/norm_util.py +109 -0
  20. jacksung/ai/utils/util.py +300 -0
  21. jacksung/libs/times.ttf +0 -0
  22. jacksung/utils/__init__.py +1 -0
  23. jacksung/utils/base_db.py +72 -0
  24. jacksung/utils/cache.py +71 -0
  25. jacksung/utils/data_convert.py +273 -0
  26. jacksung/utils/exception.py +27 -0
  27. jacksung/utils/fastnumpy.py +115 -0
  28. jacksung/utils/figure.py +251 -0
  29. jacksung/utils/hash.py +26 -0
  30. jacksung/utils/image.py +221 -0
  31. jacksung/utils/log.py +86 -0
  32. jacksung/utils/login.py +149 -0
  33. jacksung/utils/mean_std.py +66 -0
  34. jacksung/utils/multi_task.py +129 -0
  35. jacksung/utils/number.py +6 -0
  36. jacksung/utils/nvidia.py +140 -0
  37. jacksung/utils/time.py +87 -0
  38. jacksung/utils/web.py +63 -0
  39. jacksung_dev-0.0.4.15.dist-info/LICENSE +201 -0
  40. jacksung_dev-0.0.4.15.dist-info/METADATA +228 -0
  41. jacksung_dev-0.0.4.15.dist-info/RECORD +44 -0
  42. jacksung_dev-0.0.4.15.dist-info/WHEEL +5 -0
  43. jacksung_dev-0.0.4.15.dist-info/entry_points.txt +3 -0
  44. jacksung_dev-0.0.4.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,164 @@
1
+ import traceback
2
+
3
+ from satpy import Scene
4
+ from pyresample import create_area_def
5
+ import numpy as np
6
+ import os
7
+ from datetime import timedelta
8
+ from jacksung.utils.data_convert import np2tif, Coordinate
9
+ from contextlib import contextmanager
10
+
11
+
12
+ @contextmanager
13
+ def satpy_scene_context(filenames, reader):
14
+ """satpy Scene上下文管理器"""
15
+ scn = None
16
+ try:
17
+ scn = Scene(filenames=filenames, reader=reader)
18
+ yield scn
19
+ finally:
20
+ # 显式清理satpy资源
21
+ if scn is not None:
22
+ try:
23
+ # 尝试调用satpy的清理方法
24
+ if hasattr(scn, 'close'):
25
+ scn.close()
26
+ # 手动删除引用
27
+ del scn
28
+ except Exception as e:
29
+ print(f"清理satpy资源时警告: {e}")
30
+
31
+
32
+ def _define_wgs84_area(resolution=0.05, area_extent=(-14.5, -60, 105.5, 60.0)):
33
+ """定义WGS84坐标系目标区域(60°N-60°S,全经度),保持您原有的区域定义"""
34
+
35
+ wgs84_proj = "+proj=longlat +datum=WGS84 +ellps=WGS84 +no_defs" # 旧版projection参数(PROJ4字符串)
36
+ target_area = create_area_def(
37
+ area_id="wgs84_60n60s",
38
+ projection=wgs84_proj, # 旧版参数:projection(PROJ4字符串)
39
+ area_extent=area_extent, # 保持您原有的区域范围
40
+ resolution=resolution,
41
+ description=f"WGS84 Lat/Lon, 60N-60S, {resolution}deg resolution"
42
+ )
43
+ return target_area
44
+
45
+
46
+ def _extract_time_from_filename(filename):
47
+ """从MSG文件名中提取时间戳(格式:20251126062741 → 2025-11-26_06-27-41)"""
48
+ try:
49
+ # MSG文件名格式示例:MSG2-SEVI-MSG15-0100-NA-20251126062741.272000000Z-NA.nat
50
+ time_part = filename.split("-")[5].split(".")[0] # 提取"20251126062741"
51
+ formatted_time = f"{time_part[:4]}-{time_part[4:6]}-{time_part[6:8]}_{time_part[8:10]}-{time_part[10:12]}-{time_part[12:14]}"
52
+ return formatted_time
53
+ except IndexError:
54
+ print("警告:无法从文件名提取时间戳,使用默认时间格式")
55
+ return "unknown_time"
56
+
57
+
58
+ def _process_msg_seviri_to_numpy(nat_file_path, resolution=0.05, resampler="nearest", channels=("WV_062",), lock=None,
59
+ only_ld=False):
60
+ try:
61
+ if lock is not None:
62
+ lock.acquire()
63
+ # ========== 主要改动开始 ==========
64
+ # 使用上下文管理器确保Scene资源被释放(替换原来的Scene创建)
65
+ with satpy_scene_context(filenames=[nat_file_path], reader="seviri_l1b_native") as scn:
66
+ scn.load(channels)
67
+ ld = scn[channels[0]].attrs['orbital_parameters']['projection_longitude']
68
+ if only_ld:
69
+ return ld
70
+ area_extent = (ld - 60, -60, ld + 60, 60.0)
71
+ target_area = _define_wgs84_area(resolution=resolution, area_extent=area_extent)
72
+ scn_wgs84 = scn.resample(target_area, resampler=resampler)
73
+ time_str = _extract_time_from_filename(os.path.basename(nat_file_path))
74
+ result = {
75
+ 'data': {},
76
+ 'metadata': {},
77
+ 'coordinates': {},
78
+ 'global_attrs': {
79
+ 'source_file': os.path.basename(nat_file_path),
80
+ 'processing_time': time_str,
81
+ 'resolution_degrees': resolution,
82
+ 'resampling_method': resampler,
83
+ 'area_extent': area_extent
84
+ }
85
+ }
86
+ # 提取每个通道的数据
87
+ for ch in channels:
88
+ try:
89
+ data_array = scn_wgs84[ch].values
90
+ result['data'][ch] = data_array
91
+ result['metadata'][ch] = dict(scn_wgs84[ch].attrs)
92
+ except Exception as e:
93
+ print(f"提取{ch}通道数据失败:{str(e)}")
94
+ # 提取坐标信息
95
+ if channels:
96
+ first_ch = channels[0]
97
+ try:
98
+ area = scn_wgs84[first_ch].attrs['area']
99
+ lons, lats = area.get_lonlats()
100
+ result['coordinates']['longitude'] = lons
101
+ result['coordinates']['latitude'] = lats
102
+ result['coordinates']['shape'] = lons.shape
103
+ except Exception as e:
104
+ print(f"提取坐标信息失败:{str(e)}")
105
+ return result
106
+ except Exception as e:
107
+ print(f"处理失败:{str(e)}")
108
+ traceback.print_exc()
109
+ return None
110
+ finally:
111
+ if lock is not None:
112
+ lock.release()
113
+
114
+
115
+ def getNPfromNAT(file_path, save_file=False, lock=None, return_coord=False, only_coord=False):
116
+ all_target_channels = ["WV_062", "WV_073", "IR_087", "IR_097", "IR_108", "IR_120", "IR_134"]
117
+ # all_target_channels = ["VIS006", "VIS008"]
118
+ if only_coord:
119
+ ld = _process_msg_seviri_to_numpy(nat_file_path=file_path, resolution=0.05, resampler="nearest",
120
+ channels=all_target_channels, lock=lock, only_ld=True)
121
+ return Coordinate(left=ld - 60, bottom=-60, right=ld + 60, top=60, x_res=0.05, y_res=0.05)
122
+ else:
123
+ result = _process_msg_seviri_to_numpy(nat_file_path=file_path, resolution=0.05, resampler="nearest",
124
+ channels=all_target_channels, lock=lock, only_ld=False)
125
+ np_data = None
126
+ coord = None
127
+ if result is not None:
128
+ for idx, channel in enumerate(all_target_channels):
129
+ if channel in result['data']:
130
+ chann_data = result['data'][channel].copy()
131
+ if np_data is None:
132
+ np_data = np.zeros((len(all_target_channels),) + chann_data.shape, dtype=chann_data.dtype)
133
+ area_extent = result['global_attrs']['area_extent']
134
+ coord = Coordinate(left=area_extent[0], bottom=area_extent[1], right=area_extent[2],
135
+ top=area_extent[3], x_res=0.05, y_res=0.05)
136
+ np_data[idx] = chann_data
137
+ else:
138
+ del result
139
+ raise Exception(f"文件{file_path},通道 {channel} 数据未找到")
140
+ else:
141
+ raise Exception(f"文件{file_path}处理失败,未获取数据")
142
+ if save_file:
143
+ np2tif(np_data, save_path='np2tif_dir', coord=coord, out_name='MetSat',
144
+ dtype='float32')
145
+ if return_coord:
146
+ return np_data, coord
147
+ else:
148
+ return np_data
149
+
150
+
151
+ def get_seviri_file_path(data_path, data_date):
152
+ e_date = data_date + timedelta(minutes=14, seconds=59)
153
+ parent_dir = rf'{data_path}/{data_date.year}/{data_date.month}/{data_date.day}/'
154
+ start_date_str = data_date.strftime('%Y%m%d%H%M%S')
155
+ end_date_str = e_date.strftime('%Y%m%d%H%M%S')
156
+ for file in os.listdir(parent_dir):
157
+ if file.endswith('.nat') and start_date_str <= file.split('-')[5].split('.')[0] <= end_date_str:
158
+ return os.path.join(parent_dir, file)
159
+ return None
160
+
161
+
162
+ if __name__ == '__main__':
163
+ np_data = getNPfromNAT("MSG4-SEVI-MSG15-0100-NA-20221230031243.610000000Z-NA.nat", save_file=False)
164
+ print()
@@ -0,0 +1,109 @@
1
+ import os.path
2
+
3
+ import torch
4
+ import numpy as np
5
+ from einops import rearrange
6
+
7
+
8
+ class PredNormalization:
9
+ def __init__(self, data_path):
10
+ self.mean = np.load(os.path.join(data_path, 'mean_level.npy')).astype(np.float32)[2:]
11
+ self.std = np.load(os.path.join(data_path, 'std_level.npy')).astype(np.float32)[2:]
12
+ self.mean = torch.from_numpy(self.mean)
13
+ self.std = torch.from_numpy(self.std)
14
+
15
+ def norm(self, data):
16
+ data = rearrange(data, 'b c h w->b h w c')
17
+ data = (data - self.mean) / self.std
18
+ return rearrange(data, 'b h w c->b c h w')
19
+
20
+ def denorm(self, data):
21
+ data = rearrange(data, 'b c h w->b h w c')
22
+ data = data * self.std + self.mean
23
+ return rearrange(data, 'b h w c->b c h w')
24
+
25
+
26
+ class PrecNormalization:
27
+ def __init__(self, data_path):
28
+ self.mean_qpe = torch.from_numpy(
29
+ np.load(os.path.join(data_path, 'mean_level_qpe.npy')).astype(np.float32))
30
+ self.std_qpe = torch.from_numpy(
31
+ np.load(os.path.join(data_path, 'std_level_qpe.npy')).astype(np.float32))
32
+ self.mean_fy = torch.from_numpy(
33
+ np.load(os.path.join(data_path, 'mean_level_fy.npy')).astype(np.float32).mean(axis=0)[2:])
34
+ self.std_fy = torch.from_numpy(
35
+ np.load(os.path.join(data_path, 'std_level_fy.npy')).astype(np.float32).mean(axis=0)[2:])
36
+
37
+ def norm(self, data, norm_type='fy'):
38
+ if norm_type == 'fy':
39
+ data = rearrange(data, 'b t c h w->b h w t c')
40
+ data = (data - self.mean_fy) / self.std_fy
41
+ return rearrange(data, 'b h w t c->b t c h w')
42
+ elif norm_type == 'qpe':
43
+ data = rearrange(data, 'b c h w->b h w c')
44
+ data = (data - self.mean_qpe) / self.std_qpe
45
+ return rearrange(data, 'b h w c->b c h w')
46
+
47
+ def denorm(self, data, norm_type='fy'):
48
+ if norm_type == 'fy':
49
+ data = rearrange(data, 'b t c h w->b h w t c')
50
+ data = data * self.std_fy + self.mean_fy
51
+ return rearrange(data, 'b h w t c->b t c h w')
52
+ elif norm_type == 'qpe':
53
+ data = rearrange(data, 'b c h w->b h w c')
54
+ data = data * self.std_qpe + self.mean_qpe
55
+ return rearrange(data, 'b h w c->b c h w')
56
+
57
+
58
+ class PremNormalization:
59
+
60
+ def __init__(self, prec_data_path):
61
+ self.mean = torch.from_numpy(
62
+ np.load(os.path.join(prec_data_path, 'imerg_mean.npy')).astype(np.float32))
63
+ self.std = torch.from_numpy(
64
+ np.load(os.path.join(prec_data_path, 'imerg_std.npy')).astype(np.float32))
65
+ # self.mean = torch.Tensor(
66
+ # [3.6614e+03, 3.3748e+03, 3.3829e+03, 2.8368e+03, 2.5664e+03, 2.4914e+03, 2.4259e+03, 0.5723, 0, 0])
67
+ # self.std = torch.Tensor([164.7376, 265.8857, 252.9820, 509.8994, 532.4901, 518.8191, 414.1427, 0.6308, 1, 1])
68
+ # print(self.mean, self.std)
69
+
70
+ def norm(self, data, fy_norm=True):
71
+ data = rearrange(data, 'b c h w->b h w c')
72
+ if fy_norm:
73
+ data = (data - self.mean[:7]) / self.std[:7]
74
+ else:
75
+ # data[:, 0][data[:, 0] == 0] = torch.inf
76
+ # data[:, 0] = 1 / torch.sqrt(data[:, 0].clone())
77
+ data = (data - self.mean[7:]) / self.std[7:]
78
+ return rearrange(data, 'b h w c->b c h w')
79
+
80
+ def denorm(self, data, fy_norm=True):
81
+ data = rearrange(data, 'b c h w->b h w c')
82
+ if fy_norm:
83
+ data = data * self.std[:7] + self.mean[:7]
84
+ else:
85
+ data = data * self.std[7:] + self.mean[7:]
86
+ # data[:, 0][data[:, 0] == 0] = torch.inf
87
+ # data[:, 0] = 1 / (data[:, 0].clone() ** 2)
88
+ return rearrange(data, 'b h w c->b c h w')
89
+
90
+
91
+ class Normalization:
92
+ def __init__(self, mean_std_npy, idx=None):
93
+ mean_std = torch.from_numpy(mean_std_npy.astype(np.float32))
94
+ if idx:
95
+ self.mean = mean_std[0, idx[0]:idx[1]]
96
+ self.std = mean_std[1, idx[0]:idx[1]]
97
+ else:
98
+ self.mean = mean_std[0]
99
+ self.std = mean_std[1]
100
+
101
+ def norm(self, data):
102
+ data = rearrange(data, 'b c h w->b h w c')
103
+ data = (data - self.mean) / self.std
104
+ return rearrange(data, 'b h w c->b c h w')
105
+
106
+ def denorm(self, data):
107
+ data = rearrange(data, 'b c h w->b h w c')
108
+ data = data * self.std + self.mean
109
+ return rearrange(data, 'b h w c->b c h w')
@@ -0,0 +1,300 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import os
7
+ from jacksung.utils.image import crop_png, zoom_image, zoomAndDock
8
+ import rasterio
9
+ from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
10
+ from matplotlib.colors import LinearSegmentedColormap
11
+ import cartopy.feature as cfeature
12
+ import cartopy.crs as ccrs
13
+ import yaml
14
+ import argparse
15
+ from datetime import datetime, timedelta
16
+ from matplotlib.ticker import MaxNLocator
17
+ import random
18
+
19
+
20
+ def load_model(model, state_dict, strict=True):
21
+ own_state = model.state_dict()
22
+ for name, param in state_dict.items():
23
+ name = name[name.index('.') + 1:]
24
+ if name in own_state.keys():
25
+ if isinstance(param, nn.Parameter):
26
+ param = param.data
27
+ try:
28
+ own_state[name].copy_(param)
29
+ # own_state[name].requires_grad = False
30
+ except Exception as e:
31
+ err_log = f'While copying the parameter named {name}, ' \
32
+ f'whose dimensions in the model are {own_state[name].size()} and ' \
33
+ f'whose dimensions in the checkpoint are {param.size()}.'
34
+ if not strict:
35
+ print(err_log)
36
+ else:
37
+ raise Exception(err_log)
38
+ elif strict:
39
+ raise KeyError(f'unexpected key {name} in {own_state.keys()}')
40
+ else:
41
+ print(f'{name} not loaded by model')
42
+
43
+
44
+ def get_stat_dict(metrics, extra_info=None):
45
+ if extra_info is None:
46
+ extra_info = dict()
47
+ stat_dict = {'epochs': 0, 'loss': [], 'metrics': {}}
48
+ for idx, metrics in enumerate(metrics):
49
+ name, default_value, op = metrics
50
+ stat_dict['metrics'][name] = {'value': [], 'best': {'value': default_value, 'epoch': 0, 'op': op}}
51
+ for key, value in extra_info.items():
52
+ stat_dict[key] = value
53
+ return stat_dict
54
+
55
+
56
+ def data_to_device(datas, device, fp=32):
57
+ outs = []
58
+ for data in datas:
59
+ if fp == 16:
60
+ data = data.type(torch.HalfTensor)
61
+ if fp == 64:
62
+ data = data.type(torch.DoubleTensor)
63
+ if fp == 32:
64
+ data = data.type(torch.FloatTensor)
65
+ data = data.to(device)
66
+ outs.append(data)
67
+ return outs
68
+
69
+
70
+ def draw_lines(stat_dict_path):
71
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
72
+ print('[TemporaryTag]Producing the LinePicture of Log...', end='[TemporaryTag]\n')
73
+ yaml_args = yaml.load(open(stat_dict_path), Loader=yaml.FullLoader)
74
+ # 创建图表
75
+ m_len = len(yaml_args['metrics'])
76
+ sub_loss_count = 0
77
+ for i in range(0, 5):
78
+ if f'loss{i}es' in yaml_args:
79
+ sub_loss_count += 1
80
+ else:
81
+ break
82
+ plt.figure(figsize=(8 * (m_len + sub_loss_count), 6)) # 设置图表的大小
83
+ x = np.array(range(1, yaml_args['epochs'] + 1))
84
+ for idx, d in enumerate(yaml_args['metrics'].items()):
85
+ m_name, m_value = d
86
+ y = np.array(m_value['value'])
87
+ # 生成数据
88
+ plt.subplot(1, m_len + sub_loss_count, idx + 1)
89
+ plt.plot(x, y)
90
+ plt.title(m_name)
91
+ plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
92
+ for i in range(0, sub_loss_count):
93
+ y = np.array(yaml_args[f'loss{i}es'])
94
+ plt.subplot(1, m_len + sub_loss_count, m_len + i + 1)
95
+ scale = len(y) / yaml_args['epochs']
96
+ # 把X坐标的纯统计范围缩放到和其他图表一致(epoch)
97
+ x = np.array(range(1, len(y) + 1)) / scale
98
+ plt.plot(x, y)
99
+ plt.title(f'Loss{i}')
100
+ plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
101
+ # 添加图例
102
+ # plt.legend()
103
+ plt.savefig(os.path.join(os.path.dirname(stat_dict_path), 'Metrics.jpg'))
104
+
105
+
106
+ def make_best_metric(stat_dict, metrics, epoch, save_model_param, server_log_param):
107
+ save_model_flag = False
108
+ experiment_model_path, model, optimizer, scheduler = save_model_param
109
+ log, epochs, cloudLogName = server_log_param
110
+
111
+ for name, m_value in metrics:
112
+ stat_dict['metrics'][name]['value'].append(m_value)
113
+ inf = float('inf')
114
+ if eval(str(m_value) + stat_dict['metrics'][name]['best']['op'] + str(
115
+ stat_dict['metrics'][name]['best']['value'])):
116
+ stat_dict['metrics'][name]['best']['value'] = m_value
117
+ stat_dict['metrics'][name]['best']['epoch'] = epoch
118
+ log.send_log('{}:{} epoch:{}/{}'.format(name, m_value, epoch, epochs), cloudLogName)
119
+ save_model_flag = True
120
+
121
+ if save_model_flag:
122
+ # sava best model
123
+ save_model(os.path.join(experiment_model_path, 'model_{}.pt'.format(epoch)), epoch,
124
+ model, optimizer, scheduler, stat_dict)
125
+ # '[Validation] nRMSE/RMSE: {:.4f}/{:.4f} (Best: {:.4f}/{:.4f}, Epoch: {}/{})\n'
126
+ test_log = f'[Val epoch:{epoch}] ' \
127
+ + ' '.join([str(m[0]) + ':' + str(round(m[1], 4)) + '('
128
+ + str(round(stat_dict['metrics'][m[0]]['best']['value'], 4)) + ')' for m in metrics]) \
129
+ + ' (Best Epoch: ' \
130
+ + '/'.join([str(stat_dict['metrics'][m[0]]['best']['epoch']) for m in metrics]) \
131
+ + ')'
132
+ save_model(os.path.join(experiment_model_path, 'model_latest.pt'), epoch, model, optimizer, scheduler, stat_dict)
133
+ return test_log
134
+
135
+
136
+ def save_model(_path, _epoch, _model, _optimizer, _scheduler, _stat_dict):
137
+ # torch.save(model.state_dict(), saved_model_path)
138
+ torch.save({
139
+ 'epoch': _epoch,
140
+ 'model_state_dict': _model.state_dict(),
141
+ 'optimizer_state_dict': _optimizer.state_dict(),
142
+ 'scheduler_state_dict': _scheduler.state_dict(),
143
+ 'stat_dict': _stat_dict
144
+ }, _path)
145
+
146
+
147
+ def parse_config(config=None, set_gpu=True):
148
+ parser = argparse.ArgumentParser(description='config')
149
+ parser.add_argument('--config', type=str, default=None, help='pre-config file for training')
150
+ parser.add_argument('--prec_data_path', type=str, default=None, help='dataset path')
151
+ args = parser.parse_args()
152
+ if args.config:
153
+ opt = vars(args)
154
+ yaml_args = yaml.load(open(args.config), Loader=yaml.FullLoader)
155
+ opt.update(yaml_args)
156
+ else:
157
+ opt = vars(args)
158
+ yaml_args = yaml.load(open(config), Loader=yaml.FullLoader)
159
+ opt.update(yaml_args)
160
+
161
+ if set_gpu:
162
+ # set visible gpu
163
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
164
+ os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(x) for x in args.gpu_ids])
165
+
166
+ # select active gpu devices
167
+ if args.gpu_ids is not None and torch.cuda.is_available():
168
+ print('use cuda & cudnn for acceleration!')
169
+ print('the gpu id is: {}'.format(args.gpu_ids))
170
+ device = torch.device('cuda')
171
+ # device = torch.device('cuda:' + str(args.gpu_ids[0]))
172
+ else:
173
+ print('use cpu for training!')
174
+ device = torch.device('cpu')
175
+ return device, args
176
+ else:
177
+ return args
178
+
179
+
180
+ def _get_color_normalization(data, colors):
181
+ max_value = colors[-1][0]
182
+ min_value = colors[0][0]
183
+ data[data < min_value] = min_value
184
+ data[data > max_value] = max_value
185
+ data = (data - min_value) / (max_value - min_value)
186
+ new_colors = []
187
+ for color in colors:
188
+ new_colors.append([(color[0] - min_value) / (max_value - min_value), color[1]])
189
+ return data, new_colors
190
+
191
+
192
+ def data_augmentation(images):
193
+ # 旋转
194
+ rotate = random.random()
195
+ if 0 <= rotate < 0.25:
196
+ images = [torch.rot90(image, 1, [2, 3]) for image in images]
197
+ elif 0.25 <= rotate < 0.5:
198
+ images = [torch.rot90(image, 2, [2, 3]) for image in images]
199
+ elif 0.5 <= rotate < 0.75:
200
+ images = [torch.rot90(image, 3, [2, 3]) for image in images]
201
+ # 水平翻折
202
+ if random.random() > 0.5:
203
+ images = [torch.flip(image, [2]) for image in images]
204
+ # 垂直翻折
205
+ if random.random() > 0.5:
206
+ images = [torch.flip(image, [3]) for image in images]
207
+ return images
208
+
209
+
210
+ def make_fig(file_name, root_path, out_folder=None, tz='UTC',
211
+ colors=((0, '#1E90FF'), (0.1, '#1874CD'), (0.2, '#3A5FCD'), (0.3, '#0000CD'), (1, '#9400D3')),
212
+ area=((100, 140, 10), (20, 60, 10)), font_size=20, corp=(0, 0, None, None),
213
+ zoom_rectangle=(310 * 5, 300 * 5, 50 * 5, 40 * 5), docker=(300, 730), dpi=500, filter=0.3, exposure=None):
214
+ # corp = [92, 31, 542, 456]
215
+ print('过期的make_figure方法!')
216
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
217
+ extents = [100, 140, 20, 60]
218
+ proj = ccrs.PlateCarree()
219
+ fig = plt.figure(dpi=dpi)
220
+ ax = fig.add_subplot(111, projection=proj)
221
+ ax.set_extent(extents, crs=proj)
222
+ # 读取TIFF数据
223
+ elevation = None
224
+ if type(file_name) == list:
225
+ for each_file in file_name:
226
+ file_path = os.path.join(root_path, each_file)
227
+ with rasterio.open(file_path) as dataset:
228
+ el_rd = dataset.read(1)
229
+ elevation[elevation < filter] = np.nan
230
+ if elevation is None:
231
+ elevation = el_rd
232
+ else:
233
+ elevation += el_rd
234
+ else:
235
+ file_path = os.path.join(root_path, file_name)
236
+ with rasterio.open(file_path) as dataset:
237
+ elevation = dataset.read(1)
238
+ elevation[elevation <= filter] = np.nan
239
+ elevation, colors = _get_color_normalization(elevation, colors)
240
+ cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)
241
+ # 添加各种特征
242
+ land = cfeature.NaturalEarthFeature('physical', 'land', '50m', edgecolor='black', facecolor='none',
243
+ linewidth=0.4)
244
+ ax.add_feature(land)
245
+ ax.imshow(elevation, origin='upper', extent=extents, transform=proj, cmap=cmap)
246
+ ax.add_feature(cfeature.OCEAN)
247
+ ax.add_feature(cfeature.LAND)
248
+ # ax.add_feature(cfeature.LAKES, edgecolor='black')
249
+ ax.add_feature(cfeature.RIVERS)
250
+ # ax.add_feature(cfeature.BORDERS)
251
+ # 添加网格线
252
+ # ax.gridlines(linestyle='--')
253
+ # 设置大刻度和小刻度
254
+ tick_proj = ccrs.PlateCarree()
255
+ ax.set_xticks(np.arange(area[0][0], area[0][1] + 1, area[0][2]), crs=tick_proj)
256
+ # ax.set_xticks(np.arange(-180, 180 + 30, 30), minor=True, crs=tick_proj)
257
+ ax.set_yticks(np.arange(area[1][0], area[1][1] + 1, area[1][2]), crs=tick_proj)
258
+ # ax.set_yticks(np.arange(-90, 90 + 15, 15), minor=True, crs=tick_proj)
259
+ # 利用Formatter格式化刻度标签
260
+ ax.xaxis.set_major_formatter(LongitudeFormatter())
261
+ ax.yaxis.set_major_formatter(LatitudeFormatter())
262
+ if out_folder is None:
263
+ if type(file_name) == list:
264
+ file_name = file_name[0]
265
+ file_dir = 'concate'
266
+ else:
267
+ file_dir = 'outs'
268
+ else:
269
+ file_dir = out_folder
270
+ os.makedirs(os.path.join(root_path, file_dir), exist_ok=True)
271
+ file_name = file_name.replace('.tif', '.png')
272
+ file_title = datetime.strptime(file_name.split('-')[0].replace('target_', ''), '%Y%m%d_%H%M%S')
273
+ file_name = file_title.strftime('%Y%m%d_%H%M%S.png')
274
+ exposure = exposure if exposure else (60 if file_dir == 'concate' else 15)
275
+ if tz == 'BJT':
276
+ exposure_end = (file_title + timedelta(hours=8) + timedelta(minutes=exposure)).strftime('%H:%M')
277
+ file_title = (file_title + timedelta(hours=8)).strftime('%Y-%m-%d %H:%M')
278
+ else:
279
+ exposure_end = (file_title + timedelta(minutes=exposure)).strftime('%H:%M')
280
+ file_title = file_title.strftime('%Y-%m-%d %H:%M')
281
+ ax.set_title(file_title + f'-' + exposure_end + f' ({tz})', fontsize=font_size)
282
+ plt.xticks(fontsize=font_size)
283
+ plt.yticks(fontsize=font_size)
284
+ # plt.title(fontsize=font_size)
285
+ file_save_path = os.path.join(root_path, file_dir, file_name)
286
+ plt.savefig(file_save_path)
287
+ if zoom_rectangle is not None:
288
+ read_png = cv2.imread(file_save_path)
289
+ read_png = zoomAndDock(read_png, zoom_rectangle, docker, scale_factor=5, border=14)
290
+ cv2.imwrite(file_save_path, read_png)
291
+ crop_png(file_save_path, left=corp[0], top=corp[1], right=corp[2], bottom=corp[3])
292
+ return file_save_path
293
+ # plt.show()
294
+
295
+
296
+ def clipSatelliteNP(np_data, ld, area=((100, 140, 10), (20, 60, 10))):
297
+ lon_d = int((ld - (area[0][0] + area[0][1]) / 2) * 20)
298
+ lat_d = int(((area[1][0] + area[1][1]) / 2) * 20)
299
+ np_data = np_data[:, 800 - lat_d:1600 - lat_d, 800 - lon_d:1600 - lon_d]
300
+ return np_data
Binary file
@@ -0,0 +1 @@
1
+ from . import *
@@ -0,0 +1,72 @@
1
+ import os
2
+ import time
3
+ import pymysql
4
+ from datetime import datetime, timedelta
5
+ from tqdm import tqdm
6
+
7
+ import threading
8
+ import multiprocessing
9
+ import traceback, sys
10
+
11
+ import configparser
12
+
13
+
14
+ def convert_str(s):
15
+ return "'" + str(s).replace("'", '"') + "'"
16
+
17
+
18
+ def convert_num(n):
19
+ return str(n)
20
+
21
+
22
+ class BaseDB:
23
+ def __init__(self, ini_path='db.ini'):
24
+ self.conn = None
25
+ self.db = None
26
+ self.passwd = None
27
+ self.user = None
28
+ self.host = None
29
+ self.port = None
30
+ self.autocommit = True
31
+ self.ini_path = ini_path
32
+ pymysql.install_as_MySQLdb()
33
+ self.lock_t = threading.Lock()
34
+
35
+ def read_config(self):
36
+ config = configparser.ConfigParser()
37
+ config.read(self.ini_path)
38
+ # Create the connection object
39
+ self.host = config['database']['host']
40
+ self.passwd = config['database']['password']
41
+ self.db = config['database']['database']
42
+ self.user = config['database'].get('user', 'root')
43
+ self.port = config['database'].getint('port', 3306)
44
+
45
+ def reconnect(self):
46
+ if not self.conn:
47
+ self.read_config()
48
+ # 打开数据库连接
49
+ connection = pymysql.connect(host=self.host, user=self.user, passwd=self.passwd, db=self.db, port=self.port,
50
+ autocommit=self.autocommit)
51
+ tqdm.write('Reconnected!')
52
+ return connection
53
+
54
+ def execute(self, sql):
55
+ self.lock_t.acquire()
56
+ cursor = None
57
+ try:
58
+ cursor = self.conn.cursor()
59
+ result = cursor.execute(sql)
60
+ except Exception as e:
61
+ self.conn = self.reconnect()
62
+ cursor = self.conn.cursor()
63
+ try:
64
+ result = cursor.execute(sql)
65
+ except Exception as e:
66
+ print(f'err SQL: {sql}')
67
+ raise e
68
+ finally:
69
+ if cursor:
70
+ cursor.close()
71
+ self.lock_t.release()
72
+ return result, cursor