jacksung-dev 0.0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. jacksung/__init__.py +1 -0
  2. jacksung/ai/GeoAttX.py +356 -0
  3. jacksung/ai/GeoNet/__init__.py +0 -0
  4. jacksung/ai/GeoNet/m_block.py +393 -0
  5. jacksung/ai/GeoNet/m_blockV2.py +442 -0
  6. jacksung/ai/GeoNet/m_network.py +107 -0
  7. jacksung/ai/GeoNet/m_networkV2.py +91 -0
  8. jacksung/ai/__init__.py +0 -0
  9. jacksung/ai/latex_tool.py +199 -0
  10. jacksung/ai/metrics.py +181 -0
  11. jacksung/ai/utils/__init__.py +0 -0
  12. jacksung/ai/utils/cmorph.py +42 -0
  13. jacksung/ai/utils/data_parallelV2.py +90 -0
  14. jacksung/ai/utils/fy.py +333 -0
  15. jacksung/ai/utils/goes.py +161 -0
  16. jacksung/ai/utils/gsmap.py +24 -0
  17. jacksung/ai/utils/imerg.py +159 -0
  18. jacksung/ai/utils/metsat.py +164 -0
  19. jacksung/ai/utils/norm_util.py +109 -0
  20. jacksung/ai/utils/util.py +300 -0
  21. jacksung/libs/times.ttf +0 -0
  22. jacksung/utils/__init__.py +1 -0
  23. jacksung/utils/base_db.py +72 -0
  24. jacksung/utils/cache.py +71 -0
  25. jacksung/utils/data_convert.py +273 -0
  26. jacksung/utils/exception.py +27 -0
  27. jacksung/utils/fastnumpy.py +115 -0
  28. jacksung/utils/figure.py +251 -0
  29. jacksung/utils/hash.py +26 -0
  30. jacksung/utils/image.py +221 -0
  31. jacksung/utils/log.py +86 -0
  32. jacksung/utils/login.py +149 -0
  33. jacksung/utils/mean_std.py +66 -0
  34. jacksung/utils/multi_task.py +129 -0
  35. jacksung/utils/number.py +6 -0
  36. jacksung/utils/nvidia.py +140 -0
  37. jacksung/utils/time.py +87 -0
  38. jacksung/utils/web.py +63 -0
  39. jacksung_dev-0.0.4.15.dist-info/LICENSE +201 -0
  40. jacksung_dev-0.0.4.15.dist-info/METADATA +228 -0
  41. jacksung_dev-0.0.4.15.dist-info/RECORD +44 -0
  42. jacksung_dev-0.0.4.15.dist-info/WHEEL +5 -0
  43. jacksung_dev-0.0.4.15.dist-info/entry_points.txt +3 -0
  44. jacksung_dev-0.0.4.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,442 @@
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ import numpy as np
9
+
10
+
11
+ class ShiftConv2d0(nn.Module):
12
+ def __init__(self, inp_channels, out_channels, stride):
13
+ super(ShiftConv2d0, self).__init__()
14
+ self.inp_channels = inp_channels
15
+ self.out_channels = out_channels
16
+ self.n_div = 5
17
+ self.stride = stride
18
+ g = inp_channels // self.n_div
19
+
20
+ conv3x3 = nn.Conv2d(inp_channels, out_channels, 3, 1, 1)
21
+ mask = nn.Parameter(torch.zeros((self.out_channels, self.inp_channels, 3, 3)), requires_grad=False)
22
+ mask[:, 0 * g:1 * g, 1, 2] = 1.0
23
+ mask[:, 1 * g:2 * g, 1, 0] = 1.0
24
+ mask[:, 2 * g:3 * g, 2, 1] = 1.0
25
+ mask[:, 3 * g:4 * g, 0, 1] = 1.0
26
+ mask[:, 4 * g:, 1, 1] = 1.0
27
+ self.w = conv3x3.weight
28
+ self.b = conv3x3.bias
29
+ self.m = mask
30
+
31
+ def forward(self, x):
32
+ y = F.conv2d(input=x, weight=self.w * self.m, bias=self.b, stride=self.stride, padding=1)
33
+ return y
34
+
35
+
36
+ class ShiftConv2d1(nn.Module):
37
+ def __init__(self, inp_channels, out_channels, stride):
38
+ super(ShiftConv2d1, self).__init__()
39
+ self.inp_channels = inp_channels
40
+ self.out_channels = out_channels
41
+ self.stride = stride
42
+ self.weight = nn.Parameter(torch.zeros(inp_channels, 1, 3, 3), requires_grad=False)
43
+ self.n_div = 5
44
+ g = inp_channels // self.n_div
45
+
46
+ channels_idx = list(range(inp_channels))
47
+ random.shuffle(channels_idx)
48
+ self.weight[channels_idx[0 * g:1 * g], 0, 1, 2] = 1.0 ## left
49
+ self.weight[channels_idx[1 * g:2 * g], 0, 1, 0] = 1.0 ## right
50
+ self.weight[channels_idx[2 * g:3 * g], 0, 2, 1] = 1.0 ## up
51
+ self.weight[channels_idx[3 * g:4 * g], 0, 0, 1] = 1.0 ## down
52
+ self.weight[channels_idx[4 * g:], 0, 1, 1] = 1.0 ## identity
53
+ self.conv1x1 = nn.Conv2d(inp_channels, out_channels, 1)
54
+
55
+ def forward(self, x):
56
+ y = F.conv2d(input=x, weight=self.weight, bias=None, stride=self.stride, padding=1, groups=self.inp_channels)
57
+ y = self.conv1x1(y)
58
+ return y
59
+
60
+
61
+ class ShiftConv2d(nn.Module):
62
+ def __init__(self, inp_channels, out_channels, conv_type='conv3', stride=1):
63
+ super(ShiftConv2d, self).__init__()
64
+ self.inp_channels = inp_channels
65
+ self.out_channels = out_channels
66
+ self.conv_type = conv_type
67
+ if conv_type == 'low-training-memory':
68
+ self.shift_conv = ShiftConv2d0(inp_channels, out_channels, stride=stride)
69
+ elif conv_type == 'fast-training-speed':
70
+ self.shift_conv = ShiftConv2d1(inp_channels, out_channels, stride=stride)
71
+ elif conv_type == 'common':
72
+ self.shift_conv = nn.Conv2d(inp_channels, out_channels, kernel_size=1, stride=stride)
73
+ elif conv_type == 'conv3':
74
+ self.shift_conv = nn.Conv2d(inp_channels, out_channels, kernel_size=3, stride=stride, padding=1)
75
+ else:
76
+ raise ValueError('invalid type of shift-conv2d')
77
+
78
+ def forward(self, x):
79
+ y = self.shift_conv(x)
80
+ return y
81
+
82
+
83
+ class ACT(nn.Module):
84
+ def __init__(self):
85
+ super(ACT, self).__init__()
86
+ # self.act = nn.Mish()
87
+ self.act = nn.SiLU()
88
+
89
+ def forward(self, x):
90
+ return self.act(x)
91
+
92
+
93
+ class DownBlock(nn.Module):
94
+ def __init__(self, c_lgan, downscale=2):
95
+ super(DownBlock, self).__init__()
96
+ self.down = nn.Conv2d(c_lgan, c_lgan * downscale, kernel_size=downscale, stride=downscale)
97
+ self.norm = Norm(c_lgan * downscale)
98
+ self.act = ACT()
99
+ self.conv = nn.Conv2d(c_lgan * downscale, c_lgan * downscale, kernel_size=3, stride=1, padding=1)
100
+
101
+ def forward(self, x):
102
+ x = self.down(x)
103
+ x = self.norm(x)
104
+ x = self.act(x)
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class UpBlock(nn.Module):
110
+ def __init__(self, c_lgan, downscale=2):
111
+ super(UpBlock, self).__init__()
112
+ self.conv = nn.Conv2d(c_lgan * downscale, c_lgan * downscale, kernel_size=3, stride=1, padding=1)
113
+ self.norm = Norm(c_lgan * downscale)
114
+ self.act = ACT()
115
+ self.up = nn.ConvTranspose2d(c_lgan * downscale, c_lgan, kernel_size=downscale, stride=downscale)
116
+
117
+ def forward(self, x):
118
+ x = self.conv(x)
119
+ x = self.norm(x)
120
+ x = self.act(x)
121
+ x = self.up(x)
122
+ return x
123
+
124
+
125
+ class Norm(nn.Module):
126
+ def __init__(self, c_in):
127
+ super(Norm, self).__init__()
128
+ self.norm = nn.BatchNorm2d(c_in)
129
+ # self.norm = nn.GroupNorm(1, c_in, eps=1e-6, affine=True)
130
+
131
+ def forward(self, x):
132
+ x = self.norm(x)
133
+ return x
134
+
135
+ # def __init__(self, c_in):
136
+ # super(Norm, self).__init__()
137
+ # self.norm = nn.LayerNorm(c_in, elementwise_affine=False)
138
+ #
139
+ # def forward(self, x):
140
+ # b, c, h, w = x.shape
141
+ # x = rearrange(x, 'b c h w->b (h w) c')
142
+ # x = self.norm(x)
143
+ # x = rearrange(x, 'b (h w) c->b c h w', w=w, h=h)
144
+ # return x
145
+
146
+
147
+ class CubeEmbeding(nn.Module):
148
+ def __init__(self, c_lgan, c_in, down_sample=1):
149
+ super(CubeEmbeding, self).__init__()
150
+ self.embeding = nn.Conv3d(c_in, c_lgan, kernel_size=(2, down_sample, down_sample),
151
+ stride=(2, down_sample, down_sample), padding=(0, 0, 0))
152
+ self.norm = Norm(c_lgan)
153
+ self.act = ACT()
154
+ self.conv2 = ShiftConv2d(c_lgan, c_lgan)
155
+
156
+ def forward(self, x):
157
+ x = self.embeding(x)
158
+ x = self.norm(x[:, :, 0, :, :])
159
+ x = self.act(x)
160
+ x = self.conv2(x)
161
+ return x
162
+
163
+
164
+ class CubeUnEmbeding(nn.Module):
165
+ def __init__(self, c_lgan, c_in, downscale=2):
166
+ super(CubeUnEmbeding, self).__init__()
167
+ self.conv1 = nn.Conv2d(c_lgan, c_lgan, kernel_size=3, stride=1, padding=1)
168
+ self.norm = Norm(c_lgan)
169
+ self.act = ACT()
170
+ self.up = nn.ConvTranspose2d(c_lgan, c_lgan, kernel_size=downscale, stride=downscale)
171
+ self.conv2 = nn.Conv2d(c_lgan, c_in, kernel_size=3, stride=1, padding=1)
172
+
173
+ def forward(self, x):
174
+ x = self.conv1(x)
175
+ x = self.norm(x)
176
+ x = self.act(x)
177
+ x = self.up(x)
178
+ x = self.conv2(x)
179
+ return x
180
+
181
+
182
+ class Head(nn.Module):
183
+ def __init__(self, c_lgan, c_in, down_sample):
184
+ super(Head, self).__init__()
185
+ self.stage = nn.ModuleList()
186
+ while down_sample >= 2:
187
+ self.stage.append(
188
+ ShiftConv2d(c_in, c_in * 2, stride=2))
189
+ c_in *= 2
190
+ self.stage.append(Norm(c_in))
191
+ self.stage.append(ACT())
192
+ down_sample = down_sample // 2
193
+
194
+ self.conv2 = ShiftConv2d(c_in, c_lgan)
195
+
196
+ def forward(self, x):
197
+ # body
198
+ for stage in self.stage:
199
+ x = stage(x)
200
+ x = self.conv2(x)
201
+ return x
202
+
203
+
204
+ class Tail(nn.Module):
205
+ def __init__(self, c_lgan, c_in, down_sample):
206
+ super(Tail, self).__init__()
207
+ self.down_sample = down_sample
208
+ self.conv1 = nn.Conv2d(c_lgan, c_lgan, kernel_size=3, stride=1, padding=1)
209
+ # self.norm = nn.BatchNorm2d(c_in * down_sample * down_sample)
210
+ self.norm = Norm(c_lgan)
211
+ self.act = ACT()
212
+ if self.down_sample == 1:
213
+ self.conv2 = nn.Conv2d(c_lgan, c_in, kernel_size=3,
214
+ stride=1, padding=1)
215
+ else:
216
+ self.conv2 = nn.Conv2d(c_lgan, c_in * down_sample * down_sample, kernel_size=3,
217
+ stride=1, padding=1)
218
+ self.ps = nn.PixelShuffle(down_sample)
219
+ # self.conv2 = nn.Conv2d(c_lgan, c_lgan * down_sample * down_sample, kernel_size=3,
220
+ # stride=1, padding=1)
221
+ # self.ps = nn.PixelShuffle(down_sample)
222
+ # self.conv3 = nn.Conv2d(c_lgan, c_in, kernel_size=3,
223
+ # stride=1, padding=1)
224
+
225
+ def forward(self, x):
226
+ x = self.conv1(x)
227
+ x = self.norm(x)
228
+ x = self.act(x)
229
+ x = self.conv2(x)
230
+ if self.down_sample != 1:
231
+ x = self.ps(x)
232
+ # x = self.conv3(x)
233
+ return x
234
+
235
+
236
+ # class Tail(nn.Module):
237
+ # def __init__(self, c_lgan, c_in, down_sample):
238
+ # super(Tail, self).__init__()
239
+ # self.c_in = c_in
240
+ # self.down_sample = down_sample
241
+ # self.fc = nn.Linear(c_lgan, c_in * down_sample * down_sample)
242
+ # self.norm = nn.LayerNorm(c_lgan)
243
+
244
+ # def forward(self, x):
245
+ # b, c, h, w = x.shape
246
+ # x = rearrange(x, 'b c h w->b (h w) c')
247
+ # x = self.norm(x)
248
+ # x = self.fc(x)
249
+ # x = rearrange(x, 'b (h w) (c d1 d2)->b c (d1 h) (d2 w)',
250
+ # c=self.c_in, d1=self.down_sample, d2=self.down_sample, h=h, w=w)
251
+ # x = nn.functional.interpolate(x, size=[721, 1440], mode='bilinear')
252
+ # return x
253
+
254
+
255
+ class FD(nn.Module):
256
+ def __init__(self, inp_channels, out_channels, exp_ratio=4):
257
+ super(FD, self).__init__()
258
+ # self.fc1 = MLP(inp_channels, inp_channels * exp_ratio)
259
+ # self.fc2 = MLP(inp_channels * exp_ratio, out_channels)
260
+ self.fc1 = ShiftConv2d(inp_channels, inp_channels * exp_ratio)
261
+ self.fc2 = ShiftConv2d(inp_channels * exp_ratio, out_channels)
262
+ # self.fc1 = nn.Conv2d(inp_channels, inp_channels * exp_ratio, kernel_size=3, stride=1, padding=1)
263
+ # self.fc2 = nn.Conv2d(inp_channels * exp_ratio, out_channels, kernel_size=3, stride=1, padding=1)
264
+ self.act1 = ACT()
265
+ self.act2 = ACT()
266
+
267
+ def forward(self, x):
268
+ y = self.fc1(x)
269
+ y = self.act1(y)
270
+ y = self.fc2(y)
271
+ return y
272
+
273
+
274
+ class LGAB(nn.Module):
275
+
276
+ def __init__(self, channels, window_size=5, num_heads=8, split_part=3):
277
+ super(LGAB, self).__init__()
278
+ self.num_heads = num_heads
279
+ self.window_size = window_size
280
+ self.split_chns = [int(channels * 2 / split_part) for _ in range(split_part)]
281
+ self.f_split_chns = [int(channels / split_part) for _ in range(split_part)]
282
+ # self.project_inp = MLP(channels, channels * 3)
283
+ # self.project_out = MLP(channels, channels)
284
+ self.project_inp = ShiftConv2d(channels, channels * 2)
285
+ self.f_project_inp = ShiftConv2d(channels, channels)
286
+ self.project_out = ShiftConv2d(channels, channels)
287
+ # self.f_project_inp = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
288
+ # self.project_inp = nn.Conv2d(channels, channels * 2, kernel_size=1, stride=1)
289
+ # self.project_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
290
+
291
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((self.num_heads, 1, 1))), requires_grad=True)
292
+ self.lr_logit_scale = nn.Parameter(torch.log(10 * torch.ones((self.num_heads, 1, 1))), requires_grad=True)
293
+ # #########################################################################
294
+ # mlp to generate continuous relative position bias
295
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
296
+ nn.ReLU(inplace=True),
297
+ nn.Linear(512, self.num_heads, bias=False))
298
+
299
+ # get relative_coords_table
300
+ relative_coords_h = torch.arange(-(self.window_size - 1), self.window_size, dtype=torch.float32)
301
+ relative_coords_w = torch.arange(-(self.window_size - 1), self.window_size, dtype=torch.float32)
302
+ relative_coords_table = (torch.stack(
303
+ torch.meshgrid([relative_coords_h, relative_coords_w], indexing='ij'))
304
+ .permute(1, 2, 0).contiguous().unsqueeze(0)) # 1, 2*Wh-1, 2*Ww-1, 2
305
+
306
+ relative_coords_table[:, :, :, 0] /= (self.window_size - 1)
307
+ relative_coords_table[:, :, :, 1] /= (self.window_size - 1)
308
+ relative_coords_table *= 8 # normalize to -8, 8
309
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
310
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
311
+
312
+ self.register_buffer("relative_coords_table", relative_coords_table)
313
+
314
+ # get pair-wise relative position index for each token inside the window
315
+ coords_h = torch.arange(self.window_size)
316
+ coords_w = torch.arange(self.window_size)
317
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
318
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
319
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
320
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
321
+ relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
322
+ relative_coords[:, :, 1] += self.window_size - 1
323
+ relative_coords[:, :, 0] *= 2 * self.window_size - 1
324
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
325
+ self.register_buffer("relative_position_index", relative_position_index)
326
+ # #########################################################################
327
+
328
+ def wa(self, f, x, wsize):
329
+ b, c, h, w = x.shape
330
+ q = rearrange(
331
+ f, 'b (head c) (h dh) (w dw) -> (b h w) head (dh dw) c',
332
+ dh=wsize, dw=wsize, head=self.num_heads
333
+ )
334
+ k, v = rearrange(
335
+ x, 'b (kv head c) (h dh) (w dw) -> kv (b h w) head (dh dw) c',
336
+ kv=2, dh=wsize, dw=wsize, head=self.num_heads
337
+ )
338
+ atn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
339
+ t = torch.tensor(1. / 0.01).to(atn.device)
340
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(t)).exp()
341
+ atn = atn * logit_scale
342
+
343
+ atn = atn.softmax(dim=-1)
344
+ y_ = (atn @ v)
345
+ y_ = rearrange(y_, '(b h w) head (dh dw) c-> b (head c) (h dh) (w dw)',
346
+ h=h // wsize, w=w // wsize, dh=wsize, dw=wsize, head=self.num_heads)
347
+ return y_
348
+
349
+ # def forward(self, x, roll=False):
350
+ # y_ = self.project_inp(x)
351
+ # _, _, h, w = y_.shape
352
+ # wsize = self.window_size
353
+ # shifted window attention
354
+ # if roll:
355
+ # y_ = torch.roll(y_, shifts=(-wsize // 2, -wsize // 2), dims=(2, 3))
356
+ # y_ = self.wa(y_, wsize, (h, w))
357
+ # if roll:
358
+ # y_ = torch.roll(y_, shifts=(wsize // 2, wsize // 2), dims=(2, 3))
359
+ # y = self.project_out(y_)
360
+ # return y
361
+ def forward(self, x):
362
+ b, c, h, w = x.shape
363
+ x_ = x
364
+ x = self.project_inp(x_)
365
+ xs = torch.split(x, self.split_chns, dim=1)
366
+ f = self.f_project_inp(x_)
367
+ fs = torch.split(f, self.f_split_chns, dim=1)
368
+ wsize = self.window_size
369
+ ys = []
370
+ # window attention
371
+ y_ = self.wa(fs[0], xs[0], wsize)
372
+ ys.append(y_)
373
+ # shifted window attention
374
+ x_ = torch.roll(xs[1], shifts=(-wsize // 2, -wsize // 2), dims=(2, 3))
375
+ f_ = torch.roll(fs[1], shifts=(-wsize // 2, -wsize // 2), dims=(2, 3))
376
+ y_ = self.wa(f_, x_, wsize)
377
+ y_ = torch.roll(y_, shifts=(wsize // 2, wsize // 2), dims=(2, 3))
378
+ ys.append(y_)
379
+
380
+ # long-range attentin
381
+ # for longitude
382
+ q = rearrange(fs[2], 'b (head c) h w -> (b h) head w c', head=self.num_heads)
383
+ k, v = rearrange(xs[2], 'b (kv head c) h w -> kv (b h) head w c', kv=2, head=self.num_heads)
384
+ atn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
385
+ t = torch.tensor(1. / 0.01).to(atn.device)
386
+ logit_scale = torch.clamp(self.lr_logit_scale, max=torch.log(t)).exp()
387
+ atn = atn * logit_scale
388
+ # atn = (q @ k.transpose(-2, -1))
389
+ atn = atn.softmax(dim=-1)
390
+ v = (atn @ v)
391
+ # for latitude
392
+ q, k, v = (rearrange(q, '(b h) head w c -> (b w) head h c', h=h),
393
+ rearrange(k, '(b h) head w c -> (b w) head h c', h=h),
394
+ rearrange(v, '(b h) head w c -> (b w) head h c', h=h))
395
+ atn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
396
+ atn = atn * logit_scale
397
+ # atn = (q @ k.transpose(-2, -1))
398
+ atn = atn.softmax(dim=-1)
399
+ v = (atn @ v)
400
+ y_ = rearrange(v, '(b w) head h c-> b (head c) h w', b=b)
401
+ ys.append(y_)
402
+
403
+ y = torch.cat(ys, dim=1)
404
+ y = self.project_out(y)
405
+ return y
406
+
407
+
408
+ class FEB(nn.Module):
409
+ def __init__(self, inp_channels, exp_ratio=2, window_size=5, num_heads=8):
410
+ super(FEB, self).__init__()
411
+ self.exp_ratio = exp_ratio
412
+ self.inp_channels = inp_channels
413
+ self.down = DownBlock(inp_channels, downscale=2)
414
+ self.up = UpBlock(inp_channels, downscale=2)
415
+
416
+ self.FD = FD(inp_channels=inp_channels * 2, out_channels=inp_channels * 2, exp_ratio=exp_ratio)
417
+ self.LGAB = LGAB(channels=inp_channels * 2, window_size=window_size, num_heads=num_heads)
418
+ self.norm1 = Norm(inp_channels * 2)
419
+ self.norm2 = Norm(inp_channels * 2)
420
+ self.drop = nn.Dropout2d(0.2)
421
+
422
+ def forward(self, x):
423
+ res = x
424
+ x = self.down(x)
425
+ shortcut = x
426
+ x = self.LGAB(x)
427
+ x = self.drop(x)
428
+ x = self.norm1(x) + shortcut
429
+ shortcut = x
430
+ x = self.FD(x)
431
+ x = self.norm2(x) + shortcut
432
+ x = self.up(x)
433
+ x = x + res
434
+ return x
435
+
436
+
437
+ if __name__ == '__main__':
438
+ input_data = torch.zeros((5, 32, 32))
439
+ conv = nn.Conv2d(5, 5, stride=2, kernel_size=3, padding=1)
440
+ for i in range(5):
441
+ input_data = conv(input_data)
442
+ print(input_data.shape)
@@ -0,0 +1,107 @@
1
+ import torch.nn as nn
2
+ from jacksung.ai.GeoNet.m_block import FEB, Tail, Head, DownBlock, UpBlock, CubeEmbeding, CubeUnEmbeding, Norm, ACT
3
+ import torch
4
+
5
+
6
+ class GeoNet(nn.Module):
7
+ def __init__(self, window_sizes, n_lgab, c_in, c_lgan, r_expand=4, down_sample=4, num_heads=8, task='pred',
8
+ downstage=2):
9
+ super(GeoNet, self).__init__()
10
+ self.window_sizes = window_sizes
11
+ self.n_lgab = n_lgab
12
+ self.c_in = c_in
13
+ self.c_lgan = c_lgan
14
+ self.r_expand = r_expand
15
+ self.task = task
16
+ self.down_sample = down_sample
17
+ # define head module
18
+ if self.task == 'prec':
19
+ self.head = Head(self.c_lgan, self.c_in, self.down_sample)
20
+ else:
21
+ self.head = CubeEmbeding(self.c_lgan, self.c_in, self.down_sample)
22
+ self.head_res = Head(self.c_lgan, self.c_in, self.down_sample)
23
+ # self.head = Head(self.c_lgan, self.c_in * 2, self.down_sample)
24
+ # define body module
25
+ self.body = nn.ModuleList()
26
+ self.downstage = downstage
27
+ for i in range(self.n_lgab):
28
+ if i / self.downstage in [1]:
29
+ self.body.append(DownBlock(self.c_lgan, 2))
30
+ self.c_lgan = self.c_lgan * 2
31
+ elif (self.n_lgab - i) / self.downstage in [1]:
32
+ self.body.append(UpBlock(self.c_lgan // 2, 2))
33
+ self.c_lgan = self.c_lgan // 2
34
+ self.body.append(
35
+ FEB(self.c_lgan, self.r_expand, self.window_sizes[i % len(self.window_sizes)], num_heads=num_heads))
36
+ # self.conv1 = nn.Conv2d(self.c_lgan, self.c_lgan, 3, 1, 1)
37
+ # self.norm = Norm(self.c_lgan)
38
+ # self.act = ACT()
39
+ # self.conv2 = nn.Conv2d(self.c_lgan, self.c_lgan, 3, 1, 1)
40
+ self.tail = Tail(self.c_lgan, 1 if self.task == 'prec' else self.c_in, down_sample)
41
+ # self.tail = nn.ConvTranspose3d(self.c_lgan, 5 if self.task == 'prec' else self.c_in, kernel_size=(2, 2, 2),
42
+ # stride=(2, 2, 2))
43
+ # self.tail = CubeUnEmbeding(self.c_lgan, 5 if self.task == 'prec' else self.c_in, self.down_sample)
44
+
45
+ def forward(self, f, x, roll=0):
46
+ # head
47
+ head_res = None
48
+ if self.task == 'pred':
49
+ head_res = self.head_res(x)
50
+ x = torch.stack([x, f], dim=2)
51
+ if roll > 0:
52
+ x = torch.roll(x, shifts=roll, dims=-1)
53
+ # f, x = rearrange(x, 'b c z h w->b (c z) h w'), rearrange(f, 'b c z h w->b (c z) h w')
54
+ # x = nn.functional.interpolate(x, size=[1600, 2000], mode='bilinear')
55
+ # f = nn.functional.interpolate(f, size=[1600, 2000], mode='bilinear')
56
+ x = self.head(x)
57
+ # shortcut = x
58
+ # body
59
+ x_res = None
60
+ for idx, stage in enumerate(self.body):
61
+ if idx / self.downstage in [1]:
62
+ x_res = x
63
+ x = stage(x)
64
+ if (self.n_lgab + 1 - idx) / self.downstage in [1]:
65
+ x += x_res
66
+ if self.task == 'pred':
67
+ x = head_res + x
68
+ # tail
69
+ x = self.tail(x)
70
+ # x = x[:, :, 0, :, :]
71
+ if roll > 0:
72
+ x = torch.roll(x, shifts=-roll, dims=-1)
73
+ # x = nn.functional.interpolate(x, size=[1607, 2008], mode='bilinear')
74
+ return x
75
+
76
+ def init_model(self):
77
+ print('Initializing the model!')
78
+ for m in self.children():
79
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
80
+ nn.init.xavier_normal_(m.weight)
81
+
82
+ def load(self, state_dict, strict=True):
83
+ own_state = self.state_dict()
84
+ for name, param in state_dict.items():
85
+ name = name[name.index('.') + 1:]
86
+ if name in own_state.keys():
87
+ if isinstance(param, nn.Parameter):
88
+ param = param.data
89
+ try:
90
+ own_state[name].copy_(param)
91
+ # own_state[name].requires_grad = False
92
+ except Exception as e:
93
+ err_log = f'While copying the parameter named {name}, ' \
94
+ f'whose dimensions in the model are {own_state[name].size()} and ' \
95
+ f'whose dimensions in the checkpoint are {param.size()}.'
96
+ if not strict:
97
+ print(err_log)
98
+ else:
99
+ raise Exception(err_log)
100
+ elif strict:
101
+ raise KeyError(f'unexpected key {name} in {own_state.keys()}')
102
+ else:
103
+ print(f'{name} not loaded by model')
104
+
105
+
106
+ if __name__ == '__main__':
107
+ pass
@@ -0,0 +1,91 @@
1
+ import torch.nn as nn
2
+ from jacksung.ai.GeoNet.m_blockV2 import FEB, Tail, Head, DownBlock, UpBlock, CubeEmbeding, CubeUnEmbeding, Norm, ACT
3
+ import torch
4
+ import jacksung.utils.fastnumpy as fnp
5
+
6
+
7
+ class GeoNet(nn.Module):
8
+ def __init__(self, window_sizes, n_lgab, c_in, c_lgan, r_expand=4, down_sample=2, num_heads=8, downstage=2):
9
+ super(GeoNet, self).__init__()
10
+ self.window_sizes = window_sizes
11
+ self.n_lgab = n_lgab
12
+ self.c_in = c_in
13
+ self.c_lgan = c_lgan
14
+ self.r_expand = r_expand
15
+ self.down_sample = down_sample
16
+ # define head module
17
+ self.head = Head(self.c_lgan, self.c_in, self.down_sample)
18
+ # define body module
19
+ self.body = nn.ModuleList()
20
+ self.downstage = downstage
21
+ for i in range(self.n_lgab):
22
+ if i % self.downstage in [1]:
23
+ if 0 <= i < self.n_lgab / 2 - 1:
24
+ self.body.append(DownBlock(self.c_lgan, 2))
25
+ self.c_lgan = self.c_lgan * 2
26
+ elif i > self.n_lgab / 2 + 1:
27
+ self.body.append(UpBlock(self.c_lgan // 2, 2))
28
+ self.c_lgan = self.c_lgan // 2
29
+ self.body.append(
30
+ FEB(self.c_lgan, self.r_expand, self.window_sizes[i % len(self.window_sizes)], num_heads=num_heads))
31
+ self.tail = Tail(self.c_lgan, 3, down_sample)
32
+
33
+ def forward(self, x, roll=0):
34
+ # head
35
+
36
+ x = torch.roll(x, shifts=roll, dims=-1)
37
+ x = self.head(x)
38
+ x_res = list()
39
+ x_res.append(x)
40
+ idx = 0
41
+ for stage in self.body:
42
+ if str(stage.__class__) == "<class 'jacksung.ai.GeoNet.m_blockV2.FEB'>":
43
+ # 7,9
44
+ if idx > self.n_lgab / 2 + 1 and idx % self.downstage in [1]:
45
+ x += x_res.pop()
46
+ # 0,2
47
+ x = stage(x)
48
+ if idx < self.n_lgab / 2 - 1 and idx % self.downstage in [0]:
49
+ x_res.append(x)
50
+ idx += 1
51
+ else:
52
+ x = stage(x)
53
+ # tail
54
+ x += x_res.pop()
55
+ x = self.tail(x)
56
+ if roll > 0:
57
+ x = torch.roll(x, shifts=-roll, dims=-1)
58
+ return x
59
+
60
+ def init_model(self):
61
+ print('Initializing the model!')
62
+ for m in self.children():
63
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
64
+ nn.init.xavier_normal_(m.weight)
65
+
66
+ def load(self, state_dict, strict=True):
67
+ own_state = self.state_dict()
68
+ for name, param in state_dict.items():
69
+ name = name[name.index('.') + 1:]
70
+ if name in own_state.keys():
71
+ if isinstance(param, nn.Parameter):
72
+ param = param.data
73
+ try:
74
+ own_state[name].copy_(param)
75
+ # own_state[name].requires_grad = False
76
+ except Exception as e:
77
+ err_log = f'While copying the parameter named {name}, ' \
78
+ f'whose dimensions in the model are {own_state[name].size()} and ' \
79
+ f'whose dimensions in the checkpoint are {param.size()}.'
80
+ if not strict:
81
+ print(err_log)
82
+ else:
83
+ raise Exception(err_log)
84
+ elif strict:
85
+ raise KeyError(f'unexpected key {name} in {own_state.keys()}')
86
+ else:
87
+ print(f'{name} not loaded by model')
88
+
89
+
90
+ if __name__ == '__main__':
91
+ pass
File without changes