MindsDB 25.9.2.0a1__py3-none-any.whl → 25.9.3rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +39 -20
- mindsdb/api/a2a/agent.py +7 -9
- mindsdb/api/a2a/common/server/server.py +3 -3
- mindsdb/api/a2a/common/server/task_manager.py +4 -4
- mindsdb/api/a2a/task_manager.py +15 -17
- mindsdb/api/common/middleware.py +9 -11
- mindsdb/api/executor/command_executor.py +2 -4
- mindsdb/api/executor/datahub/datanodes/datanode.py +2 -2
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +100 -48
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +8 -4
- mindsdb/api/executor/datahub/datanodes/system_tables.py +1 -1
- mindsdb/api/executor/exceptions.py +29 -10
- mindsdb/api/executor/planner/plan_join.py +17 -3
- mindsdb/api/executor/sql_query/sql_query.py +74 -74
- mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +1 -2
- mindsdb/api/executor/sql_query/steps/subselect_step.py +0 -1
- mindsdb/api/executor/utilities/functions.py +6 -6
- mindsdb/api/executor/utilities/sql.py +32 -16
- mindsdb/api/http/gui.py +5 -11
- mindsdb/api/http/initialize.py +8 -10
- mindsdb/api/http/namespaces/agents.py +10 -12
- mindsdb/api/http/namespaces/analysis.py +13 -20
- mindsdb/api/http/namespaces/auth.py +1 -1
- mindsdb/api/http/namespaces/config.py +15 -11
- mindsdb/api/http/namespaces/databases.py +140 -201
- mindsdb/api/http/namespaces/file.py +15 -4
- mindsdb/api/http/namespaces/handlers.py +7 -2
- mindsdb/api/http/namespaces/knowledge_bases.py +8 -7
- mindsdb/api/http/namespaces/models.py +94 -126
- mindsdb/api/http/namespaces/projects.py +13 -22
- mindsdb/api/http/namespaces/sql.py +33 -25
- mindsdb/api/http/namespaces/tab.py +27 -37
- mindsdb/api/http/namespaces/views.py +1 -1
- mindsdb/api/http/start.py +14 -8
- mindsdb/api/mcp/__init__.py +2 -1
- mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +15 -20
- mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +26 -50
- mindsdb/api/mysql/mysql_proxy/utilities/__init__.py +0 -1
- mindsdb/api/postgres/postgres_proxy/executor/executor.py +6 -13
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_packets.py +40 -28
- mindsdb/integrations/handlers/byom_handler/byom_handler.py +168 -185
- mindsdb/integrations/handlers/file_handler/file_handler.py +7 -0
- mindsdb/integrations/handlers/lightwood_handler/functions.py +45 -79
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +13 -1
- mindsdb/integrations/handlers/shopify_handler/shopify_handler.py +25 -12
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +2 -1
- mindsdb/integrations/handlers/statsforecast_handler/requirements.txt +1 -0
- mindsdb/integrations/handlers/statsforecast_handler/requirements_extra.txt +1 -0
- mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +4 -4
- mindsdb/integrations/libs/api_handler.py +10 -10
- mindsdb/integrations/libs/base.py +4 -4
- mindsdb/integrations/libs/llm/utils.py +2 -2
- mindsdb/integrations/libs/ml_handler_process/create_engine_process.py +4 -7
- mindsdb/integrations/libs/ml_handler_process/func_call_process.py +2 -7
- mindsdb/integrations/libs/ml_handler_process/learn_process.py +37 -47
- mindsdb/integrations/libs/ml_handler_process/update_engine_process.py +4 -7
- mindsdb/integrations/libs/ml_handler_process/update_process.py +2 -7
- mindsdb/integrations/libs/process_cache.py +132 -140
- mindsdb/integrations/libs/response.py +18 -12
- mindsdb/integrations/libs/vectordatabase_handler.py +26 -0
- mindsdb/integrations/utilities/files/file_reader.py +6 -7
- mindsdb/integrations/utilities/rag/config_loader.py +37 -26
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +59 -9
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +4 -4
- mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +55 -133
- mindsdb/integrations/utilities/rag/settings.py +58 -133
- mindsdb/integrations/utilities/rag/splitters/file_splitter.py +5 -15
- mindsdb/interfaces/agents/agents_controller.py +2 -1
- mindsdb/interfaces/agents/constants.py +0 -2
- mindsdb/interfaces/agents/litellm_server.py +34 -58
- mindsdb/interfaces/agents/mcp_client_agent.py +10 -10
- mindsdb/interfaces/agents/mindsdb_database_agent.py +5 -5
- mindsdb/interfaces/agents/run_mcp_agent.py +12 -21
- mindsdb/interfaces/chatbot/chatbot_task.py +20 -23
- mindsdb/interfaces/chatbot/polling.py +30 -18
- mindsdb/interfaces/data_catalog/data_catalog_loader.py +10 -10
- mindsdb/interfaces/database/integrations.py +19 -2
- mindsdb/interfaces/file/file_controller.py +6 -6
- mindsdb/interfaces/functions/controller.py +1 -1
- mindsdb/interfaces/functions/to_markdown.py +2 -2
- mindsdb/interfaces/jobs/jobs_controller.py +5 -5
- mindsdb/interfaces/jobs/scheduler.py +3 -8
- mindsdb/interfaces/knowledge_base/controller.py +50 -23
- mindsdb/interfaces/knowledge_base/preprocessing/json_chunker.py +40 -61
- mindsdb/interfaces/model/model_controller.py +170 -166
- mindsdb/interfaces/query_context/context_controller.py +14 -2
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +6 -4
- mindsdb/interfaces/skills/retrieval_tool.py +43 -50
- mindsdb/interfaces/skills/skill_tool.py +2 -2
- mindsdb/interfaces/skills/sql_agent.py +25 -19
- mindsdb/interfaces/storage/fs.py +114 -169
- mindsdb/interfaces/storage/json.py +19 -18
- mindsdb/interfaces/tabs/tabs_controller.py +49 -72
- mindsdb/interfaces/tasks/task_monitor.py +3 -9
- mindsdb/interfaces/tasks/task_thread.py +7 -9
- mindsdb/interfaces/triggers/trigger_task.py +7 -13
- mindsdb/interfaces/triggers/triggers_controller.py +47 -50
- mindsdb/migrations/migrate.py +16 -16
- mindsdb/utilities/api_status.py +58 -0
- mindsdb/utilities/config.py +49 -0
- mindsdb/utilities/exception.py +40 -1
- mindsdb/utilities/fs.py +0 -1
- mindsdb/utilities/hooks/profiling.py +17 -14
- mindsdb/utilities/langfuse.py +40 -45
- mindsdb/utilities/log.py +272 -0
- mindsdb/utilities/ml_task_queue/consumer.py +52 -58
- mindsdb/utilities/ml_task_queue/producer.py +26 -30
- mindsdb/utilities/render/sqlalchemy_render.py +7 -6
- mindsdb/utilities/utils.py +2 -2
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/METADATA +269 -264
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/RECORD +115 -115
- mindsdb/api/mysql/mysql_proxy/utilities/exceptions.py +0 -14
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/WHEEL +0 -0
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,26 @@
|
|
|
1
1
|
"""Utility functions for RAG pipeline configuration"""
|
|
2
|
+
|
|
2
3
|
from typing import Dict, Any, Optional
|
|
3
4
|
|
|
4
5
|
from mindsdb.utilities.log import getLogger
|
|
5
6
|
from mindsdb.integrations.utilities.rag.settings import (
|
|
6
|
-
RetrieverType,
|
|
7
|
-
|
|
8
|
-
|
|
7
|
+
RetrieverType,
|
|
8
|
+
MultiVectorRetrieverMode,
|
|
9
|
+
SearchType,
|
|
10
|
+
SearchKwargs,
|
|
11
|
+
SummarizationConfig,
|
|
12
|
+
VectorStoreConfig,
|
|
13
|
+
RerankerConfig,
|
|
14
|
+
RAGPipelineModel,
|
|
15
|
+
DEFAULT_COLLECTION_NAME,
|
|
9
16
|
)
|
|
10
17
|
|
|
11
18
|
logger = getLogger(__name__)
|
|
12
19
|
|
|
13
20
|
|
|
14
|
-
def load_rag_config(
|
|
21
|
+
def load_rag_config(
|
|
22
|
+
base_config: Dict[str, Any], kb_params: Optional[Dict[str, Any]] = None, embedding_model: Any = None
|
|
23
|
+
) -> RAGPipelineModel:
|
|
15
24
|
"""
|
|
16
25
|
Load and validate RAG configuration parameters. This function handles the conversion of configuration
|
|
17
26
|
parameters into their appropriate types and ensures all required settings are properly configured.
|
|
@@ -37,41 +46,43 @@ def load_rag_config(base_config: Dict[str, Any], kb_params: Optional[Dict[str, A
|
|
|
37
46
|
|
|
38
47
|
# Set embedding model if provided
|
|
39
48
|
if embedding_model is not None:
|
|
40
|
-
rag_params[
|
|
49
|
+
rag_params["embedding_model"] = embedding_model
|
|
41
50
|
|
|
42
51
|
# Handle enums and type conversions
|
|
43
|
-
if
|
|
44
|
-
rag_params[
|
|
45
|
-
if
|
|
46
|
-
rag_params[
|
|
47
|
-
if
|
|
48
|
-
rag_params[
|
|
52
|
+
if "retriever_type" in rag_params:
|
|
53
|
+
rag_params["retriever_type"] = RetrieverType(rag_params["retriever_type"])
|
|
54
|
+
if "multi_retriever_mode" in rag_params:
|
|
55
|
+
rag_params["multi_retriever_mode"] = MultiVectorRetrieverMode(rag_params["multi_retriever_mode"])
|
|
56
|
+
if "search_type" in rag_params:
|
|
57
|
+
rag_params["search_type"] = SearchType(rag_params["search_type"])
|
|
49
58
|
|
|
50
59
|
# Handle search kwargs if present
|
|
51
|
-
if
|
|
52
|
-
rag_params[
|
|
60
|
+
if "search_kwargs" in rag_params and isinstance(rag_params["search_kwargs"], dict):
|
|
61
|
+
rag_params["search_kwargs"] = SearchKwargs(**rag_params["search_kwargs"])
|
|
53
62
|
|
|
54
63
|
# Handle summarization config if present
|
|
55
|
-
summarization_config = rag_params.get(
|
|
64
|
+
summarization_config = rag_params.get("summarization_config")
|
|
56
65
|
if summarization_config is not None and isinstance(summarization_config, dict):
|
|
57
|
-
rag_params[
|
|
66
|
+
rag_params["summarization_config"] = SummarizationConfig(**summarization_config)
|
|
58
67
|
|
|
59
68
|
# Handle vector store config
|
|
60
|
-
if
|
|
61
|
-
if isinstance(rag_params[
|
|
62
|
-
rag_params[
|
|
69
|
+
if "vector_store_config" in rag_params:
|
|
70
|
+
if isinstance(rag_params["vector_store_config"], dict):
|
|
71
|
+
rag_params["vector_store_config"] = VectorStoreConfig(**rag_params["vector_store_config"])
|
|
63
72
|
else:
|
|
64
|
-
rag_params[
|
|
65
|
-
logger.warning(
|
|
66
|
-
|
|
67
|
-
|
|
73
|
+
rag_params["vector_store_config"] = {}
|
|
74
|
+
logger.warning(
|
|
75
|
+
f"No collection_name specified for the retrieval tool, "
|
|
76
|
+
f"using default collection_name: '{DEFAULT_COLLECTION_NAME}'"
|
|
77
|
+
f"\nWarning: If this collection does not exist, no data will be retrieved"
|
|
78
|
+
)
|
|
68
79
|
|
|
69
|
-
if
|
|
70
|
-
rag_params[
|
|
80
|
+
if "reranker_config" in rag_params:
|
|
81
|
+
rag_params["reranker_config"] = RerankerConfig(**rag_params["reranker_config"])
|
|
71
82
|
|
|
72
83
|
# Convert to RAGPipelineModel with validation
|
|
73
84
|
try:
|
|
74
85
|
return RAGPipelineModel(**rag_params)
|
|
75
86
|
except Exception as e:
|
|
76
|
-
logger.
|
|
77
|
-
raise ValueError(f"Configuration validation failed: {str(e)}")
|
|
87
|
+
logger.exception("Invalid RAG configuration:")
|
|
88
|
+
raise ValueError(f"Configuration validation failed: {str(e)}") from e
|
|
@@ -13,7 +13,15 @@ from typing import Any, List, Optional, Tuple
|
|
|
13
13
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
|
14
14
|
from pydantic import BaseModel
|
|
15
15
|
|
|
16
|
-
from mindsdb.integrations.utilities.rag.settings import
|
|
16
|
+
from mindsdb.integrations.utilities.rag.settings import (
|
|
17
|
+
DEFAULT_RERANKING_MODEL,
|
|
18
|
+
DEFAULT_LLM_ENDPOINT,
|
|
19
|
+
DEFAULT_RERANKER_N,
|
|
20
|
+
DEFAULT_RERANKER_LOGPROBS,
|
|
21
|
+
DEFAULT_RERANKER_TOP_LOGPROBS,
|
|
22
|
+
DEFAULT_RERANKER_MAX_TOKENS,
|
|
23
|
+
DEFAULT_VALID_CLASS_TOKENS,
|
|
24
|
+
)
|
|
17
25
|
from mindsdb.integrations.libs.base import BaseMLEngine
|
|
18
26
|
|
|
19
27
|
log = logging.getLogger(__name__)
|
|
@@ -38,6 +46,11 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
38
46
|
request_timeout: float = 20.0 # Timeout for API requests
|
|
39
47
|
early_stop: bool = True # Whether to enable early stopping
|
|
40
48
|
early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
|
|
49
|
+
n: int = DEFAULT_RERANKER_N # Number of completions to generate
|
|
50
|
+
logprobs: bool = DEFAULT_RERANKER_LOGPROBS # Whether to include log probabilities
|
|
51
|
+
top_logprobs: int = DEFAULT_RERANKER_TOP_LOGPROBS # Number of top log probabilities to include
|
|
52
|
+
max_tokens: int = DEFAULT_RERANKER_MAX_TOKENS # Maximum tokens to generate
|
|
53
|
+
valid_class_tokens: List[str] = DEFAULT_VALID_CLASS_TOKENS
|
|
41
54
|
|
|
42
55
|
class Config:
|
|
43
56
|
arbitrary_types_allowed = True
|
|
@@ -142,7 +155,7 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
142
155
|
return ranked_results
|
|
143
156
|
except Exception as e:
|
|
144
157
|
# Don't let early stopping errors stop the whole process
|
|
145
|
-
log.warning(f"Error in early stopping check: {
|
|
158
|
+
log.warning(f"Error in early stopping check: {e}")
|
|
146
159
|
|
|
147
160
|
return ranked_results
|
|
148
161
|
|
|
@@ -234,6 +247,28 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
234
247
|
return rerank_data
|
|
235
248
|
|
|
236
249
|
async def search_relevancy_score(self, query: str, document: str) -> Any:
|
|
250
|
+
"""
|
|
251
|
+
This method is used to score the relevance of a document to a query.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
query: The query to score the relevance of.
|
|
255
|
+
document: The document to score the relevance of.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
A dictionary with the document and the relevance score.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
log.debug("Start search_relevancy_score")
|
|
262
|
+
log.debug(f"Reranker query: {query[:5]}")
|
|
263
|
+
log.debug(f"Reranker document: {document[:50]}")
|
|
264
|
+
log.debug(f"Reranker model: {self.model}")
|
|
265
|
+
log.debug(f"Reranker temperature: {self.temperature}")
|
|
266
|
+
log.debug(f"Reranker n: {self.n}")
|
|
267
|
+
log.debug(f"Reranker logprobs: {self.logprobs}")
|
|
268
|
+
log.debug(f"Reranker top_logprobs: {self.top_logprobs}")
|
|
269
|
+
log.debug(f"Reranker max_tokens: {self.max_tokens}")
|
|
270
|
+
log.debug(f"Reranker valid_class_tokens: {self.valid_class_tokens}")
|
|
271
|
+
|
|
237
272
|
response = await self.client.chat.completions.create(
|
|
238
273
|
model=self.model,
|
|
239
274
|
messages=[
|
|
@@ -306,17 +341,30 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
306
341
|
},
|
|
307
342
|
],
|
|
308
343
|
temperature=self.temperature,
|
|
309
|
-
n=
|
|
310
|
-
logprobs=
|
|
311
|
-
top_logprobs=
|
|
312
|
-
max_tokens=
|
|
344
|
+
n=self.n,
|
|
345
|
+
logprobs=self.logprobs,
|
|
346
|
+
top_logprobs=self.top_logprobs,
|
|
347
|
+
max_tokens=self.max_tokens,
|
|
313
348
|
)
|
|
314
349
|
|
|
315
350
|
# Extract response and logprobs
|
|
316
351
|
token_logprobs = response.choices[0].logprobs.content
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
352
|
+
|
|
353
|
+
# Find the token that contains the class number
|
|
354
|
+
# Instead of just taking the last token, search for the actual class number token
|
|
355
|
+
class_token_logprob = None
|
|
356
|
+
for token_logprob in reversed(token_logprobs):
|
|
357
|
+
if token_logprob.token in self.valid_class_tokens:
|
|
358
|
+
class_token_logprob = token_logprob
|
|
359
|
+
break
|
|
360
|
+
|
|
361
|
+
# If we couldn't find a class token, fall back to the last non-empty token
|
|
362
|
+
if class_token_logprob is None:
|
|
363
|
+
log.warning("No class token logprob found, using the last token as fallback")
|
|
364
|
+
class_token_logprob = token_logprobs[-1]
|
|
365
|
+
|
|
366
|
+
top_logprobs = class_token_logprob.top_logprobs
|
|
367
|
+
|
|
320
368
|
# Create a map of 'class_1' -> probability, using token combinations
|
|
321
369
|
class_probs = {}
|
|
322
370
|
for top_token in top_logprobs:
|
|
@@ -337,6 +385,8 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
337
385
|
score = 0.0
|
|
338
386
|
|
|
339
387
|
rerank_data = {"document": document, "relevance_score": score}
|
|
388
|
+
log.debug(f"Reranker score: {score}")
|
|
389
|
+
log.debug("End search_relevancy_score")
|
|
340
390
|
return rerank_data
|
|
341
391
|
|
|
342
392
|
def get_scores(self, query: str, documents: list[str]):
|
|
@@ -36,7 +36,7 @@ class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
|
|
|
36
36
|
return []
|
|
37
37
|
|
|
38
38
|
# Stream reranking update.
|
|
39
|
-
dispatch_custom_event(
|
|
39
|
+
dispatch_custom_event("rerank_begin", {"num_documents": len(documents)})
|
|
40
40
|
|
|
41
41
|
try:
|
|
42
42
|
# Prepare query-document pairs
|
|
@@ -73,10 +73,10 @@ class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
|
|
|
73
73
|
return filtered_docs
|
|
74
74
|
|
|
75
75
|
except Exception as e:
|
|
76
|
-
error_msg =
|
|
77
|
-
log.
|
|
76
|
+
error_msg = "Error during async document compression:"
|
|
77
|
+
log.exception(error_msg)
|
|
78
78
|
if callbacks:
|
|
79
|
-
await callbacks.on_retriever_error(error_msg)
|
|
79
|
+
await callbacks.on_retriever_error(f"{error_msg} {e}")
|
|
80
80
|
return documents # Return original documents on error
|
|
81
81
|
|
|
82
82
|
def compress_documents(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import re
|
|
2
|
-
|
|
3
|
-
from pydantic import BaseModel, Field
|
|
4
|
-
from typing import List, Any, Optional, Dict, Tuple, Union, Callable
|
|
5
|
-
import collections
|
|
6
2
|
import math
|
|
3
|
+
import logging
|
|
4
|
+
import collections
|
|
5
|
+
from typing import List, Any, Optional, Dict, Tuple, Union, Callable
|
|
7
6
|
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
8
|
from langchain.chains.llm import LLMChain
|
|
9
9
|
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
|
10
10
|
from langchain_core.documents.base import Document
|
|
@@ -39,9 +39,7 @@ class MetadataFilter(BaseModel):
|
|
|
39
39
|
"""Represents an LLM generated metadata filter to apply to a PostgreSQL query."""
|
|
40
40
|
|
|
41
41
|
attribute: str = Field(description="Database column to apply filter to")
|
|
42
|
-
comparator: str = Field(
|
|
43
|
-
description="PostgreSQL comparator to use to filter database column"
|
|
44
|
-
)
|
|
42
|
+
comparator: str = Field(description="PostgreSQL comparator to use to filter database column")
|
|
45
43
|
value: Any = Field(description="Value to use to filter database column")
|
|
46
44
|
|
|
47
45
|
|
|
@@ -56,9 +54,7 @@ class AblativeMetadataFilter(MetadataFilter):
|
|
|
56
54
|
class MetadataFilters(BaseModel):
|
|
57
55
|
"""List of LLM generated metadata filters to apply to a PostgreSQL query."""
|
|
58
56
|
|
|
59
|
-
filters: List[MetadataFilter] = Field(
|
|
60
|
-
description="List of PostgreSQL metadata filters to apply for user query"
|
|
61
|
-
)
|
|
57
|
+
filters: List[MetadataFilter] = Field(description="List of PostgreSQL metadata filters to apply for user query")
|
|
62
58
|
|
|
63
59
|
|
|
64
60
|
class SQLRetriever(BaseRetriever):
|
|
@@ -142,25 +138,17 @@ class SQLRetriever(BaseRetriever):
|
|
|
142
138
|
elif isinstance(schema, ColumnSchema):
|
|
143
139
|
collection_key = "values"
|
|
144
140
|
else:
|
|
145
|
-
raise Exception(
|
|
146
|
-
"schema must be either a DatabaseSchema, TableSchema, or ColumnSchema."
|
|
147
|
-
)
|
|
141
|
+
raise Exception("schema must be either a DatabaseSchema, TableSchema, or ColumnSchema.")
|
|
148
142
|
|
|
149
143
|
if update is not None:
|
|
150
|
-
ordered = collections.OrderedDict(
|
|
151
|
-
sorted(update.items(), key=key, reverse=True)
|
|
152
|
-
)
|
|
144
|
+
ordered = collections.OrderedDict(sorted(update.items(), key=key, reverse=True))
|
|
153
145
|
else:
|
|
154
|
-
ordered = collections.OrderedDict(
|
|
155
|
-
sorted(getattr(schema, collection_key).items(), key=key, reverse=True)
|
|
156
|
-
)
|
|
146
|
+
ordered = collections.OrderedDict(sorted(getattr(schema, collection_key).items(), key=key, reverse=True))
|
|
157
147
|
schema = schema.model_copy(update={collection_key: ordered})
|
|
158
148
|
|
|
159
149
|
return schema
|
|
160
150
|
|
|
161
|
-
def _sort_database_schema_by_key(
|
|
162
|
-
self, database_schema: DatabaseSchema, key: Callable
|
|
163
|
-
) -> DatabaseSchema:
|
|
151
|
+
def _sort_database_schema_by_key(self, database_schema: DatabaseSchema, key: Callable) -> DatabaseSchema:
|
|
164
152
|
"""Re-build schema with OrderedDicts"""
|
|
165
153
|
tables = {}
|
|
166
154
|
# build new tables dict
|
|
@@ -169,17 +157,11 @@ class SQLRetriever(BaseRetriever):
|
|
|
169
157
|
# build new column dict
|
|
170
158
|
for column_key, column_schema in table_schema.columns.items():
|
|
171
159
|
# sort values directly and update column schema
|
|
172
|
-
columns[column_key] = self._sort_schema_by_key(
|
|
173
|
-
schema=column_schema, key=key
|
|
174
|
-
)
|
|
160
|
+
columns[column_key] = self._sort_schema_by_key(schema=column_schema, key=key)
|
|
175
161
|
# update table schema and sort
|
|
176
|
-
tables[table_key] = self._sort_schema_by_key(
|
|
177
|
-
schema=table_schema, key=key, update=columns
|
|
178
|
-
)
|
|
162
|
+
tables[table_key] = self._sort_schema_by_key(schema=table_schema, key=key, update=columns)
|
|
179
163
|
# update table schema and sort
|
|
180
|
-
database_schema = self._sort_schema_by_key(
|
|
181
|
-
schema=database_schema, key=key, update=tables
|
|
182
|
-
)
|
|
164
|
+
database_schema = self._sort_schema_by_key(schema=database_schema, key=key, update=tables)
|
|
183
165
|
|
|
184
166
|
return database_schema
|
|
185
167
|
|
|
@@ -191,15 +173,12 @@ class SQLRetriever(BaseRetriever):
|
|
|
191
173
|
boolean_system_prompt: bool = True,
|
|
192
174
|
format_instructions: Optional[str] = None,
|
|
193
175
|
) -> ChatPromptTemplate:
|
|
194
|
-
|
|
195
176
|
if boolean_system_prompt is True:
|
|
196
177
|
system_prompt = self.boolean_system_prompt
|
|
197
178
|
else:
|
|
198
179
|
system_prompt = self.generative_system_prompt
|
|
199
180
|
|
|
200
|
-
prepared_column_prompt = self._prepare_column_prompt(
|
|
201
|
-
column_schema=column_schema, table_schema=table_schema
|
|
202
|
-
)
|
|
181
|
+
prepared_column_prompt = self._prepare_column_prompt(column_schema=column_schema, table_schema=table_schema)
|
|
203
182
|
column_schema_str = (
|
|
204
183
|
prepared_column_prompt.messages[1]
|
|
205
184
|
.format(
|
|
@@ -290,7 +269,6 @@ Below is a list of comparison operators for constructing filters for this value
|
|
|
290
269
|
table_schema: TableSchema,
|
|
291
270
|
boolean_system_prompt: bool = True,
|
|
292
271
|
) -> ChatPromptTemplate:
|
|
293
|
-
|
|
294
272
|
if boolean_system_prompt is True:
|
|
295
273
|
system_prompt = self.boolean_system_prompt
|
|
296
274
|
else:
|
|
@@ -312,9 +290,7 @@ Below is a list of comparison operators for constructing filters for this value
|
|
|
312
290
|
[("system", system_prompt), ("user", self.column_prompt_template)]
|
|
313
291
|
)
|
|
314
292
|
|
|
315
|
-
header_str =
|
|
316
|
-
f"This schema describes a column in the {table_schema.table} table."
|
|
317
|
-
)
|
|
293
|
+
header_str = f"This schema describes a column in the {table_schema.table} table."
|
|
318
294
|
|
|
319
295
|
value_str = """
|
|
320
296
|
## **Content**
|
|
@@ -388,26 +364,18 @@ Below is a description of the contents in this column in list format:
|
|
|
388
364
|
)
|
|
389
365
|
|
|
390
366
|
def _rank_schema(self, prompt: ChatPromptTemplate, query: str) -> float:
|
|
391
|
-
rank_chain = LLMChain(
|
|
392
|
-
llm=self.llm.bind(logprobs=True), prompt=prompt, return_final_only=False
|
|
393
|
-
)
|
|
367
|
+
rank_chain = LLMChain(llm=self.llm.bind(logprobs=True), prompt=prompt, return_final_only=False)
|
|
394
368
|
output = rank_chain({"query": query}) # returns metadata
|
|
395
369
|
|
|
396
370
|
# parse through metadata tokens until encountering either yes, or no.
|
|
397
371
|
score = None # a None score indicates the model output could not be parsed.
|
|
398
|
-
for content in output["full_generation"][0].message.response_metadata[
|
|
399
|
-
"logprobs"
|
|
400
|
-
]["content"]:
|
|
372
|
+
for content in output["full_generation"][0].message.response_metadata["logprobs"]["content"]:
|
|
401
373
|
# Convert answer to score using the model's confidence
|
|
402
374
|
if content["token"].lower().strip() == "yes":
|
|
403
|
-
score = (
|
|
404
|
-
1 + math.exp(content["logprob"])
|
|
405
|
-
) / 2 # If yes, use the model's confidence
|
|
375
|
+
score = (1 + math.exp(content["logprob"])) / 2 # If yes, use the model's confidence
|
|
406
376
|
break
|
|
407
377
|
elif content["token"].lower().strip() == "no":
|
|
408
|
-
score = (
|
|
409
|
-
1 - math.exp(content["logprob"])
|
|
410
|
-
) / 2 # If no, invert the confidence
|
|
378
|
+
score = (1 - math.exp(content["logprob"])) / 2 # If no, invert the confidence
|
|
411
379
|
break
|
|
412
380
|
|
|
413
381
|
if score is None:
|
|
@@ -465,9 +433,7 @@ Below is a description of the contents in this column in list format:
|
|
|
465
433
|
table_schema=table_schema,
|
|
466
434
|
boolean_system_prompt=True,
|
|
467
435
|
)
|
|
468
|
-
column_schema.relevance = self._rank_schema(
|
|
469
|
-
prompt=prompt, query=query
|
|
470
|
-
)
|
|
436
|
+
column_schema.relevance = self._rank_schema(prompt=prompt, query=query)
|
|
471
437
|
|
|
472
438
|
columns[column_key] = column_schema
|
|
473
439
|
|
|
@@ -512,9 +478,7 @@ Below is a description of the contents in this column in list format:
|
|
|
512
478
|
table_schema=table_schema,
|
|
513
479
|
boolean_system_prompt=True,
|
|
514
480
|
)
|
|
515
|
-
value_schema.relevance = self._rank_schema(
|
|
516
|
-
prompt=prompt, query=query
|
|
517
|
-
)
|
|
481
|
+
value_schema.relevance = self._rank_schema(prompt=prompt, query=query)
|
|
518
482
|
|
|
519
483
|
values[value_key] = value_schema
|
|
520
484
|
|
|
@@ -592,19 +556,13 @@ Below is a description of the contents in this column in list format:
|
|
|
592
556
|
for table_key, table_schema in ordered_database_schema.tables.items():
|
|
593
557
|
for column_key, column_schema in table_schema.columns.items():
|
|
594
558
|
for value_key, value_schema in column_schema.values.items():
|
|
595
|
-
ablation_value_dict[(table_key, column_key, value_key)] =
|
|
596
|
-
value_schema.relevance
|
|
597
|
-
)
|
|
559
|
+
ablation_value_dict[(table_key, column_key, value_key)] = value_schema.relevance
|
|
598
560
|
|
|
599
|
-
ablation_value_dict = collections.OrderedDict(
|
|
600
|
-
sorted(ablation_value_dict.items(), key=lambda x: x[1])
|
|
601
|
-
)
|
|
561
|
+
ablation_value_dict = collections.OrderedDict(sorted(ablation_value_dict.items(), key=lambda x: x[1]))
|
|
602
562
|
|
|
603
563
|
relevance_scores = list(ablation_value_dict.values())
|
|
604
564
|
if len(relevance_scores) > 0:
|
|
605
|
-
ablation_quantiles = np.quantile(
|
|
606
|
-
relevance_scores, np.linspace(0, 1, self.num_retries + 2)[1:-1]
|
|
607
|
-
)
|
|
565
|
+
ablation_quantiles = np.quantile(relevance_scores, np.linspace(0, 1, self.num_retries + 2)[1:-1])
|
|
608
566
|
else:
|
|
609
567
|
ablation_quantiles = None
|
|
610
568
|
|
|
@@ -628,11 +586,7 @@ Below is a description of the contents in this column in list format:
|
|
|
628
586
|
ablated_filters = []
|
|
629
587
|
for filter in metadata_filters:
|
|
630
588
|
for key in ablated_dict.keys():
|
|
631
|
-
if
|
|
632
|
-
filter.schema_table in key
|
|
633
|
-
and filter.schema_column in key
|
|
634
|
-
and filter.schema_value in key
|
|
635
|
-
):
|
|
589
|
+
if filter.schema_table in key and filter.schema_column in key and filter.schema_value in key:
|
|
636
590
|
ablated_filters.append(filter)
|
|
637
591
|
|
|
638
592
|
return ablated_filters
|
|
@@ -646,9 +600,7 @@ Below is a description of the contents in this column in list format:
|
|
|
646
600
|
pass
|
|
647
601
|
|
|
648
602
|
def _prepare_retrieval_query(self, query: str) -> str:
|
|
649
|
-
rewrite_prompt = PromptTemplate(
|
|
650
|
-
input_variables=["input"], template=self.rewrite_prompt_template
|
|
651
|
-
)
|
|
603
|
+
rewrite_prompt = PromptTemplate(input_variables=["input"], template=self.rewrite_prompt_template)
|
|
652
604
|
rewrite_chain = LLMChain(llm=self.llm, prompt=rewrite_prompt)
|
|
653
605
|
return rewrite_chain.predict(input=query)
|
|
654
606
|
|
|
@@ -668,9 +620,7 @@ Below is a description of the contents in this column in list format:
|
|
|
668
620
|
# Add Table JOIN statements
|
|
669
621
|
join_clauses = set()
|
|
670
622
|
for metadata_filter in metadata_filters:
|
|
671
|
-
join_clause = ranked_database_schema.tables[
|
|
672
|
-
metadata_filter.schema_table
|
|
673
|
-
].join
|
|
623
|
+
join_clause = ranked_database_schema.tables[metadata_filter.schema_table].join
|
|
674
624
|
if join_clause in join_clauses:
|
|
675
625
|
continue
|
|
676
626
|
else:
|
|
@@ -688,12 +638,12 @@ Below is a description of the contents in this column in list format:
|
|
|
688
638
|
if i < len(metadata_filters) - 1:
|
|
689
639
|
base_query += " AND "
|
|
690
640
|
|
|
691
|
-
base_query +=
|
|
641
|
+
base_query += (
|
|
642
|
+
f" ORDER BY e.embeddings {self.distance_function.value[0]} '{{embeddings}}' LIMIT {self.search_kwargs.k};"
|
|
643
|
+
)
|
|
692
644
|
return base_query
|
|
693
645
|
|
|
694
|
-
def _generate_filter(
|
|
695
|
-
self, prompt: ChatPromptTemplate, query: str
|
|
696
|
-
) -> MetadataFilter:
|
|
646
|
+
def _generate_filter(self, prompt: ChatPromptTemplate, query: str) -> MetadataFilter:
|
|
697
647
|
gen_filter_chain = LLMChain(llm=self.llm, prompt=prompt)
|
|
698
648
|
output = gen_filter_chain({"query": query})
|
|
699
649
|
return output
|
|
@@ -714,28 +664,22 @@ Below is a description of the contents in this column in list format:
|
|
|
714
664
|
# must use generation if field is a dictionary of tuples or a list
|
|
715
665
|
if type(value_schema.value) in [list, dict]:
|
|
716
666
|
try:
|
|
717
|
-
metadata_prompt: ChatPromptTemplate = (
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
boolean_system_prompt=False,
|
|
724
|
-
)
|
|
667
|
+
metadata_prompt: ChatPromptTemplate = self._prepare_value_prompt(
|
|
668
|
+
format_instructions=parser.get_format_instructions(),
|
|
669
|
+
value_schema=value_schema,
|
|
670
|
+
column_schema=column_schema,
|
|
671
|
+
table_schema=table_schema,
|
|
672
|
+
boolean_system_prompt=False,
|
|
725
673
|
)
|
|
726
674
|
|
|
727
|
-
metadata_filters_chain = LLMChain(
|
|
728
|
-
llm=self.llm, prompt=metadata_prompt
|
|
729
|
-
)
|
|
675
|
+
metadata_filters_chain = LLMChain(llm=self.llm, prompt=metadata_prompt)
|
|
730
676
|
metadata_filter_output = metadata_filters_chain.predict(
|
|
731
677
|
query=query,
|
|
732
678
|
)
|
|
733
679
|
|
|
734
680
|
# If the LLM outputs raw JSON, use it as-is.
|
|
735
681
|
# If the LLM outputs anything including a json markdown section, use the last one.
|
|
736
|
-
json_markdown_output = re.findall(
|
|
737
|
-
r"```json.*```", metadata_filter_output, re.DOTALL
|
|
738
|
-
)
|
|
682
|
+
json_markdown_output = re.findall(r"```json.*```", metadata_filter_output, re.DOTALL)
|
|
739
683
|
if json_markdown_output:
|
|
740
684
|
metadata_filter_output = json_markdown_output[-1]
|
|
741
685
|
# Clean the json tags.
|
|
@@ -754,11 +698,10 @@ Below is a description of the contents in this column in list format:
|
|
|
754
698
|
metadata_filter = AblativeMetadataFilter(**model_dump)
|
|
755
699
|
except OutputParserException as e:
|
|
756
700
|
logger.warning(
|
|
757
|
-
f"LLM failed to generate structured metadata filters: {
|
|
758
|
-
|
|
759
|
-
return HandlerResponse(
|
|
760
|
-
RESPONSE_TYPE.ERROR, error_message=str(e)
|
|
701
|
+
f"LLM failed to generate structured metadata filters: {e}",
|
|
702
|
+
exc_info=logger.isEnabledFor(logging.DEBUG),
|
|
761
703
|
)
|
|
704
|
+
return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
|
|
762
705
|
else:
|
|
763
706
|
metadata_filter = AblativeMetadataFilter(
|
|
764
707
|
attribute=column_schema.column,
|
|
@@ -779,24 +722,17 @@ Below is a description of the contents in this column in list format:
|
|
|
779
722
|
embeddings_str: str,
|
|
780
723
|
) -> HandlerResponse:
|
|
781
724
|
try:
|
|
782
|
-
checked_sql_query = self._prepare_pgvector_query(
|
|
783
|
-
|
|
784
|
-
)
|
|
785
|
-
checked_sql_query_with_embeddings = checked_sql_query.format(
|
|
786
|
-
embeddings=embeddings_str
|
|
787
|
-
)
|
|
788
|
-
return self.vector_store_handler.native_query(
|
|
789
|
-
checked_sql_query_with_embeddings
|
|
790
|
-
)
|
|
725
|
+
checked_sql_query = self._prepare_pgvector_query(ranked_database_schema, metadata_filters)
|
|
726
|
+
checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=embeddings_str)
|
|
727
|
+
return self.vector_store_handler.native_query(checked_sql_query_with_embeddings)
|
|
791
728
|
except Exception as e:
|
|
792
729
|
logger.warning(
|
|
793
|
-
f"Failed to prepare and execute SQL query from structured metadata: {
|
|
730
|
+
f"Failed to prepare and execute SQL query from structured metadata: {e}",
|
|
731
|
+
exc_info=logger.isEnabledFor(logging.DEBUG),
|
|
794
732
|
)
|
|
795
733
|
return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
|
|
796
734
|
|
|
797
|
-
def _get_relevant_documents(
|
|
798
|
-
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
799
|
-
) -> List[Document]:
|
|
735
|
+
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
|
|
800
736
|
# Rewrite query to be suitable for retrieval.
|
|
801
737
|
retrieval_query = self._prepare_retrieval_query(query)
|
|
802
738
|
|
|
@@ -804,14 +740,10 @@ Below is a description of the contents in this column in list format:
|
|
|
804
740
|
embedded_query = self.embeddings_model.embed_query(retrieval_query)
|
|
805
741
|
|
|
806
742
|
# Search for relevant filters
|
|
807
|
-
ranked_database_schema, ablation_value_dict, ablation_quantiles = (
|
|
808
|
-
self._breadth_first_search(query=query)
|
|
809
|
-
)
|
|
743
|
+
ranked_database_schema, ablation_value_dict, ablation_quantiles = self._breadth_first_search(query=query)
|
|
810
744
|
|
|
811
745
|
# Generate metadata filters
|
|
812
|
-
metadata_filters = self._generate_metadata_filters(
|
|
813
|
-
query=query, ranked_database_schema=ranked_database_schema
|
|
814
|
-
)
|
|
746
|
+
metadata_filters = self._generate_metadata_filters(query=query, ranked_database_schema=ranked_database_schema)
|
|
815
747
|
|
|
816
748
|
if type(metadata_filters) is list:
|
|
817
749
|
# Initial Execution of the similarity search with metadata filters.
|
|
@@ -830,9 +762,7 @@ Below is a description of the contents in this column in list format:
|
|
|
830
762
|
break
|
|
831
763
|
elif document_response.resp_type == RESPONSE_TYPE.ERROR:
|
|
832
764
|
# LLMs won't always generate structured metadata so we should have a fallback after retrying.
|
|
833
|
-
logger.info(
|
|
834
|
-
f"SQL Retriever query failed with error {document_response.error_message}"
|
|
835
|
-
)
|
|
765
|
+
logger.info(f"SQL Retriever query failed with error {document_response.error_message}")
|
|
836
766
|
else:
|
|
837
767
|
logger.info(
|
|
838
768
|
f"SQL Retriever did not retrieve {self.min_k} documents: {len(document_response.data_frame)} documents retrieved."
|
|
@@ -867,17 +797,9 @@ Below is a description of the contents in this column in list format:
|
|
|
867
797
|
return retrieved_documents
|
|
868
798
|
|
|
869
799
|
# If the SQL query constructed did not return any documents, fallback.
|
|
870
|
-
logger.info(
|
|
871
|
-
|
|
872
|
-
)
|
|
873
|
-
return self.fallback_retriever._get_relevant_documents(
|
|
874
|
-
retrieval_query, run_manager=run_manager
|
|
875
|
-
)
|
|
800
|
+
logger.info("No documents returned from SQL retriever, using fallback retriever.")
|
|
801
|
+
return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager)
|
|
876
802
|
else:
|
|
877
803
|
# If no metadata fields could be generated fallback.
|
|
878
|
-
logger.info(
|
|
879
|
-
|
|
880
|
-
)
|
|
881
|
-
return self.fallback_retriever._get_relevant_documents(
|
|
882
|
-
retrieval_query, run_manager=run_manager
|
|
883
|
-
)
|
|
804
|
+
logger.info("No metadata fields were successfully generated, using fallback retriever.")
|
|
805
|
+
return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager)
|