doctra 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. doctra/__init__.py +4 -0
  2. doctra/cli/main.py +168 -0
  3. doctra/engines/image_restoration/__init__.py +10 -0
  4. doctra/engines/image_restoration/docres_engine.py +566 -0
  5. doctra/engines/vlm/service.py +0 -12
  6. doctra/parsers/enhanced_pdf_parser.py +370 -0
  7. doctra/parsers/structured_pdf_parser.py +11 -60
  8. doctra/parsers/table_chart_extractor.py +8 -44
  9. doctra/third_party/docres/data/MBD/MBD.py +110 -0
  10. doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
  11. doctra/third_party/docres/data/MBD/infer.py +151 -0
  12. doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
  13. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  14. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  15. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  16. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  17. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  18. doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
  19. doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
  20. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  21. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  22. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  23. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  24. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  25. doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
  26. doctra/third_party/docres/inference.py +370 -0
  27. doctra/third_party/docres/models/restormer_arch.py +308 -0
  28. doctra/third_party/docres/utils.py +464 -0
  29. doctra/ui/app.py +5 -32
  30. doctra/utils/progress.py +13 -98
  31. doctra/utils/structured_utils.py +45 -49
  32. doctra/version.py +1 -1
  33. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/METADATA +1 -1
  34. doctra-0.4.0.dist-info/RECORD +67 -0
  35. doctra-0.3.2.dist-info/RECORD +0 -44
  36. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/WHEEL +0 -0
  37. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/licenses/LICENSE +0 -0
  38. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,81 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5
+ from model.deep_lab_model.aspp import build_aspp
6
+ from model.deep_lab_model.decoder import build_decoder
7
+ from model.deep_lab_model.backbone import build_backbone
8
+
9
+ class DeepLab(nn.Module):
10
+ def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
11
+ sync_bn=True, freeze_bn=False):
12
+ super(DeepLab, self).__init__()
13
+ if backbone == 'drn':
14
+ output_stride = 8
15
+
16
+ if sync_bn == True:
17
+ BatchNorm = SynchronizedBatchNorm2d
18
+ else:
19
+ BatchNorm = nn.BatchNorm2d
20
+
21
+ self.backbone = build_backbone(backbone, output_stride, BatchNorm)
22
+ self.aspp = build_aspp(backbone, output_stride, BatchNorm)
23
+ self.decoder = build_decoder(num_classes, backbone, BatchNorm)
24
+
25
+ self.freeze_bn = freeze_bn
26
+
27
+ def forward(self, input):
28
+ x, low_level_feat = self.backbone(input)
29
+ x = self.aspp(x)
30
+ x = self.decoder(x, low_level_feat)
31
+ x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
32
+
33
+ return x
34
+
35
+ def freeze_bn(self):
36
+ for m in self.modules():
37
+ if isinstance(m, SynchronizedBatchNorm2d):
38
+ m.eval()
39
+ elif isinstance(m, nn.BatchNorm2d):
40
+ m.eval()
41
+
42
+ def get_1x_lr_params(self):
43
+ modules = [self.backbone]
44
+ for i in range(len(modules)):
45
+ for m in modules[i].named_modules():
46
+ if self.freeze_bn:
47
+ if isinstance(m[1], nn.Conv2d):
48
+ for p in m[1].parameters():
49
+ if p.requires_grad:
50
+ yield p
51
+ else:
52
+ if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
53
+ or isinstance(m[1], nn.BatchNorm2d):
54
+ for p in m[1].parameters():
55
+ if p.requires_grad:
56
+ yield p
57
+
58
+ def get_10x_lr_params(self):
59
+ modules = [self.aspp, self.decoder]
60
+ for i in range(len(modules)):
61
+ for m in modules[i].named_modules():
62
+ if self.freeze_bn:
63
+ if isinstance(m[1], nn.Conv2d):
64
+ for p in m[1].parameters():
65
+ if p.requires_grad:
66
+ yield p
67
+ else:
68
+ if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
69
+ or isinstance(m[1], nn.BatchNorm2d):
70
+ for p in m[1].parameters():
71
+ if p.requires_grad:
72
+ yield p
73
+
74
+ if __name__ == "__main__":
75
+ model = DeepLab(backbone='mobilenet', output_stride=16)
76
+ model.eval()
77
+ input = torch.rand(1, 3, 513, 513)
78
+ output = model(input)
79
+ print(output.size())
80
+
81
+
@@ -0,0 +1,12 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
@@ -0,0 +1,282 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ def forward(self, input):
49
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50
+ if not (self._is_parallel and self.training):
51
+ return F.batch_norm(
52
+ input, self.running_mean, self.running_var, self.weight, self.bias,
53
+ self.training, self.momentum, self.eps)
54
+
55
+ # Resize the input to (B, C, -1).
56
+ input_shape = input.size()
57
+ input = input.view(input.size(0), self.num_features, -1)
58
+
59
+ # Compute the sum and square-sum.
60
+ sum_size = input.size(0) * input.size(2)
61
+ input_sum = _sum_ft(input)
62
+ input_ssum = _sum_ft(input ** 2)
63
+
64
+ # Reduce-and-broadcast the statistics.
65
+ if self._parallel_id == 0:
66
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67
+ else:
68
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69
+
70
+ # Compute the output.
71
+ if self.affine:
72
+ # MJY:: Fuse the multiplication for speed.
73
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74
+ else:
75
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76
+
77
+ # Reshape it.
78
+ return output.view(input_shape)
79
+
80
+ def __data_parallel_replicate__(self, ctx, copy_id):
81
+ self._is_parallel = True
82
+ self._parallel_id = copy_id
83
+
84
+ # parallel_id == 0 means master device.
85
+ if self._parallel_id == 0:
86
+ ctx.sync_master = self._sync_master
87
+ else:
88
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89
+
90
+ def _data_parallel_master(self, intermediates):
91
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92
+
93
+ # Always using same "device order" makes the ReduceAdd operation faster.
94
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
95
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96
+
97
+ to_reduce = [i[1][:2] for i in intermediates]
98
+ to_reduce = [j for i in to_reduce for j in i] # flatten
99
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
100
+
101
+ sum_size = sum([i[1].sum_size for i in intermediates])
102
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104
+
105
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106
+
107
+ outputs = []
108
+ for i, rec in enumerate(intermediates):
109
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
110
+
111
+ return outputs
112
+
113
+ def _compute_mean_std(self, sum_, ssum, size):
114
+ """Compute the mean and standard-deviation with sum and square-sum. This method
115
+ also maintains the moving average on the master device."""
116
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117
+ mean = sum_ / size
118
+ sumvar = ssum - sum_ * mean
119
+ unbias_var = sumvar / (size - 1)
120
+ bias_var = sumvar / size
121
+
122
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124
+
125
+ return mean, bias_var.clamp(self.eps) ** -0.5
126
+
127
+
128
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130
+ mini-batch.
131
+ .. math::
132
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
133
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
134
+ standard-deviation are reduced across all devices during training.
135
+ For example, when one uses `nn.DataParallel` to wrap the network during
136
+ training, PyTorch's implementation normalize the tensor on each device using
137
+ the statistics only on that device, which accelerated the computation and
138
+ is also easy to implement, but the statistics might be inaccurate.
139
+ Instead, in this synchronized version, the statistics will be computed
140
+ over all training samples distributed on multiple devices.
141
+
142
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
143
+ as the built-in PyTorch implementation.
144
+ The mean and standard-deviation are calculated per-dimension over
145
+ the mini-batches and gamma and beta are learnable parameter vectors
146
+ of size C (where C is the input size).
147
+ During training, this layer keeps a running estimate of its computed mean
148
+ and variance. The running sum is kept with a default momentum of 0.1.
149
+ During evaluation, this running mean/variance is used for normalization.
150
+ Because the BatchNorm is done over the `C` dimension, computing statistics
151
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
152
+ Args:
153
+ num_features: num_features from an expected input of size
154
+ `batch_size x num_features [x width]`
155
+ eps: a value added to the denominator for numerical stability.
156
+ Default: 1e-5
157
+ momentum: the value used for the running_mean and running_var
158
+ computation. Default: 0.1
159
+ affine: a boolean value that when set to ``True``, gives the layer learnable
160
+ affine parameters. Default: ``True``
161
+ Shape:
162
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
163
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
164
+ Examples:
165
+ >>> # With Learnable Parameters
166
+ >>> m = SynchronizedBatchNorm1d(100)
167
+ >>> # Without Learnable Parameters
168
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
169
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
170
+ >>> output = m(input)
171
+ """
172
+
173
+ def _check_input_dim(self, input):
174
+ if input.dim() != 2 and input.dim() != 3:
175
+ raise ValueError('expected 2D or 3D input (got {}D input)'
176
+ .format(input.dim()))
177
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
178
+
179
+
180
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
181
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
182
+ of 3d inputs
183
+ .. math::
184
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
185
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
186
+ standard-deviation are reduced across all devices during training.
187
+ For example, when one uses `nn.DataParallel` to wrap the network during
188
+ training, PyTorch's implementation normalize the tensor on each device using
189
+ the statistics only on that device, which accelerated the computation and
190
+ is also easy to implement, but the statistics might be inaccurate.
191
+ Instead, in this synchronized version, the statistics will be computed
192
+ over all training samples distributed on multiple devices.
193
+
194
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
195
+ as the built-in PyTorch implementation.
196
+ The mean and standard-deviation are calculated per-dimension over
197
+ the mini-batches and gamma and beta are learnable parameter vectors
198
+ of size C (where C is the input size).
199
+ During training, this layer keeps a running estimate of its computed mean
200
+ and variance. The running sum is kept with a default momentum of 0.1.
201
+ During evaluation, this running mean/variance is used for normalization.
202
+ Because the BatchNorm is done over the `C` dimension, computing statistics
203
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
204
+ Args:
205
+ num_features: num_features from an expected input of
206
+ size batch_size x num_features x height x width
207
+ eps: a value added to the denominator for numerical stability.
208
+ Default: 1e-5
209
+ momentum: the value used for the running_mean and running_var
210
+ computation. Default: 0.1
211
+ affine: a boolean value that when set to ``True``, gives the layer learnable
212
+ affine parameters. Default: ``True``
213
+ Shape:
214
+ - Input: :math:`(N, C, H, W)`
215
+ - Output: :math:`(N, C, H, W)` (same shape as input)
216
+ Examples:
217
+ >>> # With Learnable Parameters
218
+ >>> m = SynchronizedBatchNorm2d(100)
219
+ >>> # Without Learnable Parameters
220
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
221
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
222
+ >>> output = m(input)
223
+ """
224
+
225
+ def _check_input_dim(self, input):
226
+ if input.dim() != 4:
227
+ raise ValueError('expected 4D input (got {}D input)'
228
+ .format(input.dim()))
229
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
230
+
231
+
232
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
233
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
234
+ of 4d inputs
235
+ .. math::
236
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
237
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
238
+ standard-deviation are reduced across all devices during training.
239
+ For example, when one uses `nn.DataParallel` to wrap the network during
240
+ training, PyTorch's implementation normalize the tensor on each device using
241
+ the statistics only on that device, which accelerated the computation and
242
+ is also easy to implement, but the statistics might be inaccurate.
243
+ Instead, in this synchronized version, the statistics will be computed
244
+ over all training samples distributed on multiple devices.
245
+
246
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
247
+ as the built-in PyTorch implementation.
248
+ The mean and standard-deviation are calculated per-dimension over
249
+ the mini-batches and gamma and beta are learnable parameter vectors
250
+ of size C (where C is the input size).
251
+ During training, this layer keeps a running estimate of its computed mean
252
+ and variance. The running sum is kept with a default momentum of 0.1.
253
+ During evaluation, this running mean/variance is used for normalization.
254
+ Because the BatchNorm is done over the `C` dimension, computing statistics
255
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
256
+ or Spatio-temporal BatchNorm
257
+ Args:
258
+ num_features: num_features from an expected input of
259
+ size batch_size x num_features x depth x height x width
260
+ eps: a value added to the denominator for numerical stability.
261
+ Default: 1e-5
262
+ momentum: the value used for the running_mean and running_var
263
+ computation. Default: 0.1
264
+ affine: a boolean value that when set to ``True``, gives the layer learnable
265
+ affine parameters. Default: ``True``
266
+ Shape:
267
+ - Input: :math:`(N, C, D, H, W)`
268
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
269
+ Examples:
270
+ >>> # With Learnable Parameters
271
+ >>> m = SynchronizedBatchNorm3d(100)
272
+ >>> # Without Learnable Parameters
273
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
274
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
275
+ >>> output = m(input)
276
+ """
277
+
278
+ def _check_input_dim(self, input):
279
+ if input.dim() != 5:
280
+ raise ValueError('expected 5D input (got {}D input)'
281
+ .format(input.dim()))
282
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
@@ -0,0 +1,129 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61
+ and passed to a registered callback.
62
+ - After receiving the messages, the master device should gather the information and determine to message passed
63
+ back to each slave devices.
64
+ """
65
+
66
+ def __init__(self, master_callback):
67
+ """
68
+ Args:
69
+ master_callback: a callback to be invoked after having collected messages from slave devices.
70
+ """
71
+ self._master_callback = master_callback
72
+ self._queue = queue.Queue()
73
+ self._registry = collections.OrderedDict()
74
+ self._activated = False
75
+
76
+ def __getstate__(self):
77
+ return {'master_callback': self._master_callback}
78
+
79
+ def __setstate__(self, state):
80
+ self.__init__(state['master_callback'])
81
+
82
+ def register_slave(self, identifier):
83
+ """
84
+ Register an slave device.
85
+ Args:
86
+ identifier: an identifier, usually is the device id.
87
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
88
+ """
89
+ if self._activated:
90
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
91
+ self._activated = False
92
+ self._registry.clear()
93
+ future = FutureResult()
94
+ self._registry[identifier] = _MasterRegistry(future)
95
+ return SlavePipe(identifier, self._queue, future)
96
+
97
+ def run_master(self, master_msg):
98
+ """
99
+ Main entry for the master device in each forward pass.
100
+ The messages were first collected from each devices (including the master device), and then
101
+ an callback will be invoked to compute the message to be sent back to each devices
102
+ (including the master device).
103
+ Args:
104
+ master_msg: the message that the master want to send to itself. This will be placed as the first
105
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106
+ Returns: the message to be sent back to the master device.
107
+ """
108
+ self._activated = True
109
+
110
+ intermediates = [(0, master_msg)]
111
+ for i in range(self.nr_slaves):
112
+ intermediates.append(self._queue.get())
113
+
114
+ results = self._master_callback(intermediates)
115
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
116
+
117
+ for i, res in results:
118
+ if i == 0:
119
+ continue
120
+ self._registry[i].result.put(res)
121
+
122
+ for i in range(self.nr_slaves):
123
+ assert self._queue.get() is True
124
+
125
+ return results[0][1]
126
+
127
+ @property
128
+ def nr_slaves(self):
129
+ return len(self._registry)
@@ -0,0 +1,88 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
31
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
32
+ (shared among multiple copies of this module on different devices).
33
+ Through this context, different copies can share some information.
34
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
35
+ of any slave copies.
36
+ """
37
+ master_copy = modules[0]
38
+ nr_modules = len(list(master_copy.modules()))
39
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
40
+
41
+ for i, module in enumerate(modules):
42
+ for j, m in enumerate(module.modules()):
43
+ if hasattr(m, '__data_parallel_replicate__'):
44
+ m.__data_parallel_replicate__(ctxs[j], i)
45
+
46
+
47
+ class DataParallelWithCallback(DataParallel):
48
+ """
49
+ Data Parallel with a replication callback.
50
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
51
+ original `replicate` function.
52
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
53
+ Examples:
54
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
55
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
56
+ # sync_bn.__data_parallel_replicate__ will be invoked.
57
+ """
58
+
59
+ def replicate(self, module, device_ids):
60
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
61
+ execute_replication_callbacks(modules)
62
+ return modules
63
+
64
+
65
+ def patch_replication_callback(data_parallel):
66
+ """
67
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
68
+ Useful when you have customized `DataParallel` implementation.
69
+ Examples:
70
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
71
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
72
+ > patch_replication_callback(sync_bn)
73
+ # this is equivalent to
74
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
75
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
76
+ """
77
+
78
+ assert isinstance(data_parallel, DataParallel)
79
+
80
+ old_replicate = data_parallel.replicate
81
+
82
+ @functools.wraps(old_replicate)
83
+ def new_replicate(module, device_ids):
84
+ modules = old_replicate(module, device_ids)
85
+ execute_replication_callbacks(modules)
86
+ return modules
87
+
88
+ data_parallel.replicate = new_replicate
@@ -0,0 +1,29 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )