jacksung-dev 0.0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. jacksung/__init__.py +1 -0
  2. jacksung/ai/GeoAttX.py +356 -0
  3. jacksung/ai/GeoNet/__init__.py +0 -0
  4. jacksung/ai/GeoNet/m_block.py +393 -0
  5. jacksung/ai/GeoNet/m_blockV2.py +442 -0
  6. jacksung/ai/GeoNet/m_network.py +107 -0
  7. jacksung/ai/GeoNet/m_networkV2.py +91 -0
  8. jacksung/ai/__init__.py +0 -0
  9. jacksung/ai/latex_tool.py +199 -0
  10. jacksung/ai/metrics.py +181 -0
  11. jacksung/ai/utils/__init__.py +0 -0
  12. jacksung/ai/utils/cmorph.py +42 -0
  13. jacksung/ai/utils/data_parallelV2.py +90 -0
  14. jacksung/ai/utils/fy.py +333 -0
  15. jacksung/ai/utils/goes.py +161 -0
  16. jacksung/ai/utils/gsmap.py +24 -0
  17. jacksung/ai/utils/imerg.py +159 -0
  18. jacksung/ai/utils/metsat.py +164 -0
  19. jacksung/ai/utils/norm_util.py +109 -0
  20. jacksung/ai/utils/util.py +300 -0
  21. jacksung/libs/times.ttf +0 -0
  22. jacksung/utils/__init__.py +1 -0
  23. jacksung/utils/base_db.py +72 -0
  24. jacksung/utils/cache.py +71 -0
  25. jacksung/utils/data_convert.py +273 -0
  26. jacksung/utils/exception.py +27 -0
  27. jacksung/utils/fastnumpy.py +115 -0
  28. jacksung/utils/figure.py +251 -0
  29. jacksung/utils/hash.py +26 -0
  30. jacksung/utils/image.py +221 -0
  31. jacksung/utils/log.py +86 -0
  32. jacksung/utils/login.py +149 -0
  33. jacksung/utils/mean_std.py +66 -0
  34. jacksung/utils/multi_task.py +129 -0
  35. jacksung/utils/number.py +6 -0
  36. jacksung/utils/nvidia.py +140 -0
  37. jacksung/utils/time.py +87 -0
  38. jacksung/utils/web.py +63 -0
  39. jacksung_dev-0.0.4.15.dist-info/LICENSE +201 -0
  40. jacksung_dev-0.0.4.15.dist-info/METADATA +228 -0
  41. jacksung_dev-0.0.4.15.dist-info/RECORD +44 -0
  42. jacksung_dev-0.0.4.15.dist-info/WHEEL +5 -0
  43. jacksung_dev-0.0.4.15.dist-info/entry_points.txt +3 -0
  44. jacksung_dev-0.0.4.15.dist-info/top_level.txt +1 -0
jacksung/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from jacksung import *
jacksung/ai/GeoAttX.py ADDED
@@ -0,0 +1,356 @@
1
+ import os.path
2
+ import random
3
+ from datetime import timedelta
4
+
5
+ import torch
6
+ from jacksung.utils.time import RemainTime, Stopwatch, cur_timestamp_str
7
+ from jacksung.ai.utils.norm_util import PredNormalization, PrecNormalization, Normalization
8
+ import numpy as np
9
+ from jacksung.utils.data_convert import np2tif
10
+ from jacksung.utils.cache import Cache
11
+ from jacksung.ai.utils import fy, goes, metsat
12
+ from einops import rearrange
13
+ from jacksung.ai.utils.util import parse_config, data_to_device, clipSatelliteNP
14
+ from jacksung.ai.GeoNet.m_network import GeoNet
15
+ from jacksung.ai.GeoNet.m_networkV2 import GeoNet as GeoNetV2
16
+ from jacksung.utils.exception import NoFileException, NanNPException
17
+ import torch.nn as nn
18
+
19
+ AGRI = 'agri'
20
+ ABI = 'abi'
21
+ SEVIRI = 'seviri'
22
+ AHI = 'ahi'
23
+ FCI = 'fci'
24
+
25
+
26
+ class GeoAttX:
27
+ def __init__(self, config=None, root_path=None, task_type=None, area=((100, 140, 10), (20, 60, 10)), lock=None):
28
+ self.root_path = None
29
+ self.timestamp = None
30
+ self.dir_name = None
31
+ self.device, self.args = parse_config(config)
32
+ self.task_type = task_type
33
+ self.set_root_path(root_path)
34
+ self.area = area
35
+ self.lock = lock
36
+
37
+ def get_root_path(self):
38
+ return self.root_path
39
+
40
+ def get_dir_name(self):
41
+ return self.dir_name
42
+
43
+ def set_root_path(self, root_path=None, dir_name=None):
44
+ self.timestamp = cur_timestamp_str()
45
+ root_path = root_path if root_path else self.args.save_path
46
+ dir_name = dir_name if dir_name else self.task_type + '-' + self.args.model + '-' + self.timestamp + '_' + str(
47
+ random.randint(1000, 9999))
48
+ self.dir_name = dir_name
49
+ self.root_path = os.path.join(root_path, dir_name)
50
+
51
+ def load_model(self, path, version=1, c_in=None):
52
+ if version == 1:
53
+ model = GeoNet(window_sizes=self.args.window_sizes, n_lgab=self.args.n_lgab, c_in=self.args.c_in,
54
+ c_lgan=self.args.c_lgan, r_expand=self.args.r_expand, down_sample=self.args.down_sample,
55
+ num_heads=self.args.num_heads, task=self.task_type if self.task_type else self.args.task,
56
+ downstage=self.args.downstage)
57
+ else:
58
+ model = GeoNetV2(window_sizes=self.args.window_sizes, n_lgab=self.args.n_lgab, c_in=c_in,
59
+ c_lgan=self.args.c_lgan, r_expand=self.args.r_expand, down_sample=self.args.down_sample,
60
+ num_heads=self.args.num_heads, downstage=self.args.downstage)
61
+
62
+ ckpt = torch.load(path, map_location=torch.device(self.device))
63
+ model.load(ckpt['model_state_dict'])
64
+ model = model.to(self.device)
65
+ model = model.eval()
66
+ model.requires_grad_(False)
67
+ return model
68
+
69
+
70
+ class GeoAttX_I(GeoAttX):
71
+ def __init__(self, data_path, x1_path, x4_path, x12_path, root_path=None, config='config_predict.yml',
72
+ area=((100, 140, 10), (20, 60, 10)), cache_size=1, lock=None, device=None):
73
+ super().__init__(config=config, root_path=root_path, task_type='pred', area=area, lock=lock)
74
+ if device is not None:
75
+ self.device = device
76
+ self.f, self.n, self.ys = None, None, None
77
+ self.data_path = data_path
78
+ self.x1 = self.load_model(x1_path)
79
+ self.x4 = self.load_model(x4_path)
80
+ self.x12 = self.load_model(x12_path)
81
+ # self.x48 = self.load_model(x48_path)
82
+ self.norm = PredNormalization(self.args.pred_data_path)
83
+ self.norm.mean, self.norm.std = data_to_device([self.norm.mean, self.norm.std], self.device, self.args.fp)
84
+ self.ld = None
85
+ self.cache = Cache(cache_size)
86
+
87
+ def save(self, file_name, ys):
88
+ file_info = fy.prase_filename(file_name)
89
+ ld = int(file_info["position"])
90
+ for idx, (k, y) in enumerate(ys.items()):
91
+ # coord = getFY_coord_min(ld)
92
+ coord = fy.getFY_coord_clip(self.area)
93
+ np2tif(y, save_path=self.root_path, out_name=f'{k.strftime("%Y%m%d_%H%M%S")}', coord=coord,
94
+ dtype=np.float32, print_log=False, dim_value=[{'value': [str(x) for x in list(range(9, 16))]}])
95
+ td = k - file_info['start']
96
+ mins = td.days * 24 * 60 + td.seconds // 60
97
+ target_filename = self.get_filename(file_name, mins)
98
+ p_path = self.get_path_by_filename(target_filename)
99
+ if idx >= 1 and os.path.exists(p_path):
100
+ coord = fy.getFY_coord_clip(self.area)
101
+ p_data = fy.getNPfromHDF(p_path)
102
+ p_data = clipSatelliteNP(p_data, ld=self.ld, area=self.area)
103
+ if p_data is not None:
104
+ p_data = p_data[2:, :, :]
105
+ np2tif(p_data, save_path=self.root_path, out_name=f'target_{k.strftime("%Y%m%d_%H%M%S")}',
106
+ print_log=False, coord=coord, dtype=np.float32,
107
+ dim_value=[{'value': [str(x) for x in list(range(9, 16))]}])
108
+ print(f'data saved in {self.root_path}')
109
+ with open(os.path.join(self.root_path, 'info.log'), 'w') as f:
110
+ f.write(f'输入数据:{file_info["start"]} {file_info["position"]} {file_info["end"]}\n')
111
+ for k, y in ys.items():
112
+ f.write(f'预测:{k}\n')
113
+ return self.root_path
114
+
115
+ def get_filename(self, file_name, mins):
116
+ file_info = fy.prase_filename(file_name)
117
+ new_file_name = file_name.replace(
118
+ f'{file_info["start"].strftime("%Y%m%d%H%M%S")}_{file_info["end"].strftime("%Y%m%d%H%M%S")}',
119
+ f'{(file_info["start"] + timedelta(minutes=mins)).strftime("%Y%m%d%H%M%S")}_'
120
+ f'{(file_info["end"] + timedelta(minutes=mins)).strftime("%Y%m%d%H%M%S")}')
121
+ return new_file_name
122
+
123
+ def get_path_by_filename(self, file_name):
124
+ file_info = fy.prase_filename(file_name)
125
+ return f'{self.data_path}{os.sep}downloaded_file{os.sep}{file_info["start"].year}{os.sep}{file_info["start"].month}{os.sep}{file_info["start"].day}{os.sep}{file_name}'
126
+ # return f'{self.data_path}/{file_name}'
127
+
128
+ def numpy2tensor(self, f_data):
129
+ f_data = torch.from_numpy(f_data)
130
+ f = data_to_device([f_data], self.device, self.args.fp)[0]
131
+ f = rearrange(f, '(b c) h w -> b c h w', b=1)
132
+ return f
133
+
134
+ def get_exist_by_filename_and_mins(self, file_name, mins):
135
+ f_path = self.get_filename(file_name, mins)
136
+ f_path = self.get_path_by_filename(f_path)
137
+ if not os.path.exists(f_path):
138
+ raise NoFileException(f_path)
139
+ f_data = fy.getNPfromHDF(f_path)
140
+ f_data = clipSatelliteNP(f_data, ld=self.ld, area=self.area)
141
+ if type(f_data) is not str:
142
+ f_data = f_data[2:, :, :]
143
+ else:
144
+ raise NanNPException(f_path)
145
+ # f_data = zoom(f_data.astype(np.float32), (1, 1 / 5, 1 / 5))
146
+ return self.numpy2tensor(f_data)
147
+
148
+ def mean_std2Tensor(self, in_data, h, w):
149
+ in_data = rearrange(in_data, '(b h w) c->b c h w', h=1, w=1)
150
+ in_data = in_data.expand(1, 7, h, w)
151
+ return in_data
152
+
153
+ def secdOrderStd(self, t_data, mean0, std0, mean, std):
154
+ h, w = t_data.shape[2], t_data.shape[3]
155
+ t_data = (t_data - self.mean_std2Tensor(mean0, h, w)) / self.mean_std2Tensor(std0, h, w)
156
+ t_data = t_data * self.mean_std2Tensor(std, h, w) + self.mean_std2Tensor(mean, h, w)
157
+ return t_data
158
+
159
+ def predict(self, file_name, step=360, p_steps=(48, 12, 4, 1), print_log=True):
160
+ try:
161
+ file_info = fy.prase_filename(file_name)
162
+ self.ld = int(file_info["position"])
163
+ step = step // 15
164
+ # if step > 24:
165
+ # step = 24
166
+ if print_log:
167
+ print(f'当前时刻:{file_info["start"]}\n预测长度:{step * 15}分钟')
168
+ task_progress = []
169
+ p_steps = sorted(p_steps, reverse=True)
170
+ while step > 0:
171
+ for p_step in p_steps:
172
+ if step >= p_step:
173
+ task_progress.append(p_step)
174
+ step -= p_step
175
+ break
176
+ task_progress.reverse()
177
+ if print_log:
178
+ print(f'正在预测:{file_info["start"] + timedelta(minutes=sum(task_progress) * 15)}...')
179
+ n = self.get_exist_by_filename_and_mins(file_name, 0)
180
+ now_date = file_info["start"]
181
+ porcess_list = {now_date: n.detach().cpu().numpy()[0]}
182
+ n = self.norm.norm(n)
183
+ o_mean = torch.mean(n, dim=(2, 3))
184
+ o_std = torch.std(n, dim=(2, 3))
185
+ for step in task_progress:
186
+ pre_date = now_date - timedelta(minutes=15 * step)
187
+ now_date += timedelta(minutes=15 * step)
188
+ if print_log:
189
+ print(f'正在预测 {now_date}...')
190
+ if pre_date in porcess_list:
191
+ f = porcess_list[pre_date]
192
+ f = self.numpy2tensor(f)
193
+ else:
194
+ f = self.get_exist_by_filename_and_mins(file_name,
195
+ -int((file_info["start"] - pre_date).seconds / 60))
196
+ f = self.norm.norm(f)
197
+ st = Stopwatch()
198
+ y_ = eval(f'self.x{step}(f, n)')
199
+ if print_log:
200
+ print(f'预测耗时: {st.reset()} 秒')
201
+ # mean and std再标准化
202
+ y_mean = torch.mean(y_, dim=(2, 3))
203
+ y_std = torch.std(y_, dim=(2, 3))
204
+ # 二次标准化
205
+ y_ = self.secdOrderStd(y_, y_mean, y_std, o_mean, o_std)
206
+ # y_ = torch.Tensor(np.zeros((1, 7, 2400, 2400), dtype=np.float32)).to(self.device)
207
+ # for i in range(5):
208
+ # for j in range(5):
209
+ # n_ = n[:, :, 480 * i:480 * (i + 1), 480 * j:480 * (j + 1)]
210
+ # f_ = f[:, :, 480 * i:480 * (i + 1), 480 * j:480 * (j + 1)]
211
+ # y_[:, :, 480 * i:480 * (i + 1), 480 * j:480 * (j + 1)] = eval(f'self.x{step}(f_, n_)')
212
+ del f
213
+ n = y_
214
+ porcess_list[now_date] = self.norm.denorm(y_).detach().cpu().numpy()[0]
215
+ except (NoFileException, NanNPException) as e:
216
+ os.makedirs(self.root_path, exist_ok=True)
217
+ with open(os.path.join(self.root_path, 'err.log'), 'a') as f:
218
+ filename = e.file_name.split(os.sep)[-1]
219
+ file_info = fy.prase_filename(filename)
220
+ f.write(f'{e.__class__}: {file_info["start"]} {e.file_name}\n')
221
+ return {}
222
+ return porcess_list
223
+
224
+
225
+ class GeoAttX_P(GeoAttX):
226
+ def __init__(self, model_path, root_path=None, config='predict_qpe.yml', area=((100, 140, 10), (20, 60, 10)),
227
+ device=None):
228
+ super().__init__(config=config, root_path=root_path, task_type='prec', area=area)
229
+ if device is not None:
230
+ self.device = device
231
+ self.model = self.load_model(model_path)
232
+
233
+ def save(self, y, save_name, info_log=True, print_log=True):
234
+ np2tif(y, save_path=self.root_path, out_name=save_name, coord=fy.getFY_coord_clip(self.area), dtype=np.float32,
235
+ print_log=False, dim_value=[{'value': ['qpe']}])
236
+ if print_log:
237
+ print(f'data saved in {self.root_path}')
238
+ if info_log:
239
+ with open(os.path.join(self.root_path, 'info.log'), 'w') as f:
240
+ f.write(f'QPE 反演:{save_name}\n')
241
+ return self.root_path
242
+
243
+ def predict(self, fy_npy):
244
+ try:
245
+ if type(fy_npy) is str:
246
+ print(f'正在反演:{fy_npy}...')
247
+ if not os.path.exists(fy_npy):
248
+ raise NoFileException(fy_npy)
249
+ n_data = np.load(fy_npy)
250
+ elif type(fy_npy) is np.ndarray:
251
+ n_data = fy_npy
252
+ else:
253
+ raise Exception('输入数据类型错误,仅支持文件路径或numpy数组')
254
+ n_data = torch.from_numpy(n_data)
255
+ norm = PrecNormalization(self.args.prec_data_path)
256
+ norm.mean_fy, norm.mean_qpe, norm.std_fy, norm.std_qpe = \
257
+ data_to_device([norm.mean_fy, norm.mean_qpe, norm.std_fy, norm.std_qpe], self.device, self.args.fp)
258
+ n_data = data_to_device([n_data], self.device, self.args.fp)[0]
259
+ n_data = rearrange(n_data, '(b t c) h w -> b t c h w', b=1, t=1)
260
+ n = norm.norm(n_data, norm_type='fy')[:, 0, :, :, :]
261
+ y_ = self.model(n, n)
262
+ y = norm.denorm(y_, norm_type='qpe').detach().cpu().numpy()[0]
263
+ return y
264
+ except NoFileException as e:
265
+ os.makedirs(self.root_path, exist_ok=True)
266
+ with open(os.path.join(self.root_path, 'err.log'), 'a') as f:
267
+ filename = e.file_name.split(os.sep)[-1]
268
+ file_info = fy.prase_filename(filename)
269
+ f.write(f'Not exist {file_info["start"]} {e.file_name}\n')
270
+ return None
271
+
272
+
273
+ class Huayu(GeoAttX):
274
+ def __init__(self, model_path, root_path=None, config='predict_imerg.yml', area=((100, 140, 10), (20, 60, 10)),
275
+ device=None):
276
+ super().__init__(config=config, root_path=root_path, task_type='prem', area=area)
277
+ if device is not None:
278
+ self.device = device
279
+ if self.args.sensor_type == AGRI:
280
+ self.sensor = AGRI
281
+ self.satellite_num = 2
282
+ self.satellite_channel = 7
283
+ elif self.args.sensor_type == ABI:
284
+ self.sensor = ABI
285
+ self.satellite_num = 3
286
+ self.satellite_channel = 9
287
+ elif self.args.sensor_type == SEVIRI:
288
+ self.sensor = SEVIRI
289
+ self.satellite_num = 2
290
+ self.satellite_channel = 7
291
+ else:
292
+ raise ValueError(f'Unknown sensor type{self.args.sensor_type}')
293
+ self.model = self.load_model(model_path, version=2, c_in=self.satellite_channel)
294
+
295
+ def save(self, y, save_name, info_log=True, print_log=True):
296
+ np2tif(y, save_path=self.root_path, out_name=save_name, coord=fy.getFY_coord_clip(self.area), dtype=np.float32,
297
+ print_log=False, dim_value=[{'value': ['imerg']}])
298
+ if print_log:
299
+ print(f'data saved in {self.root_path}')
300
+ if info_log:
301
+ with open(os.path.join(self.root_path, 'info.log'), 'w') as f:
302
+ f.write(f'Imerg 反演:{save_name}\n')
303
+ return self.root_path
304
+
305
+ def predict(self, satellite_file, npy_path=None, smooth=True, up=True, area=None, satellite_date=None):
306
+ try:
307
+ if npy_path is None:
308
+ if self.sensor == AGRI:
309
+ n_data, coord = fy.getNPfromHDF(satellite_file, return_coord=True)
310
+ elif self.sensor == ABI:
311
+ n_data, coord = goes.getNPfromDir(satellite_file, satellite_date, return_coord=True)
312
+ elif self.sensor == SEVIRI:
313
+ n_data, coord = metsat.getNPfromNAT(satellite_file, return_coord=True)
314
+ else:
315
+ raise ValueError(f'Unknown sensor type{self.args.sensor_type}')
316
+ n_data = clipSatelliteNP(n_data, coord.ld, self.area if area is None else area)
317
+ else:
318
+ n_data = np.load(npy_path)
319
+ mean_std_npy = np.load(os.path.join(rf'{self.args.data_path}', 'mean_std.npy'))
320
+ satellite_norm = Normalization(mean_std_npy, (0, self.satellite_channel))
321
+ imerg_norm = Normalization(mean_std_npy,
322
+ (self.satellite_channel * self.satellite_num,
323
+ self.satellite_channel * self.satellite_num + 3))
324
+ satellite_norm.mean, satellite_norm.std, imerg_norm.mean, imerg_norm.std = \
325
+ data_to_device([satellite_norm.mean, satellite_norm.std, imerg_norm.mean, imerg_norm.std],
326
+ self.device, self.args.fp)
327
+ n_data = data_to_device([n_data], self.device, self.args.fp)[0]
328
+ n_data = rearrange(n_data, '(b c) h w -> b c h w', b=1)
329
+ n = satellite_norm.norm(n_data)[:, :, :, :]
330
+ ps = nn.PixelShuffle(2)
331
+ ups = nn.PixelUnshuffle(2)
332
+ if smooth:
333
+ smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
334
+ n = ups(n)
335
+ n = rearrange(n, 'b (c dsize) h w -> (b dsize) c h w', dsize=4)
336
+ if not up:
337
+ n = n.mean(dim=0, keepdim=True)
338
+ y_ = self.model(n)
339
+ y_ = rearrange(y_, '(b dsize) c h w -> b (c dsize) h w', b=1)
340
+ if up:
341
+ y_ = ps(y_)
342
+ y = imerg_norm.denorm(y_)[0]
343
+ y[0][y[1] > y[2]] = 0
344
+ y[0][y[0] < 0] = 0
345
+ y = rearrange(y[0], '(b h) w -> b h w', b=1)
346
+ _, H, W = y.shape
347
+ if smooth:
348
+ y[0, 1:H - 1, 1:W - 1] = smooth(y)[0, 1:H - 1, 1:W - 1]
349
+ return y.detach().cpu().numpy()
350
+ except NoFileException as e:
351
+ os.makedirs(self.root_path, exist_ok=True)
352
+ with open(os.path.join(self.root_path, 'err.log'), 'a') as f:
353
+ filename = e.file_name.split(os.sep)[-1]
354
+ file_info = fy.prase_filename(filename)
355
+ f.write(f'Not exist {file_info["start"]} {e.file_name}\n')
356
+ return None
File without changes