jacksung-dev 0.0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. jacksung/__init__.py +1 -0
  2. jacksung/ai/GeoAttX.py +356 -0
  3. jacksung/ai/GeoNet/__init__.py +0 -0
  4. jacksung/ai/GeoNet/m_block.py +393 -0
  5. jacksung/ai/GeoNet/m_blockV2.py +442 -0
  6. jacksung/ai/GeoNet/m_network.py +107 -0
  7. jacksung/ai/GeoNet/m_networkV2.py +91 -0
  8. jacksung/ai/__init__.py +0 -0
  9. jacksung/ai/latex_tool.py +199 -0
  10. jacksung/ai/metrics.py +181 -0
  11. jacksung/ai/utils/__init__.py +0 -0
  12. jacksung/ai/utils/cmorph.py +42 -0
  13. jacksung/ai/utils/data_parallelV2.py +90 -0
  14. jacksung/ai/utils/fy.py +333 -0
  15. jacksung/ai/utils/goes.py +161 -0
  16. jacksung/ai/utils/gsmap.py +24 -0
  17. jacksung/ai/utils/imerg.py +159 -0
  18. jacksung/ai/utils/metsat.py +164 -0
  19. jacksung/ai/utils/norm_util.py +109 -0
  20. jacksung/ai/utils/util.py +300 -0
  21. jacksung/libs/times.ttf +0 -0
  22. jacksung/utils/__init__.py +1 -0
  23. jacksung/utils/base_db.py +72 -0
  24. jacksung/utils/cache.py +71 -0
  25. jacksung/utils/data_convert.py +273 -0
  26. jacksung/utils/exception.py +27 -0
  27. jacksung/utils/fastnumpy.py +115 -0
  28. jacksung/utils/figure.py +251 -0
  29. jacksung/utils/hash.py +26 -0
  30. jacksung/utils/image.py +221 -0
  31. jacksung/utils/log.py +86 -0
  32. jacksung/utils/login.py +149 -0
  33. jacksung/utils/mean_std.py +66 -0
  34. jacksung/utils/multi_task.py +129 -0
  35. jacksung/utils/number.py +6 -0
  36. jacksung/utils/nvidia.py +140 -0
  37. jacksung/utils/time.py +87 -0
  38. jacksung/utils/web.py +63 -0
  39. jacksung_dev-0.0.4.15.dist-info/LICENSE +201 -0
  40. jacksung_dev-0.0.4.15.dist-info/METADATA +228 -0
  41. jacksung_dev-0.0.4.15.dist-info/RECORD +44 -0
  42. jacksung_dev-0.0.4.15.dist-info/WHEEL +5 -0
  43. jacksung_dev-0.0.4.15.dist-info/entry_points.txt +3 -0
  44. jacksung_dev-0.0.4.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,393 @@
1
+ import random
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ import numpy as np
8
+
9
+
10
+ class ShiftConv2d0(nn.Module):
11
+ def __init__(self, inp_channels, out_channels, stride):
12
+ super(ShiftConv2d0, self).__init__()
13
+ self.inp_channels = inp_channels
14
+ self.out_channels = out_channels
15
+ self.n_div = 5
16
+ self.stride = stride
17
+ g = inp_channels // self.n_div
18
+
19
+ conv3x3 = nn.Conv2d(inp_channels, out_channels, 3, 1, 1)
20
+ mask = nn.Parameter(torch.zeros((self.out_channels, self.inp_channels, 3, 3)), requires_grad=False)
21
+ mask[:, 0 * g:1 * g, 1, 2] = 1.0
22
+ mask[:, 1 * g:2 * g, 1, 0] = 1.0
23
+ mask[:, 2 * g:3 * g, 2, 1] = 1.0
24
+ mask[:, 3 * g:4 * g, 0, 1] = 1.0
25
+ mask[:, 4 * g:, 1, 1] = 1.0
26
+ self.w = conv3x3.weight
27
+ self.b = conv3x3.bias
28
+ self.m = mask
29
+
30
+ def forward(self, x):
31
+ y = F.conv2d(input=x, weight=self.w * self.m, bias=self.b, stride=self.stride, padding=1)
32
+ return y
33
+
34
+
35
+ class ShiftConv2d1(nn.Module):
36
+ def __init__(self, inp_channels, out_channels, stride):
37
+ super(ShiftConv2d1, self).__init__()
38
+ self.inp_channels = inp_channels
39
+ self.out_channels = out_channels
40
+ self.stride = stride
41
+ self.weight = nn.Parameter(torch.zeros(inp_channels, 1, 3, 3), requires_grad=False)
42
+ self.n_div = 5
43
+ g = inp_channels // self.n_div
44
+
45
+ channels_idx = list(range(inp_channels))
46
+ random.shuffle(channels_idx)
47
+ self.weight[channels_idx[0 * g:1 * g], 0, 1, 2] = 1.0 ## left
48
+ self.weight[channels_idx[1 * g:2 * g], 0, 1, 0] = 1.0 ## right
49
+ self.weight[channels_idx[2 * g:3 * g], 0, 2, 1] = 1.0 ## up
50
+ self.weight[channels_idx[3 * g:4 * g], 0, 0, 1] = 1.0 ## down
51
+ self.weight[channels_idx[4 * g:], 0, 1, 1] = 1.0 ## identity
52
+ self.conv1x1 = nn.Conv2d(inp_channels, out_channels, 1)
53
+
54
+ def forward(self, x):
55
+ y = F.conv2d(input=x, weight=self.weight, bias=None, stride=self.stride, padding=1, groups=self.inp_channels)
56
+ y = self.conv1x1(y)
57
+ return y
58
+
59
+
60
+ class ShiftConv2d(nn.Module):
61
+ def __init__(self, inp_channels, out_channels, conv_type='conv3', stride=1):
62
+ super(ShiftConv2d, self).__init__()
63
+ self.inp_channels = inp_channels
64
+ self.out_channels = out_channels
65
+ self.conv_type = conv_type
66
+ if conv_type == 'low-training-memory':
67
+ self.shift_conv = ShiftConv2d0(inp_channels, out_channels, stride=stride)
68
+ elif conv_type == 'fast-training-speed':
69
+ self.shift_conv = ShiftConv2d1(inp_channels, out_channels, stride=stride)
70
+ elif conv_type == 'common':
71
+ self.shift_conv = nn.Conv2d(inp_channels, out_channels, kernel_size=1, stride=stride)
72
+ elif conv_type == 'conv3':
73
+ self.shift_conv = nn.Conv2d(inp_channels, out_channels, kernel_size=3, stride=stride, padding=1)
74
+ else:
75
+ raise ValueError('invalid type of shift-conv2d')
76
+
77
+ def forward(self, x):
78
+ y = self.shift_conv(x)
79
+ return y
80
+
81
+
82
+ class ACT(nn.Module):
83
+ def __init__(self):
84
+ super(ACT, self).__init__()
85
+ # self.act = nn.Mish()
86
+ self.act = nn.SiLU()
87
+
88
+ def forward(self, x):
89
+ return self.act(x)
90
+
91
+
92
+ class DownBlock(nn.Module):
93
+ def __init__(self, c_lgan, downscale=2):
94
+ super(DownBlock, self).__init__()
95
+ self.down = nn.Conv2d(c_lgan, c_lgan * downscale, kernel_size=downscale, stride=downscale)
96
+ self.norm = Norm(c_lgan * downscale)
97
+ self.act = ACT()
98
+ self.conv = nn.Conv2d(c_lgan * downscale, c_lgan * downscale, kernel_size=3, stride=1, padding=1)
99
+
100
+ def forward(self, x):
101
+ x = self.down(x)
102
+ x = self.norm(x)
103
+ x = self.act(x)
104
+ x = self.conv(x)
105
+ return x
106
+
107
+
108
+ class UpBlock(nn.Module):
109
+ def __init__(self, c_lgan, downscale=2):
110
+ super(UpBlock, self).__init__()
111
+ self.conv = nn.Conv2d(c_lgan * downscale, c_lgan * downscale, kernel_size=3, stride=1, padding=1)
112
+ self.norm = Norm(c_lgan * downscale)
113
+ self.act = ACT()
114
+ self.up = nn.ConvTranspose2d(c_lgan * downscale, c_lgan, kernel_size=downscale, stride=downscale)
115
+
116
+ def forward(self, x):
117
+ x = self.conv(x)
118
+ x = self.norm(x)
119
+ x = self.act(x)
120
+ x = self.up(x)
121
+ return x
122
+
123
+
124
+ class Norm(nn.Module):
125
+ def __init__(self, c_in):
126
+ super(Norm, self).__init__()
127
+ self.norm = nn.BatchNorm2d(c_in)
128
+ # self.norm = nn.GroupNorm(4, c_in, eps=1e-6, affine=True)
129
+
130
+ def forward(self, x):
131
+ x = self.norm(x)
132
+ return x
133
+
134
+
135
+ class CubeEmbeding(nn.Module):
136
+ def __init__(self, c_lgan, c_in, down_sample=1):
137
+ super(CubeEmbeding, self).__init__()
138
+ self.embeding = nn.Conv3d(c_in, c_lgan, kernel_size=(2, down_sample, down_sample),
139
+ stride=(2, down_sample, down_sample), padding=(0, 0, 0))
140
+ self.norm = Norm(c_lgan)
141
+ self.act = ACT()
142
+ self.conv2 = ShiftConv2d(c_lgan, c_lgan)
143
+
144
+ def forward(self, x):
145
+ x = self.embeding(x)
146
+ x = self.norm(x[:, :, 0, :, :])
147
+ x = self.act(x)
148
+ x = self.conv2(x)
149
+ return x
150
+
151
+
152
+ class CubeUnEmbeding(nn.Module):
153
+ def __init__(self, c_lgan, c_in, downscale=2):
154
+ super(CubeUnEmbeding, self).__init__()
155
+ self.conv1 = nn.Conv2d(c_lgan, c_lgan, kernel_size=3, stride=1, padding=1)
156
+ self.norm = Norm(c_lgan)
157
+ self.act = ACT()
158
+ self.up = nn.ConvTranspose2d(c_lgan, c_lgan, kernel_size=downscale, stride=downscale)
159
+ self.conv2 = nn.Conv2d(c_lgan, c_in, kernel_size=3, stride=1, padding=1)
160
+
161
+ def forward(self, x):
162
+ x = self.conv1(x)
163
+ x = self.norm(x)
164
+ x = self.act(x)
165
+ x = self.up(x)
166
+ x = self.conv2(x)
167
+ return x
168
+
169
+
170
+ class Head(nn.Module):
171
+ def __init__(self, c_lgan, c_in, down_sample):
172
+ super(Head, self).__init__()
173
+ self.stage = nn.ModuleList()
174
+ while down_sample >= 2:
175
+ self.stage.append(
176
+ ShiftConv2d(c_in, c_in * 2, stride=2))
177
+ c_in *= 2
178
+ self.stage.append(Norm(c_in))
179
+ self.stage.append(ACT())
180
+ down_sample = down_sample // 2
181
+
182
+ self.conv2 = ShiftConv2d(c_in, c_lgan)
183
+
184
+ def forward(self, x):
185
+ # body
186
+ for stage in self.stage:
187
+ x = stage(x)
188
+ x = self.conv2(x)
189
+ return x
190
+
191
+
192
+ class Tail(nn.Module):
193
+ def __init__(self, c_lgan, c_in, down_sample):
194
+ super(Tail, self).__init__()
195
+ self.down_sample = down_sample
196
+ self.conv1 = nn.Conv2d(c_lgan, c_lgan, kernel_size=3, stride=1, padding=1)
197
+ # self.norm = nn.BatchNorm2d(c_in * down_sample * down_sample)
198
+ self.norm = Norm(c_lgan)
199
+ self.act = ACT()
200
+ if self.down_sample == 1:
201
+ self.conv2 = nn.Conv2d(c_lgan, c_in, kernel_size=3,
202
+ stride=1, padding=1)
203
+ else:
204
+ self.conv2 = nn.Conv2d(c_lgan, c_in * down_sample * down_sample, kernel_size=3,
205
+ stride=1, padding=1)
206
+ self.ps = nn.PixelShuffle(down_sample)
207
+
208
+ def forward(self, x):
209
+ x = self.conv1(x)
210
+ x = self.norm(x)
211
+ x = self.act(x)
212
+ x = self.conv2(x)
213
+ if self.down_sample != 1:
214
+ x = self.ps(x)
215
+ return x
216
+
217
+
218
+ class FD(nn.Module):
219
+ def __init__(self, inp_channels, out_channels, exp_ratio=4):
220
+ super(FD, self).__init__()
221
+ # self.fc1 = MLP(inp_channels, inp_channels * exp_ratio)
222
+ # self.fc2 = MLP(inp_channels * exp_ratio, out_channels)
223
+ self.fc1 = ShiftConv2d(inp_channels, inp_channels * exp_ratio)
224
+ self.fc2 = ShiftConv2d(inp_channels * exp_ratio, out_channels)
225
+ # self.fc1 = nn.Conv2d(inp_channels, inp_channels * exp_ratio, kernel_size=3, stride=1, padding=1)
226
+ # self.fc2 = nn.Conv2d(inp_channels * exp_ratio, out_channels, kernel_size=3, stride=1, padding=1)
227
+ self.act1 = ACT()
228
+ self.act2 = ACT()
229
+
230
+ def forward(self, x):
231
+ y = self.fc1(x)
232
+ y = self.act1(y)
233
+ y = self.fc2(y)
234
+ return y
235
+
236
+
237
+ class LGAB(nn.Module):
238
+
239
+ def __init__(self, channels, window_size=5, num_heads=8, split_part=3):
240
+ super(LGAB, self).__init__()
241
+ self.num_heads = num_heads
242
+ self.window_size = window_size
243
+ self.split_chns = [int(channels * 2 / split_part) for _ in range(split_part)]
244
+ self.f_split_chns = [int(channels / split_part) for _ in range(split_part)]
245
+ # self.project_inp = MLP(channels, channels * 3)
246
+ # self.project_out = MLP(channels, channels)
247
+ self.project_inp = ShiftConv2d(channels, channels * 2)
248
+ self.f_project_inp = ShiftConv2d(channels, channels)
249
+ self.project_out = ShiftConv2d(channels, channels)
250
+ # self.f_project_inp = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
251
+ # self.project_inp = nn.Conv2d(channels, channels * 2, kernel_size=1, stride=1)
252
+ # self.project_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
253
+
254
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((self.num_heads, 1, 1))), requires_grad=True)
255
+ self.lr_logit_scale = nn.Parameter(torch.log(10 * torch.ones((self.num_heads, 1, 1))), requires_grad=True)
256
+ # #########################################################################
257
+ # mlp to generate continuous relative position bias
258
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
259
+ nn.ReLU(inplace=True),
260
+ nn.Linear(512, self.num_heads, bias=False))
261
+
262
+ # get relative_coords_table
263
+ relative_coords_h = torch.arange(-(self.window_size - 1), self.window_size, dtype=torch.float32)
264
+ relative_coords_w = torch.arange(-(self.window_size - 1), self.window_size, dtype=torch.float32)
265
+ relative_coords_table = (torch.stack(
266
+ torch.meshgrid([relative_coords_h, relative_coords_w], indexing='ij'))
267
+ .permute(1, 2, 0).contiguous().unsqueeze(0)) # 1, 2*Wh-1, 2*Ww-1, 2
268
+
269
+ relative_coords_table[:, :, :, 0] /= (self.window_size - 1)
270
+ relative_coords_table[:, :, :, 1] /= (self.window_size - 1)
271
+ relative_coords_table *= 8 # normalize to -8, 8
272
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
273
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
274
+
275
+ self.register_buffer("relative_coords_table", relative_coords_table)
276
+
277
+ # get pair-wise relative position index for each token inside the window
278
+ coords_h = torch.arange(self.window_size)
279
+ coords_w = torch.arange(self.window_size)
280
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
281
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
282
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
283
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
284
+ relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
285
+ relative_coords[:, :, 1] += self.window_size - 1
286
+ relative_coords[:, :, 0] *= 2 * self.window_size - 1
287
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
288
+ self.register_buffer("relative_position_index", relative_position_index)
289
+ # #########################################################################
290
+
291
+ def wa(self, f, x, wsize):
292
+ b, c, h, w = x.shape
293
+ q = rearrange(
294
+ f, 'b (head c) (h dh) (w dw) -> (b h w) head (dh dw) c',
295
+ dh=wsize, dw=wsize, head=self.num_heads
296
+ )
297
+ k, v = rearrange(
298
+ x, 'b (kv head c) (h dh) (w dw) -> kv (b h w) head (dh dw) c',
299
+ kv=2, dh=wsize, dw=wsize, head=self.num_heads
300
+ )
301
+ atn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
302
+ t = torch.tensor(1. / 0.01).to(atn.device)
303
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(t)).exp()
304
+ atn = atn * logit_scale
305
+
306
+ atn = atn.softmax(dim=-1)
307
+ y_ = (atn @ v)
308
+ y_ = rearrange(y_, '(b h w) head (dh dw) c-> b (head c) (h dh) (w dw)',
309
+ h=h // wsize, w=w // wsize, dh=wsize, dw=wsize, head=self.num_heads)
310
+ return y_
311
+
312
+ def forward(self, x):
313
+ b, c, h, w = x.shape
314
+ x_ = x
315
+ x = self.project_inp(x_)
316
+ xs = torch.split(x, self.split_chns, dim=1)
317
+ f = self.f_project_inp(x_)
318
+ fs = torch.split(f, self.f_split_chns, dim=1)
319
+ wsize = self.window_size
320
+ ys = []
321
+ # window attention
322
+ y_ = self.wa(fs[0], xs[0], wsize)
323
+ ys.append(y_)
324
+ # shifted window attention
325
+ x_ = torch.roll(xs[1], shifts=(-wsize // 2, -wsize // 2), dims=(2, 3))
326
+ f_ = torch.roll(fs[1], shifts=(-wsize // 2, -wsize // 2), dims=(2, 3))
327
+ y_ = self.wa(f_, x_, wsize)
328
+ y_ = torch.roll(y_, shifts=(wsize // 2, wsize // 2), dims=(2, 3))
329
+ ys.append(y_)
330
+
331
+ # long-range attentin
332
+ # for longitude
333
+ q = rearrange(fs[2], 'b (head c) h w -> (b h) head w c', head=self.num_heads)
334
+ k, v = rearrange(xs[2], 'b (kv head c) h w -> kv (b h) head w c', kv=2, head=self.num_heads)
335
+ atn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
336
+ t = torch.tensor(1. / 0.01).to(atn.device)
337
+ logit_scale = torch.clamp(self.lr_logit_scale, max=torch.log(t)).exp()
338
+ atn = atn * logit_scale
339
+ # atn = (q @ k.transpose(-2, -1))
340
+ atn = atn.softmax(dim=-1)
341
+ v = (atn @ v)
342
+ # for latitude
343
+ q, k, v = (rearrange(q, '(b h) head w c -> (b w) head h c', h=h),
344
+ rearrange(k, '(b h) head w c -> (b w) head h c', h=h),
345
+ rearrange(v, '(b h) head w c -> (b w) head h c', h=h))
346
+ atn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
347
+ atn = atn * logit_scale
348
+ # atn = (q @ k.transpose(-2, -1))
349
+ atn = atn.softmax(dim=-1)
350
+ v = (atn @ v)
351
+ y_ = rearrange(v, '(b w) head h c-> b (head c) h w', b=b)
352
+ ys.append(y_)
353
+
354
+ y = torch.cat(ys, dim=1)
355
+ y = self.project_out(y)
356
+ return y
357
+
358
+
359
+ class FEB(nn.Module):
360
+ def __init__(self, inp_channels, exp_ratio=2, window_size=5, num_heads=8):
361
+ super(FEB, self).__init__()
362
+ self.exp_ratio = exp_ratio
363
+ self.inp_channels = inp_channels
364
+ self.down = DownBlock(inp_channels, downscale=2)
365
+ self.up = UpBlock(inp_channels, downscale=2)
366
+
367
+ self.FD = FD(inp_channels=inp_channels * 2, out_channels=inp_channels * 2, exp_ratio=exp_ratio)
368
+ self.LGAB = LGAB(channels=inp_channels * 2, window_size=window_size, num_heads=num_heads)
369
+ self.norm1 = Norm(inp_channels * 2)
370
+ self.norm2 = Norm(inp_channels * 2)
371
+ self.drop = nn.Dropout2d(0.2)
372
+
373
+ def forward(self, x):
374
+ res = x
375
+ x = self.down(x)
376
+ shortcut = x
377
+ x = self.LGAB(x)
378
+ x = self.drop(x)
379
+ x = self.norm1(x) + shortcut
380
+ shortcut = x
381
+ x = self.FD(x)
382
+ x = self.norm2(x) + shortcut
383
+ x = self.up(x)
384
+ x = x + res
385
+ return x
386
+
387
+
388
+ if __name__ == '__main__':
389
+ input_data = torch.zeros((5, 32, 32))
390
+ conv = nn.Conv2d(5, 5, stride=2, kernel_size=3, padding=1)
391
+ for i in range(5):
392
+ input_data = conv(input_data)
393
+ print(input_data.shape)