visual-rag-toolkit 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
demo/ui/sidebar.py CHANGED
@@ -3,6 +3,8 @@
3
3
  import os
4
4
  import streamlit as st
5
5
 
6
+ from qdrant_client.models import VectorParamsDiff
7
+
6
8
  from demo.qdrant_utils import (
7
9
  get_qdrant_credentials,
8
10
  init_qdrant_client_with_creds,
@@ -17,8 +19,8 @@ def render_sidebar():
17
19
  with st.sidebar:
18
20
  st.subheader("🔑 Qdrant Credentials")
19
21
 
20
- env_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") or ""
21
- env_key = os.getenv("SIGIR_QDRANT_KEY") or os.getenv("SIGIR_QDRANT_API_KEY") or os.getenv("DEST_QDRANT_API_KEY") or os.getenv("QDRANT_API_KEY") or ""
22
+ env_url = os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") or ""
23
+ env_key = os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") or ""
22
24
 
23
25
  if "qdrant_url_input" not in st.session_state:
24
26
  st.session_state["qdrant_url_input"] = env_url
@@ -136,7 +138,6 @@ def render_sidebar():
136
138
  if target_in_ram != current_in_ram:
137
139
  if st.button("💾 Apply Change", key="admin_apply"):
138
140
  try:
139
- from qdrant_client.models import VectorParamsDiff
140
141
  client.update_collection(
141
142
  collection_name=active,
142
143
  vectors_config={sel_vec: VectorParamsDiff(on_disk=not target_in_ram)}
demo/ui/upload.py CHANGED
@@ -9,6 +9,7 @@ import inspect
9
9
  from datetime import datetime
10
10
  from pathlib import Path
11
11
 
12
+ import numpy as np
12
13
  import streamlit as st
13
14
 
14
15
  from demo.config import AVAILABLE_MODELS
@@ -17,6 +18,10 @@ from demo.qdrant_utils import (
17
18
  get_collection_stats,
18
19
  sample_points_cached,
19
20
  )
21
+ from visual_rag.embedding.visual_embedder import VisualEmbedder
22
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
23
+ from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
24
+ from visual_rag.indexing.pipeline import ProcessingPipeline
20
25
 
21
26
 
22
27
  VECTOR_TYPES = ["initial", "mean_pooling", "experimental_pooling", "global_pooling"]
@@ -251,10 +256,6 @@ def process_pdfs(uploaded_files, config):
251
256
  model_short = model_name.split("/")[-1]
252
257
  model_status.info(f"Loading `{model_short}`...")
253
258
 
254
- import numpy as np
255
- from visual_rag import VisualEmbedder
256
- from visual_rag.indexing import QdrantIndexer, CloudinaryUploader, ProcessingPipeline
257
-
258
259
  output_dtype = np.float16 if vector_dtype == "float16" else np.float32
259
260
  embedder_key = f"{model_name}::{vector_dtype}"
260
261
  embedder = None
visual_rag/__init__.py CHANGED
@@ -31,7 +31,47 @@ Quick Start:
31
31
  Each component works independently - use only what you need.
32
32
  """
33
33
 
34
- __version__ = "0.1.0"
34
+ import logging
35
+
36
+ __version__ = "0.1.3"
37
+
38
+
39
+ def setup_logging(level: str = "INFO", format: str = None) -> None:
40
+ """
41
+ Configure logging for visual_rag package.
42
+
43
+ Args:
44
+ level: Log level ("DEBUG", "INFO", "WARNING", "ERROR")
45
+ format: Custom format string. Default shows time, level, and message.
46
+
47
+ Example:
48
+ >>> import visual_rag
49
+ >>> visual_rag.setup_logging("INFO")
50
+ >>> # Now you'll see processing logs
51
+ """
52
+ if format is None:
53
+ format = "[%(asctime)s] %(levelname)s - %(message)s"
54
+
55
+ logging.basicConfig(
56
+ level=getattr(logging, level.upper(), logging.INFO),
57
+ format=format,
58
+ datefmt="%H:%M:%S",
59
+ )
60
+
61
+ # Also set the visual_rag logger specifically
62
+ logger = logging.getLogger("visual_rag")
63
+ logger.setLevel(getattr(logging, level.upper(), logging.INFO))
64
+
65
+
66
+ # Enable INFO logging by default for visual_rag package and all submodules
67
+ # This ensures logs like "Processing PDF...", "Embedding pages..." are visible
68
+ _logger = logging.getLogger("visual_rag")
69
+ if not _logger.handlers:
70
+ _handler = logging.StreamHandler()
71
+ _handler.setFormatter(logging.Formatter("[%(asctime)s] %(message)s", datefmt="%H:%M:%S"))
72
+ _logger.addHandler(_handler)
73
+ _logger.setLevel(logging.INFO)
74
+ _logger.propagate = False # Don't duplicate to root logger
35
75
 
36
76
  # Import main classes at package level for convenience
37
77
  # These are optional - if dependencies aren't installed, we catch the error
@@ -95,4 +135,6 @@ __all__ = [
95
135
  "load_config",
96
136
  "get",
97
137
  "get_section",
138
+ # Logging
139
+ "setup_logging",
98
140
  ]
visual_rag/config.py CHANGED
@@ -21,16 +21,13 @@ _raw_config_cache_path: Optional[str] = None
21
21
 
22
22
 
23
23
  def _env_qdrant_url() -> Optional[str]:
24
- return os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
24
+ """Get Qdrant URL from environment. Prefers QDRANT_URL."""
25
+ return os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy fallback
25
26
 
26
27
 
27
28
  def _env_qdrant_api_key() -> Optional[str]:
28
- return (
29
- os.getenv("SIGIR_QDRANT_KEY")
30
- or os.getenv("SIGIR_QDRANT_API_KEY")
31
- or os.getenv("DEST_QDRANT_API_KEY")
32
- or os.getenv("QDRANT_API_KEY")
33
- )
29
+ """Get Qdrant API key from environment. Prefers QDRANT_API_KEY."""
30
+ return os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") # legacy fallback
34
31
 
35
32
 
36
33
  def load_config(
@@ -8,10 +8,27 @@ Components:
8
8
  - ProcessingPipeline: End-to-end PDF → Qdrant pipeline
9
9
  """
10
10
 
11
- from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
12
- from visual_rag.indexing.pdf_processor import PDFProcessor
13
- from visual_rag.indexing.pipeline import ProcessingPipeline
14
- from visual_rag.indexing.qdrant_indexer import QdrantIndexer
11
+ # Lazy imports to avoid failures when optional dependencies aren't installed
12
+
13
+ try:
14
+ from visual_rag.indexing.pdf_processor import PDFProcessor
15
+ except ImportError:
16
+ PDFProcessor = None
17
+
18
+ try:
19
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
20
+ except ImportError:
21
+ QdrantIndexer = None
22
+
23
+ try:
24
+ from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
25
+ except ImportError:
26
+ CloudinaryUploader = None
27
+
28
+ try:
29
+ from visual_rag.indexing.pipeline import ProcessingPipeline
30
+ except ImportError:
31
+ ProcessingPipeline = None
15
32
 
16
33
  __all__ = [
17
34
  "PDFProcessor",
@@ -19,6 +19,23 @@ from urllib.parse import urlparse
19
19
 
20
20
  import numpy as np
21
21
 
22
+ try:
23
+ from qdrant_client import QdrantClient
24
+ from qdrant_client.http import models as qdrant_models
25
+ from qdrant_client.http.models import Distance, VectorParams
26
+ from qdrant_client.models import FieldCondition, Filter, MatchValue
27
+
28
+ QDRANT_AVAILABLE = True
29
+ except ImportError:
30
+ QDRANT_AVAILABLE = False
31
+ QdrantClient = None
32
+ qdrant_models = None
33
+ Distance = None
34
+ VectorParams = None
35
+ FieldCondition = None
36
+ Filter = None
37
+ MatchValue = None
38
+
22
39
  logger = logging.getLogger(__name__)
23
40
 
24
41
 
@@ -58,9 +75,7 @@ class QdrantIndexer:
58
75
  prefer_grpc: bool = False,
59
76
  vector_datatype: str = "float32",
60
77
  ):
61
- try:
62
- from qdrant_client import QdrantClient
63
- except ImportError:
78
+ if not QDRANT_AVAILABLE:
64
79
  raise ImportError(
65
80
  "Qdrant client not installed. "
66
81
  "Install with: pip install visual-rag-toolkit[qdrant]"
@@ -139,9 +154,6 @@ class QdrantIndexer:
139
154
  Returns:
140
155
  True if created, False if already existed
141
156
  """
142
- from qdrant_client.http import models
143
- from qdrant_client.http.models import Distance, VectorParams
144
-
145
157
  if self.collection_exists():
146
158
  if force_recreate:
147
159
  logger.info(f"🗑️ Deleting existing collection: {self.collection_name}")
@@ -153,15 +165,15 @@ class QdrantIndexer:
153
165
  logger.info(f"📦 Creating collection: {self.collection_name}")
154
166
 
155
167
  # Multi-vector config for ColBERT-style MaxSim
156
- multivector_config = models.MultiVectorConfig(
157
- comparator=models.MultiVectorComparator.MAX_SIM
168
+ multivector_config = qdrant_models.MultiVectorConfig(
169
+ comparator=qdrant_models.MultiVectorComparator.MAX_SIM
158
170
  )
159
171
 
160
172
  # Vector configs - simplified for compatibility
161
173
  datatype = (
162
- models.Datatype.FLOAT16
174
+ qdrant_models.Datatype.FLOAT16
163
175
  if self.vector_datatype == "float16"
164
- else models.Datatype.FLOAT32
176
+ else qdrant_models.Datatype.FLOAT32
165
177
  )
166
178
  vectors_config = {
167
179
  "initial": VectorParams(
@@ -198,6 +210,18 @@ class QdrantIndexer:
198
210
  vectors_config=vectors_config,
199
211
  )
200
212
 
213
+ # Create required payload index for skip_existing functionality
214
+ # This index is needed for filtering by filename when checking existing docs
215
+ try:
216
+ self.client.create_payload_index(
217
+ collection_name=self.collection_name,
218
+ field_name="filename",
219
+ field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
220
+ )
221
+ logger.info(" 📇 Created payload index: filename")
222
+ except Exception as e:
223
+ logger.warning(f" ⚠️ Could not create filename index: {e}")
224
+
201
225
  logger.info(f"✅ Collection created: {self.collection_name}")
202
226
  return True
203
227
 
@@ -212,14 +236,12 @@ class QdrantIndexer:
212
236
  fields: List of {field, type} dicts
213
237
  type can be: integer, keyword, bool, float, text
214
238
  """
215
- from qdrant_client.http import models
216
-
217
239
  type_mapping = {
218
- "integer": models.PayloadSchemaType.INTEGER,
219
- "keyword": models.PayloadSchemaType.KEYWORD,
220
- "bool": models.PayloadSchemaType.BOOL,
221
- "float": models.PayloadSchemaType.FLOAT,
222
- "text": models.PayloadSchemaType.TEXT,
240
+ "integer": qdrant_models.PayloadSchemaType.INTEGER,
241
+ "keyword": qdrant_models.PayloadSchemaType.KEYWORD,
242
+ "bool": qdrant_models.PayloadSchemaType.BOOL,
243
+ "float": qdrant_models.PayloadSchemaType.FLOAT,
244
+ "text": qdrant_models.PayloadSchemaType.TEXT,
223
245
  }
224
246
 
225
247
  if not fields:
@@ -230,7 +252,7 @@ class QdrantIndexer:
230
252
  for field_config in fields:
231
253
  field_name = field_config["field"]
232
254
  field_type_str = field_config.get("type", "keyword")
233
- field_type = type_mapping.get(field_type_str, models.PayloadSchemaType.KEYWORD)
255
+ field_type = type_mapping.get(field_type_str, qdrant_models.PayloadSchemaType.KEYWORD)
234
256
 
235
257
  try:
236
258
  self.client.create_payload_index(
@@ -271,8 +293,6 @@ class QdrantIndexer:
271
293
  Returns:
272
294
  Number of successfully uploaded points
273
295
  """
274
- from qdrant_client.http import models
275
-
276
296
  if not points:
277
297
  return 0
278
298
 
@@ -315,8 +335,8 @@ class QdrantIndexer:
315
335
  return val.tolist()
316
336
  return val
317
337
 
318
- def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[models.PointStruct]:
319
- qdrant_points: List[models.PointStruct] = []
338
+ def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[qdrant_models.PointStruct]:
339
+ qdrant_points: List[qdrant_models.PointStruct] = []
320
340
  for p in batch_points:
321
341
  global_pooled = p.get("global_pooled_embedding")
322
342
  if global_pooled is None:
@@ -336,7 +356,7 @@ class QdrantIndexer:
336
356
  global_pooling = global_pooled.astype(self._np_vector_dtype, copy=False)
337
357
 
338
358
  qdrant_points.append(
339
- models.PointStruct(
359
+ qdrant_models.PointStruct(
340
360
  id=p["id"],
341
361
  vector={
342
362
  "initial": _to_list(initial),
@@ -361,6 +381,8 @@ class QdrantIndexer:
361
381
  wait=wait,
362
382
  )
363
383
 
384
+ logger.info(f" ✅ Uploaded {len(points)} points to Qdrant")
385
+
364
386
  if delay_between_batches > 0:
365
387
  if _is_cancelled():
366
388
  return 0
@@ -413,32 +435,60 @@ class QdrantIndexer:
413
435
  return False
414
436
 
415
437
  def get_existing_ids(self, filename: str) -> Set[str]:
416
- """Get all point IDs for a specific file."""
417
- from qdrant_client.models import FieldCondition, Filter, MatchValue
438
+ """Get all point IDs for a specific file.
418
439
 
440
+ Requires a payload index on 'filename' field. If the index doesn't exist,
441
+ this method will attempt to create it automatically.
442
+ """
419
443
  existing_ids = set()
420
444
  offset = None
421
445
 
422
- while True:
423
- results = self.client.scroll(
424
- collection_name=self.collection_name,
425
- scroll_filter=Filter(
426
- must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
427
- ),
428
- limit=100,
429
- offset=offset,
430
- with_payload=["page_number"],
431
- with_vectors=False,
432
- )
446
+ try:
447
+ while True:
448
+ results = self.client.scroll(
449
+ collection_name=self.collection_name,
450
+ scroll_filter=Filter(
451
+ must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
452
+ ),
453
+ limit=100,
454
+ offset=offset,
455
+ with_payload=["page_number"],
456
+ with_vectors=False,
457
+ )
458
+
459
+ points, next_offset = results
433
460
 
434
- points, next_offset = results
461
+ for point in points:
462
+ existing_ids.add(str(point.id))
435
463
 
436
- for point in points:
437
- existing_ids.add(str(point.id))
464
+ if next_offset is None or len(points) == 0:
465
+ break
466
+ offset = next_offset
438
467
 
439
- if next_offset is None or len(points) == 0:
440
- break
441
- offset = next_offset
468
+ except Exception as e:
469
+ error_msg = str(e).lower()
470
+ if "index required" in error_msg or "index" in error_msg and "filename" in error_msg:
471
+ # Missing payload index - try to create it
472
+ logger.warning(
473
+ "⚠️ Missing 'filename' payload index. Creating it now... "
474
+ "(skip_existing requires this index for filtering)"
475
+ )
476
+ try:
477
+ self.client.create_payload_index(
478
+ collection_name=self.collection_name,
479
+ field_name="filename",
480
+ field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
481
+ )
482
+ logger.info(" ✅ Created 'filename' index. Retrying query...")
483
+ # Retry the query
484
+ return self.get_existing_ids(filename)
485
+ except Exception as idx_err:
486
+ logger.warning(f" ❌ Could not create index: {idx_err}")
487
+ logger.warning(" Returning empty set - all pages will be processed")
488
+ return set()
489
+ else:
490
+ logger.warning(f"⚠️ Error checking existing IDs: {e}")
491
+ return set()
442
492
 
443
493
  return existing_ids
444
494
 
@@ -2,6 +2,25 @@ import os
2
2
  from typing import Any, Dict, List, Optional
3
3
  from urllib.parse import urlparse
4
4
 
5
+ import numpy as np
6
+ import torch
7
+
8
+ try:
9
+ from dotenv import load_dotenv
10
+
11
+ DOTENV_AVAILABLE = True
12
+ except ImportError:
13
+ DOTENV_AVAILABLE = False
14
+ load_dotenv = None
15
+
16
+ try:
17
+ from qdrant_client import QdrantClient
18
+
19
+ QDRANT_AVAILABLE = True
20
+ except ImportError:
21
+ QDRANT_AVAILABLE = False
22
+ QdrantClient = None
23
+
5
24
  from visual_rag.embedding.visual_embedder import VisualEmbedder
6
25
  from visual_rag.retrieval.single_stage import SingleStageRetriever
7
26
  from visual_rag.retrieval.three_stage import ThreeStageRetriever
@@ -11,9 +30,7 @@ from visual_rag.retrieval.two_stage import TwoStageRetriever
11
30
  class MultiVectorRetriever:
12
31
  @staticmethod
13
32
  def _maybe_load_dotenv() -> None:
14
- try:
15
- from dotenv import load_dotenv
16
- except ImportError:
33
+ if not DOTENV_AVAILABLE:
17
34
  return
18
35
  if os.path.exists(".env"):
19
36
  load_dotenv(".env")
@@ -33,87 +50,84 @@ class MultiVectorRetriever:
33
50
  ):
34
51
  if qdrant_client is None:
35
52
  self._maybe_load_dotenv()
36
- try:
37
- from qdrant_client import QdrantClient
38
- except ImportError as e:
53
+ if not QDRANT_AVAILABLE:
39
54
  raise ImportError(
40
55
  "Qdrant client not installed. Install with: pip install visual-rag-toolkit[qdrant]"
41
- ) from e
56
+ )
42
57
 
43
58
  qdrant_url = (
44
59
  qdrant_url
45
- or os.getenv("SIGIR_QDRANT_URL")
46
- or os.getenv("DEST_QDRANT_URL")
47
60
  or os.getenv("QDRANT_URL")
61
+ or os.getenv("SIGIR_QDRANT_URL") # legacy
48
62
  )
49
63
  if not qdrant_url:
50
64
  raise ValueError(
51
- "QDRANT_URL is required (pass qdrant_url or set env var). "
52
- "You can also set DEST_QDRANT_URL to override."
65
+ "QDRANT_URL is required (pass qdrant_url or set env var)."
53
66
  )
54
67
 
55
68
  qdrant_api_key = (
56
69
  qdrant_api_key
57
- or os.getenv("SIGIR_QDRANT_KEY")
58
- or os.getenv("SIGIR_QDRANT_API_KEY")
59
- or os.getenv("DEST_QDRANT_API_KEY")
60
70
  or os.getenv("QDRANT_API_KEY")
71
+ or os.getenv("SIGIR_QDRANT_KEY") # legacy
61
72
  )
62
73
 
63
74
  grpc_port = None
64
75
  if prefer_grpc:
65
76
  try:
66
- if urlparse(qdrant_url).port == 6333:
77
+ parsed = urlparse(qdrant_url)
78
+ port = parsed.port
79
+ if port == 6333:
67
80
  grpc_port = 6334
68
81
  except Exception:
69
- grpc_port = None
82
+ pass
70
83
 
71
84
  def _make_client(use_grpc: bool):
72
85
  return QdrantClient(
73
86
  url=qdrant_url,
74
87
  api_key=qdrant_api_key,
88
+ timeout=request_timeout,
75
89
  prefer_grpc=bool(use_grpc),
76
90
  grpc_port=grpc_port,
77
- timeout=int(request_timeout),
78
91
  check_compatibility=False,
79
92
  )
80
93
 
81
- qdrant_client = _make_client(prefer_grpc)
94
+ client = _make_client(prefer_grpc)
82
95
  if prefer_grpc:
83
96
  try:
84
- _ = qdrant_client.get_collections()
97
+ _ = client.get_collections()
85
98
  except Exception as e:
86
99
  msg = str(e)
87
- if (
88
- "StatusCode.PERMISSION_DENIED" in msg
89
- or "http2 header with status: 403" in msg
90
- ):
91
- qdrant_client = _make_client(False)
100
+ if "StatusCode.PERMISSION_DENIED" in msg or "http2 header with status: 403" in msg:
101
+ client = _make_client(False)
92
102
  else:
93
103
  raise
104
+ qdrant_client = client
94
105
 
95
106
  self.client = qdrant_client
96
107
  self.collection_name = collection_name
108
+
97
109
  self.embedder = embedder or VisualEmbedder(model_name=model_name)
98
110
 
99
111
  self._two_stage = TwoStageRetriever(
100
- self.client,
101
- collection_name=self.collection_name,
102
- request_timeout=int(request_timeout),
103
- max_retries=int(max_retries),
104
- retry_sleep=float(retry_sleep),
112
+ qdrant_client=qdrant_client,
113
+ collection_name=collection_name,
114
+ request_timeout=request_timeout,
115
+ max_retries=max_retries,
116
+ retry_sleep=retry_sleep,
105
117
  )
106
118
  self._three_stage = ThreeStageRetriever(
107
- self.client,
108
- collection_name=self.collection_name,
109
- request_timeout=int(request_timeout),
110
- max_retries=int(max_retries),
111
- retry_sleep=float(retry_sleep),
119
+ qdrant_client=qdrant_client,
120
+ collection_name=collection_name,
121
+ request_timeout=request_timeout,
122
+ max_retries=max_retries,
123
+ retry_sleep=retry_sleep,
112
124
  )
113
125
  self._single_stage = SingleStageRetriever(
114
- self.client,
115
- collection_name=self.collection_name,
116
- request_timeout=int(request_timeout),
126
+ qdrant_client=qdrant_client,
127
+ collection_name=collection_name,
128
+ request_timeout=request_timeout,
129
+ max_retries=max_retries,
130
+ retry_sleep=retry_sleep,
117
131
  )
118
132
 
119
133
  def build_filter(
@@ -143,14 +157,10 @@ class MultiVectorRetriever:
143
157
  return_embeddings: bool = False,
144
158
  ) -> List[Dict[str, Any]]:
145
159
  q = self.embedder.embed_query(query)
146
- try:
147
- import torch
148
- except ImportError:
149
- torch = None
150
- if torch is not None and isinstance(q, torch.Tensor):
160
+ if isinstance(q, torch.Tensor):
151
161
  query_embedding = q.detach().cpu().numpy()
152
162
  else:
153
- query_embedding = q.numpy()
163
+ query_embedding = np.asarray(q)
154
164
 
155
165
  return self.search_embedded(
156
166
  query_embedding=query_embedding,
@@ -179,27 +189,17 @@ class MultiVectorRetriever:
179
189
  return self._single_stage.search(
180
190
  query_embedding=query_embedding,
181
191
  top_k=top_k,
182
- strategy="multi_vector",
183
192
  filter_obj=filter_obj,
193
+ using="initial",
184
194
  )
185
-
186
- if mode == "single_tiles":
195
+ elif mode == "single_pooled":
187
196
  return self._single_stage.search(
188
197
  query_embedding=query_embedding,
189
198
  top_k=top_k,
190
- strategy="tiles_maxsim",
191
199
  filter_obj=filter_obj,
200
+ using="mean_pooling",
192
201
  )
193
-
194
- if mode == "single_global":
195
- return self._single_stage.search(
196
- query_embedding=query_embedding,
197
- top_k=top_k,
198
- strategy="pooled_global",
199
- filter_obj=filter_obj,
200
- )
201
-
202
- if mode == "two_stage":
202
+ elif mode == "two_stage":
203
203
  return self._two_stage.search_server_side(
204
204
  query_embedding=query_embedding,
205
205
  top_k=top_k,
@@ -207,16 +207,14 @@ class MultiVectorRetriever:
207
207
  filter_obj=filter_obj,
208
208
  stage1_mode=stage1_mode,
209
209
  )
210
-
211
- if mode == "three_stage":
212
- s1 = int(stage1_k) if stage1_k is not None else 1000
213
- s2 = int(stage2_k) if stage2_k is not None else 300
210
+ elif mode == "three_stage":
214
211
  return self._three_stage.search_server_side(
215
212
  query_embedding=query_embedding,
216
213
  top_k=top_k,
217
- stage1_k=s1,
218
- stage2_k=s2,
214
+ stage1_k=stage1_k,
215
+ stage2_k=stage2_k,
219
216
  filter_obj=filter_obj,
217
+ stage1_mode=stage1_mode,
220
218
  )
221
-
222
- raise ValueError(f"Unknown mode: {mode}")
219
+ else:
220
+ raise ValueError(f"Unknown mode: {mode}")
@@ -30,6 +30,9 @@ class SingleStageRetriever:
30
30
  Args:
31
31
  qdrant_client: Connected Qdrant client
32
32
  collection_name: Name of the Qdrant collection
33
+ request_timeout: Timeout for Qdrant requests (seconds)
34
+ max_retries: Number of retry attempts on failure
35
+ retry_sleep: Sleep time between retries (seconds)
33
36
 
34
37
  Example:
35
38
  >>> retriever = SingleStageRetriever(client, "my_collection")
@@ -41,10 +44,14 @@ class SingleStageRetriever:
41
44
  qdrant_client,
42
45
  collection_name: str,
43
46
  request_timeout: int = 120,
47
+ max_retries: int = 3,
48
+ retry_sleep: float = 1.0,
44
49
  ):
45
50
  self.client = qdrant_client
46
51
  self.collection_name = collection_name
47
52
  self.request_timeout = int(request_timeout)
53
+ self.max_retries = max_retries
54
+ self.retry_sleep = retry_sleep
48
55
 
49
56
  def search(
50
57
  self,