docling-ibm-models 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
- docling_ibm_models/tableformer/__init__.py +0 -0
- docling_ibm_models/tableformer/common.py +200 -0
- docling_ibm_models/tableformer/data_management/__init__.py +0 -0
- docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
- docling_ibm_models/tableformer/data_management/functional.py +574 -0
- docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
- docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
- docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
- docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
- docling_ibm_models/tableformer/data_management/transforms.py +396 -0
- docling_ibm_models/tableformer/models/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/base_model.py +279 -0
- docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
- docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
- docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
- docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
- docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
- docling_ibm_models/tableformer/otsl.py +541 -0
- docling_ibm_models/tableformer/settings.py +90 -0
- docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
- docling_ibm_models/tableformer/test_prepare_image.py +99 -0
- docling_ibm_models/tableformer/utils/__init__.py +0 -0
- docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
- docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
- docling_ibm_models/tableformer/utils/utils.py +376 -0
- docling_ibm_models/tableformer/utils/variance.py +175 -0
- docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
- docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
- docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
- docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,376 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
import torch.nn as nn
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from PIL import Image
|
10
|
+
from torchvision.models.resnet import BasicBlock, conv1x1
|
11
|
+
from torchvision.ops.boxes import box_area
|
12
|
+
|
13
|
+
|
14
|
+
def remove_padding(seq):
|
15
|
+
r"""
|
16
|
+
Remove the trailing zeros from the provided input
|
17
|
+
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
list: List of integers
|
21
|
+
Predicted sequence
|
22
|
+
|
23
|
+
Returns
|
24
|
+
-------
|
25
|
+
list: List of integers
|
26
|
+
The part of the input before the zero padding
|
27
|
+
|
28
|
+
"""
|
29
|
+
pad_len = 0
|
30
|
+
for x in reversed(seq):
|
31
|
+
if x != 0:
|
32
|
+
break
|
33
|
+
pad_len += 1
|
34
|
+
if pad_len == 0:
|
35
|
+
return seq, 0
|
36
|
+
|
37
|
+
un_padded = seq[:-pad_len]
|
38
|
+
return un_padded, pad_len
|
39
|
+
|
40
|
+
|
41
|
+
def probabilities_to_predictions(probabilities):
|
42
|
+
r"""
|
43
|
+
Convert probabilities to predictions
|
44
|
+
|
45
|
+
Parameters
|
46
|
+
----------
|
47
|
+
probabilities : Tensor[batch_size, vocab_size, seq_len]
|
48
|
+
All log probabilities coming out at the last stage of the decoder
|
49
|
+
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
predictions : tensor [batch_size, output_sequence_length]
|
53
|
+
The prediceted trags
|
54
|
+
|
55
|
+
"""
|
56
|
+
# max_idx: [batch_size, seq_len]
|
57
|
+
max_idx = torch.argmax(probabilities, dim=1)
|
58
|
+
return max_idx
|
59
|
+
|
60
|
+
|
61
|
+
def print_target_predict(target, predictions, filenames=None, batch_idx=0):
|
62
|
+
r"""
|
63
|
+
For the Tags, print the target and predicted tensors for the specified batch index
|
64
|
+
|
65
|
+
We expect to have the batch size as the first dimension.
|
66
|
+
Only the specified batch is extractred and the remaining dimenions are flattened.
|
67
|
+
The results are printed as 2 lists with the target on top and the predictions below underlined
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
---------
|
71
|
+
target : tensor [batch_size, output_sequence_length]
|
72
|
+
The ground truth tags
|
73
|
+
|
74
|
+
predictions : tensor [batch_size, output_sequence_length]
|
75
|
+
The prediceted trags
|
76
|
+
|
77
|
+
filenames : list of string
|
78
|
+
The actual filename that provides the data
|
79
|
+
|
80
|
+
batch_idx : int
|
81
|
+
Which index in the batch dimension will be printed
|
82
|
+
"""
|
83
|
+
target_flat = target[batch_idx].flatten()
|
84
|
+
predictions_flat = predictions[batch_idx].flatten()
|
85
|
+
target_label = "target"
|
86
|
+
predict_label = "predict"
|
87
|
+
if filenames is not None:
|
88
|
+
target_label = filenames[batch_idx]
|
89
|
+
label_len = max(len(target_label), len(predict_label))
|
90
|
+
print("{}: {}".format(target_label.ljust(label_len, " "), target_flat.tolist()))
|
91
|
+
print(
|
92
|
+
"{}: {}".format(predict_label.ljust(label_len, " "), predictions_flat.tolist())
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
def load_image(full_fn):
|
97
|
+
r"""
|
98
|
+
Load an image from the disk as a numpy array
|
99
|
+
|
100
|
+
Parameters
|
101
|
+
----------
|
102
|
+
full_fn : string
|
103
|
+
The full path filename of the image
|
104
|
+
|
105
|
+
Results
|
106
|
+
-------
|
107
|
+
img : numpy array: (channels, width, height)
|
108
|
+
The loaded image as a numpy array
|
109
|
+
"""
|
110
|
+
with Image.open(full_fn) as f:
|
111
|
+
img = np.asarray(f) # (width, height, channels)
|
112
|
+
img = img.transpose(2, 0, 1) # (channels, width, height)
|
113
|
+
return img
|
114
|
+
|
115
|
+
|
116
|
+
def resnet_block(stride=1):
|
117
|
+
layers = []
|
118
|
+
downsample = nn.Sequential(
|
119
|
+
conv1x1(256, 512, stride),
|
120
|
+
nn.BatchNorm2d(512),
|
121
|
+
)
|
122
|
+
layers.append(BasicBlock(256, 512, stride, downsample))
|
123
|
+
layers.append(BasicBlock(512, 512, 1))
|
124
|
+
return nn.Sequential(*layers)
|
125
|
+
|
126
|
+
|
127
|
+
def repackage_hidden(h):
|
128
|
+
r"""
|
129
|
+
Wraps hidden states in new Tensors, to detach them from their history.
|
130
|
+
"""
|
131
|
+
if isinstance(h, torch.Tensor):
|
132
|
+
return h.detach()
|
133
|
+
else:
|
134
|
+
return tuple(repackage_hidden(v) for v in h)
|
135
|
+
|
136
|
+
|
137
|
+
def accuracy(scores, targets, k):
|
138
|
+
"""
|
139
|
+
Computes top-k accuracy, from predicted and true labels.
|
140
|
+
|
141
|
+
:param scores: scores from the model
|
142
|
+
:param targets: true labels
|
143
|
+
:param k: k in top-k accuracy
|
144
|
+
:return: top-k accuracy
|
145
|
+
"""
|
146
|
+
|
147
|
+
batch_size = targets.size(0)
|
148
|
+
_, ind = scores.topk(k, 1, True, True)
|
149
|
+
correct = ind.eq(targets.view(-1, 1).expand_as(ind))
|
150
|
+
correct_total = correct.view(-1).float().sum() # 0D tensor
|
151
|
+
return correct_total.item() * (100.0 / batch_size)
|
152
|
+
|
153
|
+
|
154
|
+
def clip_gradient(optimizer, grad_clip):
|
155
|
+
"""
|
156
|
+
Clips gradients computed during backpropagation to avoid explosion of gradients.
|
157
|
+
|
158
|
+
:param optimizer: optimizer with the gradients to be clipped
|
159
|
+
:param grad_clip: clip value
|
160
|
+
"""
|
161
|
+
for group in optimizer.param_groups:
|
162
|
+
for param in group["params"]:
|
163
|
+
if param.grad is not None:
|
164
|
+
param.grad.data.clamp_(-grad_clip, grad_clip)
|
165
|
+
|
166
|
+
|
167
|
+
class AverageMeter(object):
|
168
|
+
"""
|
169
|
+
Keeps track of most recent, average, sum, and count of a metric.
|
170
|
+
"""
|
171
|
+
|
172
|
+
def __init__(self):
|
173
|
+
self.reset()
|
174
|
+
|
175
|
+
def reset(self):
|
176
|
+
self.val = 0
|
177
|
+
self.avg = 0
|
178
|
+
self.sum = 0
|
179
|
+
self.count = 0
|
180
|
+
|
181
|
+
def update(self, val, n=1):
|
182
|
+
self.val = val
|
183
|
+
self.sum += val * n
|
184
|
+
self.count += n
|
185
|
+
self.avg = self.sum / self.count
|
186
|
+
|
187
|
+
|
188
|
+
@torch.no_grad()
|
189
|
+
def bip_accuracy(output, target, topk=(1,)):
|
190
|
+
"""Computes the precision@k for the specified values of k"""
|
191
|
+
if target.numel() == 0:
|
192
|
+
return [torch.zeros([], device=output.device)]
|
193
|
+
maxk = max(topk)
|
194
|
+
batch_size = target.size(0)
|
195
|
+
|
196
|
+
_, pred = output.topk(maxk, 1, True, True)
|
197
|
+
pred = pred.t()
|
198
|
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
199
|
+
|
200
|
+
res = []
|
201
|
+
for k in topk:
|
202
|
+
correct_k = correct[:k].view(-1).float().sum(0)
|
203
|
+
res.append(correct_k.mul_(100.0 / batch_size))
|
204
|
+
return res
|
205
|
+
|
206
|
+
|
207
|
+
def box_cxcywh_to_xyxy(x):
|
208
|
+
x_c, y_c, w, h = x.unbind(-1)
|
209
|
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
210
|
+
return torch.stack(b, dim=-1)
|
211
|
+
|
212
|
+
|
213
|
+
def box_xyxy_to_cxcywh(x):
|
214
|
+
x0, y0, x1, y1 = x.unbind(-1)
|
215
|
+
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
|
216
|
+
return torch.stack(b, dim=-1)
|
217
|
+
|
218
|
+
|
219
|
+
# modified from torchvision to also return the union
|
220
|
+
def box_iou(boxes1, boxes2):
|
221
|
+
area1 = box_area(boxes1)
|
222
|
+
area2 = box_area(boxes2)
|
223
|
+
|
224
|
+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
225
|
+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
226
|
+
|
227
|
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
228
|
+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
229
|
+
|
230
|
+
union = area1[:, None] + area2 - inter
|
231
|
+
|
232
|
+
iou = inter / union
|
233
|
+
return iou, union
|
234
|
+
|
235
|
+
|
236
|
+
def generalized_box_iou(boxes1, boxes2):
|
237
|
+
"""
|
238
|
+
Generalized IoU from https://giou.stanford.edu/
|
239
|
+
|
240
|
+
The boxes should be in [x0, y0, x1, y1] format
|
241
|
+
|
242
|
+
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
243
|
+
and M = len(boxes2)
|
244
|
+
"""
|
245
|
+
# degenerate boxes gives inf / nan results
|
246
|
+
# so do an early check
|
247
|
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
248
|
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
249
|
+
iou, union = box_iou(boxes1, boxes2)
|
250
|
+
|
251
|
+
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
252
|
+
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
253
|
+
|
254
|
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
255
|
+
area = wh[:, :, 0] * wh[:, :, 1]
|
256
|
+
|
257
|
+
return iou - (area - union) / area
|
258
|
+
|
259
|
+
|
260
|
+
class MLP(nn.Module):
|
261
|
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
262
|
+
|
263
|
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
264
|
+
super().__init__()
|
265
|
+
self.num_layers = num_layers
|
266
|
+
h = [hidden_dim] * (num_layers - 1)
|
267
|
+
self.layers = nn.ModuleList(
|
268
|
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
269
|
+
)
|
270
|
+
|
271
|
+
def forward(self, x):
|
272
|
+
for i, layer in enumerate(self.layers):
|
273
|
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
274
|
+
return x
|
275
|
+
|
276
|
+
|
277
|
+
def generate_square_subsequent_mask(sz: int, device: str = "cpu") -> torch.Tensor:
|
278
|
+
"""Generate the attention mask for causal decoding"""
|
279
|
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
280
|
+
mask = (
|
281
|
+
mask.float()
|
282
|
+
.masked_fill(mask == 0, float("-inf"))
|
283
|
+
.masked_fill(mask == 1, float(0.0))
|
284
|
+
).to(device=device)
|
285
|
+
return mask
|
286
|
+
|
287
|
+
|
288
|
+
class EarlyStopping:
|
289
|
+
"""Early stops the training if validation loss doesn't improve after a given patience.
|
290
|
+
Source from: https://github.com/Bjarten/early-stopping-pytorch
|
291
|
+
"""
|
292
|
+
|
293
|
+
def __init__(self, patience=2, verbose=False, delta=0, trace_func=print):
|
294
|
+
"""
|
295
|
+
Args:
|
296
|
+
patience (int): How long to wait after last time validation loss improved.
|
297
|
+
Default: 7
|
298
|
+
verbose (bool): If True, prints a message for each validation loss improvement.
|
299
|
+
Default: False
|
300
|
+
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
301
|
+
Default: 0
|
302
|
+
path (str): Path for the checkpoint to be saved to.
|
303
|
+
Default: 'checkpoint.pt'
|
304
|
+
trace_func (function): trace print function.
|
305
|
+
Default: print
|
306
|
+
"""
|
307
|
+
self._patience = patience
|
308
|
+
self._verbose = verbose
|
309
|
+
self._counter = 0
|
310
|
+
self._best_score = None
|
311
|
+
self._early_stop = False
|
312
|
+
self._val_loss_min = np.Inf
|
313
|
+
self._delta = delta
|
314
|
+
self._trace_func = trace_func
|
315
|
+
|
316
|
+
def __call__(self, val_loss):
|
317
|
+
score = -val_loss
|
318
|
+
save_checkpoint = True
|
319
|
+
if self._best_score is None:
|
320
|
+
self._best_score = score
|
321
|
+
save_checkpoint = True
|
322
|
+
if self._verbose:
|
323
|
+
verb = f"Validation loss decreased ({self._val_loss_min:.6f} --> {val_loss:.6f})."
|
324
|
+
self._trace_func(verb)
|
325
|
+
self._val_loss_min = val_loss
|
326
|
+
elif score < self._best_score + self._delta:
|
327
|
+
self._counter += 1
|
328
|
+
self._trace_func(
|
329
|
+
f"EarlyStopping counter: {self._counter} out of {self._patience}"
|
330
|
+
)
|
331
|
+
if self._counter >= self._patience:
|
332
|
+
self._early_stop = True
|
333
|
+
save_checkpoint = False
|
334
|
+
else:
|
335
|
+
self._best_score = score
|
336
|
+
save_checkpoint = True
|
337
|
+
self._counter = 0
|
338
|
+
if self._verbose:
|
339
|
+
verb = f"Validation loss decreased ({self._val_loss_min:.6f} --> {val_loss:.6f})."
|
340
|
+
self._trace_func(verb)
|
341
|
+
self._val_loss_min = val_loss
|
342
|
+
return save_checkpoint
|
343
|
+
|
344
|
+
|
345
|
+
def print_dict(m: dict):
|
346
|
+
r"""
|
347
|
+
Print dict elements in separate lines sorted by keys
|
348
|
+
"""
|
349
|
+
if len(m) == 0:
|
350
|
+
return
|
351
|
+
|
352
|
+
# Check if the key is a stringified integer
|
353
|
+
first_key = next(iter(m))
|
354
|
+
is_numeric = isinstance(first_key, str) and first_key.isnumeric()
|
355
|
+
if is_numeric:
|
356
|
+
keys = sorted([int(k) for k in m.keys()])
|
357
|
+
else:
|
358
|
+
keys = sorted([k for k in m.keys()])
|
359
|
+
|
360
|
+
for k in keys:
|
361
|
+
if is_numeric:
|
362
|
+
v = m[str(k)]
|
363
|
+
else:
|
364
|
+
v = m[k]
|
365
|
+
print("{}: {}".format(k, v))
|
366
|
+
|
367
|
+
|
368
|
+
def print_list(lst: list):
|
369
|
+
r"""
|
370
|
+
Print list elements in separate lines
|
371
|
+
"""
|
372
|
+
for i, elm in enumerate(lst):
|
373
|
+
if isinstance(elm, list):
|
374
|
+
print("{}: ({}) - {}".format(i, len(elm), elm))
|
375
|
+
else:
|
376
|
+
print("{}: {}".format(i, elm))
|
@@ -0,0 +1,175 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import logging
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
import docling_ibm_models.tableformer.settings as s
|
10
|
+
|
11
|
+
LOG_LEVEL = logging.INFO
|
12
|
+
|
13
|
+
|
14
|
+
class MyWelford:
|
15
|
+
r"""
|
16
|
+
Running computation of the sample mean and sample variance using Welford's algorithm
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self):
|
20
|
+
self._i = 0 # Running index
|
21
|
+
self._m = 0 # Running mean
|
22
|
+
self._s = 0 # (n - 1) * variance
|
23
|
+
|
24
|
+
def reset(self):
|
25
|
+
r"""
|
26
|
+
Reset the object
|
27
|
+
"""
|
28
|
+
self._i = 0
|
29
|
+
self._m = 0
|
30
|
+
self._s = 0
|
31
|
+
|
32
|
+
def add(self, xi):
|
33
|
+
r"""
|
34
|
+
Invoke add each time a new sample arrives
|
35
|
+
|
36
|
+
Inputs:
|
37
|
+
xi: The next sample of data
|
38
|
+
"""
|
39
|
+
self._i += 1
|
40
|
+
old_m = self._m
|
41
|
+
self._m = self._m + (xi - self._m) / self._i
|
42
|
+
self._s = self._s + (xi - self._m) * (xi - old_m)
|
43
|
+
|
44
|
+
def results(self):
|
45
|
+
r"""
|
46
|
+
Get the computed mean, variance and standard deviation up to now
|
47
|
+
|
48
|
+
Outputs:
|
49
|
+
m: Sample mean
|
50
|
+
v: Sample variance
|
51
|
+
std: Sample standard deviation
|
52
|
+
"""
|
53
|
+
if self._i <= 1:
|
54
|
+
return None, None, None
|
55
|
+
|
56
|
+
# v = self._s / (self._i - 1) # Sample variance
|
57
|
+
v = self._s / (self._i) # Population variance
|
58
|
+
std = np.sqrt(v)
|
59
|
+
return self._m, v, std
|
60
|
+
|
61
|
+
|
62
|
+
class MyWelfordImg(MyWelford):
|
63
|
+
r"""
|
64
|
+
Welford algorithm to calculate running mean and sample variance for images
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(self):
|
68
|
+
super(MyWelfordImg, self).__init__()
|
69
|
+
|
70
|
+
def add(self, img):
|
71
|
+
r"""
|
72
|
+
Input:
|
73
|
+
img: An image numpy array (channel, width, height). The only requirement is to have the
|
74
|
+
channels as the first dimension and have 3 dimensions in total
|
75
|
+
"""
|
76
|
+
channels = img.shape[0]
|
77
|
+
flat_dim = img.shape[1] * img.shape[2]
|
78
|
+
img_r = img.reshape(channels, flat_dim)
|
79
|
+
|
80
|
+
for i in range(flat_dim):
|
81
|
+
super(MyWelfordImg, self).add(img_r[:, i])
|
82
|
+
|
83
|
+
|
84
|
+
class ChanVarianceImg:
|
85
|
+
r"""
|
86
|
+
Chan's algorithm to compute a running variance with support of sub-samples
|
87
|
+
In this implementation each sub-sample is an images
|
88
|
+
|
89
|
+
Math for the original paper:
|
90
|
+
https://github.ibm.com/nli/variance_formulae
|
91
|
+
"""
|
92
|
+
|
93
|
+
def __init__(self):
|
94
|
+
r""" """
|
95
|
+
self._first = True
|
96
|
+
# Size of the calculated dataset
|
97
|
+
self._n = 0
|
98
|
+
# Sum of the samples for the 3 image channels
|
99
|
+
self._t = 0
|
100
|
+
# Sum of the square differences of the deviations of the samples from the mean
|
101
|
+
self._s = 0
|
102
|
+
|
103
|
+
def add(self, img):
|
104
|
+
r"""
|
105
|
+
Add the provided image to the computation of the dataset statistics
|
106
|
+
|
107
|
+
Input:
|
108
|
+
img: An image numpy array (channel, width, height). The only requirement is to have the
|
109
|
+
channels as the first dimension and have 3 dimensions in total
|
110
|
+
"""
|
111
|
+
ch = img.shape[0]
|
112
|
+
n = img.shape[1] * img.shape[2]
|
113
|
+
img = img.reshape(ch, n)
|
114
|
+
img_t = img.sum(axis=1)
|
115
|
+
img_t_v = img_t.reshape(ch, 1)
|
116
|
+
diff = (img - (img_t_v / n)) ** 2
|
117
|
+
img_s = diff.sum(axis=1)
|
118
|
+
|
119
|
+
if not self._first:
|
120
|
+
c = (self._n / (n * (self._n + n))) * (
|
121
|
+
((n / self._n) * self._t - img_t) ** 2
|
122
|
+
)
|
123
|
+
self._s += img_s + c
|
124
|
+
self._t += img_t
|
125
|
+
else:
|
126
|
+
self._s = img_s
|
127
|
+
self._t = img_t
|
128
|
+
self._first = False
|
129
|
+
self._n += n
|
130
|
+
|
131
|
+
def results(self):
|
132
|
+
r"""
|
133
|
+
Get the computed statistics
|
134
|
+
|
135
|
+
Output:
|
136
|
+
mean: Mean for the complete dataset
|
137
|
+
var: Population variance for the complete dataset
|
138
|
+
std: Population standard deviation for the complete dataset
|
139
|
+
"""
|
140
|
+
mean = list(self._t / self._n)
|
141
|
+
var = list(self._s / self._n) # Population variance
|
142
|
+
std = list(np.sqrt(var))
|
143
|
+
|
144
|
+
return mean, var, std
|
145
|
+
|
146
|
+
def reset(self):
|
147
|
+
r"""
|
148
|
+
Reset the object to start over again
|
149
|
+
"""
|
150
|
+
self._n = 0
|
151
|
+
self._t = 0
|
152
|
+
self._s = 0
|
153
|
+
self._first = True
|
154
|
+
|
155
|
+
|
156
|
+
if __name__ == "__main__":
|
157
|
+
logger = s.get_custom_logger("variance", LOG_LEVEL)
|
158
|
+
|
159
|
+
n = 50000
|
160
|
+
channels = 3
|
161
|
+
width = 448
|
162
|
+
height = 448
|
163
|
+
|
164
|
+
my = ChanVarianceImg()
|
165
|
+
# Generate random images
|
166
|
+
for i in range(n):
|
167
|
+
logger.info(i)
|
168
|
+
img = 255 * np.random.rand(channels, width, height)
|
169
|
+
my.add(img)
|
170
|
+
|
171
|
+
# Calculate the statistics
|
172
|
+
m, v, std = my.results()
|
173
|
+
assert m.shape == (3,), "Wrong mean dimension"
|
174
|
+
assert v.shape == (3,), "Wrong variance dimension"
|
175
|
+
assert std.shape == (3,), "Wrong std dimension"
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) [year] [fullname]
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|