pensiev 0.25.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. memos/__init__.py +6 -0
  2. memos/cmds/__init__.py +0 -0
  3. memos/cmds/library.py +1289 -0
  4. memos/cmds/plugin.py +96 -0
  5. memos/commands.py +865 -0
  6. memos/config.py +225 -0
  7. memos/crud.py +605 -0
  8. memos/databases/__init__.py +0 -0
  9. memos/databases/initializers.py +481 -0
  10. memos/dataset_extractor_for_florence.py +165 -0
  11. memos/dataset_extractor_for_internvl2.py +192 -0
  12. memos/default_config.yaml +88 -0
  13. memos/embedding.py +129 -0
  14. memos/frame_extractor.py +53 -0
  15. memos/logging_config.py +35 -0
  16. memos/main.py +104 -0
  17. memos/migrations/alembic/README +1 -0
  18. memos/migrations/alembic/__pycache__/env.cpython-310.pyc +0 -0
  19. memos/migrations/alembic/env.py +108 -0
  20. memos/migrations/alembic/script.py.mako +30 -0
  21. memos/migrations/alembic/versions/00904ac8c6fc_add_indexes_to_entitymodel.py +63 -0
  22. memos/migrations/alembic/versions/04acdaf75664_add_indices_to_entitytags_and_metadata.py +86 -0
  23. memos/migrations/alembic/versions/12504c5b1d3c_add_extra_columns_for_embedding.py +67 -0
  24. memos/migrations/alembic/versions/31a1ad0e10b3_add_entity_plugin_status.py +71 -0
  25. memos/migrations/alembic/versions/__pycache__/00904ac8c6fc_add_indexes_to_entitymodel.cpython-310.pyc +0 -0
  26. memos/migrations/alembic/versions/__pycache__/04acdaf75664_add_indices_to_entitytags_and_metadata.cpython-310.pyc +0 -0
  27. memos/migrations/alembic/versions/__pycache__/12504c5b1d3c_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
  28. memos/migrations/alembic/versions/__pycache__/20f5ecab014d_add_entity_plugin_status.cpython-310.pyc +0 -0
  29. memos/migrations/alembic/versions/__pycache__/31a1ad0e10b3_add_entity_plugin_status.cpython-310.pyc +0 -0
  30. memos/migrations/alembic/versions/__pycache__/4fcb062c5128_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
  31. memos/migrations/alembic/versions/__pycache__/d10c55fbb7d2_add_index_for_entity_file_type_group_.cpython-310.pyc +0 -0
  32. memos/migrations/alembic/versions/__pycache__/f8f158182416_add_active_app_index.cpython-310.pyc +0 -0
  33. memos/migrations/alembic/versions/d10c55fbb7d2_add_index_for_entity_file_type_group_.py +44 -0
  34. memos/migrations/alembic/versions/f8f158182416_add_active_app_index.py +75 -0
  35. memos/migrations/alembic.ini +116 -0
  36. memos/migrations.py +19 -0
  37. memos/models.py +199 -0
  38. memos/plugins/__init__.py +0 -0
  39. memos/plugins/ocr/__init__.py +0 -0
  40. memos/plugins/ocr/main.py +251 -0
  41. memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx +0 -0
  42. memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx +0 -0
  43. memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx +0 -0
  44. memos/plugins/ocr/ppocr-gpu.yaml +43 -0
  45. memos/plugins/ocr/ppocr.yaml +44 -0
  46. memos/plugins/ocr/server.py +227 -0
  47. memos/plugins/ocr/temp_ppocr.yaml +42 -0
  48. memos/plugins/vlm/__init__.py +0 -0
  49. memos/plugins/vlm/main.py +251 -0
  50. memos/prepare_dataset.py +107 -0
  51. memos/process_webp.py +55 -0
  52. memos/read_metadata.py +32 -0
  53. memos/record.py +358 -0
  54. memos/schemas.py +289 -0
  55. memos/search.py +1198 -0
  56. memos/server.py +883 -0
  57. memos/shotsum.py +105 -0
  58. memos/shotsum_with_ocr.py +145 -0
  59. memos/simple_tokenizer/dict/README.md +31 -0
  60. memos/simple_tokenizer/dict/hmm_model.utf8 +34 -0
  61. memos/simple_tokenizer/dict/idf.utf8 +258826 -0
  62. memos/simple_tokenizer/dict/jieba.dict.utf8 +348982 -0
  63. memos/simple_tokenizer/dict/pos_dict/char_state_tab.utf8 +6653 -0
  64. memos/simple_tokenizer/dict/pos_dict/prob_emit.utf8 +166 -0
  65. memos/simple_tokenizer/dict/pos_dict/prob_start.utf8 +259 -0
  66. memos/simple_tokenizer/dict/pos_dict/prob_trans.utf8 +5222 -0
  67. memos/simple_tokenizer/dict/stop_words.utf8 +1534 -0
  68. memos/simple_tokenizer/dict/user.dict.utf8 +4 -0
  69. memos/simple_tokenizer/linux/libsimple.so +0 -0
  70. memos/simple_tokenizer/macos/libsimple.dylib +0 -0
  71. memos/simple_tokenizer/windows/simple.dll +0 -0
  72. memos/static/_app/immutable/assets/0.e250c031.css +1 -0
  73. memos/static/_app/immutable/assets/_layout.e7937cfe.css +1 -0
  74. memos/static/_app/immutable/chunks/index.5c08976b.js +1 -0
  75. memos/static/_app/immutable/chunks/index.60ee613b.js +4 -0
  76. memos/static/_app/immutable/chunks/runtime.a7926cf6.js +5 -0
  77. memos/static/_app/immutable/chunks/scheduler.5c1cff6e.js +1 -0
  78. memos/static/_app/immutable/chunks/singletons.583bdf4e.js +1 -0
  79. memos/static/_app/immutable/entry/app.666c1643.js +1 -0
  80. memos/static/_app/immutable/entry/start.aed5c701.js +3 -0
  81. memos/static/_app/immutable/nodes/0.5862ea38.js +7 -0
  82. memos/static/_app/immutable/nodes/1.35378a5e.js +1 -0
  83. memos/static/_app/immutable/nodes/2.1ccf9ea5.js +81 -0
  84. memos/static/_app/version.json +1 -0
  85. memos/static/app.html +36 -0
  86. memos/static/favicon.png +0 -0
  87. memos/static/logos/memos_logo_1024.png +0 -0
  88. memos/static/logos/memos_logo_1024@2x.png +0 -0
  89. memos/static/logos/memos_logo_128.png +0 -0
  90. memos/static/logos/memos_logo_128@2x.png +0 -0
  91. memos/static/logos/memos_logo_16.png +0 -0
  92. memos/static/logos/memos_logo_16@2x.png +0 -0
  93. memos/static/logos/memos_logo_256.png +0 -0
  94. memos/static/logos/memos_logo_256@2x.png +0 -0
  95. memos/static/logos/memos_logo_32.png +0 -0
  96. memos/static/logos/memos_logo_32@2x.png +0 -0
  97. memos/static/logos/memos_logo_512.png +0 -0
  98. memos/static/logos/memos_logo_512@2x.png +0 -0
  99. memos/static/logos/memos_logo_64.png +0 -0
  100. memos/static/logos/memos_logo_64@2x.png +0 -0
  101. memos/test_server.py +802 -0
  102. memos/utils.py +49 -0
  103. memos_ml_backends/florence2_server.py +176 -0
  104. memos_ml_backends/qwen2vl_server.py +182 -0
  105. memos_ml_backends/schemas.py +48 -0
  106. pensiev-0.25.5.dist-info/LICENSE +201 -0
  107. pensiev-0.25.5.dist-info/METADATA +541 -0
  108. pensiev-0.25.5.dist-info/RECORD +111 -0
  109. pensiev-0.25.5.dist-info/WHEEL +5 -0
  110. pensiev-0.25.5.dist-info/entry_points.txt +2 -0
  111. pensiev-0.25.5.dist-info/top_level.txt +2 -0
@@ -0,0 +1,227 @@
1
+ from PIL import Image
2
+ import numpy as np
3
+ import logging
4
+ from fastapi import FastAPI, Body, HTTPException
5
+ from contextlib import asynccontextmanager
6
+ import base64
7
+ import io
8
+ import asyncio
9
+ from pydantic import BaseModel, Field
10
+ from typing import List
11
+ from multiprocessing import Pool
12
+ import threading
13
+ import time
14
+ import uvicorn
15
+ import os
16
+
17
+ # Configure logger
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # 创建进程池
22
+ process_pool = None
23
+
24
+ # 从环境变量中读取参数
25
+ max_workers = int(os.getenv("MAX_WORKERS", 1))
26
+
27
+
28
+ def str_to_bool(value):
29
+ return value.lower() in ("true", "1", "t", "y", "yes")
30
+
31
+
32
+ use_gpu = str_to_bool(os.getenv("USE_GPU", "false"))
33
+
34
+
35
+ @asynccontextmanager
36
+ async def lifespan(app):
37
+ global process_pool
38
+ if process_pool is None:
39
+ init_process_pool(max_workers=max_workers, use_gpu=use_gpu)
40
+ yield
41
+ if process_pool:
42
+ logger.info("Shutting down process pool...")
43
+ process_pool.close()
44
+ process_pool.join()
45
+ logger.info("Process pool shut down.")
46
+
47
+
48
+ app = FastAPI(lifespan=lifespan)
49
+
50
+
51
+ def init_worker(use_gpu):
52
+ global ocr
53
+ ocr = init_ocr(use_gpu)
54
+
55
+
56
+ def init_process_pool(max_workers, use_gpu):
57
+ global process_pool
58
+ process_pool = Pool(
59
+ processes=max_workers, initializer=init_worker, initargs=(use_gpu,)
60
+ )
61
+
62
+
63
+ def init_ocr(use_gpu):
64
+ if use_gpu:
65
+ try:
66
+ from rapidocr_paddle import RapidOCR as RapidOCRPaddle
67
+
68
+ ocr = RapidOCRPaddle(
69
+ det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True
70
+ )
71
+ logger.info("Initialized OCR with RapidOCR Paddle (GPU)")
72
+ except ImportError:
73
+ logger.error(
74
+ "Failed to import rapidocr_paddle. Make sure it's installed for GPU usage."
75
+ )
76
+ raise
77
+ else:
78
+ try:
79
+ from rapidocr_onnxruntime import RapidOCR
80
+
81
+ ocr = RapidOCR()
82
+ logger.info("Initialized OCR with RapidOCR ONNX Runtime (CPU)")
83
+ except ImportError:
84
+ logger.error(
85
+ "Failed to import rapidocr_onnxruntime. Make sure it's installed for CPU usage."
86
+ )
87
+ raise
88
+ return ocr
89
+
90
+
91
+ def convert_ocr_results(results):
92
+ if results is None:
93
+ return []
94
+
95
+ converted = []
96
+ for result in results:
97
+ item = {"dt_boxes": result[0], "rec_txt": result[1], "score": result[2]}
98
+ converted.append(item)
99
+ return converted
100
+
101
+
102
+ def predict(image_data):
103
+ global ocr
104
+ if ocr is None:
105
+ raise ValueError("OCR engine not initialized")
106
+
107
+ image = Image.open(io.BytesIO(image_data))
108
+ img_array = np.array(image)
109
+ results, _ = ocr(img_array)
110
+ converted_results = convert_ocr_results(results)
111
+ return converted_results
112
+
113
+
114
+ def convert_to_python_type(item):
115
+ if isinstance(item, np.ndarray):
116
+ return item.tolist()
117
+ elif isinstance(item, np.generic): # This includes numpy scalars like numpy.float32
118
+ return item.item()
119
+ elif isinstance(item, list):
120
+ return [convert_to_python_type(sub_item) for sub_item in item]
121
+ elif isinstance(item, dict):
122
+ return {key: convert_to_python_type(value) for key, value in item.items()}
123
+ else:
124
+ return item
125
+
126
+
127
+ async def async_predict(image_data):
128
+ loop = asyncio.get_running_loop()
129
+ results = await loop.run_in_executor(
130
+ None, process_pool.apply, predict, (image_data,)
131
+ )
132
+ return results
133
+
134
+
135
+ class OCRResult(BaseModel):
136
+ dt_boxes: List[List[float]] = Field(..., description="Bounding box coordinates")
137
+ rec_txt: str = Field(..., description="Recognized text")
138
+ score: float = Field(..., description="Confidence score")
139
+
140
+
141
+ @app.post("/predict", response_model=List[OCRResult])
142
+ async def predict_base64(image_base64: str = Body(..., embed=True)):
143
+ try:
144
+ if not image_base64:
145
+ raise HTTPException(status_code=400, detail="Missing image_base64 field")
146
+
147
+ # Remove header part if present
148
+ if image_base64.startswith("data:image"):
149
+ image_base64 = image_base64.split(",")[1]
150
+
151
+ # Decode the base64 image
152
+ image_data = base64.b64decode(image_base64)
153
+
154
+ # 直接传递图像数据给async_predict
155
+ ocr_result = await async_predict(image_data)
156
+
157
+ return convert_to_python_type(ocr_result)
158
+
159
+ except Exception as e:
160
+ logging.error(f"Error during OCR processing: {str(e)}")
161
+ raise HTTPException(status_code=500, detail=str(e))
162
+
163
+
164
+ shutdown_event = threading.Event()
165
+
166
+
167
+ def signal_handler(signum, frame):
168
+ logger.info("Received interrupt signal. Initiating shutdown...")
169
+ shutdown_event.set()
170
+
171
+
172
+ def run_server(app, host, port):
173
+ config = uvicorn.Config(app, host=host, port=port, loop="asyncio")
174
+ server = uvicorn.Server(config)
175
+ server.install_signal_handlers = (
176
+ lambda: None
177
+ ) # Disable Uvicorn's own signal handlers
178
+
179
+ async def serve():
180
+ await server.serve()
181
+
182
+ thread = threading.Thread(target=asyncio.run, args=(serve(),))
183
+ thread.start()
184
+
185
+ try:
186
+ while not shutdown_event.is_set():
187
+ time.sleep(1)
188
+ except KeyboardInterrupt:
189
+ logger.info("Keyboard interrupt received. Initiating shutdown...")
190
+ finally:
191
+ shutdown_event.set()
192
+ logger.info("Stopping the server...")
193
+ asyncio.run(server.shutdown())
194
+ thread.join()
195
+ logger.info("Server stopped.")
196
+
197
+
198
+ if __name__ == "__main__":
199
+ import uvicorn
200
+ import argparse
201
+
202
+ parser = argparse.ArgumentParser(description="OCR Service")
203
+ parser.add_argument(
204
+ "--port",
205
+ type=int,
206
+ default=8000,
207
+ help="Port to run the OCR service on",
208
+ )
209
+ parser.add_argument(
210
+ "--max-workers",
211
+ type=int,
212
+ default=1,
213
+ help="Maximum number of worker threads for OCR processing",
214
+ )
215
+ parser.add_argument(
216
+ "--gpu",
217
+ action="store_true",
218
+ help="Use GPU for OCR processing",
219
+ )
220
+
221
+ args = parser.parse_args()
222
+ port = args.port
223
+ max_workers = args.max_workers
224
+ use_gpu = args.gpu
225
+
226
+ run_server(app, "0.0.0.0", port)
227
+
@@ -0,0 +1,42 @@
1
+ Cls:
2
+ cls_batch_num: 6
3
+ cls_image_shape:
4
+ - 3
5
+ - 48
6
+ - 192
7
+ cls_thresh: 0.9
8
+ label_list:
9
+ - '0'
10
+ - '180'
11
+ model_path: /Users/shanchuanxu/projects/memos/memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx
12
+ use_cuda: false
13
+ Det:
14
+ box_thresh: 0.3
15
+ limit_side_len: 1500
16
+ limit_type: min
17
+ max_candidates: 1000
18
+ model_path: /Users/shanchuanxu/projects/memos/memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx
19
+ score_mode: fast
20
+ thresh: 0.3
21
+ unclip_ratio: 1.6
22
+ use_cuda: false
23
+ use_dilation: true
24
+ Global:
25
+ max_side_len: 2000
26
+ min_height: 30
27
+ min_side_len: 30
28
+ print_verbose: false
29
+ text_score: 0.5
30
+ use_cls: true
31
+ use_det: true
32
+ use_rec: true
33
+ use_space_char: true
34
+ width_height_ratio: 40
35
+ Rec:
36
+ model_path: /Users/shanchuanxu/projects/memos/memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx
37
+ rec_batch_num: 6
38
+ rec_img_shape:
39
+ - 3
40
+ - 48
41
+ - 320
42
+ use_cuda: false
File without changes
@@ -0,0 +1,251 @@
1
+ import base64
2
+ import httpx
3
+ from PIL import Image
4
+ import asyncio
5
+ from typing import Optional
6
+ from fastapi import APIRouter, FastAPI, Request, HTTPException
7
+ from memos.schemas import Entity, MetadataType
8
+ import logging
9
+ import uvicorn
10
+ import os
11
+ import io
12
+ import numpy as np
13
+
14
+
15
+ # Configure logger
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ PLUGIN_NAME = "vlm"
20
+
21
+ router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
22
+
23
+ modelname = None
24
+ endpoint = None
25
+ token = None
26
+ concurrency = None
27
+ semaphore = None
28
+ force_jpeg = None
29
+ prompt = None
30
+
31
+
32
+ def get_metadata_name() -> str:
33
+ """Return the metadata field name used by this plugin."""
34
+ global modelname
35
+ return f"{modelname.replace('-', '_')}_result"
36
+
37
+
38
+ def image2base64(img_path):
39
+ try:
40
+ with Image.open(img_path) as img:
41
+ img.verify() # Verify the image file
42
+
43
+ with Image.open(img_path) as img:
44
+ # Check image size and skip if it's too small
45
+ if img.width < 10 or img.height < 10:
46
+ logger.warning(f"Image is too small: {img.width}x{img.height}. Skipping processing.")
47
+ return None
48
+
49
+ # Convert image to RGB mode (removes alpha channel if present)
50
+ img = img.convert("RGB")
51
+
52
+ # Convert to numpy array and check shape
53
+ img_array = np.array(img)
54
+ logger.info(f"Image shape: {img_array.shape}")
55
+
56
+ if img_array.shape[2] != 3:
57
+ logger.warning(f"Unexpected number of channels: {img_array.shape[2]}. Expected 3. Skipping processing.")
58
+ return None
59
+
60
+ if force_jpeg:
61
+ # Save as JPEG in memory
62
+ buffer = io.BytesIO()
63
+ img.save(buffer, format="JPEG")
64
+ buffer.seek(0)
65
+ encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
66
+ else:
67
+ # Use original format, but ensure it's RGB
68
+ buffer = io.BytesIO()
69
+ img.save(buffer, format=img.format or "JPEG")
70
+ buffer.seek(0)
71
+ encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
72
+ return encoded_string
73
+ except Exception as e:
74
+ logger.error(f"Error processing image {img_path}: {str(e)}")
75
+ return None
76
+
77
+
78
+ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = None):
79
+ async with semaphore:
80
+ try:
81
+ response = await client.post(
82
+ f"{endpoint}/v1/chat/completions",
83
+ json=request_data,
84
+ timeout=60,
85
+ headers=headers,
86
+ )
87
+ response.raise_for_status()
88
+ result = response.json()
89
+ choices = result.get("choices", [])
90
+ if (
91
+ choices
92
+ and "message" in choices[0]
93
+ and "content" in choices[0]["message"]
94
+ ):
95
+ return choices[0]["message"]["content"]
96
+ return ""
97
+ except Exception as e:
98
+ logger.error(f"Exception occurred: {str(e)}")
99
+ return None
100
+
101
+
102
+ async def predict(
103
+ endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
104
+ ) -> Optional[str]:
105
+ return await predict_remote(endpoint, modelname, img_path, token)
106
+
107
+
108
+ async def predict_remote(
109
+ endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
110
+ ) -> Optional[str]:
111
+ img_base64 = image2base64(img_path)
112
+ if not img_base64:
113
+ logger.warning(f"Skipping processing for file: {img_path} due to invalid or small image")
114
+ return None
115
+
116
+ mime_type = (
117
+ "image/jpeg" if force_jpeg else "image/jpeg"
118
+ ) # Default to JPEG if force_jpeg is True
119
+
120
+ if not force_jpeg:
121
+ # Only determine MIME type if not forcing JPEG
122
+ _, file_extension = os.path.splitext(img_path)
123
+ file_extension = file_extension.lower()[1:]
124
+ mime_types = {
125
+ "png": "image/png",
126
+ "jpg": "image/jpeg",
127
+ "jpeg": "image/jpeg",
128
+ "webp": "image/webp",
129
+ }
130
+ mime_type = mime_types.get(file_extension, "image/jpeg")
131
+
132
+ request_data = {
133
+ "model": modelname,
134
+ "messages": [
135
+ {
136
+ "role": "user",
137
+ "content": [
138
+ {
139
+ "type": "image_url",
140
+ "image_url": {"url": f"data:{mime_type};base64,{img_base64}"},
141
+ },
142
+ {"type": "text", "text": prompt}, # Use the global prompt variable here
143
+ ],
144
+ }
145
+ ],
146
+ "stream": False,
147
+ "max_tokens": 1024,
148
+ "temperature": 0.1,
149
+ "repetition_penalty": 1.1,
150
+ "top_p": 0.8,
151
+ }
152
+ async with httpx.AsyncClient() as client:
153
+ headers = {}
154
+ if token:
155
+ headers["Authorization"] = f"Bearer {token.get_secret_value()}"
156
+ return await fetch(endpoint, client, request_data, headers=headers)
157
+
158
+
159
+ @router.get("/")
160
+ async def read_root():
161
+ return {"healthy": True}
162
+
163
+
164
+ @router.post("", include_in_schema=False)
165
+ @router.post("/")
166
+ async def vlm(entity: Entity, request: Request):
167
+ global modelname, endpoint, token
168
+ metadata_field_name = get_metadata_name()
169
+ if not entity.file_type_group == "image":
170
+ return {metadata_field_name: ""}
171
+
172
+ # Check if the METADATA_FIELD_NAME field is empty or null
173
+ existing_metadata = entity.get_metadata_by_key(metadata_field_name)
174
+ if (
175
+ existing_metadata
176
+ and existing_metadata.value
177
+ and existing_metadata.value.strip()
178
+ ):
179
+ logger.info(
180
+ f"Skipping processing for file: {entity.filepath} due to existing metadata"
181
+ )
182
+ # If the field is not empty, return without processing
183
+ return {metadata_field_name: existing_metadata.value}
184
+
185
+ # Check if the entity contains the tag "low_info"
186
+ if any(tag.name == "low_info" for tag in entity.tags):
187
+ # If the tag is present, return without processing
188
+ logger.info(
189
+ f"Skipping processing for file: {entity.filepath} due to 'low_info' tag"
190
+ )
191
+ return {metadata_field_name: ""}
192
+
193
+ location_url = request.headers.get("Location")
194
+ if not location_url:
195
+ raise HTTPException(status_code=400, detail="Location header is missing")
196
+
197
+ patch_url = f"{location_url}/metadata"
198
+
199
+ vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
200
+
201
+ logger.info(f"VLM result: {vlm_result[:100]}...")
202
+ if not vlm_result:
203
+ logger.info(f"No VLM result found for file: {entity.filepath}")
204
+ return {metadata_field_name: "{}"}
205
+
206
+ async with httpx.AsyncClient() as client:
207
+ response = await client.patch(
208
+ patch_url,
209
+ json={
210
+ "metadata_entries": [
211
+ {
212
+ "key": metadata_field_name,
213
+ "value": vlm_result,
214
+ "source": PLUGIN_NAME,
215
+ "data_type": MetadataType.TEXT_DATA.value,
216
+ }
217
+ ]
218
+ },
219
+ timeout=30,
220
+ )
221
+
222
+ if response.status_code != 200:
223
+ raise HTTPException(
224
+ status_code=response.status_code, detail="Failed to patch entity metadata"
225
+ )
226
+
227
+ return {
228
+ metadata_field_name: vlm_result,
229
+ }
230
+
231
+
232
+ def init_plugin(config):
233
+ global modelname, endpoint, token, concurrency, semaphore, force_jpeg, prompt
234
+
235
+ modelname = config.modelname
236
+ endpoint = config.endpoint
237
+ token = config.token
238
+ concurrency = config.concurrency
239
+ force_jpeg = config.force_jpeg
240
+ prompt = config.prompt
241
+ semaphore = asyncio.Semaphore(concurrency)
242
+
243
+ # Print the parameters
244
+ logger.info("VLM plugin initialized")
245
+ logger.info(f"Model Name: {modelname}")
246
+ logger.info(f"Endpoint: {endpoint}")
247
+ logger.info(f"Token: {token}")
248
+ logger.info(f"Concurrency: {concurrency}")
249
+ logger.info(f"Force JPEG: {force_jpeg}")
250
+ logger.info(f"Prompt: {prompt}")
251
+
@@ -0,0 +1,107 @@
1
+ import json
2
+ import os
3
+ import shutil
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ import io
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+
9
+
10
+ def compress_and_convert_image(image_path, output_path, max_size_mb=2):
11
+ with Image.open(image_path) as img:
12
+ if (
13
+ image_path.lower().endswith(".png")
14
+ and os.path.getsize(image_path) > max_size_mb * 1024 * 1024
15
+ ):
16
+ # 将 PNG 转换为 JPG 并压缩
17
+ img = img.convert("RGB")
18
+ output_path = output_path.rsplit(".", 1)[0] + ".jpg"
19
+ img.save(output_path, "JPEG", quality=85)
20
+ else:
21
+ # 直接复制原图
22
+ img.save(output_path)
23
+ return output_path
24
+
25
+
26
+ def process_image(data, shots_dir):
27
+ image_path = data.get("image")
28
+ if image_path and os.path.exists(image_path):
29
+ image_name = os.path.basename(image_path)
30
+ new_image_path = os.path.join(shots_dir, image_name)
31
+ if new_image_path == image_path:
32
+ return None
33
+ try:
34
+ new_image_path = compress_and_convert_image(image_path, new_image_path)
35
+ data["image"] = os.path.relpath(new_image_path, shots_dir)
36
+ return data, os.path.basename(new_image_path)
37
+ except Exception as e:
38
+ print(f"Error processing image {image_path}: {e}")
39
+ elif not image_path:
40
+ return data, None
41
+ return None
42
+
43
+
44
+ def update_image_paths(input_file, output_file, shots_dir, max_workers=8):
45
+ if not os.path.exists(shots_dir):
46
+ os.makedirs(shots_dir)
47
+
48
+ updated_data = []
49
+ seen_images = set()
50
+
51
+ with open(input_file, "r") as infile:
52
+ data_list = [json.loads(line) for line in infile]
53
+
54
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
55
+ futures = [
56
+ executor.submit(process_image, data, shots_dir) for data in data_list
57
+ ]
58
+
59
+ for future in tqdm(
60
+ as_completed(futures), total=len(futures), desc="Processing images"
61
+ ):
62
+ result = future.result()
63
+ if result:
64
+ data, image_name = result
65
+ if image_name and image_name not in seen_images:
66
+ seen_images.add(image_name)
67
+ updated_data.append(data)
68
+ elif not image_name:
69
+ updated_data.append(data)
70
+
71
+ with open(output_file, "w") as outfile:
72
+ for item in tqdm(updated_data, desc="Writing updated data", unit="item"):
73
+ json.dump(item, outfile, ensure_ascii=False)
74
+ outfile.write("\n")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ import argparse
79
+
80
+ parser = argparse.ArgumentParser(description="Update image paths in the dataset.")
81
+ parser.add_argument(
82
+ "--input_file",
83
+ type=str,
84
+ default="dataset.updated.jsonl",
85
+ help="Path to the input JSONL file.",
86
+ )
87
+ parser.add_argument(
88
+ "--output_file",
89
+ type=str,
90
+ default="dataset.updated.formatted.jsonl",
91
+ help="Path to the output JSONL file.",
92
+ )
93
+ parser.add_argument(
94
+ "--shots_dir",
95
+ type=str,
96
+ default="shots",
97
+ help="Directory to save processed images.",
98
+ )
99
+ parser.add_argument(
100
+ "--max_workers", type=int, default=8, help="Maximum number of worker threads."
101
+ )
102
+
103
+ args = parser.parse_args()
104
+
105
+ update_image_paths(
106
+ args.input_file, args.output_file, args.shots_dir, args.max_workers
107
+ )
memos/process_webp.py ADDED
@@ -0,0 +1,55 @@
1
+ from pathlib import Path
2
+ from PIL import Image
3
+ import piexif
4
+ import json
5
+ from memos.utils import write_image_metadata, get_image_metadata
6
+ from tqdm import tqdm
7
+
8
+ def convert_webp_metadata(directory):
9
+ webp_files = list(Path(directory).glob('**/*.webp'))
10
+
11
+ for webp_file in tqdm(webp_files, desc="Converting WebP metadata", unit="file"):
12
+ try:
13
+ # Try to get metadata using the new method
14
+ new_metadata = get_image_metadata(webp_file)
15
+
16
+ if new_metadata:
17
+ tqdm.write(f"Skipping {webp_file}: Already in new format")
18
+ continue
19
+
20
+ # If new method fails, try to get old metadata
21
+ img = Image.open(webp_file)
22
+ old_metadata = img.info.get("exif", None)
23
+
24
+ if old_metadata is None:
25
+ tqdm.write(f"Skipping {webp_file}: No metadata found")
26
+ continue
27
+
28
+ if isinstance(old_metadata, bytes):
29
+ try:
30
+ old_metadata = old_metadata.decode('utf-8')
31
+ except UnicodeDecodeError:
32
+ tqdm.write(f"Skipping {webp_file}: Unable to decode metadata")
33
+ continue
34
+
35
+ try:
36
+ metadata = json.loads(old_metadata)
37
+ except json.JSONDecodeError:
38
+ tqdm.write(f"Skipping {webp_file}: Invalid metadata format")
39
+ continue
40
+
41
+ # Convert to new format
42
+ write_image_metadata(webp_file, metadata)
43
+ tqdm.write(f"Converted metadata for {webp_file}")
44
+
45
+ except Exception as e:
46
+ tqdm.write(f"Error processing {webp_file}: {str(e)}")
47
+
48
+ if __name__ == "__main__":
49
+ import sys
50
+ if len(sys.argv) != 2:
51
+ print("Usage: python convert_webp_metadata.py <directory>")
52
+ sys.exit(1)
53
+
54
+ directory = sys.argv[1]
55
+ convert_webp_metadata(directory)