yomitoku 0.4.0.post1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. yomitoku/__init__.py +20 -0
  2. yomitoku/base.py +136 -0
  3. yomitoku/cli/__init__.py +0 -0
  4. yomitoku/cli/main.py +230 -0
  5. yomitoku/configs/__init__.py +13 -0
  6. yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
  7. yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
  8. yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
  9. yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
  10. yomitoku/constants.py +32 -0
  11. yomitoku/data/__init__.py +3 -0
  12. yomitoku/data/dataset.py +40 -0
  13. yomitoku/data/functions.py +279 -0
  14. yomitoku/document_analyzer.py +315 -0
  15. yomitoku/export/__init__.py +6 -0
  16. yomitoku/export/export_csv.py +71 -0
  17. yomitoku/export/export_html.py +188 -0
  18. yomitoku/export/export_json.py +34 -0
  19. yomitoku/export/export_markdown.py +145 -0
  20. yomitoku/layout_analyzer.py +66 -0
  21. yomitoku/layout_parser.py +189 -0
  22. yomitoku/models/__init__.py +9 -0
  23. yomitoku/models/dbnet_plus.py +272 -0
  24. yomitoku/models/layers/__init__.py +0 -0
  25. yomitoku/models/layers/activate.py +38 -0
  26. yomitoku/models/layers/dbnet_feature_attention.py +160 -0
  27. yomitoku/models/layers/parseq_transformer.py +218 -0
  28. yomitoku/models/layers/rtdetr_backbone.py +333 -0
  29. yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
  30. yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
  31. yomitoku/models/parseq.py +243 -0
  32. yomitoku/models/rtdetr.py +22 -0
  33. yomitoku/ocr.py +87 -0
  34. yomitoku/postprocessor/__init__.py +9 -0
  35. yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
  36. yomitoku/postprocessor/parseq_tokenizer.py +128 -0
  37. yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
  38. yomitoku/reading_order.py +214 -0
  39. yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
  40. yomitoku/resource/charset.txt +1 -0
  41. yomitoku/table_structure_recognizer.py +244 -0
  42. yomitoku/text_detector.py +103 -0
  43. yomitoku/text_recognizer.py +128 -0
  44. yomitoku/utils/__init__.py +0 -0
  45. yomitoku/utils/graph.py +20 -0
  46. yomitoku/utils/logger.py +15 -0
  47. yomitoku/utils/misc.py +102 -0
  48. yomitoku/utils/visualizer.py +179 -0
  49. yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
  50. yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
  51. yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
  52. yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,243 @@
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional, Sequence
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from huggingface_hub import PyTorchModelHubMixin
22
+ from timm.models.helpers import named_apply
23
+ from torch import Tensor
24
+
25
+ from ..postprocessor import ParseqTokenizer as Tokenizer
26
+ from .layers.parseq_transformer import Decoder, Encoder, TokenEmbedding
27
+
28
+
29
+ def init_weights(
30
+ module: nn.Module, name: str = "", exclude: Sequence[str] = ()
31
+ ):
32
+ """Initialize the weights using the typical initialization schemes used in SOTA models."""
33
+ if any(map(name.startswith, exclude)):
34
+ return
35
+ if isinstance(module, nn.Linear):
36
+ nn.init.trunc_normal_(module.weight, std=0.02)
37
+ if module.bias is not None:
38
+ nn.init.zeros_(module.bias)
39
+ elif isinstance(module, nn.Embedding):
40
+ nn.init.trunc_normal_(module.weight, std=0.02)
41
+ if module.padding_idx is not None:
42
+ module.weight.data[module.padding_idx].zero_()
43
+ elif isinstance(module, nn.Conv2d):
44
+ nn.init.kaiming_normal_(
45
+ module.weight, mode="fan_out", nonlinearity="relu"
46
+ )
47
+ if module.bias is not None:
48
+ nn.init.zeros_(module.bias)
49
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
50
+ nn.init.ones_(module.weight)
51
+ nn.init.zeros_(module.bias)
52
+
53
+
54
+ class PARSeq(nn.Module, PyTorchModelHubMixin):
55
+ def __init__(
56
+ self,
57
+ cfg,
58
+ ) -> None:
59
+ super().__init__()
60
+ self.cfg = cfg
61
+ self.max_label_length = self.cfg.max_label_length
62
+ self.decode_ar = self.cfg.decode_ar
63
+ self.refine_iters = self.cfg.refine_iters
64
+ embed_dim = self.cfg.decoder.embed_dim
65
+
66
+ self.encoder = Encoder(
67
+ self.cfg.data.img_size,
68
+ **self.cfg.encoder,
69
+ )
70
+
71
+ self.decoder = Decoder(
72
+ norm=nn.LayerNorm(self.cfg.decoder.embed_dim),
73
+ cfg=self.cfg.decoder,
74
+ )
75
+
76
+ # We don't predict <bos> nor <pad>
77
+ self.head = nn.Linear(embed_dim, self.cfg.num_tokens - 2)
78
+ self.text_embed = TokenEmbedding(self.cfg.num_tokens, embed_dim)
79
+
80
+ # +1 for <eos>
81
+ self.pos_queries = nn.Parameter(
82
+ torch.Tensor(1, self.max_label_length + 1, embed_dim)
83
+ )
84
+ self.dropout = nn.Dropout()
85
+ # Encoder has its own init.
86
+ named_apply(partial(init_weights, exclude=["encoder"]), self)
87
+ nn.init.trunc_normal_(self.pos_queries, std=0.02)
88
+
89
+ @property
90
+ def _device(self) -> torch.device:
91
+ return next(self.head.parameters(recurse=False)).device
92
+
93
+ @torch.jit.ignore
94
+ def no_weight_decay(self):
95
+ param_names = {"text_embed.embedding.weight", "pos_queries"}
96
+ enc_param_names = {
97
+ "encoder." + n for n in self.encoder.no_weight_decay()
98
+ }
99
+ return param_names.union(enc_param_names)
100
+
101
+ def encode(self, img: torch.Tensor):
102
+ return self.encoder(img)
103
+
104
+ def decode(
105
+ self,
106
+ tgt: torch.Tensor,
107
+ memory: torch.Tensor,
108
+ tgt_mask: Optional[Tensor] = None,
109
+ tgt_padding_mask: Optional[Tensor] = None,
110
+ tgt_query: Optional[Tensor] = None,
111
+ tgt_query_mask: Optional[Tensor] = None,
112
+ ):
113
+ N, L = tgt.shape
114
+ # <bos> stands for the null context. We only supply position information for characters after <bos>.
115
+ null_ctx = self.text_embed(tgt[:, :1])
116
+ tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
117
+ tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
118
+ if tgt_query is None:
119
+ tgt_query = self.pos_queries[:, :L].expand(N, -1, -1)
120
+ tgt_query = self.dropout(tgt_query)
121
+ return self.decoder(
122
+ tgt_query,
123
+ tgt_emb,
124
+ memory,
125
+ tgt_query_mask,
126
+ tgt_mask,
127
+ tgt_padding_mask,
128
+ )
129
+
130
+ def forward(
131
+ self,
132
+ tokenizer: Tokenizer,
133
+ images: Tensor,
134
+ max_length: Optional[int] = None,
135
+ ) -> Tensor:
136
+ testing = max_length is None
137
+ max_length = (
138
+ self.max_label_length
139
+ if max_length is None
140
+ else min(max_length, self.max_label_length)
141
+ )
142
+ bs = images.shape[0]
143
+ # +1 for <eos> at end of sequence.
144
+ num_steps = max_length + 1
145
+ memory = self.encode(images)
146
+
147
+ # Query positions up to `num_steps`
148
+ pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
149
+
150
+ # Special case for the forward permutation. Faster than using `generate_attn_masks()`
151
+ tgt_mask = query_mask = torch.triu(
152
+ torch.ones(
153
+ (num_steps, num_steps), dtype=torch.bool, device=self._device
154
+ ),
155
+ 1,
156
+ )
157
+
158
+ if self.decode_ar:
159
+ tgt_in = torch.full(
160
+ (bs, num_steps),
161
+ tokenizer.pad_id,
162
+ dtype=torch.long,
163
+ device=self._device,
164
+ )
165
+ tgt_in[:, 0] = tokenizer.bos_id
166
+
167
+ logits = []
168
+ for i in range(num_steps):
169
+ j = i + 1 # next token index
170
+ # Efficient decoding:
171
+ # Input the context up to the ith token. We use only one query (at poad masking effect of the canonical (forward) AR context.
172
+ # Past tokens have no access to future tokens, hence are fixed once computed.sition = i) at a time.
173
+ # This works because of the lookahe
174
+ tgt_out = self.decode(
175
+ tgt_in[:, :j],
176
+ memory,
177
+ tgt_mask[:j, :j],
178
+ tgt_query=pos_queries[:, i:j],
179
+ tgt_query_mask=query_mask[i:j, :j],
180
+ )
181
+ # the next token probability is in the output's ith token position
182
+ p_i = self.head(tgt_out)
183
+ logits.append(p_i)
184
+ if j < num_steps:
185
+ # greedy decode. add the next token index to the target input
186
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
187
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
188
+ if (
189
+ testing
190
+ and (tgt_in == tokenizer.eos_id).any(dim=-1).all()
191
+ ):
192
+ break
193
+
194
+ logits = torch.cat(logits, dim=1)
195
+ else:
196
+ # No prior context, so input is just <bos>. We query all positions.
197
+ tgt_in = torch.full(
198
+ (bs, 1),
199
+ tokenizer.bos_id,
200
+ dtype=torch.long,
201
+ device=self._device,
202
+ )
203
+ tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
204
+ logits = self.head(tgt_out)
205
+
206
+ if self.refine_iters:
207
+ # For iterative refinement, we always use a 'cloze' mask.
208
+ # We can derive it from the AR forward mask by unmasking the token context to the right.
209
+ query_mask[
210
+ torch.triu(
211
+ torch.ones(
212
+ num_steps,
213
+ num_steps,
214
+ dtype=torch.bool,
215
+ device=self._device,
216
+ ),
217
+ 2,
218
+ )
219
+ ] = 0
220
+ bos = torch.full(
221
+ (bs, 1),
222
+ tokenizer.bos_id,
223
+ dtype=torch.long,
224
+ device=self._device,
225
+ )
226
+ for i in range(self.refine_iters):
227
+ # Prior context is the previous output.
228
+ tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
229
+ # Mask tokens beyond the first EOS token.
230
+ tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(
231
+ -1
232
+ ) > 0
233
+ tgt_out = self.decode(
234
+ tgt_in,
235
+ memory,
236
+ tgt_mask,
237
+ tgt_padding_mask,
238
+ pos_queries,
239
+ query_mask[:, : tgt_in.shape[1]],
240
+ )
241
+ logits = self.head(tgt_out)
242
+
243
+ return logits
@@ -0,0 +1,22 @@
1
+ import torch.nn as nn
2
+ from huggingface_hub import PyTorchModelHubMixin
3
+
4
+ from .layers.rtdetr_backbone import PResNet
5
+ from .layers.rtdetr_hybrid_encoder import HybridEncoder
6
+ from .layers.rtdetrv2_decoder import RTDETRTransformerv2
7
+
8
+
9
+ class RTDETRv2(nn.Module, PyTorchModelHubMixin):
10
+ def __init__(self, cfg):
11
+ super().__init__()
12
+ self.cfg = cfg
13
+ self.backbone = PResNet(**cfg.PResNet)
14
+ self.encoder = HybridEncoder(**cfg.HybridEncoder)
15
+ self.decoder = RTDETRTransformerv2(**cfg.RTDETRTransformerv2)
16
+
17
+ def forward(self, x, targets=None):
18
+ x = self.backbone(x)
19
+ x = self.encoder(x)
20
+ x = self.decoder(x, targets)
21
+
22
+ return x
yomitoku/ocr.py ADDED
@@ -0,0 +1,87 @@
1
+ from typing import List
2
+
3
+ from pydantic import conlist
4
+
5
+ from yomitoku.text_detector import TextDetector
6
+ from yomitoku.text_recognizer import TextRecognizer
7
+
8
+ from .base import BaseSchema
9
+
10
+
11
+ class WordPrediction(BaseSchema):
12
+ points: conlist(
13
+ conlist(int, min_length=2, max_length=2),
14
+ min_length=4,
15
+ max_length=4,
16
+ )
17
+ content: str
18
+ direction: str
19
+ det_score: float
20
+ rec_score: float
21
+
22
+
23
+ class OCRSchema(BaseSchema):
24
+ words: List[WordPrediction]
25
+
26
+
27
+ class OCR:
28
+ def __init__(self, configs=None, device="cuda", visualize=False):
29
+ text_detector_kwargs = {
30
+ "device": device,
31
+ "visualize": visualize,
32
+ }
33
+ text_recognizer_kwargs = {
34
+ "device": device,
35
+ "visualize": visualize,
36
+ }
37
+
38
+ if isinstance(configs, dict):
39
+ assert (
40
+ "text_detector" in configs or "text_recognizer" in configs
41
+ ), "Invalid config key. Please check the config keys."
42
+
43
+ if "text_detector" in configs:
44
+ text_detector_kwargs.update(configs["text_detector"])
45
+ if "text_recognizer" in configs:
46
+ text_recognizer_kwargs.update(configs["text_recognizer"])
47
+ else:
48
+ raise ValueError(
49
+ "configs must be a dict. See the https://kotaro-kinoshita.github.io/yomitoku-dev/usage/"
50
+ )
51
+
52
+ self.detector = TextDetector(**text_detector_kwargs)
53
+ self.recognizer = TextRecognizer(**text_recognizer_kwargs)
54
+
55
+ def aggregate(self, det_outputs, rec_outputs):
56
+ words = []
57
+ for points, det_score, pred, rec_score, direction in zip(
58
+ det_outputs.points,
59
+ det_outputs.scores,
60
+ rec_outputs.contents,
61
+ rec_outputs.scores,
62
+ rec_outputs.directions,
63
+ ):
64
+ words.append(
65
+ {
66
+ "points": points,
67
+ "content": pred,
68
+ "direction": direction,
69
+ "det_score": det_score,
70
+ "rec_score": rec_score,
71
+ }
72
+ )
73
+ return words
74
+
75
+ def __call__(self, img):
76
+ """_summary_
77
+
78
+ Args:
79
+ img (np.ndarray): cv2 image(BGR)
80
+ """
81
+
82
+ det_outputs, vis = self.detector(img)
83
+ rec_outputs, vis = self.recognizer(img, det_outputs.points, vis=vis)
84
+
85
+ outputs = {"words": self.aggregate(det_outputs, rec_outputs)}
86
+ results = OCRSchema(**outputs)
87
+ return results, vis
@@ -0,0 +1,9 @@
1
+ from .dbnet_postporcessor import DBnetPostProcessor
2
+ from .parseq_tokenizer import ParseqTokenizer
3
+ from .rtdetr_postprocessor import RTDETRPostProcessor
4
+
5
+ __all__ = [
6
+ "DBnetPostProcessor",
7
+ "RTDETRPostProcessor",
8
+ "ParseqTokenizer",
9
+ ]
@@ -0,0 +1,137 @@
1
+ import cv2
2
+ import numpy as np
3
+ import pyclipper
4
+ from shapely.geometry import Polygon
5
+
6
+
7
+ class DBnetPostProcessor:
8
+ def __init__(
9
+ self, min_size, thresh, box_thresh, max_candidates, unclip_ratio
10
+ ):
11
+ self.min_size = min_size
12
+ self.thresh = thresh
13
+ self.box_thresh = box_thresh
14
+ self.max_candidates = max_candidates
15
+ self.unclip_ratio = unclip_ratio
16
+
17
+ def __call__(self, preds, image_size):
18
+ """
19
+ pred:
20
+ binary: text region segmentation map, with shape (N, H, W)
21
+ thresh: [if exists] thresh hold prediction with shape (N, H, W)
22
+ thresh_binary: [if exists] binarized with threshhold, (N, H, W)
23
+ """
24
+ pred = preds["binary"][0]
25
+ segmentation = self.binarize(pred)[0]
26
+ height, width = image_size
27
+ quads, scores = self.boxes_from_bitmap(
28
+ pred, segmentation, width, height
29
+ )
30
+ return quads, scores
31
+
32
+ def binarize(self, pred):
33
+ return pred > self.thresh
34
+
35
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
36
+ """
37
+ _bitmap: single map with shape (H, W),
38
+ whose values are binarized as {0, 1}
39
+ """
40
+
41
+ assert len(_bitmap.shape) == 2
42
+ bitmap = _bitmap.cpu().numpy() # The first channel
43
+
44
+ pred = pred.cpu().detach().numpy()[0]
45
+ height, width = bitmap.shape
46
+ contours, _ = cv2.findContours(
47
+ (bitmap * 255).astype(np.uint8),
48
+ cv2.RETR_LIST,
49
+ cv2.CHAIN_APPROX_SIMPLE,
50
+ )
51
+
52
+ num_contours = min(len(contours), self.max_candidates)
53
+
54
+ boxes = []
55
+ scores = []
56
+ for index in range(num_contours):
57
+ contour = contours[index].squeeze(1)
58
+ points, sside = self.get_mini_boxes(contour)
59
+
60
+ if sside < self.min_size:
61
+ continue
62
+ points = np.array(points)
63
+ score = self.box_score_fast(pred, contour)
64
+
65
+ if self.box_thresh > score:
66
+ continue
67
+
68
+ box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(
69
+ -1, 1, 2
70
+ )
71
+ box, sside = self.get_mini_boxes(box)
72
+ if sside < self.min_size + 2:
73
+ continue
74
+ box = np.array(box)
75
+ if not isinstance(dest_width, int):
76
+ dest_width = dest_width.item()
77
+ dest_height = dest_height.item()
78
+
79
+ box[:, 0] = np.clip(
80
+ np.round(box[:, 0] / width * dest_width), 0, dest_width
81
+ )
82
+ box[:, 1] = np.clip(
83
+ np.round(box[:, 1] / height * dest_height), 0, dest_height
84
+ )
85
+
86
+ boxes.append(box.astype(np.int16).tolist())
87
+ scores.append(score)
88
+
89
+ return boxes, scores
90
+
91
+ def unclip(self, box, unclip_ratio=1.5):
92
+ poly = Polygon(box)
93
+ distance = poly.area * unclip_ratio / poly.length
94
+ offset = pyclipper.PyclipperOffset()
95
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
96
+ expanded = np.array(offset.Execute(distance))
97
+ return expanded
98
+
99
+ def get_mini_boxes(self, contour):
100
+ bounding_box = cv2.minAreaRect(contour)
101
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
102
+
103
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
104
+ if points[1][1] > points[0][1]:
105
+ index_1 = 0
106
+ index_4 = 1
107
+ else:
108
+ index_1 = 1
109
+ index_4 = 0
110
+ if points[3][1] > points[2][1]:
111
+ index_2 = 2
112
+ index_3 = 3
113
+ else:
114
+ index_2 = 3
115
+ index_3 = 2
116
+
117
+ box = [
118
+ points[index_1],
119
+ points[index_2],
120
+ points[index_3],
121
+ points[index_4],
122
+ ]
123
+ return box, min(bounding_box[1])
124
+
125
+ def box_score_fast(self, bitmap, _box):
126
+ h, w = bitmap.shape[:2]
127
+ box = _box.copy()
128
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1)
129
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1)
130
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1)
131
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1)
132
+
133
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
134
+ box[:, 0] = box[:, 0] - xmin
135
+ box[:, 1] = box[:, 1] - ymin
136
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
137
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
@@ -0,0 +1,128 @@
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from torch import Tensor
21
+ from torch.nn.utils.rnn import pad_sequence
22
+
23
+
24
+ class BaseTokenizer(ABC):
25
+ def __init__(
26
+ self,
27
+ charset: str,
28
+ specials_first: tuple = (),
29
+ specials_last: tuple = (),
30
+ ) -> None:
31
+ self._itos = specials_first + tuple(charset) + specials_last
32
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
33
+
34
+ def __len__(self):
35
+ return len(self._itos)
36
+
37
+ def _tok2ids(self, tokens: str) -> list[int]:
38
+ return [self._stoi[s] for s in tokens]
39
+
40
+ def _ids2tok(self, token_ids: list[int], join: bool = True) -> str:
41
+ tokens = [self._itos[i] for i in token_ids]
42
+ return "".join(tokens) if join else tokens
43
+
44
+ @abstractmethod
45
+ def encode(
46
+ self, labels: list[str], device: Optional[torch.device] = None
47
+ ) -> Tensor:
48
+ """Encode a batch of labels to a representation suitable for the model.
49
+
50
+ Args:
51
+ labels: List of labels. Each can be of arbitrary length.
52
+ device: Create tensor on this device.
53
+
54
+ Returns:
55
+ Batched tensor representation padded to the max label length. Shape: N, L
56
+ """
57
+ raise NotImplementedError
58
+
59
+ @abstractmethod
60
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
61
+ """Internal method which performs the necessary filtering prior to decoding."""
62
+ raise NotImplementedError
63
+
64
+ def decode(
65
+ self, token_dists: Tensor, raw: bool = False
66
+ ) -> tuple[list[str], list[Tensor]]:
67
+ """Decode a batch of token distributions.
68
+
69
+ Args:
70
+ token_dists: softmax probabilities over the token distribution. Shape: N, L, C
71
+ raw: return unprocessed labels (will return list of list of strings)
72
+
73
+ Returns:
74
+ list of string labels (arbitrary length) and
75
+ their corresponding sequence probabilities as a list of Tensors
76
+ """
77
+ batch_tokens = []
78
+ batch_probs = []
79
+ for dist in token_dists:
80
+ probs, ids = dist.max(-1) # greedy selection
81
+ if not raw:
82
+ probs, ids = self._filter(probs, ids)
83
+ tokens = self._ids2tok(ids, not raw)
84
+ probs = probs.cpu().numpy()
85
+ probs = float(probs.prod())
86
+ batch_tokens.append(tokens)
87
+ batch_probs.append(probs)
88
+ return batch_tokens, batch_probs
89
+
90
+
91
+ class ParseqTokenizer(BaseTokenizer):
92
+ BOS = "[B]"
93
+ EOS = "[E]"
94
+ PAD = "[P]"
95
+
96
+ def __init__(self, charset: str) -> None:
97
+ specials_first = (self.EOS,)
98
+ specials_last = (self.BOS, self.PAD)
99
+ super().__init__(charset, specials_first, specials_last)
100
+ self.eos_id, self.bos_id, self.pad_id = [
101
+ self._stoi[s] for s in specials_first + specials_last
102
+ ]
103
+
104
+ def encode(
105
+ self, labels: list[str], device: Optional[torch.device] = None
106
+ ) -> Tensor:
107
+ batch = [
108
+ torch.as_tensor(
109
+ [self.bos_id] + self._tok2ids(y) + [self.eos_id],
110
+ dtype=torch.long,
111
+ device=device,
112
+ )
113
+ for y in labels
114
+ ]
115
+ return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
116
+
117
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
118
+ ids = ids.tolist()
119
+ try:
120
+ eos_idx = ids.index(self.eos_id)
121
+ except ValueError:
122
+ eos_idx = len(ids) # Nothing to truncate.
123
+ # Truncate after EOS
124
+ ids = ids[:eos_idx]
125
+ probs = probs[
126
+ : eos_idx + 1
127
+ ] # but include prob. for EOS (if it exists)
128
+ return probs, ids