manuscript-ocr 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- manuscript_ocr-0.1.0/MANIFEST.in +2 -0
- manuscript_ocr-0.1.0/PKG-INFO +73 -0
- manuscript_ocr-0.1.0/README.md +39 -0
- manuscript_ocr-0.1.0/pyproject.toml +3 -0
- manuscript_ocr-0.1.0/requirements.txt +13 -0
- manuscript_ocr-0.1.0/setup.cfg +4 -0
- manuscript_ocr-0.1.0/setup.py +35 -0
- manuscript_ocr-0.1.0/src/example.py +13 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/__init__.py +3 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/__pycache__/__init__.cpython-311.pyc +0 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/__pycache__/types.cpython-311.pyc +0 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/__init__.py +112 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/__init__.cpython-311.pyc +0 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/dataset.cpython-311.pyc +0 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/east.cpython-311.pyc +0 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/lanms.cpython-311.pyc +0 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/loss.cpython-311.pyc +0 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/sam.cpython-311.pyc +0 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/utils.cpython-311.pyc +0 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/dataset.py +154 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/east.py +115 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/lanms.py +214 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/loss.py +63 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/sam.py +76 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/train.py +305 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/east/utils.py +235 -0
- manuscript_ocr-0.1.0/src/manuscript/detectors/types.py +27 -0
- manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/PKG-INFO +73 -0
- manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/SOURCES.txt +30 -0
- manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/dependency_links.txt +1 -0
- manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/requires.txt +12 -0
- manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: manuscript-ocr
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: manuscript-ocr
|
|
5
|
+
Home-page:
|
|
6
|
+
Author:
|
|
7
|
+
Author-email:
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
13
|
+
Requires-Python: >=3.8
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
Requires-Dist: torch==2.0.1+cu118
|
|
16
|
+
Requires-Dist: torchvision==0.15.2+cu118
|
|
17
|
+
Requires-Dist: torchaudio==2.0.2
|
|
18
|
+
Requires-Dist: numpy<2
|
|
19
|
+
Requires-Dist: opencv-python
|
|
20
|
+
Requires-Dist: Pillow
|
|
21
|
+
Requires-Dist: shapely
|
|
22
|
+
Requires-Dist: numba
|
|
23
|
+
Requires-Dist: tensorboard
|
|
24
|
+
Requires-Dist: gdown
|
|
25
|
+
Requires-Dist: pydantic
|
|
26
|
+
Requires-Dist: scikit-image
|
|
27
|
+
Dynamic: classifier
|
|
28
|
+
Dynamic: description
|
|
29
|
+
Dynamic: description-content-type
|
|
30
|
+
Dynamic: license
|
|
31
|
+
Dynamic: requires-dist
|
|
32
|
+
Dynamic: requires-python
|
|
33
|
+
Dynamic: summary
|
|
34
|
+
|
|
35
|
+
## Installation
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
pip install manuscript-ocr
|
|
39
|
+
````
|
|
40
|
+
|
|
41
|
+
## Usage Example
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
from PIL import Image
|
|
45
|
+
from manuscript.detectors import EASTInfer
|
|
46
|
+
|
|
47
|
+
# Инициализация
|
|
48
|
+
det = EASTInfer(score_thresh=0.9)
|
|
49
|
+
|
|
50
|
+
# Инфер с визуализацией
|
|
51
|
+
page, vis_image = det.infer(r"example\ocr_example_image.jpg", vis=True)
|
|
52
|
+
|
|
53
|
+
print(page)
|
|
54
|
+
|
|
55
|
+
# Покажет картинку с наложенными боксами
|
|
56
|
+
Image.fromarray(vis_image).show()
|
|
57
|
+
|
|
58
|
+
# Или сохранить результат на диск:
|
|
59
|
+
Image.fromarray(vis_image).save(r"example\ocr_example_image_infer.png")
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
### Результат
|
|
63
|
+
|
|
64
|
+
Текстовые блоки будут выведены в консоль, например:
|
|
65
|
+
|
|
66
|
+
```
|
|
67
|
+
Page(blocks=[Block(words=[Word(polygon=[(874.1005, 909.1005), (966.8995, 909.1005), (966.8995, 956.8995), (874.1005, 956.8995)]),
|
|
68
|
+
Word(polygon=[(849.1234, 810.5678), … ])])])
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
А визуализация сохранится в файл `example/ocr_example_image_infer.png`:
|
|
72
|
+
|
|
73
|
+

|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
## Installation
|
|
2
|
+
|
|
3
|
+
```bash
|
|
4
|
+
pip install manuscript-ocr
|
|
5
|
+
````
|
|
6
|
+
|
|
7
|
+
## Usage Example
|
|
8
|
+
|
|
9
|
+
```python
|
|
10
|
+
from PIL import Image
|
|
11
|
+
from manuscript.detectors import EASTInfer
|
|
12
|
+
|
|
13
|
+
# Инициализация
|
|
14
|
+
det = EASTInfer(score_thresh=0.9)
|
|
15
|
+
|
|
16
|
+
# Инфер с визуализацией
|
|
17
|
+
page, vis_image = det.infer(r"example\ocr_example_image.jpg", vis=True)
|
|
18
|
+
|
|
19
|
+
print(page)
|
|
20
|
+
|
|
21
|
+
# Покажет картинку с наложенными боксами
|
|
22
|
+
Image.fromarray(vis_image).show()
|
|
23
|
+
|
|
24
|
+
# Или сохранить результат на диск:
|
|
25
|
+
Image.fromarray(vis_image).save(r"example\ocr_example_image_infer.png")
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
### Результат
|
|
29
|
+
|
|
30
|
+
Текстовые блоки будут выведены в консоль, например:
|
|
31
|
+
|
|
32
|
+
```
|
|
33
|
+
Page(blocks=[Block(words=[Word(polygon=[(874.1005, 909.1005), (966.8995, 909.1005), (966.8995, 956.8995), (874.1005, 956.8995)]),
|
|
34
|
+
Word(polygon=[(849.1234, 810.5678), … ])])])
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
А визуализация сохранится в файл `example/ocr_example_image_infer.png`:
|
|
38
|
+
|
|
39
|
+

|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from setuptools import setup, find_packages
|
|
3
|
+
|
|
4
|
+
def parse_requirements(fname="requirements.txt"):
|
|
5
|
+
here = os.path.dirname(__file__)
|
|
6
|
+
with open(os.path.join(here, fname)) as f:
|
|
7
|
+
return [ln.strip() for ln in f if ln.strip() and not ln.startswith("#")]
|
|
8
|
+
|
|
9
|
+
setup(
|
|
10
|
+
name="manuscript-ocr",
|
|
11
|
+
version="0.1.0",
|
|
12
|
+
description="manuscript-ocr",
|
|
13
|
+
long_description=open("README.md", encoding="utf-8").read(),
|
|
14
|
+
long_description_content_type="text/markdown",
|
|
15
|
+
|
|
16
|
+
author="",
|
|
17
|
+
author_email="",
|
|
18
|
+
|
|
19
|
+
url="",
|
|
20
|
+
license="MIT",
|
|
21
|
+
|
|
22
|
+
packages=find_packages(where="src"),
|
|
23
|
+
package_dir={"": "src"},
|
|
24
|
+
|
|
25
|
+
python_requires=">=3.8",
|
|
26
|
+
install_requires=parse_requirements(),
|
|
27
|
+
|
|
28
|
+
classifiers=[
|
|
29
|
+
"Programming Language :: Python :: 3",
|
|
30
|
+
"License :: OSI Approved :: MIT License",
|
|
31
|
+
"Operating System :: OS Independent",
|
|
32
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
33
|
+
],
|
|
34
|
+
include_package_data=True,
|
|
35
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from PIL import Image
|
|
2
|
+
from manuscript.detectors import EASTInfer
|
|
3
|
+
|
|
4
|
+
# инициализация
|
|
5
|
+
det = EASTInfer(score_thresh=0.9)
|
|
6
|
+
|
|
7
|
+
# инфер
|
|
8
|
+
page, image = det.infer(r"example\ocr_example_image.jpg", vis=True)
|
|
9
|
+
print(page)
|
|
10
|
+
|
|
11
|
+
pil_img = Image.fromarray(image)
|
|
12
|
+
|
|
13
|
+
pil_img.show()
|
|
Binary file
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from torchvision import transforms
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Union, Optional, List, Tuple
|
|
7
|
+
|
|
8
|
+
from .east import TextDetectionFCN
|
|
9
|
+
from .utils import decode_boxes_from_maps, draw_boxes
|
|
10
|
+
from ..types import Word, Block, Page
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
import gdown
|
|
14
|
+
|
|
15
|
+
class EASTInfer:
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
weights_path: Optional[Union[str, Path]] = None,
|
|
19
|
+
device: Optional[str] = None,
|
|
20
|
+
target_size: int = 1024,
|
|
21
|
+
score_geo_scale: float = 0.25,
|
|
22
|
+
shrink_ratio: float = 0.5,
|
|
23
|
+
score_thresh: float = 0.9,
|
|
24
|
+
iou_threshold: float = 0.2,
|
|
25
|
+
):
|
|
26
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
27
|
+
|
|
28
|
+
if weights_path is None:
|
|
29
|
+
url = (
|
|
30
|
+
"https://github.com/konstantinkozhin/manuscript-ocr"
|
|
31
|
+
"/releases/download/v0.1.0/east_quad_14_05.pth"
|
|
32
|
+
)
|
|
33
|
+
out = os.path.expanduser("~/.east_weights.pth")
|
|
34
|
+
if not os.path.exists(out):
|
|
35
|
+
print(f"Downloading EAST weights from {url} …")
|
|
36
|
+
gdown.download(url, out, quiet=False)
|
|
37
|
+
weights_path = out
|
|
38
|
+
print(weights_path)
|
|
39
|
+
# Загружаем модель с весами
|
|
40
|
+
self.model = TextDetectionFCN(
|
|
41
|
+
pretrained_backbone=False,
|
|
42
|
+
pretrained_model_path=str(weights_path),
|
|
43
|
+
).to(self.device)
|
|
44
|
+
self.model.eval()
|
|
45
|
+
|
|
46
|
+
self.target_size = target_size
|
|
47
|
+
self.score_geo_scale = score_geo_scale
|
|
48
|
+
self.shrink_ratio = shrink_ratio
|
|
49
|
+
self.score_thresh = score_thresh
|
|
50
|
+
self.iou_threshold = iou_threshold
|
|
51
|
+
|
|
52
|
+
self.tf = transforms.Compose([
|
|
53
|
+
transforms.ToTensor(),
|
|
54
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5),
|
|
55
|
+
std=(0.5, 0.5, 0.5)),
|
|
56
|
+
])
|
|
57
|
+
|
|
58
|
+
def infer(
|
|
59
|
+
self,
|
|
60
|
+
img_or_path: Union[str, Path, np.ndarray],
|
|
61
|
+
vis: bool = False
|
|
62
|
+
) -> Union[Page, Tuple[Page, np.ndarray]]:
|
|
63
|
+
"""
|
|
64
|
+
:param img_or_path: путь или RGB ndarray
|
|
65
|
+
:param vis: если True, возвращает также изображение с боксами
|
|
66
|
+
:return: Page или (Page, vis_image)
|
|
67
|
+
"""
|
|
68
|
+
# 1) Read & RGB
|
|
69
|
+
if isinstance(img_or_path, (str, Path)):
|
|
70
|
+
img = cv2.imread(str(img_or_path))
|
|
71
|
+
if img is None:
|
|
72
|
+
raise FileNotFoundError(f"Cannot read image: {img_or_path}")
|
|
73
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
74
|
+
elif isinstance(img_or_path, np.ndarray):
|
|
75
|
+
img = img_or_path
|
|
76
|
+
else:
|
|
77
|
+
raise TypeError(f"Unsupported type {type(img_or_path)}")
|
|
78
|
+
|
|
79
|
+
# 2) Resize + ToTensor + Normalize
|
|
80
|
+
resized = cv2.resize(img, (self.target_size, self.target_size))
|
|
81
|
+
img_t = self.tf(resized).to(self.device)
|
|
82
|
+
|
|
83
|
+
# 3) Forward
|
|
84
|
+
with torch.no_grad():
|
|
85
|
+
out = self.model(img_t.unsqueeze(0))
|
|
86
|
+
|
|
87
|
+
# 4) Extract maps
|
|
88
|
+
score_map = out["score"][0].cpu().numpy().squeeze(0)
|
|
89
|
+
geo_map = out["geometry"][0].cpu().numpy().transpose(1, 2, 0)
|
|
90
|
+
|
|
91
|
+
# 5) Decode quads
|
|
92
|
+
quads9 = decode_boxes_from_maps(
|
|
93
|
+
score_map=score_map,
|
|
94
|
+
geo_map=geo_map,
|
|
95
|
+
score_thresh=self.score_thresh,
|
|
96
|
+
scale=1.0 / self.score_geo_scale,
|
|
97
|
+
iou_threshold=self.iou_threshold,
|
|
98
|
+
expand_ratio=self.shrink_ratio,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# 6) Build Page
|
|
102
|
+
words: List[Word] = []
|
|
103
|
+
for quad in quads9:
|
|
104
|
+
pts = quad[:8].reshape(4, 2).tolist()
|
|
105
|
+
words.append(Word(polygon=pts))
|
|
106
|
+
page = Page(blocks=[Block(words=words)])
|
|
107
|
+
|
|
108
|
+
# 7) Optional visualization
|
|
109
|
+
if vis:
|
|
110
|
+
vis_img = draw_boxes(resized, quads9)
|
|
111
|
+
return page, vis_img
|
|
112
|
+
return page
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import cv2
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
import math
|
|
8
|
+
from shapely.geometry import Polygon
|
|
9
|
+
import torchvision.transforms as transforms
|
|
10
|
+
import skimage.draw
|
|
11
|
+
from .utils import quad_to_rbox
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def order_vertices_clockwise(poly):
|
|
15
|
+
poly = np.array(poly).reshape(-1, 2)
|
|
16
|
+
s = poly.sum(axis=1)
|
|
17
|
+
diff = np.diff(poly, axis=1).flatten()
|
|
18
|
+
tl = poly[np.argmin(s)]
|
|
19
|
+
br = poly[np.argmax(s)]
|
|
20
|
+
tr = poly[np.argmin(diff)]
|
|
21
|
+
bl = poly[np.argmax(diff)]
|
|
22
|
+
return np.array([tl, tr, br, bl], dtype=np.float32)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def shrink_poly(poly, shrink_ratio=0.3):
|
|
26
|
+
poly = np.array(poly, dtype=np.float32).reshape(-1, 2)
|
|
27
|
+
N = poly.shape[0]
|
|
28
|
+
if N != 4:
|
|
29
|
+
raise ValueError("Expected quadrilateral with 4 vertices")
|
|
30
|
+
# signed area sign
|
|
31
|
+
area = 0.0
|
|
32
|
+
for i in range(N):
|
|
33
|
+
x1, y1 = poly[i]
|
|
34
|
+
x2, y2 = poly[(i + 1) % N]
|
|
35
|
+
area += x1 * y2 - x2 * y1
|
|
36
|
+
sign = 1.0 if area > 0 else -1.0
|
|
37
|
+
new_poly = np.zeros_like(poly)
|
|
38
|
+
for i in range(N):
|
|
39
|
+
p_prev = poly[(i - 1) % N]
|
|
40
|
+
p_curr = poly[i]
|
|
41
|
+
p_next = poly[(i + 1) % N]
|
|
42
|
+
edge1 = p_curr - p_prev
|
|
43
|
+
len1 = np.linalg.norm(edge1)
|
|
44
|
+
n1 = sign * np.array([edge1[1], -edge1[0]]) / (len1 + 1e-6)
|
|
45
|
+
edge2 = p_next - p_curr
|
|
46
|
+
len2 = np.linalg.norm(edge2)
|
|
47
|
+
n2 = sign * np.array([edge2[1], -edge2[0]]) / (len2 + 1e-6)
|
|
48
|
+
n_avg = n1 + n2
|
|
49
|
+
norm_n = np.linalg.norm(n_avg)
|
|
50
|
+
if norm_n > 0:
|
|
51
|
+
n_avg /= norm_n
|
|
52
|
+
offset = shrink_ratio * min(len1, len2)
|
|
53
|
+
new_poly[i] = p_curr - offset * n_avg
|
|
54
|
+
return new_poly.astype(np.float32)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class EASTDataset(Dataset):
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
images_folder,
|
|
61
|
+
coco_annotation_file,
|
|
62
|
+
target_size=512,
|
|
63
|
+
score_geo_scale=0.25,
|
|
64
|
+
transform=None,
|
|
65
|
+
):
|
|
66
|
+
self.images_folder = images_folder
|
|
67
|
+
self.target_size = target_size
|
|
68
|
+
self.score_geo_scale = score_geo_scale
|
|
69
|
+
# transform pipeline
|
|
70
|
+
if transform is None:
|
|
71
|
+
self.transform = transforms.Compose(
|
|
72
|
+
[
|
|
73
|
+
transforms.ToPILImage(),
|
|
74
|
+
transforms.ColorJitter(0.5, 0.5, 0.5, 0.25),
|
|
75
|
+
transforms.ToTensor(),
|
|
76
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
77
|
+
]
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
self.transform = transform
|
|
81
|
+
# load COCO annotations
|
|
82
|
+
with open(coco_annotation_file, "r", encoding="utf-8") as f:
|
|
83
|
+
data = json.load(f)
|
|
84
|
+
self.images_info = {img["id"]: img for img in data["images"]}
|
|
85
|
+
self.image_ids = list(self.images_info.keys())
|
|
86
|
+
self.annots = {}
|
|
87
|
+
for ann in data["annotations"]:
|
|
88
|
+
self.annots.setdefault(ann["image_id"], []).append(ann)
|
|
89
|
+
|
|
90
|
+
def __len__(self):
|
|
91
|
+
return len(self.image_ids)
|
|
92
|
+
|
|
93
|
+
def __getitem__(self, idx):
|
|
94
|
+
image_id = self.image_ids[idx]
|
|
95
|
+
info = self.images_info[image_id]
|
|
96
|
+
path = os.path.join(self.images_folder, info["file_name"])
|
|
97
|
+
img = cv2.imread(path)
|
|
98
|
+
if img is None:
|
|
99
|
+
raise FileNotFoundError(f"Image not found: {path}")
|
|
100
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
101
|
+
# resize
|
|
102
|
+
img_resized = cv2.resize(img, (self.target_size, self.target_size))
|
|
103
|
+
# scale annotations
|
|
104
|
+
anns = self.annots.get(image_id, [])
|
|
105
|
+
quads = []
|
|
106
|
+
for ann in anns:
|
|
107
|
+
if "segmentation" not in ann:
|
|
108
|
+
continue
|
|
109
|
+
seg = ann["segmentation"]
|
|
110
|
+
pts = np.array(seg, dtype=np.float32).reshape(-1, 2)
|
|
111
|
+
# вот тут — минимальный прямоугольник
|
|
112
|
+
rect = cv2.minAreaRect(pts)
|
|
113
|
+
box = cv2.boxPoints(rect)
|
|
114
|
+
quad = order_vertices_clockwise(box)
|
|
115
|
+
# масштабируем quad под resized
|
|
116
|
+
quad[:, 0] *= self.target_size / info["width"]
|
|
117
|
+
quad[:, 1] *= self.target_size / info["height"]
|
|
118
|
+
quads.append(quad)
|
|
119
|
+
# generate maps
|
|
120
|
+
score_map, geo_map = self.compute_quad_maps(quads)
|
|
121
|
+
rboxes = np.stack([quad_to_rbox(q.flatten()) for q in quads], axis=0).astype(
|
|
122
|
+
np.float32
|
|
123
|
+
)
|
|
124
|
+
# transform image
|
|
125
|
+
img_tensor = self.transform(img_resized)
|
|
126
|
+
target = {
|
|
127
|
+
"score_map": torch.tensor(score_map).unsqueeze(0),
|
|
128
|
+
"geo_map": torch.tensor(geo_map),
|
|
129
|
+
"rboxes": torch.tensor(rboxes),
|
|
130
|
+
}
|
|
131
|
+
return img_tensor, target
|
|
132
|
+
|
|
133
|
+
def compute_quad_maps(self, quads):
|
|
134
|
+
"""
|
|
135
|
+
quads: list of (4,2) arrays
|
|
136
|
+
returns score_map (H',W') and geo_map (8,H',W')
|
|
137
|
+
"""
|
|
138
|
+
out_h = int(self.target_size * self.score_geo_scale)
|
|
139
|
+
out_w = int(self.target_size * self.score_geo_scale)
|
|
140
|
+
score_map = np.zeros((out_h, out_w), dtype=np.float32)
|
|
141
|
+
geo_map = np.zeros((8, out_h, out_w), dtype=np.float32)
|
|
142
|
+
for quad in quads:
|
|
143
|
+
# shrink & order
|
|
144
|
+
shrunk = shrink_poly(order_vertices_clockwise(quad), shrink_ratio=0.3)
|
|
145
|
+
# scale to map resolution
|
|
146
|
+
coords = shrunk * self.score_geo_scale
|
|
147
|
+
rr, cc = skimage.draw.polygon(
|
|
148
|
+
coords[:, 1], coords[:, 0], shape=(out_h, out_w)
|
|
149
|
+
)
|
|
150
|
+
score_map[rr, cc] = 1
|
|
151
|
+
for i, (vx, vy) in enumerate(coords):
|
|
152
|
+
geo_map[2 * i, rr, cc] = vx - cc
|
|
153
|
+
geo_map[2 * i + 1, rr, cc] = vy - rr
|
|
154
|
+
return score_map, geo_map
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from torchvision.models import resnet50, ResNet50_Weights
|
|
5
|
+
from torchvision.models.feature_extraction import create_feature_extractor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DecoderBlock(nn.Module):
|
|
9
|
+
def __init__(self, in_channels, mid_channels, out_channels):
|
|
10
|
+
super(DecoderBlock, self).__init__()
|
|
11
|
+
self.conv1x1 = nn.Sequential(
|
|
12
|
+
nn.Conv2d(in_channels, mid_channels, kernel_size=1),
|
|
13
|
+
nn.BatchNorm2d(mid_channels),
|
|
14
|
+
nn.ReLU(inplace=True),
|
|
15
|
+
)
|
|
16
|
+
self.conv3x3 = nn.Sequential(
|
|
17
|
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
|
18
|
+
nn.BatchNorm2d(out_channels),
|
|
19
|
+
nn.ReLU(inplace=True),
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def forward(self, x):
|
|
23
|
+
x = self.conv1x1(x)
|
|
24
|
+
x = self.conv3x3(x)
|
|
25
|
+
return x
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ResNetFeatureExtractor(nn.Module):
|
|
29
|
+
def __init__(self, pretrained=True, freeze_first=False):
|
|
30
|
+
super(ResNetFeatureExtractor, self).__init__()
|
|
31
|
+
self.model = resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None)
|
|
32
|
+
if freeze_first:
|
|
33
|
+
for name, param in self.model.named_parameters():
|
|
34
|
+
if name.startswith(("conv1", "bn1", "layer1")):
|
|
35
|
+
param.requires_grad = False
|
|
36
|
+
self.extractor = create_feature_extractor(
|
|
37
|
+
self.model,
|
|
38
|
+
return_nodes={
|
|
39
|
+
"layer1": "res1", # stride 4, 256 channels
|
|
40
|
+
"layer2": "res2", # stride 8, 512 channels
|
|
41
|
+
"layer3": "res3", # stride 16,1024 channels
|
|
42
|
+
"layer4": "res4", # stride 32,2048 channels
|
|
43
|
+
},
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
return self.extractor(x)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class FeatureMergingBranchResNet(nn.Module):
|
|
51
|
+
def __init__(self):
|
|
52
|
+
super(FeatureMergingBranchResNet, self).__init__()
|
|
53
|
+
self.block1 = DecoderBlock(in_channels=2048, mid_channels=512, out_channels=512)
|
|
54
|
+
self.block2 = DecoderBlock(
|
|
55
|
+
in_channels=512 + 1024, mid_channels=256, out_channels=256
|
|
56
|
+
)
|
|
57
|
+
self.block3 = DecoderBlock(
|
|
58
|
+
in_channels=256 + 512, mid_channels=128, out_channels=128
|
|
59
|
+
)
|
|
60
|
+
self.block4 = DecoderBlock(
|
|
61
|
+
in_channels=128 + 256, mid_channels=64, out_channels=32
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def forward(self, feats):
|
|
65
|
+
f1 = feats["res1"]
|
|
66
|
+
f2 = feats["res2"]
|
|
67
|
+
f3 = feats["res3"]
|
|
68
|
+
f4 = feats["res4"]
|
|
69
|
+
h4 = self.block1(f4)
|
|
70
|
+
h4_up = F.interpolate(h4, scale_factor=2, mode="bilinear", align_corners=False)
|
|
71
|
+
h3 = self.block2(torch.cat([h4_up, f3], dim=1))
|
|
72
|
+
h3_up = F.interpolate(h3, scale_factor=2, mode="bilinear", align_corners=False)
|
|
73
|
+
h2 = self.block3(torch.cat([h3_up, f2], dim=1))
|
|
74
|
+
h2_up = F.interpolate(h2, scale_factor=2, mode="bilinear", align_corners=False)
|
|
75
|
+
h1 = self.block4(torch.cat([h2_up, f1], dim=1))
|
|
76
|
+
return h1
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class OutputHead(nn.Module):
|
|
80
|
+
def __init__(self):
|
|
81
|
+
super(OutputHead, self).__init__()
|
|
82
|
+
self.score_map = nn.Conv2d(32, 1, kernel_size=1)
|
|
83
|
+
self.geo_map = nn.Conv2d(32, 8, kernel_size=1)
|
|
84
|
+
|
|
85
|
+
def forward(self, x):
|
|
86
|
+
score = torch.sigmoid(self.score_map(x))
|
|
87
|
+
geometry = self.geo_map(x)
|
|
88
|
+
return score, geometry
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class TextDetectionFCN(nn.Module):
|
|
92
|
+
def __init__(
|
|
93
|
+
self, pretrained_backbone=True, freeze_first=False, pretrained_model_path=None
|
|
94
|
+
):
|
|
95
|
+
super(TextDetectionFCN, self).__init__()
|
|
96
|
+
# ResNet50 backbone
|
|
97
|
+
self.backbone = ResNetFeatureExtractor(
|
|
98
|
+
pretrained=pretrained_backbone, freeze_first=freeze_first
|
|
99
|
+
)
|
|
100
|
+
self.decoder = FeatureMergingBranchResNet()
|
|
101
|
+
self.output_head = OutputHead()
|
|
102
|
+
# scales for maps
|
|
103
|
+
self.score_scale = 0.25
|
|
104
|
+
self.geo_scale = 0.25
|
|
105
|
+
# load optional pretrained model weights
|
|
106
|
+
if pretrained_model_path:
|
|
107
|
+
state = torch.load(pretrained_model_path, map_location="cpu")
|
|
108
|
+
self.load_state_dict(state, strict=False)
|
|
109
|
+
print(f"Loaded pretrained model from {pretrained_model_path}")
|
|
110
|
+
|
|
111
|
+
def forward(self, x):
|
|
112
|
+
feats = self.backbone(x)
|
|
113
|
+
merged = self.decoder(feats)
|
|
114
|
+
score, geometry = self.output_head(merged)
|
|
115
|
+
return {"score": score, "geometry": geometry}
|