pensiev 0.25.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memos/__init__.py +6 -0
- memos/cmds/__init__.py +0 -0
- memos/cmds/library.py +1289 -0
- memos/cmds/plugin.py +96 -0
- memos/commands.py +865 -0
- memos/config.py +225 -0
- memos/crud.py +605 -0
- memos/databases/__init__.py +0 -0
- memos/databases/initializers.py +481 -0
- memos/dataset_extractor_for_florence.py +165 -0
- memos/dataset_extractor_for_internvl2.py +192 -0
- memos/default_config.yaml +88 -0
- memos/embedding.py +129 -0
- memos/frame_extractor.py +53 -0
- memos/logging_config.py +35 -0
- memos/main.py +104 -0
- memos/migrations/alembic/README +1 -0
- memos/migrations/alembic/__pycache__/env.cpython-310.pyc +0 -0
- memos/migrations/alembic/env.py +108 -0
- memos/migrations/alembic/script.py.mako +30 -0
- memos/migrations/alembic/versions/00904ac8c6fc_add_indexes_to_entitymodel.py +63 -0
- memos/migrations/alembic/versions/04acdaf75664_add_indices_to_entitytags_and_metadata.py +86 -0
- memos/migrations/alembic/versions/12504c5b1d3c_add_extra_columns_for_embedding.py +67 -0
- memos/migrations/alembic/versions/31a1ad0e10b3_add_entity_plugin_status.py +71 -0
- memos/migrations/alembic/versions/__pycache__/00904ac8c6fc_add_indexes_to_entitymodel.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/04acdaf75664_add_indices_to_entitytags_and_metadata.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/12504c5b1d3c_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/20f5ecab014d_add_entity_plugin_status.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/31a1ad0e10b3_add_entity_plugin_status.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/4fcb062c5128_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/d10c55fbb7d2_add_index_for_entity_file_type_group_.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/f8f158182416_add_active_app_index.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/d10c55fbb7d2_add_index_for_entity_file_type_group_.py +44 -0
- memos/migrations/alembic/versions/f8f158182416_add_active_app_index.py +75 -0
- memos/migrations/alembic.ini +116 -0
- memos/migrations.py +19 -0
- memos/models.py +199 -0
- memos/plugins/__init__.py +0 -0
- memos/plugins/ocr/__init__.py +0 -0
- memos/plugins/ocr/main.py +251 -0
- memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx +0 -0
- memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx +0 -0
- memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx +0 -0
- memos/plugins/ocr/ppocr-gpu.yaml +43 -0
- memos/plugins/ocr/ppocr.yaml +44 -0
- memos/plugins/ocr/server.py +227 -0
- memos/plugins/ocr/temp_ppocr.yaml +42 -0
- memos/plugins/vlm/__init__.py +0 -0
- memos/plugins/vlm/main.py +251 -0
- memos/prepare_dataset.py +107 -0
- memos/process_webp.py +55 -0
- memos/read_metadata.py +32 -0
- memos/record.py +358 -0
- memos/schemas.py +289 -0
- memos/search.py +1198 -0
- memos/server.py +883 -0
- memos/shotsum.py +105 -0
- memos/shotsum_with_ocr.py +145 -0
- memos/simple_tokenizer/dict/README.md +31 -0
- memos/simple_tokenizer/dict/hmm_model.utf8 +34 -0
- memos/simple_tokenizer/dict/idf.utf8 +258826 -0
- memos/simple_tokenizer/dict/jieba.dict.utf8 +348982 -0
- memos/simple_tokenizer/dict/pos_dict/char_state_tab.utf8 +6653 -0
- memos/simple_tokenizer/dict/pos_dict/prob_emit.utf8 +166 -0
- memos/simple_tokenizer/dict/pos_dict/prob_start.utf8 +259 -0
- memos/simple_tokenizer/dict/pos_dict/prob_trans.utf8 +5222 -0
- memos/simple_tokenizer/dict/stop_words.utf8 +1534 -0
- memos/simple_tokenizer/dict/user.dict.utf8 +4 -0
- memos/simple_tokenizer/linux/libsimple.so +0 -0
- memos/simple_tokenizer/macos/libsimple.dylib +0 -0
- memos/simple_tokenizer/windows/simple.dll +0 -0
- memos/static/_app/immutable/assets/0.e250c031.css +1 -0
- memos/static/_app/immutable/assets/_layout.e7937cfe.css +1 -0
- memos/static/_app/immutable/chunks/index.5c08976b.js +1 -0
- memos/static/_app/immutable/chunks/index.60ee613b.js +4 -0
- memos/static/_app/immutable/chunks/runtime.a7926cf6.js +5 -0
- memos/static/_app/immutable/chunks/scheduler.5c1cff6e.js +1 -0
- memos/static/_app/immutable/chunks/singletons.583bdf4e.js +1 -0
- memos/static/_app/immutable/entry/app.666c1643.js +1 -0
- memos/static/_app/immutable/entry/start.aed5c701.js +3 -0
- memos/static/_app/immutable/nodes/0.5862ea38.js +7 -0
- memos/static/_app/immutable/nodes/1.35378a5e.js +1 -0
- memos/static/_app/immutable/nodes/2.1ccf9ea5.js +81 -0
- memos/static/_app/version.json +1 -0
- memos/static/app.html +36 -0
- memos/static/favicon.png +0 -0
- memos/static/logos/memos_logo_1024.png +0 -0
- memos/static/logos/memos_logo_1024@2x.png +0 -0
- memos/static/logos/memos_logo_128.png +0 -0
- memos/static/logos/memos_logo_128@2x.png +0 -0
- memos/static/logos/memos_logo_16.png +0 -0
- memos/static/logos/memos_logo_16@2x.png +0 -0
- memos/static/logos/memos_logo_256.png +0 -0
- memos/static/logos/memos_logo_256@2x.png +0 -0
- memos/static/logos/memos_logo_32.png +0 -0
- memos/static/logos/memos_logo_32@2x.png +0 -0
- memos/static/logos/memos_logo_512.png +0 -0
- memos/static/logos/memos_logo_512@2x.png +0 -0
- memos/static/logos/memos_logo_64.png +0 -0
- memos/static/logos/memos_logo_64@2x.png +0 -0
- memos/test_server.py +802 -0
- memos/utils.py +49 -0
- memos_ml_backends/florence2_server.py +176 -0
- memos_ml_backends/qwen2vl_server.py +182 -0
- memos_ml_backends/schemas.py +48 -0
- pensiev-0.25.5.dist-info/LICENSE +201 -0
- pensiev-0.25.5.dist-info/METADATA +541 -0
- pensiev-0.25.5.dist-info/RECORD +111 -0
- pensiev-0.25.5.dist-info/WHEEL +5 -0
- pensiev-0.25.5.dist-info/entry_points.txt +2 -0
- pensiev-0.25.5.dist-info/top_level.txt +2 -0
@@ -0,0 +1,227 @@
|
|
1
|
+
from PIL import Image
|
2
|
+
import numpy as np
|
3
|
+
import logging
|
4
|
+
from fastapi import FastAPI, Body, HTTPException
|
5
|
+
from contextlib import asynccontextmanager
|
6
|
+
import base64
|
7
|
+
import io
|
8
|
+
import asyncio
|
9
|
+
from pydantic import BaseModel, Field
|
10
|
+
from typing import List
|
11
|
+
from multiprocessing import Pool
|
12
|
+
import threading
|
13
|
+
import time
|
14
|
+
import uvicorn
|
15
|
+
import os
|
16
|
+
|
17
|
+
# Configure logger
|
18
|
+
logging.basicConfig(level=logging.INFO)
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
# 创建进程池
|
22
|
+
process_pool = None
|
23
|
+
|
24
|
+
# 从环境变量中读取参数
|
25
|
+
max_workers = int(os.getenv("MAX_WORKERS", 1))
|
26
|
+
|
27
|
+
|
28
|
+
def str_to_bool(value):
|
29
|
+
return value.lower() in ("true", "1", "t", "y", "yes")
|
30
|
+
|
31
|
+
|
32
|
+
use_gpu = str_to_bool(os.getenv("USE_GPU", "false"))
|
33
|
+
|
34
|
+
|
35
|
+
@asynccontextmanager
|
36
|
+
async def lifespan(app):
|
37
|
+
global process_pool
|
38
|
+
if process_pool is None:
|
39
|
+
init_process_pool(max_workers=max_workers, use_gpu=use_gpu)
|
40
|
+
yield
|
41
|
+
if process_pool:
|
42
|
+
logger.info("Shutting down process pool...")
|
43
|
+
process_pool.close()
|
44
|
+
process_pool.join()
|
45
|
+
logger.info("Process pool shut down.")
|
46
|
+
|
47
|
+
|
48
|
+
app = FastAPI(lifespan=lifespan)
|
49
|
+
|
50
|
+
|
51
|
+
def init_worker(use_gpu):
|
52
|
+
global ocr
|
53
|
+
ocr = init_ocr(use_gpu)
|
54
|
+
|
55
|
+
|
56
|
+
def init_process_pool(max_workers, use_gpu):
|
57
|
+
global process_pool
|
58
|
+
process_pool = Pool(
|
59
|
+
processes=max_workers, initializer=init_worker, initargs=(use_gpu,)
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
def init_ocr(use_gpu):
|
64
|
+
if use_gpu:
|
65
|
+
try:
|
66
|
+
from rapidocr_paddle import RapidOCR as RapidOCRPaddle
|
67
|
+
|
68
|
+
ocr = RapidOCRPaddle(
|
69
|
+
det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True
|
70
|
+
)
|
71
|
+
logger.info("Initialized OCR with RapidOCR Paddle (GPU)")
|
72
|
+
except ImportError:
|
73
|
+
logger.error(
|
74
|
+
"Failed to import rapidocr_paddle. Make sure it's installed for GPU usage."
|
75
|
+
)
|
76
|
+
raise
|
77
|
+
else:
|
78
|
+
try:
|
79
|
+
from rapidocr_onnxruntime import RapidOCR
|
80
|
+
|
81
|
+
ocr = RapidOCR()
|
82
|
+
logger.info("Initialized OCR with RapidOCR ONNX Runtime (CPU)")
|
83
|
+
except ImportError:
|
84
|
+
logger.error(
|
85
|
+
"Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage."
|
86
|
+
)
|
87
|
+
raise
|
88
|
+
return ocr
|
89
|
+
|
90
|
+
|
91
|
+
def convert_ocr_results(results):
|
92
|
+
if results is None:
|
93
|
+
return []
|
94
|
+
|
95
|
+
converted = []
|
96
|
+
for result in results:
|
97
|
+
item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]}
|
98
|
+
converted.append(item)
|
99
|
+
return converted
|
100
|
+
|
101
|
+
|
102
|
+
def predict(image_data):
|
103
|
+
global ocr
|
104
|
+
if ocr is None:
|
105
|
+
raise ValueError("OCR engine not initialized")
|
106
|
+
|
107
|
+
image = Image.open(io.BytesIO(image_data))
|
108
|
+
img_array = np.array(image)
|
109
|
+
results, _ = ocr(img_array)
|
110
|
+
converted_results = convert_ocr_results(results)
|
111
|
+
return converted_results
|
112
|
+
|
113
|
+
|
114
|
+
def convert_to_python_type(item):
|
115
|
+
if isinstance(item, np.ndarray):
|
116
|
+
return item.tolist()
|
117
|
+
elif isinstance(item, np.generic): # This includes numpy scalars like numpy.float32
|
118
|
+
return item.item()
|
119
|
+
elif isinstance(item, list):
|
120
|
+
return [convert_to_python_type(sub_item) for sub_item in item]
|
121
|
+
elif isinstance(item, dict):
|
122
|
+
return {key: convert_to_python_type(value) for key, value in item.items()}
|
123
|
+
else:
|
124
|
+
return item
|
125
|
+
|
126
|
+
|
127
|
+
async def async_predict(image_data):
|
128
|
+
loop = asyncio.get_running_loop()
|
129
|
+
results = await loop.run_in_executor(
|
130
|
+
None, process_pool.apply, predict, (image_data,)
|
131
|
+
)
|
132
|
+
return results
|
133
|
+
|
134
|
+
|
135
|
+
class OCRResult(BaseModel):
|
136
|
+
dt_boxes: List[List[float]] = Field(..., description="Bounding box coordinates")
|
137
|
+
rec_txt: str = Field(..., description="Recognized text")
|
138
|
+
score: float = Field(..., description="Confidence score")
|
139
|
+
|
140
|
+
|
141
|
+
@app.post("/predict", response_model=List[OCRResult])
|
142
|
+
async def predict_base64(image_base64: str = Body(..., embed=True)):
|
143
|
+
try:
|
144
|
+
if not image_base64:
|
145
|
+
raise HTTPException(status_code=400, detail="Missing image_base64 field")
|
146
|
+
|
147
|
+
# Remove header part if present
|
148
|
+
if image_base64.startswith("data:image"):
|
149
|
+
image_base64 = image_base64.split(",")[1]
|
150
|
+
|
151
|
+
# Decode the base64 image
|
152
|
+
image_data = base64.b64decode(image_base64)
|
153
|
+
|
154
|
+
# 直接传递图像数据给async_predict
|
155
|
+
ocr_result = await async_predict(image_data)
|
156
|
+
|
157
|
+
return convert_to_python_type(ocr_result)
|
158
|
+
|
159
|
+
except Exception as e:
|
160
|
+
logging.error(f"Error during OCR processing: {str(e)}")
|
161
|
+
raise HTTPException(status_code=500, detail=str(e))
|
162
|
+
|
163
|
+
|
164
|
+
shutdown_event = threading.Event()
|
165
|
+
|
166
|
+
|
167
|
+
def signal_handler(signum, frame):
|
168
|
+
logger.info("Received interrupt signal. Initiating shutdown...")
|
169
|
+
shutdown_event.set()
|
170
|
+
|
171
|
+
|
172
|
+
def run_server(app, host, port):
|
173
|
+
config = uvicorn.Config(app, host=host, port=port, loop="asyncio")
|
174
|
+
server = uvicorn.Server(config)
|
175
|
+
server.install_signal_handlers = (
|
176
|
+
lambda: None
|
177
|
+
) # Disable Uvicorn's own signal handlers
|
178
|
+
|
179
|
+
async def serve():
|
180
|
+
await server.serve()
|
181
|
+
|
182
|
+
thread = threading.Thread(target=asyncio.run, args=(serve(),))
|
183
|
+
thread.start()
|
184
|
+
|
185
|
+
try:
|
186
|
+
while not shutdown_event.is_set():
|
187
|
+
time.sleep(1)
|
188
|
+
except KeyboardInterrupt:
|
189
|
+
logger.info("Keyboard interrupt received. Initiating shutdown...")
|
190
|
+
finally:
|
191
|
+
shutdown_event.set()
|
192
|
+
logger.info("Stopping the server...")
|
193
|
+
asyncio.run(server.shutdown())
|
194
|
+
thread.join()
|
195
|
+
logger.info("Server stopped.")
|
196
|
+
|
197
|
+
|
198
|
+
if __name__ == "__main__":
|
199
|
+
import uvicorn
|
200
|
+
import argparse
|
201
|
+
|
202
|
+
parser = argparse.ArgumentParser(description="OCR Service")
|
203
|
+
parser.add_argument(
|
204
|
+
"--port",
|
205
|
+
type=int,
|
206
|
+
default=8000,
|
207
|
+
help="Port to run the OCR service on",
|
208
|
+
)
|
209
|
+
parser.add_argument(
|
210
|
+
"--max-workers",
|
211
|
+
type=int,
|
212
|
+
default=1,
|
213
|
+
help="Maximum number of worker threads for OCR processing",
|
214
|
+
)
|
215
|
+
parser.add_argument(
|
216
|
+
"--gpu",
|
217
|
+
action="store_true",
|
218
|
+
help="Use GPU for OCR processing",
|
219
|
+
)
|
220
|
+
|
221
|
+
args = parser.parse_args()
|
222
|
+
port = args.port
|
223
|
+
max_workers = args.max_workers
|
224
|
+
use_gpu = args.gpu
|
225
|
+
|
226
|
+
run_server(app, "0.0.0.0", port)
|
227
|
+
|
@@ -0,0 +1,42 @@
|
|
1
|
+
Cls:
|
2
|
+
cls_batch_num: 6
|
3
|
+
cls_image_shape:
|
4
|
+
- 3
|
5
|
+
- 48
|
6
|
+
- 192
|
7
|
+
cls_thresh: 0.9
|
8
|
+
label_list:
|
9
|
+
- '0'
|
10
|
+
- '180'
|
11
|
+
model_path: /Users/shanchuanxu/projects/memos/memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx
|
12
|
+
use_cuda: false
|
13
|
+
Det:
|
14
|
+
box_thresh: 0.3
|
15
|
+
limit_side_len: 1500
|
16
|
+
limit_type: min
|
17
|
+
max_candidates: 1000
|
18
|
+
model_path: /Users/shanchuanxu/projects/memos/memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx
|
19
|
+
score_mode: fast
|
20
|
+
thresh: 0.3
|
21
|
+
unclip_ratio: 1.6
|
22
|
+
use_cuda: false
|
23
|
+
use_dilation: true
|
24
|
+
Global:
|
25
|
+
max_side_len: 2000
|
26
|
+
min_height: 30
|
27
|
+
min_side_len: 30
|
28
|
+
print_verbose: false
|
29
|
+
text_score: 0.5
|
30
|
+
use_cls: true
|
31
|
+
use_det: true
|
32
|
+
use_rec: true
|
33
|
+
use_space_char: true
|
34
|
+
width_height_ratio: 40
|
35
|
+
Rec:
|
36
|
+
model_path: /Users/shanchuanxu/projects/memos/memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx
|
37
|
+
rec_batch_num: 6
|
38
|
+
rec_img_shape:
|
39
|
+
- 3
|
40
|
+
- 48
|
41
|
+
- 320
|
42
|
+
use_cuda: false
|
File without changes
|
@@ -0,0 +1,251 @@
|
|
1
|
+
import base64
|
2
|
+
import httpx
|
3
|
+
from PIL import Image
|
4
|
+
import asyncio
|
5
|
+
from typing import Optional
|
6
|
+
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
7
|
+
from memos.schemas import Entity, MetadataType
|
8
|
+
import logging
|
9
|
+
import uvicorn
|
10
|
+
import os
|
11
|
+
import io
|
12
|
+
import numpy as np
|
13
|
+
|
14
|
+
|
15
|
+
# Configure logger
|
16
|
+
logging.basicConfig(level=logging.INFO)
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
PLUGIN_NAME = "vlm"
|
20
|
+
|
21
|
+
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
22
|
+
|
23
|
+
modelname = None
|
24
|
+
endpoint = None
|
25
|
+
token = None
|
26
|
+
concurrency = None
|
27
|
+
semaphore = None
|
28
|
+
force_jpeg = None
|
29
|
+
prompt = None
|
30
|
+
|
31
|
+
|
32
|
+
def get_metadata_name() -> str:
|
33
|
+
"""Return the metadata field name used by this plugin."""
|
34
|
+
global modelname
|
35
|
+
return f"{modelname.replace('-', '_')}_result"
|
36
|
+
|
37
|
+
|
38
|
+
def image2base64(img_path):
|
39
|
+
try:
|
40
|
+
with Image.open(img_path) as img:
|
41
|
+
img.verify() # Verify the image file
|
42
|
+
|
43
|
+
with Image.open(img_path) as img:
|
44
|
+
# Check image size and skip if it's too small
|
45
|
+
if img.width < 10 or img.height < 10:
|
46
|
+
logger.warning(f"Image is too small: {img.width}x{img.height}. Skipping processing.")
|
47
|
+
return None
|
48
|
+
|
49
|
+
# Convert image to RGB mode (removes alpha channel if present)
|
50
|
+
img = img.convert("RGB")
|
51
|
+
|
52
|
+
# Convert to numpy array and check shape
|
53
|
+
img_array = np.array(img)
|
54
|
+
logger.info(f"Image shape: {img_array.shape}")
|
55
|
+
|
56
|
+
if img_array.shape[2] != 3:
|
57
|
+
logger.warning(f"Unexpected number of channels: {img_array.shape[2]}. Expected 3. Skipping processing.")
|
58
|
+
return None
|
59
|
+
|
60
|
+
if force_jpeg:
|
61
|
+
# Save as JPEG in memory
|
62
|
+
buffer = io.BytesIO()
|
63
|
+
img.save(buffer, format="JPEG")
|
64
|
+
buffer.seek(0)
|
65
|
+
encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
66
|
+
else:
|
67
|
+
# Use original format, but ensure it's RGB
|
68
|
+
buffer = io.BytesIO()
|
69
|
+
img.save(buffer, format=img.format or "JPEG")
|
70
|
+
buffer.seek(0)
|
71
|
+
encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
72
|
+
return encoded_string
|
73
|
+
except Exception as e:
|
74
|
+
logger.error(f"Error processing image {img_path}: {str(e)}")
|
75
|
+
return None
|
76
|
+
|
77
|
+
|
78
|
+
async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
|
79
|
+
async with semaphore:
|
80
|
+
try:
|
81
|
+
response = await client.post(
|
82
|
+
f"{endpoint}/v1/chat/completions",
|
83
|
+
json=request_data,
|
84
|
+
timeout=60,
|
85
|
+
headers=headers,
|
86
|
+
)
|
87
|
+
response.raise_for_status()
|
88
|
+
result = response.json()
|
89
|
+
choices = result.get("choices", [])
|
90
|
+
if (
|
91
|
+
choices
|
92
|
+
and "message" in choices[0]
|
93
|
+
and "content" in choices[0]["message"]
|
94
|
+
):
|
95
|
+
return choices[0]["message"]["content"]
|
96
|
+
return ""
|
97
|
+
except Exception as e:
|
98
|
+
logger.error(f"Exception occurred: {str(e)}")
|
99
|
+
return None
|
100
|
+
|
101
|
+
|
102
|
+
async def predict(
|
103
|
+
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
104
|
+
) -> Optional[str]:
|
105
|
+
return await predict_remote(endpoint, modelname, img_path, token)
|
106
|
+
|
107
|
+
|
108
|
+
async def predict_remote(
|
109
|
+
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
110
|
+
) -> Optional[str]:
|
111
|
+
img_base64 = image2base64(img_path)
|
112
|
+
if not img_base64:
|
113
|
+
logger.warning(f"Skipping processing for file: {img_path} due to invalid or small image")
|
114
|
+
return None
|
115
|
+
|
116
|
+
mime_type = (
|
117
|
+
"image/jpeg" if force_jpeg else "image/jpeg"
|
118
|
+
) # Default to JPEG if force_jpeg is True
|
119
|
+
|
120
|
+
if not force_jpeg:
|
121
|
+
# Only determine MIME type if not forcing JPEG
|
122
|
+
_, file_extension = os.path.splitext(img_path)
|
123
|
+
file_extension = file_extension.lower()[1:]
|
124
|
+
mime_types = {
|
125
|
+
"png": "image/png",
|
126
|
+
"jpg": "image/jpeg",
|
127
|
+
"jpeg": "image/jpeg",
|
128
|
+
"webp": "image/webp",
|
129
|
+
}
|
130
|
+
mime_type = mime_types.get(file_extension, "image/jpeg")
|
131
|
+
|
132
|
+
request_data = {
|
133
|
+
"model": modelname,
|
134
|
+
"messages": [
|
135
|
+
{
|
136
|
+
"role": "user",
|
137
|
+
"content": [
|
138
|
+
{
|
139
|
+
"type": "image_url",
|
140
|
+
"image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
|
141
|
+
},
|
142
|
+
{"type": "text", "text": prompt}, # Use the global prompt variable here
|
143
|
+
],
|
144
|
+
}
|
145
|
+
],
|
146
|
+
"stream": False,
|
147
|
+
"max_tokens": 1024,
|
148
|
+
"temperature": 0.1,
|
149
|
+
"repetition_penalty": 1.1,
|
150
|
+
"top_p": 0.8,
|
151
|
+
}
|
152
|
+
async with httpx.AsyncClient() as client:
|
153
|
+
headers = {}
|
154
|
+
if token:
|
155
|
+
headers["Authorization"] = f"Bearer {token.get_secret_value()}"
|
156
|
+
return await fetch(endpoint, client, request_data, headers=headers)
|
157
|
+
|
158
|
+
|
159
|
+
@router.get("/")
|
160
|
+
async def read_root():
|
161
|
+
return {"healthy": True}
|
162
|
+
|
163
|
+
|
164
|
+
@router.post("", include_in_schema=False)
|
165
|
+
@router.post("/")
|
166
|
+
async def vlm(entity: Entity, request: Request):
|
167
|
+
global modelname, endpoint, token
|
168
|
+
metadata_field_name = get_metadata_name()
|
169
|
+
if not entity.file_type_group == "image":
|
170
|
+
return {metadata_field_name: ""}
|
171
|
+
|
172
|
+
# Check if the METADATA_FIELD_NAME field is empty or null
|
173
|
+
existing_metadata = entity.get_metadata_by_key(metadata_field_name)
|
174
|
+
if (
|
175
|
+
existing_metadata
|
176
|
+
and existing_metadata.value
|
177
|
+
and existing_metadata.value.strip()
|
178
|
+
):
|
179
|
+
logger.info(
|
180
|
+
f"Skipping processing for file: {entity.filepath} due to existing metadata"
|
181
|
+
)
|
182
|
+
# If the field is not empty, return without processing
|
183
|
+
return {metadata_field_name: existing_metadata.value}
|
184
|
+
|
185
|
+
# Check if the entity contains the tag "low_info"
|
186
|
+
if any(tag.name == "low_info" for tag in entity.tags):
|
187
|
+
# If the tag is present, return without processing
|
188
|
+
logger.info(
|
189
|
+
f"Skipping processing for file: {entity.filepath} due to 'low_info' tag"
|
190
|
+
)
|
191
|
+
return {metadata_field_name: ""}
|
192
|
+
|
193
|
+
location_url = request.headers.get("Location")
|
194
|
+
if not location_url:
|
195
|
+
raise HTTPException(status_code=400, detail="Location header is missing")
|
196
|
+
|
197
|
+
patch_url = f"{location_url}/metadata"
|
198
|
+
|
199
|
+
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
|
200
|
+
|
201
|
+
logger.info(f"VLM result: {vlm_result[:100]}...")
|
202
|
+
if not vlm_result:
|
203
|
+
logger.info(f"No VLM result found for file: {entity.filepath}")
|
204
|
+
return {metadata_field_name: "{}"}
|
205
|
+
|
206
|
+
async with httpx.AsyncClient() as client:
|
207
|
+
response = await client.patch(
|
208
|
+
patch_url,
|
209
|
+
json={
|
210
|
+
"metadata_entries": [
|
211
|
+
{
|
212
|
+
"key": metadata_field_name,
|
213
|
+
"value": vlm_result,
|
214
|
+
"source": PLUGIN_NAME,
|
215
|
+
"data_type": MetadataType.TEXT_DATA.value,
|
216
|
+
}
|
217
|
+
]
|
218
|
+
},
|
219
|
+
timeout=30,
|
220
|
+
)
|
221
|
+
|
222
|
+
if response.status_code != 200:
|
223
|
+
raise HTTPException(
|
224
|
+
status_code=response.status_code, detail="Failed to patch entity metadata"
|
225
|
+
)
|
226
|
+
|
227
|
+
return {
|
228
|
+
metadata_field_name: vlm_result,
|
229
|
+
}
|
230
|
+
|
231
|
+
|
232
|
+
def init_plugin(config):
|
233
|
+
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, prompt
|
234
|
+
|
235
|
+
modelname = config.modelname
|
236
|
+
endpoint = config.endpoint
|
237
|
+
token = config.token
|
238
|
+
concurrency = config.concurrency
|
239
|
+
force_jpeg = config.force_jpeg
|
240
|
+
prompt = config.prompt
|
241
|
+
semaphore = asyncio.Semaphore(concurrency)
|
242
|
+
|
243
|
+
# Print the parameters
|
244
|
+
logger.info("VLM plugin initialized")
|
245
|
+
logger.info(f"Model Name: {modelname}")
|
246
|
+
logger.info(f"Endpoint: {endpoint}")
|
247
|
+
logger.info(f"Token: {token}")
|
248
|
+
logger.info(f"Concurrency: {concurrency}")
|
249
|
+
logger.info(f"Force JPEG: {force_jpeg}")
|
250
|
+
logger.info(f"Prompt: {prompt}")
|
251
|
+
|
memos/prepare_dataset.py
ADDED
@@ -0,0 +1,107 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import shutil
|
4
|
+
from tqdm import tqdm
|
5
|
+
from PIL import Image
|
6
|
+
import io
|
7
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8
|
+
|
9
|
+
|
10
|
+
def compress_and_convert_image(image_path, output_path, max_size_mb=2):
|
11
|
+
with Image.open(image_path) as img:
|
12
|
+
if (
|
13
|
+
image_path.lower().endswith(".png")
|
14
|
+
and os.path.getsize(image_path) > max_size_mb * 1024 * 1024
|
15
|
+
):
|
16
|
+
# 将 PNG 转换为 JPG 并压缩
|
17
|
+
img = img.convert("RGB")
|
18
|
+
output_path = output_path.rsplit(".", 1)[0] + ".jpg"
|
19
|
+
img.save(output_path, "JPEG", quality=85)
|
20
|
+
else:
|
21
|
+
# 直接复制原图
|
22
|
+
img.save(output_path)
|
23
|
+
return output_path
|
24
|
+
|
25
|
+
|
26
|
+
def process_image(data, shots_dir):
|
27
|
+
image_path = data.get("image")
|
28
|
+
if image_path and os.path.exists(image_path):
|
29
|
+
image_name = os.path.basename(image_path)
|
30
|
+
new_image_path = os.path.join(shots_dir, image_name)
|
31
|
+
if new_image_path == image_path:
|
32
|
+
return None
|
33
|
+
try:
|
34
|
+
new_image_path = compress_and_convert_image(image_path, new_image_path)
|
35
|
+
data["image"] = os.path.relpath(new_image_path, shots_dir)
|
36
|
+
return data, os.path.basename(new_image_path)
|
37
|
+
except Exception as e:
|
38
|
+
print(f"Error processing image {image_path}: {e}")
|
39
|
+
elif not image_path:
|
40
|
+
return data, None
|
41
|
+
return None
|
42
|
+
|
43
|
+
|
44
|
+
def update_image_paths(input_file, output_file, shots_dir, max_workers=8):
|
45
|
+
if not os.path.exists(shots_dir):
|
46
|
+
os.makedirs(shots_dir)
|
47
|
+
|
48
|
+
updated_data = []
|
49
|
+
seen_images = set()
|
50
|
+
|
51
|
+
with open(input_file, "r") as infile:
|
52
|
+
data_list = [json.loads(line) for line in infile]
|
53
|
+
|
54
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
55
|
+
futures = [
|
56
|
+
executor.submit(process_image, data, shots_dir) for data in data_list
|
57
|
+
]
|
58
|
+
|
59
|
+
for future in tqdm(
|
60
|
+
as_completed(futures), total=len(futures), desc="Processing images"
|
61
|
+
):
|
62
|
+
result = future.result()
|
63
|
+
if result:
|
64
|
+
data, image_name = result
|
65
|
+
if image_name and image_name not in seen_images:
|
66
|
+
seen_images.add(image_name)
|
67
|
+
updated_data.append(data)
|
68
|
+
elif not image_name:
|
69
|
+
updated_data.append(data)
|
70
|
+
|
71
|
+
with open(output_file, "w") as outfile:
|
72
|
+
for item in tqdm(updated_data, desc="Writing updated data", unit="item"):
|
73
|
+
json.dump(item, outfile, ensure_ascii=False)
|
74
|
+
outfile.write("\n")
|
75
|
+
|
76
|
+
|
77
|
+
if __name__ == "__main__":
|
78
|
+
import argparse
|
79
|
+
|
80
|
+
parser = argparse.ArgumentParser(description="Update image paths in the dataset.")
|
81
|
+
parser.add_argument(
|
82
|
+
"--input_file",
|
83
|
+
type=str,
|
84
|
+
default="dataset.updated.jsonl",
|
85
|
+
help="Path to the input JSONL file.",
|
86
|
+
)
|
87
|
+
parser.add_argument(
|
88
|
+
"--output_file",
|
89
|
+
type=str,
|
90
|
+
default="dataset.updated.formatted.jsonl",
|
91
|
+
help="Path to the output JSONL file.",
|
92
|
+
)
|
93
|
+
parser.add_argument(
|
94
|
+
"--shots_dir",
|
95
|
+
type=str,
|
96
|
+
default="shots",
|
97
|
+
help="Directory to save processed images.",
|
98
|
+
)
|
99
|
+
parser.add_argument(
|
100
|
+
"--max_workers", type=int, default=8, help="Maximum number of worker threads."
|
101
|
+
)
|
102
|
+
|
103
|
+
args = parser.parse_args()
|
104
|
+
|
105
|
+
update_image_paths(
|
106
|
+
args.input_file, args.output_file, args.shots_dir, args.max_workers
|
107
|
+
)
|
memos/process_webp.py
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from PIL import Image
|
3
|
+
import piexif
|
4
|
+
import json
|
5
|
+
from memos.utils import write_image_metadata, get_image_metadata
|
6
|
+
from tqdm import tqdm
|
7
|
+
|
8
|
+
def convert_webp_metadata(directory):
|
9
|
+
webp_files = list(Path(directory).glob('**/*.webp'))
|
10
|
+
|
11
|
+
for webp_file in tqdm(webp_files, desc="Converting WebP metadata", unit="file"):
|
12
|
+
try:
|
13
|
+
# Try to get metadata using the new method
|
14
|
+
new_metadata = get_image_metadata(webp_file)
|
15
|
+
|
16
|
+
if new_metadata:
|
17
|
+
tqdm.write(f"Skipping {webp_file}: Already in new format")
|
18
|
+
continue
|
19
|
+
|
20
|
+
# If new method fails, try to get old metadata
|
21
|
+
img = Image.open(webp_file)
|
22
|
+
old_metadata = img.info.get("exif", None)
|
23
|
+
|
24
|
+
if old_metadata is None:
|
25
|
+
tqdm.write(f"Skipping {webp_file}: No metadata found")
|
26
|
+
continue
|
27
|
+
|
28
|
+
if isinstance(old_metadata, bytes):
|
29
|
+
try:
|
30
|
+
old_metadata = old_metadata.decode('utf-8')
|
31
|
+
except UnicodeDecodeError:
|
32
|
+
tqdm.write(f"Skipping {webp_file}: Unable to decode metadata")
|
33
|
+
continue
|
34
|
+
|
35
|
+
try:
|
36
|
+
metadata = json.loads(old_metadata)
|
37
|
+
except json.JSONDecodeError:
|
38
|
+
tqdm.write(f"Skipping {webp_file}: Invalid metadata format")
|
39
|
+
continue
|
40
|
+
|
41
|
+
# Convert to new format
|
42
|
+
write_image_metadata(webp_file, metadata)
|
43
|
+
tqdm.write(f"Converted metadata for {webp_file}")
|
44
|
+
|
45
|
+
except Exception as e:
|
46
|
+
tqdm.write(f"Error processing {webp_file}: {str(e)}")
|
47
|
+
|
48
|
+
if __name__ == "__main__":
|
49
|
+
import sys
|
50
|
+
if len(sys.argv) != 2:
|
51
|
+
print("Usage: python convert_webp_metadata.py <directory>")
|
52
|
+
sys.exit(1)
|
53
|
+
|
54
|
+
directory = sys.argv[1]
|
55
|
+
convert_webp_metadata(directory)
|