visual-rag-toolkit 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
demo/ui/playground.py CHANGED
@@ -9,6 +9,7 @@ from demo.qdrant_utils import (
9
9
  sample_points_cached,
10
10
  search_collection,
11
11
  )
12
+ from visual_rag.retrieval import MultiVectorRetriever
12
13
 
13
14
 
14
15
  def render_playground_tab():
@@ -46,7 +47,6 @@ def render_playground_tab():
46
47
  if not st.session_state.get("model_loaded"):
47
48
  with st.spinner(f"Loading {model_short}..."):
48
49
  try:
49
- from visual_rag.retrieval import MultiVectorRetriever
50
50
  _ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
51
51
  st.session_state["model_loaded"] = True
52
52
  st.session_state["loaded_model_key"] = cache_key
demo/ui/sidebar.py CHANGED
@@ -3,6 +3,8 @@
3
3
  import os
4
4
  import streamlit as st
5
5
 
6
+ from qdrant_client.models import VectorParamsDiff
7
+
6
8
  from demo.qdrant_utils import (
7
9
  get_qdrant_credentials,
8
10
  init_qdrant_client_with_creds,
@@ -14,11 +16,33 @@ from demo.qdrant_utils import (
14
16
 
15
17
 
16
18
  def render_sidebar():
19
+ # CSS to make sidebar metrics smaller
20
+ st.markdown("""
21
+ <style>
22
+ /* Smaller metrics in sidebar */
23
+ [data-testid="stSidebar"] [data-testid="stMetricValue"] {
24
+ font-size: 1.2rem !important;
25
+ }
26
+ [data-testid="stSidebar"] [data-testid="stMetricLabel"] {
27
+ font-size: 0.75rem !important;
28
+ }
29
+ /* Smaller expander headers in sidebar */
30
+ [data-testid="stSidebar"] [data-testid="stExpander"] summary {
31
+ font-size: 0.9rem !important;
32
+ }
33
+ /* Compact subheaders */
34
+ [data-testid="stSidebar"] h3 {
35
+ font-size: 1rem !important;
36
+ margin-bottom: 0.5rem !important;
37
+ }
38
+ </style>
39
+ """, unsafe_allow_html=True)
40
+
17
41
  with st.sidebar:
18
42
  st.subheader("🔑 Qdrant Credentials")
19
43
 
20
- env_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") or ""
21
- env_key = os.getenv("SIGIR_QDRANT_KEY") or os.getenv("SIGIR_QDRANT_API_KEY") or os.getenv("DEST_QDRANT_API_KEY") or os.getenv("QDRANT_API_KEY") or ""
44
+ env_url = os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") or ""
45
+ env_key = os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") or ""
22
46
 
23
47
  if "qdrant_url_input" not in st.session_state:
24
48
  st.session_state["qdrant_url_input"] = env_url
@@ -136,7 +160,6 @@ def render_sidebar():
136
160
  if target_in_ram != current_in_ram:
137
161
  if st.button("💾 Apply Change", key="admin_apply"):
138
162
  try:
139
- from qdrant_client.models import VectorParamsDiff
140
163
  client.update_collection(
141
164
  collection_name=active,
142
165
  vectors_config={sel_vec: VectorParamsDiff(on_disk=not target_in_ram)}
demo/ui/upload.py CHANGED
@@ -9,6 +9,7 @@ import inspect
9
9
  from datetime import datetime
10
10
  from pathlib import Path
11
11
 
12
+ import numpy as np
12
13
  import streamlit as st
13
14
 
14
15
  from demo.config import AVAILABLE_MODELS
@@ -17,6 +18,10 @@ from demo.qdrant_utils import (
17
18
  get_collection_stats,
18
19
  sample_points_cached,
19
20
  )
21
+ from visual_rag.embedding.visual_embedder import VisualEmbedder
22
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
23
+ from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
24
+ from visual_rag.indexing.pipeline import ProcessingPipeline
20
25
 
21
26
 
22
27
  VECTOR_TYPES = ["initial", "mean_pooling", "experimental_pooling", "global_pooling"]
@@ -251,10 +256,6 @@ def process_pdfs(uploaded_files, config):
251
256
  model_short = model_name.split("/")[-1]
252
257
  model_status.info(f"Loading `{model_short}`...")
253
258
 
254
- import numpy as np
255
- from visual_rag import VisualEmbedder
256
- from visual_rag.indexing import QdrantIndexer, CloudinaryUploader, ProcessingPipeline
257
-
258
259
  output_dtype = np.float16 if vector_dtype == "float16" else np.float32
259
260
  embedder_key = f"{model_name}::{vector_dtype}"
260
261
  embedder = None
@@ -448,7 +449,7 @@ def process_pdfs(uploaded_files, config):
448
449
 
449
450
  if total_uploaded > 0:
450
451
  st.session_state["upload_success"] = f"Uploaded {total_uploaded} pages to {collection_name}"
451
- st.balloons()
452
+ st.rerun() # Immediately refresh to show success toast + balloons
452
453
 
453
454
  except Exception as e:
454
455
  st.error(f"❌ Processing error: {e}")
visual_rag/__init__.py CHANGED
@@ -31,7 +31,47 @@ Quick Start:
31
31
  Each component works independently - use only what you need.
32
32
  """
33
33
 
34
- __version__ = "0.1.0"
34
+ import logging
35
+
36
+ __version__ = "0.1.4"
37
+
38
+
39
+ def setup_logging(level: str = "INFO", format: str = None) -> None:
40
+ """
41
+ Configure logging for visual_rag package.
42
+
43
+ Args:
44
+ level: Log level ("DEBUG", "INFO", "WARNING", "ERROR")
45
+ format: Custom format string. Default shows time, level, and message.
46
+
47
+ Example:
48
+ >>> import visual_rag
49
+ >>> visual_rag.setup_logging("INFO")
50
+ >>> # Now you'll see processing logs
51
+ """
52
+ if format is None:
53
+ format = "[%(asctime)s] %(levelname)s - %(message)s"
54
+
55
+ logging.basicConfig(
56
+ level=getattr(logging, level.upper(), logging.INFO),
57
+ format=format,
58
+ datefmt="%H:%M:%S",
59
+ )
60
+
61
+ # Also set the visual_rag logger specifically
62
+ logger = logging.getLogger("visual_rag")
63
+ logger.setLevel(getattr(logging, level.upper(), logging.INFO))
64
+
65
+
66
+ # Enable INFO logging by default for visual_rag package and all submodules
67
+ # This ensures logs like "Processing PDF...", "Embedding pages..." are visible
68
+ _logger = logging.getLogger("visual_rag")
69
+ if not _logger.handlers:
70
+ _handler = logging.StreamHandler()
71
+ _handler.setFormatter(logging.Formatter("[%(asctime)s] %(message)s", datefmt="%H:%M:%S"))
72
+ _logger.addHandler(_handler)
73
+ _logger.setLevel(logging.INFO)
74
+ _logger.propagate = False # Don't duplicate to root logger
35
75
 
36
76
  # Import main classes at package level for convenience
37
77
  # These are optional - if dependencies aren't installed, we catch the error
@@ -71,13 +111,16 @@ try:
71
111
  except ImportError:
72
112
  QdrantAdmin = None
73
113
 
74
- try:
75
- from visual_rag.demo_runner import demo
76
- except ImportError:
77
- demo = None
114
+ # demo is lazily imported to avoid RuntimeWarning when running as __main__
115
+ # Access via visual_rag.demo() which triggers __getattr__
78
116
 
79
117
  # Config utilities (always available)
80
- from visual_rag.config import get, get_section, load_config
118
+ try:
119
+ from visual_rag.config import get, get_section, load_config
120
+ except ImportError:
121
+ get = None
122
+ get_section = None
123
+ load_config = None
81
124
 
82
125
  __all__ = [
83
126
  # Version
@@ -95,4 +138,18 @@ __all__ = [
95
138
  "load_config",
96
139
  "get",
97
140
  "get_section",
141
+ # Logging
142
+ "setup_logging",
98
143
  ]
144
+
145
+
146
+ def __getattr__(name: str):
147
+ """Lazy import for demo to avoid RuntimeWarning when running as __main__."""
148
+ if name == "demo":
149
+ try:
150
+ from visual_rag.demo_runner import demo
151
+
152
+ return demo
153
+ except ImportError:
154
+ return None
155
+ raise AttributeError(f"module 'visual_rag' has no attribute {name!r}")
visual_rag/config.py CHANGED
@@ -21,16 +21,13 @@ _raw_config_cache_path: Optional[str] = None
21
21
 
22
22
 
23
23
  def _env_qdrant_url() -> Optional[str]:
24
- return os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
24
+ """Get Qdrant URL from environment. Prefers QDRANT_URL."""
25
+ return os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") # legacy fallback
25
26
 
26
27
 
27
28
  def _env_qdrant_api_key() -> Optional[str]:
28
- return (
29
- os.getenv("SIGIR_QDRANT_KEY")
30
- or os.getenv("SIGIR_QDRANT_API_KEY")
31
- or os.getenv("DEST_QDRANT_API_KEY")
32
- or os.getenv("QDRANT_API_KEY")
33
- )
29
+ """Get Qdrant API key from environment. Prefers QDRANT_API_KEY."""
30
+ return os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") # legacy fallback
34
31
 
35
32
 
36
33
  def load_config(
visual_rag/demo_runner.py CHANGED
@@ -52,13 +52,11 @@ def demo(
52
52
  cmd = [sys.executable, "-m", "streamlit", "run", str(app_path)]
53
53
  cmd += ["--server.address", str(host)]
54
54
  cmd += ["--server.port", str(int(port))]
55
- cmd += ["--server.headless", "true" if headless else "false"]
55
+ # headless=true prevents browser from auto-opening; open_browser overrides
56
+ should_be_headless = headless and not open_browser
57
+ cmd += ["--server.headless", "true" if should_be_headless else "false"]
56
58
  cmd += ["--browser.gatherUsageStats", "false"]
57
59
  cmd += ["--server.runOnSave", "false"]
58
- cmd += ["--browser.serverAddress", str(host)]
59
- if not open_browser:
60
- cmd += ["--browser.serverPort", str(int(port))]
61
- cmd += ["--browser.open", "false"]
62
60
 
63
61
  if extra_args:
64
62
  cmd += list(extra_args)
@@ -8,10 +8,27 @@ Components:
8
8
  - ProcessingPipeline: End-to-end PDF → Qdrant pipeline
9
9
  """
10
10
 
11
- from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
12
- from visual_rag.indexing.pdf_processor import PDFProcessor
13
- from visual_rag.indexing.pipeline import ProcessingPipeline
14
- from visual_rag.indexing.qdrant_indexer import QdrantIndexer
11
+ # Lazy imports to avoid failures when optional dependencies aren't installed
12
+
13
+ try:
14
+ from visual_rag.indexing.pdf_processor import PDFProcessor
15
+ except ImportError:
16
+ PDFProcessor = None
17
+
18
+ try:
19
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
20
+ except ImportError:
21
+ QdrantIndexer = None
22
+
23
+ try:
24
+ from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
25
+ except ImportError:
26
+ CloudinaryUploader = None
27
+
28
+ try:
29
+ from visual_rag.indexing.pipeline import ProcessingPipeline
30
+ except ImportError:
31
+ ProcessingPipeline = None
15
32
 
16
33
  __all__ = [
17
34
  "PDFProcessor",
@@ -19,6 +19,23 @@ from urllib.parse import urlparse
19
19
 
20
20
  import numpy as np
21
21
 
22
+ try:
23
+ from qdrant_client import QdrantClient
24
+ from qdrant_client.http import models as qdrant_models
25
+ from qdrant_client.http.models import Distance, VectorParams
26
+ from qdrant_client.models import FieldCondition, Filter, MatchValue
27
+
28
+ QDRANT_AVAILABLE = True
29
+ except ImportError:
30
+ QDRANT_AVAILABLE = False
31
+ QdrantClient = None
32
+ qdrant_models = None
33
+ Distance = None
34
+ VectorParams = None
35
+ FieldCondition = None
36
+ Filter = None
37
+ MatchValue = None
38
+
22
39
  logger = logging.getLogger(__name__)
23
40
 
24
41
 
@@ -58,9 +75,7 @@ class QdrantIndexer:
58
75
  prefer_grpc: bool = False,
59
76
  vector_datatype: str = "float32",
60
77
  ):
61
- try:
62
- from qdrant_client import QdrantClient
63
- except ImportError:
78
+ if not QDRANT_AVAILABLE:
64
79
  raise ImportError(
65
80
  "Qdrant client not installed. "
66
81
  "Install with: pip install visual-rag-toolkit[qdrant]"
@@ -139,9 +154,6 @@ class QdrantIndexer:
139
154
  Returns:
140
155
  True if created, False if already existed
141
156
  """
142
- from qdrant_client.http import models
143
- from qdrant_client.http.models import Distance, VectorParams
144
-
145
157
  if self.collection_exists():
146
158
  if force_recreate:
147
159
  logger.info(f"🗑️ Deleting existing collection: {self.collection_name}")
@@ -153,15 +165,15 @@ class QdrantIndexer:
153
165
  logger.info(f"📦 Creating collection: {self.collection_name}")
154
166
 
155
167
  # Multi-vector config for ColBERT-style MaxSim
156
- multivector_config = models.MultiVectorConfig(
157
- comparator=models.MultiVectorComparator.MAX_SIM
168
+ multivector_config = qdrant_models.MultiVectorConfig(
169
+ comparator=qdrant_models.MultiVectorComparator.MAX_SIM
158
170
  )
159
171
 
160
172
  # Vector configs - simplified for compatibility
161
173
  datatype = (
162
- models.Datatype.FLOAT16
174
+ qdrant_models.Datatype.FLOAT16
163
175
  if self.vector_datatype == "float16"
164
- else models.Datatype.FLOAT32
176
+ else qdrant_models.Datatype.FLOAT32
165
177
  )
166
178
  vectors_config = {
167
179
  "initial": VectorParams(
@@ -198,6 +210,18 @@ class QdrantIndexer:
198
210
  vectors_config=vectors_config,
199
211
  )
200
212
 
213
+ # Create required payload index for skip_existing functionality
214
+ # This index is needed for filtering by filename when checking existing docs
215
+ try:
216
+ self.client.create_payload_index(
217
+ collection_name=self.collection_name,
218
+ field_name="filename",
219
+ field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
220
+ )
221
+ logger.info(" 📇 Created payload index: filename")
222
+ except Exception as e:
223
+ logger.warning(f" ⚠️ Could not create filename index: {e}")
224
+
201
225
  logger.info(f"✅ Collection created: {self.collection_name}")
202
226
  return True
203
227
 
@@ -212,14 +236,12 @@ class QdrantIndexer:
212
236
  fields: List of {field, type} dicts
213
237
  type can be: integer, keyword, bool, float, text
214
238
  """
215
- from qdrant_client.http import models
216
-
217
239
  type_mapping = {
218
- "integer": models.PayloadSchemaType.INTEGER,
219
- "keyword": models.PayloadSchemaType.KEYWORD,
220
- "bool": models.PayloadSchemaType.BOOL,
221
- "float": models.PayloadSchemaType.FLOAT,
222
- "text": models.PayloadSchemaType.TEXT,
240
+ "integer": qdrant_models.PayloadSchemaType.INTEGER,
241
+ "keyword": qdrant_models.PayloadSchemaType.KEYWORD,
242
+ "bool": qdrant_models.PayloadSchemaType.BOOL,
243
+ "float": qdrant_models.PayloadSchemaType.FLOAT,
244
+ "text": qdrant_models.PayloadSchemaType.TEXT,
223
245
  }
224
246
 
225
247
  if not fields:
@@ -230,7 +252,7 @@ class QdrantIndexer:
230
252
  for field_config in fields:
231
253
  field_name = field_config["field"]
232
254
  field_type_str = field_config.get("type", "keyword")
233
- field_type = type_mapping.get(field_type_str, models.PayloadSchemaType.KEYWORD)
255
+ field_type = type_mapping.get(field_type_str, qdrant_models.PayloadSchemaType.KEYWORD)
234
256
 
235
257
  try:
236
258
  self.client.create_payload_index(
@@ -271,8 +293,6 @@ class QdrantIndexer:
271
293
  Returns:
272
294
  Number of successfully uploaded points
273
295
  """
274
- from qdrant_client.http import models
275
-
276
296
  if not points:
277
297
  return 0
278
298
 
@@ -315,8 +335,10 @@ class QdrantIndexer:
315
335
  return val.tolist()
316
336
  return val
317
337
 
318
- def _build_qdrant_points(batch_points: List[Dict[str, Any]]) -> List[models.PointStruct]:
319
- qdrant_points: List[models.PointStruct] = []
338
+ def _build_qdrant_points(
339
+ batch_points: List[Dict[str, Any]],
340
+ ) -> List[qdrant_models.PointStruct]:
341
+ qdrant_points: List[qdrant_models.PointStruct] = []
320
342
  for p in batch_points:
321
343
  global_pooled = p.get("global_pooled_embedding")
322
344
  if global_pooled is None:
@@ -336,7 +358,7 @@ class QdrantIndexer:
336
358
  global_pooling = global_pooled.astype(self._np_vector_dtype, copy=False)
337
359
 
338
360
  qdrant_points.append(
339
- models.PointStruct(
361
+ qdrant_models.PointStruct(
340
362
  id=p["id"],
341
363
  vector={
342
364
  "initial": _to_list(initial),
@@ -361,6 +383,8 @@ class QdrantIndexer:
361
383
  wait=wait,
362
384
  )
363
385
 
386
+ logger.info(f" ✅ Uploaded {len(points)} points to Qdrant")
387
+
364
388
  if delay_between_batches > 0:
365
389
  if _is_cancelled():
366
390
  return 0
@@ -413,32 +437,60 @@ class QdrantIndexer:
413
437
  return False
414
438
 
415
439
  def get_existing_ids(self, filename: str) -> Set[str]:
416
- """Get all point IDs for a specific file."""
417
- from qdrant_client.models import FieldCondition, Filter, MatchValue
440
+ """Get all point IDs for a specific file.
418
441
 
442
+ Requires a payload index on 'filename' field. If the index doesn't exist,
443
+ this method will attempt to create it automatically.
444
+ """
419
445
  existing_ids = set()
420
446
  offset = None
421
447
 
422
- while True:
423
- results = self.client.scroll(
424
- collection_name=self.collection_name,
425
- scroll_filter=Filter(
426
- must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
427
- ),
428
- limit=100,
429
- offset=offset,
430
- with_payload=["page_number"],
431
- with_vectors=False,
432
- )
448
+ try:
449
+ while True:
450
+ results = self.client.scroll(
451
+ collection_name=self.collection_name,
452
+ scroll_filter=Filter(
453
+ must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
454
+ ),
455
+ limit=100,
456
+ offset=offset,
457
+ with_payload=["page_number"],
458
+ with_vectors=False,
459
+ )
433
460
 
434
- points, next_offset = results
461
+ points, next_offset = results
435
462
 
436
- for point in points:
437
- existing_ids.add(str(point.id))
463
+ for point in points:
464
+ existing_ids.add(str(point.id))
438
465
 
439
- if next_offset is None or len(points) == 0:
440
- break
441
- offset = next_offset
466
+ if next_offset is None or len(points) == 0:
467
+ break
468
+ offset = next_offset
469
+
470
+ except Exception as e:
471
+ error_msg = str(e).lower()
472
+ if "index required" in error_msg or "index" in error_msg and "filename" in error_msg:
473
+ # Missing payload index - try to create it
474
+ logger.warning(
475
+ "⚠️ Missing 'filename' payload index. Creating it now... "
476
+ "(skip_existing requires this index for filtering)"
477
+ )
478
+ try:
479
+ self.client.create_payload_index(
480
+ collection_name=self.collection_name,
481
+ field_name="filename",
482
+ field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
483
+ )
484
+ logger.info(" ✅ Created 'filename' index. Retrying query...")
485
+ # Retry the query
486
+ return self.get_existing_ids(filename)
487
+ except Exception as idx_err:
488
+ logger.warning(f" ❌ Could not create index: {idx_err}")
489
+ logger.warning(" Returning empty set - all pages will be processed")
490
+ return set()
491
+ else:
492
+ logger.warning(f"⚠️ Error checking existing IDs: {e}")
493
+ return set()
442
494
 
443
495
  return existing_ids
444
496