signalwire-agents 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- signalwire_agents/__init__.py +1 -1
- signalwire_agents/cli/build_search.py +95 -19
- signalwire_agents/core/agent_base.py +38 -0
- signalwire_agents/core/mixins/ai_config_mixin.py +120 -0
- signalwire_agents/core/skill_manager.py +47 -0
- signalwire_agents/search/index_builder.py +105 -10
- signalwire_agents/search/pgvector_backend.py +523 -0
- signalwire_agents/search/search_engine.py +41 -4
- signalwire_agents/search/search_service.py +86 -35
- signalwire_agents/skills/api_ninjas_trivia/skill.py +37 -1
- signalwire_agents/skills/datasphere/skill.py +82 -0
- signalwire_agents/skills/datasphere_serverless/skill.py +82 -0
- signalwire_agents/skills/joke/skill.py +21 -0
- signalwire_agents/skills/mcp_gateway/skill.py +82 -0
- signalwire_agents/skills/native_vector_search/README.md +210 -0
- signalwire_agents/skills/native_vector_search/skill.py +197 -7
- signalwire_agents/skills/play_background_file/skill.py +36 -0
- signalwire_agents/skills/registry.py +36 -0
- signalwire_agents/skills/spider/skill.py +113 -0
- signalwire_agents/skills/swml_transfer/skill.py +90 -0
- signalwire_agents/skills/weather_api/skill.py +28 -0
- signalwire_agents/skills/wikipedia_search/skill.py +22 -0
- {signalwire_agents-0.1.36.dist-info → signalwire_agents-0.1.38.dist-info}/METADATA +66 -1
- {signalwire_agents-0.1.36.dist-info → signalwire_agents-0.1.38.dist-info}/RECORD +28 -26
- {signalwire_agents-0.1.36.dist-info → signalwire_agents-0.1.38.dist-info}/WHEEL +0 -0
- {signalwire_agents-0.1.36.dist-info → signalwire_agents-0.1.38.dist-info}/entry_points.txt +0 -0
- {signalwire_agents-0.1.36.dist-info → signalwire_agents-0.1.38.dist-info}/licenses/LICENSE +0 -0
- {signalwire_agents-0.1.36.dist-info → signalwire_agents-0.1.38.dist-info}/top_level.txt +0 -0
signalwire_agents/__init__.py
CHANGED
@@ -18,7 +18,7 @@ A package for building AI agents using SignalWire's AI and SWML capabilities.
|
|
18
18
|
from .core.logging_config import configure_logging
|
19
19
|
configure_logging()
|
20
20
|
|
21
|
-
__version__ = "0.1.
|
21
|
+
__version__ = "0.1.38"
|
22
22
|
|
23
23
|
# Import core classes for easier access
|
24
24
|
from .core.agent_base import AgentBase
|
@@ -89,6 +89,24 @@ Examples:
|
|
89
89
|
# Search via remote API
|
90
90
|
sw-search remote http://localhost:8001 "how to create an agent" --index-name docs
|
91
91
|
sw-search remote localhost:8001 "API reference" --index-name docs --count 3 --verbose
|
92
|
+
|
93
|
+
# PostgreSQL pgvector backend
|
94
|
+
sw-search ./docs \\
|
95
|
+
--backend pgvector \\
|
96
|
+
--connection-string "postgresql://user:pass@localhost/knowledge" \\
|
97
|
+
--output docs_collection
|
98
|
+
|
99
|
+
# Overwrite existing pgvector collection
|
100
|
+
sw-search ./docs \\
|
101
|
+
--backend pgvector \\
|
102
|
+
--connection-string "postgresql://user:pass@localhost/knowledge" \\
|
103
|
+
--output docs_collection \\
|
104
|
+
--overwrite
|
105
|
+
|
106
|
+
# Search in pgvector collection
|
107
|
+
sw-search search docs_collection "how to create an agent" \\
|
108
|
+
--backend pgvector \\
|
109
|
+
--connection-string "postgresql://user:pass@localhost/knowledge"
|
92
110
|
"""
|
93
111
|
)
|
94
112
|
|
@@ -100,7 +118,25 @@ Examples:
|
|
100
118
|
|
101
119
|
parser.add_argument(
|
102
120
|
'--output',
|
103
|
-
help='Output .swsearch file (default: sources.swsearch)'
|
121
|
+
help='Output .swsearch file (default: sources.swsearch) or collection name for pgvector'
|
122
|
+
)
|
123
|
+
|
124
|
+
parser.add_argument(
|
125
|
+
'--backend',
|
126
|
+
choices=['sqlite', 'pgvector'],
|
127
|
+
default='sqlite',
|
128
|
+
help='Storage backend to use (default: sqlite)'
|
129
|
+
)
|
130
|
+
|
131
|
+
parser.add_argument(
|
132
|
+
'--connection-string',
|
133
|
+
help='PostgreSQL connection string for pgvector backend'
|
134
|
+
)
|
135
|
+
|
136
|
+
parser.add_argument(
|
137
|
+
'--overwrite',
|
138
|
+
action='store_true',
|
139
|
+
help='Overwrite existing collection (pgvector backend only)'
|
104
140
|
)
|
105
141
|
|
106
142
|
parser.add_argument(
|
@@ -213,18 +249,31 @@ Examples:
|
|
213
249
|
print("Error: No valid sources found")
|
214
250
|
sys.exit(1)
|
215
251
|
|
252
|
+
# Validate backend configuration
|
253
|
+
if args.backend == 'pgvector' and not args.connection_string:
|
254
|
+
print("Error: --connection-string is required for pgvector backend")
|
255
|
+
sys.exit(1)
|
256
|
+
|
216
257
|
# Default output filename
|
217
258
|
if not args.output:
|
218
|
-
if
|
219
|
-
|
220
|
-
|
221
|
-
|
259
|
+
if args.backend == 'sqlite':
|
260
|
+
if len(valid_sources) == 1:
|
261
|
+
# Single source - use its name
|
262
|
+
source_name = valid_sources[0].stem if valid_sources[0].is_file() else valid_sources[0].name
|
263
|
+
args.output = f"{source_name}.swsearch"
|
264
|
+
else:
|
265
|
+
# Multiple sources - use generic name
|
266
|
+
args.output = "sources.swsearch"
|
222
267
|
else:
|
223
|
-
#
|
224
|
-
|
268
|
+
# For pgvector, use a default collection name
|
269
|
+
if len(valid_sources) == 1:
|
270
|
+
source_name = valid_sources[0].stem if valid_sources[0].is_file() else valid_sources[0].name
|
271
|
+
args.output = source_name
|
272
|
+
else:
|
273
|
+
args.output = "documents"
|
225
274
|
|
226
|
-
# Ensure output has .swsearch extension
|
227
|
-
if not args.output.endswith('.swsearch'):
|
275
|
+
# Ensure output has .swsearch extension for sqlite
|
276
|
+
if args.backend == 'sqlite' and not args.output.endswith('.swsearch'):
|
228
277
|
args.output += '.swsearch'
|
229
278
|
|
230
279
|
# Parse lists
|
@@ -235,8 +284,13 @@ Examples:
|
|
235
284
|
|
236
285
|
if args.verbose:
|
237
286
|
print(f"Building search index:")
|
287
|
+
print(f" Backend: {args.backend}")
|
238
288
|
print(f" Sources: {[str(s) for s in valid_sources]}")
|
239
|
-
|
289
|
+
if args.backend == 'sqlite':
|
290
|
+
print(f" Output file: {args.output}")
|
291
|
+
else:
|
292
|
+
print(f" Collection name: {args.output}")
|
293
|
+
print(f" Connection: {args.connection_string}")
|
240
294
|
print(f" File types (for directories): {file_types}")
|
241
295
|
print(f" Exclude patterns: {exclude_patterns}")
|
242
296
|
print(f" Languages: {languages}")
|
@@ -278,7 +332,9 @@ Examples:
|
|
278
332
|
index_nlp_backend=args.index_nlp_backend,
|
279
333
|
verbose=args.verbose,
|
280
334
|
semantic_threshold=args.semantic_threshold,
|
281
|
-
topic_threshold=args.topic_threshold
|
335
|
+
topic_threshold=args.topic_threshold,
|
336
|
+
backend=args.backend,
|
337
|
+
connection_string=args.connection_string
|
282
338
|
)
|
283
339
|
|
284
340
|
# Build index with multiple sources
|
@@ -288,7 +344,8 @@ Examples:
|
|
288
344
|
file_types=file_types,
|
289
345
|
exclude_patterns=exclude_patterns,
|
290
346
|
languages=languages,
|
291
|
-
tags=tags
|
347
|
+
tags=tags,
|
348
|
+
overwrite=args.overwrite if args.backend == 'pgvector' else False
|
292
349
|
)
|
293
350
|
|
294
351
|
# Validate if requested
|
@@ -307,7 +364,11 @@ Examples:
|
|
307
364
|
print(f"✗ Index validation failed: {validation['error']}")
|
308
365
|
sys.exit(1)
|
309
366
|
|
310
|
-
|
367
|
+
if args.backend == 'sqlite':
|
368
|
+
print(f"\n✓ Search index created successfully: {args.output}")
|
369
|
+
else:
|
370
|
+
print(f"\n✓ Search collection created successfully: {args.output}")
|
371
|
+
print(f" Connection: {args.connection_string}")
|
311
372
|
|
312
373
|
except KeyboardInterrupt:
|
313
374
|
print("\n\nBuild interrupted by user")
|
@@ -359,9 +420,12 @@ def validate_command():
|
|
359
420
|
|
360
421
|
def search_command():
|
361
422
|
"""Search within an existing search index"""
|
362
|
-
parser = argparse.ArgumentParser(description='Search within a .swsearch index file')
|
363
|
-
parser.add_argument('
|
423
|
+
parser = argparse.ArgumentParser(description='Search within a .swsearch index file or pgvector collection')
|
424
|
+
parser.add_argument('index_source', help='Path to .swsearch file or collection name for pgvector')
|
364
425
|
parser.add_argument('query', help='Search query')
|
426
|
+
parser.add_argument('--backend', choices=['sqlite', 'pgvector'], default='sqlite',
|
427
|
+
help='Storage backend (default: sqlite)')
|
428
|
+
parser.add_argument('--connection-string', help='PostgreSQL connection string for pgvector backend')
|
365
429
|
parser.add_argument('--count', type=int, default=5, help='Number of results to return (default: 5)')
|
366
430
|
parser.add_argument('--distance-threshold', type=float, default=0.0, help='Minimum similarity score (default: 0.0)')
|
367
431
|
parser.add_argument('--tags', help='Comma-separated tags to filter by')
|
@@ -373,8 +437,13 @@ def search_command():
|
|
373
437
|
|
374
438
|
args = parser.parse_args()
|
375
439
|
|
376
|
-
|
377
|
-
|
440
|
+
# Validate backend configuration
|
441
|
+
if args.backend == 'pgvector' and not args.connection_string:
|
442
|
+
print("Error: --connection-string is required for pgvector backend")
|
443
|
+
sys.exit(1)
|
444
|
+
|
445
|
+
if args.backend == 'sqlite' and not Path(args.index_source).exists():
|
446
|
+
print(f"Error: Index file does not exist: {args.index_source}")
|
378
447
|
sys.exit(1)
|
379
448
|
|
380
449
|
try:
|
@@ -389,9 +458,16 @@ def search_command():
|
|
389
458
|
|
390
459
|
# Load search engine
|
391
460
|
if args.verbose:
|
392
|
-
|
461
|
+
if args.backend == 'sqlite':
|
462
|
+
print(f"Loading search index: {args.index_source}")
|
463
|
+
else:
|
464
|
+
print(f"Connecting to pgvector collection: {args.index_source}")
|
393
465
|
|
394
|
-
|
466
|
+
if args.backend == 'sqlite':
|
467
|
+
engine = SearchEngine(backend='sqlite', index_path=args.index_source)
|
468
|
+
else:
|
469
|
+
engine = SearchEngine(backend='pgvector', connection_string=args.connection_string,
|
470
|
+
collection_name=args.index_source)
|
395
471
|
|
396
472
|
# Get index stats
|
397
473
|
stats = engine.get_stats()
|
@@ -238,6 +238,20 @@ class AgentBase(
|
|
238
238
|
self._params = {}
|
239
239
|
self._global_data = {}
|
240
240
|
self._function_includes = []
|
241
|
+
# Initialize with default LLM params
|
242
|
+
self._prompt_llm_params = {
|
243
|
+
'temperature': 0.3,
|
244
|
+
'top_p': 1.0,
|
245
|
+
'barge_confidence': 0.0,
|
246
|
+
'presence_penalty': 0.1,
|
247
|
+
'frequency_penalty': 0.1
|
248
|
+
}
|
249
|
+
self._post_prompt_llm_params = {
|
250
|
+
'temperature': 0.0,
|
251
|
+
'top_p': 1.0,
|
252
|
+
'presence_penalty': 0.0,
|
253
|
+
'frequency_penalty': 0.0
|
254
|
+
}
|
241
255
|
|
242
256
|
# Dynamic configuration callback
|
243
257
|
self._dynamic_config_callback = None
|
@@ -763,6 +777,30 @@ class AgentBase(
|
|
763
777
|
# Add global_data if any
|
764
778
|
if agent_to_use._global_data:
|
765
779
|
ai_config["global_data"] = agent_to_use._global_data
|
780
|
+
|
781
|
+
# Always add LLM parameters to prompt
|
782
|
+
if "prompt" in ai_config:
|
783
|
+
# Update existing prompt with LLM params
|
784
|
+
if isinstance(ai_config["prompt"], dict):
|
785
|
+
ai_config["prompt"].update(agent_to_use._prompt_llm_params)
|
786
|
+
elif isinstance(ai_config["prompt"], str):
|
787
|
+
# Convert string prompt to dict format
|
788
|
+
ai_config["prompt"] = {
|
789
|
+
"text": ai_config["prompt"],
|
790
|
+
**agent_to_use._prompt_llm_params
|
791
|
+
}
|
792
|
+
|
793
|
+
# Always add LLM parameters to post_prompt if post_prompt exists
|
794
|
+
if post_prompt and "post_prompt" in ai_config:
|
795
|
+
# Update existing post_prompt with LLM params
|
796
|
+
if isinstance(ai_config["post_prompt"], dict):
|
797
|
+
ai_config["post_prompt"].update(agent_to_use._post_prompt_llm_params)
|
798
|
+
elif isinstance(ai_config["post_prompt"], str):
|
799
|
+
# Convert string post_prompt to dict format
|
800
|
+
ai_config["post_prompt"] = {
|
801
|
+
"text": ai_config["post_prompt"],
|
802
|
+
**agent_to_use._post_prompt_llm_params
|
803
|
+
}
|
766
804
|
|
767
805
|
except ValueError as e:
|
768
806
|
if not agent_to_use._suppress_logs:
|
@@ -370,4 +370,124 @@ class AIConfigMixin:
|
|
370
370
|
valid_includes.append(include)
|
371
371
|
|
372
372
|
self._function_includes = valid_includes
|
373
|
+
return self
|
374
|
+
|
375
|
+
def set_prompt_llm_params(
|
376
|
+
self,
|
377
|
+
temperature: Optional[float] = None,
|
378
|
+
top_p: Optional[float] = None,
|
379
|
+
barge_confidence: Optional[float] = None,
|
380
|
+
presence_penalty: Optional[float] = None,
|
381
|
+
frequency_penalty: Optional[float] = None
|
382
|
+
) -> 'AgentBase':
|
383
|
+
"""
|
384
|
+
Set LLM parameters for the main prompt.
|
385
|
+
|
386
|
+
Args:
|
387
|
+
temperature: Randomness setting (0.0-1.5). Lower values make output more deterministic.
|
388
|
+
Default: 0.3
|
389
|
+
top_p: Alternative to temperature (0.0-1.0). Controls nucleus sampling.
|
390
|
+
Default: 1.0
|
391
|
+
barge_confidence: ASR confidence to interrupt (0.0-1.0). Higher values make it harder to interrupt.
|
392
|
+
Default: 0.0
|
393
|
+
presence_penalty: Topic diversity (-2.0 to 2.0). Positive values encourage new topics.
|
394
|
+
Default: 0.1
|
395
|
+
frequency_penalty: Repetition control (-2.0 to 2.0). Positive values reduce repetition.
|
396
|
+
Default: 0.1
|
397
|
+
|
398
|
+
Returns:
|
399
|
+
Self for method chaining
|
400
|
+
|
401
|
+
Example:
|
402
|
+
agent.set_prompt_llm_params(
|
403
|
+
temperature=0.7,
|
404
|
+
top_p=0.9,
|
405
|
+
barge_confidence=0.6
|
406
|
+
)
|
407
|
+
"""
|
408
|
+
# Validate and set temperature
|
409
|
+
if temperature is not None:
|
410
|
+
if not 0.0 <= temperature <= 1.5:
|
411
|
+
raise ValueError("temperature must be between 0.0 and 1.5")
|
412
|
+
self._prompt_llm_params['temperature'] = temperature
|
413
|
+
|
414
|
+
# Validate and set top_p
|
415
|
+
if top_p is not None:
|
416
|
+
if not 0.0 <= top_p <= 1.0:
|
417
|
+
raise ValueError("top_p must be between 0.0 and 1.0")
|
418
|
+
self._prompt_llm_params['top_p'] = top_p
|
419
|
+
|
420
|
+
# Validate and set barge_confidence
|
421
|
+
if barge_confidence is not None:
|
422
|
+
if not 0.0 <= barge_confidence <= 1.0:
|
423
|
+
raise ValueError("barge_confidence must be between 0.0 and 1.0")
|
424
|
+
self._prompt_llm_params['barge_confidence'] = barge_confidence
|
425
|
+
|
426
|
+
# Validate and set presence_penalty
|
427
|
+
if presence_penalty is not None:
|
428
|
+
if not -2.0 <= presence_penalty <= 2.0:
|
429
|
+
raise ValueError("presence_penalty must be between -2.0 and 2.0")
|
430
|
+
self._prompt_llm_params['presence_penalty'] = presence_penalty
|
431
|
+
|
432
|
+
# Validate and set frequency_penalty
|
433
|
+
if frequency_penalty is not None:
|
434
|
+
if not -2.0 <= frequency_penalty <= 2.0:
|
435
|
+
raise ValueError("frequency_penalty must be between -2.0 and 2.0")
|
436
|
+
self._prompt_llm_params['frequency_penalty'] = frequency_penalty
|
437
|
+
|
438
|
+
return self
|
439
|
+
|
440
|
+
def set_post_prompt_llm_params(
|
441
|
+
self,
|
442
|
+
temperature: Optional[float] = None,
|
443
|
+
top_p: Optional[float] = None,
|
444
|
+
presence_penalty: Optional[float] = None,
|
445
|
+
frequency_penalty: Optional[float] = None
|
446
|
+
) -> 'AgentBase':
|
447
|
+
"""
|
448
|
+
Set LLM parameters for the post-prompt.
|
449
|
+
|
450
|
+
Args:
|
451
|
+
temperature: Randomness setting (0.0-1.5). Lower values make output more deterministic.
|
452
|
+
Default: 0.0
|
453
|
+
top_p: Alternative to temperature (0.0-1.0). Controls nucleus sampling.
|
454
|
+
Default: 1.0
|
455
|
+
presence_penalty: Topic diversity (-2.0 to 2.0). Positive values encourage new topics.
|
456
|
+
Default: 0.0
|
457
|
+
frequency_penalty: Repetition control (-2.0 to 2.0). Positive values reduce repetition.
|
458
|
+
Default: 0.0
|
459
|
+
|
460
|
+
Returns:
|
461
|
+
Self for method chaining
|
462
|
+
|
463
|
+
Example:
|
464
|
+
agent.set_post_prompt_llm_params(
|
465
|
+
temperature=0.5, # More deterministic for post-prompt
|
466
|
+
top_p=0.9
|
467
|
+
)
|
468
|
+
"""
|
469
|
+
# Validate and set temperature
|
470
|
+
if temperature is not None:
|
471
|
+
if not 0.0 <= temperature <= 1.5:
|
472
|
+
raise ValueError("temperature must be between 0.0 and 1.5")
|
473
|
+
self._post_prompt_llm_params['temperature'] = temperature
|
474
|
+
|
475
|
+
# Validate and set top_p
|
476
|
+
if top_p is not None:
|
477
|
+
if not 0.0 <= top_p <= 1.0:
|
478
|
+
raise ValueError("top_p must be between 0.0 and 1.0")
|
479
|
+
self._post_prompt_llm_params['top_p'] = top_p
|
480
|
+
|
481
|
+
# Validate and set presence_penalty
|
482
|
+
if presence_penalty is not None:
|
483
|
+
if not -2.0 <= presence_penalty <= 2.0:
|
484
|
+
raise ValueError("presence_penalty must be between -2.0 and 2.0")
|
485
|
+
self._post_prompt_llm_params['presence_penalty'] = presence_penalty
|
486
|
+
|
487
|
+
# Validate and set frequency_penalty
|
488
|
+
if frequency_penalty is not None:
|
489
|
+
if not -2.0 <= frequency_penalty <= 2.0:
|
490
|
+
raise ValueError("frequency_penalty must be between -2.0 and 2.0")
|
491
|
+
self._post_prompt_llm_params['frequency_penalty'] = frequency_penalty
|
492
|
+
|
373
493
|
return self
|
@@ -45,6 +45,53 @@ class SkillManager:
|
|
45
45
|
self.logger.error(error_msg)
|
46
46
|
return False, error_msg
|
47
47
|
|
48
|
+
# Validate that the skill has a proper parameter schema
|
49
|
+
if not hasattr(skill_class, 'get_parameter_schema') or not callable(getattr(skill_class, 'get_parameter_schema')):
|
50
|
+
error_msg = f"Skill '{skill_name}' must have get_parameter_schema() classmethod"
|
51
|
+
self.logger.error(error_msg)
|
52
|
+
return False, error_msg
|
53
|
+
|
54
|
+
try:
|
55
|
+
# Validate the parameter schema
|
56
|
+
schema = skill_class.get_parameter_schema()
|
57
|
+
if not isinstance(schema, dict):
|
58
|
+
error_msg = f"Skill '{skill_name}'.get_parameter_schema() must return a dictionary"
|
59
|
+
self.logger.error(error_msg)
|
60
|
+
return False, error_msg
|
61
|
+
|
62
|
+
# Ensure it's not an empty schema
|
63
|
+
if not schema:
|
64
|
+
error_msg = f"Skill '{skill_name}'.get_parameter_schema() returned empty dictionary"
|
65
|
+
self.logger.error(error_msg)
|
66
|
+
return False, error_msg
|
67
|
+
|
68
|
+
# Check if the skill has overridden the method
|
69
|
+
from signalwire_agents.core.skill_base import SkillBase
|
70
|
+
skill_method = getattr(skill_class, 'get_parameter_schema', None)
|
71
|
+
base_method = getattr(SkillBase, 'get_parameter_schema', None)
|
72
|
+
|
73
|
+
if skill_method and base_method:
|
74
|
+
# For class methods, check the underlying function
|
75
|
+
skill_func = skill_method.__func__ if hasattr(skill_method, '__func__') else skill_method
|
76
|
+
base_func = base_method.__func__ if hasattr(base_method, '__func__') else base_method
|
77
|
+
|
78
|
+
if skill_func is base_func:
|
79
|
+
# Get base schema to check if skill added any parameters
|
80
|
+
base_schema = SkillBase.get_parameter_schema()
|
81
|
+
if set(schema.keys()) == set(base_schema.keys()):
|
82
|
+
error_msg = f"Skill '{skill_name}' must override get_parameter_schema() to define its specific parameters"
|
83
|
+
self.logger.error(error_msg)
|
84
|
+
return False, error_msg
|
85
|
+
|
86
|
+
except AttributeError as e:
|
87
|
+
error_msg = f"Skill '{skill_name}' must properly implement get_parameter_schema() classmethod"
|
88
|
+
self.logger.error(error_msg)
|
89
|
+
return False, error_msg
|
90
|
+
except Exception as e:
|
91
|
+
error_msg = f"Skill '{skill_name}'.get_parameter_schema() failed: {e}"
|
92
|
+
self.logger.error(error_msg)
|
93
|
+
return False, error_msg
|
94
|
+
|
48
95
|
try:
|
49
96
|
# Create skill instance with parameters to get the instance key
|
50
97
|
skill_instance = skill_class(self.agent, params)
|
@@ -46,7 +46,9 @@ class IndexBuilder:
|
|
46
46
|
index_nlp_backend: str = 'nltk',
|
47
47
|
verbose: bool = False,
|
48
48
|
semantic_threshold: float = 0.5,
|
49
|
-
topic_threshold: float = 0.3
|
49
|
+
topic_threshold: float = 0.3,
|
50
|
+
backend: str = 'sqlite',
|
51
|
+
connection_string: Optional[str] = None
|
50
52
|
):
|
51
53
|
"""
|
52
54
|
Initialize the index builder
|
@@ -62,6 +64,8 @@ class IndexBuilder:
|
|
62
64
|
verbose: Whether to enable verbose logging (default: False)
|
63
65
|
semantic_threshold: Similarity threshold for semantic chunking (default: 0.5)
|
64
66
|
topic_threshold: Similarity threshold for topic chunking (default: 0.3)
|
67
|
+
backend: Storage backend ('sqlite' or 'pgvector') (default: 'sqlite')
|
68
|
+
connection_string: PostgreSQL connection string for pgvector backend
|
65
69
|
"""
|
66
70
|
self.model_name = model_name
|
67
71
|
self.chunking_strategy = chunking_strategy
|
@@ -73,8 +77,17 @@ class IndexBuilder:
|
|
73
77
|
self.verbose = verbose
|
74
78
|
self.semantic_threshold = semantic_threshold
|
75
79
|
self.topic_threshold = topic_threshold
|
80
|
+
self.backend = backend
|
81
|
+
self.connection_string = connection_string
|
76
82
|
self.model = None
|
77
83
|
|
84
|
+
# Validate backend
|
85
|
+
if self.backend not in ['sqlite', 'pgvector']:
|
86
|
+
raise ValueError(f"Invalid backend '{self.backend}'. Must be 'sqlite' or 'pgvector'")
|
87
|
+
|
88
|
+
if self.backend == 'pgvector' and not self.connection_string:
|
89
|
+
raise ValueError("connection_string is required for pgvector backend")
|
90
|
+
|
78
91
|
# Validate NLP backend
|
79
92
|
if self.index_nlp_backend not in ['nltk', 'spacy']:
|
80
93
|
logger.warning(f"Invalid index_nlp_backend '{self.index_nlp_backend}', using 'nltk'")
|
@@ -109,7 +122,8 @@ class IndexBuilder:
|
|
109
122
|
|
110
123
|
def build_index_from_sources(self, sources: List[Path], output_file: str,
|
111
124
|
file_types: List[str], exclude_patterns: Optional[List[str]] = None,
|
112
|
-
languages: List[str] = None, tags: Optional[List[str]] = None
|
125
|
+
languages: List[str] = None, tags: Optional[List[str]] = None,
|
126
|
+
overwrite: bool = False):
|
113
127
|
"""
|
114
128
|
Build complete search index from multiple sources (files and directories)
|
115
129
|
|
@@ -191,13 +205,18 @@ class IndexBuilder:
|
|
191
205
|
else:
|
192
206
|
chunk['embedding'] = b''
|
193
207
|
|
194
|
-
#
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
208
|
+
# Store chunks based on backend
|
209
|
+
if self.backend == 'sqlite':
|
210
|
+
# Create SQLite database
|
211
|
+
sources_info = [str(s) for s in sources]
|
212
|
+
self._create_database(output_file, chunks, languages or ['en'], sources_info, file_types)
|
213
|
+
|
214
|
+
if self.verbose:
|
215
|
+
print(f"Index created: {output_file}")
|
216
|
+
print(f"Total chunks: {len(chunks)}")
|
217
|
+
else:
|
218
|
+
# Use pgvector backend
|
219
|
+
self._store_chunks_pgvector(chunks, output_file, languages or ['en'], overwrite)
|
201
220
|
|
202
221
|
def build_index(self, source_dir: str, output_file: str,
|
203
222
|
file_types: List[str], exclude_patterns: Optional[List[str]] = None,
|
@@ -607,4 +626,80 @@ class IndexBuilder:
|
|
607
626
|
}
|
608
627
|
|
609
628
|
except Exception as e:
|
610
|
-
return {"valid": False, "error": str(e)}
|
629
|
+
return {"valid": False, "error": str(e)}
|
630
|
+
|
631
|
+
def _store_chunks_pgvector(self, chunks: List[Dict[str, Any]], collection_name: str,
|
632
|
+
languages: List[str], overwrite: bool = False):
|
633
|
+
"""
|
634
|
+
Store chunks in pgvector backend
|
635
|
+
|
636
|
+
Args:
|
637
|
+
chunks: List of processed chunks
|
638
|
+
collection_name: Name for the collection (from output_file parameter)
|
639
|
+
languages: List of supported languages
|
640
|
+
"""
|
641
|
+
from .pgvector_backend import PgVectorBackend
|
642
|
+
|
643
|
+
# Extract collection name from the provided name
|
644
|
+
if collection_name.endswith('.swsearch'):
|
645
|
+
collection_name = collection_name[:-9] # Remove .swsearch extension
|
646
|
+
|
647
|
+
# Clean collection name for PostgreSQL
|
648
|
+
import re
|
649
|
+
collection_name = re.sub(r'[^a-zA-Z0-9_]', '_', collection_name)
|
650
|
+
|
651
|
+
if self.verbose:
|
652
|
+
print(f"Storing chunks in pgvector collection: {collection_name}")
|
653
|
+
|
654
|
+
# Create backend instance
|
655
|
+
backend = PgVectorBackend(self.connection_string)
|
656
|
+
|
657
|
+
try:
|
658
|
+
# Get embedding dimensions from model
|
659
|
+
if self.model:
|
660
|
+
embedding_dim = self.model.get_sentence_embedding_dimension()
|
661
|
+
else:
|
662
|
+
embedding_dim = 768 # Default for all-mpnet-base-v2
|
663
|
+
|
664
|
+
# Delete existing collection if overwrite is requested
|
665
|
+
if overwrite:
|
666
|
+
if self.verbose:
|
667
|
+
print(f"Dropping existing collection: {collection_name}")
|
668
|
+
backend.delete_collection(collection_name)
|
669
|
+
|
670
|
+
# Create schema
|
671
|
+
backend.create_schema(collection_name, embedding_dim)
|
672
|
+
|
673
|
+
# Convert embeddings from bytes to numpy arrays
|
674
|
+
for chunk in chunks:
|
675
|
+
if chunk.get('embedding') and isinstance(chunk['embedding'], bytes):
|
676
|
+
if np:
|
677
|
+
chunk['embedding'] = np.frombuffer(chunk['embedding'], dtype=np.float32)
|
678
|
+
else:
|
679
|
+
# If numpy not available, leave as bytes
|
680
|
+
pass
|
681
|
+
|
682
|
+
# Prepare config
|
683
|
+
config = {
|
684
|
+
'model_name': self.model_name,
|
685
|
+
'embedding_dimensions': embedding_dim,
|
686
|
+
'chunking_strategy': self.chunking_strategy,
|
687
|
+
'languages': languages,
|
688
|
+
'metadata': {
|
689
|
+
'max_sentences_per_chunk': self.max_sentences_per_chunk,
|
690
|
+
'chunk_size': self.chunk_size,
|
691
|
+
'chunk_overlap': self.chunk_overlap,
|
692
|
+
'index_nlp_backend': self.index_nlp_backend
|
693
|
+
}
|
694
|
+
}
|
695
|
+
|
696
|
+
# Store chunks
|
697
|
+
backend.store_chunks(chunks, collection_name, config)
|
698
|
+
|
699
|
+
if self.verbose:
|
700
|
+
stats = backend.get_stats(collection_name)
|
701
|
+
print(f"Stored {stats['total_chunks']} chunks in pgvector")
|
702
|
+
print(f"Collection: {collection_name}")
|
703
|
+
|
704
|
+
finally:
|
705
|
+
backend.close()
|