onnxtr 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/models/predictor/predictor.py +1 -1
- onnxtr/models/recognition/models/__init__.py +1 -0
- onnxtr/models/recognition/models/viptr.py +179 -0
- onnxtr/models/recognition/predictor/_utils.py +107 -45
- onnxtr/models/recognition/predictor/base.py +3 -3
- onnxtr/models/recognition/utils.py +56 -47
- onnxtr/models/recognition/zoo.py +10 -1
- onnxtr/utils/vocabs.py +1061 -76
- onnxtr/version.py +1 -1
- {onnxtr-0.6.3.dist-info → onnxtr-0.7.1.dist-info}/METADATA +5 -3
- {onnxtr-0.6.3.dist-info → onnxtr-0.7.1.dist-info}/RECORD +15 -14
- {onnxtr-0.6.3.dist-info → onnxtr-0.7.1.dist-info}/WHEEL +1 -1
- {onnxtr-0.6.3.dist-info → onnxtr-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {onnxtr-0.6.3.dist-info → onnxtr-0.7.1.dist-info}/top_level.txt +0 -0
- {onnxtr-0.6.3.dist-info → onnxtr-0.7.1.dist-info}/zip-safe +0 -0
|
@@ -115,7 +115,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
115
115
|
# Crop images
|
|
116
116
|
crops, loc_preds = self._prepare_crops(
|
|
117
117
|
pages,
|
|
118
|
-
loc_preds,
|
|
118
|
+
loc_preds,
|
|
119
119
|
channels_last=True,
|
|
120
120
|
assume_straight_pages=self.assume_straight_pages,
|
|
121
121
|
assume_horizontal=self._page_orientation_disabled,
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from itertools import groupby
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from scipy.special import softmax
|
|
13
|
+
|
|
14
|
+
from onnxtr.utils import VOCABS
|
|
15
|
+
|
|
16
|
+
from ...engine import Engine, EngineConfig
|
|
17
|
+
from ..core import RecognitionPostProcessor
|
|
18
|
+
|
|
19
|
+
__all__ = ["VIPTR", "viptr_tiny"]
|
|
20
|
+
|
|
21
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
|
+
"viptr_tiny": {
|
|
23
|
+
"mean": (0.694, 0.695, 0.693),
|
|
24
|
+
"std": (0.299, 0.296, 0.301),
|
|
25
|
+
"input_shape": (3, 32, 128),
|
|
26
|
+
"vocab": VOCABS["french"],
|
|
27
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.6.3/viptr_tiny-499b8015.onnx",
|
|
28
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.6.3/viptr_tiny-499b8015.onnx",
|
|
29
|
+
},
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class VIPTRPostProcessor(RecognitionPostProcessor):
|
|
34
|
+
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
vocab: string containing the ordered sequence of supported characters
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, vocab):
|
|
41
|
+
self.vocab = vocab
|
|
42
|
+
|
|
43
|
+
def decode_sequence(self, sequence, vocab):
|
|
44
|
+
return "".join([vocab[int(char)] for char in sequence])
|
|
45
|
+
|
|
46
|
+
def ctc_best_path(
|
|
47
|
+
self,
|
|
48
|
+
logits,
|
|
49
|
+
vocab,
|
|
50
|
+
blank=0,
|
|
51
|
+
):
|
|
52
|
+
"""Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
|
|
53
|
+
<https://github.com/githubharald/CTCDecoder>`_.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
logits: model output, shape: N x T x C
|
|
57
|
+
vocab: vocabulary to use
|
|
58
|
+
blank: index of blank label
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
A list of tuples: (word, confidence)
|
|
62
|
+
"""
|
|
63
|
+
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
|
|
64
|
+
probs = softmax(logits, axis=-1).max(axis=-1).min(axis=1)
|
|
65
|
+
|
|
66
|
+
# collapse best path (using itertools.groupby), map to chars, join char list to string
|
|
67
|
+
words = [
|
|
68
|
+
self.decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
|
|
69
|
+
for seq in np.argmax(logits, axis=-1)
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
return list(zip(words, probs.astype(float).tolist()))
|
|
73
|
+
|
|
74
|
+
def __call__(self, logits):
|
|
75
|
+
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
76
|
+
with label_to_idx mapping dictionnary
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
A tuple of 2 lists: a list of str (words) and a list of float (probs)
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
# Decode CTC
|
|
86
|
+
return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class VIPTR(Engine):
|
|
90
|
+
"""VIPTR Onnx loader
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
model_path: path or url to onnx model file
|
|
94
|
+
vocab: vocabulary used for encoding
|
|
95
|
+
engine_cfg: configuration for the inference engine
|
|
96
|
+
cfg: configuration dictionary
|
|
97
|
+
**kwargs: additional arguments to be passed to `Engine`
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
_children_names: list[str] = ["postprocessor"]
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
model_path: str,
|
|
105
|
+
vocab: str,
|
|
106
|
+
engine_cfg: EngineConfig | None = None,
|
|
107
|
+
cfg: dict[str, Any] | None = None,
|
|
108
|
+
**kwargs: Any,
|
|
109
|
+
) -> None:
|
|
110
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
111
|
+
|
|
112
|
+
self.vocab = vocab
|
|
113
|
+
self.cfg = cfg
|
|
114
|
+
|
|
115
|
+
self.postprocessor = VIPTRPostProcessor(self.vocab)
|
|
116
|
+
|
|
117
|
+
def __call__(
|
|
118
|
+
self,
|
|
119
|
+
x: np.ndarray,
|
|
120
|
+
return_model_output: bool = False,
|
|
121
|
+
) -> dict[str, Any]:
|
|
122
|
+
logits = self.run(x)
|
|
123
|
+
|
|
124
|
+
out: dict[str, Any] = {}
|
|
125
|
+
if return_model_output:
|
|
126
|
+
out["out_map"] = logits
|
|
127
|
+
|
|
128
|
+
# Post-process
|
|
129
|
+
out["preds"] = self.postprocessor(logits)
|
|
130
|
+
|
|
131
|
+
return out
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _viptr(
|
|
135
|
+
arch: str,
|
|
136
|
+
model_path: str,
|
|
137
|
+
load_in_8_bit: bool = False,
|
|
138
|
+
engine_cfg: EngineConfig | None = None,
|
|
139
|
+
**kwargs: Any,
|
|
140
|
+
) -> VIPTR:
|
|
141
|
+
if load_in_8_bit:
|
|
142
|
+
logging.warning("VIPTR models do not support 8-bit quantization yet. Loading full precision model...")
|
|
143
|
+
kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
|
|
144
|
+
|
|
145
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
146
|
+
_cfg["vocab"] = kwargs["vocab"]
|
|
147
|
+
_cfg["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
|
|
148
|
+
# Patch the url
|
|
149
|
+
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
150
|
+
|
|
151
|
+
# Build the model
|
|
152
|
+
return VIPTR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def viptr_tiny(
|
|
156
|
+
model_path: str = default_cfgs["viptr_tiny"]["url"],
|
|
157
|
+
load_in_8_bit: bool = False,
|
|
158
|
+
engine_cfg: EngineConfig | None = None,
|
|
159
|
+
**kwargs: Any,
|
|
160
|
+
) -> VIPTR:
|
|
161
|
+
"""VIPTR as described in `"A Vision Permutable Extractor for Fast and Efficient
|
|
162
|
+
Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
163
|
+
|
|
164
|
+
>>> import numpy as np
|
|
165
|
+
>>> from onnxtr.models import viptr_tiny
|
|
166
|
+
>>> model = viptr_tiny()
|
|
167
|
+
>>> input_tensor = np.random.rand(1, 3, 32, 128)
|
|
168
|
+
>>> out = model(input_tensor)
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
172
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
173
|
+
engine_cfg: configuration for the inference engine
|
|
174
|
+
**kwargs: keyword arguments of the VIPTR architecture
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
text recognition architecture
|
|
178
|
+
"""
|
|
179
|
+
return _viptr("viptr_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
import math
|
|
8
|
+
|
|
7
9
|
import numpy as np
|
|
8
10
|
|
|
9
11
|
from ..utils import merge_multi_strings
|
|
@@ -15,69 +17,129 @@ def split_crops(
|
|
|
15
17
|
crops: list[np.ndarray],
|
|
16
18
|
max_ratio: float,
|
|
17
19
|
target_ratio: int,
|
|
18
|
-
|
|
20
|
+
split_overlap_ratio: float,
|
|
19
21
|
channels_last: bool = True,
|
|
20
|
-
) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
|
|
21
|
-
"""
|
|
22
|
+
) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
|
|
23
|
+
"""
|
|
24
|
+
Split crops horizontally if they exceed a given aspect ratio.
|
|
22
25
|
|
|
23
26
|
Args:
|
|
24
|
-
crops:
|
|
25
|
-
max_ratio:
|
|
26
|
-
target_ratio:
|
|
27
|
-
|
|
28
|
-
channels_last:
|
|
27
|
+
crops: List of image crops (H, W, C) if channels_last else (C, H, W).
|
|
28
|
+
max_ratio: Aspect ratio threshold above which crops are split.
|
|
29
|
+
target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
|
|
30
|
+
split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
|
|
31
|
+
channels_last: Whether the crops are in channels-last format.
|
|
29
32
|
|
|
30
33
|
Returns:
|
|
31
|
-
|
|
34
|
+
A tuple containing:
|
|
35
|
+
- The new list of crops (possibly with splits),
|
|
36
|
+
- A mapping indicating how to reassemble predictions,
|
|
37
|
+
- A boolean indicating whether remapping is required.
|
|
32
38
|
"""
|
|
33
|
-
|
|
34
|
-
|
|
39
|
+
if split_overlap_ratio <= 0.0 or split_overlap_ratio >= 1.0:
|
|
40
|
+
raise ValueError(f"Valid range for split_overlap_ratio is (0.0, 1.0), but is: {split_overlap_ratio}")
|
|
41
|
+
|
|
42
|
+
remap_required = False
|
|
35
43
|
new_crops: list[np.ndarray] = []
|
|
44
|
+
crop_map: list[int | tuple[int, int, float]] = []
|
|
45
|
+
|
|
36
46
|
for crop in crops:
|
|
37
47
|
h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
|
|
38
48
|
aspect_ratio = w / h
|
|
49
|
+
|
|
39
50
|
if aspect_ratio > max_ratio:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
#
|
|
46
|
-
if
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
+
split_width = max(1, math.ceil(h * target_ratio))
|
|
52
|
+
overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
|
|
53
|
+
|
|
54
|
+
splits, last_overlap = _split_horizontally(crop, split_width, overlap_width, channels_last)
|
|
55
|
+
|
|
56
|
+
# Remove any empty splits
|
|
57
|
+
splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
|
|
58
|
+
if splits:
|
|
59
|
+
crop_map.append((len(new_crops), len(new_crops) + len(splits), last_overlap))
|
|
60
|
+
new_crops.extend(splits)
|
|
61
|
+
remap_required = True
|
|
51
62
|
else:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
]
|
|
56
|
-
# Avoid sending zero-sized crops
|
|
57
|
-
_crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
|
|
58
|
-
# Record the slice of crops
|
|
59
|
-
crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
|
|
60
|
-
new_crops.extend(_crops)
|
|
61
|
-
# At least one crop will require merging
|
|
62
|
-
_remap_required = True
|
|
63
|
+
# Fallback: treat it as a single crop
|
|
64
|
+
crop_map.append(len(new_crops))
|
|
65
|
+
new_crops.append(crop)
|
|
63
66
|
else:
|
|
64
67
|
crop_map.append(len(new_crops))
|
|
65
68
|
new_crops.append(crop)
|
|
66
69
|
|
|
67
|
-
return new_crops, crop_map,
|
|
70
|
+
return new_crops, crop_map, remap_required
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _split_horizontally(
|
|
74
|
+
image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
|
|
75
|
+
) -> tuple[list[np.ndarray], float]:
|
|
76
|
+
"""
|
|
77
|
+
Horizontally split a single image with overlapping regions.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
image: The image to split (H, W, C) if channels_last else (C, H, W).
|
|
81
|
+
split_width: Width of each split.
|
|
82
|
+
overlap_width: Width of the overlapping region.
|
|
83
|
+
channels_last: Whether the image is in channels-last format.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
- A list of horizontal image slices.
|
|
87
|
+
- The actual overlap ratio of the last split.
|
|
88
|
+
"""
|
|
89
|
+
image_width = image.shape[1] if channels_last else image.shape[-1]
|
|
90
|
+
if image_width <= split_width:
|
|
91
|
+
return [image], 0.0
|
|
92
|
+
|
|
93
|
+
# Compute start columns for each split
|
|
94
|
+
step = split_width - overlap_width
|
|
95
|
+
starts = list(range(0, image_width - split_width + 1, step))
|
|
96
|
+
|
|
97
|
+
# Ensure the last patch reaches the end of the image
|
|
98
|
+
if starts[-1] + split_width < image_width:
|
|
99
|
+
starts.append(image_width - split_width)
|
|
100
|
+
|
|
101
|
+
splits = []
|
|
102
|
+
for start_col in starts:
|
|
103
|
+
end_col = start_col + split_width
|
|
104
|
+
if channels_last:
|
|
105
|
+
split = image[:, start_col:end_col, :]
|
|
106
|
+
else:
|
|
107
|
+
split = image[:, :, start_col:end_col]
|
|
108
|
+
splits.append(split)
|
|
109
|
+
|
|
110
|
+
# Calculate the last overlap ratio, if only one split no overlap
|
|
111
|
+
last_overlap = 0
|
|
112
|
+
if len(starts) > 1:
|
|
113
|
+
last_overlap = (starts[-2] + split_width) - starts[-1]
|
|
114
|
+
last_overlap_ratio = last_overlap / split_width if split_width else 0.0
|
|
115
|
+
|
|
116
|
+
return splits, last_overlap_ratio
|
|
68
117
|
|
|
69
118
|
|
|
70
119
|
def remap_preds(
|
|
71
|
-
preds: list[tuple[str, float]],
|
|
120
|
+
preds: list[tuple[str, float]],
|
|
121
|
+
crop_map: list[int | tuple[int, int, float]],
|
|
122
|
+
overlap_ratio: float,
|
|
72
123
|
) -> list[tuple[str, float]]:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
124
|
+
"""
|
|
125
|
+
Reconstruct predictions from possibly split crops.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
preds: List of (text, confidence) tuples from each crop.
|
|
129
|
+
crop_map: Map returned by `split_crops`.
|
|
130
|
+
overlap_ratio: Overlap ratio used during splitting.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
List of merged (text, confidence) tuples corresponding to original crops.
|
|
134
|
+
"""
|
|
135
|
+
remapped = []
|
|
136
|
+
for item in crop_map:
|
|
137
|
+
if isinstance(item, int):
|
|
138
|
+
remapped.append(preds[item])
|
|
78
139
|
else:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
140
|
+
start_idx, end_idx, last_overlap = item
|
|
141
|
+
text_parts, confidences = zip(*preds[start_idx:end_idx])
|
|
142
|
+
merged_text = merge_multi_strings(list(text_parts), overlap_ratio, last_overlap)
|
|
143
|
+
merged_conf = sum(confidences) / len(confidences) # average confidence
|
|
144
|
+
remapped.append((merged_text, merged_conf))
|
|
145
|
+
return remapped
|
|
@@ -36,7 +36,7 @@ class RecognitionPredictor(NestedObject):
|
|
|
36
36
|
self.model = model
|
|
37
37
|
self.split_wide_crops = split_wide_crops
|
|
38
38
|
self.critical_ar = 8 # Critical aspect ratio
|
|
39
|
-
self.
|
|
39
|
+
self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
|
|
40
40
|
self.target_ar = 6 # Target aspect ratio
|
|
41
41
|
|
|
42
42
|
def __call__(
|
|
@@ -57,7 +57,7 @@ class RecognitionPredictor(NestedObject):
|
|
|
57
57
|
crops, # type: ignore[arg-type]
|
|
58
58
|
self.critical_ar,
|
|
59
59
|
self.target_ar,
|
|
60
|
-
self.
|
|
60
|
+
self.overlap_ratio,
|
|
61
61
|
True,
|
|
62
62
|
)
|
|
63
63
|
if remapped:
|
|
@@ -74,6 +74,6 @@ class RecognitionPredictor(NestedObject):
|
|
|
74
74
|
|
|
75
75
|
# Remap crops
|
|
76
76
|
if self.split_wide_crops and remapped:
|
|
77
|
-
out = remap_preds(out, crop_map, self.
|
|
77
|
+
out = remap_preds(out, crop_map, self.overlap_ratio)
|
|
78
78
|
|
|
79
79
|
return out
|
|
@@ -4,81 +4,90 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
from rapidfuzz.distance import
|
|
7
|
+
from rapidfuzz.distance import Hamming
|
|
8
8
|
|
|
9
9
|
__all__ = ["merge_strings", "merge_multi_strings"]
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def merge_strings(a: str, b: str,
|
|
12
|
+
def merge_strings(a: str, b: str, overlap_ratio: float) -> str:
|
|
13
13
|
"""Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
a: first char seq, suffix should be similar to b's prefix.
|
|
17
17
|
b: second char seq, prefix should be similar to a's suffix.
|
|
18
|
-
|
|
19
|
-
only used when the mother sequence is splitted on a character repetition
|
|
18
|
+
overlap_ratio: estimated ratio of overlapping characters.
|
|
20
19
|
|
|
21
20
|
Returns:
|
|
22
21
|
A merged character sequence.
|
|
23
22
|
|
|
24
23
|
Example::
|
|
25
|
-
>>> from
|
|
26
|
-
>>>
|
|
24
|
+
>>> from doctr.models.recognition.utils import merge_strings
|
|
25
|
+
>>> merge_strings('abcd', 'cdefgh', 0.5)
|
|
27
26
|
'abcdefgh'
|
|
28
|
-
>>>
|
|
27
|
+
>>> merge_strings('abcdi', 'cdefgh', 0.5)
|
|
29
28
|
'abcdefgh'
|
|
30
29
|
"""
|
|
31
30
|
seq_len = min(len(a), len(b))
|
|
32
|
-
if seq_len
|
|
33
|
-
return b if len(a) == 0 else a
|
|
34
|
-
|
|
35
|
-
# Initialize merging index and corresponding score (mean Levenstein)
|
|
36
|
-
min_score, index = 1.0, 0 # No overlap, just concatenate
|
|
37
|
-
|
|
38
|
-
scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
|
|
39
|
-
|
|
40
|
-
# Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
|
|
41
|
-
if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
|
|
42
|
-
# Compute n_overlap (number of overlapping chars, geometrically determined)
|
|
43
|
-
n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
|
|
44
|
-
# Find the number of consecutive zeros in the scores list
|
|
45
|
-
# Impossible to have a zero after a non-zero score in that case
|
|
46
|
-
n_zeros = sum(val == 0 for val in scores)
|
|
47
|
-
# Index is bounded by the geometrical overlap to avoid collapsing repetitions
|
|
48
|
-
min_score, index = 0, min(n_zeros, n_overlap)
|
|
49
|
-
|
|
50
|
-
else: # Common case: choose the min score index
|
|
51
|
-
for i, score in enumerate(scores):
|
|
52
|
-
if score < min_score:
|
|
53
|
-
min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
|
|
54
|
-
|
|
55
|
-
# Merge with correct overlap
|
|
56
|
-
if index == 0:
|
|
31
|
+
if seq_len <= 1: # One sequence is empty or will be after cropping in next step, return both to keep data
|
|
57
32
|
return a + b
|
|
58
|
-
return a[:-1] + b[index - 1 :]
|
|
59
33
|
|
|
34
|
+
a_crop, b_crop = a[:-1], b[1:] # Remove last letter of "a" and first of "b", because they might be cut off
|
|
35
|
+
max_overlap = min(len(a_crop), len(b_crop))
|
|
60
36
|
|
|
61
|
-
|
|
62
|
-
|
|
37
|
+
# Compute Hamming distances for all possible overlaps
|
|
38
|
+
scores = [Hamming.distance(a_crop[-i:], b_crop[:i], processor=None) for i in range(1, max_overlap + 1)]
|
|
39
|
+
|
|
40
|
+
# Find zero-score matches
|
|
41
|
+
zero_matches = [i for i, score in enumerate(scores) if score == 0]
|
|
42
|
+
|
|
43
|
+
expected_overlap = round(len(b) * overlap_ratio) - 3 # adjust for cropping and index
|
|
44
|
+
|
|
45
|
+
# Case 1: One perfect match - exactly one zero score - just merge there
|
|
46
|
+
if len(zero_matches) == 1:
|
|
47
|
+
i = zero_matches[0]
|
|
48
|
+
return a_crop + b_crop[i + 1 :]
|
|
49
|
+
|
|
50
|
+
# Case 2: Multiple perfect matches - likely due to repeated characters.
|
|
51
|
+
# Use the estimated overlap length to choose the match closest to the expected alignment.
|
|
52
|
+
elif len(zero_matches) > 1:
|
|
53
|
+
best_i = min(zero_matches, key=lambda x: abs(x - expected_overlap))
|
|
54
|
+
return a_crop + b_crop[best_i + 1 :]
|
|
55
|
+
|
|
56
|
+
# Case 3: Absence of zero scores indicates that the same character in the image was recognized differently OR that
|
|
57
|
+
# the overlap was too small and we just need to merge the crops fully
|
|
58
|
+
if expected_overlap < -1:
|
|
59
|
+
return a + b
|
|
60
|
+
elif expected_overlap < 0:
|
|
61
|
+
return a_crop + b_crop
|
|
62
|
+
|
|
63
|
+
# Find best overlap by minimizing Hamming distance + distance from expected overlap size
|
|
64
|
+
combined_scores = [score + abs(i - expected_overlap) for i, score in enumerate(scores)]
|
|
65
|
+
best_i = combined_scores.index(min(combined_scores))
|
|
66
|
+
return a_crop + b_crop[best_i + 1 :]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def merge_multi_strings(seq_list: list[str], overlap_ratio: float, last_overlap_ratio: float) -> str:
|
|
70
|
+
"""
|
|
71
|
+
Merges consecutive string sequences with overlapping characters.
|
|
63
72
|
|
|
64
73
|
Args:
|
|
65
74
|
seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
|
|
66
|
-
|
|
67
|
-
|
|
75
|
+
overlap_ratio: Estimated ratio of overlapping letters between neighboring strings.
|
|
76
|
+
last_overlap_ratio: Estimated ratio of overlapping letters for the last element in seq_list.
|
|
68
77
|
|
|
69
78
|
Returns:
|
|
70
79
|
A merged character sequence
|
|
71
80
|
|
|
72
81
|
Example::
|
|
73
|
-
>>> from
|
|
74
|
-
>>>
|
|
82
|
+
>>> from doctr.models.recognition.utils import merge_multi_strings
|
|
83
|
+
>>> merge_multi_strings(['abc', 'bcdef', 'difghi', 'aijkl'], 0.5, 0.1)
|
|
75
84
|
'abcdefghijkl'
|
|
76
85
|
"""
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
return
|
|
86
|
+
if not seq_list:
|
|
87
|
+
return ""
|
|
88
|
+
result = seq_list[0]
|
|
89
|
+
for i in range(1, len(seq_list)):
|
|
90
|
+
text_b = seq_list[i]
|
|
91
|
+
ratio = last_overlap_ratio if i == len(seq_list) - 1 else overlap_ratio
|
|
92
|
+
result = merge_strings(result, text_b, ratio)
|
|
93
|
+
return result
|
onnxtr/models/recognition/zoo.py
CHANGED
|
@@ -22,6 +22,7 @@ ARCHS: list[str] = [
|
|
|
22
22
|
"vitstr_small",
|
|
23
23
|
"vitstr_base",
|
|
24
24
|
"parseq",
|
|
25
|
+
"viptr_tiny",
|
|
25
26
|
]
|
|
26
27
|
|
|
27
28
|
|
|
@@ -35,7 +36,15 @@ def _predictor(
|
|
|
35
36
|
_model = recognition.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg)
|
|
36
37
|
else:
|
|
37
38
|
if not isinstance(
|
|
38
|
-
arch,
|
|
39
|
+
arch,
|
|
40
|
+
(
|
|
41
|
+
recognition.CRNN,
|
|
42
|
+
recognition.SAR,
|
|
43
|
+
recognition.MASTER,
|
|
44
|
+
recognition.ViTSTR,
|
|
45
|
+
recognition.PARSeq,
|
|
46
|
+
recognition.VIPTR,
|
|
47
|
+
),
|
|
39
48
|
):
|
|
40
49
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
41
50
|
_model = arch
|