xfmr-zem 0.2.2__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. xfmr_zem/cli.py +32 -3
  2. xfmr_zem/client.py +59 -8
  3. xfmr_zem/server.py +21 -4
  4. xfmr_zem/servers/data_juicer/server.py +1 -1
  5. xfmr_zem/servers/instruction_gen/server.py +1 -1
  6. xfmr_zem/servers/io/server.py +1 -1
  7. xfmr_zem/servers/llm/parameters.yml +10 -0
  8. xfmr_zem/servers/nemo_curator/server.py +1 -1
  9. xfmr_zem/servers/ocr/deepdoc_vietocr/__init__.py +90 -0
  10. xfmr_zem/servers/ocr/deepdoc_vietocr/implementations.py +1286 -0
  11. xfmr_zem/servers/ocr/deepdoc_vietocr/layout_recognizer.py +562 -0
  12. xfmr_zem/servers/ocr/deepdoc_vietocr/ocr.py +512 -0
  13. xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/.gitattributes +35 -0
  14. xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/README.md +5 -0
  15. xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/ocr.res +6623 -0
  16. xfmr_zem/servers/ocr/deepdoc_vietocr/operators.py +725 -0
  17. xfmr_zem/servers/ocr/deepdoc_vietocr/phases.py +191 -0
  18. xfmr_zem/servers/ocr/deepdoc_vietocr/pipeline.py +561 -0
  19. xfmr_zem/servers/ocr/deepdoc_vietocr/postprocess.py +370 -0
  20. xfmr_zem/servers/ocr/deepdoc_vietocr/recognizer.py +436 -0
  21. xfmr_zem/servers/ocr/deepdoc_vietocr/table_structure_recognizer.py +569 -0
  22. xfmr_zem/servers/ocr/deepdoc_vietocr/utils/__init__.py +81 -0
  23. xfmr_zem/servers/ocr/deepdoc_vietocr/utils/file_utils.py +246 -0
  24. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/__init__.py +0 -0
  25. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/base.yml +58 -0
  26. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/vgg-seq2seq.yml +38 -0
  27. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/__init__.py +0 -0
  28. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/cnn.py +25 -0
  29. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/vgg.py +51 -0
  30. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/seqmodel/seq2seq.py +175 -0
  31. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/transformerocr.py +29 -0
  32. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/vocab.py +36 -0
  33. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/config.py +37 -0
  34. xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/translate.py +111 -0
  35. xfmr_zem/servers/ocr/engines.py +242 -0
  36. xfmr_zem/servers/ocr/install_models.py +63 -0
  37. xfmr_zem/servers/ocr/parameters.yml +4 -0
  38. xfmr_zem/servers/ocr/server.py +44 -0
  39. xfmr_zem/servers/profiler/parameters.yml +4 -0
  40. xfmr_zem/servers/sinks/parameters.yml +6 -0
  41. xfmr_zem/servers/unstructured/parameters.yml +6 -0
  42. xfmr_zem/servers/unstructured/server.py +62 -0
  43. xfmr_zem/zenml_wrapper.py +20 -7
  44. {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/METADATA +19 -1
  45. xfmr_zem-0.2.5.dist-info/RECORD +58 -0
  46. xfmr_zem-0.2.2.dist-info/RECORD +0 -23
  47. /xfmr_zem/servers/data_juicer/{parameter.yaml → parameters.yml} +0 -0
  48. /xfmr_zem/servers/instruction_gen/{parameter.yaml → parameters.yml} +0 -0
  49. /xfmr_zem/servers/io/{parameter.yaml → parameters.yml} +0 -0
  50. /xfmr_zem/servers/nemo_curator/{parameter.yaml → parameters.yml} +0 -0
  51. {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/WHEEL +0 -0
  52. {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/entry_points.txt +0 -0
  53. {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,246 @@
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import base64
17
+ import json
18
+ import os
19
+ import re
20
+ import sys
21
+ import threading
22
+ from io import BytesIO
23
+
24
+ import pdfplumber
25
+ from PIL import Image
26
+ from cachetools import LRUCache, cached
27
+ from ruamel.yaml import YAML
28
+
29
+ from enum import Enum
30
+ # from .db import FileType
31
+ # from .constants import IMG_BASE64_PREFIX
32
+ IMG_BASE64_PREFIX = 'data:image/png;base64,'
33
+
34
+ class FileType(Enum):
35
+ PDF = 'pdf'
36
+ DOC = 'doc'
37
+ VISUAL = 'visual'
38
+ AURAL = 'aural'
39
+ VIRTUAL = 'virtual'
40
+ FOLDER = 'folder'
41
+ OTHER = "other"
42
+
43
+ PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
44
+ RAG_BASE = os.getenv("RAG_BASE")
45
+
46
+ LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
47
+ if LOCK_KEY_pdfplumber not in sys.modules:
48
+ sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
49
+
50
+
51
+ def get_project_base_directory(*args):
52
+ base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir))
53
+ if args:
54
+ return os.path.join(base_dir, *args)
55
+ return base_dir
56
+
57
+ def get_rag_directory(*args):
58
+ global RAG_BASE
59
+ if RAG_BASE is None:
60
+ RAG_BASE = os.path.abspath(
61
+ os.path.join(
62
+ os.path.dirname(os.path.realpath(__file__)),
63
+ os.pardir,
64
+ os.pardir,
65
+ os.pardir,
66
+ )
67
+ )
68
+ if args:
69
+ return os.path.join(RAG_BASE, *args)
70
+ return RAG_BASE
71
+
72
+
73
+ def get_rag_python_directory(*args):
74
+ return get_rag_directory("python", *args)
75
+
76
+
77
+ def get_home_cache_dir():
78
+ dir = os.path.join(os.path.expanduser('~'), ".ragflow")
79
+ try:
80
+ os.mkdir(dir)
81
+ except OSError:
82
+ pass
83
+ return dir
84
+
85
+
86
+ @cached(cache=LRUCache(maxsize=10))
87
+ def load_json_conf(conf_path):
88
+ if os.path.isabs(conf_path):
89
+ json_conf_path = conf_path
90
+ else:
91
+ json_conf_path = os.path.join(get_project_base_directory(), conf_path)
92
+ try:
93
+ with open(json_conf_path) as f:
94
+ return json.load(f)
95
+ except BaseException:
96
+ raise EnvironmentError(
97
+ "loading json file config from '{}' failed!".format(json_conf_path)
98
+ )
99
+
100
+
101
+ def dump_json_conf(config_data, conf_path):
102
+ if os.path.isabs(conf_path):
103
+ json_conf_path = conf_path
104
+ else:
105
+ json_conf_path = os.path.join(get_project_base_directory(), conf_path)
106
+ try:
107
+ with open(json_conf_path, "w") as f:
108
+ json.dump(config_data, f, indent=4)
109
+ except BaseException:
110
+ raise EnvironmentError(
111
+ "loading json file config from '{}' failed!".format(json_conf_path)
112
+ )
113
+
114
+
115
+ def load_json_conf_real_time(conf_path):
116
+ if os.path.isabs(conf_path):
117
+ json_conf_path = conf_path
118
+ else:
119
+ json_conf_path = os.path.join(get_project_base_directory(), conf_path)
120
+ try:
121
+ with open(json_conf_path) as f:
122
+ return json.load(f)
123
+ except BaseException:
124
+ raise EnvironmentError(
125
+ "loading json file config from '{}' failed!".format(json_conf_path)
126
+ )
127
+
128
+
129
+ def load_yaml_conf(conf_path):
130
+ if not os.path.isabs(conf_path):
131
+ conf_path = os.path.join(get_project_base_directory(), conf_path)
132
+ try:
133
+ with open(conf_path) as f:
134
+ yaml = YAML(typ='safe', pure=True)
135
+ return yaml.load(f)
136
+ except Exception as e:
137
+ raise EnvironmentError(
138
+ "loading yaml file config from {} failed:".format(conf_path), e
139
+ )
140
+
141
+
142
+ def rewrite_yaml_conf(conf_path, config):
143
+ if not os.path.isabs(conf_path):
144
+ conf_path = os.path.join(get_project_base_directory(), conf_path)
145
+ try:
146
+ with open(conf_path, "w") as f:
147
+ yaml = YAML(typ="safe")
148
+ yaml.dump(config, f)
149
+ except Exception as e:
150
+ raise EnvironmentError(
151
+ "rewrite yaml file config {} failed:".format(conf_path), e
152
+ )
153
+
154
+
155
+ def rewrite_json_file(filepath, json_data):
156
+ with open(filepath, "w", encoding='utf-8') as f:
157
+ json.dump(json_data, f, indent=4, separators=(",", ": "))
158
+ f.close()
159
+
160
+
161
+ def filename_type(filename):
162
+ filename = filename.lower()
163
+ if re.match(r".*\.pdf$", filename):
164
+ return FileType.PDF.value
165
+
166
+ if re.match(
167
+ r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
168
+ return FileType.DOC.value
169
+
170
+ if re.match(
171
+ r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
172
+ return FileType.AURAL.value
173
+
174
+ if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
175
+ return FileType.VISUAL.value
176
+
177
+ return FileType.OTHER.value
178
+
179
+ def thumbnail_img(filename, blob):
180
+ """
181
+ MySQL LongText max length is 65535
182
+ """
183
+ filename = filename.lower()
184
+ if re.match(r".*\.pdf$", filename):
185
+ with sys.modules[LOCK_KEY_pdfplumber]:
186
+ pdf = pdfplumber.open(BytesIO(blob))
187
+ buffered = BytesIO()
188
+ resolution = 32
189
+ img = None
190
+ for _ in range(10):
191
+ # https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
192
+ pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
193
+ img = buffered.getvalue()
194
+ if len(img) >= 64000 and resolution >= 2:
195
+ resolution = resolution / 2
196
+ buffered = BytesIO()
197
+ else:
198
+ break
199
+ pdf.close()
200
+ return img
201
+
202
+ elif re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
203
+ image = Image.open(BytesIO(blob))
204
+ image.thumbnail((30, 30))
205
+ buffered = BytesIO()
206
+ image.save(buffered, format="png")
207
+ return buffered.getvalue()
208
+
209
+ elif re.match(r".*\.(ppt|pptx)$", filename):
210
+ import aspose.slides as slides
211
+ import aspose.pydrawing as drawing
212
+ try:
213
+ with slides.Presentation(BytesIO(blob)) as presentation:
214
+ buffered = BytesIO()
215
+ scale = 0.03
216
+ img = None
217
+ for _ in range(10):
218
+ # https://reference.aspose.com/slides/python-net/aspose.slides/slide/get_thumbnail/#float-float
219
+ presentation.slides[0].get_thumbnail(scale, scale).save(
220
+ buffered, drawing.imaging.ImageFormat.png)
221
+ img = buffered.getvalue()
222
+ if len(img) >= 64000:
223
+ scale = scale / 2.0
224
+ buffered = BytesIO()
225
+ else:
226
+ break
227
+ return img
228
+ except Exception:
229
+ pass
230
+ return None
231
+
232
+
233
+ def thumbnail(filename, blob):
234
+ img = thumbnail_img(filename, blob)
235
+ if img is not None:
236
+ return IMG_BASE64_PREFIX + \
237
+ base64.b64encode(img).decode("utf-8")
238
+ else:
239
+ return ''
240
+
241
+
242
+ def traversal_files(base):
243
+ for root, ds, fs in os.walk(base):
244
+ for f in fs:
245
+ fullname = os.path.join(root, f)
246
+ yield fullname
@@ -0,0 +1,58 @@
1
+ # change to list chars of your dataset or use default vietnamese chars
2
+ vocab: 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ '
3
+
4
+ # cpu, cuda, cuda:0
5
+ device: cuda:0
6
+
7
+ seq_modeling: transformer
8
+ transformer:
9
+ d_model: 256
10
+ nhead: 8
11
+ num_encoder_layers: 6
12
+ num_decoder_layers: 6
13
+ dim_feedforward: 2048
14
+ max_seq_length: 1024
15
+ pos_dropout: 0.1
16
+ trans_dropout: 0.1
17
+
18
+ optimizer:
19
+ max_lr: 0.0003
20
+ pct_start: 0.1
21
+
22
+ trainer:
23
+ batch_size: 32
24
+ print_every: 200
25
+ valid_every: 4000
26
+ iters: 100000
27
+ # where to save our model for prediction
28
+ export: ./weights/transformerocr.pth
29
+ checkpoint: ./checkpoint/transformerocr_checkpoint.pth
30
+ log: ./train.log
31
+ # null to disable compuate accuracy, or change to number of sample to enable validiation while training
32
+ metrics: null
33
+
34
+ dataset:
35
+ # name of your dataset
36
+ name: data
37
+ # path to annotation and image
38
+ data_root: ./img/
39
+ train_annotation: annotation_train.txt
40
+ valid_annotation: annotation_val_small.txt
41
+ # resize image to 32 height, larger height will increase accuracy
42
+ image_height: 32
43
+ image_min_width: 32
44
+ image_max_width: 512
45
+
46
+ dataloader:
47
+ num_workers: 3
48
+ pin_memory: True
49
+
50
+ aug:
51
+ image_aug: true
52
+ masked_language_model: true
53
+
54
+ predictor:
55
+ # disable or enable beamsearch while prediction, use beamsearch will be slower
56
+ beamsearch: False
57
+
58
+ quiet: False
@@ -0,0 +1,38 @@
1
+ pretrain:
2
+ id_or_url: 1nTKlEog9YFK74kPyX0qLwCWi60_YHHk4
3
+ md5: efcabaa6d3adfca8e52bda2fd7d2ee04
4
+ cached: /tmp/tranformerorc.pth
5
+
6
+ # url or local path
7
+ weights: https://drive.google.com/uc?id=1nTKlEog9YFK74kPyX0qLwCWi60_YHHk4
8
+
9
+ backbone: vgg19_bn
10
+ cnn:
11
+ # pooling stride size
12
+ ss:
13
+ - [2, 2]
14
+ - [2, 2]
15
+ - [2, 1]
16
+ - [2, 1]
17
+ - [1, 1]
18
+ # pooling kernel size
19
+ ks:
20
+ - [2, 2]
21
+ - [2, 2]
22
+ - [2, 1]
23
+ - [2, 1]
24
+ - [1, 1]
25
+ # dim of ouput feature map
26
+ hidden: 256
27
+
28
+ seq_modeling: seq2seq
29
+ transformer:
30
+ encoder_hidden: 256
31
+ decoder_hidden: 256
32
+ img_channel: 256
33
+ decoder_embedded: 256
34
+ dropout: 0.1
35
+
36
+ optimizer:
37
+ max_lr: 0.001
38
+ pct_start: 0.1
@@ -0,0 +1,25 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from . import vgg
5
+
6
+ class CNN(nn.Module):
7
+ def __init__(self, backbone, **kwargs):
8
+ super(CNN, self).__init__()
9
+
10
+ if backbone == 'vgg11_bn':
11
+ self.model = vgg.vgg11_bn(**kwargs)
12
+ elif backbone == 'vgg19_bn':
13
+ self.model = vgg.vgg19_bn(**kwargs)
14
+
15
+ def forward(self, x):
16
+ return self.model(x)
17
+
18
+ def freeze(self):
19
+ for name, param in self.model.features.named_parameters():
20
+ if name != 'last_conv_1x1':
21
+ param.requires_grad = False
22
+
23
+ def unfreeze(self):
24
+ for param in self.model.features.parameters():
25
+ param.requires_grad = True
@@ -0,0 +1,51 @@
1
+ import torch
2
+ from torch import nn
3
+ from torchvision import models
4
+ from einops import rearrange
5
+ from torchvision.models._utils import IntermediateLayerGetter
6
+
7
+
8
+ class Vgg(nn.Module):
9
+ def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5):
10
+ super(Vgg, self).__init__()
11
+
12
+ weights = "DEFAULT" if pretrained else None
13
+ if name == 'vgg11_bn':
14
+ cnn = models.vgg11_bn(weights=weights)
15
+ elif name == 'vgg19_bn':
16
+ cnn = models.vgg19_bn(weights=weights)
17
+
18
+ pool_idx = 0
19
+
20
+ for i, layer in enumerate(cnn.features):
21
+ if isinstance(layer, torch.nn.MaxPool2d):
22
+ cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0)
23
+ pool_idx += 1
24
+
25
+ self.features = cnn.features
26
+ self.dropout = nn.Dropout(dropout)
27
+ self.last_conv_1x1 = nn.Conv2d(512, hidden, 1)
28
+
29
+ def forward(self, x):
30
+ """
31
+ Shape:
32
+ - x: (N, C, H, W)
33
+ - output: (W, N, C)
34
+ """
35
+
36
+ conv = self.features(x)
37
+ conv = self.dropout(conv)
38
+ conv = self.last_conv_1x1(conv)
39
+
40
+ # conv = rearrange(conv, 'b d h w -> b d (w h)')
41
+ conv = conv.permute(0, 1, 3, 2)
42
+ conv = conv.flatten(2)
43
+ conv = conv.permute(2, 0, 1)
44
+
45
+ return conv
46
+
47
+ def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
48
+ return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout)
49
+
50
+ def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
51
+ return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout)
@@ -0,0 +1,175 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+
6
+ class Encoder(nn.Module):
7
+ def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
8
+ super().__init__()
9
+
10
+ self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
11
+ self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
12
+ self.dropout = nn.Dropout(dropout)
13
+
14
+ def forward(self, src):
15
+ """
16
+ src: src_len x batch_size x img_channel
17
+ outputs: src_len x batch_size x hid_dim
18
+ hidden: batch_size x hid_dim
19
+ """
20
+
21
+ embedded = self.dropout(src)
22
+
23
+ outputs, hidden = self.rnn(embedded)
24
+
25
+ hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
26
+
27
+ return outputs, hidden
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(self, enc_hid_dim, dec_hid_dim):
31
+ super().__init__()
32
+
33
+ self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
34
+ self.v = nn.Linear(dec_hid_dim, 1, bias = False)
35
+
36
+ def forward(self, hidden, encoder_outputs):
37
+ """
38
+ hidden: batch_size x hid_dim
39
+ encoder_outputs: src_len x batch_size x hid_dim,
40
+ outputs: batch_size x src_len
41
+ """
42
+
43
+ batch_size = encoder_outputs.shape[1]
44
+ src_len = encoder_outputs.shape[0]
45
+
46
+ hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
47
+
48
+ encoder_outputs = encoder_outputs.permute(1, 0, 2)
49
+
50
+ energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
51
+
52
+ attention = self.v(energy).squeeze(2)
53
+
54
+ return F.softmax(attention, dim = 1)
55
+
56
+ class Decoder(nn.Module):
57
+ def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
58
+ super().__init__()
59
+
60
+ self.output_dim = output_dim
61
+ self.attention = attention
62
+
63
+ self.embedding = nn.Embedding(output_dim, emb_dim)
64
+ self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
65
+ self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
66
+ self.dropout = nn.Dropout(dropout)
67
+
68
+ def forward(self, input, hidden, encoder_outputs):
69
+ """
70
+ inputs: batch_size
71
+ hidden: batch_size x hid_dim
72
+ encoder_outputs: src_len x batch_size x hid_dim
73
+ """
74
+
75
+ input = input.unsqueeze(0)
76
+
77
+ embedded = self.dropout(self.embedding(input))
78
+
79
+ a = self.attention(hidden, encoder_outputs)
80
+
81
+ a = a.unsqueeze(1)
82
+
83
+ encoder_outputs = encoder_outputs.permute(1, 0, 2)
84
+
85
+ weighted = torch.bmm(a, encoder_outputs)
86
+
87
+ weighted = weighted.permute(1, 0, 2)
88
+
89
+ rnn_input = torch.cat((embedded, weighted), dim = 2)
90
+
91
+ output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
92
+
93
+ assert (output == hidden).all()
94
+
95
+ embedded = embedded.squeeze(0)
96
+ output = output.squeeze(0)
97
+ weighted = weighted.squeeze(0)
98
+
99
+ prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
100
+
101
+ return prediction, hidden.squeeze(0), a.squeeze(1)
102
+
103
+ class Seq2Seq(nn.Module):
104
+ def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1):
105
+ super().__init__()
106
+
107
+ attn = Attention(encoder_hidden, decoder_hidden)
108
+
109
+ self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout)
110
+ self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn)
111
+
112
+ def forward_encoder(self, src):
113
+ """
114
+ src: timestep x batch_size x channel
115
+ hidden: batch_size x hid_dim
116
+ encoder_outputs: src_len x batch_size x hid_dim
117
+ """
118
+
119
+ encoder_outputs, hidden = self.encoder(src)
120
+
121
+ return (hidden, encoder_outputs)
122
+
123
+ def forward_decoder(self, tgt, memory):
124
+ """
125
+ tgt: timestep x batch_size
126
+ hidden: batch_size x hid_dim
127
+ encouder: src_len x batch_size x hid_dim
128
+ output: batch_size x 1 x vocab_size
129
+ """
130
+
131
+ tgt = tgt[-1]
132
+ hidden, encoder_outputs = memory
133
+ output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs)
134
+ output = output.unsqueeze(1)
135
+
136
+ return output, (hidden, encoder_outputs)
137
+
138
+ def forward(self, src, trg):
139
+ """
140
+ src: time_step x batch_size
141
+ trg: time_step x batch_size
142
+ outputs: batch_size x time_step x vocab_size
143
+ """
144
+
145
+ batch_size = src.shape[1]
146
+ trg_len = trg.shape[0]
147
+ trg_vocab_size = self.decoder.output_dim
148
+ device = src.device
149
+
150
+ outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device)
151
+ encoder_outputs, hidden = self.encoder(src)
152
+
153
+ for t in range(trg_len):
154
+ input = trg[t]
155
+ output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
156
+
157
+ outputs[t] = output
158
+
159
+ outputs = outputs.transpose(0, 1).contiguous()
160
+
161
+ return outputs
162
+
163
+ def expand_memory(self, memory, beam_size):
164
+ hidden, encoder_outputs = memory
165
+ hidden = hidden.repeat(beam_size, 1)
166
+ encoder_outputs = encoder_outputs.repeat(1, beam_size, 1)
167
+
168
+ return (hidden, encoder_outputs)
169
+
170
+ def get_memory(self, memory, i):
171
+ hidden, encoder_outputs = memory
172
+ hidden = hidden[[i]]
173
+ encoder_outputs = encoder_outputs[:, [i],:]
174
+
175
+ return (hidden, encoder_outputs)
@@ -0,0 +1,29 @@
1
+ from .backbone.cnn import CNN
2
+ from .seqmodel.seq2seq import Seq2Seq
3
+ from torch import nn
4
+
5
+ class VietOCR(nn.Module):
6
+ def __init__(self, vocab_size,
7
+ backbone,
8
+ cnn_args,
9
+ transformer_args, seq_modeling='transformer'):
10
+
11
+ super(VietOCR, self).__init__()
12
+
13
+ self.cnn = CNN(backbone, **cnn_args)
14
+ self.seq_modeling = seq_modeling
15
+ self.transformer = Seq2Seq(vocab_size, **transformer_args)
16
+
17
+
18
+ def forward(self, img, tgt_input, tgt_key_padding_mask):
19
+ """
20
+ Shape:
21
+ - img: (N, C, H, W)
22
+ - tgt_input: (T, N)
23
+ - tgt_key_padding_mask: (N, T)
24
+ - output: b t v
25
+ """
26
+ src = self.cnn(img)
27
+ outputs = self.transformer(src, tgt_input)
28
+
29
+ return outputs
@@ -0,0 +1,36 @@
1
+ class Vocab():
2
+ def __init__(self, chars):
3
+ self.pad = 0
4
+ self.go = 1
5
+ self.eos = 2
6
+ self.mask_token = 3
7
+
8
+ self.chars = chars
9
+
10
+ self.c2i = {c:i+4 for i, c in enumerate(chars)}
11
+
12
+ self.i2c = {i+4:c for i, c in enumerate(chars)}
13
+
14
+ self.i2c[0] = '<pad>'
15
+ self.i2c[1] = '<sos>'
16
+ self.i2c[2] = '<eos>'
17
+ self.i2c[3] = '*'
18
+
19
+ def encode(self, chars):
20
+ return [self.go] + [self.c2i[c] for c in chars] + [self.eos]
21
+
22
+ def decode(self, ids):
23
+ first = 1 if self.go in ids else 0
24
+ last = ids.index(self.eos) if self.eos in ids else None
25
+ sent = ''.join([self.i2c[i] for i in ids[first:last]])
26
+ return sent
27
+
28
+ def __len__(self):
29
+ return len(self.c2i) + 4
30
+
31
+ def batch_decode(self, arr):
32
+ texts = [self.decode(ids) for ids in arr]
33
+ return texts
34
+
35
+ def __str__(self):
36
+ return self.chars