visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Qdrant Indexer - Upload embeddings to Qdrant vector database.
|
|
3
|
+
|
|
4
|
+
Works INDEPENDENTLY of PDF processing and embedding generation.
|
|
5
|
+
Use it if you already have embeddings and just need to upload.
|
|
6
|
+
|
|
7
|
+
Features:
|
|
8
|
+
- Named vectors for multi-vector and pooled search
|
|
9
|
+
- Batch uploading with retry logic
|
|
10
|
+
- Skip-existing for incremental updates
|
|
11
|
+
- Configurable payload indexes
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import hashlib
|
|
15
|
+
import logging
|
|
16
|
+
import time
|
|
17
|
+
from typing import Any, Dict, List, Optional, Set
|
|
18
|
+
from urllib.parse import urlparse
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class QdrantIndexer:
|
|
26
|
+
"""
|
|
27
|
+
Upload visual embeddings to Qdrant.
|
|
28
|
+
|
|
29
|
+
Works independently - just needs embeddings and metadata.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
url: Qdrant server URL
|
|
33
|
+
api_key: Qdrant API key
|
|
34
|
+
collection_name: Name of the collection
|
|
35
|
+
timeout: Request timeout in seconds
|
|
36
|
+
prefer_grpc: Use gRPC protocol (faster but may have issues)
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
>>> indexer = QdrantIndexer(
|
|
40
|
+
... url="https://your-cluster.qdrant.io:6333",
|
|
41
|
+
... api_key="your-api-key",
|
|
42
|
+
... collection_name="my_collection",
|
|
43
|
+
... )
|
|
44
|
+
>>>
|
|
45
|
+
>>> # Create collection
|
|
46
|
+
>>> indexer.create_collection()
|
|
47
|
+
>>>
|
|
48
|
+
>>> # Upload points
|
|
49
|
+
>>> indexer.upload_batch(points)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
url: str,
|
|
55
|
+
api_key: str,
|
|
56
|
+
collection_name: str,
|
|
57
|
+
timeout: int = 60,
|
|
58
|
+
prefer_grpc: bool = False,
|
|
59
|
+
vector_datatype: str = "float32",
|
|
60
|
+
):
|
|
61
|
+
try:
|
|
62
|
+
from qdrant_client import QdrantClient
|
|
63
|
+
except ImportError:
|
|
64
|
+
raise ImportError(
|
|
65
|
+
"Qdrant client not installed. "
|
|
66
|
+
"Install with: pip install visual-rag-toolkit[qdrant]"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self.collection_name = collection_name
|
|
70
|
+
self.timeout = timeout
|
|
71
|
+
if vector_datatype not in ("float32", "float16"):
|
|
72
|
+
raise ValueError("vector_datatype must be 'float32' or 'float16'")
|
|
73
|
+
self.vector_datatype = vector_datatype
|
|
74
|
+
self._np_vector_dtype = np.float16 if vector_datatype == "float16" else np.float32
|
|
75
|
+
|
|
76
|
+
grpc_port = None
|
|
77
|
+
if prefer_grpc:
|
|
78
|
+
try:
|
|
79
|
+
parsed = urlparse(url)
|
|
80
|
+
port = parsed.port
|
|
81
|
+
if port == 6333:
|
|
82
|
+
grpc_port = 6334
|
|
83
|
+
except Exception:
|
|
84
|
+
grpc_port = None
|
|
85
|
+
|
|
86
|
+
def _make_client(use_grpc: bool):
|
|
87
|
+
return QdrantClient(
|
|
88
|
+
url=url,
|
|
89
|
+
api_key=api_key,
|
|
90
|
+
timeout=timeout,
|
|
91
|
+
prefer_grpc=bool(use_grpc),
|
|
92
|
+
grpc_port=grpc_port,
|
|
93
|
+
check_compatibility=False,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self.client = _make_client(prefer_grpc)
|
|
97
|
+
if prefer_grpc:
|
|
98
|
+
try:
|
|
99
|
+
_ = self.client.get_collections()
|
|
100
|
+
except Exception as e:
|
|
101
|
+
msg = str(e)
|
|
102
|
+
if "StatusCode.PERMISSION_DENIED" in msg or "http2 header with status: 403" in msg:
|
|
103
|
+
self.client = _make_client(False)
|
|
104
|
+
else:
|
|
105
|
+
raise
|
|
106
|
+
|
|
107
|
+
logger.info(f"🔌 Connected to Qdrant: {url}")
|
|
108
|
+
logger.info(f" Collection: {collection_name}")
|
|
109
|
+
logger.info(f" Vector datatype: {self.vector_datatype}")
|
|
110
|
+
|
|
111
|
+
def collection_exists(self) -> bool:
|
|
112
|
+
"""Check if collection exists."""
|
|
113
|
+
collections = self.client.get_collections().collections
|
|
114
|
+
return any(c.name == self.collection_name for c in collections)
|
|
115
|
+
|
|
116
|
+
def create_collection(
|
|
117
|
+
self,
|
|
118
|
+
embedding_dim: int = 128,
|
|
119
|
+
force_recreate: bool = False,
|
|
120
|
+
enable_quantization: bool = False,
|
|
121
|
+
indexing_threshold: int = 20000,
|
|
122
|
+
full_scan_threshold: int = 0,
|
|
123
|
+
) -> bool:
|
|
124
|
+
"""
|
|
125
|
+
Create collection with multi-vector support.
|
|
126
|
+
|
|
127
|
+
Creates named vectors:
|
|
128
|
+
- initial: Full multi-vector embeddings (num_patches × dim)
|
|
129
|
+
- mean_pooling: Tile-level pooled vectors (num_tiles × dim)
|
|
130
|
+
- experimental_pooling: Experimental multi-vector pooling (varies by model)
|
|
131
|
+
- global_pooling: Single vector pooled representation (dim)
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
embedding_dim: Embedding dimension (128 for ColSmol)
|
|
135
|
+
force_recreate: Delete and recreate if exists
|
|
136
|
+
enable_quantization: Enable int8 quantization
|
|
137
|
+
indexing_threshold: Qdrant optimizer indexing threshold (set 0 to always build ANN indexes)
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
True if created, False if already existed
|
|
141
|
+
"""
|
|
142
|
+
from qdrant_client.http import models
|
|
143
|
+
from qdrant_client.http.models import Distance, VectorParams
|
|
144
|
+
|
|
145
|
+
if self.collection_exists():
|
|
146
|
+
if force_recreate:
|
|
147
|
+
logger.info(f"🗑️ Deleting existing collection: {self.collection_name}")
|
|
148
|
+
self.client.delete_collection(self.collection_name)
|
|
149
|
+
else:
|
|
150
|
+
logger.info(f"✅ Collection already exists: {self.collection_name}")
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
logger.info(f"📦 Creating collection: {self.collection_name}")
|
|
154
|
+
|
|
155
|
+
# Multi-vector config for ColBERT-style MaxSim
|
|
156
|
+
multivector_config = models.MultiVectorConfig(
|
|
157
|
+
comparator=models.MultiVectorComparator.MAX_SIM
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Vector configs - simplified for compatibility
|
|
161
|
+
datatype = (
|
|
162
|
+
models.Datatype.FLOAT16
|
|
163
|
+
if self.vector_datatype == "float16"
|
|
164
|
+
else models.Datatype.FLOAT32
|
|
165
|
+
)
|
|
166
|
+
vectors_config = {
|
|
167
|
+
"initial": VectorParams(
|
|
168
|
+
size=embedding_dim,
|
|
169
|
+
distance=Distance.COSINE,
|
|
170
|
+
on_disk=True,
|
|
171
|
+
multivector_config=multivector_config,
|
|
172
|
+
datatype=datatype,
|
|
173
|
+
),
|
|
174
|
+
"mean_pooling": VectorParams(
|
|
175
|
+
size=embedding_dim,
|
|
176
|
+
distance=Distance.COSINE,
|
|
177
|
+
on_disk=False,
|
|
178
|
+
multivector_config=multivector_config,
|
|
179
|
+
datatype=datatype,
|
|
180
|
+
),
|
|
181
|
+
"experimental_pooling": VectorParams(
|
|
182
|
+
size=embedding_dim,
|
|
183
|
+
distance=Distance.COSINE,
|
|
184
|
+
on_disk=False,
|
|
185
|
+
multivector_config=multivector_config,
|
|
186
|
+
datatype=datatype,
|
|
187
|
+
),
|
|
188
|
+
"global_pooling": VectorParams(
|
|
189
|
+
size=embedding_dim,
|
|
190
|
+
distance=Distance.COSINE,
|
|
191
|
+
on_disk=False,
|
|
192
|
+
datatype=datatype,
|
|
193
|
+
),
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
self.client.create_collection(
|
|
197
|
+
collection_name=self.collection_name,
|
|
198
|
+
vectors_config=vectors_config,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
logger.info(f"✅ Collection created: {self.collection_name}")
|
|
202
|
+
return True
|
|
203
|
+
|
|
204
|
+
def create_payload_indexes(
|
|
205
|
+
self,
|
|
206
|
+
fields: Optional[List[Dict[str, str]]] = None,
|
|
207
|
+
):
|
|
208
|
+
"""
|
|
209
|
+
Create payload indexes for filtering.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
fields: List of {field, type} dicts
|
|
213
|
+
type can be: integer, keyword, bool, float, text
|
|
214
|
+
"""
|
|
215
|
+
from qdrant_client.http import models
|
|
216
|
+
|
|
217
|
+
type_mapping = {
|
|
218
|
+
"integer": models.PayloadSchemaType.INTEGER,
|
|
219
|
+
"keyword": models.PayloadSchemaType.KEYWORD,
|
|
220
|
+
"bool": models.PayloadSchemaType.BOOL,
|
|
221
|
+
"float": models.PayloadSchemaType.FLOAT,
|
|
222
|
+
"text": models.PayloadSchemaType.TEXT,
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
if not fields:
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
logger.info("📇 Creating payload indexes...")
|
|
229
|
+
|
|
230
|
+
for field_config in fields:
|
|
231
|
+
field_name = field_config["field"]
|
|
232
|
+
field_type_str = field_config.get("type", "keyword")
|
|
233
|
+
field_type = type_mapping.get(field_type_str, models.PayloadSchemaType.KEYWORD)
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
self.client.create_payload_index(
|
|
237
|
+
collection_name=self.collection_name,
|
|
238
|
+
field_name=field_name,
|
|
239
|
+
field_schema=field_type,
|
|
240
|
+
)
|
|
241
|
+
logger.info(f" ✅ {field_name} ({field_type_str})")
|
|
242
|
+
except Exception as e:
|
|
243
|
+
logger.debug(f" Index {field_name} might already exist: {e}")
|
|
244
|
+
|
|
245
|
+
def upload_batch(
|
|
246
|
+
self,
|
|
247
|
+
points: List[Dict[str, Any]],
|
|
248
|
+
max_retries: int = 3,
|
|
249
|
+
delay_between_batches: float = 0.5,
|
|
250
|
+
wait: bool = True,
|
|
251
|
+
stop_event=None,
|
|
252
|
+
) -> int:
|
|
253
|
+
"""
|
|
254
|
+
Upload a batch of points to Qdrant.
|
|
255
|
+
|
|
256
|
+
Each point should have:
|
|
257
|
+
- id: Unique point ID (string or UUID)
|
|
258
|
+
- visual_embedding: Full embedding [num_patches, dim]
|
|
259
|
+
- tile_pooled_embedding: Pooled embedding [num_tiles, dim]
|
|
260
|
+
- experimental_pooled_embedding: Experimental pooled embedding [*, dim]
|
|
261
|
+
- global_pooled_embedding: Pooled embedding [dim]
|
|
262
|
+
- metadata: Payload dict
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
points: List of point dicts
|
|
266
|
+
max_retries: Retry attempts on failure
|
|
267
|
+
delay_between_batches: Delay after upload
|
|
268
|
+
wait: Wait for operation to complete on Qdrant server
|
|
269
|
+
stop_event: Optional threading.Event used to cancel uploads early
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Number of successfully uploaded points
|
|
273
|
+
"""
|
|
274
|
+
from qdrant_client.http import models
|
|
275
|
+
|
|
276
|
+
if not points:
|
|
277
|
+
return 0
|
|
278
|
+
|
|
279
|
+
def _is_cancelled() -> bool:
|
|
280
|
+
return stop_event is not None and getattr(stop_event, "is_set", lambda: False)()
|
|
281
|
+
|
|
282
|
+
def _is_payload_too_large_error(e: Exception) -> bool:
|
|
283
|
+
msg = str(e)
|
|
284
|
+
if ("JSON payload" in msg and "larger than allowed" in msg) or (
|
|
285
|
+
"Payload error:" in msg and "limit:" in msg
|
|
286
|
+
):
|
|
287
|
+
return True
|
|
288
|
+
content = getattr(e, "content", None)
|
|
289
|
+
if content is not None:
|
|
290
|
+
try:
|
|
291
|
+
if isinstance(content, (bytes, bytearray)):
|
|
292
|
+
text = content.decode("utf-8", errors="ignore")
|
|
293
|
+
else:
|
|
294
|
+
text = str(content)
|
|
295
|
+
except Exception:
|
|
296
|
+
text = ""
|
|
297
|
+
if ("JSON payload" in text and "larger than allowed" in text) or (
|
|
298
|
+
"Payload error" in text and "limit" in text
|
|
299
|
+
):
|
|
300
|
+
return True
|
|
301
|
+
resp = getattr(e, "response", None)
|
|
302
|
+
if resp is not None:
|
|
303
|
+
try:
|
|
304
|
+
text = str(getattr(resp, "text", "") or "")
|
|
305
|
+
except Exception:
|
|
306
|
+
text = ""
|
|
307
|
+
if ("JSON payload" in text and "larger than allowed" in text) or (
|
|
308
|
+
"Payload error" in text and "limit" in text
|
|
309
|
+
):
|
|
310
|
+
return True
|
|
311
|
+
return False
|
|
312
|
+
|
|
313
|
+
def _to_list(val):
|
|
314
|
+
if isinstance(val, np.ndarray):
|
|
315
|
+
return val.tolist()
|
|
316
|
+
return val
|
|
317
|
+
|
|
318
|
+
def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[models.PointStruct]:
|
|
319
|
+
qdrant_points: List[models.PointStruct] = []
|
|
320
|
+
for p in batch_points:
|
|
321
|
+
global_pooled = p.get("global_pooled_embedding")
|
|
322
|
+
if global_pooled is None:
|
|
323
|
+
tile_pooled = np.array(p["tile_pooled_embedding"], dtype=np.float32)
|
|
324
|
+
global_pooled = tile_pooled.mean(axis=0)
|
|
325
|
+
global_pooled = np.array(global_pooled, dtype=np.float32).reshape(-1)
|
|
326
|
+
|
|
327
|
+
initial = np.array(p["visual_embedding"], dtype=np.float32).astype(
|
|
328
|
+
self._np_vector_dtype, copy=False
|
|
329
|
+
)
|
|
330
|
+
mean_pooling = np.array(p["tile_pooled_embedding"], dtype=np.float32).astype(
|
|
331
|
+
self._np_vector_dtype, copy=False
|
|
332
|
+
)
|
|
333
|
+
experimental_pooling = np.array(
|
|
334
|
+
p["experimental_pooled_embedding"], dtype=np.float32
|
|
335
|
+
).astype(self._np_vector_dtype, copy=False)
|
|
336
|
+
global_pooling = global_pooled.astype(self._np_vector_dtype, copy=False)
|
|
337
|
+
|
|
338
|
+
qdrant_points.append(
|
|
339
|
+
models.PointStruct(
|
|
340
|
+
id=p["id"],
|
|
341
|
+
vector={
|
|
342
|
+
"initial": _to_list(initial),
|
|
343
|
+
"mean_pooling": _to_list(mean_pooling),
|
|
344
|
+
"experimental_pooling": _to_list(experimental_pooling),
|
|
345
|
+
"global_pooling": _to_list(global_pooling),
|
|
346
|
+
},
|
|
347
|
+
payload=p["metadata"],
|
|
348
|
+
)
|
|
349
|
+
)
|
|
350
|
+
return qdrant_points
|
|
351
|
+
|
|
352
|
+
# Upload with retry
|
|
353
|
+
for attempt in range(max_retries):
|
|
354
|
+
try:
|
|
355
|
+
if _is_cancelled():
|
|
356
|
+
return 0
|
|
357
|
+
qdrant_points = _build_qdrant_points(points)
|
|
358
|
+
self.client.upsert(
|
|
359
|
+
collection_name=self.collection_name,
|
|
360
|
+
points=qdrant_points,
|
|
361
|
+
wait=wait,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
if delay_between_batches > 0:
|
|
365
|
+
if _is_cancelled():
|
|
366
|
+
return 0
|
|
367
|
+
time.sleep(delay_between_batches)
|
|
368
|
+
|
|
369
|
+
return len(points)
|
|
370
|
+
|
|
371
|
+
except Exception as e:
|
|
372
|
+
if _is_payload_too_large_error(e) and len(points) > 1:
|
|
373
|
+
mid = len(points) // 2
|
|
374
|
+
left = points[:mid]
|
|
375
|
+
right = points[mid:]
|
|
376
|
+
logger.warning(
|
|
377
|
+
f"Upload payload too large for {len(points)} points; splitting into {len(left)} + {len(right)}"
|
|
378
|
+
)
|
|
379
|
+
return self.upload_batch(
|
|
380
|
+
left,
|
|
381
|
+
max_retries=max_retries,
|
|
382
|
+
delay_between_batches=delay_between_batches,
|
|
383
|
+
wait=wait,
|
|
384
|
+
stop_event=stop_event,
|
|
385
|
+
) + self.upload_batch(
|
|
386
|
+
right,
|
|
387
|
+
max_retries=max_retries,
|
|
388
|
+
delay_between_batches=delay_between_batches,
|
|
389
|
+
wait=wait,
|
|
390
|
+
stop_event=stop_event,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
logger.warning(f"Upload attempt {attempt + 1}/{max_retries} failed: {e}")
|
|
394
|
+
if attempt < max_retries - 1:
|
|
395
|
+
if _is_cancelled():
|
|
396
|
+
return 0
|
|
397
|
+
time.sleep(2**attempt) # Exponential backoff
|
|
398
|
+
|
|
399
|
+
logger.error(f"❌ Upload failed after {max_retries} attempts")
|
|
400
|
+
return 0
|
|
401
|
+
|
|
402
|
+
def check_exists(self, chunk_id: str) -> bool:
|
|
403
|
+
"""Check if a point already exists."""
|
|
404
|
+
try:
|
|
405
|
+
result = self.client.retrieve(
|
|
406
|
+
collection_name=self.collection_name,
|
|
407
|
+
ids=[chunk_id],
|
|
408
|
+
with_payload=False,
|
|
409
|
+
with_vectors=False,
|
|
410
|
+
)
|
|
411
|
+
return len(result) > 0
|
|
412
|
+
except Exception:
|
|
413
|
+
return False
|
|
414
|
+
|
|
415
|
+
def get_existing_ids(self, filename: str) -> Set[str]:
|
|
416
|
+
"""Get all point IDs for a specific file."""
|
|
417
|
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
418
|
+
|
|
419
|
+
existing_ids = set()
|
|
420
|
+
offset = None
|
|
421
|
+
|
|
422
|
+
while True:
|
|
423
|
+
results = self.client.scroll(
|
|
424
|
+
collection_name=self.collection_name,
|
|
425
|
+
scroll_filter=Filter(
|
|
426
|
+
must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
|
|
427
|
+
),
|
|
428
|
+
limit=100,
|
|
429
|
+
offset=offset,
|
|
430
|
+
with_payload=["page_number"],
|
|
431
|
+
with_vectors=False,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
points, next_offset = results
|
|
435
|
+
|
|
436
|
+
for point in points:
|
|
437
|
+
existing_ids.add(str(point.id))
|
|
438
|
+
|
|
439
|
+
if next_offset is None or len(points) == 0:
|
|
440
|
+
break
|
|
441
|
+
offset = next_offset
|
|
442
|
+
|
|
443
|
+
return existing_ids
|
|
444
|
+
|
|
445
|
+
def get_collection_info(self) -> Optional[Dict[str, Any]]:
|
|
446
|
+
"""Get collection statistics."""
|
|
447
|
+
try:
|
|
448
|
+
info = self.client.get_collection(self.collection_name)
|
|
449
|
+
|
|
450
|
+
status = info.status
|
|
451
|
+
if hasattr(status, "value"):
|
|
452
|
+
status = status.value
|
|
453
|
+
|
|
454
|
+
indexed_count = getattr(info, "indexed_vectors_count", 0) or 0
|
|
455
|
+
if isinstance(indexed_count, dict):
|
|
456
|
+
indexed_count = sum(indexed_count.values())
|
|
457
|
+
|
|
458
|
+
return {
|
|
459
|
+
"status": str(status),
|
|
460
|
+
"points_count": getattr(info, "points_count", 0),
|
|
461
|
+
"indexed_vectors_count": indexed_count,
|
|
462
|
+
}
|
|
463
|
+
except Exception as e:
|
|
464
|
+
logger.warning(f"Could not get collection info: {e}")
|
|
465
|
+
return None
|
|
466
|
+
|
|
467
|
+
@staticmethod
|
|
468
|
+
def generate_point_id(filename: str, page_number: int) -> str:
|
|
469
|
+
"""
|
|
470
|
+
Generate deterministic point ID from filename and page.
|
|
471
|
+
|
|
472
|
+
Returns a valid UUID string.
|
|
473
|
+
"""
|
|
474
|
+
content = f"{filename}:page:{page_number}"
|
|
475
|
+
hash_obj = hashlib.sha256(content.encode())
|
|
476
|
+
hex_str = hash_obj.hexdigest()[:32]
|
|
477
|
+
# Format as UUID
|
|
478
|
+
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class CropEmptyConfig:
|
|
12
|
+
percentage_to_remove: float = 0.9
|
|
13
|
+
remove_page_number: bool = False
|
|
14
|
+
color_threshold: int = 240
|
|
15
|
+
min_white_fraction: float = 0.99
|
|
16
|
+
content_density_sides: float = 0.001
|
|
17
|
+
content_density_main_text: float = 0.05
|
|
18
|
+
content_density_any: float = 1e-6
|
|
19
|
+
preserve_border_px: int = 1
|
|
20
|
+
uniform_rowcol_std_threshold: float = 0.0
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def crop_empty(
|
|
24
|
+
image: Image.Image, *, config: CropEmptyConfig
|
|
25
|
+
) -> Tuple[Image.Image, Dict[str, Any]]:
|
|
26
|
+
img = image.convert("RGB")
|
|
27
|
+
arr = np.array(img)
|
|
28
|
+
intensity = arr.mean(axis=2)
|
|
29
|
+
|
|
30
|
+
def _find_border_start(axis: int, *, min_content_density_threshold: float) -> int:
|
|
31
|
+
size = intensity.shape[axis]
|
|
32
|
+
for i in range(size):
|
|
33
|
+
pixels = intensity[i, :] if axis == 0 else intensity[:, i]
|
|
34
|
+
white = float(np.mean(pixels > config.color_threshold))
|
|
35
|
+
non_white = 1.0 - white
|
|
36
|
+
if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(
|
|
37
|
+
config.uniform_rowcol_std_threshold
|
|
38
|
+
):
|
|
39
|
+
continue
|
|
40
|
+
if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
|
|
41
|
+
return int(i)
|
|
42
|
+
return int(size)
|
|
43
|
+
|
|
44
|
+
def _find_border_end(axis: int, *, min_content_density_threshold: float) -> int:
|
|
45
|
+
size = intensity.shape[axis]
|
|
46
|
+
for i in range(size - 1, -1, -1):
|
|
47
|
+
pixels = intensity[i, :] if axis == 0 else intensity[:, i]
|
|
48
|
+
white = float(np.mean(pixels > config.color_threshold))
|
|
49
|
+
non_white = 1.0 - white
|
|
50
|
+
if float(config.uniform_rowcol_std_threshold) > 0.0 and float(np.std(pixels)) <= float(
|
|
51
|
+
config.uniform_rowcol_std_threshold
|
|
52
|
+
):
|
|
53
|
+
continue
|
|
54
|
+
if (white < config.min_white_fraction) and (non_white > min_content_density_threshold):
|
|
55
|
+
return int(i + 1)
|
|
56
|
+
return 0
|
|
57
|
+
|
|
58
|
+
top = _find_border_start(0, min_content_density_threshold=float(config.content_density_sides))
|
|
59
|
+
left = _find_border_start(1, min_content_density_threshold=float(config.content_density_sides))
|
|
60
|
+
right = _find_border_end(1, min_content_density_threshold=float(config.content_density_sides))
|
|
61
|
+
|
|
62
|
+
main_text_end = _find_border_end(
|
|
63
|
+
0, min_content_density_threshold=float(config.content_density_main_text)
|
|
64
|
+
)
|
|
65
|
+
last_content_end = _find_border_end(
|
|
66
|
+
0, min_content_density_threshold=float(config.content_density_any)
|
|
67
|
+
)
|
|
68
|
+
bottom = main_text_end if config.remove_page_number else last_content_end
|
|
69
|
+
|
|
70
|
+
width, height = img.size
|
|
71
|
+
pad = max(int(getattr(config, "preserve_border_px", 0) or 0), 0)
|
|
72
|
+
if pad > 0:
|
|
73
|
+
left = max(int(left) - pad, 0)
|
|
74
|
+
top = max(int(top) - pad, 0)
|
|
75
|
+
right = min(int(right) + pad, int(width))
|
|
76
|
+
bottom = min(int(bottom) + pad, int(height))
|
|
77
|
+
crop_box = (int(left), int(top), int(right), int(bottom))
|
|
78
|
+
valid = 0 <= crop_box[0] < crop_box[2] <= width and 0 <= crop_box[1] < crop_box[3] <= height
|
|
79
|
+
|
|
80
|
+
if not valid:
|
|
81
|
+
return image, {
|
|
82
|
+
"applied": False,
|
|
83
|
+
"crop_box": None,
|
|
84
|
+
"original_width": int(width),
|
|
85
|
+
"original_height": int(height),
|
|
86
|
+
"cropped_width": int(width),
|
|
87
|
+
"cropped_height": int(height),
|
|
88
|
+
"config": {
|
|
89
|
+
"percentage_to_remove": float(config.percentage_to_remove),
|
|
90
|
+
"remove_page_number": bool(config.remove_page_number),
|
|
91
|
+
"color_threshold": int(config.color_threshold),
|
|
92
|
+
"min_white_fraction": float(config.min_white_fraction),
|
|
93
|
+
"content_density_sides": float(config.content_density_sides),
|
|
94
|
+
"content_density_main_text": float(config.content_density_main_text),
|
|
95
|
+
"content_density_any": float(config.content_density_any),
|
|
96
|
+
"preserve_border_px": int(config.preserve_border_px),
|
|
97
|
+
"uniform_rowcol_std_threshold": float(config.uniform_rowcol_std_threshold),
|
|
98
|
+
},
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
cropped = img.crop(crop_box)
|
|
102
|
+
return cropped, {
|
|
103
|
+
"applied": True,
|
|
104
|
+
"crop_box": [int(crop_box[0]), int(crop_box[1]), int(crop_box[2]), int(crop_box[3])],
|
|
105
|
+
"original_width": int(width),
|
|
106
|
+
"original_height": int(height),
|
|
107
|
+
"cropped_width": int(cropped.width),
|
|
108
|
+
"cropped_height": int(cropped.height),
|
|
109
|
+
"config": {
|
|
110
|
+
"percentage_to_remove": float(config.percentage_to_remove),
|
|
111
|
+
"remove_page_number": bool(config.remove_page_number),
|
|
112
|
+
"color_threshold": int(config.color_threshold),
|
|
113
|
+
"min_white_fraction": float(config.min_white_fraction),
|
|
114
|
+
"content_density_sides": float(config.content_density_sides),
|
|
115
|
+
"content_density_main_text": float(config.content_density_main_text),
|
|
116
|
+
"content_density_any": float(config.content_density_any),
|
|
117
|
+
"preserve_border_px": int(config.preserve_border_px),
|
|
118
|
+
"uniform_rowcol_std_threshold": float(config.uniform_rowcol_std_threshold),
|
|
119
|
+
},
|
|
120
|
+
}
|