xfmr-zem 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xfmr_zem/cli.py +32 -3
- xfmr_zem/client.py +59 -8
- xfmr_zem/server.py +21 -4
- xfmr_zem/servers/data_juicer/server.py +1 -1
- xfmr_zem/servers/instruction_gen/server.py +1 -1
- xfmr_zem/servers/io/server.py +1 -1
- xfmr_zem/servers/llm/parameters.yml +10 -0
- xfmr_zem/servers/nemo_curator/server.py +1 -1
- xfmr_zem/servers/ocr/deepdoc_vietocr/__init__.py +90 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/implementations.py +1286 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/layout_recognizer.py +562 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/ocr.py +512 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/.gitattributes +35 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/README.md +5 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/ocr.res +6623 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/operators.py +725 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/phases.py +191 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/pipeline.py +561 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/postprocess.py +370 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/recognizer.py +436 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/table_structure_recognizer.py +569 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/__init__.py +81 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/file_utils.py +246 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/base.yml +58 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/vgg-seq2seq.yml +38 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/cnn.py +25 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/vgg.py +51 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/seqmodel/seq2seq.py +175 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/transformerocr.py +29 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/vocab.py +36 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/config.py +37 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/translate.py +111 -0
- xfmr_zem/servers/ocr/engines.py +242 -0
- xfmr_zem/servers/ocr/install_models.py +63 -0
- xfmr_zem/servers/ocr/parameters.yml +4 -0
- xfmr_zem/servers/ocr/server.py +102 -0
- xfmr_zem/servers/profiler/parameters.yml +4 -0
- xfmr_zem/servers/sinks/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/server.py +62 -0
- xfmr_zem/zenml_wrapper.py +20 -7
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/METADATA +20 -1
- xfmr_zem-0.2.6.dist-info/RECORD +58 -0
- xfmr_zem-0.2.4.dist-info/RECORD +0 -23
- /xfmr_zem/servers/data_juicer/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/instruction_gen/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/io/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/nemo_curator/{parameter.yaml → parameters.yml} +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/WHEEL +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/entry_points.txt +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
import base64
|
|
17
|
+
import json
|
|
18
|
+
import os
|
|
19
|
+
import re
|
|
20
|
+
import sys
|
|
21
|
+
import threading
|
|
22
|
+
from io import BytesIO
|
|
23
|
+
|
|
24
|
+
import pdfplumber
|
|
25
|
+
from PIL import Image
|
|
26
|
+
from cachetools import LRUCache, cached
|
|
27
|
+
from ruamel.yaml import YAML
|
|
28
|
+
|
|
29
|
+
from enum import Enum
|
|
30
|
+
# from .db import FileType
|
|
31
|
+
# from .constants import IMG_BASE64_PREFIX
|
|
32
|
+
IMG_BASE64_PREFIX = 'data:image/png;base64,'
|
|
33
|
+
|
|
34
|
+
class FileType(Enum):
|
|
35
|
+
PDF = 'pdf'
|
|
36
|
+
DOC = 'doc'
|
|
37
|
+
VISUAL = 'visual'
|
|
38
|
+
AURAL = 'aural'
|
|
39
|
+
VIRTUAL = 'virtual'
|
|
40
|
+
FOLDER = 'folder'
|
|
41
|
+
OTHER = "other"
|
|
42
|
+
|
|
43
|
+
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
|
44
|
+
RAG_BASE = os.getenv("RAG_BASE")
|
|
45
|
+
|
|
46
|
+
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
|
47
|
+
if LOCK_KEY_pdfplumber not in sys.modules:
|
|
48
|
+
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_project_base_directory(*args):
|
|
52
|
+
base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir))
|
|
53
|
+
if args:
|
|
54
|
+
return os.path.join(base_dir, *args)
|
|
55
|
+
return base_dir
|
|
56
|
+
|
|
57
|
+
def get_rag_directory(*args):
|
|
58
|
+
global RAG_BASE
|
|
59
|
+
if RAG_BASE is None:
|
|
60
|
+
RAG_BASE = os.path.abspath(
|
|
61
|
+
os.path.join(
|
|
62
|
+
os.path.dirname(os.path.realpath(__file__)),
|
|
63
|
+
os.pardir,
|
|
64
|
+
os.pardir,
|
|
65
|
+
os.pardir,
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
if args:
|
|
69
|
+
return os.path.join(RAG_BASE, *args)
|
|
70
|
+
return RAG_BASE
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_rag_python_directory(*args):
|
|
74
|
+
return get_rag_directory("python", *args)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_home_cache_dir():
|
|
78
|
+
dir = os.path.join(os.path.expanduser('~'), ".ragflow")
|
|
79
|
+
try:
|
|
80
|
+
os.mkdir(dir)
|
|
81
|
+
except OSError:
|
|
82
|
+
pass
|
|
83
|
+
return dir
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@cached(cache=LRUCache(maxsize=10))
|
|
87
|
+
def load_json_conf(conf_path):
|
|
88
|
+
if os.path.isabs(conf_path):
|
|
89
|
+
json_conf_path = conf_path
|
|
90
|
+
else:
|
|
91
|
+
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
|
92
|
+
try:
|
|
93
|
+
with open(json_conf_path) as f:
|
|
94
|
+
return json.load(f)
|
|
95
|
+
except BaseException:
|
|
96
|
+
raise EnvironmentError(
|
|
97
|
+
"loading json file config from '{}' failed!".format(json_conf_path)
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def dump_json_conf(config_data, conf_path):
|
|
102
|
+
if os.path.isabs(conf_path):
|
|
103
|
+
json_conf_path = conf_path
|
|
104
|
+
else:
|
|
105
|
+
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
|
106
|
+
try:
|
|
107
|
+
with open(json_conf_path, "w") as f:
|
|
108
|
+
json.dump(config_data, f, indent=4)
|
|
109
|
+
except BaseException:
|
|
110
|
+
raise EnvironmentError(
|
|
111
|
+
"loading json file config from '{}' failed!".format(json_conf_path)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def load_json_conf_real_time(conf_path):
|
|
116
|
+
if os.path.isabs(conf_path):
|
|
117
|
+
json_conf_path = conf_path
|
|
118
|
+
else:
|
|
119
|
+
json_conf_path = os.path.join(get_project_base_directory(), conf_path)
|
|
120
|
+
try:
|
|
121
|
+
with open(json_conf_path) as f:
|
|
122
|
+
return json.load(f)
|
|
123
|
+
except BaseException:
|
|
124
|
+
raise EnvironmentError(
|
|
125
|
+
"loading json file config from '{}' failed!".format(json_conf_path)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def load_yaml_conf(conf_path):
|
|
130
|
+
if not os.path.isabs(conf_path):
|
|
131
|
+
conf_path = os.path.join(get_project_base_directory(), conf_path)
|
|
132
|
+
try:
|
|
133
|
+
with open(conf_path) as f:
|
|
134
|
+
yaml = YAML(typ='safe', pure=True)
|
|
135
|
+
return yaml.load(f)
|
|
136
|
+
except Exception as e:
|
|
137
|
+
raise EnvironmentError(
|
|
138
|
+
"loading yaml file config from {} failed:".format(conf_path), e
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def rewrite_yaml_conf(conf_path, config):
|
|
143
|
+
if not os.path.isabs(conf_path):
|
|
144
|
+
conf_path = os.path.join(get_project_base_directory(), conf_path)
|
|
145
|
+
try:
|
|
146
|
+
with open(conf_path, "w") as f:
|
|
147
|
+
yaml = YAML(typ="safe")
|
|
148
|
+
yaml.dump(config, f)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
raise EnvironmentError(
|
|
151
|
+
"rewrite yaml file config {} failed:".format(conf_path), e
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def rewrite_json_file(filepath, json_data):
|
|
156
|
+
with open(filepath, "w", encoding='utf-8') as f:
|
|
157
|
+
json.dump(json_data, f, indent=4, separators=(",", ": "))
|
|
158
|
+
f.close()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def filename_type(filename):
|
|
162
|
+
filename = filename.lower()
|
|
163
|
+
if re.match(r".*\.pdf$", filename):
|
|
164
|
+
return FileType.PDF.value
|
|
165
|
+
|
|
166
|
+
if re.match(
|
|
167
|
+
r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
|
|
168
|
+
return FileType.DOC.value
|
|
169
|
+
|
|
170
|
+
if re.match(
|
|
171
|
+
r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
|
172
|
+
return FileType.AURAL.value
|
|
173
|
+
|
|
174
|
+
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
|
|
175
|
+
return FileType.VISUAL.value
|
|
176
|
+
|
|
177
|
+
return FileType.OTHER.value
|
|
178
|
+
|
|
179
|
+
def thumbnail_img(filename, blob):
|
|
180
|
+
"""
|
|
181
|
+
MySQL LongText max length is 65535
|
|
182
|
+
"""
|
|
183
|
+
filename = filename.lower()
|
|
184
|
+
if re.match(r".*\.pdf$", filename):
|
|
185
|
+
with sys.modules[LOCK_KEY_pdfplumber]:
|
|
186
|
+
pdf = pdfplumber.open(BytesIO(blob))
|
|
187
|
+
buffered = BytesIO()
|
|
188
|
+
resolution = 32
|
|
189
|
+
img = None
|
|
190
|
+
for _ in range(10):
|
|
191
|
+
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
|
|
192
|
+
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
|
|
193
|
+
img = buffered.getvalue()
|
|
194
|
+
if len(img) >= 64000 and resolution >= 2:
|
|
195
|
+
resolution = resolution / 2
|
|
196
|
+
buffered = BytesIO()
|
|
197
|
+
else:
|
|
198
|
+
break
|
|
199
|
+
pdf.close()
|
|
200
|
+
return img
|
|
201
|
+
|
|
202
|
+
elif re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
|
|
203
|
+
image = Image.open(BytesIO(blob))
|
|
204
|
+
image.thumbnail((30, 30))
|
|
205
|
+
buffered = BytesIO()
|
|
206
|
+
image.save(buffered, format="png")
|
|
207
|
+
return buffered.getvalue()
|
|
208
|
+
|
|
209
|
+
elif re.match(r".*\.(ppt|pptx)$", filename):
|
|
210
|
+
import aspose.slides as slides
|
|
211
|
+
import aspose.pydrawing as drawing
|
|
212
|
+
try:
|
|
213
|
+
with slides.Presentation(BytesIO(blob)) as presentation:
|
|
214
|
+
buffered = BytesIO()
|
|
215
|
+
scale = 0.03
|
|
216
|
+
img = None
|
|
217
|
+
for _ in range(10):
|
|
218
|
+
# https://reference.aspose.com/slides/python-net/aspose.slides/slide/get_thumbnail/#float-float
|
|
219
|
+
presentation.slides[0].get_thumbnail(scale, scale).save(
|
|
220
|
+
buffered, drawing.imaging.ImageFormat.png)
|
|
221
|
+
img = buffered.getvalue()
|
|
222
|
+
if len(img) >= 64000:
|
|
223
|
+
scale = scale / 2.0
|
|
224
|
+
buffered = BytesIO()
|
|
225
|
+
else:
|
|
226
|
+
break
|
|
227
|
+
return img
|
|
228
|
+
except Exception:
|
|
229
|
+
pass
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def thumbnail(filename, blob):
|
|
234
|
+
img = thumbnail_img(filename, blob)
|
|
235
|
+
if img is not None:
|
|
236
|
+
return IMG_BASE64_PREFIX + \
|
|
237
|
+
base64.b64encode(img).decode("utf-8")
|
|
238
|
+
else:
|
|
239
|
+
return ''
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def traversal_files(base):
|
|
243
|
+
for root, ds, fs in os.walk(base):
|
|
244
|
+
for f in fs:
|
|
245
|
+
fullname = os.path.join(root, f)
|
|
246
|
+
yield fullname
|
|
File without changes
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# change to list chars of your dataset or use default vietnamese chars
|
|
2
|
+
vocab: 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ '
|
|
3
|
+
|
|
4
|
+
# cpu, cuda, cuda:0
|
|
5
|
+
device: cuda:0
|
|
6
|
+
|
|
7
|
+
seq_modeling: transformer
|
|
8
|
+
transformer:
|
|
9
|
+
d_model: 256
|
|
10
|
+
nhead: 8
|
|
11
|
+
num_encoder_layers: 6
|
|
12
|
+
num_decoder_layers: 6
|
|
13
|
+
dim_feedforward: 2048
|
|
14
|
+
max_seq_length: 1024
|
|
15
|
+
pos_dropout: 0.1
|
|
16
|
+
trans_dropout: 0.1
|
|
17
|
+
|
|
18
|
+
optimizer:
|
|
19
|
+
max_lr: 0.0003
|
|
20
|
+
pct_start: 0.1
|
|
21
|
+
|
|
22
|
+
trainer:
|
|
23
|
+
batch_size: 32
|
|
24
|
+
print_every: 200
|
|
25
|
+
valid_every: 4000
|
|
26
|
+
iters: 100000
|
|
27
|
+
# where to save our model for prediction
|
|
28
|
+
export: ./weights/transformerocr.pth
|
|
29
|
+
checkpoint: ./checkpoint/transformerocr_checkpoint.pth
|
|
30
|
+
log: ./train.log
|
|
31
|
+
# null to disable compuate accuracy, or change to number of sample to enable validiation while training
|
|
32
|
+
metrics: null
|
|
33
|
+
|
|
34
|
+
dataset:
|
|
35
|
+
# name of your dataset
|
|
36
|
+
name: data
|
|
37
|
+
# path to annotation and image
|
|
38
|
+
data_root: ./img/
|
|
39
|
+
train_annotation: annotation_train.txt
|
|
40
|
+
valid_annotation: annotation_val_small.txt
|
|
41
|
+
# resize image to 32 height, larger height will increase accuracy
|
|
42
|
+
image_height: 32
|
|
43
|
+
image_min_width: 32
|
|
44
|
+
image_max_width: 512
|
|
45
|
+
|
|
46
|
+
dataloader:
|
|
47
|
+
num_workers: 3
|
|
48
|
+
pin_memory: True
|
|
49
|
+
|
|
50
|
+
aug:
|
|
51
|
+
image_aug: true
|
|
52
|
+
masked_language_model: true
|
|
53
|
+
|
|
54
|
+
predictor:
|
|
55
|
+
# disable or enable beamsearch while prediction, use beamsearch will be slower
|
|
56
|
+
beamsearch: False
|
|
57
|
+
|
|
58
|
+
quiet: False
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
pretrain:
|
|
2
|
+
id_or_url: 1nTKlEog9YFK74kPyX0qLwCWi60_YHHk4
|
|
3
|
+
md5: efcabaa6d3adfca8e52bda2fd7d2ee04
|
|
4
|
+
cached: /tmp/tranformerorc.pth
|
|
5
|
+
|
|
6
|
+
# url or local path
|
|
7
|
+
weights: https://drive.google.com/uc?id=1nTKlEog9YFK74kPyX0qLwCWi60_YHHk4
|
|
8
|
+
|
|
9
|
+
backbone: vgg19_bn
|
|
10
|
+
cnn:
|
|
11
|
+
# pooling stride size
|
|
12
|
+
ss:
|
|
13
|
+
- [2, 2]
|
|
14
|
+
- [2, 2]
|
|
15
|
+
- [2, 1]
|
|
16
|
+
- [2, 1]
|
|
17
|
+
- [1, 1]
|
|
18
|
+
# pooling kernel size
|
|
19
|
+
ks:
|
|
20
|
+
- [2, 2]
|
|
21
|
+
- [2, 2]
|
|
22
|
+
- [2, 1]
|
|
23
|
+
- [2, 1]
|
|
24
|
+
- [1, 1]
|
|
25
|
+
# dim of ouput feature map
|
|
26
|
+
hidden: 256
|
|
27
|
+
|
|
28
|
+
seq_modeling: seq2seq
|
|
29
|
+
transformer:
|
|
30
|
+
encoder_hidden: 256
|
|
31
|
+
decoder_hidden: 256
|
|
32
|
+
img_channel: 256
|
|
33
|
+
decoder_embedded: 256
|
|
34
|
+
dropout: 0.1
|
|
35
|
+
|
|
36
|
+
optimizer:
|
|
37
|
+
max_lr: 0.001
|
|
38
|
+
pct_start: 0.1
|
|
File without changes
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
from . import vgg
|
|
5
|
+
|
|
6
|
+
class CNN(nn.Module):
|
|
7
|
+
def __init__(self, backbone, **kwargs):
|
|
8
|
+
super(CNN, self).__init__()
|
|
9
|
+
|
|
10
|
+
if backbone == 'vgg11_bn':
|
|
11
|
+
self.model = vgg.vgg11_bn(**kwargs)
|
|
12
|
+
elif backbone == 'vgg19_bn':
|
|
13
|
+
self.model = vgg.vgg19_bn(**kwargs)
|
|
14
|
+
|
|
15
|
+
def forward(self, x):
|
|
16
|
+
return self.model(x)
|
|
17
|
+
|
|
18
|
+
def freeze(self):
|
|
19
|
+
for name, param in self.model.features.named_parameters():
|
|
20
|
+
if name != 'last_conv_1x1':
|
|
21
|
+
param.requires_grad = False
|
|
22
|
+
|
|
23
|
+
def unfreeze(self):
|
|
24
|
+
for param in self.model.features.parameters():
|
|
25
|
+
param.requires_grad = True
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torchvision import models
|
|
4
|
+
from einops import rearrange
|
|
5
|
+
from torchvision.models._utils import IntermediateLayerGetter
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Vgg(nn.Module):
|
|
9
|
+
def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5):
|
|
10
|
+
super(Vgg, self).__init__()
|
|
11
|
+
|
|
12
|
+
weights = "DEFAULT" if pretrained else None
|
|
13
|
+
if name == 'vgg11_bn':
|
|
14
|
+
cnn = models.vgg11_bn(weights=weights)
|
|
15
|
+
elif name == 'vgg19_bn':
|
|
16
|
+
cnn = models.vgg19_bn(weights=weights)
|
|
17
|
+
|
|
18
|
+
pool_idx = 0
|
|
19
|
+
|
|
20
|
+
for i, layer in enumerate(cnn.features):
|
|
21
|
+
if isinstance(layer, torch.nn.MaxPool2d):
|
|
22
|
+
cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0)
|
|
23
|
+
pool_idx += 1
|
|
24
|
+
|
|
25
|
+
self.features = cnn.features
|
|
26
|
+
self.dropout = nn.Dropout(dropout)
|
|
27
|
+
self.last_conv_1x1 = nn.Conv2d(512, hidden, 1)
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
"""
|
|
31
|
+
Shape:
|
|
32
|
+
- x: (N, C, H, W)
|
|
33
|
+
- output: (W, N, C)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
conv = self.features(x)
|
|
37
|
+
conv = self.dropout(conv)
|
|
38
|
+
conv = self.last_conv_1x1(conv)
|
|
39
|
+
|
|
40
|
+
# conv = rearrange(conv, 'b d h w -> b d (w h)')
|
|
41
|
+
conv = conv.permute(0, 1, 3, 2)
|
|
42
|
+
conv = conv.flatten(2)
|
|
43
|
+
conv = conv.permute(2, 0, 1)
|
|
44
|
+
|
|
45
|
+
return conv
|
|
46
|
+
|
|
47
|
+
def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
|
|
48
|
+
return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout)
|
|
49
|
+
|
|
50
|
+
def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
|
|
51
|
+
return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.optim as optim
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
class Encoder(nn.Module):
|
|
7
|
+
def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
|
|
8
|
+
super().__init__()
|
|
9
|
+
|
|
10
|
+
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
|
|
11
|
+
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
|
|
12
|
+
self.dropout = nn.Dropout(dropout)
|
|
13
|
+
|
|
14
|
+
def forward(self, src):
|
|
15
|
+
"""
|
|
16
|
+
src: src_len x batch_size x img_channel
|
|
17
|
+
outputs: src_len x batch_size x hid_dim
|
|
18
|
+
hidden: batch_size x hid_dim
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
embedded = self.dropout(src)
|
|
22
|
+
|
|
23
|
+
outputs, hidden = self.rnn(embedded)
|
|
24
|
+
|
|
25
|
+
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
|
|
26
|
+
|
|
27
|
+
return outputs, hidden
|
|
28
|
+
|
|
29
|
+
class Attention(nn.Module):
|
|
30
|
+
def __init__(self, enc_hid_dim, dec_hid_dim):
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
|
|
34
|
+
self.v = nn.Linear(dec_hid_dim, 1, bias = False)
|
|
35
|
+
|
|
36
|
+
def forward(self, hidden, encoder_outputs):
|
|
37
|
+
"""
|
|
38
|
+
hidden: batch_size x hid_dim
|
|
39
|
+
encoder_outputs: src_len x batch_size x hid_dim,
|
|
40
|
+
outputs: batch_size x src_len
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
batch_size = encoder_outputs.shape[1]
|
|
44
|
+
src_len = encoder_outputs.shape[0]
|
|
45
|
+
|
|
46
|
+
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
|
|
47
|
+
|
|
48
|
+
encoder_outputs = encoder_outputs.permute(1, 0, 2)
|
|
49
|
+
|
|
50
|
+
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
|
|
51
|
+
|
|
52
|
+
attention = self.v(energy).squeeze(2)
|
|
53
|
+
|
|
54
|
+
return F.softmax(attention, dim = 1)
|
|
55
|
+
|
|
56
|
+
class Decoder(nn.Module):
|
|
57
|
+
def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
|
|
58
|
+
super().__init__()
|
|
59
|
+
|
|
60
|
+
self.output_dim = output_dim
|
|
61
|
+
self.attention = attention
|
|
62
|
+
|
|
63
|
+
self.embedding = nn.Embedding(output_dim, emb_dim)
|
|
64
|
+
self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
|
|
65
|
+
self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
|
|
66
|
+
self.dropout = nn.Dropout(dropout)
|
|
67
|
+
|
|
68
|
+
def forward(self, input, hidden, encoder_outputs):
|
|
69
|
+
"""
|
|
70
|
+
inputs: batch_size
|
|
71
|
+
hidden: batch_size x hid_dim
|
|
72
|
+
encoder_outputs: src_len x batch_size x hid_dim
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
input = input.unsqueeze(0)
|
|
76
|
+
|
|
77
|
+
embedded = self.dropout(self.embedding(input))
|
|
78
|
+
|
|
79
|
+
a = self.attention(hidden, encoder_outputs)
|
|
80
|
+
|
|
81
|
+
a = a.unsqueeze(1)
|
|
82
|
+
|
|
83
|
+
encoder_outputs = encoder_outputs.permute(1, 0, 2)
|
|
84
|
+
|
|
85
|
+
weighted = torch.bmm(a, encoder_outputs)
|
|
86
|
+
|
|
87
|
+
weighted = weighted.permute(1, 0, 2)
|
|
88
|
+
|
|
89
|
+
rnn_input = torch.cat((embedded, weighted), dim = 2)
|
|
90
|
+
|
|
91
|
+
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
|
|
92
|
+
|
|
93
|
+
assert (output == hidden).all()
|
|
94
|
+
|
|
95
|
+
embedded = embedded.squeeze(0)
|
|
96
|
+
output = output.squeeze(0)
|
|
97
|
+
weighted = weighted.squeeze(0)
|
|
98
|
+
|
|
99
|
+
prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
|
|
100
|
+
|
|
101
|
+
return prediction, hidden.squeeze(0), a.squeeze(1)
|
|
102
|
+
|
|
103
|
+
class Seq2Seq(nn.Module):
|
|
104
|
+
def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1):
|
|
105
|
+
super().__init__()
|
|
106
|
+
|
|
107
|
+
attn = Attention(encoder_hidden, decoder_hidden)
|
|
108
|
+
|
|
109
|
+
self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout)
|
|
110
|
+
self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn)
|
|
111
|
+
|
|
112
|
+
def forward_encoder(self, src):
|
|
113
|
+
"""
|
|
114
|
+
src: timestep x batch_size x channel
|
|
115
|
+
hidden: batch_size x hid_dim
|
|
116
|
+
encoder_outputs: src_len x batch_size x hid_dim
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
encoder_outputs, hidden = self.encoder(src)
|
|
120
|
+
|
|
121
|
+
return (hidden, encoder_outputs)
|
|
122
|
+
|
|
123
|
+
def forward_decoder(self, tgt, memory):
|
|
124
|
+
"""
|
|
125
|
+
tgt: timestep x batch_size
|
|
126
|
+
hidden: batch_size x hid_dim
|
|
127
|
+
encouder: src_len x batch_size x hid_dim
|
|
128
|
+
output: batch_size x 1 x vocab_size
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
tgt = tgt[-1]
|
|
132
|
+
hidden, encoder_outputs = memory
|
|
133
|
+
output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs)
|
|
134
|
+
output = output.unsqueeze(1)
|
|
135
|
+
|
|
136
|
+
return output, (hidden, encoder_outputs)
|
|
137
|
+
|
|
138
|
+
def forward(self, src, trg):
|
|
139
|
+
"""
|
|
140
|
+
src: time_step x batch_size
|
|
141
|
+
trg: time_step x batch_size
|
|
142
|
+
outputs: batch_size x time_step x vocab_size
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
batch_size = src.shape[1]
|
|
146
|
+
trg_len = trg.shape[0]
|
|
147
|
+
trg_vocab_size = self.decoder.output_dim
|
|
148
|
+
device = src.device
|
|
149
|
+
|
|
150
|
+
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device)
|
|
151
|
+
encoder_outputs, hidden = self.encoder(src)
|
|
152
|
+
|
|
153
|
+
for t in range(trg_len):
|
|
154
|
+
input = trg[t]
|
|
155
|
+
output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
|
|
156
|
+
|
|
157
|
+
outputs[t] = output
|
|
158
|
+
|
|
159
|
+
outputs = outputs.transpose(0, 1).contiguous()
|
|
160
|
+
|
|
161
|
+
return outputs
|
|
162
|
+
|
|
163
|
+
def expand_memory(self, memory, beam_size):
|
|
164
|
+
hidden, encoder_outputs = memory
|
|
165
|
+
hidden = hidden.repeat(beam_size, 1)
|
|
166
|
+
encoder_outputs = encoder_outputs.repeat(1, beam_size, 1)
|
|
167
|
+
|
|
168
|
+
return (hidden, encoder_outputs)
|
|
169
|
+
|
|
170
|
+
def get_memory(self, memory, i):
|
|
171
|
+
hidden, encoder_outputs = memory
|
|
172
|
+
hidden = hidden[[i]]
|
|
173
|
+
encoder_outputs = encoder_outputs[:, [i],:]
|
|
174
|
+
|
|
175
|
+
return (hidden, encoder_outputs)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .backbone.cnn import CNN
|
|
2
|
+
from .seqmodel.seq2seq import Seq2Seq
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
class VietOCR(nn.Module):
|
|
6
|
+
def __init__(self, vocab_size,
|
|
7
|
+
backbone,
|
|
8
|
+
cnn_args,
|
|
9
|
+
transformer_args, seq_modeling='transformer'):
|
|
10
|
+
|
|
11
|
+
super(VietOCR, self).__init__()
|
|
12
|
+
|
|
13
|
+
self.cnn = CNN(backbone, **cnn_args)
|
|
14
|
+
self.seq_modeling = seq_modeling
|
|
15
|
+
self.transformer = Seq2Seq(vocab_size, **transformer_args)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def forward(self, img, tgt_input, tgt_key_padding_mask):
|
|
19
|
+
"""
|
|
20
|
+
Shape:
|
|
21
|
+
- img: (N, C, H, W)
|
|
22
|
+
- tgt_input: (T, N)
|
|
23
|
+
- tgt_key_padding_mask: (N, T)
|
|
24
|
+
- output: b t v
|
|
25
|
+
"""
|
|
26
|
+
src = self.cnn(img)
|
|
27
|
+
outputs = self.transformer(src, tgt_input)
|
|
28
|
+
|
|
29
|
+
return outputs
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
class Vocab():
|
|
2
|
+
def __init__(self, chars):
|
|
3
|
+
self.pad = 0
|
|
4
|
+
self.go = 1
|
|
5
|
+
self.eos = 2
|
|
6
|
+
self.mask_token = 3
|
|
7
|
+
|
|
8
|
+
self.chars = chars
|
|
9
|
+
|
|
10
|
+
self.c2i = {c:i+4 for i, c in enumerate(chars)}
|
|
11
|
+
|
|
12
|
+
self.i2c = {i+4:c for i, c in enumerate(chars)}
|
|
13
|
+
|
|
14
|
+
self.i2c[0] = '<pad>'
|
|
15
|
+
self.i2c[1] = '<sos>'
|
|
16
|
+
self.i2c[2] = '<eos>'
|
|
17
|
+
self.i2c[3] = '*'
|
|
18
|
+
|
|
19
|
+
def encode(self, chars):
|
|
20
|
+
return [self.go] + [self.c2i[c] for c in chars] + [self.eos]
|
|
21
|
+
|
|
22
|
+
def decode(self, ids):
|
|
23
|
+
first = 1 if self.go in ids else 0
|
|
24
|
+
last = ids.index(self.eos) if self.eos in ids else None
|
|
25
|
+
sent = ''.join([self.i2c[i] for i in ids[first:last]])
|
|
26
|
+
return sent
|
|
27
|
+
|
|
28
|
+
def __len__(self):
|
|
29
|
+
return len(self.c2i) + 4
|
|
30
|
+
|
|
31
|
+
def batch_decode(self, arr):
|
|
32
|
+
texts = [self.decode(ids) for ids in arr]
|
|
33
|
+
return texts
|
|
34
|
+
|
|
35
|
+
def __str__(self):
|
|
36
|
+
return self.chars
|