visual-rag-toolkit 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- demo/__init__.py +1 -1
- demo/app.py +20 -8
- demo/evaluation.py +5 -45
- demo/indexing.py +180 -221
- demo/qdrant_utils.py +12 -5
- demo/ui/playground.py +1 -1
- demo/ui/sidebar.py +26 -3
- demo/ui/upload.py +6 -5
- visual_rag/__init__.py +63 -6
- visual_rag/config.py +4 -7
- visual_rag/demo_runner.py +3 -5
- visual_rag/indexing/__init__.py +21 -4
- visual_rag/indexing/qdrant_indexer.py +94 -42
- visual_rag/retrieval/multi_vector.py +62 -65
- visual_rag/retrieval/single_stage.py +7 -0
- visual_rag/retrieval/two_stage.py +7 -10
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.4.dist-info}/METADATA +28 -16
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.4.dist-info}/RECORD +21 -22
- demo/example_metadata_mapping_sigir.json +0 -37
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.4.dist-info}/WHEEL +0 -0
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.4.dist-info}/entry_points.txt +0 -0
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.4.dist-info}/licenses/LICENSE +0 -0
demo/ui/playground.py
CHANGED
|
@@ -9,6 +9,7 @@ from demo.qdrant_utils import (
|
|
|
9
9
|
sample_points_cached,
|
|
10
10
|
search_collection,
|
|
11
11
|
)
|
|
12
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def render_playground_tab():
|
|
@@ -46,7 +47,6 @@ def render_playground_tab():
|
|
|
46
47
|
if not st.session_state.get("model_loaded"):
|
|
47
48
|
with st.spinner(f"Loading {model_short}..."):
|
|
48
49
|
try:
|
|
49
|
-
from visual_rag.retrieval import MultiVectorRetriever
|
|
50
50
|
_ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
|
|
51
51
|
st.session_state["model_loaded"] = True
|
|
52
52
|
st.session_state["loaded_model_key"] = cache_key
|
demo/ui/sidebar.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
import os
|
|
4
4
|
import streamlit as st
|
|
5
5
|
|
|
6
|
+
from qdrant_client.models import VectorParamsDiff
|
|
7
|
+
|
|
6
8
|
from demo.qdrant_utils import (
|
|
7
9
|
get_qdrant_credentials,
|
|
8
10
|
init_qdrant_client_with_creds,
|
|
@@ -14,11 +16,33 @@ from demo.qdrant_utils import (
|
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def render_sidebar():
|
|
19
|
+
# CSS to make sidebar metrics smaller
|
|
20
|
+
st.markdown("""
|
|
21
|
+
<style>
|
|
22
|
+
/* Smaller metrics in sidebar */
|
|
23
|
+
[data-testid="stSidebar"] [data-testid="stMetricValue"] {
|
|
24
|
+
font-size: 1.2rem !important;
|
|
25
|
+
}
|
|
26
|
+
[data-testid="stSidebar"] [data-testid="stMetricLabel"] {
|
|
27
|
+
font-size: 0.75rem !important;
|
|
28
|
+
}
|
|
29
|
+
/* Smaller expander headers in sidebar */
|
|
30
|
+
[data-testid="stSidebar"] [data-testid="stExpander"] summary {
|
|
31
|
+
font-size: 0.9rem !important;
|
|
32
|
+
}
|
|
33
|
+
/* Compact subheaders */
|
|
34
|
+
[data-testid="stSidebar"] h3 {
|
|
35
|
+
font-size: 1rem !important;
|
|
36
|
+
margin-bottom: 0.5rem !important;
|
|
37
|
+
}
|
|
38
|
+
</style>
|
|
39
|
+
""", unsafe_allow_html=True)
|
|
40
|
+
|
|
17
41
|
with st.sidebar:
|
|
18
42
|
st.subheader("🔑 Qdrant Credentials")
|
|
19
43
|
|
|
20
|
-
env_url = os.getenv("
|
|
21
|
-
env_key = os.getenv("
|
|
44
|
+
env_url = os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") or ""
|
|
45
|
+
env_key = os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") or ""
|
|
22
46
|
|
|
23
47
|
if "qdrant_url_input" not in st.session_state:
|
|
24
48
|
st.session_state["qdrant_url_input"] = env_url
|
|
@@ -136,7 +160,6 @@ def render_sidebar():
|
|
|
136
160
|
if target_in_ram != current_in_ram:
|
|
137
161
|
if st.button("💾 Apply Change", key="admin_apply"):
|
|
138
162
|
try:
|
|
139
|
-
from qdrant_client.models import VectorParamsDiff
|
|
140
163
|
client.update_collection(
|
|
141
164
|
collection_name=active,
|
|
142
165
|
vectors_config={sel_vec: VectorParamsDiff(on_disk=not target_in_ram)}
|
demo/ui/upload.py
CHANGED
|
@@ -9,6 +9,7 @@ import inspect
|
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
|
|
12
|
+
import numpy as np
|
|
12
13
|
import streamlit as st
|
|
13
14
|
|
|
14
15
|
from demo.config import AVAILABLE_MODELS
|
|
@@ -17,6 +18,10 @@ from demo.qdrant_utils import (
|
|
|
17
18
|
get_collection_stats,
|
|
18
19
|
sample_points_cached,
|
|
19
20
|
)
|
|
21
|
+
from visual_rag.embedding.visual_embedder import VisualEmbedder
|
|
22
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
23
|
+
from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
|
|
24
|
+
from visual_rag.indexing.pipeline import ProcessingPipeline
|
|
20
25
|
|
|
21
26
|
|
|
22
27
|
VECTOR_TYPES = ["initial", "mean_pooling", "experimental_pooling", "global_pooling"]
|
|
@@ -251,10 +256,6 @@ def process_pdfs(uploaded_files, config):
|
|
|
251
256
|
model_short = model_name.split("/")[-1]
|
|
252
257
|
model_status.info(f"Loading `{model_short}`...")
|
|
253
258
|
|
|
254
|
-
import numpy as np
|
|
255
|
-
from visual_rag import VisualEmbedder
|
|
256
|
-
from visual_rag.indexing import QdrantIndexer, CloudinaryUploader, ProcessingPipeline
|
|
257
|
-
|
|
258
259
|
output_dtype = np.float16 if vector_dtype == "float16" else np.float32
|
|
259
260
|
embedder_key = f"{model_name}::{vector_dtype}"
|
|
260
261
|
embedder = None
|
|
@@ -448,7 +449,7 @@ def process_pdfs(uploaded_files, config):
|
|
|
448
449
|
|
|
449
450
|
if total_uploaded > 0:
|
|
450
451
|
st.session_state["upload_success"] = f"Uploaded {total_uploaded} pages to {collection_name}"
|
|
451
|
-
st.
|
|
452
|
+
st.rerun() # Immediately refresh to show success toast + balloons
|
|
452
453
|
|
|
453
454
|
except Exception as e:
|
|
454
455
|
st.error(f"❌ Processing error: {e}")
|
visual_rag/__init__.py
CHANGED
|
@@ -31,7 +31,47 @@ Quick Start:
|
|
|
31
31
|
Each component works independently - use only what you need.
|
|
32
32
|
"""
|
|
33
33
|
|
|
34
|
-
|
|
34
|
+
import logging
|
|
35
|
+
|
|
36
|
+
__version__ = "0.1.4"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def setup_logging(level: str = "INFO", format: str = None) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Configure logging for visual_rag package.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
level: Log level ("DEBUG", "INFO", "WARNING", "ERROR")
|
|
45
|
+
format: Custom format string. Default shows time, level, and message.
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> import visual_rag
|
|
49
|
+
>>> visual_rag.setup_logging("INFO")
|
|
50
|
+
>>> # Now you'll see processing logs
|
|
51
|
+
"""
|
|
52
|
+
if format is None:
|
|
53
|
+
format = "[%(asctime)s] %(levelname)s - %(message)s"
|
|
54
|
+
|
|
55
|
+
logging.basicConfig(
|
|
56
|
+
level=getattr(logging, level.upper(), logging.INFO),
|
|
57
|
+
format=format,
|
|
58
|
+
datefmt="%H:%M:%S",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Also set the visual_rag logger specifically
|
|
62
|
+
logger = logging.getLogger("visual_rag")
|
|
63
|
+
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# Enable INFO logging by default for visual_rag package and all submodules
|
|
67
|
+
# This ensures logs like "Processing PDF...", "Embedding pages..." are visible
|
|
68
|
+
_logger = logging.getLogger("visual_rag")
|
|
69
|
+
if not _logger.handlers:
|
|
70
|
+
_handler = logging.StreamHandler()
|
|
71
|
+
_handler.setFormatter(logging.Formatter("[%(asctime)s] %(message)s", datefmt="%H:%M:%S"))
|
|
72
|
+
_logger.addHandler(_handler)
|
|
73
|
+
_logger.setLevel(logging.INFO)
|
|
74
|
+
_logger.propagate = False # Don't duplicate to root logger
|
|
35
75
|
|
|
36
76
|
# Import main classes at package level for convenience
|
|
37
77
|
# These are optional - if dependencies aren't installed, we catch the error
|
|
@@ -71,13 +111,16 @@ try:
|
|
|
71
111
|
except ImportError:
|
|
72
112
|
QdrantAdmin = None
|
|
73
113
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
except ImportError:
|
|
77
|
-
demo = None
|
|
114
|
+
# demo is lazily imported to avoid RuntimeWarning when running as __main__
|
|
115
|
+
# Access via visual_rag.demo() which triggers __getattr__
|
|
78
116
|
|
|
79
117
|
# Config utilities (always available)
|
|
80
|
-
|
|
118
|
+
try:
|
|
119
|
+
from visual_rag.config import get, get_section, load_config
|
|
120
|
+
except ImportError:
|
|
121
|
+
get = None
|
|
122
|
+
get_section = None
|
|
123
|
+
load_config = None
|
|
81
124
|
|
|
82
125
|
__all__ = [
|
|
83
126
|
# Version
|
|
@@ -95,4 +138,18 @@ __all__ = [
|
|
|
95
138
|
"load_config",
|
|
96
139
|
"get",
|
|
97
140
|
"get_section",
|
|
141
|
+
# Logging
|
|
142
|
+
"setup_logging",
|
|
98
143
|
]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def __getattr__(name: str):
|
|
147
|
+
"""Lazy import for demo to avoid RuntimeWarning when running as __main__."""
|
|
148
|
+
if name == "demo":
|
|
149
|
+
try:
|
|
150
|
+
from visual_rag.demo_runner import demo
|
|
151
|
+
|
|
152
|
+
return demo
|
|
153
|
+
except ImportError:
|
|
154
|
+
return None
|
|
155
|
+
raise AttributeError(f"module 'visual_rag' has no attribute {name!r}")
|
visual_rag/config.py
CHANGED
|
@@ -21,16 +21,13 @@ _raw_config_cache_path: Optional[str] = None
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def _env_qdrant_url() -> Optional[str]:
|
|
24
|
-
|
|
24
|
+
"""Get Qdrant URL from environment. Prefers QDRANT_URL."""
|
|
25
|
+
return os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy fallback
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def _env_qdrant_api_key() -> Optional[str]:
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
31
|
-
or os.getenv("DEST_QDRANT_API_KEY")
|
|
32
|
-
or os.getenv("QDRANT_API_KEY")
|
|
33
|
-
)
|
|
29
|
+
"""Get Qdrant API key from environment. Prefers QDRANT_API_KEY."""
|
|
30
|
+
return os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") # legacy fallback
|
|
34
31
|
|
|
35
32
|
|
|
36
33
|
def load_config(
|
visual_rag/demo_runner.py
CHANGED
|
@@ -52,13 +52,11 @@ def demo(
|
|
|
52
52
|
cmd = [sys.executable, "-m", "streamlit", "run", str(app_path)]
|
|
53
53
|
cmd += ["--server.address", str(host)]
|
|
54
54
|
cmd += ["--server.port", str(int(port))]
|
|
55
|
-
|
|
55
|
+
# headless=true prevents browser from auto-opening; open_browser overrides
|
|
56
|
+
should_be_headless = headless and not open_browser
|
|
57
|
+
cmd += ["--server.headless", "true" if should_be_headless else "false"]
|
|
56
58
|
cmd += ["--browser.gatherUsageStats", "false"]
|
|
57
59
|
cmd += ["--server.runOnSave", "false"]
|
|
58
|
-
cmd += ["--browser.serverAddress", str(host)]
|
|
59
|
-
if not open_browser:
|
|
60
|
-
cmd += ["--browser.serverPort", str(int(port))]
|
|
61
|
-
cmd += ["--browser.open", "false"]
|
|
62
60
|
|
|
63
61
|
if extra_args:
|
|
64
62
|
cmd += list(extra_args)
|
visual_rag/indexing/__init__.py
CHANGED
|
@@ -8,10 +8,27 @@ Components:
|
|
|
8
8
|
- ProcessingPipeline: End-to-end PDF → Qdrant pipeline
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
from visual_rag.indexing.
|
|
11
|
+
# Lazy imports to avoid failures when optional dependencies aren't installed
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from visual_rag.indexing.pdf_processor import PDFProcessor
|
|
15
|
+
except ImportError:
|
|
16
|
+
PDFProcessor = None
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
20
|
+
except ImportError:
|
|
21
|
+
QdrantIndexer = None
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
|
|
25
|
+
except ImportError:
|
|
26
|
+
CloudinaryUploader = None
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from visual_rag.indexing.pipeline import ProcessingPipeline
|
|
30
|
+
except ImportError:
|
|
31
|
+
ProcessingPipeline = None
|
|
15
32
|
|
|
16
33
|
__all__ = [
|
|
17
34
|
"PDFProcessor",
|
|
@@ -19,6 +19,23 @@ from urllib.parse import urlparse
|
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
|
+
try:
|
|
23
|
+
from qdrant_client import QdrantClient
|
|
24
|
+
from qdrant_client.http import models as qdrant_models
|
|
25
|
+
from qdrant_client.http.models import Distance, VectorParams
|
|
26
|
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
27
|
+
|
|
28
|
+
QDRANT_AVAILABLE = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
QDRANT_AVAILABLE = False
|
|
31
|
+
QdrantClient = None
|
|
32
|
+
qdrant_models = None
|
|
33
|
+
Distance = None
|
|
34
|
+
VectorParams = None
|
|
35
|
+
FieldCondition = None
|
|
36
|
+
Filter = None
|
|
37
|
+
MatchValue = None
|
|
38
|
+
|
|
22
39
|
logger = logging.getLogger(__name__)
|
|
23
40
|
|
|
24
41
|
|
|
@@ -58,9 +75,7 @@ class QdrantIndexer:
|
|
|
58
75
|
prefer_grpc: bool = False,
|
|
59
76
|
vector_datatype: str = "float32",
|
|
60
77
|
):
|
|
61
|
-
|
|
62
|
-
from qdrant_client import QdrantClient
|
|
63
|
-
except ImportError:
|
|
78
|
+
if not QDRANT_AVAILABLE:
|
|
64
79
|
raise ImportError(
|
|
65
80
|
"Qdrant client not installed. "
|
|
66
81
|
"Install with: pip install visual-rag-toolkit[qdrant]"
|
|
@@ -139,9 +154,6 @@ class QdrantIndexer:
|
|
|
139
154
|
Returns:
|
|
140
155
|
True if created, False if already existed
|
|
141
156
|
"""
|
|
142
|
-
from qdrant_client.http import models
|
|
143
|
-
from qdrant_client.http.models import Distance, VectorParams
|
|
144
|
-
|
|
145
157
|
if self.collection_exists():
|
|
146
158
|
if force_recreate:
|
|
147
159
|
logger.info(f"🗑️ Deleting existing collection: {self.collection_name}")
|
|
@@ -153,15 +165,15 @@ class QdrantIndexer:
|
|
|
153
165
|
logger.info(f"📦 Creating collection: {self.collection_name}")
|
|
154
166
|
|
|
155
167
|
# Multi-vector config for ColBERT-style MaxSim
|
|
156
|
-
multivector_config =
|
|
157
|
-
comparator=
|
|
168
|
+
multivector_config = qdrant_models.MultiVectorConfig(
|
|
169
|
+
comparator=qdrant_models.MultiVectorComparator.MAX_SIM
|
|
158
170
|
)
|
|
159
171
|
|
|
160
172
|
# Vector configs - simplified for compatibility
|
|
161
173
|
datatype = (
|
|
162
|
-
|
|
174
|
+
qdrant_models.Datatype.FLOAT16
|
|
163
175
|
if self.vector_datatype == "float16"
|
|
164
|
-
else
|
|
176
|
+
else qdrant_models.Datatype.FLOAT32
|
|
165
177
|
)
|
|
166
178
|
vectors_config = {
|
|
167
179
|
"initial": VectorParams(
|
|
@@ -198,6 +210,18 @@ class QdrantIndexer:
|
|
|
198
210
|
vectors_config=vectors_config,
|
|
199
211
|
)
|
|
200
212
|
|
|
213
|
+
# Create required payload index for skip_existing functionality
|
|
214
|
+
# This index is needed for filtering by filename when checking existing docs
|
|
215
|
+
try:
|
|
216
|
+
self.client.create_payload_index(
|
|
217
|
+
collection_name=self.collection_name,
|
|
218
|
+
field_name="filename",
|
|
219
|
+
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
|
220
|
+
)
|
|
221
|
+
logger.info(" 📇 Created payload index: filename")
|
|
222
|
+
except Exception as e:
|
|
223
|
+
logger.warning(f" ⚠️ Could not create filename index: {e}")
|
|
224
|
+
|
|
201
225
|
logger.info(f"✅ Collection created: {self.collection_name}")
|
|
202
226
|
return True
|
|
203
227
|
|
|
@@ -212,14 +236,12 @@ class QdrantIndexer:
|
|
|
212
236
|
fields: List of {field, type} dicts
|
|
213
237
|
type can be: integer, keyword, bool, float, text
|
|
214
238
|
"""
|
|
215
|
-
from qdrant_client.http import models
|
|
216
|
-
|
|
217
239
|
type_mapping = {
|
|
218
|
-
"integer":
|
|
219
|
-
"keyword":
|
|
220
|
-
"bool":
|
|
221
|
-
"float":
|
|
222
|
-
"text":
|
|
240
|
+
"integer": qdrant_models.PayloadSchemaType.INTEGER,
|
|
241
|
+
"keyword": qdrant_models.PayloadSchemaType.KEYWORD,
|
|
242
|
+
"bool": qdrant_models.PayloadSchemaType.BOOL,
|
|
243
|
+
"float": qdrant_models.PayloadSchemaType.FLOAT,
|
|
244
|
+
"text": qdrant_models.PayloadSchemaType.TEXT,
|
|
223
245
|
}
|
|
224
246
|
|
|
225
247
|
if not fields:
|
|
@@ -230,7 +252,7 @@ class QdrantIndexer:
|
|
|
230
252
|
for field_config in fields:
|
|
231
253
|
field_name = field_config["field"]
|
|
232
254
|
field_type_str = field_config.get("type", "keyword")
|
|
233
|
-
field_type = type_mapping.get(field_type_str,
|
|
255
|
+
field_type = type_mapping.get(field_type_str, qdrant_models.PayloadSchemaType.KEYWORD)
|
|
234
256
|
|
|
235
257
|
try:
|
|
236
258
|
self.client.create_payload_index(
|
|
@@ -271,8 +293,6 @@ class QdrantIndexer:
|
|
|
271
293
|
Returns:
|
|
272
294
|
Number of successfully uploaded points
|
|
273
295
|
"""
|
|
274
|
-
from qdrant_client.http import models
|
|
275
|
-
|
|
276
296
|
if not points:
|
|
277
297
|
return 0
|
|
278
298
|
|
|
@@ -315,8 +335,10 @@ class QdrantIndexer:
|
|
|
315
335
|
return val.tolist()
|
|
316
336
|
return val
|
|
317
337
|
|
|
318
|
-
def _build_qdrant_points(
|
|
319
|
-
|
|
338
|
+
def _build_qdrant_points(
|
|
339
|
+
batch_points: List[Dict[str, Any]],
|
|
340
|
+
) -> List[qdrant_models.PointStruct]:
|
|
341
|
+
qdrant_points: List[qdrant_models.PointStruct] = []
|
|
320
342
|
for p in batch_points:
|
|
321
343
|
global_pooled = p.get("global_pooled_embedding")
|
|
322
344
|
if global_pooled is None:
|
|
@@ -336,7 +358,7 @@ class QdrantIndexer:
|
|
|
336
358
|
global_pooling = global_pooled.astype(self._np_vector_dtype, copy=False)
|
|
337
359
|
|
|
338
360
|
qdrant_points.append(
|
|
339
|
-
|
|
361
|
+
qdrant_models.PointStruct(
|
|
340
362
|
id=p["id"],
|
|
341
363
|
vector={
|
|
342
364
|
"initial": _to_list(initial),
|
|
@@ -361,6 +383,8 @@ class QdrantIndexer:
|
|
|
361
383
|
wait=wait,
|
|
362
384
|
)
|
|
363
385
|
|
|
386
|
+
logger.info(f" ✅ Uploaded {len(points)} points to Qdrant")
|
|
387
|
+
|
|
364
388
|
if delay_between_batches > 0:
|
|
365
389
|
if _is_cancelled():
|
|
366
390
|
return 0
|
|
@@ -413,32 +437,60 @@ class QdrantIndexer:
|
|
|
413
437
|
return False
|
|
414
438
|
|
|
415
439
|
def get_existing_ids(self, filename: str) -> Set[str]:
|
|
416
|
-
"""Get all point IDs for a specific file.
|
|
417
|
-
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
440
|
+
"""Get all point IDs for a specific file.
|
|
418
441
|
|
|
442
|
+
Requires a payload index on 'filename' field. If the index doesn't exist,
|
|
443
|
+
this method will attempt to create it automatically.
|
|
444
|
+
"""
|
|
419
445
|
existing_ids = set()
|
|
420
446
|
offset = None
|
|
421
447
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
448
|
+
try:
|
|
449
|
+
while True:
|
|
450
|
+
results = self.client.scroll(
|
|
451
|
+
collection_name=self.collection_name,
|
|
452
|
+
scroll_filter=Filter(
|
|
453
|
+
must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
|
|
454
|
+
),
|
|
455
|
+
limit=100,
|
|
456
|
+
offset=offset,
|
|
457
|
+
with_payload=["page_number"],
|
|
458
|
+
with_vectors=False,
|
|
459
|
+
)
|
|
433
460
|
|
|
434
|
-
|
|
461
|
+
points, next_offset = results
|
|
435
462
|
|
|
436
|
-
|
|
437
|
-
|
|
463
|
+
for point in points:
|
|
464
|
+
existing_ids.add(str(point.id))
|
|
438
465
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
466
|
+
if next_offset is None or len(points) == 0:
|
|
467
|
+
break
|
|
468
|
+
offset = next_offset
|
|
469
|
+
|
|
470
|
+
except Exception as e:
|
|
471
|
+
error_msg = str(e).lower()
|
|
472
|
+
if "index required" in error_msg or "index" in error_msg and "filename" in error_msg:
|
|
473
|
+
# Missing payload index - try to create it
|
|
474
|
+
logger.warning(
|
|
475
|
+
"⚠️ Missing 'filename' payload index. Creating it now... "
|
|
476
|
+
"(skip_existing requires this index for filtering)"
|
|
477
|
+
)
|
|
478
|
+
try:
|
|
479
|
+
self.client.create_payload_index(
|
|
480
|
+
collection_name=self.collection_name,
|
|
481
|
+
field_name="filename",
|
|
482
|
+
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
|
483
|
+
)
|
|
484
|
+
logger.info(" ✅ Created 'filename' index. Retrying query...")
|
|
485
|
+
# Retry the query
|
|
486
|
+
return self.get_existing_ids(filename)
|
|
487
|
+
except Exception as idx_err:
|
|
488
|
+
logger.warning(f" ❌ Could not create index: {idx_err}")
|
|
489
|
+
logger.warning(" Returning empty set - all pages will be processed")
|
|
490
|
+
return set()
|
|
491
|
+
else:
|
|
492
|
+
logger.warning(f"⚠️ Error checking existing IDs: {e}")
|
|
493
|
+
return set()
|
|
442
494
|
|
|
443
495
|
return existing_ids
|
|
444
496
|
|