pensiev 0.25.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memos/__init__.py +6 -0
- memos/cmds/__init__.py +0 -0
- memos/cmds/library.py +1289 -0
- memos/cmds/plugin.py +96 -0
- memos/commands.py +865 -0
- memos/config.py +225 -0
- memos/crud.py +605 -0
- memos/databases/__init__.py +0 -0
- memos/databases/initializers.py +481 -0
- memos/dataset_extractor_for_florence.py +165 -0
- memos/dataset_extractor_for_internvl2.py +192 -0
- memos/default_config.yaml +88 -0
- memos/embedding.py +129 -0
- memos/frame_extractor.py +53 -0
- memos/logging_config.py +35 -0
- memos/main.py +104 -0
- memos/migrations/alembic/README +1 -0
- memos/migrations/alembic/__pycache__/env.cpython-310.pyc +0 -0
- memos/migrations/alembic/env.py +108 -0
- memos/migrations/alembic/script.py.mako +30 -0
- memos/migrations/alembic/versions/00904ac8c6fc_add_indexes_to_entitymodel.py +63 -0
- memos/migrations/alembic/versions/04acdaf75664_add_indices_to_entitytags_and_metadata.py +86 -0
- memos/migrations/alembic/versions/12504c5b1d3c_add_extra_columns_for_embedding.py +67 -0
- memos/migrations/alembic/versions/31a1ad0e10b3_add_entity_plugin_status.py +71 -0
- memos/migrations/alembic/versions/__pycache__/00904ac8c6fc_add_indexes_to_entitymodel.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/04acdaf75664_add_indices_to_entitytags_and_metadata.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/12504c5b1d3c_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/20f5ecab014d_add_entity_plugin_status.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/31a1ad0e10b3_add_entity_plugin_status.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/4fcb062c5128_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/d10c55fbb7d2_add_index_for_entity_file_type_group_.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/f8f158182416_add_active_app_index.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/d10c55fbb7d2_add_index_for_entity_file_type_group_.py +44 -0
- memos/migrations/alembic/versions/f8f158182416_add_active_app_index.py +75 -0
- memos/migrations/alembic.ini +116 -0
- memos/migrations.py +19 -0
- memos/models.py +199 -0
- memos/plugins/__init__.py +0 -0
- memos/plugins/ocr/__init__.py +0 -0
- memos/plugins/ocr/main.py +251 -0
- memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx +0 -0
- memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx +0 -0
- memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx +0 -0
- memos/plugins/ocr/ppocr-gpu.yaml +43 -0
- memos/plugins/ocr/ppocr.yaml +44 -0
- memos/plugins/ocr/server.py +227 -0
- memos/plugins/ocr/temp_ppocr.yaml +42 -0
- memos/plugins/vlm/__init__.py +0 -0
- memos/plugins/vlm/main.py +251 -0
- memos/prepare_dataset.py +107 -0
- memos/process_webp.py +55 -0
- memos/read_metadata.py +32 -0
- memos/record.py +358 -0
- memos/schemas.py +289 -0
- memos/search.py +1198 -0
- memos/server.py +883 -0
- memos/shotsum.py +105 -0
- memos/shotsum_with_ocr.py +145 -0
- memos/simple_tokenizer/dict/README.md +31 -0
- memos/simple_tokenizer/dict/hmm_model.utf8 +34 -0
- memos/simple_tokenizer/dict/idf.utf8 +258826 -0
- memos/simple_tokenizer/dict/jieba.dict.utf8 +348982 -0
- memos/simple_tokenizer/dict/pos_dict/char_state_tab.utf8 +6653 -0
- memos/simple_tokenizer/dict/pos_dict/prob_emit.utf8 +166 -0
- memos/simple_tokenizer/dict/pos_dict/prob_start.utf8 +259 -0
- memos/simple_tokenizer/dict/pos_dict/prob_trans.utf8 +5222 -0
- memos/simple_tokenizer/dict/stop_words.utf8 +1534 -0
- memos/simple_tokenizer/dict/user.dict.utf8 +4 -0
- memos/simple_tokenizer/linux/libsimple.so +0 -0
- memos/simple_tokenizer/macos/libsimple.dylib +0 -0
- memos/simple_tokenizer/windows/simple.dll +0 -0
- memos/static/_app/immutable/assets/0.e250c031.css +1 -0
- memos/static/_app/immutable/assets/_layout.e7937cfe.css +1 -0
- memos/static/_app/immutable/chunks/index.5c08976b.js +1 -0
- memos/static/_app/immutable/chunks/index.60ee613b.js +4 -0
- memos/static/_app/immutable/chunks/runtime.a7926cf6.js +5 -0
- memos/static/_app/immutable/chunks/scheduler.5c1cff6e.js +1 -0
- memos/static/_app/immutable/chunks/singletons.583bdf4e.js +1 -0
- memos/static/_app/immutable/entry/app.666c1643.js +1 -0
- memos/static/_app/immutable/entry/start.aed5c701.js +3 -0
- memos/static/_app/immutable/nodes/0.5862ea38.js +7 -0
- memos/static/_app/immutable/nodes/1.35378a5e.js +1 -0
- memos/static/_app/immutable/nodes/2.1ccf9ea5.js +81 -0
- memos/static/_app/version.json +1 -0
- memos/static/app.html +36 -0
- memos/static/favicon.png +0 -0
- memos/static/logos/memos_logo_1024.png +0 -0
- memos/static/logos/memos_logo_1024@2x.png +0 -0
- memos/static/logos/memos_logo_128.png +0 -0
- memos/static/logos/memos_logo_128@2x.png +0 -0
- memos/static/logos/memos_logo_16.png +0 -0
- memos/static/logos/memos_logo_16@2x.png +0 -0
- memos/static/logos/memos_logo_256.png +0 -0
- memos/static/logos/memos_logo_256@2x.png +0 -0
- memos/static/logos/memos_logo_32.png +0 -0
- memos/static/logos/memos_logo_32@2x.png +0 -0
- memos/static/logos/memos_logo_512.png +0 -0
- memos/static/logos/memos_logo_512@2x.png +0 -0
- memos/static/logos/memos_logo_64.png +0 -0
- memos/static/logos/memos_logo_64@2x.png +0 -0
- memos/test_server.py +802 -0
- memos/utils.py +49 -0
- memos_ml_backends/florence2_server.py +176 -0
- memos_ml_backends/qwen2vl_server.py +182 -0
- memos_ml_backends/schemas.py +48 -0
- pensiev-0.25.5.dist-info/LICENSE +201 -0
- pensiev-0.25.5.dist-info/METADATA +541 -0
- pensiev-0.25.5.dist-info/RECORD +111 -0
- pensiev-0.25.5.dist-info/WHEEL +5 -0
- pensiev-0.25.5.dist-info/entry_points.txt +2 -0
- pensiev-0.25.5.dist-info/top_level.txt +2 -0
memos/migrations.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from alembic.config import Config
|
3
|
+
from alembic import command
|
4
|
+
from .config import settings
|
5
|
+
|
6
|
+
|
7
|
+
def run_migrations():
|
8
|
+
"""Run all pending database migrations."""
|
9
|
+
# Get the directory containing the migrations
|
10
|
+
migrations_dir = Path(__file__).parent / "migrations"
|
11
|
+
alembic_dir = migrations_dir / "alembic"
|
12
|
+
|
13
|
+
# Create an Alembic configuration
|
14
|
+
alembic_cfg = Config(str(migrations_dir / "alembic.ini"))
|
15
|
+
alembic_cfg.set_main_option("script_location", str(alembic_dir))
|
16
|
+
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_url)
|
17
|
+
|
18
|
+
# Run the migration
|
19
|
+
command.upgrade(alembic_cfg, "head")
|
memos/models.py
ADDED
@@ -0,0 +1,199 @@
|
|
1
|
+
from sqlalchemy import (
|
2
|
+
Integer,
|
3
|
+
String,
|
4
|
+
Text,
|
5
|
+
DateTime,
|
6
|
+
Enum,
|
7
|
+
ForeignKey,
|
8
|
+
func,
|
9
|
+
Index,
|
10
|
+
)
|
11
|
+
from datetime import datetime
|
12
|
+
from sqlalchemy.orm import relationship, DeclarativeBase, Mapped, mapped_column, Session
|
13
|
+
from typing import List
|
14
|
+
from .schemas import MetadataSource, MetadataType, FolderType
|
15
|
+
|
16
|
+
|
17
|
+
class RawBase(DeclarativeBase):
|
18
|
+
pass
|
19
|
+
|
20
|
+
|
21
|
+
class Base(RawBase):
|
22
|
+
__abstract__ = True
|
23
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
24
|
+
created_at: Mapped[datetime] = mapped_column(
|
25
|
+
DateTime, server_default=func.now(), nullable=False
|
26
|
+
)
|
27
|
+
updated_at: Mapped[datetime] = mapped_column(
|
28
|
+
DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
class LibraryModel(Base):
|
33
|
+
__tablename__ = "libraries"
|
34
|
+
name: Mapped[str] = mapped_column(String, nullable=False, unique=True)
|
35
|
+
folders: Mapped[List["FolderModel"]] = relationship(
|
36
|
+
"FolderModel",
|
37
|
+
back_populates="library",
|
38
|
+
lazy="joined",
|
39
|
+
primaryjoin="and_(LibraryModel.id==FolderModel.library_id, FolderModel.type=='DEFAULT')",
|
40
|
+
)
|
41
|
+
plugins: Mapped[List["PluginModel"]] = relationship(
|
42
|
+
"PluginModel", secondary="library_plugins", lazy="joined"
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class FolderModel(Base):
|
47
|
+
__tablename__ = "folders"
|
48
|
+
path: Mapped[str] = mapped_column(String, nullable=False)
|
49
|
+
library_id: Mapped[int] = mapped_column(
|
50
|
+
Integer, ForeignKey("libraries.id"), nullable=False
|
51
|
+
)
|
52
|
+
library: Mapped["LibraryModel"] = relationship(
|
53
|
+
"LibraryModel", back_populates="folders"
|
54
|
+
)
|
55
|
+
entities: Mapped[List["EntityModel"]] = relationship(
|
56
|
+
"EntityModel", back_populates="folder"
|
57
|
+
)
|
58
|
+
type: Mapped[FolderType] = mapped_column(Enum(FolderType), nullable=False)
|
59
|
+
last_modified_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=False)
|
60
|
+
|
61
|
+
|
62
|
+
class EntityPluginStatusModel(RawBase):
|
63
|
+
__tablename__ = "entity_plugin_status"
|
64
|
+
|
65
|
+
entity_id: Mapped[int] = mapped_column(
|
66
|
+
Integer, ForeignKey("entities.id", ondelete="CASCADE"), primary_key=True
|
67
|
+
)
|
68
|
+
plugin_id: Mapped[int] = mapped_column(
|
69
|
+
Integer, ForeignKey("plugins.id", ondelete="CASCADE"), primary_key=True
|
70
|
+
)
|
71
|
+
processed_at: Mapped[datetime] = mapped_column(
|
72
|
+
DateTime, server_default=func.now(), nullable=False
|
73
|
+
)
|
74
|
+
|
75
|
+
__table_args__ = (
|
76
|
+
Index("idx_entity_plugin_entity_id", "entity_id"),
|
77
|
+
Index("idx_entity_plugin_plugin_id", "plugin_id"),
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
class EntityModel(Base):
|
82
|
+
__tablename__ = "entities"
|
83
|
+
filepath: Mapped[str] = mapped_column(String, nullable=False)
|
84
|
+
filename: Mapped[str] = mapped_column(String, nullable=False)
|
85
|
+
size: Mapped[int] = mapped_column(Integer, nullable=False)
|
86
|
+
file_created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
87
|
+
file_last_modified_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
88
|
+
file_type: Mapped[str] = mapped_column(String, nullable=False)
|
89
|
+
file_type_group: Mapped[str] = mapped_column(String, nullable=False)
|
90
|
+
last_scan_at: Mapped[datetime | None] = mapped_column(
|
91
|
+
DateTime, server_default=func.now(), onupdate=func.now(), nullable=True
|
92
|
+
)
|
93
|
+
library_id: Mapped[int] = mapped_column(
|
94
|
+
Integer, ForeignKey("libraries.id"), nullable=False
|
95
|
+
)
|
96
|
+
folder_id: Mapped[int] = mapped_column(
|
97
|
+
Integer, ForeignKey("folders.id"), nullable=False
|
98
|
+
)
|
99
|
+
folder: Mapped["FolderModel"] = relationship(
|
100
|
+
"FolderModel",
|
101
|
+
back_populates="entities",
|
102
|
+
lazy="select"
|
103
|
+
)
|
104
|
+
metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship(
|
105
|
+
"EntityMetadataModel",
|
106
|
+
lazy="select",
|
107
|
+
cascade="all, delete-orphan"
|
108
|
+
)
|
109
|
+
tags: Mapped[List["TagModel"]] = relationship(
|
110
|
+
"TagModel",
|
111
|
+
secondary="entity_tags",
|
112
|
+
lazy="select",
|
113
|
+
cascade="all, delete",
|
114
|
+
overlaps="entities",
|
115
|
+
)
|
116
|
+
plugin_status: Mapped[List["EntityPluginStatusModel"]] = relationship(
|
117
|
+
"EntityPluginStatusModel",
|
118
|
+
cascade="all, delete-orphan",
|
119
|
+
lazy="select"
|
120
|
+
)
|
121
|
+
|
122
|
+
# 添加索引
|
123
|
+
__table_args__ = (
|
124
|
+
Index("idx_filepath", "filepath"),
|
125
|
+
Index("idx_filename", "filename"),
|
126
|
+
Index("idx_file_type", "file_type"),
|
127
|
+
Index("idx_library_id", "library_id"),
|
128
|
+
Index("idx_folder_id", "folder_id"),
|
129
|
+
Index("idx_file_type_group", "file_type_group"),
|
130
|
+
Index("idx_file_created_at", "file_created_at"),
|
131
|
+
)
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def update_last_scan_at(cls, session: Session, entity: "EntityModel"):
|
135
|
+
entity.last_scan_at = func.now()
|
136
|
+
session.add(entity)
|
137
|
+
|
138
|
+
@property
|
139
|
+
def tag_names(self) -> List[str]:
|
140
|
+
return [tag.name for tag in self.tags]
|
141
|
+
|
142
|
+
|
143
|
+
class TagModel(Base):
|
144
|
+
__tablename__ = "tags"
|
145
|
+
name: Mapped[str] = mapped_column(String, nullable=False)
|
146
|
+
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
147
|
+
color: Mapped[str | None] = mapped_column(String, nullable=True)
|
148
|
+
# source: Mapped[str | None] = mapped_column(String, nullable=True)
|
149
|
+
|
150
|
+
|
151
|
+
class EntityTagModel(Base):
|
152
|
+
__tablename__ = "entity_tags"
|
153
|
+
entity_id: Mapped[int] = mapped_column(
|
154
|
+
Integer, ForeignKey("entities.id", ondelete="CASCADE"), nullable=False
|
155
|
+
)
|
156
|
+
tag_id: Mapped[int] = mapped_column(Integer, ForeignKey("tags.id"), nullable=False)
|
157
|
+
source: Mapped[MetadataSource] = mapped_column(Enum(MetadataSource), nullable=False)
|
158
|
+
|
159
|
+
__table_args__ = (
|
160
|
+
Index("idx_entity_tag_entity_id", "entity_id"),
|
161
|
+
Index("idx_entity_tag_tag_id", "tag_id"),
|
162
|
+
)
|
163
|
+
|
164
|
+
|
165
|
+
class EntityMetadataModel(Base):
|
166
|
+
__tablename__ = "metadata_entries"
|
167
|
+
entity_id: Mapped[int] = mapped_column(
|
168
|
+
Integer, ForeignKey("entities.id"), nullable=False
|
169
|
+
)
|
170
|
+
key: Mapped[str] = mapped_column(String, nullable=False)
|
171
|
+
value: Mapped[str] = mapped_column(Text, nullable=False)
|
172
|
+
source_type: Mapped[MetadataSource] = mapped_column(
|
173
|
+
Enum(MetadataSource), nullable=False
|
174
|
+
)
|
175
|
+
source: Mapped[str | None] = mapped_column(String, nullable=True)
|
176
|
+
data_type: Mapped[MetadataType] = mapped_column(Enum(MetadataType), nullable=False)
|
177
|
+
entity = relationship("EntityModel", back_populates="metadata_entries")
|
178
|
+
|
179
|
+
__table_args__ = (
|
180
|
+
Index("idx_metadata_entity_id", "entity_id"),
|
181
|
+
Index("idx_metadata_key", "key"),
|
182
|
+
)
|
183
|
+
|
184
|
+
|
185
|
+
class PluginModel(Base):
|
186
|
+
__tablename__ = "plugins"
|
187
|
+
name: Mapped[str] = mapped_column(String, nullable=False)
|
188
|
+
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
189
|
+
webhook_url: Mapped[str] = mapped_column(String, nullable=False)
|
190
|
+
|
191
|
+
|
192
|
+
class LibraryPluginModel(Base):
|
193
|
+
__tablename__ = "library_plugins"
|
194
|
+
library_id: Mapped[int] = mapped_column(
|
195
|
+
Integer, ForeignKey("libraries.id"), nullable=False
|
196
|
+
)
|
197
|
+
plugin_id: Mapped[int] = mapped_column(
|
198
|
+
Integer, ForeignKey("plugins.id"), nullable=False
|
199
|
+
)
|
File without changes
|
File without changes
|
@@ -0,0 +1,251 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from typing import Optional
|
5
|
+
import httpx
|
6
|
+
import json
|
7
|
+
import base64
|
8
|
+
from PIL import Image
|
9
|
+
import numpy as np
|
10
|
+
from concurrent.futures import ThreadPoolExecutor
|
11
|
+
from functools import partial
|
12
|
+
import yaml
|
13
|
+
import io
|
14
|
+
import platform
|
15
|
+
import cpuinfo
|
16
|
+
|
17
|
+
MAX_THUMBNAIL_SIZE = (1920, 1920)
|
18
|
+
|
19
|
+
from fastapi import APIRouter, Request, HTTPException
|
20
|
+
from memos.schemas import Entity, MetadataType
|
21
|
+
|
22
|
+
METADATA_FIELD_NAME = "ocr_result"
|
23
|
+
PLUGIN_NAME = "ocr"
|
24
|
+
|
25
|
+
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
26
|
+
endpoint = None
|
27
|
+
token = None
|
28
|
+
concurrency = None
|
29
|
+
semaphore = None
|
30
|
+
use_local = False
|
31
|
+
ocr = None
|
32
|
+
thread_pool = None
|
33
|
+
|
34
|
+
# Configure logger
|
35
|
+
logging.basicConfig(level=logging.INFO)
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
|
38
|
+
|
39
|
+
def get_metadata_name() -> str:
|
40
|
+
"""Return the metadata field name used by this plugin."""
|
41
|
+
return METADATA_FIELD_NAME
|
42
|
+
|
43
|
+
|
44
|
+
def image2base64(img_path):
|
45
|
+
try:
|
46
|
+
with Image.open(img_path) as img:
|
47
|
+
img = img.convert("RGB")
|
48
|
+
img.thumbnail(MAX_THUMBNAIL_SIZE)
|
49
|
+
buffered = io.BytesIO()
|
50
|
+
img.save(buffered, format="JPEG")
|
51
|
+
encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
52
|
+
return encoded_string
|
53
|
+
except Exception as e:
|
54
|
+
logger.error(f"Error processing image {img_path}: {str(e)}")
|
55
|
+
return None
|
56
|
+
|
57
|
+
|
58
|
+
async def fetch(endpoint: str, client, image_base64, headers: Optional[dict] = None):
|
59
|
+
async with semaphore: # 使用信号量控制并发
|
60
|
+
response = await client.post(
|
61
|
+
f"{endpoint}",
|
62
|
+
json={"image_base64": image_base64},
|
63
|
+
timeout=60,
|
64
|
+
headers=headers,
|
65
|
+
)
|
66
|
+
if response.status_code != 200:
|
67
|
+
return None
|
68
|
+
return response.json()
|
69
|
+
|
70
|
+
|
71
|
+
def convert_ocr_results(results):
|
72
|
+
if results is None:
|
73
|
+
return []
|
74
|
+
|
75
|
+
converted = []
|
76
|
+
for result in results:
|
77
|
+
item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]}
|
78
|
+
converted.append(item)
|
79
|
+
return converted
|
80
|
+
|
81
|
+
|
82
|
+
def convert_ocr_data(ocr_data):
|
83
|
+
converted_data = []
|
84
|
+
for text, score, bbox in ocr_data:
|
85
|
+
x_min, y_min, x_max, y_max = bbox
|
86
|
+
dt_boxes = [
|
87
|
+
[x_min, y_min],
|
88
|
+
[x_max, y_min],
|
89
|
+
[x_max, y_max],
|
90
|
+
[x_min, y_max]
|
91
|
+
]
|
92
|
+
entry = {
|
93
|
+
'dt_boxes': dt_boxes,
|
94
|
+
'rec_txt': text,
|
95
|
+
'score': float(score)
|
96
|
+
}
|
97
|
+
converted_data.append(entry)
|
98
|
+
return converted_data
|
99
|
+
|
100
|
+
|
101
|
+
def predict_local(img_path):
|
102
|
+
try:
|
103
|
+
if platform.system() == 'Darwin': # Check if the OS is macOS
|
104
|
+
from ocrmac import ocrmac
|
105
|
+
result = ocrmac.OCR(img_path, language_preference=['zh-Hans']).recognize(px=True)
|
106
|
+
return convert_ocr_data(result)
|
107
|
+
else:
|
108
|
+
with Image.open(img_path) as img:
|
109
|
+
img = img.convert("RGB")
|
110
|
+
img.thumbnail(MAX_THUMBNAIL_SIZE)
|
111
|
+
img_array = np.array(img)
|
112
|
+
results, _ = ocr(img_array)
|
113
|
+
return convert_ocr_results(results)
|
114
|
+
except Exception as e:
|
115
|
+
logger.error(f"Error processing image {img_path}: {str(e)}")
|
116
|
+
return None
|
117
|
+
|
118
|
+
|
119
|
+
async def async_predict_local(img_path):
|
120
|
+
loop = asyncio.get_running_loop()
|
121
|
+
results = await loop.run_in_executor(thread_pool, partial(predict_local, img_path))
|
122
|
+
return results
|
123
|
+
|
124
|
+
|
125
|
+
# Modify the predict function to use semaphore
|
126
|
+
async def predict(img_path):
|
127
|
+
if use_local:
|
128
|
+
return await async_predict_local(img_path)
|
129
|
+
|
130
|
+
image_base64 = image2base64(img_path)
|
131
|
+
if not image_base64:
|
132
|
+
return None
|
133
|
+
|
134
|
+
async with httpx.AsyncClient() as client:
|
135
|
+
headers = {"Authorization": f"Bearer {token.get_secret_value()}"} if token else {}
|
136
|
+
return await fetch(endpoint, client, image_base64, headers)
|
137
|
+
|
138
|
+
|
139
|
+
@router.get("/")
|
140
|
+
async def read_root():
|
141
|
+
return {"healthy": True}
|
142
|
+
|
143
|
+
|
144
|
+
@router.post("", include_in_schema=False)
|
145
|
+
@router.post("/")
|
146
|
+
async def ocr(entity: Entity, request: Request):
|
147
|
+
metadata_field_name = get_metadata_name()
|
148
|
+
if not entity.file_type_group == "image":
|
149
|
+
return {metadata_field_name: "{}"}
|
150
|
+
|
151
|
+
# Check if the metadata field already exists and has a non-empty value
|
152
|
+
existing_metadata = entity.get_metadata_by_key(metadata_field_name)
|
153
|
+
if existing_metadata and existing_metadata.value and existing_metadata.value.strip():
|
154
|
+
logger.info(f"Skipping OCR processing for file: {entity.filepath} due to existing metadata")
|
155
|
+
return {metadata_field_name: existing_metadata.value}
|
156
|
+
|
157
|
+
# Check if the entity contains the tag "low_info"
|
158
|
+
if any(tag.name == "low_info" for tag in entity.tags):
|
159
|
+
logger.info(f"Skipping OCR processing for file: {entity.filepath} due to 'low_info' tag")
|
160
|
+
return {metadata_field_name: "{}"}
|
161
|
+
|
162
|
+
location_url = request.headers.get("Location")
|
163
|
+
if not location_url:
|
164
|
+
raise HTTPException(status_code=400, detail="Location header is missing")
|
165
|
+
|
166
|
+
patch_url = f"{location_url}/metadata"
|
167
|
+
|
168
|
+
ocr_result = await predict(entity.filepath)
|
169
|
+
if ocr_result:
|
170
|
+
filtered_results = [r for r in ocr_result if r['score'] > 0.5][:10]
|
171
|
+
texts = [f"{r['rec_txt']}({r['score']:.2f})" for r in filtered_results]
|
172
|
+
total = len(ocr_result)
|
173
|
+
logger.info(f"First {len(texts)}/{total} OCR results: {texts}")
|
174
|
+
else:
|
175
|
+
logger.info(f"No OCR result found for file: {entity.filepath}")
|
176
|
+
return {metadata_field_name: "{}"}
|
177
|
+
|
178
|
+
# Call the URL to patch the entity's metadata
|
179
|
+
async with httpx.AsyncClient() as client:
|
180
|
+
response = await client.patch(
|
181
|
+
patch_url,
|
182
|
+
json={
|
183
|
+
"metadata_entries": [
|
184
|
+
{
|
185
|
+
"key": metadata_field_name,
|
186
|
+
"value": json.dumps(
|
187
|
+
ocr_result,
|
188
|
+
default=lambda o: o.item() if hasattr(o, "item") else o,
|
189
|
+
),
|
190
|
+
"source": PLUGIN_NAME,
|
191
|
+
"data_type": MetadataType.JSON_DATA.value,
|
192
|
+
}
|
193
|
+
]
|
194
|
+
},
|
195
|
+
timeout=30,
|
196
|
+
)
|
197
|
+
|
198
|
+
# Check if the patch request was successful
|
199
|
+
if response.status_code != 200:
|
200
|
+
raise HTTPException(
|
201
|
+
status_code=response.status_code, detail="Failed to patch entity metadata"
|
202
|
+
)
|
203
|
+
|
204
|
+
return {
|
205
|
+
metadata_field_name: json.dumps(
|
206
|
+
ocr_result,
|
207
|
+
default=lambda o: o.item() if hasattr(o, "item") else o,
|
208
|
+
)
|
209
|
+
}
|
210
|
+
|
211
|
+
|
212
|
+
def init_plugin(config):
|
213
|
+
global endpoint, token, concurrency, semaphore, use_local, ocr, thread_pool
|
214
|
+
endpoint = config.endpoint
|
215
|
+
token = config.token
|
216
|
+
concurrency = config.concurrency
|
217
|
+
use_local = config.use_local
|
218
|
+
semaphore = asyncio.Semaphore(concurrency)
|
219
|
+
|
220
|
+
if use_local:
|
221
|
+
config_path = os.path.join(os.path.dirname(__file__), "ppocr.yaml")
|
222
|
+
|
223
|
+
# Load and update the config file with absolute model paths
|
224
|
+
with open(config_path, 'r') as f:
|
225
|
+
ocr_config = yaml.safe_load(f)
|
226
|
+
|
227
|
+
model_dir = os.path.join(os.path.dirname(__file__), "models")
|
228
|
+
ocr_config['Det']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Det']['model_path']))
|
229
|
+
ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path']))
|
230
|
+
ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path']))
|
231
|
+
|
232
|
+
# Save the updated config to a temporary file with strings wrapped in double quotes
|
233
|
+
temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml")
|
234
|
+
with open(temp_config_path, 'w') as f:
|
235
|
+
yaml.safe_dump(ocr_config, f)
|
236
|
+
|
237
|
+
if platform.system() == 'Windows' and 'Intel' in cpuinfo.get_cpu_info()['brand_raw']:
|
238
|
+
from rapidocr_openvino import RapidOCR
|
239
|
+
ocr = RapidOCR(config_path=temp_config_path)
|
240
|
+
else:
|
241
|
+
from rapidocr_onnxruntime import RapidOCR
|
242
|
+
ocr = RapidOCR(config_path=temp_config_path)
|
243
|
+
thread_pool = ThreadPoolExecutor(max_workers=concurrency)
|
244
|
+
|
245
|
+
logger.info("OCR plugin initialized")
|
246
|
+
logger.info(f"Endpoint: {endpoint}")
|
247
|
+
logger.info(f"Token: {token}")
|
248
|
+
logger.info(f"Concurrency: {concurrency}")
|
249
|
+
logger.info(f"Use local: {use_local}")
|
250
|
+
if use_local:
|
251
|
+
logger.info(f"OCR library: {'rapidocr_openvino' if platform.system() == 'Windows' and 'Intel' in cpuinfo.get_cpu_info()['brand_raw'] else 'rapidocr_onnxruntime'}")
|
Binary file
|
Binary file
|
Binary file
|
@@ -0,0 +1,43 @@
|
|
1
|
+
Global:
|
2
|
+
text_score: 0.5
|
3
|
+
use_det: true
|
4
|
+
use_cls: true
|
5
|
+
use_rec: true
|
6
|
+
print_verbose: false
|
7
|
+
min_height: 30
|
8
|
+
width_height_ratio: 40
|
9
|
+
max_side_len: 2000
|
10
|
+
min_side_len: 30
|
11
|
+
|
12
|
+
Det:
|
13
|
+
use_cuda: true
|
14
|
+
|
15
|
+
model_path: models/ch_PP-OCRv4_det_infer.onnx
|
16
|
+
|
17
|
+
limit_side_len: 1500
|
18
|
+
limit_type: min
|
19
|
+
|
20
|
+
thresh: 0.3
|
21
|
+
box_thresh: 0.3
|
22
|
+
max_candidates: 1000
|
23
|
+
unclip_ratio: 1.6
|
24
|
+
use_dilation: true
|
25
|
+
score_mode: fast
|
26
|
+
|
27
|
+
Cls:
|
28
|
+
use_cuda: true
|
29
|
+
|
30
|
+
model_path: models/ch_ppocr_mobile_v2.0_cls_train.onnx
|
31
|
+
|
32
|
+
cls_image_shape: [3, 48, 192]
|
33
|
+
cls_batch_num: 6
|
34
|
+
cls_thresh: 0.9
|
35
|
+
label_list: ['0', '180']
|
36
|
+
|
37
|
+
Rec:
|
38
|
+
use_cuda: true
|
39
|
+
|
40
|
+
model_path: models/ch_PP-OCRv4_rec_infer.onnx
|
41
|
+
|
42
|
+
rec_img_shape: [3, 48, 320]
|
43
|
+
rec_batch_num: 6
|
@@ -0,0 +1,44 @@
|
|
1
|
+
Global:
|
2
|
+
text_score: 0.5
|
3
|
+
use_det: true
|
4
|
+
use_cls: true
|
5
|
+
use_rec: true
|
6
|
+
print_verbose: false
|
7
|
+
min_height: 30
|
8
|
+
width_height_ratio: 40
|
9
|
+
use_space_char: true
|
10
|
+
max_side_len: 2000
|
11
|
+
min_side_len: 30
|
12
|
+
|
13
|
+
Det:
|
14
|
+
use_cuda: false
|
15
|
+
|
16
|
+
model_path: models/ch_PP-OCRv4_det_infer.onnx
|
17
|
+
|
18
|
+
limit_side_len: 1500
|
19
|
+
limit_type: min
|
20
|
+
|
21
|
+
thresh: 0.3
|
22
|
+
box_thresh: 0.3
|
23
|
+
max_candidates: 1000
|
24
|
+
unclip_ratio: 1.6
|
25
|
+
use_dilation: true
|
26
|
+
score_mode: fast
|
27
|
+
|
28
|
+
Cls:
|
29
|
+
use_cuda: false
|
30
|
+
|
31
|
+
model_path: models/ch_ppocr_mobile_v2.0_cls_train.onnx
|
32
|
+
|
33
|
+
cls_image_shape: [3, 48, 192]
|
34
|
+
cls_batch_num: 6
|
35
|
+
cls_thresh: 0.9
|
36
|
+
label_list: ['0', '180']
|
37
|
+
|
38
|
+
Rec:
|
39
|
+
use_cuda: false
|
40
|
+
|
41
|
+
model_path: models/ch_PP-OCRv4_rec_infer.onnx
|
42
|
+
|
43
|
+
rec_img_shape: [3, 48, 320]
|
44
|
+
rec_batch_num: 6
|