manuscript-ocr 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. manuscript_ocr-0.1.0/MANIFEST.in +2 -0
  2. manuscript_ocr-0.1.0/PKG-INFO +73 -0
  3. manuscript_ocr-0.1.0/README.md +39 -0
  4. manuscript_ocr-0.1.0/pyproject.toml +3 -0
  5. manuscript_ocr-0.1.0/requirements.txt +13 -0
  6. manuscript_ocr-0.1.0/setup.cfg +4 -0
  7. manuscript_ocr-0.1.0/setup.py +35 -0
  8. manuscript_ocr-0.1.0/src/example.py +13 -0
  9. manuscript_ocr-0.1.0/src/manuscript/detectors/__init__.py +3 -0
  10. manuscript_ocr-0.1.0/src/manuscript/detectors/__pycache__/__init__.cpython-311.pyc +0 -0
  11. manuscript_ocr-0.1.0/src/manuscript/detectors/__pycache__/types.cpython-311.pyc +0 -0
  12. manuscript_ocr-0.1.0/src/manuscript/detectors/east/__init__.py +112 -0
  13. manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/__init__.cpython-311.pyc +0 -0
  14. manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/dataset.cpython-311.pyc +0 -0
  15. manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/east.cpython-311.pyc +0 -0
  16. manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/lanms.cpython-311.pyc +0 -0
  17. manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/loss.cpython-311.pyc +0 -0
  18. manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/sam.cpython-311.pyc +0 -0
  19. manuscript_ocr-0.1.0/src/manuscript/detectors/east/__pycache__/utils.cpython-311.pyc +0 -0
  20. manuscript_ocr-0.1.0/src/manuscript/detectors/east/dataset.py +154 -0
  21. manuscript_ocr-0.1.0/src/manuscript/detectors/east/east.py +115 -0
  22. manuscript_ocr-0.1.0/src/manuscript/detectors/east/lanms.py +214 -0
  23. manuscript_ocr-0.1.0/src/manuscript/detectors/east/loss.py +63 -0
  24. manuscript_ocr-0.1.0/src/manuscript/detectors/east/sam.py +76 -0
  25. manuscript_ocr-0.1.0/src/manuscript/detectors/east/train.py +305 -0
  26. manuscript_ocr-0.1.0/src/manuscript/detectors/east/utils.py +235 -0
  27. manuscript_ocr-0.1.0/src/manuscript/detectors/types.py +27 -0
  28. manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/PKG-INFO +73 -0
  29. manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/SOURCES.txt +30 -0
  30. manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/dependency_links.txt +1 -0
  31. manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/requires.txt +12 -0
  32. manuscript_ocr-0.1.0/src/manuscript_ocr.egg-info/top_level.txt +1 -0
@@ -0,0 +1,2 @@
1
+ include README.md LICENSE requirements.txt
2
+ recursive-include src *
@@ -0,0 +1,73 @@
1
+ Metadata-Version: 2.4
2
+ Name: manuscript-ocr
3
+ Version: 0.1.0
4
+ Summary: manuscript-ocr
5
+ Home-page:
6
+ Author:
7
+ Author-email:
8
+ License: MIT
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
13
+ Requires-Python: >=3.8
14
+ Description-Content-Type: text/markdown
15
+ Requires-Dist: torch==2.0.1+cu118
16
+ Requires-Dist: torchvision==0.15.2+cu118
17
+ Requires-Dist: torchaudio==2.0.2
18
+ Requires-Dist: numpy<2
19
+ Requires-Dist: opencv-python
20
+ Requires-Dist: Pillow
21
+ Requires-Dist: shapely
22
+ Requires-Dist: numba
23
+ Requires-Dist: tensorboard
24
+ Requires-Dist: gdown
25
+ Requires-Dist: pydantic
26
+ Requires-Dist: scikit-image
27
+ Dynamic: classifier
28
+ Dynamic: description
29
+ Dynamic: description-content-type
30
+ Dynamic: license
31
+ Dynamic: requires-dist
32
+ Dynamic: requires-python
33
+ Dynamic: summary
34
+
35
+ ## Installation
36
+
37
+ ```bash
38
+ pip install manuscript-ocr
39
+ ````
40
+
41
+ ## Usage Example
42
+
43
+ ```python
44
+ from PIL import Image
45
+ from manuscript.detectors import EASTInfer
46
+
47
+ # Инициализация
48
+ det = EASTInfer(score_thresh=0.9)
49
+
50
+ # Инфер с визуализацией
51
+ page, vis_image = det.infer(r"example\ocr_example_image.jpg", vis=True)
52
+
53
+ print(page)
54
+
55
+ # Покажет картинку с наложенными боксами
56
+ Image.fromarray(vis_image).show()
57
+
58
+ # Или сохранить результат на диск:
59
+ Image.fromarray(vis_image).save(r"example\ocr_example_image_infer.png")
60
+ ```
61
+
62
+ ### Результат
63
+
64
+ Текстовые блоки будут выведены в консоль, например:
65
+
66
+ ```
67
+ Page(blocks=[Block(words=[Word(polygon=[(874.1005, 909.1005), (966.8995, 909.1005), (966.8995, 956.8995), (874.1005, 956.8995)]),
68
+ Word(polygon=[(849.1234, 810.5678), … ])])])
69
+ ```
70
+
71
+ А визуализация сохранится в файл `example/ocr_example_image_infer.png`:
72
+
73
+ ![OCR Inference Result](example/ocr_example_image_infer.png)
@@ -0,0 +1,39 @@
1
+ ## Installation
2
+
3
+ ```bash
4
+ pip install manuscript-ocr
5
+ ````
6
+
7
+ ## Usage Example
8
+
9
+ ```python
10
+ from PIL import Image
11
+ from manuscript.detectors import EASTInfer
12
+
13
+ # Инициализация
14
+ det = EASTInfer(score_thresh=0.9)
15
+
16
+ # Инфер с визуализацией
17
+ page, vis_image = det.infer(r"example\ocr_example_image.jpg", vis=True)
18
+
19
+ print(page)
20
+
21
+ # Покажет картинку с наложенными боксами
22
+ Image.fromarray(vis_image).show()
23
+
24
+ # Или сохранить результат на диск:
25
+ Image.fromarray(vis_image).save(r"example\ocr_example_image_infer.png")
26
+ ```
27
+
28
+ ### Результат
29
+
30
+ Текстовые блоки будут выведены в консоль, например:
31
+
32
+ ```
33
+ Page(blocks=[Block(words=[Word(polygon=[(874.1005, 909.1005), (966.8995, 909.1005), (966.8995, 956.8995), (874.1005, 956.8995)]),
34
+ Word(polygon=[(849.1234, 810.5678), … ])])])
35
+ ```
36
+
37
+ А визуализация сохранится в файл `example/ocr_example_image_infer.png`:
38
+
39
+ ![OCR Inference Result](example/ocr_example_image_infer.png)
@@ -0,0 +1,3 @@
1
+ [build-system]
2
+ requires = ["setuptools>=42", "wheel"]
3
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,13 @@
1
+ torch==2.0.1+cu118
2
+ torchvision==0.15.2+cu118
3
+ torchaudio==2.0.2
4
+
5
+ numpy<2
6
+ opencv-python
7
+ Pillow
8
+ shapely
9
+ numba
10
+ tensorboard
11
+ gdown
12
+ pydantic
13
+ scikit-image
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,35 @@
1
+ import os
2
+ from setuptools import setup, find_packages
3
+
4
+ def parse_requirements(fname="requirements.txt"):
5
+ here = os.path.dirname(__file__)
6
+ with open(os.path.join(here, fname)) as f:
7
+ return [ln.strip() for ln in f if ln.strip() and not ln.startswith("#")]
8
+
9
+ setup(
10
+ name="manuscript-ocr",
11
+ version="0.1.0",
12
+ description="manuscript-ocr",
13
+ long_description=open("README.md", encoding="utf-8").read(),
14
+ long_description_content_type="text/markdown",
15
+
16
+ author="",
17
+ author_email="",
18
+
19
+ url="",
20
+ license="MIT",
21
+
22
+ packages=find_packages(where="src"),
23
+ package_dir={"": "src"},
24
+
25
+ python_requires=">=3.8",
26
+ install_requires=parse_requirements(),
27
+
28
+ classifiers=[
29
+ "Programming Language :: Python :: 3",
30
+ "License :: OSI Approved :: MIT License",
31
+ "Operating System :: OS Independent",
32
+ "Topic :: Software Development :: Libraries :: Python Modules",
33
+ ],
34
+ include_package_data=True,
35
+ )
@@ -0,0 +1,13 @@
1
+ from PIL import Image
2
+ from manuscript.detectors import EASTInfer
3
+
4
+ # инициализация
5
+ det = EASTInfer(score_thresh=0.9)
6
+
7
+ # инфер
8
+ page, image = det.infer(r"example\ocr_example_image.jpg", vis=True)
9
+ print(page)
10
+
11
+ pil_img = Image.fromarray(image)
12
+
13
+ pil_img.show()
@@ -0,0 +1,3 @@
1
+ from .east import EASTInfer
2
+
3
+ __all__ = ["EASTInfer"]
@@ -0,0 +1,112 @@
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from torchvision import transforms
5
+ from pathlib import Path
6
+ from typing import Union, Optional, List, Tuple
7
+
8
+ from .east import TextDetectionFCN
9
+ from .utils import decode_boxes_from_maps, draw_boxes
10
+ from ..types import Word, Block, Page
11
+ import os
12
+
13
+ import gdown
14
+
15
+ class EASTInfer:
16
+ def __init__(
17
+ self,
18
+ weights_path: Optional[Union[str, Path]] = None,
19
+ device: Optional[str] = None,
20
+ target_size: int = 1024,
21
+ score_geo_scale: float = 0.25,
22
+ shrink_ratio: float = 0.5,
23
+ score_thresh: float = 0.9,
24
+ iou_threshold: float = 0.2,
25
+ ):
26
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ if weights_path is None:
29
+ url = (
30
+ "https://github.com/konstantinkozhin/manuscript-ocr"
31
+ "/releases/download/v0.1.0/east_quad_14_05.pth"
32
+ )
33
+ out = os.path.expanduser("~/.east_weights.pth")
34
+ if not os.path.exists(out):
35
+ print(f"Downloading EAST weights from {url} …")
36
+ gdown.download(url, out, quiet=False)
37
+ weights_path = out
38
+ print(weights_path)
39
+ # Загружаем модель с весами
40
+ self.model = TextDetectionFCN(
41
+ pretrained_backbone=False,
42
+ pretrained_model_path=str(weights_path),
43
+ ).to(self.device)
44
+ self.model.eval()
45
+
46
+ self.target_size = target_size
47
+ self.score_geo_scale = score_geo_scale
48
+ self.shrink_ratio = shrink_ratio
49
+ self.score_thresh = score_thresh
50
+ self.iou_threshold = iou_threshold
51
+
52
+ self.tf = transforms.Compose([
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=(0.5, 0.5, 0.5),
55
+ std=(0.5, 0.5, 0.5)),
56
+ ])
57
+
58
+ def infer(
59
+ self,
60
+ img_or_path: Union[str, Path, np.ndarray],
61
+ vis: bool = False
62
+ ) -> Union[Page, Tuple[Page, np.ndarray]]:
63
+ """
64
+ :param img_or_path: путь или RGB ndarray
65
+ :param vis: если True, возвращает также изображение с боксами
66
+ :return: Page или (Page, vis_image)
67
+ """
68
+ # 1) Read & RGB
69
+ if isinstance(img_or_path, (str, Path)):
70
+ img = cv2.imread(str(img_or_path))
71
+ if img is None:
72
+ raise FileNotFoundError(f"Cannot read image: {img_or_path}")
73
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
74
+ elif isinstance(img_or_path, np.ndarray):
75
+ img = img_or_path
76
+ else:
77
+ raise TypeError(f"Unsupported type {type(img_or_path)}")
78
+
79
+ # 2) Resize + ToTensor + Normalize
80
+ resized = cv2.resize(img, (self.target_size, self.target_size))
81
+ img_t = self.tf(resized).to(self.device)
82
+
83
+ # 3) Forward
84
+ with torch.no_grad():
85
+ out = self.model(img_t.unsqueeze(0))
86
+
87
+ # 4) Extract maps
88
+ score_map = out["score"][0].cpu().numpy().squeeze(0)
89
+ geo_map = out["geometry"][0].cpu().numpy().transpose(1, 2, 0)
90
+
91
+ # 5) Decode quads
92
+ quads9 = decode_boxes_from_maps(
93
+ score_map=score_map,
94
+ geo_map=geo_map,
95
+ score_thresh=self.score_thresh,
96
+ scale=1.0 / self.score_geo_scale,
97
+ iou_threshold=self.iou_threshold,
98
+ expand_ratio=self.shrink_ratio,
99
+ )
100
+
101
+ # 6) Build Page
102
+ words: List[Word] = []
103
+ for quad in quads9:
104
+ pts = quad[:8].reshape(4, 2).tolist()
105
+ words.append(Word(polygon=pts))
106
+ page = Page(blocks=[Block(words=words)])
107
+
108
+ # 7) Optional visualization
109
+ if vis:
110
+ vis_img = draw_boxes(resized, quads9)
111
+ return page, vis_img
112
+ return page
@@ -0,0 +1,154 @@
1
+ import os
2
+ import json
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import math
8
+ from shapely.geometry import Polygon
9
+ import torchvision.transforms as transforms
10
+ import skimage.draw
11
+ from .utils import quad_to_rbox
12
+
13
+
14
+ def order_vertices_clockwise(poly):
15
+ poly = np.array(poly).reshape(-1, 2)
16
+ s = poly.sum(axis=1)
17
+ diff = np.diff(poly, axis=1).flatten()
18
+ tl = poly[np.argmin(s)]
19
+ br = poly[np.argmax(s)]
20
+ tr = poly[np.argmin(diff)]
21
+ bl = poly[np.argmax(diff)]
22
+ return np.array([tl, tr, br, bl], dtype=np.float32)
23
+
24
+
25
+ def shrink_poly(poly, shrink_ratio=0.3):
26
+ poly = np.array(poly, dtype=np.float32).reshape(-1, 2)
27
+ N = poly.shape[0]
28
+ if N != 4:
29
+ raise ValueError("Expected quadrilateral with 4 vertices")
30
+ # signed area sign
31
+ area = 0.0
32
+ for i in range(N):
33
+ x1, y1 = poly[i]
34
+ x2, y2 = poly[(i + 1) % N]
35
+ area += x1 * y2 - x2 * y1
36
+ sign = 1.0 if area > 0 else -1.0
37
+ new_poly = np.zeros_like(poly)
38
+ for i in range(N):
39
+ p_prev = poly[(i - 1) % N]
40
+ p_curr = poly[i]
41
+ p_next = poly[(i + 1) % N]
42
+ edge1 = p_curr - p_prev
43
+ len1 = np.linalg.norm(edge1)
44
+ n1 = sign * np.array([edge1[1], -edge1[0]]) / (len1 + 1e-6)
45
+ edge2 = p_next - p_curr
46
+ len2 = np.linalg.norm(edge2)
47
+ n2 = sign * np.array([edge2[1], -edge2[0]]) / (len2 + 1e-6)
48
+ n_avg = n1 + n2
49
+ norm_n = np.linalg.norm(n_avg)
50
+ if norm_n > 0:
51
+ n_avg /= norm_n
52
+ offset = shrink_ratio * min(len1, len2)
53
+ new_poly[i] = p_curr - offset * n_avg
54
+ return new_poly.astype(np.float32)
55
+
56
+
57
+ class EASTDataset(Dataset):
58
+ def __init__(
59
+ self,
60
+ images_folder,
61
+ coco_annotation_file,
62
+ target_size=512,
63
+ score_geo_scale=0.25,
64
+ transform=None,
65
+ ):
66
+ self.images_folder = images_folder
67
+ self.target_size = target_size
68
+ self.score_geo_scale = score_geo_scale
69
+ # transform pipeline
70
+ if transform is None:
71
+ self.transform = transforms.Compose(
72
+ [
73
+ transforms.ToPILImage(),
74
+ transforms.ColorJitter(0.5, 0.5, 0.5, 0.25),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
77
+ ]
78
+ )
79
+ else:
80
+ self.transform = transform
81
+ # load COCO annotations
82
+ with open(coco_annotation_file, "r", encoding="utf-8") as f:
83
+ data = json.load(f)
84
+ self.images_info = {img["id"]: img for img in data["images"]}
85
+ self.image_ids = list(self.images_info.keys())
86
+ self.annots = {}
87
+ for ann in data["annotations"]:
88
+ self.annots.setdefault(ann["image_id"], []).append(ann)
89
+
90
+ def __len__(self):
91
+ return len(self.image_ids)
92
+
93
+ def __getitem__(self, idx):
94
+ image_id = self.image_ids[idx]
95
+ info = self.images_info[image_id]
96
+ path = os.path.join(self.images_folder, info["file_name"])
97
+ img = cv2.imread(path)
98
+ if img is None:
99
+ raise FileNotFoundError(f"Image not found: {path}")
100
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
101
+ # resize
102
+ img_resized = cv2.resize(img, (self.target_size, self.target_size))
103
+ # scale annotations
104
+ anns = self.annots.get(image_id, [])
105
+ quads = []
106
+ for ann in anns:
107
+ if "segmentation" not in ann:
108
+ continue
109
+ seg = ann["segmentation"]
110
+ pts = np.array(seg, dtype=np.float32).reshape(-1, 2)
111
+ # вот тут — минимальный прямоугольник
112
+ rect = cv2.minAreaRect(pts)
113
+ box = cv2.boxPoints(rect)
114
+ quad = order_vertices_clockwise(box)
115
+ # масштабируем quad под resized
116
+ quad[:, 0] *= self.target_size / info["width"]
117
+ quad[:, 1] *= self.target_size / info["height"]
118
+ quads.append(quad)
119
+ # generate maps
120
+ score_map, geo_map = self.compute_quad_maps(quads)
121
+ rboxes = np.stack([quad_to_rbox(q.flatten()) for q in quads], axis=0).astype(
122
+ np.float32
123
+ )
124
+ # transform image
125
+ img_tensor = self.transform(img_resized)
126
+ target = {
127
+ "score_map": torch.tensor(score_map).unsqueeze(0),
128
+ "geo_map": torch.tensor(geo_map),
129
+ "rboxes": torch.tensor(rboxes),
130
+ }
131
+ return img_tensor, target
132
+
133
+ def compute_quad_maps(self, quads):
134
+ """
135
+ quads: list of (4,2) arrays
136
+ returns score_map (H',W') and geo_map (8,H',W')
137
+ """
138
+ out_h = int(self.target_size * self.score_geo_scale)
139
+ out_w = int(self.target_size * self.score_geo_scale)
140
+ score_map = np.zeros((out_h, out_w), dtype=np.float32)
141
+ geo_map = np.zeros((8, out_h, out_w), dtype=np.float32)
142
+ for quad in quads:
143
+ # shrink & order
144
+ shrunk = shrink_poly(order_vertices_clockwise(quad), shrink_ratio=0.3)
145
+ # scale to map resolution
146
+ coords = shrunk * self.score_geo_scale
147
+ rr, cc = skimage.draw.polygon(
148
+ coords[:, 1], coords[:, 0], shape=(out_h, out_w)
149
+ )
150
+ score_map[rr, cc] = 1
151
+ for i, (vx, vy) in enumerate(coords):
152
+ geo_map[2 * i, rr, cc] = vx - cc
153
+ geo_map[2 * i + 1, rr, cc] = vy - rr
154
+ return score_map, geo_map
@@ -0,0 +1,115 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import resnet50, ResNet50_Weights
5
+ from torchvision.models.feature_extraction import create_feature_extractor
6
+
7
+
8
+ class DecoderBlock(nn.Module):
9
+ def __init__(self, in_channels, mid_channels, out_channels):
10
+ super(DecoderBlock, self).__init__()
11
+ self.conv1x1 = nn.Sequential(
12
+ nn.Conv2d(in_channels, mid_channels, kernel_size=1),
13
+ nn.BatchNorm2d(mid_channels),
14
+ nn.ReLU(inplace=True),
15
+ )
16
+ self.conv3x3 = nn.Sequential(
17
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
18
+ nn.BatchNorm2d(out_channels),
19
+ nn.ReLU(inplace=True),
20
+ )
21
+
22
+ def forward(self, x):
23
+ x = self.conv1x1(x)
24
+ x = self.conv3x3(x)
25
+ return x
26
+
27
+
28
+ class ResNetFeatureExtractor(nn.Module):
29
+ def __init__(self, pretrained=True, freeze_first=False):
30
+ super(ResNetFeatureExtractor, self).__init__()
31
+ self.model = resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None)
32
+ if freeze_first:
33
+ for name, param in self.model.named_parameters():
34
+ if name.startswith(("conv1", "bn1", "layer1")):
35
+ param.requires_grad = False
36
+ self.extractor = create_feature_extractor(
37
+ self.model,
38
+ return_nodes={
39
+ "layer1": "res1", # stride 4, 256 channels
40
+ "layer2": "res2", # stride 8, 512 channels
41
+ "layer3": "res3", # stride 16,1024 channels
42
+ "layer4": "res4", # stride 32,2048 channels
43
+ },
44
+ )
45
+
46
+ def forward(self, x):
47
+ return self.extractor(x)
48
+
49
+
50
+ class FeatureMergingBranchResNet(nn.Module):
51
+ def __init__(self):
52
+ super(FeatureMergingBranchResNet, self).__init__()
53
+ self.block1 = DecoderBlock(in_channels=2048, mid_channels=512, out_channels=512)
54
+ self.block2 = DecoderBlock(
55
+ in_channels=512 + 1024, mid_channels=256, out_channels=256
56
+ )
57
+ self.block3 = DecoderBlock(
58
+ in_channels=256 + 512, mid_channels=128, out_channels=128
59
+ )
60
+ self.block4 = DecoderBlock(
61
+ in_channels=128 + 256, mid_channels=64, out_channels=32
62
+ )
63
+
64
+ def forward(self, feats):
65
+ f1 = feats["res1"]
66
+ f2 = feats["res2"]
67
+ f3 = feats["res3"]
68
+ f4 = feats["res4"]
69
+ h4 = self.block1(f4)
70
+ h4_up = F.interpolate(h4, scale_factor=2, mode="bilinear", align_corners=False)
71
+ h3 = self.block2(torch.cat([h4_up, f3], dim=1))
72
+ h3_up = F.interpolate(h3, scale_factor=2, mode="bilinear", align_corners=False)
73
+ h2 = self.block3(torch.cat([h3_up, f2], dim=1))
74
+ h2_up = F.interpolate(h2, scale_factor=2, mode="bilinear", align_corners=False)
75
+ h1 = self.block4(torch.cat([h2_up, f1], dim=1))
76
+ return h1
77
+
78
+
79
+ class OutputHead(nn.Module):
80
+ def __init__(self):
81
+ super(OutputHead, self).__init__()
82
+ self.score_map = nn.Conv2d(32, 1, kernel_size=1)
83
+ self.geo_map = nn.Conv2d(32, 8, kernel_size=1)
84
+
85
+ def forward(self, x):
86
+ score = torch.sigmoid(self.score_map(x))
87
+ geometry = self.geo_map(x)
88
+ return score, geometry
89
+
90
+
91
+ class TextDetectionFCN(nn.Module):
92
+ def __init__(
93
+ self, pretrained_backbone=True, freeze_first=False, pretrained_model_path=None
94
+ ):
95
+ super(TextDetectionFCN, self).__init__()
96
+ # ResNet50 backbone
97
+ self.backbone = ResNetFeatureExtractor(
98
+ pretrained=pretrained_backbone, freeze_first=freeze_first
99
+ )
100
+ self.decoder = FeatureMergingBranchResNet()
101
+ self.output_head = OutputHead()
102
+ # scales for maps
103
+ self.score_scale = 0.25
104
+ self.geo_scale = 0.25
105
+ # load optional pretrained model weights
106
+ if pretrained_model_path:
107
+ state = torch.load(pretrained_model_path, map_location="cpu")
108
+ self.load_state_dict(state, strict=False)
109
+ print(f"Loaded pretrained model from {pretrained_model_path}")
110
+
111
+ def forward(self, x):
112
+ feats = self.backbone(x)
113
+ merged = self.decoder(feats)
114
+ score, geometry = self.output_head(merged)
115
+ return {"score": score, "geometry": geometry}