tooluniverse 0.2.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tooluniverse might be problematic. Click here for more details.
- tooluniverse/__init__.py +340 -4
- tooluniverse/admetai_tool.py +84 -0
- tooluniverse/agentic_tool.py +563 -0
- tooluniverse/alphafold_tool.py +96 -0
- tooluniverse/base_tool.py +129 -6
- tooluniverse/boltz_tool.py +207 -0
- tooluniverse/chem_tool.py +192 -0
- tooluniverse/compose_scripts/__init__.py +1 -0
- tooluniverse/compose_scripts/biomarker_discovery.py +293 -0
- tooluniverse/compose_scripts/comprehensive_drug_discovery.py +186 -0
- tooluniverse/compose_scripts/drug_safety_analyzer.py +89 -0
- tooluniverse/compose_scripts/literature_tool.py +34 -0
- tooluniverse/compose_scripts/output_summarizer.py +279 -0
- tooluniverse/compose_scripts/tool_description_optimizer.py +681 -0
- tooluniverse/compose_scripts/tool_discover.py +705 -0
- tooluniverse/compose_scripts/tool_graph_composer.py +448 -0
- tooluniverse/compose_tool.py +371 -0
- tooluniverse/ctg_tool.py +1002 -0
- tooluniverse/custom_tool.py +81 -0
- tooluniverse/dailymed_tool.py +108 -0
- tooluniverse/data/admetai_tools.json +155 -0
- tooluniverse/data/adverse_event_tools.json +108 -0
- tooluniverse/data/agentic_tools.json +1156 -0
- tooluniverse/data/alphafold_tools.json +87 -0
- tooluniverse/data/boltz_tools.json +9 -0
- tooluniverse/data/chembl_tools.json +16 -0
- tooluniverse/data/clinicaltrials_gov_tools.json +326 -0
- tooluniverse/data/compose_tools.json +202 -0
- tooluniverse/data/dailymed_tools.json +70 -0
- tooluniverse/data/dataset_tools.json +646 -0
- tooluniverse/data/disease_target_score_tools.json +712 -0
- tooluniverse/data/efo_tools.json +17 -0
- tooluniverse/data/embedding_tools.json +319 -0
- tooluniverse/data/enrichr_tools.json +31 -0
- tooluniverse/data/europe_pmc_tools.json +22 -0
- tooluniverse/data/expert_feedback_tools.json +10 -0
- tooluniverse/data/fda_drug_adverse_event_tools.json +491 -0
- tooluniverse/data/fda_drug_labeling_tools.json +1 -1
- tooluniverse/data/fda_drugs_with_brand_generic_names_for_tool.py +76929 -148860
- tooluniverse/data/finder_tools.json +209 -0
- tooluniverse/data/gene_ontology_tools.json +113 -0
- tooluniverse/data/gwas_tools.json +1082 -0
- tooluniverse/data/hpa_tools.json +333 -0
- tooluniverse/data/humanbase_tools.json +47 -0
- tooluniverse/data/idmap_tools.json +74 -0
- tooluniverse/data/mcp_client_tools_example.json +113 -0
- tooluniverse/data/mcpautoloadertool_defaults.json +28 -0
- tooluniverse/data/medlineplus_tools.json +141 -0
- tooluniverse/data/monarch_tools.json +1 -1
- tooluniverse/data/openalex_tools.json +36 -0
- tooluniverse/data/opentarget_tools.json +1 -1
- tooluniverse/data/output_summarization_tools.json +101 -0
- tooluniverse/data/packages/bioinformatics_core_tools.json +1756 -0
- tooluniverse/data/packages/categorized_tools.txt +206 -0
- tooluniverse/data/packages/cheminformatics_tools.json +347 -0
- tooluniverse/data/packages/earth_sciences_tools.json +74 -0
- tooluniverse/data/packages/genomics_tools.json +776 -0
- tooluniverse/data/packages/image_processing_tools.json +38 -0
- tooluniverse/data/packages/machine_learning_tools.json +789 -0
- tooluniverse/data/packages/neuroscience_tools.json +62 -0
- tooluniverse/data/packages/original_tools.txt +0 -0
- tooluniverse/data/packages/physics_astronomy_tools.json +62 -0
- tooluniverse/data/packages/scientific_computing_tools.json +560 -0
- tooluniverse/data/packages/single_cell_tools.json +453 -0
- tooluniverse/data/packages/structural_biology_tools.json +396 -0
- tooluniverse/data/packages/visualization_tools.json +399 -0
- tooluniverse/data/pubchem_tools.json +215 -0
- tooluniverse/data/pubtator_tools.json +68 -0
- tooluniverse/data/rcsb_pdb_tools.json +1332 -0
- tooluniverse/data/reactome_tools.json +19 -0
- tooluniverse/data/semantic_scholar_tools.json +26 -0
- tooluniverse/data/special_tools.json +2 -25
- tooluniverse/data/tool_composition_tools.json +88 -0
- tooluniverse/data/toolfinderkeyword_defaults.json +34 -0
- tooluniverse/data/txagent_client_tools.json +9 -0
- tooluniverse/data/uniprot_tools.json +211 -0
- tooluniverse/data/url_fetch_tools.json +94 -0
- tooluniverse/data/uspto_downloader_tools.json +9 -0
- tooluniverse/data/uspto_tools.json +811 -0
- tooluniverse/data/xml_tools.json +3275 -0
- tooluniverse/dataset_tool.py +296 -0
- tooluniverse/default_config.py +165 -0
- tooluniverse/efo_tool.py +42 -0
- tooluniverse/embedding_database.py +630 -0
- tooluniverse/embedding_sync.py +396 -0
- tooluniverse/enrichr_tool.py +266 -0
- tooluniverse/europe_pmc_tool.py +52 -0
- tooluniverse/execute_function.py +1775 -95
- tooluniverse/extended_hooks.py +444 -0
- tooluniverse/gene_ontology_tool.py +194 -0
- tooluniverse/graphql_tool.py +158 -36
- tooluniverse/gwas_tool.py +358 -0
- tooluniverse/hpa_tool.py +1645 -0
- tooluniverse/humanbase_tool.py +389 -0
- tooluniverse/logging_config.py +254 -0
- tooluniverse/mcp_client_tool.py +764 -0
- tooluniverse/mcp_integration.py +413 -0
- tooluniverse/mcp_tool_registry.py +925 -0
- tooluniverse/medlineplus_tool.py +337 -0
- tooluniverse/openalex_tool.py +228 -0
- tooluniverse/openfda_adv_tool.py +283 -0
- tooluniverse/openfda_tool.py +393 -160
- tooluniverse/output_hook.py +1122 -0
- tooluniverse/package_tool.py +195 -0
- tooluniverse/pubchem_tool.py +158 -0
- tooluniverse/pubtator_tool.py +168 -0
- tooluniverse/rcsb_pdb_tool.py +38 -0
- tooluniverse/reactome_tool.py +108 -0
- tooluniverse/remote/boltz/boltz_mcp_server.py +50 -0
- tooluniverse/remote/depmap_24q2/depmap_24q2_mcp_tool.py +442 -0
- tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py +2013 -0
- tooluniverse/remote/expert_feedback/simple_test.py +23 -0
- tooluniverse/remote/expert_feedback/start_web_interface.py +188 -0
- tooluniverse/remote/expert_feedback/web_only_interface.py +0 -0
- tooluniverse/remote/immune_compass/compass_tool.py +327 -0
- tooluniverse/remote/pinnacle/pinnacle_tool.py +328 -0
- tooluniverse/remote/transcriptformer/transcriptformer_tool.py +586 -0
- tooluniverse/remote/uspto_downloader/uspto_downloader_mcp_server.py +61 -0
- tooluniverse/remote/uspto_downloader/uspto_downloader_tool.py +120 -0
- tooluniverse/remote_tool.py +99 -0
- tooluniverse/restful_tool.py +53 -30
- tooluniverse/scripts/generate_tool_graph.py +408 -0
- tooluniverse/scripts/visualize_tool_graph.py +829 -0
- tooluniverse/semantic_scholar_tool.py +62 -0
- tooluniverse/smcp.py +2452 -0
- tooluniverse/smcp_server.py +975 -0
- tooluniverse/test/mcp_server_test.py +0 -0
- tooluniverse/test/test_admetai_tool.py +370 -0
- tooluniverse/test/test_agentic_tool.py +129 -0
- tooluniverse/test/test_alphafold_tool.py +71 -0
- tooluniverse/test/test_chem_tool.py +37 -0
- tooluniverse/test/test_compose_lieraturereview.py +63 -0
- tooluniverse/test/test_compose_tool.py +448 -0
- tooluniverse/test/test_dailymed.py +69 -0
- tooluniverse/test/test_dataset_tool.py +200 -0
- tooluniverse/test/test_disease_target_score.py +56 -0
- tooluniverse/test/test_drugbank_filter_examples.py +179 -0
- tooluniverse/test/test_efo.py +31 -0
- tooluniverse/test/test_enrichr_tool.py +21 -0
- tooluniverse/test/test_europe_pmc_tool.py +20 -0
- tooluniverse/test/test_fda_adv.py +95 -0
- tooluniverse/test/test_fda_drug_labeling.py +91 -0
- tooluniverse/test/test_gene_ontology_tools.py +66 -0
- tooluniverse/test/test_gwas_tool.py +139 -0
- tooluniverse/test/test_hpa.py +625 -0
- tooluniverse/test/test_humanbase_tool.py +20 -0
- tooluniverse/test/test_idmap_tools.py +61 -0
- tooluniverse/test/test_mcp_server.py +211 -0
- tooluniverse/test/test_mcp_tool.py +247 -0
- tooluniverse/test/test_medlineplus.py +220 -0
- tooluniverse/test/test_openalex_tool.py +32 -0
- tooluniverse/test/test_opentargets.py +28 -0
- tooluniverse/test/test_pubchem_tool.py +116 -0
- tooluniverse/test/test_pubtator_tool.py +37 -0
- tooluniverse/test/test_rcsb_pdb_tool.py +86 -0
- tooluniverse/test/test_reactome.py +54 -0
- tooluniverse/test/test_semantic_scholar_tool.py +24 -0
- tooluniverse/test/test_software_tools.py +147 -0
- tooluniverse/test/test_tool_description_optimizer.py +49 -0
- tooluniverse/test/test_tool_finder.py +26 -0
- tooluniverse/test/test_tool_finder_llm.py +252 -0
- tooluniverse/test/test_tools_find.py +195 -0
- tooluniverse/test/test_uniprot_tools.py +74 -0
- tooluniverse/test/test_uspto_tool.py +72 -0
- tooluniverse/test/test_xml_tool.py +113 -0
- tooluniverse/tool_finder_embedding.py +267 -0
- tooluniverse/tool_finder_keyword.py +693 -0
- tooluniverse/tool_finder_llm.py +699 -0
- tooluniverse/tool_graph_web_ui.py +955 -0
- tooluniverse/tool_registry.py +416 -0
- tooluniverse/uniprot_tool.py +155 -0
- tooluniverse/url_tool.py +253 -0
- tooluniverse/uspto_tool.py +240 -0
- tooluniverse/utils.py +369 -41
- tooluniverse/xml_tool.py +369 -0
- tooluniverse-1.0.1.dist-info/METADATA +387 -0
- tooluniverse-1.0.1.dist-info/RECORD +182 -0
- tooluniverse-1.0.1.dist-info/entry_points.txt +9 -0
- tooluniverse/generate_mcp_tools.py +0 -113
- tooluniverse/mcp_server.py +0 -3340
- tooluniverse-0.2.0.dist-info/METADATA +0 -139
- tooluniverse-0.2.0.dist-info/RECORD +0 -21
- tooluniverse-0.2.0.dist-info/entry_points.txt +0 -4
- {tooluniverse-0.2.0.dist-info → tooluniverse-1.0.1.dist-info}/WHEEL +0 -0
- {tooluniverse-0.2.0.dist-info → tooluniverse-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {tooluniverse-0.2.0.dist-info → tooluniverse-1.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,630 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding Database Tool for ToolUniverse
|
|
3
|
+
|
|
4
|
+
A unified tool for managing embedding databases with FAISS vector search and SQLite metadata storage.
|
|
5
|
+
Supports creating databases from documents, adding documents, searching, and loading existing databases.
|
|
6
|
+
Uses OpenAI's embedding models for text-to-vector conversion, with support for Azure OpenAI.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import json
|
|
11
|
+
import sqlite3
|
|
12
|
+
import numpy as np
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import List, Dict
|
|
15
|
+
import hashlib
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import faiss
|
|
19
|
+
except ImportError:
|
|
20
|
+
raise ImportError("faiss-cpu is required. Install with: pip install faiss-cpu")
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from openai import OpenAI, AzureOpenAI
|
|
24
|
+
except ImportError:
|
|
25
|
+
raise ImportError("openai is required. Install with: pip install openai")
|
|
26
|
+
|
|
27
|
+
from .base_tool import BaseTool
|
|
28
|
+
from .tool_registry import register_tool
|
|
29
|
+
from .logging_config import get_logger
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@register_tool("EmbeddingDatabase")
|
|
33
|
+
class EmbeddingDatabase(BaseTool):
|
|
34
|
+
"""
|
|
35
|
+
Unified embedding database tool supporting multiple operations:
|
|
36
|
+
- create_from_docs: Create new database from documents
|
|
37
|
+
- add_docs: Add documents to existing database
|
|
38
|
+
- search: Search for similar documents
|
|
39
|
+
- load_database: Load existing database from path
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, tool_config):
|
|
43
|
+
super().__init__(tool_config)
|
|
44
|
+
self.logger = get_logger("EmbeddingDatabase")
|
|
45
|
+
|
|
46
|
+
# OpenAI configuration
|
|
47
|
+
openai_config = tool_config.get("configs", {}).get("openai_config", {})
|
|
48
|
+
azure_config = tool_config.get("configs", {}).get("azure_openai_config", {})
|
|
49
|
+
|
|
50
|
+
# Initialize OpenAI client (regular or Azure)
|
|
51
|
+
self.openai_client = None
|
|
52
|
+
self.azure_client = None
|
|
53
|
+
|
|
54
|
+
# Initialize both clients for flexibility
|
|
55
|
+
if openai_config.get("api_key") or os.getenv("OPENAI_API_KEY"):
|
|
56
|
+
self.openai_client = self._init_openai_client(openai_config)
|
|
57
|
+
|
|
58
|
+
if azure_config.get("api_key") or os.getenv("AZURE_OPENAI_API_KEY"):
|
|
59
|
+
self.azure_client = self._init_azure_client(azure_config)
|
|
60
|
+
|
|
61
|
+
if not self.openai_client and not self.azure_client:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"Either OpenAI or Azure OpenAI API credentials must be provided"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Storage configuration
|
|
67
|
+
storage_config = tool_config.get("configs", {}).get("storage_config", {})
|
|
68
|
+
self.data_dir = Path(storage_config.get("data_dir", "./data/embeddings"))
|
|
69
|
+
self.faiss_index_type = storage_config.get("faiss_index_type", "IndexFlatIP")
|
|
70
|
+
|
|
71
|
+
# Ensure data directory exists
|
|
72
|
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
73
|
+
|
|
74
|
+
# Database paths
|
|
75
|
+
self.db_path = self.data_dir / "embeddings.db"
|
|
76
|
+
|
|
77
|
+
# Initialize SQLite database
|
|
78
|
+
self._init_database()
|
|
79
|
+
|
|
80
|
+
def _init_openai_client(self, config):
|
|
81
|
+
"""Initialize OpenAI client with configuration"""
|
|
82
|
+
# Handle environment variable substitution
|
|
83
|
+
api_key = self._substitute_env_vars(config.get("api_key")) or os.getenv(
|
|
84
|
+
"OPENAI_API_KEY"
|
|
85
|
+
)
|
|
86
|
+
if not api_key:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
base_url = self._substitute_env_vars(config.get("base_url")) or os.getenv(
|
|
90
|
+
"OPENAI_BASE_URL", "https://api.openai.com/v1"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return OpenAI(
|
|
94
|
+
api_key=api_key,
|
|
95
|
+
base_url=base_url,
|
|
96
|
+
timeout=config.get("timeout", 60),
|
|
97
|
+
max_retries=config.get("max_retries", 3),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def _substitute_env_vars(self, value):
|
|
101
|
+
"""Substitute environment variables in configuration values"""
|
|
102
|
+
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
|
103
|
+
# Handle default values like ${VAR:default}
|
|
104
|
+
if ":" in value:
|
|
105
|
+
var_part = value[2:-1] # Remove ${ and }
|
|
106
|
+
var_name, default_value = var_part.split(":", 1)
|
|
107
|
+
return os.getenv(var_name, default_value)
|
|
108
|
+
else:
|
|
109
|
+
var_name = value[2:-1] # Remove ${ and }
|
|
110
|
+
return os.getenv(var_name)
|
|
111
|
+
return value
|
|
112
|
+
|
|
113
|
+
def _init_azure_client(self, config):
|
|
114
|
+
"""Initialize Azure OpenAI client with configuration"""
|
|
115
|
+
# Handle environment variable substitution
|
|
116
|
+
api_key = self._substitute_env_vars(config.get("api_key")) or os.getenv(
|
|
117
|
+
"AZURE_OPENAI_API_KEY"
|
|
118
|
+
)
|
|
119
|
+
endpoint = self._substitute_env_vars(config.get("azure_endpoint")) or os.getenv(
|
|
120
|
+
"AZURE_OPENAI_ENDPOINT"
|
|
121
|
+
)
|
|
122
|
+
api_version = self._substitute_env_vars(config.get("api_version")) or os.getenv(
|
|
123
|
+
"AZURE_OPENAI_API_VERSION", "2024-02-01"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if not api_key or not endpoint:
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
return AzureOpenAI(
|
|
130
|
+
api_key=api_key,
|
|
131
|
+
azure_endpoint=endpoint,
|
|
132
|
+
api_version=api_version,
|
|
133
|
+
timeout=120, # Increased timeout for Azure
|
|
134
|
+
max_retries=5, # Increased retries for Azure
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _init_database(self):
|
|
138
|
+
"""Initialize SQLite database with required tables"""
|
|
139
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
140
|
+
conn.execute(
|
|
141
|
+
"""
|
|
142
|
+
CREATE TABLE IF NOT EXISTS databases (
|
|
143
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
144
|
+
name TEXT UNIQUE NOT NULL,
|
|
145
|
+
description TEXT,
|
|
146
|
+
embedding_model TEXT,
|
|
147
|
+
embedding_dimensions INTEGER,
|
|
148
|
+
document_count INTEGER DEFAULT 0,
|
|
149
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
150
|
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
151
|
+
)
|
|
152
|
+
"""
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
conn.execute(
|
|
156
|
+
"""
|
|
157
|
+
CREATE TABLE IF NOT EXISTS documents (
|
|
158
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
159
|
+
database_name TEXT NOT NULL,
|
|
160
|
+
faiss_index INTEGER NOT NULL,
|
|
161
|
+
text TEXT NOT NULL,
|
|
162
|
+
metadata_json TEXT,
|
|
163
|
+
text_hash TEXT,
|
|
164
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
165
|
+
FOREIGN KEY (database_name) REFERENCES databases (name)
|
|
166
|
+
)
|
|
167
|
+
"""
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
conn.execute(
|
|
171
|
+
"""
|
|
172
|
+
CREATE INDEX IF NOT EXISTS idx_database_name ON documents (database_name)
|
|
173
|
+
"""
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
conn.execute(
|
|
177
|
+
"""
|
|
178
|
+
CREATE INDEX IF NOT EXISTS idx_text_hash ON documents (text_hash)
|
|
179
|
+
"""
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def run(self, arguments):
|
|
183
|
+
"""Main entry point for the tool"""
|
|
184
|
+
action = arguments.get("action")
|
|
185
|
+
|
|
186
|
+
if action == "create_from_docs":
|
|
187
|
+
return self._create_from_documents(arguments)
|
|
188
|
+
elif action == "add_docs":
|
|
189
|
+
return self._add_documents(arguments)
|
|
190
|
+
elif action == "search":
|
|
191
|
+
return self._search(arguments)
|
|
192
|
+
elif action == "load_database":
|
|
193
|
+
return self._load_database(arguments)
|
|
194
|
+
else:
|
|
195
|
+
return {"error": f"Unknown action: {action}"}
|
|
196
|
+
|
|
197
|
+
def _create_from_documents(self, arguments):
|
|
198
|
+
"""Create new embedding database from documents"""
|
|
199
|
+
database_name = arguments.get("database_name")
|
|
200
|
+
documents = arguments.get("documents", [])
|
|
201
|
+
metadata = arguments.get("metadata", [])
|
|
202
|
+
model = arguments.get("model", "text-embedding-3-small")
|
|
203
|
+
description = arguments.get("description", "")
|
|
204
|
+
use_azure = arguments.get("use_azure", False)
|
|
205
|
+
|
|
206
|
+
if not database_name:
|
|
207
|
+
return {"error": "database_name is required"}
|
|
208
|
+
if not documents:
|
|
209
|
+
return {"error": "documents list cannot be empty"}
|
|
210
|
+
|
|
211
|
+
# Check if database already exists
|
|
212
|
+
if self._database_exists(database_name):
|
|
213
|
+
return {
|
|
214
|
+
"error": f"Database '{database_name}' already exists. Use 'add_docs' to add more documents."
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
# Generate embeddings
|
|
219
|
+
self.logger.info(
|
|
220
|
+
f"Generating embeddings for {len(documents)} documents using {model}"
|
|
221
|
+
)
|
|
222
|
+
embeddings = self._generate_embeddings(documents, model, use_azure)
|
|
223
|
+
|
|
224
|
+
if not embeddings:
|
|
225
|
+
return {"error": "Failed to generate embeddings"}
|
|
226
|
+
|
|
227
|
+
# Get embedding dimensions
|
|
228
|
+
dimensions = len(embeddings[0])
|
|
229
|
+
|
|
230
|
+
# Create FAISS index
|
|
231
|
+
if self.faiss_index_type == "IndexFlatIP":
|
|
232
|
+
index = faiss.IndexFlatIP(dimensions)
|
|
233
|
+
elif self.faiss_index_type == "IndexFlatL2":
|
|
234
|
+
index = faiss.IndexFlatL2(dimensions)
|
|
235
|
+
else:
|
|
236
|
+
index = faiss.IndexFlatIP(dimensions) # Default fallback
|
|
237
|
+
|
|
238
|
+
# Add embeddings to FAISS index
|
|
239
|
+
embedding_matrix = np.array(embeddings, dtype=np.float32)
|
|
240
|
+
|
|
241
|
+
# Normalize embeddings for cosine similarity if using IndexFlatIP
|
|
242
|
+
if self.faiss_index_type == "IndexFlatIP":
|
|
243
|
+
# Normalize the embeddings to unit vectors for cosine similarity
|
|
244
|
+
norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True)
|
|
245
|
+
embedding_matrix = embedding_matrix / norms
|
|
246
|
+
self.logger.info(
|
|
247
|
+
f"Normalized embeddings for IndexFlatIP. Norms: {norms.flatten()[:3]}..."
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
index.add(embedding_matrix)
|
|
251
|
+
|
|
252
|
+
# Save FAISS index
|
|
253
|
+
index_path = self.data_dir / f"{database_name}.faiss"
|
|
254
|
+
faiss.write_index(index, str(index_path))
|
|
255
|
+
|
|
256
|
+
# Store database info and documents in SQLite
|
|
257
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
258
|
+
# Insert database record
|
|
259
|
+
conn.execute(
|
|
260
|
+
"""
|
|
261
|
+
INSERT INTO databases (name, description, embedding_model, embedding_dimensions, document_count)
|
|
262
|
+
VALUES (?, ?, ?, ?, ?)
|
|
263
|
+
""",
|
|
264
|
+
(database_name, description, model, dimensions, len(documents)),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Insert document records
|
|
268
|
+
for i, (doc, meta) in enumerate(
|
|
269
|
+
zip(documents, metadata + [{}] * len(documents))
|
|
270
|
+
):
|
|
271
|
+
text_hash = hashlib.md5(doc.encode()).hexdigest()
|
|
272
|
+
metadata_json = json.dumps(meta)
|
|
273
|
+
|
|
274
|
+
conn.execute(
|
|
275
|
+
"""
|
|
276
|
+
INSERT INTO documents (database_name, faiss_index, text, metadata_json, text_hash)
|
|
277
|
+
VALUES (?, ?, ?, ?, ?)
|
|
278
|
+
""",
|
|
279
|
+
(database_name, i, doc, metadata_json, text_hash),
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
self.logger.info(
|
|
283
|
+
f"Created database '{database_name}' with {len(documents)} documents"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return {
|
|
287
|
+
"status": "success",
|
|
288
|
+
"database_name": database_name,
|
|
289
|
+
"documents_added": len(documents),
|
|
290
|
+
"embedding_model": model,
|
|
291
|
+
"dimensions": dimensions,
|
|
292
|
+
"index_path": str(index_path),
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
except Exception as e:
|
|
296
|
+
self.logger.error(f"Error creating database: {str(e)}")
|
|
297
|
+
return {"error": f"Failed to create database: {str(e)}"}
|
|
298
|
+
|
|
299
|
+
def _add_documents(self, arguments):
|
|
300
|
+
"""Add documents to existing database"""
|
|
301
|
+
database_name = arguments.get("database_name")
|
|
302
|
+
documents = arguments.get("documents", [])
|
|
303
|
+
metadata = arguments.get("metadata", [])
|
|
304
|
+
use_azure = arguments.get("use_azure", False)
|
|
305
|
+
|
|
306
|
+
if not database_name:
|
|
307
|
+
return {"error": "database_name is required"}
|
|
308
|
+
if not documents:
|
|
309
|
+
return {"error": "documents list cannot be empty"}
|
|
310
|
+
|
|
311
|
+
if not self._database_exists(database_name):
|
|
312
|
+
return {
|
|
313
|
+
"error": f"Database '{database_name}' does not exist. Use 'create_from_docs' first."
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
try:
|
|
317
|
+
# Get database info
|
|
318
|
+
db_info = self._get_database_info(database_name)
|
|
319
|
+
model = db_info["embedding_model"]
|
|
320
|
+
|
|
321
|
+
# Generate embeddings for new documents
|
|
322
|
+
self.logger.info(
|
|
323
|
+
f"Generating embeddings for {len(documents)} new documents"
|
|
324
|
+
)
|
|
325
|
+
new_embeddings = self._generate_embeddings(documents, model, use_azure)
|
|
326
|
+
|
|
327
|
+
if not new_embeddings:
|
|
328
|
+
return {"error": "Failed to generate embeddings"}
|
|
329
|
+
|
|
330
|
+
# Load existing FAISS index
|
|
331
|
+
index_path = self.data_dir / f"{database_name}.faiss"
|
|
332
|
+
index = faiss.read_index(str(index_path))
|
|
333
|
+
|
|
334
|
+
# Get current document count for new indices
|
|
335
|
+
current_count = index.ntotal
|
|
336
|
+
|
|
337
|
+
# Add new embeddings to index
|
|
338
|
+
new_embedding_matrix = np.array(new_embeddings, dtype=np.float32)
|
|
339
|
+
|
|
340
|
+
# Normalize embeddings for cosine similarity if using IndexFlatIP
|
|
341
|
+
if self.faiss_index_type == "IndexFlatIP":
|
|
342
|
+
norms = np.linalg.norm(new_embedding_matrix, axis=1, keepdims=True)
|
|
343
|
+
new_embedding_matrix = new_embedding_matrix / norms
|
|
344
|
+
self.logger.info(
|
|
345
|
+
f"Normalized new embeddings for IndexFlatIP. Norms: {norms.flatten()[:3]}..."
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
index.add(new_embedding_matrix)
|
|
349
|
+
|
|
350
|
+
# Save updated index
|
|
351
|
+
faiss.write_index(index, str(index_path))
|
|
352
|
+
|
|
353
|
+
# Add documents to SQLite
|
|
354
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
355
|
+
for i, (doc, meta) in enumerate(
|
|
356
|
+
zip(documents, metadata + [{}] * len(documents))
|
|
357
|
+
):
|
|
358
|
+
text_hash = hashlib.md5(doc.encode()).hexdigest()
|
|
359
|
+
metadata_json = json.dumps(meta)
|
|
360
|
+
faiss_index = current_count + i
|
|
361
|
+
|
|
362
|
+
conn.execute(
|
|
363
|
+
"""
|
|
364
|
+
INSERT INTO documents (database_name, faiss_index, text, metadata_json, text_hash)
|
|
365
|
+
VALUES (?, ?, ?, ?, ?)
|
|
366
|
+
""",
|
|
367
|
+
(database_name, faiss_index, doc, metadata_json, text_hash),
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# Update document count
|
|
371
|
+
conn.execute(
|
|
372
|
+
"""
|
|
373
|
+
UPDATE databases
|
|
374
|
+
SET document_count = document_count + ?, updated_at = CURRENT_TIMESTAMP
|
|
375
|
+
WHERE name = ?
|
|
376
|
+
""",
|
|
377
|
+
(len(documents), database_name),
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
self.logger.info(
|
|
381
|
+
f"Added {len(documents)} documents to database '{database_name}'"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
return {
|
|
385
|
+
"status": "success",
|
|
386
|
+
"database_name": database_name,
|
|
387
|
+
"documents_added": len(documents),
|
|
388
|
+
"total_documents": current_count + len(documents),
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
except Exception as e:
|
|
392
|
+
self.logger.error(f"Error adding documents: {str(e)}")
|
|
393
|
+
return {"error": f"Failed to add documents: {str(e)}"}
|
|
394
|
+
|
|
395
|
+
def _search(self, arguments):
|
|
396
|
+
"""Search for similar documents in database"""
|
|
397
|
+
database_name = arguments.get("database_name")
|
|
398
|
+
query = arguments.get("query")
|
|
399
|
+
top_k = arguments.get("top_k", 5)
|
|
400
|
+
filters = arguments.get(
|
|
401
|
+
"metadata_filter", arguments.get("filters", {})
|
|
402
|
+
) # Support both parameter names
|
|
403
|
+
use_azure = arguments.get("use_azure", False)
|
|
404
|
+
|
|
405
|
+
if not database_name:
|
|
406
|
+
return {"error": "database_name is required"}
|
|
407
|
+
if not query:
|
|
408
|
+
return {"error": "query is required"}
|
|
409
|
+
|
|
410
|
+
if not self._database_exists(database_name):
|
|
411
|
+
return {"error": f"Database '{database_name}' does not exist"}
|
|
412
|
+
|
|
413
|
+
try:
|
|
414
|
+
# Get database info
|
|
415
|
+
db_info = self._get_database_info(database_name)
|
|
416
|
+
model = db_info["embedding_model"]
|
|
417
|
+
|
|
418
|
+
# Generate query embedding
|
|
419
|
+
query_embedding = self._generate_embeddings([query], model, use_azure)
|
|
420
|
+
if not query_embedding:
|
|
421
|
+
return {"error": "Failed to generate query embedding"}
|
|
422
|
+
|
|
423
|
+
# Load FAISS index
|
|
424
|
+
index_path = self.data_dir / f"{database_name}.faiss"
|
|
425
|
+
index = faiss.read_index(str(index_path))
|
|
426
|
+
|
|
427
|
+
# Search for similar vectors
|
|
428
|
+
query_vector = np.array([query_embedding[0]], dtype=np.float32)
|
|
429
|
+
|
|
430
|
+
# Normalize query vector if using IndexFlatIP for cosine similarity
|
|
431
|
+
if self.faiss_index_type == "IndexFlatIP":
|
|
432
|
+
query_norm = np.linalg.norm(query_vector, axis=1, keepdims=True)
|
|
433
|
+
query_vector = query_vector / query_norm
|
|
434
|
+
self.logger.info(
|
|
435
|
+
f"Normalized query vector. Query norm: {query_norm[0][0]:.3f}"
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
scores, indices = index.search(query_vector, min(top_k, index.ntotal))
|
|
439
|
+
self.logger.info(
|
|
440
|
+
f"FAISS search results - Scores: {scores[0][:3]}, Indices: {indices[0][:3]}"
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Get document details from SQLite
|
|
444
|
+
results = []
|
|
445
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
446
|
+
for score, idx in zip(scores[0], indices[0]):
|
|
447
|
+
if idx == -1: # FAISS returns -1 for unfilled positions
|
|
448
|
+
continue
|
|
449
|
+
|
|
450
|
+
cursor = conn.execute(
|
|
451
|
+
"""
|
|
452
|
+
SELECT text, metadata_json FROM documents
|
|
453
|
+
WHERE database_name = ? AND faiss_index = ?
|
|
454
|
+
""",
|
|
455
|
+
(database_name, int(idx)),
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
row = cursor.fetchone()
|
|
459
|
+
if row:
|
|
460
|
+
text, metadata_json = row
|
|
461
|
+
metadata = json.loads(metadata_json) if metadata_json else {}
|
|
462
|
+
|
|
463
|
+
# Apply metadata filters if specified
|
|
464
|
+
if self._matches_filters(metadata, filters):
|
|
465
|
+
results.append(
|
|
466
|
+
{
|
|
467
|
+
"text": text,
|
|
468
|
+
"metadata": metadata,
|
|
469
|
+
"similarity_score": float(score),
|
|
470
|
+
}
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Sort by similarity score (descending)
|
|
474
|
+
results.sort(key=lambda x: x["similarity_score"], reverse=True)
|
|
475
|
+
|
|
476
|
+
return {
|
|
477
|
+
"status": "success",
|
|
478
|
+
"database_name": database_name,
|
|
479
|
+
"query": query,
|
|
480
|
+
"results": results[:top_k],
|
|
481
|
+
"total_found": len(results),
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
except Exception as e:
|
|
485
|
+
self.logger.error(f"Error searching database: {str(e)}")
|
|
486
|
+
return {"error": f"Failed to search database: {str(e)}"}
|
|
487
|
+
|
|
488
|
+
def _load_database(self, arguments):
|
|
489
|
+
"""Load existing database from path"""
|
|
490
|
+
database_path = arguments.get("database_path")
|
|
491
|
+
database_name = arguments.get("database_name")
|
|
492
|
+
|
|
493
|
+
if not database_path:
|
|
494
|
+
return {"error": "database_path is required"}
|
|
495
|
+
if not database_name:
|
|
496
|
+
return {"error": "database_name is required"}
|
|
497
|
+
|
|
498
|
+
# This is a placeholder for loading external databases
|
|
499
|
+
# Implementation would depend on the specific format of the external database
|
|
500
|
+
return {"error": "load_database not yet implemented"}
|
|
501
|
+
|
|
502
|
+
def _generate_embeddings(
|
|
503
|
+
self, texts: List[str], model: str, use_azure: bool = False
|
|
504
|
+
) -> List[List[float]]:
|
|
505
|
+
"""Generate embeddings using OpenAI or Azure OpenAI API"""
|
|
506
|
+
import time
|
|
507
|
+
|
|
508
|
+
try:
|
|
509
|
+
# Choose which client to use
|
|
510
|
+
client = None
|
|
511
|
+
if use_azure and self.azure_client:
|
|
512
|
+
client = self.azure_client
|
|
513
|
+
self.logger.info("Using Azure OpenAI for embeddings")
|
|
514
|
+
elif not use_azure and self.openai_client:
|
|
515
|
+
client = self.openai_client
|
|
516
|
+
self.logger.info("Using OpenAI for embeddings")
|
|
517
|
+
elif self.azure_client: # Fallback to Azure if available
|
|
518
|
+
client = self.azure_client
|
|
519
|
+
self.logger.info("Falling back to Azure OpenAI")
|
|
520
|
+
elif self.openai_client: # Fallback to OpenAI if available
|
|
521
|
+
client = self.openai_client
|
|
522
|
+
self.logger.info("Falling back to OpenAI")
|
|
523
|
+
else:
|
|
524
|
+
raise ValueError("No OpenAI or Azure OpenAI client available")
|
|
525
|
+
|
|
526
|
+
# Process in smaller batches for Azure OpenAI
|
|
527
|
+
batch_size = 10 if use_azure else 100
|
|
528
|
+
all_embeddings = []
|
|
529
|
+
|
|
530
|
+
for _i in range(0, len(texts), batch_size):
|
|
531
|
+
batch = texts[_i : _i + batch_size]
|
|
532
|
+
retry_count = 0
|
|
533
|
+
max_retries = 3
|
|
534
|
+
|
|
535
|
+
while retry_count < max_retries:
|
|
536
|
+
try:
|
|
537
|
+
response = client.embeddings.create(input=batch, model=model)
|
|
538
|
+
batch_embeddings = [
|
|
539
|
+
embedding.embedding for embedding in response.data
|
|
540
|
+
]
|
|
541
|
+
all_embeddings.extend(batch_embeddings)
|
|
542
|
+
|
|
543
|
+
# Small delay between batches for Azure
|
|
544
|
+
if use_azure and _i + batch_size < len(texts):
|
|
545
|
+
time.sleep(0.5)
|
|
546
|
+
break
|
|
547
|
+
|
|
548
|
+
except Exception as batch_error:
|
|
549
|
+
retry_count += 1
|
|
550
|
+
if retry_count >= max_retries:
|
|
551
|
+
raise batch_error
|
|
552
|
+
|
|
553
|
+
self.logger.warning(
|
|
554
|
+
f"Batch {_i//batch_size + 1} failed, retrying ({retry_count}/{max_retries})"
|
|
555
|
+
)
|
|
556
|
+
time.sleep(retry_count * 2) # Exponential backoff
|
|
557
|
+
|
|
558
|
+
return all_embeddings
|
|
559
|
+
|
|
560
|
+
except Exception as e:
|
|
561
|
+
self.logger.error(f"Error generating embeddings: {str(e)}")
|
|
562
|
+
return []
|
|
563
|
+
|
|
564
|
+
def _database_exists(self, database_name: str) -> bool:
|
|
565
|
+
"""Check if database exists"""
|
|
566
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
567
|
+
cursor = conn.execute(
|
|
568
|
+
"SELECT 1 FROM databases WHERE name = ?", (database_name,)
|
|
569
|
+
)
|
|
570
|
+
return cursor.fetchone() is not None
|
|
571
|
+
|
|
572
|
+
def _get_database_info(self, database_name: str) -> Dict:
|
|
573
|
+
"""Get database information"""
|
|
574
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
575
|
+
cursor = conn.execute(
|
|
576
|
+
"""
|
|
577
|
+
SELECT name, description, embedding_model, embedding_dimensions, document_count, created_at
|
|
578
|
+
FROM databases WHERE name = ?
|
|
579
|
+
""",
|
|
580
|
+
(database_name,),
|
|
581
|
+
)
|
|
582
|
+
row = cursor.fetchone()
|
|
583
|
+
if row:
|
|
584
|
+
return {
|
|
585
|
+
"name": row[0],
|
|
586
|
+
"description": row[1],
|
|
587
|
+
"embedding_model": row[2],
|
|
588
|
+
"embedding_dimensions": row[3],
|
|
589
|
+
"document_count": row[4],
|
|
590
|
+
"created_at": row[5],
|
|
591
|
+
}
|
|
592
|
+
return {}
|
|
593
|
+
|
|
594
|
+
def _matches_filters(self, metadata: Dict, filters: Dict) -> bool:
|
|
595
|
+
"""Check if metadata matches the given filters"""
|
|
596
|
+
if not filters:
|
|
597
|
+
return True
|
|
598
|
+
|
|
599
|
+
for key, filter_value in filters.items():
|
|
600
|
+
if key not in metadata:
|
|
601
|
+
return False
|
|
602
|
+
|
|
603
|
+
meta_value = metadata[key]
|
|
604
|
+
|
|
605
|
+
# Handle different filter types
|
|
606
|
+
if isinstance(filter_value, dict):
|
|
607
|
+
# Range filters like {"$gte": 2022, "$lt": 2025}
|
|
608
|
+
if "$gte" in filter_value and meta_value < filter_value["$gte"]:
|
|
609
|
+
return False
|
|
610
|
+
if "$gt" in filter_value and meta_value <= filter_value["$gt"]:
|
|
611
|
+
return False
|
|
612
|
+
if "$lte" in filter_value and meta_value > filter_value["$lte"]:
|
|
613
|
+
return False
|
|
614
|
+
if "$lt" in filter_value and meta_value >= filter_value["$lt"]:
|
|
615
|
+
return False
|
|
616
|
+
if "$in" in filter_value and meta_value not in filter_value["$in"]:
|
|
617
|
+
return False
|
|
618
|
+
if "$contains" in filter_value:
|
|
619
|
+
if isinstance(meta_value, list):
|
|
620
|
+
if filter_value["$contains"] not in meta_value:
|
|
621
|
+
return False
|
|
622
|
+
else:
|
|
623
|
+
if filter_value["$contains"] not in str(meta_value):
|
|
624
|
+
return False
|
|
625
|
+
else:
|
|
626
|
+
# Exact match
|
|
627
|
+
if meta_value != filter_value:
|
|
628
|
+
return False
|
|
629
|
+
|
|
630
|
+
return True
|