signalwire-agents 0.1.47__py3-none-any.whl → 0.1.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- signalwire_agents/__init__.py +1 -1
- signalwire_agents/cli/build_search.py +516 -12
- signalwire_agents/search/__init__.py +7 -1
- signalwire_agents/search/document_processor.py +11 -8
- signalwire_agents/search/index_builder.py +112 -13
- signalwire_agents/search/migration.py +418 -0
- signalwire_agents/search/models.py +30 -0
- signalwire_agents/search/pgvector_backend.py +236 -13
- signalwire_agents/search/query_processor.py +87 -9
- signalwire_agents/search/search_engine.py +835 -31
- signalwire_agents/search/search_service.py +56 -6
- signalwire_agents/skills/native_vector_search/skill.py +208 -33
- {signalwire_agents-0.1.47.dist-info → signalwire_agents-0.1.48.dist-info}/METADATA +1 -1
- {signalwire_agents-0.1.47.dist-info → signalwire_agents-0.1.48.dist-info}/RECORD +18 -16
- {signalwire_agents-0.1.47.dist-info → signalwire_agents-0.1.48.dist-info}/WHEEL +0 -0
- {signalwire_agents-0.1.47.dist-info → signalwire_agents-0.1.48.dist-info}/entry_points.txt +0 -0
- {signalwire_agents-0.1.47.dist-info → signalwire_agents-0.1.48.dist-info}/licenses/LICENSE +0 -0
- {signalwire_agents-0.1.47.dist-info → signalwire_agents-0.1.48.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,30 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) 2025 SignalWire
|
3
|
+
|
4
|
+
This file is part of the SignalWire AI Agents SDK.
|
5
|
+
|
6
|
+
Licensed under the MIT License.
|
7
|
+
See LICENSE file in the project root for full license information.
|
8
|
+
"""
|
9
|
+
|
10
|
+
# Embedding model configuration
|
11
|
+
MODEL_ALIASES = {
|
12
|
+
'mini': 'sentence-transformers/all-MiniLM-L6-v2', # 384 dims, ~5x faster
|
13
|
+
'base': 'sentence-transformers/all-mpnet-base-v2', # 768 dims, balanced
|
14
|
+
'large': 'sentence-transformers/all-mpnet-base-v2', # Same as base for now
|
15
|
+
}
|
16
|
+
|
17
|
+
# Default model for new indexes
|
18
|
+
DEFAULT_MODEL = MODEL_ALIASES['mini']
|
19
|
+
|
20
|
+
def resolve_model_alias(model_name: str) -> str:
|
21
|
+
"""
|
22
|
+
Resolve model alias to full model name
|
23
|
+
|
24
|
+
Args:
|
25
|
+
model_name: Model name or alias (mini, base, large)
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
Full model name
|
29
|
+
"""
|
30
|
+
return MODEL_ALIASES.get(model_name, model_name)
|
@@ -99,6 +99,7 @@ class PgVectorBackend:
|
|
99
99
|
section TEXT,
|
100
100
|
tags JSONB DEFAULT '[]'::jsonb,
|
101
101
|
metadata JSONB DEFAULT '{{}}'::jsonb,
|
102
|
+
metadata_text TEXT, -- Searchable text representation of all metadata
|
102
103
|
created_at TIMESTAMP DEFAULT NOW()
|
103
104
|
)
|
104
105
|
""")
|
@@ -120,6 +121,16 @@ class PgVectorBackend:
|
|
120
121
|
ON {table_name} USING gin (tags)
|
121
122
|
""")
|
122
123
|
|
124
|
+
cursor.execute(f"""
|
125
|
+
CREATE INDEX IF NOT EXISTS idx_{table_name}_metadata
|
126
|
+
ON {table_name} USING gin (metadata)
|
127
|
+
""")
|
128
|
+
|
129
|
+
cursor.execute(f"""
|
130
|
+
CREATE INDEX IF NOT EXISTS idx_{table_name}_metadata_text
|
131
|
+
ON {table_name} USING gin (metadata_text gin_trgm_ops)
|
132
|
+
""")
|
133
|
+
|
123
134
|
# Create config table
|
124
135
|
cursor.execute("""
|
125
136
|
CREATE TABLE IF NOT EXISTS collection_config (
|
@@ -136,6 +147,36 @@ class PgVectorBackend:
|
|
136
147
|
self.conn.commit()
|
137
148
|
logger.info(f"Created schema for collection '{collection_name}'")
|
138
149
|
|
150
|
+
def _extract_metadata_from_json_content(self, content: str) -> Dict[str, Any]:
|
151
|
+
"""
|
152
|
+
Extract metadata from JSON content if present
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
metadata_dict
|
156
|
+
"""
|
157
|
+
metadata_dict = {}
|
158
|
+
|
159
|
+
# Try to extract metadata from JSON structure in content
|
160
|
+
if '"metadata":' in content:
|
161
|
+
try:
|
162
|
+
import re
|
163
|
+
# Find all metadata objects
|
164
|
+
pattern = r'"metadata"\s*:\s*(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})'
|
165
|
+
matches = re.finditer(pattern, content)
|
166
|
+
|
167
|
+
for match in matches:
|
168
|
+
try:
|
169
|
+
json_metadata = json.loads(match.group(1))
|
170
|
+
# Merge all found metadata
|
171
|
+
if isinstance(json_metadata, dict):
|
172
|
+
metadata_dict.update(json_metadata)
|
173
|
+
except:
|
174
|
+
pass
|
175
|
+
except Exception as e:
|
176
|
+
logger.debug(f"Error extracting JSON metadata: {e}")
|
177
|
+
|
178
|
+
return metadata_dict
|
179
|
+
|
139
180
|
def store_chunks(self, chunks: List[Dict[str, Any]], collection_name: str,
|
140
181
|
config: Dict[str, Any]):
|
141
182
|
"""
|
@@ -166,6 +207,9 @@ class PgVectorBackend:
|
|
166
207
|
section = chunk.get('section') or metadata.get('section', '')
|
167
208
|
tags = chunk.get('tags', []) or metadata.get('tags', [])
|
168
209
|
|
210
|
+
# Extract metadata from JSON content and merge with chunk metadata
|
211
|
+
json_metadata = self._extract_metadata_from_json_content(chunk['content'])
|
212
|
+
|
169
213
|
# Build metadata from all fields except the ones we store separately
|
170
214
|
chunk_metadata = {}
|
171
215
|
for key, value in chunk.items():
|
@@ -176,6 +220,30 @@ class PgVectorBackend:
|
|
176
220
|
if key not in ['filename', 'section', 'tags']:
|
177
221
|
chunk_metadata[key] = value
|
178
222
|
|
223
|
+
# Merge metadata: chunk metadata takes precedence over JSON metadata
|
224
|
+
merged_metadata = {**json_metadata, **chunk_metadata}
|
225
|
+
|
226
|
+
# Create searchable metadata text
|
227
|
+
metadata_text_parts = []
|
228
|
+
|
229
|
+
# Add all metadata keys and values
|
230
|
+
for key, value in merged_metadata.items():
|
231
|
+
metadata_text_parts.append(str(key).lower())
|
232
|
+
if isinstance(value, list):
|
233
|
+
metadata_text_parts.extend(str(v).lower() for v in value)
|
234
|
+
else:
|
235
|
+
metadata_text_parts.append(str(value).lower())
|
236
|
+
|
237
|
+
# Add tags
|
238
|
+
if tags:
|
239
|
+
metadata_text_parts.extend(str(tag).lower() for tag in tags)
|
240
|
+
|
241
|
+
# Add section if present
|
242
|
+
if section:
|
243
|
+
metadata_text_parts.append(section.lower())
|
244
|
+
|
245
|
+
metadata_text = ' '.join(metadata_text_parts)
|
246
|
+
|
179
247
|
data.append((
|
180
248
|
chunk['content'],
|
181
249
|
chunk.get('processed_content', chunk['content']),
|
@@ -183,7 +251,8 @@ class PgVectorBackend:
|
|
183
251
|
filename,
|
184
252
|
section,
|
185
253
|
json.dumps(tags),
|
186
|
-
json.dumps(
|
254
|
+
json.dumps(merged_metadata),
|
255
|
+
metadata_text
|
187
256
|
))
|
188
257
|
|
189
258
|
# Batch insert chunks
|
@@ -192,11 +261,11 @@ class PgVectorBackend:
|
|
192
261
|
cursor,
|
193
262
|
f"""
|
194
263
|
INSERT INTO {table_name}
|
195
|
-
(content, processed_content, embedding, filename, section, tags, metadata)
|
264
|
+
(content, processed_content, embedding, filename, section, tags, metadata, metadata_text)
|
196
265
|
VALUES %s
|
197
266
|
""",
|
198
267
|
data,
|
199
|
-
template="(%s, %s, %s, %s, %s, %s::jsonb, %s::jsonb)"
|
268
|
+
template="(%s, %s, %s, %s, %s, %s::jsonb, %s::jsonb, %s)"
|
200
269
|
)
|
201
270
|
|
202
271
|
# Update or insert config
|
@@ -355,9 +424,10 @@ class PgVectorSearchBackend:
|
|
355
424
|
|
356
425
|
def search(self, query_vector: List[float], enhanced_text: str,
|
357
426
|
count: int = 5, distance_threshold: float = 0.0,
|
358
|
-
tags: Optional[List[str]] = None
|
427
|
+
tags: Optional[List[str]] = None,
|
428
|
+
keyword_weight: Optional[float] = None) -> List[Dict[str, Any]]:
|
359
429
|
"""
|
360
|
-
Perform hybrid search (vector + keyword)
|
430
|
+
Perform hybrid search (vector + keyword + metadata)
|
361
431
|
|
362
432
|
Args:
|
363
433
|
query_vector: Embedding vector for the query
|
@@ -365,20 +435,27 @@ class PgVectorSearchBackend:
|
|
365
435
|
count: Number of results to return
|
366
436
|
distance_threshold: Minimum similarity score
|
367
437
|
tags: Filter by tags
|
438
|
+
keyword_weight: Manual keyword weight (0.0-1.0). If None, uses default weighting
|
368
439
|
|
369
440
|
Returns:
|
370
441
|
List of search results with scores and metadata
|
371
442
|
"""
|
372
443
|
self._ensure_connection()
|
373
444
|
|
445
|
+
# Extract query terms for metadata search
|
446
|
+
query_terms = enhanced_text.lower().split()
|
447
|
+
|
374
448
|
# Vector search
|
375
449
|
vector_results = self._vector_search(query_vector, count * 2, tags)
|
376
450
|
|
377
451
|
# Keyword search
|
378
452
|
keyword_results = self._keyword_search(enhanced_text, count * 2, tags)
|
379
453
|
|
380
|
-
#
|
381
|
-
|
454
|
+
# Metadata search
|
455
|
+
metadata_results = self._metadata_search(query_terms, count * 2, tags)
|
456
|
+
|
457
|
+
# Merge all results
|
458
|
+
merged_results = self._merge_all_results(vector_results, keyword_results, metadata_results, keyword_weight)
|
382
459
|
|
383
460
|
# Filter by distance threshold
|
384
461
|
filtered_results = [
|
@@ -386,6 +463,11 @@ class PgVectorSearchBackend:
|
|
386
463
|
if r['score'] >= distance_threshold
|
387
464
|
]
|
388
465
|
|
466
|
+
# Ensure 'score' field exists for CLI compatibility
|
467
|
+
for r in filtered_results:
|
468
|
+
if 'score' not in r:
|
469
|
+
r['score'] = r.get('final_score', 0.0)
|
470
|
+
|
389
471
|
return filtered_results[:count]
|
390
472
|
|
391
473
|
def _vector_search(self, query_vector: List[float], count: int,
|
@@ -478,31 +560,172 @@ class PgVectorSearchBackend:
|
|
478
560
|
|
479
561
|
return results
|
480
562
|
|
563
|
+
def _metadata_search(self, query_terms: List[str], count: int,
|
564
|
+
tags: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
565
|
+
"""
|
566
|
+
Perform metadata search using JSONB operators and metadata_text
|
567
|
+
"""
|
568
|
+
with self.conn.cursor() as cursor:
|
569
|
+
# Build WHERE conditions
|
570
|
+
where_conditions = []
|
571
|
+
params = []
|
572
|
+
|
573
|
+
# Use metadata_text for trigram search
|
574
|
+
if query_terms:
|
575
|
+
# Create AND conditions for all terms
|
576
|
+
for term in query_terms:
|
577
|
+
where_conditions.append(f"metadata_text ILIKE %s")
|
578
|
+
params.append(f'%{term}%')
|
579
|
+
|
580
|
+
# Add tag filter if specified
|
581
|
+
if tags:
|
582
|
+
where_conditions.append("tags ?| %s")
|
583
|
+
params.append(tags)
|
584
|
+
|
585
|
+
# Build query
|
586
|
+
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
587
|
+
|
588
|
+
query = f"""
|
589
|
+
SELECT id, content, filename, section, tags, metadata,
|
590
|
+
metadata_text
|
591
|
+
FROM {self.table_name}
|
592
|
+
WHERE {where_clause}
|
593
|
+
LIMIT %s
|
594
|
+
"""
|
595
|
+
|
596
|
+
params.append(count)
|
597
|
+
|
598
|
+
cursor.execute(query, params)
|
599
|
+
|
600
|
+
results = []
|
601
|
+
for row in cursor.fetchall():
|
602
|
+
chunk_id, content, filename, section, tags_json, metadata_json, metadata_text = row
|
603
|
+
|
604
|
+
# Calculate score based on term matches
|
605
|
+
score = 0.0
|
606
|
+
if metadata_text:
|
607
|
+
metadata_lower = metadata_text.lower()
|
608
|
+
for term in query_terms:
|
609
|
+
if term.lower() in metadata_lower:
|
610
|
+
score += 0.3 # Base score for each match
|
611
|
+
|
612
|
+
# Bonus for exact matches in JSONB keys/values
|
613
|
+
if metadata_json:
|
614
|
+
json_str = json.dumps(metadata_json).lower()
|
615
|
+
for term in query_terms:
|
616
|
+
if term.lower() in json_str:
|
617
|
+
score += 0.2
|
618
|
+
|
619
|
+
# Normalize score
|
620
|
+
score = min(1.0, score)
|
621
|
+
|
622
|
+
results.append({
|
623
|
+
'id': chunk_id,
|
624
|
+
'content': content,
|
625
|
+
'score': float(score),
|
626
|
+
'metadata': {
|
627
|
+
'filename': filename,
|
628
|
+
'section': section,
|
629
|
+
'tags': tags_json if isinstance(tags_json, list) else [],
|
630
|
+
**metadata_json
|
631
|
+
},
|
632
|
+
'search_type': 'metadata'
|
633
|
+
})
|
634
|
+
|
635
|
+
# Sort by score
|
636
|
+
results.sort(key=lambda x: x['score'], reverse=True)
|
637
|
+
return results[:count]
|
638
|
+
|
481
639
|
def _merge_results(self, vector_results: List[Dict[str, Any]],
|
482
|
-
keyword_results: List[Dict[str, Any]]
|
640
|
+
keyword_results: List[Dict[str, Any]],
|
641
|
+
keyword_weight: Optional[float] = None) -> List[Dict[str, Any]]:
|
483
642
|
"""Merge and rank results from vector and keyword search"""
|
643
|
+
# Use provided weights or defaults
|
644
|
+
if keyword_weight is None:
|
645
|
+
keyword_weight = 0.3
|
646
|
+
vector_weight = 1.0 - keyword_weight
|
647
|
+
|
484
648
|
# Create a map to track unique results
|
485
649
|
results_map = {}
|
486
650
|
|
487
|
-
# Add vector results
|
651
|
+
# Add vector results
|
488
652
|
for result in vector_results:
|
489
653
|
chunk_id = result['id']
|
490
654
|
if chunk_id not in results_map:
|
491
655
|
results_map[chunk_id] = result
|
492
|
-
results_map[chunk_id]['score'] *=
|
656
|
+
results_map[chunk_id]['score'] *= vector_weight
|
493
657
|
else:
|
494
658
|
# Combine scores if result appears in both
|
495
|
-
results_map[chunk_id]['score'] += result['score'] *
|
659
|
+
results_map[chunk_id]['score'] += result['score'] * vector_weight
|
496
660
|
|
497
661
|
# Add keyword results
|
498
662
|
for result in keyword_results:
|
499
663
|
chunk_id = result['id']
|
500
664
|
if chunk_id not in results_map:
|
501
665
|
results_map[chunk_id] = result
|
502
|
-
results_map[chunk_id]['score'] *=
|
666
|
+
results_map[chunk_id]['score'] *= keyword_weight
|
503
667
|
else:
|
504
668
|
# Combine scores if result appears in both
|
505
|
-
results_map[chunk_id]['score'] += result['score'] *
|
669
|
+
results_map[chunk_id]['score'] += result['score'] * keyword_weight
|
670
|
+
|
671
|
+
# Sort by combined score
|
672
|
+
merged = list(results_map.values())
|
673
|
+
merged.sort(key=lambda x: x['score'], reverse=True)
|
674
|
+
|
675
|
+
return merged
|
676
|
+
|
677
|
+
def _merge_all_results(self, vector_results: List[Dict[str, Any]],
|
678
|
+
keyword_results: List[Dict[str, Any]],
|
679
|
+
metadata_results: List[Dict[str, Any]],
|
680
|
+
keyword_weight: Optional[float] = None) -> List[Dict[str, Any]]:
|
681
|
+
"""Merge and rank results from vector, keyword, and metadata search"""
|
682
|
+
# Use provided weights or defaults
|
683
|
+
if keyword_weight is None:
|
684
|
+
keyword_weight = 0.3
|
685
|
+
vector_weight = 0.5
|
686
|
+
metadata_weight = 0.2
|
687
|
+
|
688
|
+
# Create a map to track unique results
|
689
|
+
results_map = {}
|
690
|
+
all_sources = {}
|
691
|
+
|
692
|
+
# Add vector results
|
693
|
+
for result in vector_results:
|
694
|
+
chunk_id = result['id']
|
695
|
+
if chunk_id not in results_map:
|
696
|
+
results_map[chunk_id] = result.copy()
|
697
|
+
results_map[chunk_id]['score'] = result['score'] * vector_weight
|
698
|
+
all_sources[chunk_id] = {'vector': result['score']}
|
699
|
+
else:
|
700
|
+
results_map[chunk_id]['score'] += result['score'] * vector_weight
|
701
|
+
all_sources[chunk_id]['vector'] = result['score']
|
702
|
+
|
703
|
+
# Add keyword results
|
704
|
+
for result in keyword_results:
|
705
|
+
chunk_id = result['id']
|
706
|
+
if chunk_id not in results_map:
|
707
|
+
results_map[chunk_id] = result.copy()
|
708
|
+
results_map[chunk_id]['score'] = result['score'] * keyword_weight
|
709
|
+
all_sources.setdefault(chunk_id, {})['keyword'] = result['score']
|
710
|
+
else:
|
711
|
+
results_map[chunk_id]['score'] += result['score'] * keyword_weight
|
712
|
+
all_sources[chunk_id]['keyword'] = result['score']
|
713
|
+
|
714
|
+
# Add metadata results
|
715
|
+
for result in metadata_results:
|
716
|
+
chunk_id = result['id']
|
717
|
+
if chunk_id not in results_map:
|
718
|
+
results_map[chunk_id] = result.copy()
|
719
|
+
results_map[chunk_id]['score'] = result['score'] * metadata_weight
|
720
|
+
all_sources.setdefault(chunk_id, {})['metadata'] = result['score']
|
721
|
+
else:
|
722
|
+
results_map[chunk_id]['score'] += result['score'] * metadata_weight
|
723
|
+
all_sources[chunk_id]['metadata'] = result['score']
|
724
|
+
|
725
|
+
# Add sources to results for transparency
|
726
|
+
for chunk_id, result in results_map.items():
|
727
|
+
result['sources'] = all_sources.get(chunk_id, {})
|
728
|
+
result['final_score'] = result['score']
|
506
729
|
|
507
730
|
# Sort by combined score
|
508
731
|
merged = list(results_map.values())
|
@@ -77,22 +77,90 @@ def load_spacy_model(language: str):
|
|
77
77
|
_spacy_warning_shown = True
|
78
78
|
return None
|
79
79
|
|
80
|
-
|
80
|
+
# Global model cache
|
81
|
+
_cached_model = None
|
82
|
+
_model_lock = None
|
83
|
+
|
84
|
+
def set_global_model(model):
|
85
|
+
"""Set the global cached model instance"""
|
86
|
+
global _cached_model
|
87
|
+
_cached_model = model
|
88
|
+
logger.info("Global model set for query processor")
|
89
|
+
|
90
|
+
def _get_cached_model(model_name: str = None):
|
91
|
+
"""Get or create cached sentence transformer model
|
92
|
+
|
93
|
+
Args:
|
94
|
+
model_name: Optional model name. If not provided, uses default.
|
95
|
+
"""
|
96
|
+
global _cached_model, _model_lock
|
97
|
+
|
98
|
+
# Default model
|
99
|
+
if model_name is None:
|
100
|
+
model_name = 'sentence-transformers/all-mpnet-base-v2'
|
101
|
+
|
102
|
+
# Initialize lock if needed
|
103
|
+
if _model_lock is None:
|
104
|
+
import threading
|
105
|
+
_model_lock = threading.Lock()
|
106
|
+
|
107
|
+
# Return cached model if available and same model
|
108
|
+
if _cached_model is not None:
|
109
|
+
# Check if it's the same model (simple check - assumes model has a name attribute)
|
110
|
+
try:
|
111
|
+
if hasattr(_cached_model, 'model_name') and _cached_model.model_name == model_name:
|
112
|
+
return _cached_model
|
113
|
+
except:
|
114
|
+
pass
|
115
|
+
|
116
|
+
# Load model with lock to prevent race conditions
|
117
|
+
with _model_lock:
|
118
|
+
# Double check in case another thread loaded it
|
119
|
+
if _cached_model is not None:
|
120
|
+
try:
|
121
|
+
if hasattr(_cached_model, 'model_name') and _cached_model.model_name == model_name:
|
122
|
+
return _cached_model
|
123
|
+
except:
|
124
|
+
pass
|
125
|
+
|
126
|
+
try:
|
127
|
+
from sentence_transformers import SentenceTransformer
|
128
|
+
logger.info(f"Loading sentence transformer model: {model_name}")
|
129
|
+
_cached_model = SentenceTransformer(model_name)
|
130
|
+
_cached_model.model_name = model_name # Store for later comparison
|
131
|
+
logger.info("Model loaded and cached successfully")
|
132
|
+
return _cached_model
|
133
|
+
except ImportError:
|
134
|
+
logger.error("sentence-transformers not available. Cannot load model.")
|
135
|
+
return None
|
136
|
+
|
137
|
+
def vectorize_query(query: str, model=None, model_name: str = None):
|
81
138
|
"""
|
82
139
|
Vectorize query using sentence transformers
|
83
140
|
Returns numpy array of embeddings
|
141
|
+
|
142
|
+
Args:
|
143
|
+
query: Query string to vectorize
|
144
|
+
model: Optional pre-loaded model instance. If not provided, uses cached model.
|
145
|
+
model_name: Optional model name to use if loading a new model
|
84
146
|
"""
|
85
147
|
try:
|
86
|
-
from sentence_transformers import SentenceTransformer
|
87
148
|
import numpy as np
|
88
149
|
|
89
|
-
# Use
|
90
|
-
model
|
150
|
+
# Use provided model or get cached one
|
151
|
+
if model is None:
|
152
|
+
model = _get_cached_model(model_name)
|
153
|
+
if model is None:
|
154
|
+
return None
|
155
|
+
|
91
156
|
embedding = model.encode(query, show_progress_bar=False)
|
92
157
|
return embedding
|
93
158
|
|
94
159
|
except ImportError:
|
95
|
-
logger.error("
|
160
|
+
logger.error("numpy not available. Cannot vectorize query.")
|
161
|
+
return None
|
162
|
+
except Exception as e:
|
163
|
+
logger.error(f"Error vectorizing query: {e}")
|
96
164
|
return None
|
97
165
|
|
98
166
|
# Language to NLTK stopwords mapping
|
@@ -200,7 +268,8 @@ def remove_duplicate_words(input_string: str) -> str:
|
|
200
268
|
def preprocess_query(query: str, language: str = 'en', pos_to_expand: Optional[List[str]] = None,
|
201
269
|
max_synonyms: int = 5, debug: bool = False, vector: bool = False,
|
202
270
|
vectorize_query_param: bool = False, nlp_backend: str = None,
|
203
|
-
query_nlp_backend: str = 'nltk'
|
271
|
+
query_nlp_backend: str = 'nltk', model_name: str = None,
|
272
|
+
preserve_original: bool = True) -> Dict[str, Any]:
|
204
273
|
"""
|
205
274
|
Advanced query preprocessing with language detection, POS tagging, synonym expansion, and vectorization
|
206
275
|
|
@@ -333,14 +402,23 @@ def preprocess_query(query: str, language: str = 'en', pos_to_expand: Optional[L
|
|
333
402
|
expanded_query_set = set()
|
334
403
|
expanded_query = []
|
335
404
|
|
405
|
+
# If preserve_original is True, always include the original query first
|
406
|
+
if preserve_original:
|
407
|
+
# Add original query terms first (maintains exact phrases)
|
408
|
+
original_tokens = query.lower().split()
|
409
|
+
for token in original_tokens:
|
410
|
+
if token not in expanded_query_set:
|
411
|
+
expanded_query.append(token)
|
412
|
+
expanded_query_set.add(token)
|
413
|
+
|
336
414
|
for original, lemma in lemmas:
|
337
415
|
if original not in expanded_query_set:
|
338
416
|
expanded_query.append(original)
|
339
417
|
expanded_query_set.add(original)
|
340
|
-
if lemma not in expanded_query_set:
|
418
|
+
if lemma not in expanded_query_set and not preserve_original: # Only add lemmas if not preserving original
|
341
419
|
expanded_query.append(lemma)
|
342
420
|
expanded_query_set.add(lemma)
|
343
|
-
if pos_tags.get(original) in pos_to_expand:
|
421
|
+
if pos_tags.get(original) in pos_to_expand and max_synonyms > 0:
|
344
422
|
synonyms = get_synonyms(lemma, pos_tags[original], max_synonyms)
|
345
423
|
for synonym in synonyms:
|
346
424
|
if synonym not in expanded_query_set:
|
@@ -365,7 +443,7 @@ def preprocess_query(query: str, language: str = 'en', pos_to_expand: Optional[L
|
|
365
443
|
|
366
444
|
# Vectorize query if requested
|
367
445
|
if vector:
|
368
|
-
vectorized_query = vectorize_query(final_query_str)
|
446
|
+
vectorized_query = vectorize_query(final_query_str, model_name=model_name)
|
369
447
|
if vectorized_query is not None:
|
370
448
|
formatted_output['vector'] = vectorized_query.tolist()
|
371
449
|
else:
|