pensiev 0.25.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memos/__init__.py +6 -0
- memos/cmds/__init__.py +0 -0
- memos/cmds/library.py +1289 -0
- memos/cmds/plugin.py +96 -0
- memos/commands.py +865 -0
- memos/config.py +225 -0
- memos/crud.py +605 -0
- memos/databases/__init__.py +0 -0
- memos/databases/initializers.py +481 -0
- memos/dataset_extractor_for_florence.py +165 -0
- memos/dataset_extractor_for_internvl2.py +192 -0
- memos/default_config.yaml +88 -0
- memos/embedding.py +129 -0
- memos/frame_extractor.py +53 -0
- memos/logging_config.py +35 -0
- memos/main.py +104 -0
- memos/migrations/alembic/README +1 -0
- memos/migrations/alembic/__pycache__/env.cpython-310.pyc +0 -0
- memos/migrations/alembic/env.py +108 -0
- memos/migrations/alembic/script.py.mako +30 -0
- memos/migrations/alembic/versions/00904ac8c6fc_add_indexes_to_entitymodel.py +63 -0
- memos/migrations/alembic/versions/04acdaf75664_add_indices_to_entitytags_and_metadata.py +86 -0
- memos/migrations/alembic/versions/12504c5b1d3c_add_extra_columns_for_embedding.py +67 -0
- memos/migrations/alembic/versions/31a1ad0e10b3_add_entity_plugin_status.py +71 -0
- memos/migrations/alembic/versions/__pycache__/00904ac8c6fc_add_indexes_to_entitymodel.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/04acdaf75664_add_indices_to_entitytags_and_metadata.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/12504c5b1d3c_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/20f5ecab014d_add_entity_plugin_status.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/31a1ad0e10b3_add_entity_plugin_status.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/4fcb062c5128_add_extra_columns_for_embedding.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/d10c55fbb7d2_add_index_for_entity_file_type_group_.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/__pycache__/f8f158182416_add_active_app_index.cpython-310.pyc +0 -0
- memos/migrations/alembic/versions/d10c55fbb7d2_add_index_for_entity_file_type_group_.py +44 -0
- memos/migrations/alembic/versions/f8f158182416_add_active_app_index.py +75 -0
- memos/migrations/alembic.ini +116 -0
- memos/migrations.py +19 -0
- memos/models.py +199 -0
- memos/plugins/__init__.py +0 -0
- memos/plugins/ocr/__init__.py +0 -0
- memos/plugins/ocr/main.py +251 -0
- memos/plugins/ocr/models/ch_PP-OCRv4_det_infer.onnx +0 -0
- memos/plugins/ocr/models/ch_PP-OCRv4_rec_infer.onnx +0 -0
- memos/plugins/ocr/models/ch_ppocr_mobile_v2.0_cls_train.onnx +0 -0
- memos/plugins/ocr/ppocr-gpu.yaml +43 -0
- memos/plugins/ocr/ppocr.yaml +44 -0
- memos/plugins/ocr/server.py +227 -0
- memos/plugins/ocr/temp_ppocr.yaml +42 -0
- memos/plugins/vlm/__init__.py +0 -0
- memos/plugins/vlm/main.py +251 -0
- memos/prepare_dataset.py +107 -0
- memos/process_webp.py +55 -0
- memos/read_metadata.py +32 -0
- memos/record.py +358 -0
- memos/schemas.py +289 -0
- memos/search.py +1198 -0
- memos/server.py +883 -0
- memos/shotsum.py +105 -0
- memos/shotsum_with_ocr.py +145 -0
- memos/simple_tokenizer/dict/README.md +31 -0
- memos/simple_tokenizer/dict/hmm_model.utf8 +34 -0
- memos/simple_tokenizer/dict/idf.utf8 +258826 -0
- memos/simple_tokenizer/dict/jieba.dict.utf8 +348982 -0
- memos/simple_tokenizer/dict/pos_dict/char_state_tab.utf8 +6653 -0
- memos/simple_tokenizer/dict/pos_dict/prob_emit.utf8 +166 -0
- memos/simple_tokenizer/dict/pos_dict/prob_start.utf8 +259 -0
- memos/simple_tokenizer/dict/pos_dict/prob_trans.utf8 +5222 -0
- memos/simple_tokenizer/dict/stop_words.utf8 +1534 -0
- memos/simple_tokenizer/dict/user.dict.utf8 +4 -0
- memos/simple_tokenizer/linux/libsimple.so +0 -0
- memos/simple_tokenizer/macos/libsimple.dylib +0 -0
- memos/simple_tokenizer/windows/simple.dll +0 -0
- memos/static/_app/immutable/assets/0.e250c031.css +1 -0
- memos/static/_app/immutable/assets/_layout.e7937cfe.css +1 -0
- memos/static/_app/immutable/chunks/index.5c08976b.js +1 -0
- memos/static/_app/immutable/chunks/index.60ee613b.js +4 -0
- memos/static/_app/immutable/chunks/runtime.a7926cf6.js +5 -0
- memos/static/_app/immutable/chunks/scheduler.5c1cff6e.js +1 -0
- memos/static/_app/immutable/chunks/singletons.583bdf4e.js +1 -0
- memos/static/_app/immutable/entry/app.666c1643.js +1 -0
- memos/static/_app/immutable/entry/start.aed5c701.js +3 -0
- memos/static/_app/immutable/nodes/0.5862ea38.js +7 -0
- memos/static/_app/immutable/nodes/1.35378a5e.js +1 -0
- memos/static/_app/immutable/nodes/2.1ccf9ea5.js +81 -0
- memos/static/_app/version.json +1 -0
- memos/static/app.html +36 -0
- memos/static/favicon.png +0 -0
- memos/static/logos/memos_logo_1024.png +0 -0
- memos/static/logos/memos_logo_1024@2x.png +0 -0
- memos/static/logos/memos_logo_128.png +0 -0
- memos/static/logos/memos_logo_128@2x.png +0 -0
- memos/static/logos/memos_logo_16.png +0 -0
- memos/static/logos/memos_logo_16@2x.png +0 -0
- memos/static/logos/memos_logo_256.png +0 -0
- memos/static/logos/memos_logo_256@2x.png +0 -0
- memos/static/logos/memos_logo_32.png +0 -0
- memos/static/logos/memos_logo_32@2x.png +0 -0
- memos/static/logos/memos_logo_512.png +0 -0
- memos/static/logos/memos_logo_512@2x.png +0 -0
- memos/static/logos/memos_logo_64.png +0 -0
- memos/static/logos/memos_logo_64@2x.png +0 -0
- memos/test_server.py +802 -0
- memos/utils.py +49 -0
- memos_ml_backends/florence2_server.py +176 -0
- memos_ml_backends/qwen2vl_server.py +182 -0
- memos_ml_backends/schemas.py +48 -0
- pensiev-0.25.5.dist-info/LICENSE +201 -0
- pensiev-0.25.5.dist-info/METADATA +541 -0
- pensiev-0.25.5.dist-info/RECORD +111 -0
- pensiev-0.25.5.dist-info/WHEEL +5 -0
- pensiev-0.25.5.dist-info/entry_points.txt +2 -0
- pensiev-0.25.5.dist-info/top_level.txt +2 -0
memos/search.py
ADDED
@@ -0,0 +1,1198 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from sqlalchemy import text, bindparam
|
3
|
+
from sqlalchemy.orm import Session
|
4
|
+
from typing import List, Optional, Tuple
|
5
|
+
import time
|
6
|
+
import logging
|
7
|
+
import logfire
|
8
|
+
from sqlite_vec import serialize_float32
|
9
|
+
from collections import defaultdict
|
10
|
+
from datetime import datetime
|
11
|
+
from .embedding import get_embeddings
|
12
|
+
import json
|
13
|
+
import jieba
|
14
|
+
import os
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class SearchProvider(ABC):
|
20
|
+
@abstractmethod
|
21
|
+
def full_text_search(
|
22
|
+
self,
|
23
|
+
query: str,
|
24
|
+
db: Session,
|
25
|
+
limit: int,
|
26
|
+
library_ids: Optional[List[int]] = None,
|
27
|
+
start: Optional[int] = None,
|
28
|
+
end: Optional[int] = None,
|
29
|
+
app_names: Optional[List[str]] = None,
|
30
|
+
) -> List[int]:
|
31
|
+
pass
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def vector_search(
|
35
|
+
self,
|
36
|
+
embeddings: List[float],
|
37
|
+
db: Session,
|
38
|
+
limit: int,
|
39
|
+
library_ids: Optional[List[int]] = None,
|
40
|
+
start: Optional[int] = None,
|
41
|
+
end: Optional[int] = None,
|
42
|
+
app_names: Optional[List[str]] = None,
|
43
|
+
) -> List[int]:
|
44
|
+
pass
|
45
|
+
|
46
|
+
@abstractmethod
|
47
|
+
def update_entity_index(self, entity_id: int, db: Session):
|
48
|
+
"""Update both FTS and vector indexes for an entity"""
|
49
|
+
pass
|
50
|
+
|
51
|
+
@abstractmethod
|
52
|
+
def batch_update_entity_indices(self, entity_ids: List[int], db: Session):
|
53
|
+
"""Batch update both FTS and vector indexes for multiple entities"""
|
54
|
+
pass
|
55
|
+
|
56
|
+
@abstractmethod
|
57
|
+
def get_search_stats(
|
58
|
+
self,
|
59
|
+
query: str,
|
60
|
+
db: Session,
|
61
|
+
library_ids: Optional[List[int]] = None,
|
62
|
+
start: Optional[int] = None,
|
63
|
+
end: Optional[int] = None,
|
64
|
+
app_names: Optional[List[str]] = None,
|
65
|
+
) -> dict:
|
66
|
+
"""Get statistics for search results including date range and app name counts."""
|
67
|
+
pass
|
68
|
+
|
69
|
+
def prepare_vec_data(self, entity) -> str:
|
70
|
+
"""Prepare metadata for vector embedding.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
entity: The entity object containing metadata entries
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
str: Processed metadata string for vector embedding
|
77
|
+
"""
|
78
|
+
vec_metadata = "\n".join(
|
79
|
+
[
|
80
|
+
f"{entry.key}: {entry.value}"
|
81
|
+
for entry in entity.metadata_entries
|
82
|
+
if entry.key not in ["ocr_result", "sequence"]
|
83
|
+
]
|
84
|
+
)
|
85
|
+
ocr_result = next(
|
86
|
+
(
|
87
|
+
entry.value
|
88
|
+
for entry in entity.metadata_entries
|
89
|
+
if entry.key == "ocr_result"
|
90
|
+
),
|
91
|
+
"",
|
92
|
+
)
|
93
|
+
vec_metadata += (
|
94
|
+
f"\nocr_result: {self.process_ocr_result(ocr_result, max_length=128)}"
|
95
|
+
)
|
96
|
+
return vec_metadata
|
97
|
+
|
98
|
+
def process_ocr_result(self, value, max_length=4096):
|
99
|
+
"""Process OCR result data.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
value: OCR result data as string
|
103
|
+
max_length: Maximum number of items to process
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
str: Processed OCR result
|
107
|
+
"""
|
108
|
+
try:
|
109
|
+
ocr_data = json.loads(value)
|
110
|
+
if isinstance(ocr_data, list) and all(
|
111
|
+
isinstance(item, dict)
|
112
|
+
and "dt_boxes" in item
|
113
|
+
and "rec_txt" in item
|
114
|
+
and "score" in item
|
115
|
+
for item in ocr_data
|
116
|
+
):
|
117
|
+
return " ".join(item["rec_txt"] for item in ocr_data[:max_length])
|
118
|
+
else:
|
119
|
+
return json.dumps(ocr_data, indent=2)
|
120
|
+
except json.JSONDecodeError:
|
121
|
+
return value
|
122
|
+
|
123
|
+
|
124
|
+
class PostgreSQLSearchProvider(SearchProvider):
|
125
|
+
"""
|
126
|
+
PostgreSQL implementation of SearchProvider.
|
127
|
+
"""
|
128
|
+
|
129
|
+
def tokenize_text(self, text: str) -> str:
|
130
|
+
"""Tokenize text using jieba for both Chinese and English text."""
|
131
|
+
if not text:
|
132
|
+
return ""
|
133
|
+
# Tokenize the text using jieba
|
134
|
+
words = jieba.cut(text)
|
135
|
+
# Join with spaces for PostgreSQL full-text search
|
136
|
+
return " ".join(words)
|
137
|
+
|
138
|
+
def prepare_fts_data(self, entity) -> tuple[str, str, str]:
|
139
|
+
"""Prepare data for full-text search with jieba tokenization."""
|
140
|
+
# Process filepath: keep directory structure but normalize separators
|
141
|
+
# Also extract the filename without extension for better searchability
|
142
|
+
filepath = entity.filepath.replace("\\", "/") # normalize separators
|
143
|
+
filename = os.path.basename(filepath)
|
144
|
+
filename_without_ext = os.path.splitext(filename)[0]
|
145
|
+
# Split filename by common separators (-, _, etc) to make parts searchable
|
146
|
+
filename_parts = filename_without_ext.replace("-", " ").replace("_", " ")
|
147
|
+
processed_filepath = f"{filepath} {filename_parts}"
|
148
|
+
|
149
|
+
# Tokenize tags
|
150
|
+
tags = " ".join(entity.tag_names)
|
151
|
+
tokenized_tags = self.tokenize_text(tags)
|
152
|
+
|
153
|
+
# Tokenize metadata
|
154
|
+
metadata_entries = [
|
155
|
+
f"{entry.key}: {self.process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
|
156
|
+
for entry in entity.metadata_entries
|
157
|
+
]
|
158
|
+
metadata = "\n".join(metadata_entries)
|
159
|
+
tokenized_metadata = self.tokenize_text(metadata)
|
160
|
+
|
161
|
+
return processed_filepath, tokenized_tags, tokenized_metadata
|
162
|
+
|
163
|
+
def update_entity_index(self, entity_id: int, db: Session):
|
164
|
+
"""Update both FTS and vector indexes for an entity"""
|
165
|
+
try:
|
166
|
+
from .crud import get_entity_by_id
|
167
|
+
|
168
|
+
entity = get_entity_by_id(entity_id, db)
|
169
|
+
if not entity:
|
170
|
+
raise ValueError(f"Entity with id {entity_id} not found")
|
171
|
+
|
172
|
+
# Update FTS index with tokenized data
|
173
|
+
processed_filepath, tokenized_tags, tokenized_metadata = (
|
174
|
+
self.prepare_fts_data(entity)
|
175
|
+
)
|
176
|
+
|
177
|
+
db.execute(
|
178
|
+
text(
|
179
|
+
"""
|
180
|
+
INSERT INTO entities_fts (id, filepath, tags, metadata)
|
181
|
+
VALUES (:id, :filepath, :tags, :metadata)
|
182
|
+
ON CONFLICT (id) DO UPDATE SET
|
183
|
+
filepath = :filepath,
|
184
|
+
tags = :tags,
|
185
|
+
metadata = :metadata
|
186
|
+
"""
|
187
|
+
),
|
188
|
+
{
|
189
|
+
"id": entity.id,
|
190
|
+
"filepath": processed_filepath,
|
191
|
+
"tags": tokenized_tags,
|
192
|
+
"metadata": tokenized_metadata,
|
193
|
+
},
|
194
|
+
)
|
195
|
+
|
196
|
+
# Update vector index
|
197
|
+
vec_metadata = self.prepare_vec_data(entity)
|
198
|
+
with logfire.span("get embedding for entity metadata"):
|
199
|
+
embeddings = get_embeddings([vec_metadata])
|
200
|
+
logfire.info(f"vec_metadata: {vec_metadata}")
|
201
|
+
|
202
|
+
if embeddings and embeddings[0]:
|
203
|
+
# Extract app_name from metadata_entries
|
204
|
+
app_name = next(
|
205
|
+
(
|
206
|
+
entry.value
|
207
|
+
for entry in entity.metadata_entries
|
208
|
+
if entry.key == "active_app"
|
209
|
+
),
|
210
|
+
"unknown", # Default to 'unknown' if not found
|
211
|
+
)
|
212
|
+
# Get file_type_group from entity
|
213
|
+
file_type_group = entity.file_type_group or "unknown"
|
214
|
+
|
215
|
+
# Convert file_created_at to integer timestamp
|
216
|
+
created_at_timestamp = int(datetime.now().timestamp())
|
217
|
+
file_created_at_timestamp = int(entity.file_created_at.timestamp())
|
218
|
+
file_created_at_date = entity.file_created_at.strftime("%Y-%m-%d")
|
219
|
+
|
220
|
+
db.execute(
|
221
|
+
text(
|
222
|
+
"""
|
223
|
+
INSERT INTO entities_vec_v2 (
|
224
|
+
rowid, embedding, app_name, file_type_group,
|
225
|
+
created_at_timestamp, file_created_at_timestamp,
|
226
|
+
file_created_at_date, library_id
|
227
|
+
)
|
228
|
+
VALUES (
|
229
|
+
:id, vector(:embedding), :app_name, :file_type_group,
|
230
|
+
:created_at_timestamp, :file_created_at_timestamp,
|
231
|
+
:file_created_at_date, :library_id
|
232
|
+
)
|
233
|
+
ON CONFLICT (rowid) DO UPDATE SET
|
234
|
+
embedding = vector(:embedding),
|
235
|
+
app_name = :app_name,
|
236
|
+
file_type_group = :file_type_group,
|
237
|
+
created_at_timestamp = :created_at_timestamp,
|
238
|
+
file_created_at_timestamp = :file_created_at_timestamp,
|
239
|
+
file_created_at_date = :file_created_at_date,
|
240
|
+
library_id = :library_id
|
241
|
+
"""
|
242
|
+
),
|
243
|
+
{
|
244
|
+
"id": entity.id,
|
245
|
+
"embedding": str(
|
246
|
+
embeddings[0]
|
247
|
+
), # Convert to string for PostgreSQL vector type
|
248
|
+
"app_name": app_name,
|
249
|
+
"file_type_group": file_type_group,
|
250
|
+
"created_at_timestamp": created_at_timestamp,
|
251
|
+
"file_created_at_timestamp": file_created_at_timestamp,
|
252
|
+
"file_created_at_date": file_created_at_date,
|
253
|
+
"library_id": entity.library_id,
|
254
|
+
},
|
255
|
+
)
|
256
|
+
|
257
|
+
db.commit()
|
258
|
+
except Exception as e:
|
259
|
+
logger.error(f"Error updating indexes for entity {entity_id}: {e}")
|
260
|
+
db.rollback()
|
261
|
+
raise
|
262
|
+
|
263
|
+
def batch_update_entity_indices(self, entity_ids: List[int], db: Session):
|
264
|
+
"""Batch update both FTS and vector indexes for multiple entities"""
|
265
|
+
try:
|
266
|
+
from sqlalchemy.orm import selectinload
|
267
|
+
from .models import EntityModel
|
268
|
+
|
269
|
+
entities = (
|
270
|
+
db.query(EntityModel)
|
271
|
+
.filter(EntityModel.id.in_(entity_ids))
|
272
|
+
.options(
|
273
|
+
selectinload(EntityModel.metadata_entries),
|
274
|
+
selectinload(EntityModel.tags),
|
275
|
+
)
|
276
|
+
.all()
|
277
|
+
)
|
278
|
+
found_ids = {entity.id for entity in entities}
|
279
|
+
|
280
|
+
missing_ids = set(entity_ids) - found_ids
|
281
|
+
if missing_ids:
|
282
|
+
raise ValueError(f"Entities not found: {missing_ids}")
|
283
|
+
|
284
|
+
# Check existing vector indices and their timestamps
|
285
|
+
existing_vec_indices = db.execute(
|
286
|
+
text(
|
287
|
+
"""
|
288
|
+
SELECT rowid, created_at_timestamp
|
289
|
+
FROM entities_vec_v2
|
290
|
+
WHERE rowid = ANY(:entity_ids)
|
291
|
+
"""
|
292
|
+
),
|
293
|
+
{"entity_ids": entity_ids},
|
294
|
+
).fetchall()
|
295
|
+
|
296
|
+
# Create lookup of vector index timestamps
|
297
|
+
vec_timestamps = {row[0]: row[1] for row in existing_vec_indices}
|
298
|
+
|
299
|
+
# Separate entities that need indexing
|
300
|
+
needs_index = []
|
301
|
+
for entity in entities:
|
302
|
+
entity_last_scan = int(entity.last_scan_at.timestamp())
|
303
|
+
vec_timestamp = vec_timestamps.get(entity.id, 0)
|
304
|
+
|
305
|
+
# Entity needs full indexing if last_scan_at is
|
306
|
+
# more recent than the vector index timestamp
|
307
|
+
if entity_last_scan > vec_timestamp:
|
308
|
+
needs_index.append(entity)
|
309
|
+
|
310
|
+
logfire.info(
|
311
|
+
f"Entities needing full indexing: {len(needs_index)}/{len(entity_ids)}"
|
312
|
+
)
|
313
|
+
|
314
|
+
# Update vector index only for entities that need it
|
315
|
+
if needs_index:
|
316
|
+
vec_metadata_list = [
|
317
|
+
self.prepare_vec_data(entity) for entity in needs_index
|
318
|
+
]
|
319
|
+
with logfire.span("get embedding in batch indexing"):
|
320
|
+
embeddings = get_embeddings(vec_metadata_list)
|
321
|
+
logfire.info(f"vec_metadata_list: {vec_metadata_list}")
|
322
|
+
|
323
|
+
# Prepare batch insert data for vector index
|
324
|
+
created_at_timestamp = int(datetime.now().timestamp())
|
325
|
+
insert_values = []
|
326
|
+
for entity, embedding in zip(needs_index, embeddings):
|
327
|
+
if embedding:
|
328
|
+
app_name = next(
|
329
|
+
(
|
330
|
+
entry.value
|
331
|
+
for entry in entity.metadata_entries
|
332
|
+
if entry.key == "active_app"
|
333
|
+
),
|
334
|
+
"unknown",
|
335
|
+
)
|
336
|
+
file_type_group = entity.file_type_group or "unknown"
|
337
|
+
file_created_at_timestamp = int(
|
338
|
+
entity.file_created_at.timestamp()
|
339
|
+
)
|
340
|
+
file_created_at_date = entity.file_created_at.strftime(
|
341
|
+
"%Y-%m-%d"
|
342
|
+
)
|
343
|
+
|
344
|
+
insert_values.append(
|
345
|
+
{
|
346
|
+
"id": entity.id,
|
347
|
+
"embedding": str(
|
348
|
+
embedding
|
349
|
+
), # Convert to string for PostgreSQL vector type
|
350
|
+
"app_name": app_name,
|
351
|
+
"file_type_group": file_type_group,
|
352
|
+
"created_at_timestamp": created_at_timestamp,
|
353
|
+
"file_created_at_timestamp": file_created_at_timestamp,
|
354
|
+
"file_created_at_date": file_created_at_date,
|
355
|
+
"library_id": entity.library_id,
|
356
|
+
}
|
357
|
+
)
|
358
|
+
|
359
|
+
# Batch insert/update vector index
|
360
|
+
if insert_values:
|
361
|
+
db.execute(
|
362
|
+
text(
|
363
|
+
"""
|
364
|
+
INSERT INTO entities_vec_v2 (
|
365
|
+
rowid, embedding, app_name, file_type_group,
|
366
|
+
created_at_timestamp, file_created_at_timestamp,
|
367
|
+
file_created_at_date, library_id
|
368
|
+
)
|
369
|
+
VALUES (
|
370
|
+
:id, vector(:embedding), :app_name, :file_type_group,
|
371
|
+
:created_at_timestamp, :file_created_at_timestamp,
|
372
|
+
:file_created_at_date, :library_id
|
373
|
+
)
|
374
|
+
ON CONFLICT (rowid) DO UPDATE SET
|
375
|
+
embedding = vector(:embedding),
|
376
|
+
app_name = :app_name,
|
377
|
+
file_type_group = :file_type_group,
|
378
|
+
created_at_timestamp = :created_at_timestamp,
|
379
|
+
file_created_at_timestamp = :file_created_at_timestamp,
|
380
|
+
file_created_at_date = :file_created_at_date,
|
381
|
+
library_id = :library_id
|
382
|
+
"""
|
383
|
+
),
|
384
|
+
insert_values,
|
385
|
+
)
|
386
|
+
|
387
|
+
# Update FTS index
|
388
|
+
for entity in needs_index:
|
389
|
+
processed_filepath, tokenized_tags, tokenized_metadata = (
|
390
|
+
self.prepare_fts_data(entity)
|
391
|
+
)
|
392
|
+
|
393
|
+
db.execute(
|
394
|
+
text(
|
395
|
+
"""
|
396
|
+
INSERT INTO entities_fts (id, filepath, tags, metadata)
|
397
|
+
VALUES (:id, :filepath, :tags, :metadata)
|
398
|
+
ON CONFLICT (id) DO UPDATE SET
|
399
|
+
filepath = :filepath,
|
400
|
+
tags = :tags,
|
401
|
+
metadata = :metadata
|
402
|
+
"""
|
403
|
+
),
|
404
|
+
{
|
405
|
+
"id": entity.id,
|
406
|
+
"filepath": processed_filepath,
|
407
|
+
"tags": tokenized_tags,
|
408
|
+
"metadata": tokenized_metadata,
|
409
|
+
},
|
410
|
+
)
|
411
|
+
|
412
|
+
db.commit()
|
413
|
+
|
414
|
+
except Exception as e:
|
415
|
+
logger.error(f"Error batch updating indexes: {e}")
|
416
|
+
db.rollback()
|
417
|
+
raise
|
418
|
+
|
419
|
+
def full_text_search(
|
420
|
+
self,
|
421
|
+
query: str,
|
422
|
+
db: Session,
|
423
|
+
limit: int = 200,
|
424
|
+
library_ids: Optional[List[int]] = None,
|
425
|
+
start: Optional[int] = None,
|
426
|
+
end: Optional[int] = None,
|
427
|
+
app_names: Optional[List[str]] = None,
|
428
|
+
) -> List[int]:
|
429
|
+
base_sql = """
|
430
|
+
WITH search_results AS (
|
431
|
+
SELECT e.id,
|
432
|
+
ts_rank_cd(f.search_vector, websearch_to_tsquery('simple', :query)) as rank
|
433
|
+
FROM entities_fts f
|
434
|
+
JOIN entities e ON e.id = f.id
|
435
|
+
WHERE f.search_vector @@ websearch_to_tsquery('simple', :query)
|
436
|
+
AND e.file_type_group = 'image'
|
437
|
+
"""
|
438
|
+
|
439
|
+
where_clauses = []
|
440
|
+
if library_ids:
|
441
|
+
where_clauses.append("e.library_id = ANY(:library_ids)")
|
442
|
+
|
443
|
+
if start is not None and end is not None:
|
444
|
+
where_clauses.append(
|
445
|
+
"EXTRACT(EPOCH FROM e.file_created_at) BETWEEN :start AND :end"
|
446
|
+
)
|
447
|
+
|
448
|
+
if app_names:
|
449
|
+
where_clauses.append(
|
450
|
+
"""
|
451
|
+
EXISTS (
|
452
|
+
SELECT 1 FROM metadata_entries me
|
453
|
+
WHERE me.entity_id = e.id
|
454
|
+
AND me.key = 'active_app'
|
455
|
+
AND me.value = ANY(:app_names)
|
456
|
+
)
|
457
|
+
"""
|
458
|
+
)
|
459
|
+
|
460
|
+
if where_clauses:
|
461
|
+
base_sql += " AND " + " AND ".join(where_clauses)
|
462
|
+
|
463
|
+
base_sql += ")\nSELECT id FROM search_results ORDER BY rank DESC LIMIT :limit"
|
464
|
+
|
465
|
+
params = {"query": query, "limit": limit}
|
466
|
+
|
467
|
+
if library_ids:
|
468
|
+
params["library_ids"] = library_ids
|
469
|
+
|
470
|
+
if start is not None and end is not None:
|
471
|
+
params["start"] = start
|
472
|
+
params["end"] = end
|
473
|
+
|
474
|
+
if app_names:
|
475
|
+
params["app_names"] = app_names
|
476
|
+
|
477
|
+
logfire.info(
|
478
|
+
"full text search {query=} {limit=}",
|
479
|
+
query=query,
|
480
|
+
limit=limit,
|
481
|
+
)
|
482
|
+
|
483
|
+
sql = text(base_sql)
|
484
|
+
result = db.execute(sql, params).fetchall()
|
485
|
+
return [row[0] for row in result]
|
486
|
+
|
487
|
+
def vector_search(
|
488
|
+
self,
|
489
|
+
embeddings: List[float],
|
490
|
+
db: Session,
|
491
|
+
limit: int = 200,
|
492
|
+
library_ids: Optional[List[int]] = None,
|
493
|
+
start: Optional[int] = None,
|
494
|
+
end: Optional[int] = None,
|
495
|
+
app_names: Optional[List[str]] = None,
|
496
|
+
) -> List[int]:
|
497
|
+
sql_query = """
|
498
|
+
SELECT rowid
|
499
|
+
FROM entities_vec_v2
|
500
|
+
WHERE file_type_group = 'image'
|
501
|
+
"""
|
502
|
+
|
503
|
+
params = {
|
504
|
+
"embedding": str(
|
505
|
+
embeddings
|
506
|
+
), # Convert to string for PostgreSQL vector type
|
507
|
+
"limit": limit,
|
508
|
+
}
|
509
|
+
|
510
|
+
if library_ids:
|
511
|
+
sql_query += " AND library_id = ANY(:library_ids)"
|
512
|
+
params["library_ids"] = library_ids
|
513
|
+
|
514
|
+
if start is not None and end is not None:
|
515
|
+
sql_query += " AND file_created_at_timestamp BETWEEN :start AND :end"
|
516
|
+
params["start"] = start
|
517
|
+
params["end"] = end
|
518
|
+
|
519
|
+
if app_names:
|
520
|
+
sql_query += " AND app_name = ANY(:app_names)"
|
521
|
+
params["app_names"] = app_names
|
522
|
+
|
523
|
+
# Add vector similarity search
|
524
|
+
sql_query += """
|
525
|
+
ORDER BY embedding <=> vector(:embedding)
|
526
|
+
LIMIT :limit
|
527
|
+
"""
|
528
|
+
|
529
|
+
sql = text(sql_query)
|
530
|
+
result = db.execute(sql, params).fetchall()
|
531
|
+
|
532
|
+
return [row[0] for row in result]
|
533
|
+
|
534
|
+
def reciprocal_rank_fusion(
|
535
|
+
self, fts_results: List[int], vec_results: List[int], k: int = 60
|
536
|
+
) -> List[Tuple[int, float]]:
|
537
|
+
rank_dict = defaultdict(float)
|
538
|
+
|
539
|
+
# Weight for full-text search results: 0.7
|
540
|
+
for rank, result_id in enumerate(fts_results):
|
541
|
+
rank_dict[result_id] += 0.7 * (1 / (k + rank + 1))
|
542
|
+
|
543
|
+
# Weight for vector search results: 0.3
|
544
|
+
for rank, result_id in enumerate(vec_results):
|
545
|
+
rank_dict[result_id] += 0.3 * (1 / (k + rank + 1))
|
546
|
+
|
547
|
+
return sorted(rank_dict.items(), key=lambda x: x[1], reverse=True)
|
548
|
+
|
549
|
+
def hybrid_search(
|
550
|
+
self,
|
551
|
+
query: str,
|
552
|
+
db: Session,
|
553
|
+
limit: int = 200,
|
554
|
+
library_ids: Optional[List[int]] = None,
|
555
|
+
start: Optional[int] = None,
|
556
|
+
end: Optional[int] = None,
|
557
|
+
app_names: Optional[List[str]] = None,
|
558
|
+
) -> List[int]:
|
559
|
+
with logfire.span("full_text_search {query=}", query=query):
|
560
|
+
fts_results = self.full_text_search(
|
561
|
+
query, db, limit, library_ids, start, end, app_names
|
562
|
+
)
|
563
|
+
logger.info(f"Full-text search obtained {len(fts_results)} results")
|
564
|
+
|
565
|
+
with logfire.span("vector_search {query=}", query=query):
|
566
|
+
embeddings = get_embeddings([query])
|
567
|
+
if embeddings and embeddings[0]:
|
568
|
+
vec_results = self.vector_search(
|
569
|
+
embeddings[0], db, limit * 2, library_ids, start, end, app_names
|
570
|
+
)
|
571
|
+
logger.info(f"Vector search obtained {len(vec_results)} results")
|
572
|
+
else:
|
573
|
+
vec_results = []
|
574
|
+
|
575
|
+
with logfire.span("reciprocal_rank_fusion {query=}", query=query):
|
576
|
+
combined_results = self.reciprocal_rank_fusion(fts_results, vec_results)
|
577
|
+
|
578
|
+
sorted_ids = [id for id, _ in combined_results][:limit]
|
579
|
+
logger.info(f"Hybrid search results (sorted IDs): {sorted_ids}")
|
580
|
+
|
581
|
+
return sorted_ids
|
582
|
+
|
583
|
+
@logfire.instrument
|
584
|
+
def get_search_stats(
|
585
|
+
self,
|
586
|
+
query: str,
|
587
|
+
db: Session,
|
588
|
+
library_ids: Optional[List[int]] = None,
|
589
|
+
start: Optional[int] = None,
|
590
|
+
end: Optional[int] = None,
|
591
|
+
app_names: Optional[List[str]] = None,
|
592
|
+
) -> dict:
|
593
|
+
"""Get statistics for search results including date range and app name counts."""
|
594
|
+
MIN_SAMPLE_SIZE = 1024
|
595
|
+
MAX_SAMPLE_SIZE = 2048
|
596
|
+
|
597
|
+
with logfire.span(
|
598
|
+
"full_text_search in stats {query=} {limit=}",
|
599
|
+
query=query,
|
600
|
+
limit=MAX_SAMPLE_SIZE,
|
601
|
+
):
|
602
|
+
fts_results = self.full_text_search(
|
603
|
+
query,
|
604
|
+
db,
|
605
|
+
limit=MAX_SAMPLE_SIZE,
|
606
|
+
library_ids=library_ids,
|
607
|
+
start=start,
|
608
|
+
end=end,
|
609
|
+
app_names=app_names,
|
610
|
+
)
|
611
|
+
|
612
|
+
vec_limit = max(min(len(fts_results) * 2, MAX_SAMPLE_SIZE), MIN_SAMPLE_SIZE)
|
613
|
+
|
614
|
+
with logfire.span(
|
615
|
+
"vec_search in stats {query=} {limit=}", query=query, limit=vec_limit
|
616
|
+
):
|
617
|
+
embeddings = get_embeddings([query])
|
618
|
+
if embeddings and embeddings[0]:
|
619
|
+
vec_results = self.vector_search(
|
620
|
+
embeddings[0],
|
621
|
+
db,
|
622
|
+
limit=vec_limit,
|
623
|
+
library_ids=library_ids,
|
624
|
+
start=start,
|
625
|
+
end=end,
|
626
|
+
app_names=app_names,
|
627
|
+
)
|
628
|
+
else:
|
629
|
+
vec_results = []
|
630
|
+
|
631
|
+
logfire.info(f"fts_results: {len(fts_results)} vec_results: {len(vec_results)}")
|
632
|
+
|
633
|
+
entity_ids = set(fts_results + vec_results)
|
634
|
+
|
635
|
+
if not entity_ids:
|
636
|
+
return {
|
637
|
+
"date_range": {"earliest": None, "latest": None},
|
638
|
+
"app_name_counts": {},
|
639
|
+
}
|
640
|
+
|
641
|
+
entity_ids_str = ",".join(str(id) for id in entity_ids)
|
642
|
+
date_range = db.execute(
|
643
|
+
text(
|
644
|
+
f"""
|
645
|
+
SELECT
|
646
|
+
MIN(file_created_at) as earliest,
|
647
|
+
MAX(file_created_at) as latest
|
648
|
+
FROM entities
|
649
|
+
WHERE id IN ({entity_ids_str})
|
650
|
+
"""
|
651
|
+
)
|
652
|
+
).first()
|
653
|
+
|
654
|
+
app_name_counts = db.execute(
|
655
|
+
text(
|
656
|
+
f"""
|
657
|
+
SELECT me.value, COUNT(*) as count
|
658
|
+
FROM metadata_entries me
|
659
|
+
WHERE me.entity_id IN ({entity_ids_str}) and me.key = 'active_app'
|
660
|
+
GROUP BY me.value
|
661
|
+
ORDER BY count DESC
|
662
|
+
"""
|
663
|
+
)
|
664
|
+
).all()
|
665
|
+
|
666
|
+
return {
|
667
|
+
"date_range": {
|
668
|
+
"earliest": date_range.earliest,
|
669
|
+
"latest": date_range.latest,
|
670
|
+
},
|
671
|
+
"app_name_counts": {app_name: count for app_name, count in app_name_counts},
|
672
|
+
}
|
673
|
+
|
674
|
+
|
675
|
+
class SqliteSearchProvider(SearchProvider):
|
676
|
+
def and_words(self, input_string: str) -> str:
|
677
|
+
words = input_string.split()
|
678
|
+
result = " AND ".join(words)
|
679
|
+
return result
|
680
|
+
|
681
|
+
def prepare_fts_data(self, entity) -> tuple[str, str]:
|
682
|
+
tags = ", ".join(entity.tag_names)
|
683
|
+
fts_metadata = "\n".join(
|
684
|
+
[
|
685
|
+
f"{entry.key}: {self.process_ocr_result(entry.value) if entry.key == 'ocr_result' else entry.value}"
|
686
|
+
for entry in entity.metadata_entries
|
687
|
+
]
|
688
|
+
)
|
689
|
+
return tags, fts_metadata
|
690
|
+
|
691
|
+
def update_entity_index(self, entity_id: int, db: Session):
|
692
|
+
"""Update both FTS and vector indexes for an entity"""
|
693
|
+
try:
|
694
|
+
from .crud import get_entity_by_id
|
695
|
+
|
696
|
+
entity = get_entity_by_id(entity_id, db)
|
697
|
+
if not entity:
|
698
|
+
raise ValueError(f"Entity with id {entity_id} not found")
|
699
|
+
|
700
|
+
# Update FTS index
|
701
|
+
tags, fts_metadata = self.prepare_fts_data(entity)
|
702
|
+
db.execute(
|
703
|
+
text(
|
704
|
+
"""
|
705
|
+
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
|
706
|
+
VALUES(:id, :filepath, :tags, :metadata)
|
707
|
+
"""
|
708
|
+
),
|
709
|
+
{
|
710
|
+
"id": entity.id,
|
711
|
+
"filepath": entity.filepath,
|
712
|
+
"tags": tags,
|
713
|
+
"metadata": fts_metadata,
|
714
|
+
},
|
715
|
+
)
|
716
|
+
|
717
|
+
# Update vector index
|
718
|
+
vec_metadata = self.prepare_vec_data(entity)
|
719
|
+
with logfire.span("get embedding for entity metadata"):
|
720
|
+
embeddings = get_embeddings([vec_metadata])
|
721
|
+
logfire.info(f"vec_metadata: {vec_metadata}")
|
722
|
+
|
723
|
+
if embeddings and embeddings[0]:
|
724
|
+
db.execute(
|
725
|
+
text("DELETE FROM entities_vec_v2 WHERE rowid = :id"),
|
726
|
+
{"id": entity.id},
|
727
|
+
)
|
728
|
+
|
729
|
+
# Extract app_name from metadata_entries
|
730
|
+
app_name = next(
|
731
|
+
(
|
732
|
+
entry.value
|
733
|
+
for entry in entity.metadata_entries
|
734
|
+
if entry.key == "active_app"
|
735
|
+
),
|
736
|
+
"unknown", # Default to 'unknown' if not found
|
737
|
+
)
|
738
|
+
# Get file_type_group from entity
|
739
|
+
file_type_group = entity.file_type_group or "unknown"
|
740
|
+
|
741
|
+
# Convert file_created_at to integer timestamp
|
742
|
+
created_at_timestamp = int(entity.file_created_at.timestamp())
|
743
|
+
|
744
|
+
db.execute(
|
745
|
+
text(
|
746
|
+
"""
|
747
|
+
INSERT INTO entities_vec_v2 (
|
748
|
+
rowid, embedding, app_name, file_type_group, created_at_timestamp, file_created_at_timestamp,
|
749
|
+
file_created_at_date, library_id
|
750
|
+
)
|
751
|
+
VALUES (:id, :embedding, :app_name, :file_type_group, :created_at_timestamp, :file_created_at_timestamp, :file_created_at_date, :library_id)
|
752
|
+
"""
|
753
|
+
),
|
754
|
+
{
|
755
|
+
"id": entity.id,
|
756
|
+
"embedding": serialize_float32(embeddings[0]),
|
757
|
+
"app_name": app_name,
|
758
|
+
"file_type_group": file_type_group,
|
759
|
+
"created_at_timestamp": created_at_timestamp,
|
760
|
+
"file_created_at_timestamp": int(
|
761
|
+
entity.file_created_at.timestamp()
|
762
|
+
),
|
763
|
+
"file_created_at_date": entity.file_created_at.strftime(
|
764
|
+
"%Y-%m-%d"
|
765
|
+
),
|
766
|
+
"library_id": entity.library_id,
|
767
|
+
},
|
768
|
+
)
|
769
|
+
|
770
|
+
db.commit()
|
771
|
+
except Exception as e:
|
772
|
+
logger.error(f"Error updating indexes for entity {entity_id}: {e}")
|
773
|
+
db.rollback()
|
774
|
+
raise
|
775
|
+
|
776
|
+
def batch_update_entity_indices(self, entity_ids: List[int], db: Session):
|
777
|
+
"""Batch update both FTS and vector indexes for multiple entities"""
|
778
|
+
try:
|
779
|
+
from sqlalchemy.orm import selectinload
|
780
|
+
from .models import EntityModel
|
781
|
+
|
782
|
+
entities = (
|
783
|
+
db.query(EntityModel)
|
784
|
+
.filter(EntityModel.id.in_(entity_ids))
|
785
|
+
.options(
|
786
|
+
selectinload(EntityModel.metadata_entries),
|
787
|
+
selectinload(EntityModel.tags),
|
788
|
+
)
|
789
|
+
.all()
|
790
|
+
)
|
791
|
+
found_ids = {entity.id for entity in entities}
|
792
|
+
|
793
|
+
missing_ids = set(entity_ids) - found_ids
|
794
|
+
if missing_ids:
|
795
|
+
raise ValueError(f"Entities not found: {missing_ids}")
|
796
|
+
|
797
|
+
# Check existing vector indices and their timestamps
|
798
|
+
existing_vec_indices = db.execute(
|
799
|
+
text(
|
800
|
+
"""
|
801
|
+
SELECT rowid, created_at_timestamp
|
802
|
+
FROM entities_vec_v2
|
803
|
+
WHERE rowid IN :entity_ids
|
804
|
+
"""
|
805
|
+
).bindparams(bindparam("entity_ids", expanding=True)),
|
806
|
+
{"entity_ids": tuple(entity_ids)},
|
807
|
+
).fetchall()
|
808
|
+
|
809
|
+
# Create lookup of vector index timestamps
|
810
|
+
vec_timestamps = {row[0]: row[1] for row in existing_vec_indices}
|
811
|
+
|
812
|
+
# Separate entities that need indexing
|
813
|
+
needs_index = []
|
814
|
+
|
815
|
+
for entity in entities:
|
816
|
+
entity_last_scan = int(entity.last_scan_at.timestamp())
|
817
|
+
vec_timestamp = vec_timestamps.get(entity.id, 0)
|
818
|
+
|
819
|
+
# Entity needs full indexing if last_scan_at is
|
820
|
+
# more recent than the vector index timestamp
|
821
|
+
if entity_last_scan > vec_timestamp:
|
822
|
+
needs_index.append(entity)
|
823
|
+
|
824
|
+
logfire.info(
|
825
|
+
f"Entities needing full indexing: {len(needs_index)}/{len(entity_ids)}"
|
826
|
+
)
|
827
|
+
|
828
|
+
# Handle entities needing full indexing
|
829
|
+
if needs_index:
|
830
|
+
vec_metadata_list = [
|
831
|
+
self.prepare_vec_data(entity) for entity in needs_index
|
832
|
+
]
|
833
|
+
with logfire.span("get embedding in batch indexing"):
|
834
|
+
embeddings = get_embeddings(vec_metadata_list)
|
835
|
+
logfire.info(f"vec_metadata_list: {vec_metadata_list}")
|
836
|
+
|
837
|
+
# Delete all existing vector indices in one query
|
838
|
+
if needs_index:
|
839
|
+
db.execute(
|
840
|
+
text(
|
841
|
+
"DELETE FROM entities_vec_v2 WHERE rowid IN :ids"
|
842
|
+
).bindparams(bindparam("ids", expanding=True)),
|
843
|
+
{"ids": tuple(entity.id for entity in needs_index)},
|
844
|
+
)
|
845
|
+
|
846
|
+
# Prepare batch insert data
|
847
|
+
created_at_timestamp = int(datetime.now().timestamp())
|
848
|
+
insert_values = []
|
849
|
+
for entity, embedding in zip(needs_index, embeddings):
|
850
|
+
app_name = next(
|
851
|
+
(
|
852
|
+
entry.value
|
853
|
+
for entry in entity.metadata_entries
|
854
|
+
if entry.key == "active_app"
|
855
|
+
),
|
856
|
+
"unknown",
|
857
|
+
)
|
858
|
+
file_type_group = entity.file_type_group or "unknown"
|
859
|
+
|
860
|
+
insert_values.append(
|
861
|
+
{
|
862
|
+
"id": entity.id,
|
863
|
+
"embedding": serialize_float32(embedding),
|
864
|
+
"app_name": app_name,
|
865
|
+
"file_type_group": file_type_group,
|
866
|
+
"created_at_timestamp": created_at_timestamp,
|
867
|
+
"file_created_at_timestamp": int(
|
868
|
+
entity.file_created_at.timestamp()
|
869
|
+
),
|
870
|
+
"file_created_at_date": entity.file_created_at.strftime(
|
871
|
+
"%Y-%m-%d"
|
872
|
+
),
|
873
|
+
"library_id": entity.library_id,
|
874
|
+
}
|
875
|
+
)
|
876
|
+
|
877
|
+
# Execute batch insert
|
878
|
+
db.execute(
|
879
|
+
text(
|
880
|
+
"""
|
881
|
+
INSERT INTO entities_vec_v2 (
|
882
|
+
rowid, embedding, app_name, file_type_group,
|
883
|
+
created_at_timestamp, file_created_at_timestamp,
|
884
|
+
file_created_at_date, library_id
|
885
|
+
)
|
886
|
+
VALUES (
|
887
|
+
:id, :embedding, :app_name, :file_type_group,
|
888
|
+
:created_at_timestamp, :file_created_at_timestamp,
|
889
|
+
:file_created_at_date, :library_id
|
890
|
+
)
|
891
|
+
"""
|
892
|
+
),
|
893
|
+
insert_values,
|
894
|
+
)
|
895
|
+
|
896
|
+
# Update FTS index for all entities
|
897
|
+
for entity in entities:
|
898
|
+
tags, fts_metadata = self.prepare_fts_data(entity)
|
899
|
+
db.execute(
|
900
|
+
text(
|
901
|
+
"""
|
902
|
+
INSERT OR REPLACE INTO entities_fts(id, filepath, tags, metadata)
|
903
|
+
VALUES(:id, :filepath, :tags, :metadata)
|
904
|
+
"""
|
905
|
+
),
|
906
|
+
{
|
907
|
+
"id": entity.id,
|
908
|
+
"filepath": entity.filepath,
|
909
|
+
"tags": tags,
|
910
|
+
"metadata": fts_metadata,
|
911
|
+
},
|
912
|
+
)
|
913
|
+
|
914
|
+
db.commit()
|
915
|
+
|
916
|
+
except Exception as e:
|
917
|
+
logger.error(f"Error batch updating indexes: {e}")
|
918
|
+
db.rollback()
|
919
|
+
raise
|
920
|
+
|
921
|
+
def full_text_search(
|
922
|
+
self,
|
923
|
+
query: str,
|
924
|
+
db: Session,
|
925
|
+
limit: int = 200,
|
926
|
+
library_ids: Optional[List[int]] = None,
|
927
|
+
start: Optional[int] = None,
|
928
|
+
end: Optional[int] = None,
|
929
|
+
app_names: Optional[List[str]] = None,
|
930
|
+
) -> List[int]:
|
931
|
+
start_time = time.time()
|
932
|
+
|
933
|
+
and_query = self.and_words(query)
|
934
|
+
|
935
|
+
sql_query = """
|
936
|
+
WITH fts_matches AS (
|
937
|
+
SELECT id, rank
|
938
|
+
FROM entities_fts
|
939
|
+
WHERE entities_fts MATCH jieba_query(:query)
|
940
|
+
)
|
941
|
+
SELECT e.id
|
942
|
+
FROM fts_matches f
|
943
|
+
JOIN entities e ON e.id = f.id
|
944
|
+
WHERE e.file_type_group = 'image'
|
945
|
+
"""
|
946
|
+
|
947
|
+
params = {"query": and_query, "limit": limit}
|
948
|
+
bindparams = []
|
949
|
+
|
950
|
+
if library_ids:
|
951
|
+
sql_query += " AND e.library_id IN :library_ids"
|
952
|
+
params["library_ids"] = tuple(library_ids)
|
953
|
+
bindparams.append(bindparam("library_ids", expanding=True))
|
954
|
+
|
955
|
+
if start is not None and end is not None:
|
956
|
+
sql_query += (
|
957
|
+
" AND strftime('%s', e.file_created_at, 'utc') BETWEEN :start AND :end"
|
958
|
+
)
|
959
|
+
params["start"] = start
|
960
|
+
params["end"] = end
|
961
|
+
|
962
|
+
if app_names:
|
963
|
+
sql_query += """
|
964
|
+
AND EXISTS (
|
965
|
+
SELECT 1 FROM metadata_entries me
|
966
|
+
WHERE me.entity_id = e.id
|
967
|
+
AND me.key = 'active_app'
|
968
|
+
AND me.value IN :app_names
|
969
|
+
)
|
970
|
+
"""
|
971
|
+
params["app_names"] = tuple(app_names)
|
972
|
+
bindparams.append(bindparam("app_names", expanding=True))
|
973
|
+
|
974
|
+
sql_query += " ORDER BY f.rank LIMIT :limit"
|
975
|
+
|
976
|
+
sql = text(sql_query)
|
977
|
+
if bindparams:
|
978
|
+
sql = sql.bindparams(*bindparams)
|
979
|
+
|
980
|
+
result = db.execute(sql, params).fetchall()
|
981
|
+
|
982
|
+
execution_time = time.time() - start_time
|
983
|
+
logger.info(f"Full-text search execution time: {execution_time:.4f} seconds")
|
984
|
+
|
985
|
+
return [row[0] for row in result]
|
986
|
+
|
987
|
+
def vector_search(
|
988
|
+
self,
|
989
|
+
embeddings: List[float],
|
990
|
+
db: Session,
|
991
|
+
limit: int = 200,
|
992
|
+
library_ids: Optional[List[int]] = None,
|
993
|
+
start: Optional[int] = None,
|
994
|
+
end: Optional[int] = None,
|
995
|
+
app_names: Optional[List[str]] = None,
|
996
|
+
) -> List[int]:
|
997
|
+
start_date = None
|
998
|
+
end_date = None
|
999
|
+
if start is not None and end is not None:
|
1000
|
+
start_date = datetime.fromtimestamp(start).strftime("%Y-%m-%d")
|
1001
|
+
end_date = datetime.fromtimestamp(end).strftime("%Y-%m-%d")
|
1002
|
+
|
1003
|
+
sql_query = f"""
|
1004
|
+
SELECT rowid
|
1005
|
+
FROM entities_vec_v2
|
1006
|
+
WHERE embedding MATCH :embedding
|
1007
|
+
AND file_type_group = 'image'
|
1008
|
+
AND K = :limit
|
1009
|
+
{"AND file_created_at_date BETWEEN :start_date AND :end_date" if start_date is not None and end_date is not None else ""}
|
1010
|
+
{"AND file_created_at_timestamp BETWEEN :start AND :end" if start is not None and end is not None else ""}
|
1011
|
+
{"AND library_id IN :library_ids" if library_ids else ""}
|
1012
|
+
{"AND app_name IN :app_names" if app_names else ""}
|
1013
|
+
ORDER BY distance ASC
|
1014
|
+
"""
|
1015
|
+
|
1016
|
+
params = {
|
1017
|
+
"embedding": serialize_float32(embeddings),
|
1018
|
+
"limit": limit,
|
1019
|
+
}
|
1020
|
+
|
1021
|
+
if start is not None and end is not None:
|
1022
|
+
params["start"] = int(start)
|
1023
|
+
params["end"] = int(end)
|
1024
|
+
params["start_date"] = start_date
|
1025
|
+
params["end_date"] = end_date
|
1026
|
+
if library_ids:
|
1027
|
+
params["library_ids"] = tuple(library_ids)
|
1028
|
+
if app_names:
|
1029
|
+
params["app_names"] = tuple(app_names)
|
1030
|
+
|
1031
|
+
sql = text(sql_query)
|
1032
|
+
if app_names:
|
1033
|
+
sql = sql.bindparams(bindparam("app_names", expanding=True))
|
1034
|
+
if library_ids:
|
1035
|
+
sql = sql.bindparams(bindparam("library_ids", expanding=True))
|
1036
|
+
|
1037
|
+
with logfire.span("vec_search"):
|
1038
|
+
result = db.execute(sql, params).fetchall()
|
1039
|
+
|
1040
|
+
return [row[0] for row in result]
|
1041
|
+
|
1042
|
+
def reciprocal_rank_fusion(
|
1043
|
+
self, fts_results: List[int], vec_results: List[int], k: int = 60
|
1044
|
+
) -> List[Tuple[int, float]]:
|
1045
|
+
rank_dict = defaultdict(float)
|
1046
|
+
|
1047
|
+
# Weight for full-text search results: 0.7
|
1048
|
+
for rank, result_id in enumerate(fts_results):
|
1049
|
+
rank_dict[result_id] += 0.7 * (1 / (k + rank + 1))
|
1050
|
+
|
1051
|
+
# Weight for vector search results: 0.3
|
1052
|
+
for rank, result_id in enumerate(vec_results):
|
1053
|
+
rank_dict[result_id] += 0.3 * (1 / (k + rank + 1))
|
1054
|
+
|
1055
|
+
return sorted(rank_dict.items(), key=lambda x: x[1], reverse=True)
|
1056
|
+
|
1057
|
+
def hybrid_search(
|
1058
|
+
self,
|
1059
|
+
query: str,
|
1060
|
+
db: Session,
|
1061
|
+
limit: int = 200,
|
1062
|
+
library_ids: Optional[List[int]] = None,
|
1063
|
+
start: Optional[int] = None,
|
1064
|
+
end: Optional[int] = None,
|
1065
|
+
app_names: Optional[List[str]] = None,
|
1066
|
+
) -> List[int]:
|
1067
|
+
with logfire.span("full_text_search"):
|
1068
|
+
fts_results = self.full_text_search(
|
1069
|
+
query, db, limit, library_ids, start, end, app_names
|
1070
|
+
)
|
1071
|
+
logger.info(f"Full-text search obtained {len(fts_results)} results")
|
1072
|
+
|
1073
|
+
with logfire.span("vector_search"):
|
1074
|
+
embeddings = get_embeddings([query])
|
1075
|
+
if embeddings and embeddings[0]:
|
1076
|
+
vec_results = self.vector_search(
|
1077
|
+
embeddings[0], db, limit * 2, library_ids, start, end, app_names
|
1078
|
+
)
|
1079
|
+
logger.info(f"Vector search obtained {len(vec_results)} results")
|
1080
|
+
else:
|
1081
|
+
vec_results = []
|
1082
|
+
|
1083
|
+
with logfire.span("reciprocal_rank_fusion"):
|
1084
|
+
combined_results = self.reciprocal_rank_fusion(fts_results, vec_results)
|
1085
|
+
|
1086
|
+
sorted_ids = [id for id, _ in combined_results][:limit]
|
1087
|
+
logger.info(f"Hybrid search results (sorted IDs): {sorted_ids}")
|
1088
|
+
|
1089
|
+
return sorted_ids
|
1090
|
+
|
1091
|
+
@logfire.instrument
|
1092
|
+
def get_search_stats(
|
1093
|
+
self,
|
1094
|
+
query: str,
|
1095
|
+
db: Session,
|
1096
|
+
library_ids: Optional[List[int]] = None,
|
1097
|
+
start: Optional[int] = None,
|
1098
|
+
end: Optional[int] = None,
|
1099
|
+
app_names: Optional[List[str]] = None,
|
1100
|
+
) -> dict:
|
1101
|
+
"""Get statistics for search results including date range and tag counts."""
|
1102
|
+
MIN_SAMPLE_SIZE = 2048
|
1103
|
+
MAX_SAMPLE_SIZE = 4096
|
1104
|
+
|
1105
|
+
with logfire.span(
|
1106
|
+
"full_text_search in stats {query=} {limit=}",
|
1107
|
+
query=query,
|
1108
|
+
limit=MAX_SAMPLE_SIZE,
|
1109
|
+
):
|
1110
|
+
fts_results = self.full_text_search(
|
1111
|
+
query,
|
1112
|
+
db,
|
1113
|
+
limit=MAX_SAMPLE_SIZE,
|
1114
|
+
library_ids=library_ids,
|
1115
|
+
start=start,
|
1116
|
+
end=end,
|
1117
|
+
app_names=app_names,
|
1118
|
+
)
|
1119
|
+
|
1120
|
+
vec_limit = max(min(len(fts_results) * 2, MAX_SAMPLE_SIZE), MIN_SAMPLE_SIZE)
|
1121
|
+
|
1122
|
+
with logfire.span(
|
1123
|
+
"vec_search in stats {query=} {limit=}", query=query, limit=vec_limit
|
1124
|
+
):
|
1125
|
+
embeddings = get_embeddings([query])
|
1126
|
+
if embeddings and embeddings[0]:
|
1127
|
+
vec_results = self.vector_search(
|
1128
|
+
embeddings[0],
|
1129
|
+
db,
|
1130
|
+
limit=vec_limit,
|
1131
|
+
library_ids=library_ids,
|
1132
|
+
start=start,
|
1133
|
+
end=end,
|
1134
|
+
app_names=app_names,
|
1135
|
+
)
|
1136
|
+
else:
|
1137
|
+
vec_results = []
|
1138
|
+
|
1139
|
+
logfire.info(f"fts_results: {len(fts_results)} vec_results: {len(vec_results)}")
|
1140
|
+
|
1141
|
+
entity_ids = set(fts_results + vec_results)
|
1142
|
+
|
1143
|
+
if not entity_ids:
|
1144
|
+
return {
|
1145
|
+
"date_range": {"earliest": None, "latest": None},
|
1146
|
+
"app_name_counts": {},
|
1147
|
+
}
|
1148
|
+
|
1149
|
+
entity_ids_str = ",".join(str(id) for id in entity_ids)
|
1150
|
+
date_range = db.execute(
|
1151
|
+
text(
|
1152
|
+
f"""
|
1153
|
+
SELECT
|
1154
|
+
MIN(file_created_at) as earliest,
|
1155
|
+
MAX(file_created_at) as latest
|
1156
|
+
FROM entities
|
1157
|
+
WHERE id IN ({entity_ids_str})
|
1158
|
+
"""
|
1159
|
+
)
|
1160
|
+
).first()
|
1161
|
+
|
1162
|
+
app_name_counts = db.execute(
|
1163
|
+
text(
|
1164
|
+
f"""
|
1165
|
+
SELECT me.value, COUNT(*) as count
|
1166
|
+
FROM metadata_entries me
|
1167
|
+
WHERE me.entity_id IN ({entity_ids_str}) and me.key = 'active_app'
|
1168
|
+
GROUP BY me.value
|
1169
|
+
ORDER BY count DESC
|
1170
|
+
"""
|
1171
|
+
)
|
1172
|
+
).all()
|
1173
|
+
|
1174
|
+
return {
|
1175
|
+
"date_range": {
|
1176
|
+
"earliest": date_range.earliest,
|
1177
|
+
"latest": date_range.latest,
|
1178
|
+
},
|
1179
|
+
"app_name_counts": {app_name: count for app_name, count in app_name_counts},
|
1180
|
+
}
|
1181
|
+
|
1182
|
+
|
1183
|
+
def create_search_provider(database_url: str) -> SearchProvider:
|
1184
|
+
"""
|
1185
|
+
Factory function to create appropriate SearchProvider based on database URL.
|
1186
|
+
|
1187
|
+
Args:
|
1188
|
+
database_url: Database connection URL
|
1189
|
+
|
1190
|
+
Returns:
|
1191
|
+
SearchProvider: Appropriate search provider instance
|
1192
|
+
"""
|
1193
|
+
if database_url.startswith("postgresql://"):
|
1194
|
+
logger.info("Using PostgreSQL search provider")
|
1195
|
+
return PostgreSQLSearchProvider()
|
1196
|
+
else:
|
1197
|
+
logger.info("Using SQLite search provider")
|
1198
|
+
return SqliteSearchProvider()
|