doctra 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. doctra/__init__.py +4 -0
  2. doctra/cli/main.py +168 -0
  3. doctra/engines/image_restoration/__init__.py +10 -0
  4. doctra/engines/image_restoration/docres_engine.py +566 -0
  5. doctra/engines/vlm/service.py +0 -12
  6. doctra/parsers/enhanced_pdf_parser.py +370 -0
  7. doctra/parsers/structured_pdf_parser.py +11 -60
  8. doctra/parsers/table_chart_extractor.py +8 -44
  9. doctra/third_party/docres/data/MBD/MBD.py +110 -0
  10. doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
  11. doctra/third_party/docres/data/MBD/infer.py +151 -0
  12. doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
  13. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  14. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  15. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  16. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  17. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  18. doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
  19. doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
  20. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  21. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  22. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  23. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  24. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  25. doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
  26. doctra/third_party/docres/inference.py +370 -0
  27. doctra/third_party/docres/models/restormer_arch.py +308 -0
  28. doctra/third_party/docres/utils.py +464 -0
  29. doctra/ui/app.py +5 -32
  30. doctra/utils/progress.py +13 -98
  31. doctra/utils/structured_utils.py +45 -49
  32. doctra/version.py +1 -1
  33. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/METADATA +1 -1
  34. doctra-0.4.0.dist-info/RECORD +67 -0
  35. doctra-0.3.2.dist-info/RECORD +0 -44
  36. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/WHEEL +0 -0
  37. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/licenses/LICENSE +0 -0
  38. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
1
+ import torch
2
+ import argparse
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import glob
6
+ import cv2
7
+ from tqdm import tqdm
8
+
9
+ import time
10
+ import os
11
+ from model.deep_lab_model.deeplab import *
12
+ from MBD import mask_base_dewarper
13
+ import time
14
+
15
+ from utils import cvimg2torch,torch2cvimg
16
+
17
+
18
+
19
+ def net1_net2_infer(model,img_paths,args):
20
+
21
+ ### validate on the real datasets
22
+ seg_model=model
23
+ seg_model.eval()
24
+ for img_path in tqdm(img_paths):
25
+ if os.path.exists(img_path.replace('_origin','_capture')):
26
+ continue
27
+ t1 = time.time()
28
+ ### segmentation mask predict
29
+ img_org = cv2.imread(img_path)
30
+ h_org,w_org = img_org.shape[:2]
31
+ img = cv2.resize(img_org,(448, 448))
32
+ img = cv2.GaussianBlur(img,(15,15),0,0)
33
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
34
+ img = cvimg2torch(img)
35
+
36
+ with torch.no_grad():
37
+ pred = seg_model(img.cuda())
38
+ mask_pred = pred[:,0,:,:].unsqueeze(1)
39
+ mask_pred = F.interpolate(mask_pred,(h_org,w_org))
40
+ mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
41
+ mask_pred = (mask_pred*255).astype(np.uint8)
42
+ kernel = np.ones((3,3))
43
+ mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
44
+ mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
45
+ mask_pred[mask_pred>100] = 255
46
+ mask_pred[mask_pred<100] = 0
47
+ ### tps transform base on the mask
48
+ # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
49
+ try:
50
+ dewarp, grid = mask_base_dewarper(img_org,mask_pred)
51
+ except:
52
+ print('fail')
53
+ grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
54
+ grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
55
+ dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
56
+ grid = grid[0].numpy()
57
+ # cv2.imshow('in',cv2.resize(img_org,(512,512)))
58
+ # cv2.imshow('out',cv2.resize(dewarp,(512,512)))
59
+ # cv2.waitKey(0)
60
+ cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
61
+ cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
62
+
63
+ grid0 = cv2.resize(grid[:,:,0],(128,128))
64
+ grid1 = cv2.resize(grid[:,:,1],(128,128))
65
+ grid = np.stack((grid0,grid1),axis=-1)
66
+ np.save(img_path.replace('_origin','_grid1'),grid)
67
+
68
+
69
+ def net1_net2_infer_single_im(img,model_path):
70
+ seg_model = DeepLab(num_classes=1,
71
+ backbone='resnet',
72
+ output_stride=16,
73
+ sync_bn=None,
74
+ freeze_bn=False)
75
+ seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
76
+ seg_model.cuda()
77
+ checkpoint = torch.load(model_path)
78
+ seg_model.load_state_dict(checkpoint['model_state'])
79
+ ### validate on the real datasets
80
+ seg_model.eval()
81
+ ### segmentation mask predict
82
+ img_org = img
83
+ h_org,w_org = img_org.shape[:2]
84
+ img = cv2.resize(img_org,(448, 448))
85
+ img = cv2.GaussianBlur(img,(15,15),0,0)
86
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
87
+ img = cvimg2torch(img)
88
+
89
+ with torch.no_grad():
90
+ # from torchtoolbox.tools import summary
91
+ # print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G
92
+
93
+ pred = seg_model(img.cuda())
94
+ mask_pred = pred[:,0,:,:].unsqueeze(1)
95
+ mask_pred = F.interpolate(mask_pred,(h_org,w_org))
96
+ mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
97
+ mask_pred = (mask_pred*255).astype(np.uint8)
98
+ kernel = np.ones((3,3))
99
+ mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
100
+ mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
101
+ mask_pred[mask_pred>100] = 255
102
+ mask_pred[mask_pred<100] = 0
103
+ ### tps transform base on the mask
104
+ # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
105
+ # try:
106
+ # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
107
+ # except:
108
+ # print('fail')
109
+ # grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
110
+ # grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
111
+ # dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
112
+ # grid = grid[0].numpy()
113
+ # cv2.imshow('in',cv2.resize(img_org,(512,512)))
114
+ # cv2.imshow('out',cv2.resize(dewarp,(512,512)))
115
+ # cv2.waitKey(0)
116
+ # cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
117
+ # cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
118
+
119
+ # grid0 = cv2.resize(grid[:,:,0],(128,128))
120
+ # grid1 = cv2.resize(grid[:,:,1],(128,128))
121
+ # grid = np.stack((grid0,grid1),axis=-1)
122
+ # np.save(img_path.replace('_origin','_grid1'),grid)
123
+ return mask_pred
124
+
125
+
126
+
127
+ if __name__ == '__main__':
128
+ parser = argparse.ArgumentParser(description='Hyperparams')
129
+ parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
130
+ parser.add_argument('--img_rows', nargs='?', type=int, default=448,
131
+ help='Height of the input image')
132
+ parser.add_argument('--img_cols', nargs='?', type=int, default=448,
133
+ help='Width of the input image')
134
+ parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl',
135
+ help='Path to previous saved model to restart from')
136
+ args = parser.parse_args()
137
+
138
+ seg_model = DeepLab(num_classes=1,
139
+ backbone='resnet',
140
+ output_stride=16,
141
+ sync_bn=None,
142
+ freeze_bn=False)
143
+ seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
144
+ seg_model.cuda()
145
+ checkpoint = torch.load(args.seg_model_path)
146
+ seg_model.load_state_dict(checkpoint['model_state'])
147
+
148
+ im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
149
+
150
+ net1_net2_infer(seg_model,im_paths,args)
151
+
@@ -0,0 +1,95 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6
+
7
+ class _ASPPModule(nn.Module):
8
+ def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
9
+ super(_ASPPModule, self).__init__()
10
+ self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
11
+ stride=1, padding=padding, dilation=dilation, bias=False)
12
+ self.bn = BatchNorm(planes)
13
+ self.relu = nn.ReLU()
14
+
15
+ self._init_weight()
16
+
17
+ def forward(self, x):
18
+ x = self.atrous_conv(x)
19
+ x = self.bn(x)
20
+
21
+ return self.relu(x)
22
+
23
+ def _init_weight(self):
24
+ for m in self.modules():
25
+ if isinstance(m, nn.Conv2d):
26
+ torch.nn.init.kaiming_normal_(m.weight)
27
+ elif isinstance(m, SynchronizedBatchNorm2d):
28
+ m.weight.data.fill_(1)
29
+ m.bias.data.zero_()
30
+ elif isinstance(m, nn.BatchNorm2d):
31
+ m.weight.data.fill_(1)
32
+ m.bias.data.zero_()
33
+
34
+ class ASPP(nn.Module):
35
+ def __init__(self, backbone, output_stride, BatchNorm):
36
+ super(ASPP, self).__init__()
37
+ if backbone == 'drn':
38
+ inplanes = 512
39
+ elif backbone == 'mobilenet':
40
+ inplanes = 320
41
+ else:
42
+ inplanes = 2048
43
+ if output_stride == 16:
44
+ dilations = [1, 6, 12, 18]
45
+ elif output_stride == 8:
46
+ dilations = [1, 12, 24, 36]
47
+ else:
48
+ raise NotImplementedError
49
+
50
+ self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
51
+ self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
52
+ self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
53
+ self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
54
+
55
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
56
+ nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
57
+ BatchNorm(256),
58
+ nn.ReLU())
59
+ self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
60
+ self.bn1 = BatchNorm(256)
61
+ self.relu = nn.ReLU()
62
+ self.dropout = nn.Dropout(0.5)
63
+ self._init_weight()
64
+
65
+ def forward(self, x):
66
+ x1 = self.aspp1(x)
67
+ x2 = self.aspp2(x)
68
+ x3 = self.aspp3(x)
69
+ x4 = self.aspp4(x)
70
+ x5 = self.global_avg_pool(x)
71
+ x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
72
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
73
+
74
+ x = self.conv1(x)
75
+ x = self.bn1(x)
76
+ x = self.relu(x)
77
+
78
+ return self.dropout(x)
79
+
80
+ def _init_weight(self):
81
+ for m in self.modules():
82
+ if isinstance(m, nn.Conv2d):
83
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
84
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
85
+ torch.nn.init.kaiming_normal_(m.weight)
86
+ elif isinstance(m, SynchronizedBatchNorm2d):
87
+ m.weight.data.fill_(1)
88
+ m.bias.data.zero_()
89
+ elif isinstance(m, nn.BatchNorm2d):
90
+ m.weight.data.fill_(1)
91
+ m.bias.data.zero_()
92
+
93
+
94
+ def build_aspp(backbone, output_stride, BatchNorm):
95
+ return ASPP(backbone, output_stride, BatchNorm)
@@ -0,0 +1,13 @@
1
+ from model.deep_lab_model.backbone import resnet, xception, drn, mobilenet
2
+
3
+ def build_backbone(backbone, output_stride, BatchNorm):
4
+ if backbone == 'resnet':
5
+ return resnet.ResNet101(output_stride, BatchNorm)
6
+ elif backbone == 'xception':
7
+ return xception.AlignedXception(output_stride, BatchNorm)
8
+ elif backbone == 'drn':
9
+ return drn.drn_d_54(BatchNorm)
10
+ elif backbone == 'mobilenet':
11
+ return mobilenet.MobileNetV2(output_stride, BatchNorm)
12
+ else:
13
+ raise NotImplementedError
@@ -0,0 +1,402 @@
1
+ import torch.nn as nn
2
+ import math
3
+ import torch.utils.model_zoo as model_zoo
4
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5
+
6
+ webroot = 'http://dl.yf.io/drn/'
7
+
8
+ model_urls = {
9
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
10
+ 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth',
11
+ 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth',
12
+ 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth',
13
+ 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth',
14
+ 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth',
15
+ 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth',
16
+ 'drn-d-105': webroot + 'drn_d_105-12b40979.pth'
17
+ }
18
+
19
+
20
+ def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
21
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22
+ padding=padding, bias=False, dilation=dilation)
23
+
24
+
25
+ class BasicBlock(nn.Module):
26
+ expansion = 1
27
+
28
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
29
+ dilation=(1, 1), residual=True, BatchNorm=None):
30
+ super(BasicBlock, self).__init__()
31
+ self.conv1 = conv3x3(inplanes, planes, stride,
32
+ padding=dilation[0], dilation=dilation[0])
33
+ self.bn1 = BatchNorm(planes)
34
+ self.relu = nn.ReLU(inplace=True)
35
+ self.conv2 = conv3x3(planes, planes,
36
+ padding=dilation[1], dilation=dilation[1])
37
+ self.bn2 = BatchNorm(planes)
38
+ self.downsample = downsample
39
+ self.stride = stride
40
+ self.residual = residual
41
+
42
+ def forward(self, x):
43
+ residual = x
44
+
45
+ out = self.conv1(x)
46
+ out = self.bn1(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv2(out)
50
+ out = self.bn2(out)
51
+
52
+ if self.downsample is not None:
53
+ residual = self.downsample(x)
54
+ if self.residual:
55
+ out += residual
56
+ out = self.relu(out)
57
+
58
+ return out
59
+
60
+
61
+ class Bottleneck(nn.Module):
62
+ expansion = 4
63
+
64
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
65
+ dilation=(1, 1), residual=True, BatchNorm=None):
66
+ super(Bottleneck, self).__init__()
67
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68
+ self.bn1 = BatchNorm(planes)
69
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70
+ padding=dilation[1], bias=False,
71
+ dilation=dilation[1])
72
+ self.bn2 = BatchNorm(planes)
73
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
74
+ self.bn3 = BatchNorm(planes * 4)
75
+ self.relu = nn.ReLU(inplace=True)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+
79
+ def forward(self, x):
80
+ residual = x
81
+
82
+ out = self.conv1(x)
83
+ out = self.bn1(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv2(out)
87
+ out = self.bn2(out)
88
+ out = self.relu(out)
89
+
90
+ out = self.conv3(out)
91
+ out = self.bn3(out)
92
+
93
+ if self.downsample is not None:
94
+ residual = self.downsample(x)
95
+
96
+ out += residual
97
+ out = self.relu(out)
98
+
99
+ return out
100
+
101
+
102
+ class DRN(nn.Module):
103
+
104
+ def __init__(self, block, layers, arch='D',
105
+ channels=(16, 32, 64, 128, 256, 512, 512, 512),
106
+ BatchNorm=None):
107
+ super(DRN, self).__init__()
108
+ self.inplanes = channels[0]
109
+ self.out_dim = channels[-1]
110
+ self.arch = arch
111
+
112
+ if arch == 'C':
113
+ self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
114
+ padding=3, bias=False)
115
+ self.bn1 = BatchNorm(channels[0])
116
+ self.relu = nn.ReLU(inplace=True)
117
+
118
+ self.layer1 = self._make_layer(
119
+ BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
120
+ self.layer2 = self._make_layer(
121
+ BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
122
+
123
+ elif arch == 'D':
124
+ self.layer0 = nn.Sequential(
125
+ nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
126
+ bias=False),
127
+ BatchNorm(channels[0]),
128
+ nn.ReLU(inplace=True)
129
+ )
130
+
131
+ self.layer1 = self._make_conv_layers(
132
+ channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
133
+ self.layer2 = self._make_conv_layers(
134
+ channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
135
+
136
+ self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm)
137
+ self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm)
138
+ self.layer5 = self._make_layer(block, channels[4], layers[4],
139
+ dilation=2, new_level=False, BatchNorm=BatchNorm)
140
+ self.layer6 = None if layers[5] == 0 else \
141
+ self._make_layer(block, channels[5], layers[5], dilation=4,
142
+ new_level=False, BatchNorm=BatchNorm)
143
+
144
+ if arch == 'C':
145
+ self.layer7 = None if layers[6] == 0 else \
146
+ self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
147
+ new_level=False, residual=False, BatchNorm=BatchNorm)
148
+ self.layer8 = None if layers[7] == 0 else \
149
+ self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
150
+ new_level=False, residual=False, BatchNorm=BatchNorm)
151
+ elif arch == 'D':
152
+ self.layer7 = None if layers[6] == 0 else \
153
+ self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm)
154
+ self.layer8 = None if layers[7] == 0 else \
155
+ self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm)
156
+
157
+ self._init_weight()
158
+
159
+ def _init_weight(self):
160
+ for m in self.modules():
161
+ if isinstance(m, nn.Conv2d):
162
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
163
+ m.weight.data.normal_(0, math.sqrt(2. / n))
164
+ elif isinstance(m, SynchronizedBatchNorm2d):
165
+ m.weight.data.fill_(1)
166
+ m.bias.data.zero_()
167
+ elif isinstance(m, nn.BatchNorm2d):
168
+ m.weight.data.fill_(1)
169
+ m.bias.data.zero_()
170
+
171
+
172
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
173
+ new_level=True, residual=True, BatchNorm=None):
174
+ assert dilation == 1 or dilation % 2 == 0
175
+ downsample = None
176
+ if stride != 1 or self.inplanes != planes * block.expansion:
177
+ downsample = nn.Sequential(
178
+ nn.Conv2d(self.inplanes, planes * block.expansion,
179
+ kernel_size=1, stride=stride, bias=False),
180
+ BatchNorm(planes * block.expansion),
181
+ )
182
+
183
+ layers = list()
184
+ layers.append(block(
185
+ self.inplanes, planes, stride, downsample,
186
+ dilation=(1, 1) if dilation == 1 else (
187
+ dilation // 2 if new_level else dilation, dilation),
188
+ residual=residual, BatchNorm=BatchNorm))
189
+ self.inplanes = planes * block.expansion
190
+ for i in range(1, blocks):
191
+ layers.append(block(self.inplanes, planes, residual=residual,
192
+ dilation=(dilation, dilation), BatchNorm=BatchNorm))
193
+
194
+ return nn.Sequential(*layers)
195
+
196
+ def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None):
197
+ modules = []
198
+ for i in range(convs):
199
+ modules.extend([
200
+ nn.Conv2d(self.inplanes, channels, kernel_size=3,
201
+ stride=stride if i == 0 else 1,
202
+ padding=dilation, bias=False, dilation=dilation),
203
+ BatchNorm(channels),
204
+ nn.ReLU(inplace=True)])
205
+ self.inplanes = channels
206
+ return nn.Sequential(*modules)
207
+
208
+ def forward(self, x):
209
+ if self.arch == 'C':
210
+ x = self.conv1(x)
211
+ x = self.bn1(x)
212
+ x = self.relu(x)
213
+ elif self.arch == 'D':
214
+ x = self.layer0(x)
215
+
216
+ x = self.layer1(x)
217
+ x = self.layer2(x)
218
+
219
+ x = self.layer3(x)
220
+ low_level_feat = x
221
+
222
+ x = self.layer4(x)
223
+ x = self.layer5(x)
224
+
225
+ if self.layer6 is not None:
226
+ x = self.layer6(x)
227
+
228
+ if self.layer7 is not None:
229
+ x = self.layer7(x)
230
+
231
+ if self.layer8 is not None:
232
+ x = self.layer8(x)
233
+
234
+ return x, low_level_feat
235
+
236
+
237
+ class DRN_A(nn.Module):
238
+
239
+ def __init__(self, block, layers, BatchNorm=None):
240
+ self.inplanes = 64
241
+ super(DRN_A, self).__init__()
242
+ self.out_dim = 512 * block.expansion
243
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
244
+ bias=False)
245
+ self.bn1 = BatchNorm(64)
246
+ self.relu = nn.ReLU(inplace=True)
247
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
248
+ self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm)
249
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm)
250
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
251
+ dilation=2, BatchNorm=BatchNorm)
252
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
253
+ dilation=4, BatchNorm=BatchNorm)
254
+
255
+ self._init_weight()
256
+
257
+ def _init_weight(self):
258
+ for m in self.modules():
259
+ if isinstance(m, nn.Conv2d):
260
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
261
+ m.weight.data.normal_(0, math.sqrt(2. / n))
262
+ elif isinstance(m, SynchronizedBatchNorm2d):
263
+ m.weight.data.fill_(1)
264
+ m.bias.data.zero_()
265
+ elif isinstance(m, nn.BatchNorm2d):
266
+ m.weight.data.fill_(1)
267
+ m.bias.data.zero_()
268
+
269
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
270
+ downsample = None
271
+ if stride != 1 or self.inplanes != planes * block.expansion:
272
+ downsample = nn.Sequential(
273
+ nn.Conv2d(self.inplanes, planes * block.expansion,
274
+ kernel_size=1, stride=stride, bias=False),
275
+ BatchNorm(planes * block.expansion),
276
+ )
277
+
278
+ layers = []
279
+ layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm))
280
+ self.inplanes = planes * block.expansion
281
+ for i in range(1, blocks):
282
+ layers.append(block(self.inplanes, planes,
283
+ dilation=(dilation, dilation, ), BatchNorm=BatchNorm))
284
+
285
+ return nn.Sequential(*layers)
286
+
287
+ def forward(self, x):
288
+ x = self.conv1(x)
289
+ x = self.bn1(x)
290
+ x = self.relu(x)
291
+ x = self.maxpool(x)
292
+
293
+ x = self.layer1(x)
294
+ x = self.layer2(x)
295
+ x = self.layer3(x)
296
+ x = self.layer4(x)
297
+
298
+ return x
299
+
300
+ def drn_a_50(BatchNorm, pretrained=True):
301
+ model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm)
302
+ if pretrained:
303
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
304
+ return model
305
+
306
+
307
+ def drn_c_26(BatchNorm, pretrained=True):
308
+ model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm)
309
+ if pretrained:
310
+ pretrained = model_zoo.load_url(model_urls['drn-c-26'])
311
+ del pretrained['fc.weight']
312
+ del pretrained['fc.bias']
313
+ model.load_state_dict(pretrained)
314
+ return model
315
+
316
+
317
+ def drn_c_42(BatchNorm, pretrained=True):
318
+ model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
319
+ if pretrained:
320
+ pretrained = model_zoo.load_url(model_urls['drn-c-42'])
321
+ del pretrained['fc.weight']
322
+ del pretrained['fc.bias']
323
+ model.load_state_dict(pretrained)
324
+ return model
325
+
326
+
327
+ def drn_c_58(BatchNorm, pretrained=True):
328
+ model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
329
+ if pretrained:
330
+ pretrained = model_zoo.load_url(model_urls['drn-c-58'])
331
+ del pretrained['fc.weight']
332
+ del pretrained['fc.bias']
333
+ model.load_state_dict(pretrained)
334
+ return model
335
+
336
+
337
+ def drn_d_22(BatchNorm, pretrained=True):
338
+ model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm)
339
+ if pretrained:
340
+ pretrained = model_zoo.load_url(model_urls['drn-d-22'])
341
+ del pretrained['fc.weight']
342
+ del pretrained['fc.bias']
343
+ model.load_state_dict(pretrained)
344
+ return model
345
+
346
+
347
+ def drn_d_24(BatchNorm, pretrained=True):
348
+ model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm)
349
+ if pretrained:
350
+ pretrained = model_zoo.load_url(model_urls['drn-d-24'])
351
+ del pretrained['fc.weight']
352
+ del pretrained['fc.bias']
353
+ model.load_state_dict(pretrained)
354
+ return model
355
+
356
+
357
+ def drn_d_38(BatchNorm, pretrained=True):
358
+ model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
359
+ if pretrained:
360
+ pretrained = model_zoo.load_url(model_urls['drn-d-38'])
361
+ del pretrained['fc.weight']
362
+ del pretrained['fc.bias']
363
+ model.load_state_dict(pretrained)
364
+ return model
365
+
366
+
367
+ def drn_d_40(BatchNorm, pretrained=True):
368
+ model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm)
369
+ if pretrained:
370
+ pretrained = model_zoo.load_url(model_urls['drn-d-40'])
371
+ del pretrained['fc.weight']
372
+ del pretrained['fc.bias']
373
+ model.load_state_dict(pretrained)
374
+ return model
375
+
376
+
377
+ def drn_d_54(BatchNorm, pretrained=True):
378
+ model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
379
+ if pretrained:
380
+ pretrained = model_zoo.load_url(model_urls['drn-d-54'])
381
+ del pretrained['fc.weight']
382
+ del pretrained['fc.bias']
383
+ model.load_state_dict(pretrained)
384
+ return model
385
+
386
+
387
+ def drn_d_105(BatchNorm, pretrained=True):
388
+ model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
389
+ if pretrained:
390
+ pretrained = model_zoo.load_url(model_urls['drn-d-105'])
391
+ del pretrained['fc.weight']
392
+ del pretrained['fc.bias']
393
+ model.load_state_dict(pretrained)
394
+ return model
395
+
396
+ if __name__ == "__main__":
397
+ import torch
398
+ model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True)
399
+ input = torch.rand(1, 3, 512, 512)
400
+ output, low_level_feat = model(input)
401
+ print(output.size())
402
+ print(low_level_feat.size())