pensiev 0.25.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memos/__init__.py +6 -0
- memos/cmds/__init__.py +0 -0
- memos/cmds/library.py +1289 -0
- memos/cmds/plugin.py +96 -0
- memos/commands.py +865 -0
- memos/config.py +225 -0
- memos/crud.py +605 -0
- memos/databases/__init__.py +0 -0
- memos/databases/initializers.py +481 -0
- memos/dataset_extractor_for_florence.py +165 -0
- memos/dataset_extractor_for_internvl2.py +192 -0
- memos/default_config.yaml +88 -0
- memos/embedding.py +129 -0
- memos/frame_extractor.py +53 -0
- memos/logging_config.py +35 -0
- memos/main.py +104 -0
- memos/migrations/alembic/README +1 -0
- memos/migrations/alembic/__pycache__/env.cpython-310.pyc +0 -0
- memos/migrations/alembic/env.py +108 -0
- memos/migrations/alembic/script.py.mako +30 -0
- memos/migrations/alembic/versions/00904ac8c6fc_add_indexes_to_entitymodel.py +63 -0
- memos/migrations/alembic/versions/04acdaf75664_add_indices_to_entitytags_and_metadata.py +86 -0
- memos/migrations/alembic/versions/12504c5b1d3c_add_extra_columns_for_embedding.py +67 -0
- memos/migrations/alembic/versions/31a1ad0e10b3_add_entity_plugin_status.py +71 -0
- memos/migrations/alembic/versions/__pycache__/00904ac8c6fc_add_indexes_to_entitymodel.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/04acdaf75664_add_indices_to_entitytags_and_metadata.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/12504c5b1d3c_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/20f5ecab014d_add_entity_plugin_status.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/31a1ad0e10b3_add_entity_plugin_status.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/4fcb062c5128_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/d10c55fbb7d2_add_index_for_entity_file_type_group_.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/f8f158182416_add_active_app_index.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/d10c55fbb7d2_add_index_for_entity_file_type_group_.py +44 -0
- memos/migrations/alembic/versions/f8f158182416_add_active_app_index.py +75 -0
- memos/migrations/alembic.ini +116 -0
- memos/migrations.py +19 -0
- memos/models.py +199 -0
- memos/plugins/__init__.py +0 -0
- memos/plugins/ocr/__init__.py +0 -0
- memos/plugins/ocr/main.py +251 -0
- memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx +0 -0
- memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx +0 -0
- memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx +0 -0
- memos/plugins/ocr/ppocr-gpu.yaml +43 -0
- memos/plugins/ocr/ppocr.yaml +44 -0
- memos/plugins/ocr/server.py +227 -0
- memos/plugins/ocr/temp_ppocr.yaml +42 -0
- memos/plugins/vlm/__init__.py +0 -0
- memos/plugins/vlm/main.py +251 -0
- memos/prepare_dataset.py +107 -0
- memos/process_webp.py +55 -0
- memos/read_metadata.py +32 -0
- memos/record.py +358 -0
- memos/schemas.py +289 -0
- memos/search.py +1198 -0
- memos/server.py +883 -0
- memos/shotsum.py +105 -0
- memos/shotsum_with_ocr.py +145 -0
- memos/simple_tokenizer/dict/README.md +31 -0
- memos/simple_tokenizer/dict/hmm_model.utf8 +34 -0
- memos/simple_tokenizer/dict/idf.utf8 +258826 -0
- memos/simple_tokenizer/dict/jieba.dict.utf8 +348982 -0
- memos/simple_tokenizer/dict/pos_dict/char_state_tab.utf8 +6653 -0
- memos/simple_tokenizer/dict/pos_dict/prob_emit.utf8 +166 -0
- memos/simple_tokenizer/dict/pos_dict/prob_start.utf8 +259 -0
- memos/simple_tokenizer/dict/pos_dict/prob_trans.utf8 +5222 -0
- memos/simple_tokenizer/dict/stop_words.utf8 +1534 -0
- memos/simple_tokenizer/dict/user.dict.utf8 +4 -0
- memos/simple_tokenizer/linux/libsimple.so +0 -0
- memos/simple_tokenizer/macos/libsimple.dylib +0 -0
- memos/simple_tokenizer/windows/simple.dll +0 -0
- memos/static/_app/immutable/assets/0.e250c031.css +1 -0
- memos/static/_app/immutable/assets/_layout.e7937cfe.css +1 -0
- memos/static/_app/immutable/chunks/index.5c08976b.js +1 -0
- memos/static/_app/immutable/chunks/index.60ee613b.js +4 -0
- memos/static/_app/immutable/chunks/runtime.a7926cf6.js +5 -0
- memos/static/_app/immutable/chunks/scheduler.5c1cff6e.js +1 -0
- memos/static/_app/immutable/chunks/singletons.583bdf4e.js +1 -0
- memos/static/_app/immutable/entry/app.666c1643.js +1 -0
- memos/static/_app/immutable/entry/start.aed5c701.js +3 -0
- memos/static/_app/immutable/nodes/0.5862ea38.js +7 -0
- memos/static/_app/immutable/nodes/1.35378a5e.js +1 -0
- memos/static/_app/immutable/nodes/2.1ccf9ea5.js +81 -0
- memos/static/_app/version.json +1 -0
- memos/static/app.html +36 -0
- memos/static/favicon.png +0 -0
- memos/static/logos/memos_logo_1024.png +0 -0
- memos/static/logos/memos_logo_1024@2x.png +0 -0
- memos/static/logos/memos_logo_128.png +0 -0
- memos/static/logos/memos_logo_128@2x.png +0 -0
- memos/static/logos/memos_logo_16.png +0 -0
- memos/static/logos/memos_logo_16@2x.png +0 -0
- memos/static/logos/memos_logo_256.png +0 -0
- memos/static/logos/memos_logo_256@2x.png +0 -0
- memos/static/logos/memos_logo_32.png +0 -0
- memos/static/logos/memos_logo_32@2x.png +0 -0
- memos/static/logos/memos_logo_512.png +0 -0
- memos/static/logos/memos_logo_512@2x.png +0 -0
- memos/static/logos/memos_logo_64.png +0 -0
- memos/static/logos/memos_logo_64@2x.png +0 -0
- memos/test_server.py +802 -0
- memos/utils.py +49 -0
- memos_ml_backends/florence2_server.py +176 -0
- memos_ml_backends/qwen2vl_server.py +182 -0
- memos_ml_backends/schemas.py +48 -0
- pensiev-0.25.5.dist-info/LICENSE +201 -0
- pensiev-0.25.5.dist-info/METADATA +541 -0
- pensiev-0.25.5.dist-info/RECORD +111 -0
- pensiev-0.25.5.dist-info/WHEEL +5 -0
- pensiev-0.25.5.dist-info/entry_points.txt +2 -0
- pensiev-0.25.5.dist-info/top_level.txt +2 -0
@@ -0,0 +1,192 @@
|
|
1
|
+
"""
|
2
|
+
准备 image path 和 ocr 数据,用于后续提供给 internvl2 生成更符合需求的 caption
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
import argparse
|
7
|
+
from sqlalchemy.orm import sessionmaker
|
8
|
+
from memos.models import EntityModel, EntityMetadataModel
|
9
|
+
from memos.config import get_database_path
|
10
|
+
from sqlalchemy import create_engine
|
11
|
+
from tqdm import tqdm
|
12
|
+
from pathlib import Path
|
13
|
+
import argilla as rg
|
14
|
+
from PIL import Image
|
15
|
+
import io
|
16
|
+
|
17
|
+
|
18
|
+
def prepare_huggingface_dataset(output_file, batch_size=100, record_count=10000, libraries=None):
|
19
|
+
"""Prepare a Hugging Face dataset and save it as JSONL."""
|
20
|
+
db_path = get_database_path()
|
21
|
+
engine = create_engine(f"sqlite:///{db_path}")
|
22
|
+
Session = sessionmaker(bind=engine)
|
23
|
+
|
24
|
+
def simplify_ocr_data(ocr_data):
|
25
|
+
simplified_data = []
|
26
|
+
for item in ocr_data:
|
27
|
+
# 提取左上角和右下角的坐标
|
28
|
+
dt_boxes = [item["dt_boxes"][0], item["dt_boxes"][2]]
|
29
|
+
# 保留 rec_txt 字段,剔除 score
|
30
|
+
simplified_data.append({
|
31
|
+
"dt_boxes": dt_boxes,
|
32
|
+
"rec_txt": item["rec_txt"]
|
33
|
+
})
|
34
|
+
return simplified_data
|
35
|
+
|
36
|
+
with Session() as session, open(output_file, "w", encoding="utf-8") as f:
|
37
|
+
query = session.query(EntityModel)
|
38
|
+
|
39
|
+
# 如果指定了 libraries,添加过滤条件
|
40
|
+
if libraries:
|
41
|
+
query = query.filter(EntityModel.library_id.in_(libraries))
|
42
|
+
|
43
|
+
total = query.count()
|
44
|
+
|
45
|
+
progress_bar = tqdm(
|
46
|
+
total=min(total, record_count), desc="Processing entities", unit="entity"
|
47
|
+
)
|
48
|
+
inserted_records = 0
|
49
|
+
|
50
|
+
for offset in range(0, total, batch_size):
|
51
|
+
batch = query.limit(batch_size).offset(offset).all()
|
52
|
+
|
53
|
+
for entity in batch:
|
54
|
+
# Skip entities with "low_info" tag
|
55
|
+
if any(tag.name == "low_info" for tag in entity.tags):
|
56
|
+
progress_bar.update(1)
|
57
|
+
continue
|
58
|
+
|
59
|
+
metadata = {entry.key: entry.value for entry in entity.metadata_entries}
|
60
|
+
|
61
|
+
ocr = metadata.get("ocr_result")
|
62
|
+
if not ocr or not Path(entity.filepath).exists():
|
63
|
+
progress_bar.update(1)
|
64
|
+
continue
|
65
|
+
|
66
|
+
# Load JSON from string
|
67
|
+
ocr_data = json.loads(ocr)
|
68
|
+
|
69
|
+
# Simplify OCR result
|
70
|
+
simplified_ocr = simplify_ocr_data(ocr_data)
|
71
|
+
|
72
|
+
record = {
|
73
|
+
"id": entity.id,
|
74
|
+
"image": entity.filepath,
|
75
|
+
"ocr": simplified_ocr,
|
76
|
+
}
|
77
|
+
|
78
|
+
# Dump to JSON string
|
79
|
+
json_record = json.dumps(record, ensure_ascii=False)
|
80
|
+
f.write(json_record + "\n")
|
81
|
+
progress_bar.update(1)
|
82
|
+
inserted_records += 1
|
83
|
+
|
84
|
+
if inserted_records >= record_count:
|
85
|
+
break
|
86
|
+
if inserted_records >= record_count:
|
87
|
+
break
|
88
|
+
|
89
|
+
progress_bar.close()
|
90
|
+
|
91
|
+
print(f"Dataset saved to {output_file}")
|
92
|
+
|
93
|
+
|
94
|
+
def init_argilla_dataset(client, dataset_name="image_captioning"):
|
95
|
+
workspace_name = "argilla"
|
96
|
+
|
97
|
+
workspace = client.workspaces(workspace_name)
|
98
|
+
|
99
|
+
if workspace is None:
|
100
|
+
workspace = rg.Workspace(name=workspace_name, client=client)
|
101
|
+
workspace.create()
|
102
|
+
print(f"Workspace created: {workspace_name}")
|
103
|
+
|
104
|
+
dataset = client.datasets(name=dataset_name)
|
105
|
+
|
106
|
+
if dataset is not None:
|
107
|
+
return dataset
|
108
|
+
|
109
|
+
settings = rg.Settings(
|
110
|
+
fields=[
|
111
|
+
rg.ImageField(name="image"),
|
112
|
+
rg.TextField(name="filepath")
|
113
|
+
],
|
114
|
+
questions=[
|
115
|
+
rg.TextQuestion(
|
116
|
+
name="text",
|
117
|
+
title="Description of the image",
|
118
|
+
required=True,
|
119
|
+
use_markdown=True,
|
120
|
+
)
|
121
|
+
],
|
122
|
+
)
|
123
|
+
|
124
|
+
dataset = rg.Dataset(
|
125
|
+
name=dataset_name, workspace=workspace_name, settings=settings, client=client
|
126
|
+
)
|
127
|
+
|
128
|
+
dataset.create()
|
129
|
+
print(f"Dataset created: {dataset_name}")
|
130
|
+
|
131
|
+
return dataset
|
132
|
+
|
133
|
+
|
134
|
+
def upload_to_argilla(input_file, batch_size=10, dataset_name="image_captioning"):
|
135
|
+
"""Upload a JSONL dataset to Argilla."""
|
136
|
+
|
137
|
+
client = rg.Argilla(api_url="http://localhost:6900", api_key="argilla.apikey")
|
138
|
+
|
139
|
+
dataset = init_argilla_dataset(client, dataset_name)
|
140
|
+
|
141
|
+
records = []
|
142
|
+
total_records = sum(1 for _ in open(input_file, "r"))
|
143
|
+
|
144
|
+
with open(input_file, "r", encoding="utf-8") as f:
|
145
|
+
progress_bar = tqdm(
|
146
|
+
total=total_records, desc="Uploading to Argilla", unit="record"
|
147
|
+
)
|
148
|
+
|
149
|
+
for line in f:
|
150
|
+
record_data = json.loads(line)
|
151
|
+
image = Image.open(record_data["image"]).convert("RGB")
|
152
|
+
image.thumbnail((1280, 1280))
|
153
|
+
|
154
|
+
rg_record = rg.Record(
|
155
|
+
id=str(record_data["id"]),
|
156
|
+
fields={
|
157
|
+
"image": image,
|
158
|
+
"filepath": record_data["image"],
|
159
|
+
},
|
160
|
+
suggestions=[
|
161
|
+
rg.Suggestion(
|
162
|
+
"text", record_data["answer"], score=1.0, agent="internvl2"
|
163
|
+
)
|
164
|
+
],
|
165
|
+
)
|
166
|
+
records.append(rg_record)
|
167
|
+
|
168
|
+
if len(records) >= batch_size:
|
169
|
+
dataset.records.log(records)
|
170
|
+
progress_bar.update(batch_size)
|
171
|
+
records = []
|
172
|
+
|
173
|
+
if records:
|
174
|
+
dataset.records.log(records)
|
175
|
+
progress_bar.update(len(records))
|
176
|
+
|
177
|
+
progress_bar.close()
|
178
|
+
|
179
|
+
print(f"Dataset uploaded to Argilla: {dataset_name}")
|
180
|
+
|
181
|
+
|
182
|
+
if __name__ == "__main__":
|
183
|
+
parser = argparse.ArgumentParser(description="Prepare and upload dataset")
|
184
|
+
parser.add_argument("--output_file", default="dataset.jsonl", help="Output file path")
|
185
|
+
parser.add_argument("--size", type=int, default=10000, help="Number of records to extract")
|
186
|
+
parser.add_argument("--libraries", nargs="+", type=int, help="List of library IDs to filter entities")
|
187
|
+
args = parser.parse_args()
|
188
|
+
|
189
|
+
prepare_huggingface_dataset(args.output_file, record_count=args.size, libraries=args.libraries)
|
190
|
+
print(f"Dataset saved to {args.output_file}")
|
191
|
+
# Uncomment the following line if you want to upload to Argilla
|
192
|
+
# upload_to_argilla(args.output_file)
|
@@ -0,0 +1,88 @@
|
|
1
|
+
base_dir: ~/.memos
|
2
|
+
|
3
|
+
# Database settings
|
4
|
+
# Can be either:
|
5
|
+
# 1. A file path for SQLite (relative to base_dir): database.db
|
6
|
+
# 2. SQLite URL: sqlite:///absolute/path/to/database.db
|
7
|
+
# 3. PostgreSQL URL: postgresql://postgres:mysecretpassword@localhost:5432/postgres
|
8
|
+
database_path: database.db
|
9
|
+
|
10
|
+
default_library: screenshots
|
11
|
+
screenshots_dir: screenshots
|
12
|
+
|
13
|
+
server_host: 0.0.0.0
|
14
|
+
server_port: 8839
|
15
|
+
|
16
|
+
# Enable authentication by uncommenting the following lines
|
17
|
+
# auth_username: admin
|
18
|
+
# auth_password: changeme
|
19
|
+
|
20
|
+
default_plugins:
|
21
|
+
- builtin_ocr
|
22
|
+
# - builtin_vlm
|
23
|
+
|
24
|
+
# using ollama as the vlm server
|
25
|
+
vlm:
|
26
|
+
concurrency: 8
|
27
|
+
endpoint: http://localhost:11434
|
28
|
+
force_jpeg: true
|
29
|
+
modelname: minicpm-v
|
30
|
+
# 中文版本
|
31
|
+
prompt: 请帮描述这个图片中的内容,包括画面格局、出现的视觉元素等
|
32
|
+
# English version
|
33
|
+
prompt: Please describe the content of this image, including the layout and visual elements.
|
34
|
+
token: ''
|
35
|
+
|
36
|
+
# using local ocr
|
37
|
+
ocr:
|
38
|
+
concurrency: 8
|
39
|
+
# this is not used if use_local is true
|
40
|
+
endpoint: http://localhost:5555/predict
|
41
|
+
force_jpeg: false
|
42
|
+
token: ''
|
43
|
+
use_local: true
|
44
|
+
|
45
|
+
# using local embedding for English as the main language
|
46
|
+
embedding:
|
47
|
+
model: arkohut/jina-embeddings-v2-base-en
|
48
|
+
num_dim: 768
|
49
|
+
use_local: true
|
50
|
+
use_modelscope: false
|
51
|
+
|
52
|
+
watch:
|
53
|
+
# number of recent events to consider when calculating processing rates
|
54
|
+
rate_window_size: 10
|
55
|
+
# sparsity factor for file processing
|
56
|
+
# a higher value means less frequent processing
|
57
|
+
# 1.0 means process every file
|
58
|
+
sparsity_factor: 3.0
|
59
|
+
# initial processing interval for file processing, means process one file with plugins for every N files
|
60
|
+
# but will be adjusted automatically based on the processing rate
|
61
|
+
# 1 means process the first file at the beginning
|
62
|
+
processing_interval: 12
|
63
|
+
|
64
|
+
# A watch config like this means process every file with plugins at the beginning
|
65
|
+
# but if the processing rate is slower than file generated, the processing interval
|
66
|
+
# will be increased automatically
|
67
|
+
# watch:
|
68
|
+
# rate_window_size: 10
|
69
|
+
# sparsity_factor: 1.0
|
70
|
+
# processing_interval: 1
|
71
|
+
|
72
|
+
# using local embedding for Chinese as the main language
|
73
|
+
# embedding:
|
74
|
+
# model: arkohut/jina-embeddings-v2-base-zh
|
75
|
+
# num_dim: 768
|
76
|
+
# use_local: true
|
77
|
+
# use_modelscope: true
|
78
|
+
|
79
|
+
# using ollama embedding
|
80
|
+
# embedding:
|
81
|
+
# endpoint: http://localhost:11434/v1/embeddings # this is not used
|
82
|
+
# model: arkohut/gte-qwen2-1.5b-instruct:q8_0
|
83
|
+
# num_dim: 1536
|
84
|
+
# use_local: false
|
85
|
+
# use_modelscope: false
|
86
|
+
|
87
|
+
record_interval: 4 # seconds
|
88
|
+
facet: false # support facet filter
|
memos/embedding.py
ADDED
@@ -0,0 +1,129 @@
|
|
1
|
+
from typing import List
|
2
|
+
import numpy as np
|
3
|
+
from .config import settings
|
4
|
+
import logging
|
5
|
+
import httpx
|
6
|
+
import logfire
|
7
|
+
from functools import lru_cache
|
8
|
+
import hashlib
|
9
|
+
import json
|
10
|
+
|
11
|
+
# Configure logger
|
12
|
+
logging.basicConfig(level=logging.INFO)
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
# Global variables
|
16
|
+
model = None
|
17
|
+
device = None
|
18
|
+
|
19
|
+
|
20
|
+
def init_embedding_model():
|
21
|
+
import torch
|
22
|
+
from sentence_transformers import SentenceTransformer
|
23
|
+
|
24
|
+
global model, device
|
25
|
+
if torch.cuda.is_available():
|
26
|
+
device = torch.device("cuda")
|
27
|
+
elif torch.backends.mps.is_available():
|
28
|
+
device = torch.device("mps")
|
29
|
+
else:
|
30
|
+
device = torch.device("cpu")
|
31
|
+
|
32
|
+
if settings.embedding.use_modelscope:
|
33
|
+
from modelscope import snapshot_download
|
34
|
+
model_dir = snapshot_download(settings.embedding.model)
|
35
|
+
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
36
|
+
else:
|
37
|
+
model_dir = settings.embedding.model
|
38
|
+
logger.info(f"Using model: {model_dir}")
|
39
|
+
|
40
|
+
model = SentenceTransformer(model_dir, trust_remote_code=True)
|
41
|
+
model.to(device)
|
42
|
+
logger.info(f"Embedding model initialized on device: {device}")
|
43
|
+
|
44
|
+
|
45
|
+
def generate_embeddings(texts: List[str]) -> List[List[float]]:
|
46
|
+
global model
|
47
|
+
|
48
|
+
if model is None:
|
49
|
+
init_embedding_model()
|
50
|
+
|
51
|
+
if not texts:
|
52
|
+
return []
|
53
|
+
|
54
|
+
embeddings = model.encode(texts, convert_to_tensor=True, show_progress_bar=False)
|
55
|
+
embeddings = embeddings.cpu().numpy()
|
56
|
+
|
57
|
+
# Normalize embeddings
|
58
|
+
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
59
|
+
norms[norms == 0] = 1
|
60
|
+
embeddings = embeddings / norms
|
61
|
+
|
62
|
+
return embeddings.tolist()
|
63
|
+
|
64
|
+
|
65
|
+
def _hash_texts(texts: List[str]) -> str:
|
66
|
+
"""Generate a stable hash for a list of texts."""
|
67
|
+
texts_json = json.dumps(texts, sort_keys=True)
|
68
|
+
return hashlib.sha256(texts_json.encode()).hexdigest()
|
69
|
+
|
70
|
+
|
71
|
+
@logfire.instrument
|
72
|
+
@lru_cache(maxsize=100) # Cache last 100 requests
|
73
|
+
def get_embeddings_cached(texts_hash: str, texts_tuple: tuple) -> List[List[float]]:
|
74
|
+
"""Internal cached function that works with immutable types."""
|
75
|
+
texts = list(texts_tuple)
|
76
|
+
if settings.embedding.use_local:
|
77
|
+
embeddings = generate_embeddings(texts)
|
78
|
+
else:
|
79
|
+
embeddings = get_remote_embeddings(texts)
|
80
|
+
|
81
|
+
# Round the embedding values to 5 decimal places
|
82
|
+
return [
|
83
|
+
[round(float(x), 5) for x in embedding]
|
84
|
+
for embedding in embeddings
|
85
|
+
]
|
86
|
+
|
87
|
+
|
88
|
+
@logfire.instrument
|
89
|
+
def get_embeddings(texts: List[str]) -> List[List[float]]:
|
90
|
+
"""Get embeddings with caching support."""
|
91
|
+
# Convert texts to immutable type and create a hash
|
92
|
+
texts_hash = _hash_texts(texts)
|
93
|
+
texts_tuple = tuple(texts)
|
94
|
+
return get_embeddings_cached(texts_hash, texts_tuple)
|
95
|
+
|
96
|
+
|
97
|
+
def get_remote_embeddings(texts: List[str]) -> List[List[float]]:
|
98
|
+
headers = {
|
99
|
+
"Content-Type": "application/json"
|
100
|
+
}
|
101
|
+
|
102
|
+
if settings.embedding.token.get_secret_value():
|
103
|
+
headers["Authorization"] = f"Bearer {settings.embedding.token.get_secret_value()}"
|
104
|
+
|
105
|
+
endpoint = settings.embedding.endpoint
|
106
|
+
is_ollama = endpoint.endswith("/embed")
|
107
|
+
|
108
|
+
if is_ollama:
|
109
|
+
payload = {"model": settings.embedding.model, "input": texts}
|
110
|
+
else: # openai compatible api
|
111
|
+
payload = {
|
112
|
+
"input": texts,
|
113
|
+
"model": settings.embedding.model,
|
114
|
+
"encoding_format": "float"
|
115
|
+
}
|
116
|
+
|
117
|
+
with httpx.Client(timeout=60) as client:
|
118
|
+
try:
|
119
|
+
response = client.post(endpoint, json=payload, headers=headers)
|
120
|
+
response.raise_for_status()
|
121
|
+
result = response.json()
|
122
|
+
|
123
|
+
if is_ollama:
|
124
|
+
return result["embeddings"]
|
125
|
+
else: # openai compatible api
|
126
|
+
return [item["embedding"] for item in result["data"]]
|
127
|
+
except httpx.RequestError as e:
|
128
|
+
logger.error(f"Error fetching embeddings from remote endpoint: {e}")
|
129
|
+
return [] # Return an empty list instead of raising an exception
|
memos/frame_extractor.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
import cv2
|
2
|
+
from PIL import Image
|
3
|
+
from pathlib import Path
|
4
|
+
import argparse
|
5
|
+
|
6
|
+
|
7
|
+
def extract_video_frame(video_path: Path, frame_number: int) -> Image.Image:
|
8
|
+
cap = cv2.VideoCapture(str(video_path))
|
9
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
10
|
+
ret, frame = cap.read()
|
11
|
+
cap.release()
|
12
|
+
|
13
|
+
if not ret:
|
14
|
+
return None
|
15
|
+
|
16
|
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
17
|
+
return Image.fromarray(frame_rgb)
|
18
|
+
|
19
|
+
|
20
|
+
def test_extract_video_frame(video_path: Path, frame_number: int):
|
21
|
+
# 确保测试视频文件存在
|
22
|
+
if not video_path.is_file():
|
23
|
+
print(f"Error: Video file not found at {video_path}")
|
24
|
+
return
|
25
|
+
|
26
|
+
# 尝试提取指定帧
|
27
|
+
frame_image = extract_video_frame(video_path, frame_number)
|
28
|
+
|
29
|
+
if frame_image is None:
|
30
|
+
print(f"Error: Failed to extract frame {frame_number} from video")
|
31
|
+
else:
|
32
|
+
print(f"Successfully extracted frame {frame_number} from video")
|
33
|
+
print(f"Frame dimensions: {frame_image.size}")
|
34
|
+
|
35
|
+
# 保存提取的帧为图像文件
|
36
|
+
output_path = video_path.with_name(f"extracted_frame_{frame_number}.png")
|
37
|
+
frame_image.save(output_path)
|
38
|
+
print(f"Extracted frame saved to: {output_path}")
|
39
|
+
|
40
|
+
|
41
|
+
if __name__ == "__main__":
|
42
|
+
parser = argparse.ArgumentParser(description="Extract a frame from a video file.")
|
43
|
+
parser.add_argument("video_path", type=str, help="Path to the video file")
|
44
|
+
parser.add_argument(
|
45
|
+
"frame_number", type=int, help="Frame number to extract (0-based index)"
|
46
|
+
)
|
47
|
+
|
48
|
+
args = parser.parse_args()
|
49
|
+
|
50
|
+
video_path = Path(args.video_path)
|
51
|
+
frame_number = args.frame_number
|
52
|
+
|
53
|
+
test_extract_video_frame(video_path, frame_number)
|
memos/logging_config.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
import logging
|
2
|
+
import sys
|
3
|
+
|
4
|
+
LOGGING_CONFIG = {
|
5
|
+
"version": 1,
|
6
|
+
"disable_existing_loggers": False,
|
7
|
+
"formatters": {
|
8
|
+
"default": {
|
9
|
+
"format": "%(asctime)s - %(levelname)s - %(message)s",
|
10
|
+
},
|
11
|
+
},
|
12
|
+
"handlers": {
|
13
|
+
"default": {
|
14
|
+
"level": "INFO",
|
15
|
+
"formatter": "default",
|
16
|
+
"class": "logging.StreamHandler",
|
17
|
+
"stream": sys.stdout,
|
18
|
+
},
|
19
|
+
},
|
20
|
+
"loggers": {
|
21
|
+
"": {
|
22
|
+
"handlers": ["default"],
|
23
|
+
"level": "INFO",
|
24
|
+
"propagate": False,
|
25
|
+
},
|
26
|
+
"uvicorn.error": {
|
27
|
+
"level": "INFO",
|
28
|
+
},
|
29
|
+
"uvicorn.access": {
|
30
|
+
"handlers": ["default"],
|
31
|
+
"level": "INFO",
|
32
|
+
"propagate": False,
|
33
|
+
},
|
34
|
+
},
|
35
|
+
}
|
memos/main.py
ADDED
@@ -0,0 +1,104 @@
|
|
1
|
+
from fastapi import FastAPI
|
2
|
+
from pydantic import BaseModel
|
3
|
+
from typing import List
|
4
|
+
import time
|
5
|
+
import random
|
6
|
+
from fastapi import Response, HTTPException
|
7
|
+
from datetime import datetime
|
8
|
+
|
9
|
+
app = FastAPI()
|
10
|
+
libraries = []
|
11
|
+
|
12
|
+
|
13
|
+
class Folder(BaseModel):
|
14
|
+
path: str
|
15
|
+
libraryId: int
|
16
|
+
|
17
|
+
|
18
|
+
class Library(BaseModel):
|
19
|
+
id: int
|
20
|
+
name: str
|
21
|
+
description: str | None
|
22
|
+
folders: List[Folder] = []
|
23
|
+
lastScanAt: datetime | None = None
|
24
|
+
|
25
|
+
model_config = {
|
26
|
+
"json_schema_extra": {
|
27
|
+
"examples": [
|
28
|
+
{
|
29
|
+
"id": 1,
|
30
|
+
"name": "Main Library",
|
31
|
+
"description": "A primary collection of various documents.",
|
32
|
+
"folders": [
|
33
|
+
{"path": "/documents/reports", "libraryId": 1},
|
34
|
+
{"path": "/documents/notes", "libraryId": 1},
|
35
|
+
],
|
36
|
+
"lastScanAt": "2023-10-04T14:48:00",
|
37
|
+
}
|
38
|
+
]
|
39
|
+
}
|
40
|
+
}
|
41
|
+
|
42
|
+
|
43
|
+
@app.get("/")
|
44
|
+
async def root():
|
45
|
+
return "ok"
|
46
|
+
|
47
|
+
|
48
|
+
@app.get("/libraries", response_model=List[Library])
|
49
|
+
async def get_libraries():
|
50
|
+
return libraries
|
51
|
+
|
52
|
+
|
53
|
+
@app.get("/libraries/{library_id}", response_model=Library)
|
54
|
+
async def get_library(library_id: int):
|
55
|
+
for library in libraries:
|
56
|
+
if library.id == library_id:
|
57
|
+
return library
|
58
|
+
raise HTTPException(status_code=404, detail="Library not found")
|
59
|
+
|
60
|
+
|
61
|
+
class LibraryParam(BaseModel):
|
62
|
+
name: str
|
63
|
+
description: str | None
|
64
|
+
folders: List[str]
|
65
|
+
|
66
|
+
|
67
|
+
@app.post("/libraries", status_code=201)
|
68
|
+
async def create_library(library: LibraryParam):
|
69
|
+
nextid = int(time.time()) + random.randint(1, 1000)
|
70
|
+
new_library = Library(
|
71
|
+
id=nextid,
|
72
|
+
name=library.name,
|
73
|
+
description=library.description,
|
74
|
+
folders=[Folder(path=path, libraryId=nextid) for path in library.folders],
|
75
|
+
)
|
76
|
+
libraries.append(new_library)
|
77
|
+
return new_library
|
78
|
+
|
79
|
+
|
80
|
+
@app.put("/libraries/{library_id}")
|
81
|
+
async def update_library(library_id: int, library: LibraryParam):
|
82
|
+
for lib in libraries:
|
83
|
+
if lib.id == library_id:
|
84
|
+
lib.name = library.name
|
85
|
+
lib.description = library.description
|
86
|
+
lib.folders = [
|
87
|
+
Folder(path=path, libraryId=library_id) for path in library.folders
|
88
|
+
]
|
89
|
+
return Response(status_code=204)
|
90
|
+
raise HTTPException(status_code=404, detail="Library not found")
|
91
|
+
|
92
|
+
|
93
|
+
@app.delete("/libraries/{library_id}", status_code=204)
|
94
|
+
async def delete_library(library_id: int):
|
95
|
+
for lib in libraries:
|
96
|
+
if lib.id == library_id:
|
97
|
+
libraries.remove(lib)
|
98
|
+
return Response(status_code=204)
|
99
|
+
raise HTTPException(status_code=404, detail="Library not found")
|
100
|
+
|
101
|
+
|
102
|
+
@app.post("/libraries/{library_id}/scan_tasks", status_code=202)
|
103
|
+
async def request_scan_library(library_id):
|
104
|
+
pass
|
@@ -0,0 +1 @@
|
|
1
|
+
Generic single-database configuration.
|
Binary file
|