onnxtr 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -115,7 +115,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
115
115
  # Crop images
116
116
  crops, loc_preds = self._prepare_crops(
117
117
  pages,
118
- loc_preds, # type: ignore[arg-type]
118
+ loc_preds,
119
119
  channels_last=True,
120
120
  assume_straight_pages=self.assume_straight_pages,
121
121
  assume_horizontal=self._page_orientation_disabled,
@@ -3,3 +3,4 @@ from .sar import *
3
3
  from .master import *
4
4
  from .vitstr import *
5
5
  from .parseq import *
6
+ from .viptr import *
@@ -0,0 +1,179 @@
1
+ # Copyright (C) 2021-2025, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import logging
7
+ from copy import deepcopy
8
+ from itertools import groupby
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ from scipy.special import softmax
13
+
14
+ from onnxtr.utils import VOCABS
15
+
16
+ from ...engine import Engine, EngineConfig
17
+ from ..core import RecognitionPostProcessor
18
+
19
+ __all__ = ["VIPTR", "viptr_tiny"]
20
+
21
+ default_cfgs: dict[str, dict[str, Any]] = {
22
+ "viptr_tiny": {
23
+ "mean": (0.694, 0.695, 0.693),
24
+ "std": (0.299, 0.296, 0.301),
25
+ "input_shape": (3, 32, 128),
26
+ "vocab": VOCABS["french"],
27
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.6.3/viptr_tiny-499b8015.onnx",
28
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.6.3/viptr_tiny-499b8015.onnx",
29
+ },
30
+ }
31
+
32
+
33
+ class VIPTRPostProcessor(RecognitionPostProcessor):
34
+ """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
35
+
36
+ Args:
37
+ vocab: string containing the ordered sequence of supported characters
38
+ """
39
+
40
+ def __init__(self, vocab):
41
+ self.vocab = vocab
42
+
43
+ def decode_sequence(self, sequence, vocab):
44
+ return "".join([vocab[int(char)] for char in sequence])
45
+
46
+ def ctc_best_path(
47
+ self,
48
+ logits,
49
+ vocab,
50
+ blank=0,
51
+ ):
52
+ """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
53
+ <https://github.com/githubharald/CTCDecoder>`_.
54
+
55
+ Args:
56
+ logits: model output, shape: N x T x C
57
+ vocab: vocabulary to use
58
+ blank: index of blank label
59
+
60
+ Returns:
61
+ A list of tuples: (word, confidence)
62
+ """
63
+ # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
64
+ probs = softmax(logits, axis=-1).max(axis=-1).min(axis=1)
65
+
66
+ # collapse best path (using itertools.groupby), map to chars, join char list to string
67
+ words = [
68
+ self.decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
69
+ for seq in np.argmax(logits, axis=-1)
70
+ ]
71
+
72
+ return list(zip(words, probs.astype(float).tolist()))
73
+
74
+ def __call__(self, logits):
75
+ """Performs decoding of raw output with CTC and decoding of CTC predictions
76
+ with label_to_idx mapping dictionnary
77
+
78
+ Args:
79
+ logits: raw output of the model, shape (N, C + 1, seq_len)
80
+
81
+ Returns:
82
+ A tuple of 2 lists: a list of str (words) and a list of float (probs)
83
+
84
+ """
85
+ # Decode CTC
86
+ return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
87
+
88
+
89
+ class VIPTR(Engine):
90
+ """VIPTR Onnx loader
91
+
92
+ Args:
93
+ model_path: path or url to onnx model file
94
+ vocab: vocabulary used for encoding
95
+ engine_cfg: configuration for the inference engine
96
+ cfg: configuration dictionary
97
+ **kwargs: additional arguments to be passed to `Engine`
98
+ """
99
+
100
+ _children_names: list[str] = ["postprocessor"]
101
+
102
+ def __init__(
103
+ self,
104
+ model_path: str,
105
+ vocab: str,
106
+ engine_cfg: EngineConfig | None = None,
107
+ cfg: dict[str, Any] | None = None,
108
+ **kwargs: Any,
109
+ ) -> None:
110
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
111
+
112
+ self.vocab = vocab
113
+ self.cfg = cfg
114
+
115
+ self.postprocessor = VIPTRPostProcessor(self.vocab)
116
+
117
+ def __call__(
118
+ self,
119
+ x: np.ndarray,
120
+ return_model_output: bool = False,
121
+ ) -> dict[str, Any]:
122
+ logits = self.run(x)
123
+
124
+ out: dict[str, Any] = {}
125
+ if return_model_output:
126
+ out["out_map"] = logits
127
+
128
+ # Post-process
129
+ out["preds"] = self.postprocessor(logits)
130
+
131
+ return out
132
+
133
+
134
+ def _viptr(
135
+ arch: str,
136
+ model_path: str,
137
+ load_in_8_bit: bool = False,
138
+ engine_cfg: EngineConfig | None = None,
139
+ **kwargs: Any,
140
+ ) -> VIPTR:
141
+ if load_in_8_bit:
142
+ logging.warning("VIPTR models do not support 8-bit quantization yet. Loading full precision model...")
143
+ kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
144
+
145
+ _cfg = deepcopy(default_cfgs[arch])
146
+ _cfg["vocab"] = kwargs["vocab"]
147
+ _cfg["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
148
+ # Patch the url
149
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
150
+
151
+ # Build the model
152
+ return VIPTR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
153
+
154
+
155
+ def viptr_tiny(
156
+ model_path: str = default_cfgs["viptr_tiny"]["url"],
157
+ load_in_8_bit: bool = False,
158
+ engine_cfg: EngineConfig | None = None,
159
+ **kwargs: Any,
160
+ ) -> VIPTR:
161
+ """VIPTR as described in `"A Vision Permutable Extractor for Fast and Efficient
162
+ Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
163
+
164
+ >>> import numpy as np
165
+ >>> from onnxtr.models import viptr_tiny
166
+ >>> model = viptr_tiny()
167
+ >>> input_tensor = np.random.rand(1, 3, 32, 128)
168
+ >>> out = model(input_tensor)
169
+
170
+ Args:
171
+ model_path: path to onnx model file, defaults to url in default_cfgs
172
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
173
+ engine_cfg: configuration for the inference engine
174
+ **kwargs: keyword arguments of the VIPTR architecture
175
+
176
+ Returns:
177
+ text recognition architecture
178
+ """
179
+ return _viptr("viptr_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -4,6 +4,8 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
+ import math
8
+
7
9
  import numpy as np
8
10
 
9
11
  from ..utils import merge_multi_strings
@@ -15,69 +17,129 @@ def split_crops(
15
17
  crops: list[np.ndarray],
16
18
  max_ratio: float,
17
19
  target_ratio: int,
18
- dilation: float,
20
+ split_overlap_ratio: float,
19
21
  channels_last: bool = True,
20
- ) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
21
- """Chunk crops horizontally to match a given aspect ratio
22
+ ) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
23
+ """
24
+ Split crops horizontally if they exceed a given aspect ratio.
22
25
 
23
26
  Args:
24
- crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
25
- max_ratio: the maximum aspect ratio that won't trigger the chunk
26
- target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
27
- dilation: the width dilation of final chunks (to provide some overlaps)
28
- channels_last: whether the numpy array has dimensions in channels last order
27
+ crops: List of image crops (H, W, C) if channels_last else (C, H, W).
28
+ max_ratio: Aspect ratio threshold above which crops are split.
29
+ target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
30
+ split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
31
+ channels_last: Whether the crops are in channels-last format.
29
32
 
30
33
  Returns:
31
- a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
34
+ A tuple containing:
35
+ - The new list of crops (possibly with splits),
36
+ - A mapping indicating how to reassemble predictions,
37
+ - A boolean indicating whether remapping is required.
32
38
  """
33
- _remap_required = False
34
- crop_map: list[int | tuple[int, int]] = []
39
+ if split_overlap_ratio <= 0.0 or split_overlap_ratio >= 1.0:
40
+ raise ValueError(f"Valid range for split_overlap_ratio is (0.0, 1.0), but is: {split_overlap_ratio}")
41
+
42
+ remap_required = False
35
43
  new_crops: list[np.ndarray] = []
44
+ crop_map: list[int | tuple[int, int, float]] = []
45
+
36
46
  for crop in crops:
37
47
  h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
38
48
  aspect_ratio = w / h
49
+
39
50
  if aspect_ratio > max_ratio:
40
- # Determine the number of crops, reference aspect ratio = 4 = 128 / 32
41
- num_subcrops = int(aspect_ratio // target_ratio)
42
- # Find the new widths, additional dilation factor to overlap crops
43
- width = dilation * w / num_subcrops
44
- centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)]
45
- # Get the crops
46
- if channels_last:
47
- _crops = [
48
- crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :]
49
- for center in centers
50
- ]
51
+ split_width = max(1, math.ceil(h * target_ratio))
52
+ overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
53
+
54
+ splits, last_overlap = _split_horizontally(crop, split_width, overlap_width, channels_last)
55
+
56
+ # Remove any empty splits
57
+ splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
58
+ if splits:
59
+ crop_map.append((len(new_crops), len(new_crops) + len(splits), last_overlap))
60
+ new_crops.extend(splits)
61
+ remap_required = True
51
62
  else:
52
- _crops = [
53
- crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))]
54
- for center in centers
55
- ]
56
- # Avoid sending zero-sized crops
57
- _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
58
- # Record the slice of crops
59
- crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
60
- new_crops.extend(_crops)
61
- # At least one crop will require merging
62
- _remap_required = True
63
+ # Fallback: treat it as a single crop
64
+ crop_map.append(len(new_crops))
65
+ new_crops.append(crop)
63
66
  else:
64
67
  crop_map.append(len(new_crops))
65
68
  new_crops.append(crop)
66
69
 
67
- return new_crops, crop_map, _remap_required
70
+ return new_crops, crop_map, remap_required
71
+
72
+
73
+ def _split_horizontally(
74
+ image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
75
+ ) -> tuple[list[np.ndarray], float]:
76
+ """
77
+ Horizontally split a single image with overlapping regions.
78
+
79
+ Args:
80
+ image: The image to split (H, W, C) if channels_last else (C, H, W).
81
+ split_width: Width of each split.
82
+ overlap_width: Width of the overlapping region.
83
+ channels_last: Whether the image is in channels-last format.
84
+
85
+ Returns:
86
+ - A list of horizontal image slices.
87
+ - The actual overlap ratio of the last split.
88
+ """
89
+ image_width = image.shape[1] if channels_last else image.shape[-1]
90
+ if image_width <= split_width:
91
+ return [image], 0.0
92
+
93
+ # Compute start columns for each split
94
+ step = split_width - overlap_width
95
+ starts = list(range(0, image_width - split_width + 1, step))
96
+
97
+ # Ensure the last patch reaches the end of the image
98
+ if starts[-1] + split_width < image_width:
99
+ starts.append(image_width - split_width)
100
+
101
+ splits = []
102
+ for start_col in starts:
103
+ end_col = start_col + split_width
104
+ if channels_last:
105
+ split = image[:, start_col:end_col, :]
106
+ else:
107
+ split = image[:, :, start_col:end_col]
108
+ splits.append(split)
109
+
110
+ # Calculate the last overlap ratio, if only one split no overlap
111
+ last_overlap = 0
112
+ if len(starts) > 1:
113
+ last_overlap = (starts[-2] + split_width) - starts[-1]
114
+ last_overlap_ratio = last_overlap / split_width if split_width else 0.0
115
+
116
+ return splits, last_overlap_ratio
68
117
 
69
118
 
70
119
  def remap_preds(
71
- preds: list[tuple[str, float]], crop_map: list[int | tuple[int, int]], dilation: float
120
+ preds: list[tuple[str, float]],
121
+ crop_map: list[int | tuple[int, int, float]],
122
+ overlap_ratio: float,
72
123
  ) -> list[tuple[str, float]]:
73
- remapped_out = []
74
- for _idx in crop_map:
75
- # Crop hasn't been split
76
- if isinstance(_idx, int):
77
- remapped_out.append(preds[_idx])
124
+ """
125
+ Reconstruct predictions from possibly split crops.
126
+
127
+ Args:
128
+ preds: List of (text, confidence) tuples from each crop.
129
+ crop_map: Map returned by `split_crops`.
130
+ overlap_ratio: Overlap ratio used during splitting.
131
+
132
+ Returns:
133
+ List of merged (text, confidence) tuples corresponding to original crops.
134
+ """
135
+ remapped = []
136
+ for item in crop_map:
137
+ if isinstance(item, int):
138
+ remapped.append(preds[item])
78
139
  else:
79
- # unzip
80
- vals, probs = zip(*preds[_idx[0] : _idx[1]])
81
- # Merge the string values
82
- remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type]
83
- return remapped_out
140
+ start_idx, end_idx, last_overlap = item
141
+ text_parts, confidences = zip(*preds[start_idx:end_idx])
142
+ merged_text = merge_multi_strings(list(text_parts), overlap_ratio, last_overlap)
143
+ merged_conf = sum(confidences) / len(confidences) # average confidence
144
+ remapped.append((merged_text, merged_conf))
145
+ return remapped
@@ -36,7 +36,7 @@ class RecognitionPredictor(NestedObject):
36
36
  self.model = model
37
37
  self.split_wide_crops = split_wide_crops
38
38
  self.critical_ar = 8 # Critical aspect ratio
39
- self.dil_factor = 1.4 # Dilation factor to overlap the crops
39
+ self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
40
40
  self.target_ar = 6 # Target aspect ratio
41
41
 
42
42
  def __call__(
@@ -57,7 +57,7 @@ class RecognitionPredictor(NestedObject):
57
57
  crops, # type: ignore[arg-type]
58
58
  self.critical_ar,
59
59
  self.target_ar,
60
- self.dil_factor,
60
+ self.overlap_ratio,
61
61
  True,
62
62
  )
63
63
  if remapped:
@@ -74,6 +74,6 @@ class RecognitionPredictor(NestedObject):
74
74
 
75
75
  # Remap crops
76
76
  if self.split_wide_crops and remapped:
77
- out = remap_preds(out, crop_map, self.dil_factor)
77
+ out = remap_preds(out, crop_map, self.overlap_ratio)
78
78
 
79
79
  return out
@@ -4,81 +4,90 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
- from rapidfuzz.distance import Levenshtein
7
+ from rapidfuzz.distance import Hamming
8
8
 
9
9
  __all__ = ["merge_strings", "merge_multi_strings"]
10
10
 
11
11
 
12
- def merge_strings(a: str, b: str, dil_factor: float) -> str:
12
+ def merge_strings(a: str, b: str, overlap_ratio: float) -> str:
13
13
  """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
14
14
 
15
15
  Args:
16
16
  a: first char seq, suffix should be similar to b's prefix.
17
17
  b: second char seq, prefix should be similar to a's suffix.
18
- dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
19
- only used when the mother sequence is splitted on a character repetition
18
+ overlap_ratio: estimated ratio of overlapping characters.
20
19
 
21
20
  Returns:
22
21
  A merged character sequence.
23
22
 
24
23
  Example::
25
- >>> from onnxtr.model.recognition.utils import merge_sequences
26
- >>> merge_sequences('abcd', 'cdefgh', 1.4)
24
+ >>> from doctr.models.recognition.utils import merge_strings
25
+ >>> merge_strings('abcd', 'cdefgh', 0.5)
27
26
  'abcdefgh'
28
- >>> merge_sequences('abcdi', 'cdefgh', 1.4)
27
+ >>> merge_strings('abcdi', 'cdefgh', 0.5)
29
28
  'abcdefgh'
30
29
  """
31
30
  seq_len = min(len(a), len(b))
32
- if seq_len == 0: # One sequence is empty, return the other
33
- return b if len(a) == 0 else a
34
-
35
- # Initialize merging index and corresponding score (mean Levenstein)
36
- min_score, index = 1.0, 0 # No overlap, just concatenate
37
-
38
- scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
39
-
40
- # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
41
- if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
42
- # Compute n_overlap (number of overlapping chars, geometrically determined)
43
- n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
44
- # Find the number of consecutive zeros in the scores list
45
- # Impossible to have a zero after a non-zero score in that case
46
- n_zeros = sum(val == 0 for val in scores)
47
- # Index is bounded by the geometrical overlap to avoid collapsing repetitions
48
- min_score, index = 0, min(n_zeros, n_overlap)
49
-
50
- else: # Common case: choose the min score index
51
- for i, score in enumerate(scores):
52
- if score < min_score:
53
- min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
54
-
55
- # Merge with correct overlap
56
- if index == 0:
31
+ if seq_len <= 1: # One sequence is empty or will be after cropping in next step, return both to keep data
57
32
  return a + b
58
- return a[:-1] + b[index - 1 :]
59
33
 
34
+ a_crop, b_crop = a[:-1], b[1:] # Remove last letter of "a" and first of "b", because they might be cut off
35
+ max_overlap = min(len(a_crop), len(b_crop))
60
36
 
61
- def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str:
62
- """Recursively merges consecutive string sequences with overlapping characters.
37
+ # Compute Hamming distances for all possible overlaps
38
+ scores = [Hamming.distance(a_crop[-i:], b_crop[:i], processor=None) for i in range(1, max_overlap + 1)]
39
+
40
+ # Find zero-score matches
41
+ zero_matches = [i for i, score in enumerate(scores) if score == 0]
42
+
43
+ expected_overlap = round(len(b) * overlap_ratio) - 3 # adjust for cropping and index
44
+
45
+ # Case 1: One perfect match - exactly one zero score - just merge there
46
+ if len(zero_matches) == 1:
47
+ i = zero_matches[0]
48
+ return a_crop + b_crop[i + 1 :]
49
+
50
+ # Case 2: Multiple perfect matches - likely due to repeated characters.
51
+ # Use the estimated overlap length to choose the match closest to the expected alignment.
52
+ elif len(zero_matches) > 1:
53
+ best_i = min(zero_matches, key=lambda x: abs(x - expected_overlap))
54
+ return a_crop + b_crop[best_i + 1 :]
55
+
56
+ # Case 3: Absence of zero scores indicates that the same character in the image was recognized differently OR that
57
+ # the overlap was too small and we just need to merge the crops fully
58
+ if expected_overlap < -1:
59
+ return a + b
60
+ elif expected_overlap < 0:
61
+ return a_crop + b_crop
62
+
63
+ # Find best overlap by minimizing Hamming distance + distance from expected overlap size
64
+ combined_scores = [score + abs(i - expected_overlap) for i, score in enumerate(scores)]
65
+ best_i = combined_scores.index(min(combined_scores))
66
+ return a_crop + b_crop[best_i + 1 :]
67
+
68
+
69
+ def merge_multi_strings(seq_list: list[str], overlap_ratio: float, last_overlap_ratio: float) -> str:
70
+ """
71
+ Merges consecutive string sequences with overlapping characters.
63
72
 
64
73
  Args:
65
74
  seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
66
- dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
67
- only used when the mother sequence is splitted on a character repetition
75
+ overlap_ratio: Estimated ratio of overlapping letters between neighboring strings.
76
+ last_overlap_ratio: Estimated ratio of overlapping letters for the last element in seq_list.
68
77
 
69
78
  Returns:
70
79
  A merged character sequence
71
80
 
72
81
  Example::
73
- >>> from onnxtr.model.recognition.utils import merge_multi_sequences
74
- >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
82
+ >>> from doctr.models.recognition.utils import merge_multi_strings
83
+ >>> merge_multi_strings(['abc', 'bcdef', 'difghi', 'aijkl'], 0.5, 0.1)
75
84
  'abcdefghijkl'
76
85
  """
77
-
78
- def _recursive_merge(a: str, seq_list: list[str], dil_factor: float) -> str:
79
- # Recursive version of compute_overlap
80
- if len(seq_list) == 1:
81
- return merge_strings(a, seq_list[0], dil_factor)
82
- return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor)
83
-
84
- return _recursive_merge("", seq_list, dil_factor)
86
+ if not seq_list:
87
+ return ""
88
+ result = seq_list[0]
89
+ for i in range(1, len(seq_list)):
90
+ text_b = seq_list[i]
91
+ ratio = last_overlap_ratio if i == len(seq_list) - 1 else overlap_ratio
92
+ result = merge_strings(result, text_b, ratio)
93
+ return result
@@ -22,6 +22,7 @@ ARCHS: list[str] = [
22
22
  "vitstr_small",
23
23
  "vitstr_base",
24
24
  "parseq",
25
+ "viptr_tiny",
25
26
  ]
26
27
 
27
28
 
@@ -35,7 +36,15 @@ def _predictor(
35
36
  _model = recognition.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg)
36
37
  else:
37
38
  if not isinstance(
38
- arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
39
+ arch,
40
+ (
41
+ recognition.CRNN,
42
+ recognition.SAR,
43
+ recognition.MASTER,
44
+ recognition.ViTSTR,
45
+ recognition.PARSeq,
46
+ recognition.VIPTR,
47
+ ),
39
48
  ):
40
49
  raise ValueError(f"unknown architecture: {type(arch)}")
41
50
  _model = arch