jacksung-dev 0.0.4.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jacksung/__init__.py +1 -0
- jacksung/ai/GeoAttX.py +356 -0
- jacksung/ai/GeoNet/__init__.py +0 -0
- jacksung/ai/GeoNet/m_block.py +393 -0
- jacksung/ai/GeoNet/m_blockV2.py +442 -0
- jacksung/ai/GeoNet/m_network.py +107 -0
- jacksung/ai/GeoNet/m_networkV2.py +91 -0
- jacksung/ai/__init__.py +0 -0
- jacksung/ai/latex_tool.py +199 -0
- jacksung/ai/metrics.py +181 -0
- jacksung/ai/utils/__init__.py +0 -0
- jacksung/ai/utils/cmorph.py +42 -0
- jacksung/ai/utils/data_parallelV2.py +90 -0
- jacksung/ai/utils/fy.py +333 -0
- jacksung/ai/utils/goes.py +161 -0
- jacksung/ai/utils/gsmap.py +24 -0
- jacksung/ai/utils/imerg.py +159 -0
- jacksung/ai/utils/metsat.py +164 -0
- jacksung/ai/utils/norm_util.py +109 -0
- jacksung/ai/utils/util.py +300 -0
- jacksung/libs/times.ttf +0 -0
- jacksung/utils/__init__.py +1 -0
- jacksung/utils/base_db.py +72 -0
- jacksung/utils/cache.py +71 -0
- jacksung/utils/data_convert.py +273 -0
- jacksung/utils/exception.py +27 -0
- jacksung/utils/fastnumpy.py +115 -0
- jacksung/utils/figure.py +251 -0
- jacksung/utils/hash.py +26 -0
- jacksung/utils/image.py +221 -0
- jacksung/utils/log.py +86 -0
- jacksung/utils/login.py +149 -0
- jacksung/utils/mean_std.py +66 -0
- jacksung/utils/multi_task.py +129 -0
- jacksung/utils/number.py +6 -0
- jacksung/utils/nvidia.py +140 -0
- jacksung/utils/time.py +87 -0
- jacksung/utils/web.py +63 -0
- jacksung_dev-0.0.4.15.dist-info/LICENSE +201 -0
- jacksung_dev-0.0.4.15.dist-info/METADATA +228 -0
- jacksung_dev-0.0.4.15.dist-info/RECORD +44 -0
- jacksung_dev-0.0.4.15.dist-info/WHEEL +5 -0
- jacksung_dev-0.0.4.15.dist-info/entry_points.txt +3 -0
- jacksung_dev-0.0.4.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ShiftConv2d0(nn.Module):
|
|
12
|
+
def __init__(self, inp_channels, out_channels, stride):
|
|
13
|
+
super(ShiftConv2d0, self).__init__()
|
|
14
|
+
self.inp_channels = inp_channels
|
|
15
|
+
self.out_channels = out_channels
|
|
16
|
+
self.n_div = 5
|
|
17
|
+
self.stride = stride
|
|
18
|
+
g = inp_channels // self.n_div
|
|
19
|
+
|
|
20
|
+
conv3x3 = nn.Conv2d(inp_channels, out_channels, 3, 1, 1)
|
|
21
|
+
mask = nn.Parameter(torch.zeros((self.out_channels, self.inp_channels, 3, 3)), requires_grad=False)
|
|
22
|
+
mask[:, 0 * g:1 * g, 1, 2] = 1.0
|
|
23
|
+
mask[:, 1 * g:2 * g, 1, 0] = 1.0
|
|
24
|
+
mask[:, 2 * g:3 * g, 2, 1] = 1.0
|
|
25
|
+
mask[:, 3 * g:4 * g, 0, 1] = 1.0
|
|
26
|
+
mask[:, 4 * g:, 1, 1] = 1.0
|
|
27
|
+
self.w = conv3x3.weight
|
|
28
|
+
self.b = conv3x3.bias
|
|
29
|
+
self.m = mask
|
|
30
|
+
|
|
31
|
+
def forward(self, x):
|
|
32
|
+
y = F.conv2d(input=x, weight=self.w * self.m, bias=self.b, stride=self.stride, padding=1)
|
|
33
|
+
return y
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ShiftConv2d1(nn.Module):
|
|
37
|
+
def __init__(self, inp_channels, out_channels, stride):
|
|
38
|
+
super(ShiftConv2d1, self).__init__()
|
|
39
|
+
self.inp_channels = inp_channels
|
|
40
|
+
self.out_channels = out_channels
|
|
41
|
+
self.stride = stride
|
|
42
|
+
self.weight = nn.Parameter(torch.zeros(inp_channels, 1, 3, 3), requires_grad=False)
|
|
43
|
+
self.n_div = 5
|
|
44
|
+
g = inp_channels // self.n_div
|
|
45
|
+
|
|
46
|
+
channels_idx = list(range(inp_channels))
|
|
47
|
+
random.shuffle(channels_idx)
|
|
48
|
+
self.weight[channels_idx[0 * g:1 * g], 0, 1, 2] = 1.0 ## left
|
|
49
|
+
self.weight[channels_idx[1 * g:2 * g], 0, 1, 0] = 1.0 ## right
|
|
50
|
+
self.weight[channels_idx[2 * g:3 * g], 0, 2, 1] = 1.0 ## up
|
|
51
|
+
self.weight[channels_idx[3 * g:4 * g], 0, 0, 1] = 1.0 ## down
|
|
52
|
+
self.weight[channels_idx[4 * g:], 0, 1, 1] = 1.0 ## identity
|
|
53
|
+
self.conv1x1 = nn.Conv2d(inp_channels, out_channels, 1)
|
|
54
|
+
|
|
55
|
+
def forward(self, x):
|
|
56
|
+
y = F.conv2d(input=x, weight=self.weight, bias=None, stride=self.stride, padding=1, groups=self.inp_channels)
|
|
57
|
+
y = self.conv1x1(y)
|
|
58
|
+
return y
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ShiftConv2d(nn.Module):
|
|
62
|
+
def __init__(self, inp_channels, out_channels, conv_type='conv3', stride=1):
|
|
63
|
+
super(ShiftConv2d, self).__init__()
|
|
64
|
+
self.inp_channels = inp_channels
|
|
65
|
+
self.out_channels = out_channels
|
|
66
|
+
self.conv_type = conv_type
|
|
67
|
+
if conv_type == 'low-training-memory':
|
|
68
|
+
self.shift_conv = ShiftConv2d0(inp_channels, out_channels, stride=stride)
|
|
69
|
+
elif conv_type == 'fast-training-speed':
|
|
70
|
+
self.shift_conv = ShiftConv2d1(inp_channels, out_channels, stride=stride)
|
|
71
|
+
elif conv_type == 'common':
|
|
72
|
+
self.shift_conv = nn.Conv2d(inp_channels, out_channels, kernel_size=1, stride=stride)
|
|
73
|
+
elif conv_type == 'conv3':
|
|
74
|
+
self.shift_conv = nn.Conv2d(inp_channels, out_channels, kernel_size=3, stride=stride, padding=1)
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError('invalid type of shift-conv2d')
|
|
77
|
+
|
|
78
|
+
def forward(self, x):
|
|
79
|
+
y = self.shift_conv(x)
|
|
80
|
+
return y
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ACT(nn.Module):
|
|
84
|
+
def __init__(self):
|
|
85
|
+
super(ACT, self).__init__()
|
|
86
|
+
# self.act = nn.Mish()
|
|
87
|
+
self.act = nn.SiLU()
|
|
88
|
+
|
|
89
|
+
def forward(self, x):
|
|
90
|
+
return self.act(x)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class DownBlock(nn.Module):
|
|
94
|
+
def __init__(self, c_lgan, downscale=2):
|
|
95
|
+
super(DownBlock, self).__init__()
|
|
96
|
+
self.down = nn.Conv2d(c_lgan, c_lgan * downscale, kernel_size=downscale, stride=downscale)
|
|
97
|
+
self.norm = Norm(c_lgan * downscale)
|
|
98
|
+
self.act = ACT()
|
|
99
|
+
self.conv = nn.Conv2d(c_lgan * downscale, c_lgan * downscale, kernel_size=3, stride=1, padding=1)
|
|
100
|
+
|
|
101
|
+
def forward(self, x):
|
|
102
|
+
x = self.down(x)
|
|
103
|
+
x = self.norm(x)
|
|
104
|
+
x = self.act(x)
|
|
105
|
+
x = self.conv(x)
|
|
106
|
+
return x
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class UpBlock(nn.Module):
|
|
110
|
+
def __init__(self, c_lgan, downscale=2):
|
|
111
|
+
super(UpBlock, self).__init__()
|
|
112
|
+
self.conv = nn.Conv2d(c_lgan * downscale, c_lgan * downscale, kernel_size=3, stride=1, padding=1)
|
|
113
|
+
self.norm = Norm(c_lgan * downscale)
|
|
114
|
+
self.act = ACT()
|
|
115
|
+
self.up = nn.ConvTranspose2d(c_lgan * downscale, c_lgan, kernel_size=downscale, stride=downscale)
|
|
116
|
+
|
|
117
|
+
def forward(self, x):
|
|
118
|
+
x = self.conv(x)
|
|
119
|
+
x = self.norm(x)
|
|
120
|
+
x = self.act(x)
|
|
121
|
+
x = self.up(x)
|
|
122
|
+
return x
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Norm(nn.Module):
|
|
126
|
+
def __init__(self, c_in):
|
|
127
|
+
super(Norm, self).__init__()
|
|
128
|
+
self.norm = nn.BatchNorm2d(c_in)
|
|
129
|
+
# self.norm = nn.GroupNorm(1, c_in, eps=1e-6, affine=True)
|
|
130
|
+
|
|
131
|
+
def forward(self, x):
|
|
132
|
+
x = self.norm(x)
|
|
133
|
+
return x
|
|
134
|
+
|
|
135
|
+
# def __init__(self, c_in):
|
|
136
|
+
# super(Norm, self).__init__()
|
|
137
|
+
# self.norm = nn.LayerNorm(c_in, elementwise_affine=False)
|
|
138
|
+
#
|
|
139
|
+
# def forward(self, x):
|
|
140
|
+
# b, c, h, w = x.shape
|
|
141
|
+
# x = rearrange(x, 'b c h w->b (h w) c')
|
|
142
|
+
# x = self.norm(x)
|
|
143
|
+
# x = rearrange(x, 'b (h w) c->b c h w', w=w, h=h)
|
|
144
|
+
# return x
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class CubeEmbeding(nn.Module):
|
|
148
|
+
def __init__(self, c_lgan, c_in, down_sample=1):
|
|
149
|
+
super(CubeEmbeding, self).__init__()
|
|
150
|
+
self.embeding = nn.Conv3d(c_in, c_lgan, kernel_size=(2, down_sample, down_sample),
|
|
151
|
+
stride=(2, down_sample, down_sample), padding=(0, 0, 0))
|
|
152
|
+
self.norm = Norm(c_lgan)
|
|
153
|
+
self.act = ACT()
|
|
154
|
+
self.conv2 = ShiftConv2d(c_lgan, c_lgan)
|
|
155
|
+
|
|
156
|
+
def forward(self, x):
|
|
157
|
+
x = self.embeding(x)
|
|
158
|
+
x = self.norm(x[:, :, 0, :, :])
|
|
159
|
+
x = self.act(x)
|
|
160
|
+
x = self.conv2(x)
|
|
161
|
+
return x
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class CubeUnEmbeding(nn.Module):
|
|
165
|
+
def __init__(self, c_lgan, c_in, downscale=2):
|
|
166
|
+
super(CubeUnEmbeding, self).__init__()
|
|
167
|
+
self.conv1 = nn.Conv2d(c_lgan, c_lgan, kernel_size=3, stride=1, padding=1)
|
|
168
|
+
self.norm = Norm(c_lgan)
|
|
169
|
+
self.act = ACT()
|
|
170
|
+
self.up = nn.ConvTranspose2d(c_lgan, c_lgan, kernel_size=downscale, stride=downscale)
|
|
171
|
+
self.conv2 = nn.Conv2d(c_lgan, c_in, kernel_size=3, stride=1, padding=1)
|
|
172
|
+
|
|
173
|
+
def forward(self, x):
|
|
174
|
+
x = self.conv1(x)
|
|
175
|
+
x = self.norm(x)
|
|
176
|
+
x = self.act(x)
|
|
177
|
+
x = self.up(x)
|
|
178
|
+
x = self.conv2(x)
|
|
179
|
+
return x
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class Head(nn.Module):
|
|
183
|
+
def __init__(self, c_lgan, c_in, down_sample):
|
|
184
|
+
super(Head, self).__init__()
|
|
185
|
+
self.stage = nn.ModuleList()
|
|
186
|
+
while down_sample >= 2:
|
|
187
|
+
self.stage.append(
|
|
188
|
+
ShiftConv2d(c_in, c_in * 2, stride=2))
|
|
189
|
+
c_in *= 2
|
|
190
|
+
self.stage.append(Norm(c_in))
|
|
191
|
+
self.stage.append(ACT())
|
|
192
|
+
down_sample = down_sample // 2
|
|
193
|
+
|
|
194
|
+
self.conv2 = ShiftConv2d(c_in, c_lgan)
|
|
195
|
+
|
|
196
|
+
def forward(self, x):
|
|
197
|
+
# body
|
|
198
|
+
for stage in self.stage:
|
|
199
|
+
x = stage(x)
|
|
200
|
+
x = self.conv2(x)
|
|
201
|
+
return x
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class Tail(nn.Module):
|
|
205
|
+
def __init__(self, c_lgan, c_in, down_sample):
|
|
206
|
+
super(Tail, self).__init__()
|
|
207
|
+
self.down_sample = down_sample
|
|
208
|
+
self.conv1 = nn.Conv2d(c_lgan, c_lgan, kernel_size=3, stride=1, padding=1)
|
|
209
|
+
# self.norm = nn.BatchNorm2d(c_in * down_sample * down_sample)
|
|
210
|
+
self.norm = Norm(c_lgan)
|
|
211
|
+
self.act = ACT()
|
|
212
|
+
if self.down_sample == 1:
|
|
213
|
+
self.conv2 = nn.Conv2d(c_lgan, c_in, kernel_size=3,
|
|
214
|
+
stride=1, padding=1)
|
|
215
|
+
else:
|
|
216
|
+
self.conv2 = nn.Conv2d(c_lgan, c_in * down_sample * down_sample, kernel_size=3,
|
|
217
|
+
stride=1, padding=1)
|
|
218
|
+
self.ps = nn.PixelShuffle(down_sample)
|
|
219
|
+
# self.conv2 = nn.Conv2d(c_lgan, c_lgan * down_sample * down_sample, kernel_size=3,
|
|
220
|
+
# stride=1, padding=1)
|
|
221
|
+
# self.ps = nn.PixelShuffle(down_sample)
|
|
222
|
+
# self.conv3 = nn.Conv2d(c_lgan, c_in, kernel_size=3,
|
|
223
|
+
# stride=1, padding=1)
|
|
224
|
+
|
|
225
|
+
def forward(self, x):
|
|
226
|
+
x = self.conv1(x)
|
|
227
|
+
x = self.norm(x)
|
|
228
|
+
x = self.act(x)
|
|
229
|
+
x = self.conv2(x)
|
|
230
|
+
if self.down_sample != 1:
|
|
231
|
+
x = self.ps(x)
|
|
232
|
+
# x = self.conv3(x)
|
|
233
|
+
return x
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# class Tail(nn.Module):
|
|
237
|
+
# def __init__(self, c_lgan, c_in, down_sample):
|
|
238
|
+
# super(Tail, self).__init__()
|
|
239
|
+
# self.c_in = c_in
|
|
240
|
+
# self.down_sample = down_sample
|
|
241
|
+
# self.fc = nn.Linear(c_lgan, c_in * down_sample * down_sample)
|
|
242
|
+
# self.norm = nn.LayerNorm(c_lgan)
|
|
243
|
+
|
|
244
|
+
# def forward(self, x):
|
|
245
|
+
# b, c, h, w = x.shape
|
|
246
|
+
# x = rearrange(x, 'b c h w->b (h w) c')
|
|
247
|
+
# x = self.norm(x)
|
|
248
|
+
# x = self.fc(x)
|
|
249
|
+
# x = rearrange(x, 'b (h w) (c d1 d2)->b c (d1 h) (d2 w)',
|
|
250
|
+
# c=self.c_in, d1=self.down_sample, d2=self.down_sample, h=h, w=w)
|
|
251
|
+
# x = nn.functional.interpolate(x, size=[721, 1440], mode='bilinear')
|
|
252
|
+
# return x
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class FD(nn.Module):
|
|
256
|
+
def __init__(self, inp_channels, out_channels, exp_ratio=4):
|
|
257
|
+
super(FD, self).__init__()
|
|
258
|
+
# self.fc1 = MLP(inp_channels, inp_channels * exp_ratio)
|
|
259
|
+
# self.fc2 = MLP(inp_channels * exp_ratio, out_channels)
|
|
260
|
+
self.fc1 = ShiftConv2d(inp_channels, inp_channels * exp_ratio)
|
|
261
|
+
self.fc2 = ShiftConv2d(inp_channels * exp_ratio, out_channels)
|
|
262
|
+
# self.fc1 = nn.Conv2d(inp_channels, inp_channels * exp_ratio, kernel_size=3, stride=1, padding=1)
|
|
263
|
+
# self.fc2 = nn.Conv2d(inp_channels * exp_ratio, out_channels, kernel_size=3, stride=1, padding=1)
|
|
264
|
+
self.act1 = ACT()
|
|
265
|
+
self.act2 = ACT()
|
|
266
|
+
|
|
267
|
+
def forward(self, x):
|
|
268
|
+
y = self.fc1(x)
|
|
269
|
+
y = self.act1(y)
|
|
270
|
+
y = self.fc2(y)
|
|
271
|
+
return y
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class LGAB(nn.Module):
|
|
275
|
+
|
|
276
|
+
def __init__(self, channels, window_size=5, num_heads=8, split_part=3):
|
|
277
|
+
super(LGAB, self).__init__()
|
|
278
|
+
self.num_heads = num_heads
|
|
279
|
+
self.window_size = window_size
|
|
280
|
+
self.split_chns = [int(channels * 2 / split_part) for _ in range(split_part)]
|
|
281
|
+
self.f_split_chns = [int(channels / split_part) for _ in range(split_part)]
|
|
282
|
+
# self.project_inp = MLP(channels, channels * 3)
|
|
283
|
+
# self.project_out = MLP(channels, channels)
|
|
284
|
+
self.project_inp = ShiftConv2d(channels, channels * 2)
|
|
285
|
+
self.f_project_inp = ShiftConv2d(channels, channels)
|
|
286
|
+
self.project_out = ShiftConv2d(channels, channels)
|
|
287
|
+
# self.f_project_inp = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
|
|
288
|
+
# self.project_inp = nn.Conv2d(channels, channels * 2, kernel_size=1, stride=1)
|
|
289
|
+
# self.project_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
|
|
290
|
+
|
|
291
|
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((self.num_heads, 1, 1))), requires_grad=True)
|
|
292
|
+
self.lr_logit_scale = nn.Parameter(torch.log(10 * torch.ones((self.num_heads, 1, 1))), requires_grad=True)
|
|
293
|
+
# #########################################################################
|
|
294
|
+
# mlp to generate continuous relative position bias
|
|
295
|
+
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
|
|
296
|
+
nn.ReLU(inplace=True),
|
|
297
|
+
nn.Linear(512, self.num_heads, bias=False))
|
|
298
|
+
|
|
299
|
+
# get relative_coords_table
|
|
300
|
+
relative_coords_h = torch.arange(-(self.window_size - 1), self.window_size, dtype=torch.float32)
|
|
301
|
+
relative_coords_w = torch.arange(-(self.window_size - 1), self.window_size, dtype=torch.float32)
|
|
302
|
+
relative_coords_table = (torch.stack(
|
|
303
|
+
torch.meshgrid([relative_coords_h, relative_coords_w], indexing='ij'))
|
|
304
|
+
.permute(1, 2, 0).contiguous().unsqueeze(0)) # 1, 2*Wh-1, 2*Ww-1, 2
|
|
305
|
+
|
|
306
|
+
relative_coords_table[:, :, :, 0] /= (self.window_size - 1)
|
|
307
|
+
relative_coords_table[:, :, :, 1] /= (self.window_size - 1)
|
|
308
|
+
relative_coords_table *= 8 # normalize to -8, 8
|
|
309
|
+
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
|
310
|
+
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
|
311
|
+
|
|
312
|
+
self.register_buffer("relative_coords_table", relative_coords_table)
|
|
313
|
+
|
|
314
|
+
# get pair-wise relative position index for each token inside the window
|
|
315
|
+
coords_h = torch.arange(self.window_size)
|
|
316
|
+
coords_w = torch.arange(self.window_size)
|
|
317
|
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
|
|
318
|
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
319
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
320
|
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
321
|
+
relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
|
|
322
|
+
relative_coords[:, :, 1] += self.window_size - 1
|
|
323
|
+
relative_coords[:, :, 0] *= 2 * self.window_size - 1
|
|
324
|
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
325
|
+
self.register_buffer("relative_position_index", relative_position_index)
|
|
326
|
+
# #########################################################################
|
|
327
|
+
|
|
328
|
+
def wa(self, f, x, wsize):
|
|
329
|
+
b, c, h, w = x.shape
|
|
330
|
+
q = rearrange(
|
|
331
|
+
f, 'b (head c) (h dh) (w dw) -> (b h w) head (dh dw) c',
|
|
332
|
+
dh=wsize, dw=wsize, head=self.num_heads
|
|
333
|
+
)
|
|
334
|
+
k, v = rearrange(
|
|
335
|
+
x, 'b (kv head c) (h dh) (w dw) -> kv (b h w) head (dh dw) c',
|
|
336
|
+
kv=2, dh=wsize, dw=wsize, head=self.num_heads
|
|
337
|
+
)
|
|
338
|
+
atn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
|
339
|
+
t = torch.tensor(1. / 0.01).to(atn.device)
|
|
340
|
+
logit_scale = torch.clamp(self.logit_scale, max=torch.log(t)).exp()
|
|
341
|
+
atn = atn * logit_scale
|
|
342
|
+
|
|
343
|
+
atn = atn.softmax(dim=-1)
|
|
344
|
+
y_ = (atn @ v)
|
|
345
|
+
y_ = rearrange(y_, '(b h w) head (dh dw) c-> b (head c) (h dh) (w dw)',
|
|
346
|
+
h=h // wsize, w=w // wsize, dh=wsize, dw=wsize, head=self.num_heads)
|
|
347
|
+
return y_
|
|
348
|
+
|
|
349
|
+
# def forward(self, x, roll=False):
|
|
350
|
+
# y_ = self.project_inp(x)
|
|
351
|
+
# _, _, h, w = y_.shape
|
|
352
|
+
# wsize = self.window_size
|
|
353
|
+
# shifted window attention
|
|
354
|
+
# if roll:
|
|
355
|
+
# y_ = torch.roll(y_, shifts=(-wsize // 2, -wsize // 2), dims=(2, 3))
|
|
356
|
+
# y_ = self.wa(y_, wsize, (h, w))
|
|
357
|
+
# if roll:
|
|
358
|
+
# y_ = torch.roll(y_, shifts=(wsize // 2, wsize // 2), dims=(2, 3))
|
|
359
|
+
# y = self.project_out(y_)
|
|
360
|
+
# return y
|
|
361
|
+
def forward(self, x):
|
|
362
|
+
b, c, h, w = x.shape
|
|
363
|
+
x_ = x
|
|
364
|
+
x = self.project_inp(x_)
|
|
365
|
+
xs = torch.split(x, self.split_chns, dim=1)
|
|
366
|
+
f = self.f_project_inp(x_)
|
|
367
|
+
fs = torch.split(f, self.f_split_chns, dim=1)
|
|
368
|
+
wsize = self.window_size
|
|
369
|
+
ys = []
|
|
370
|
+
# window attention
|
|
371
|
+
y_ = self.wa(fs[0], xs[0], wsize)
|
|
372
|
+
ys.append(y_)
|
|
373
|
+
# shifted window attention
|
|
374
|
+
x_ = torch.roll(xs[1], shifts=(-wsize // 2, -wsize // 2), dims=(2, 3))
|
|
375
|
+
f_ = torch.roll(fs[1], shifts=(-wsize // 2, -wsize // 2), dims=(2, 3))
|
|
376
|
+
y_ = self.wa(f_, x_, wsize)
|
|
377
|
+
y_ = torch.roll(y_, shifts=(wsize // 2, wsize // 2), dims=(2, 3))
|
|
378
|
+
ys.append(y_)
|
|
379
|
+
|
|
380
|
+
# long-range attentin
|
|
381
|
+
# for longitude
|
|
382
|
+
q = rearrange(fs[2], 'b (head c) h w -> (b h) head w c', head=self.num_heads)
|
|
383
|
+
k, v = rearrange(xs[2], 'b (kv head c) h w -> kv (b h) head w c', kv=2, head=self.num_heads)
|
|
384
|
+
atn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
|
385
|
+
t = torch.tensor(1. / 0.01).to(atn.device)
|
|
386
|
+
logit_scale = torch.clamp(self.lr_logit_scale, max=torch.log(t)).exp()
|
|
387
|
+
atn = atn * logit_scale
|
|
388
|
+
# atn = (q @ k.transpose(-2, -1))
|
|
389
|
+
atn = atn.softmax(dim=-1)
|
|
390
|
+
v = (atn @ v)
|
|
391
|
+
# for latitude
|
|
392
|
+
q, k, v = (rearrange(q, '(b h) head w c -> (b w) head h c', h=h),
|
|
393
|
+
rearrange(k, '(b h) head w c -> (b w) head h c', h=h),
|
|
394
|
+
rearrange(v, '(b h) head w c -> (b w) head h c', h=h))
|
|
395
|
+
atn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
|
396
|
+
atn = atn * logit_scale
|
|
397
|
+
# atn = (q @ k.transpose(-2, -1))
|
|
398
|
+
atn = atn.softmax(dim=-1)
|
|
399
|
+
v = (atn @ v)
|
|
400
|
+
y_ = rearrange(v, '(b w) head h c-> b (head c) h w', b=b)
|
|
401
|
+
ys.append(y_)
|
|
402
|
+
|
|
403
|
+
y = torch.cat(ys, dim=1)
|
|
404
|
+
y = self.project_out(y)
|
|
405
|
+
return y
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class FEB(nn.Module):
|
|
409
|
+
def __init__(self, inp_channels, exp_ratio=2, window_size=5, num_heads=8):
|
|
410
|
+
super(FEB, self).__init__()
|
|
411
|
+
self.exp_ratio = exp_ratio
|
|
412
|
+
self.inp_channels = inp_channels
|
|
413
|
+
self.down = DownBlock(inp_channels, downscale=2)
|
|
414
|
+
self.up = UpBlock(inp_channels, downscale=2)
|
|
415
|
+
|
|
416
|
+
self.FD = FD(inp_channels=inp_channels * 2, out_channels=inp_channels * 2, exp_ratio=exp_ratio)
|
|
417
|
+
self.LGAB = LGAB(channels=inp_channels * 2, window_size=window_size, num_heads=num_heads)
|
|
418
|
+
self.norm1 = Norm(inp_channels * 2)
|
|
419
|
+
self.norm2 = Norm(inp_channels * 2)
|
|
420
|
+
self.drop = nn.Dropout2d(0.2)
|
|
421
|
+
|
|
422
|
+
def forward(self, x):
|
|
423
|
+
res = x
|
|
424
|
+
x = self.down(x)
|
|
425
|
+
shortcut = x
|
|
426
|
+
x = self.LGAB(x)
|
|
427
|
+
x = self.drop(x)
|
|
428
|
+
x = self.norm1(x) + shortcut
|
|
429
|
+
shortcut = x
|
|
430
|
+
x = self.FD(x)
|
|
431
|
+
x = self.norm2(x) + shortcut
|
|
432
|
+
x = self.up(x)
|
|
433
|
+
x = x + res
|
|
434
|
+
return x
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
if __name__ == '__main__':
|
|
438
|
+
input_data = torch.zeros((5, 32, 32))
|
|
439
|
+
conv = nn.Conv2d(5, 5, stride=2, kernel_size=3, padding=1)
|
|
440
|
+
for i in range(5):
|
|
441
|
+
input_data = conv(input_data)
|
|
442
|
+
print(input_data.shape)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from jacksung.ai.GeoNet.m_block import FEB, Tail, Head, DownBlock, UpBlock, CubeEmbeding, CubeUnEmbeding, Norm, ACT
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GeoNet(nn.Module):
|
|
7
|
+
def __init__(self, window_sizes, n_lgab, c_in, c_lgan, r_expand=4, down_sample=4, num_heads=8, task='pred',
|
|
8
|
+
downstage=2):
|
|
9
|
+
super(GeoNet, self).__init__()
|
|
10
|
+
self.window_sizes = window_sizes
|
|
11
|
+
self.n_lgab = n_lgab
|
|
12
|
+
self.c_in = c_in
|
|
13
|
+
self.c_lgan = c_lgan
|
|
14
|
+
self.r_expand = r_expand
|
|
15
|
+
self.task = task
|
|
16
|
+
self.down_sample = down_sample
|
|
17
|
+
# define head module
|
|
18
|
+
if self.task == 'prec':
|
|
19
|
+
self.head = Head(self.c_lgan, self.c_in, self.down_sample)
|
|
20
|
+
else:
|
|
21
|
+
self.head = CubeEmbeding(self.c_lgan, self.c_in, self.down_sample)
|
|
22
|
+
self.head_res = Head(self.c_lgan, self.c_in, self.down_sample)
|
|
23
|
+
# self.head = Head(self.c_lgan, self.c_in * 2, self.down_sample)
|
|
24
|
+
# define body module
|
|
25
|
+
self.body = nn.ModuleList()
|
|
26
|
+
self.downstage = downstage
|
|
27
|
+
for i in range(self.n_lgab):
|
|
28
|
+
if i / self.downstage in [1]:
|
|
29
|
+
self.body.append(DownBlock(self.c_lgan, 2))
|
|
30
|
+
self.c_lgan = self.c_lgan * 2
|
|
31
|
+
elif (self.n_lgab - i) / self.downstage in [1]:
|
|
32
|
+
self.body.append(UpBlock(self.c_lgan // 2, 2))
|
|
33
|
+
self.c_lgan = self.c_lgan // 2
|
|
34
|
+
self.body.append(
|
|
35
|
+
FEB(self.c_lgan, self.r_expand, self.window_sizes[i % len(self.window_sizes)], num_heads=num_heads))
|
|
36
|
+
# self.conv1 = nn.Conv2d(self.c_lgan, self.c_lgan, 3, 1, 1)
|
|
37
|
+
# self.norm = Norm(self.c_lgan)
|
|
38
|
+
# self.act = ACT()
|
|
39
|
+
# self.conv2 = nn.Conv2d(self.c_lgan, self.c_lgan, 3, 1, 1)
|
|
40
|
+
self.tail = Tail(self.c_lgan, 1 if self.task == 'prec' else self.c_in, down_sample)
|
|
41
|
+
# self.tail = nn.ConvTranspose3d(self.c_lgan, 5 if self.task == 'prec' else self.c_in, kernel_size=(2, 2, 2),
|
|
42
|
+
# stride=(2, 2, 2))
|
|
43
|
+
# self.tail = CubeUnEmbeding(self.c_lgan, 5 if self.task == 'prec' else self.c_in, self.down_sample)
|
|
44
|
+
|
|
45
|
+
def forward(self, f, x, roll=0):
|
|
46
|
+
# head
|
|
47
|
+
head_res = None
|
|
48
|
+
if self.task == 'pred':
|
|
49
|
+
head_res = self.head_res(x)
|
|
50
|
+
x = torch.stack([x, f], dim=2)
|
|
51
|
+
if roll > 0:
|
|
52
|
+
x = torch.roll(x, shifts=roll, dims=-1)
|
|
53
|
+
# f, x = rearrange(x, 'b c z h w->b (c z) h w'), rearrange(f, 'b c z h w->b (c z) h w')
|
|
54
|
+
# x = nn.functional.interpolate(x, size=[1600, 2000], mode='bilinear')
|
|
55
|
+
# f = nn.functional.interpolate(f, size=[1600, 2000], mode='bilinear')
|
|
56
|
+
x = self.head(x)
|
|
57
|
+
# shortcut = x
|
|
58
|
+
# body
|
|
59
|
+
x_res = None
|
|
60
|
+
for idx, stage in enumerate(self.body):
|
|
61
|
+
if idx / self.downstage in [1]:
|
|
62
|
+
x_res = x
|
|
63
|
+
x = stage(x)
|
|
64
|
+
if (self.n_lgab + 1 - idx) / self.downstage in [1]:
|
|
65
|
+
x += x_res
|
|
66
|
+
if self.task == 'pred':
|
|
67
|
+
x = head_res + x
|
|
68
|
+
# tail
|
|
69
|
+
x = self.tail(x)
|
|
70
|
+
# x = x[:, :, 0, :, :]
|
|
71
|
+
if roll > 0:
|
|
72
|
+
x = torch.roll(x, shifts=-roll, dims=-1)
|
|
73
|
+
# x = nn.functional.interpolate(x, size=[1607, 2008], mode='bilinear')
|
|
74
|
+
return x
|
|
75
|
+
|
|
76
|
+
def init_model(self):
|
|
77
|
+
print('Initializing the model!')
|
|
78
|
+
for m in self.children():
|
|
79
|
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
80
|
+
nn.init.xavier_normal_(m.weight)
|
|
81
|
+
|
|
82
|
+
def load(self, state_dict, strict=True):
|
|
83
|
+
own_state = self.state_dict()
|
|
84
|
+
for name, param in state_dict.items():
|
|
85
|
+
name = name[name.index('.') + 1:]
|
|
86
|
+
if name in own_state.keys():
|
|
87
|
+
if isinstance(param, nn.Parameter):
|
|
88
|
+
param = param.data
|
|
89
|
+
try:
|
|
90
|
+
own_state[name].copy_(param)
|
|
91
|
+
# own_state[name].requires_grad = False
|
|
92
|
+
except Exception as e:
|
|
93
|
+
err_log = f'While copying the parameter named {name}, ' \
|
|
94
|
+
f'whose dimensions in the model are {own_state[name].size()} and ' \
|
|
95
|
+
f'whose dimensions in the checkpoint are {param.size()}.'
|
|
96
|
+
if not strict:
|
|
97
|
+
print(err_log)
|
|
98
|
+
else:
|
|
99
|
+
raise Exception(err_log)
|
|
100
|
+
elif strict:
|
|
101
|
+
raise KeyError(f'unexpected key {name} in {own_state.keys()}')
|
|
102
|
+
else:
|
|
103
|
+
print(f'{name} not loaded by model')
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
if __name__ == '__main__':
|
|
107
|
+
pass
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from jacksung.ai.GeoNet.m_blockV2 import FEB, Tail, Head, DownBlock, UpBlock, CubeEmbeding, CubeUnEmbeding, Norm, ACT
|
|
3
|
+
import torch
|
|
4
|
+
import jacksung.utils.fastnumpy as fnp
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GeoNet(nn.Module):
|
|
8
|
+
def __init__(self, window_sizes, n_lgab, c_in, c_lgan, r_expand=4, down_sample=2, num_heads=8, downstage=2):
|
|
9
|
+
super(GeoNet, self).__init__()
|
|
10
|
+
self.window_sizes = window_sizes
|
|
11
|
+
self.n_lgab = n_lgab
|
|
12
|
+
self.c_in = c_in
|
|
13
|
+
self.c_lgan = c_lgan
|
|
14
|
+
self.r_expand = r_expand
|
|
15
|
+
self.down_sample = down_sample
|
|
16
|
+
# define head module
|
|
17
|
+
self.head = Head(self.c_lgan, self.c_in, self.down_sample)
|
|
18
|
+
# define body module
|
|
19
|
+
self.body = nn.ModuleList()
|
|
20
|
+
self.downstage = downstage
|
|
21
|
+
for i in range(self.n_lgab):
|
|
22
|
+
if i % self.downstage in [1]:
|
|
23
|
+
if 0 <= i < self.n_lgab / 2 - 1:
|
|
24
|
+
self.body.append(DownBlock(self.c_lgan, 2))
|
|
25
|
+
self.c_lgan = self.c_lgan * 2
|
|
26
|
+
elif i > self.n_lgab / 2 + 1:
|
|
27
|
+
self.body.append(UpBlock(self.c_lgan // 2, 2))
|
|
28
|
+
self.c_lgan = self.c_lgan // 2
|
|
29
|
+
self.body.append(
|
|
30
|
+
FEB(self.c_lgan, self.r_expand, self.window_sizes[i % len(self.window_sizes)], num_heads=num_heads))
|
|
31
|
+
self.tail = Tail(self.c_lgan, 3, down_sample)
|
|
32
|
+
|
|
33
|
+
def forward(self, x, roll=0):
|
|
34
|
+
# head
|
|
35
|
+
|
|
36
|
+
x = torch.roll(x, shifts=roll, dims=-1)
|
|
37
|
+
x = self.head(x)
|
|
38
|
+
x_res = list()
|
|
39
|
+
x_res.append(x)
|
|
40
|
+
idx = 0
|
|
41
|
+
for stage in self.body:
|
|
42
|
+
if str(stage.__class__) == "<class 'jacksung.ai.GeoNet.m_blockV2.FEB'>":
|
|
43
|
+
# 7,9
|
|
44
|
+
if idx > self.n_lgab / 2 + 1 and idx % self.downstage in [1]:
|
|
45
|
+
x += x_res.pop()
|
|
46
|
+
# 0,2
|
|
47
|
+
x = stage(x)
|
|
48
|
+
if idx < self.n_lgab / 2 - 1 and idx % self.downstage in [0]:
|
|
49
|
+
x_res.append(x)
|
|
50
|
+
idx += 1
|
|
51
|
+
else:
|
|
52
|
+
x = stage(x)
|
|
53
|
+
# tail
|
|
54
|
+
x += x_res.pop()
|
|
55
|
+
x = self.tail(x)
|
|
56
|
+
if roll > 0:
|
|
57
|
+
x = torch.roll(x, shifts=-roll, dims=-1)
|
|
58
|
+
return x
|
|
59
|
+
|
|
60
|
+
def init_model(self):
|
|
61
|
+
print('Initializing the model!')
|
|
62
|
+
for m in self.children():
|
|
63
|
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
64
|
+
nn.init.xavier_normal_(m.weight)
|
|
65
|
+
|
|
66
|
+
def load(self, state_dict, strict=True):
|
|
67
|
+
own_state = self.state_dict()
|
|
68
|
+
for name, param in state_dict.items():
|
|
69
|
+
name = name[name.index('.') + 1:]
|
|
70
|
+
if name in own_state.keys():
|
|
71
|
+
if isinstance(param, nn.Parameter):
|
|
72
|
+
param = param.data
|
|
73
|
+
try:
|
|
74
|
+
own_state[name].copy_(param)
|
|
75
|
+
# own_state[name].requires_grad = False
|
|
76
|
+
except Exception as e:
|
|
77
|
+
err_log = f'While copying the parameter named {name}, ' \
|
|
78
|
+
f'whose dimensions in the model are {own_state[name].size()} and ' \
|
|
79
|
+
f'whose dimensions in the checkpoint are {param.size()}.'
|
|
80
|
+
if not strict:
|
|
81
|
+
print(err_log)
|
|
82
|
+
else:
|
|
83
|
+
raise Exception(err_log)
|
|
84
|
+
elif strict:
|
|
85
|
+
raise KeyError(f'unexpected key {name} in {own_state.keys()}')
|
|
86
|
+
else:
|
|
87
|
+
print(f'{name} not loaded by model')
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
if __name__ == '__main__':
|
|
91
|
+
pass
|
jacksung/ai/__init__.py
ADDED
|
File without changes
|