MindsDB 25.9.2.0a1__py3-none-any.whl → 25.9.3rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (116) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/__main__.py +39 -20
  3. mindsdb/api/a2a/agent.py +7 -9
  4. mindsdb/api/a2a/common/server/server.py +3 -3
  5. mindsdb/api/a2a/common/server/task_manager.py +4 -4
  6. mindsdb/api/a2a/task_manager.py +15 -17
  7. mindsdb/api/common/middleware.py +9 -11
  8. mindsdb/api/executor/command_executor.py +2 -4
  9. mindsdb/api/executor/datahub/datanodes/datanode.py +2 -2
  10. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +100 -48
  11. mindsdb/api/executor/datahub/datanodes/project_datanode.py +8 -4
  12. mindsdb/api/executor/datahub/datanodes/system_tables.py +1 -1
  13. mindsdb/api/executor/exceptions.py +29 -10
  14. mindsdb/api/executor/planner/plan_join.py +17 -3
  15. mindsdb/api/executor/sql_query/sql_query.py +74 -74
  16. mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +1 -2
  17. mindsdb/api/executor/sql_query/steps/subselect_step.py +0 -1
  18. mindsdb/api/executor/utilities/functions.py +6 -6
  19. mindsdb/api/executor/utilities/sql.py +32 -16
  20. mindsdb/api/http/gui.py +5 -11
  21. mindsdb/api/http/initialize.py +8 -10
  22. mindsdb/api/http/namespaces/agents.py +10 -12
  23. mindsdb/api/http/namespaces/analysis.py +13 -20
  24. mindsdb/api/http/namespaces/auth.py +1 -1
  25. mindsdb/api/http/namespaces/config.py +15 -11
  26. mindsdb/api/http/namespaces/databases.py +140 -201
  27. mindsdb/api/http/namespaces/file.py +15 -4
  28. mindsdb/api/http/namespaces/handlers.py +7 -2
  29. mindsdb/api/http/namespaces/knowledge_bases.py +8 -7
  30. mindsdb/api/http/namespaces/models.py +94 -126
  31. mindsdb/api/http/namespaces/projects.py +13 -22
  32. mindsdb/api/http/namespaces/sql.py +33 -25
  33. mindsdb/api/http/namespaces/tab.py +27 -37
  34. mindsdb/api/http/namespaces/views.py +1 -1
  35. mindsdb/api/http/start.py +14 -8
  36. mindsdb/api/mcp/__init__.py +2 -1
  37. mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +15 -20
  38. mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +26 -50
  39. mindsdb/api/mysql/mysql_proxy/utilities/__init__.py +0 -1
  40. mindsdb/api/postgres/postgres_proxy/executor/executor.py +6 -13
  41. mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_packets.py +40 -28
  42. mindsdb/integrations/handlers/byom_handler/byom_handler.py +168 -185
  43. mindsdb/integrations/handlers/file_handler/file_handler.py +7 -0
  44. mindsdb/integrations/handlers/lightwood_handler/functions.py +45 -79
  45. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +13 -1
  46. mindsdb/integrations/handlers/shopify_handler/shopify_handler.py +25 -12
  47. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +2 -1
  48. mindsdb/integrations/handlers/statsforecast_handler/requirements.txt +1 -0
  49. mindsdb/integrations/handlers/statsforecast_handler/requirements_extra.txt +1 -0
  50. mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +4 -4
  51. mindsdb/integrations/libs/api_handler.py +10 -10
  52. mindsdb/integrations/libs/base.py +4 -4
  53. mindsdb/integrations/libs/llm/utils.py +2 -2
  54. mindsdb/integrations/libs/ml_handler_process/create_engine_process.py +4 -7
  55. mindsdb/integrations/libs/ml_handler_process/func_call_process.py +2 -7
  56. mindsdb/integrations/libs/ml_handler_process/learn_process.py +37 -47
  57. mindsdb/integrations/libs/ml_handler_process/update_engine_process.py +4 -7
  58. mindsdb/integrations/libs/ml_handler_process/update_process.py +2 -7
  59. mindsdb/integrations/libs/process_cache.py +132 -140
  60. mindsdb/integrations/libs/response.py +18 -12
  61. mindsdb/integrations/libs/vectordatabase_handler.py +26 -0
  62. mindsdb/integrations/utilities/files/file_reader.py +6 -7
  63. mindsdb/integrations/utilities/rag/config_loader.py +37 -26
  64. mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +59 -9
  65. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +4 -4
  66. mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +55 -133
  67. mindsdb/integrations/utilities/rag/settings.py +58 -133
  68. mindsdb/integrations/utilities/rag/splitters/file_splitter.py +5 -15
  69. mindsdb/interfaces/agents/agents_controller.py +2 -1
  70. mindsdb/interfaces/agents/constants.py +0 -2
  71. mindsdb/interfaces/agents/litellm_server.py +34 -58
  72. mindsdb/interfaces/agents/mcp_client_agent.py +10 -10
  73. mindsdb/interfaces/agents/mindsdb_database_agent.py +5 -5
  74. mindsdb/interfaces/agents/run_mcp_agent.py +12 -21
  75. mindsdb/interfaces/chatbot/chatbot_task.py +20 -23
  76. mindsdb/interfaces/chatbot/polling.py +30 -18
  77. mindsdb/interfaces/data_catalog/data_catalog_loader.py +10 -10
  78. mindsdb/interfaces/database/integrations.py +19 -2
  79. mindsdb/interfaces/file/file_controller.py +6 -6
  80. mindsdb/interfaces/functions/controller.py +1 -1
  81. mindsdb/interfaces/functions/to_markdown.py +2 -2
  82. mindsdb/interfaces/jobs/jobs_controller.py +5 -5
  83. mindsdb/interfaces/jobs/scheduler.py +3 -8
  84. mindsdb/interfaces/knowledge_base/controller.py +50 -23
  85. mindsdb/interfaces/knowledge_base/preprocessing/json_chunker.py +40 -61
  86. mindsdb/interfaces/model/model_controller.py +170 -166
  87. mindsdb/interfaces/query_context/context_controller.py +14 -2
  88. mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +6 -4
  89. mindsdb/interfaces/skills/retrieval_tool.py +43 -50
  90. mindsdb/interfaces/skills/skill_tool.py +2 -2
  91. mindsdb/interfaces/skills/sql_agent.py +25 -19
  92. mindsdb/interfaces/storage/fs.py +114 -169
  93. mindsdb/interfaces/storage/json.py +19 -18
  94. mindsdb/interfaces/tabs/tabs_controller.py +49 -72
  95. mindsdb/interfaces/tasks/task_monitor.py +3 -9
  96. mindsdb/interfaces/tasks/task_thread.py +7 -9
  97. mindsdb/interfaces/triggers/trigger_task.py +7 -13
  98. mindsdb/interfaces/triggers/triggers_controller.py +47 -50
  99. mindsdb/migrations/migrate.py +16 -16
  100. mindsdb/utilities/api_status.py +58 -0
  101. mindsdb/utilities/config.py +49 -0
  102. mindsdb/utilities/exception.py +40 -1
  103. mindsdb/utilities/fs.py +0 -1
  104. mindsdb/utilities/hooks/profiling.py +17 -14
  105. mindsdb/utilities/langfuse.py +40 -45
  106. mindsdb/utilities/log.py +272 -0
  107. mindsdb/utilities/ml_task_queue/consumer.py +52 -58
  108. mindsdb/utilities/ml_task_queue/producer.py +26 -30
  109. mindsdb/utilities/render/sqlalchemy_render.py +7 -6
  110. mindsdb/utilities/utils.py +2 -2
  111. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/METADATA +269 -264
  112. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/RECORD +115 -115
  113. mindsdb/api/mysql/mysql_proxy/utilities/exceptions.py +0 -14
  114. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/WHEEL +0 -0
  115. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/licenses/LICENSE +0 -0
  116. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.9.3rc1.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,26 @@
1
1
  """Utility functions for RAG pipeline configuration"""
2
+
2
3
  from typing import Dict, Any, Optional
3
4
 
4
5
  from mindsdb.utilities.log import getLogger
5
6
  from mindsdb.integrations.utilities.rag.settings import (
6
- RetrieverType, MultiVectorRetrieverMode, SearchType,
7
- SearchKwargs, SummarizationConfig, VectorStoreConfig,
8
- RerankerConfig, RAGPipelineModel, DEFAULT_COLLECTION_NAME
7
+ RetrieverType,
8
+ MultiVectorRetrieverMode,
9
+ SearchType,
10
+ SearchKwargs,
11
+ SummarizationConfig,
12
+ VectorStoreConfig,
13
+ RerankerConfig,
14
+ RAGPipelineModel,
15
+ DEFAULT_COLLECTION_NAME,
9
16
  )
10
17
 
11
18
  logger = getLogger(__name__)
12
19
 
13
20
 
14
- def load_rag_config(base_config: Dict[str, Any], kb_params: Optional[Dict[str, Any]] = None, embedding_model: Any = None) -> RAGPipelineModel:
21
+ def load_rag_config(
22
+ base_config: Dict[str, Any], kb_params: Optional[Dict[str, Any]] = None, embedding_model: Any = None
23
+ ) -> RAGPipelineModel:
15
24
  """
16
25
  Load and validate RAG configuration parameters. This function handles the conversion of configuration
17
26
  parameters into their appropriate types and ensures all required settings are properly configured.
@@ -37,41 +46,43 @@ def load_rag_config(base_config: Dict[str, Any], kb_params: Optional[Dict[str, A
37
46
 
38
47
  # Set embedding model if provided
39
48
  if embedding_model is not None:
40
- rag_params['embedding_model'] = embedding_model
49
+ rag_params["embedding_model"] = embedding_model
41
50
 
42
51
  # Handle enums and type conversions
43
- if 'retriever_type' in rag_params:
44
- rag_params['retriever_type'] = RetrieverType(rag_params['retriever_type'])
45
- if 'multi_retriever_mode' in rag_params:
46
- rag_params['multi_retriever_mode'] = MultiVectorRetrieverMode(rag_params['multi_retriever_mode'])
47
- if 'search_type' in rag_params:
48
- rag_params['search_type'] = SearchType(rag_params['search_type'])
52
+ if "retriever_type" in rag_params:
53
+ rag_params["retriever_type"] = RetrieverType(rag_params["retriever_type"])
54
+ if "multi_retriever_mode" in rag_params:
55
+ rag_params["multi_retriever_mode"] = MultiVectorRetrieverMode(rag_params["multi_retriever_mode"])
56
+ if "search_type" in rag_params:
57
+ rag_params["search_type"] = SearchType(rag_params["search_type"])
49
58
 
50
59
  # Handle search kwargs if present
51
- if 'search_kwargs' in rag_params and isinstance(rag_params['search_kwargs'], dict):
52
- rag_params['search_kwargs'] = SearchKwargs(**rag_params['search_kwargs'])
60
+ if "search_kwargs" in rag_params and isinstance(rag_params["search_kwargs"], dict):
61
+ rag_params["search_kwargs"] = SearchKwargs(**rag_params["search_kwargs"])
53
62
 
54
63
  # Handle summarization config if present
55
- summarization_config = rag_params.get('summarization_config')
64
+ summarization_config = rag_params.get("summarization_config")
56
65
  if summarization_config is not None and isinstance(summarization_config, dict):
57
- rag_params['summarization_config'] = SummarizationConfig(**summarization_config)
66
+ rag_params["summarization_config"] = SummarizationConfig(**summarization_config)
58
67
 
59
68
  # Handle vector store config
60
- if 'vector_store_config' in rag_params:
61
- if isinstance(rag_params['vector_store_config'], dict):
62
- rag_params['vector_store_config'] = VectorStoreConfig(**rag_params['vector_store_config'])
69
+ if "vector_store_config" in rag_params:
70
+ if isinstance(rag_params["vector_store_config"], dict):
71
+ rag_params["vector_store_config"] = VectorStoreConfig(**rag_params["vector_store_config"])
63
72
  else:
64
- rag_params['vector_store_config'] = {}
65
- logger.warning(f'No collection_name specified for the retrieval tool, '
66
- f"using default collection_name: '{DEFAULT_COLLECTION_NAME}'"
67
- f'\nWarning: If this collection does not exist, no data will be retrieved')
73
+ rag_params["vector_store_config"] = {}
74
+ logger.warning(
75
+ f"No collection_name specified for the retrieval tool, "
76
+ f"using default collection_name: '{DEFAULT_COLLECTION_NAME}'"
77
+ f"\nWarning: If this collection does not exist, no data will be retrieved"
78
+ )
68
79
 
69
- if 'reranker_config' in rag_params:
70
- rag_params['reranker_config'] = RerankerConfig(**rag_params['reranker_config'])
80
+ if "reranker_config" in rag_params:
81
+ rag_params["reranker_config"] = RerankerConfig(**rag_params["reranker_config"])
71
82
 
72
83
  # Convert to RAGPipelineModel with validation
73
84
  try:
74
85
  return RAGPipelineModel(**rag_params)
75
86
  except Exception as e:
76
- logger.error(f"Invalid RAG configuration: {str(e)}")
77
- raise ValueError(f"Configuration validation failed: {str(e)}")
87
+ logger.exception("Invalid RAG configuration:")
88
+ raise ValueError(f"Configuration validation failed: {str(e)}") from e
@@ -13,7 +13,15 @@ from typing import Any, List, Optional, Tuple
13
13
  from openai import AsyncOpenAI, AsyncAzureOpenAI
14
14
  from pydantic import BaseModel
15
15
 
16
- from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
16
+ from mindsdb.integrations.utilities.rag.settings import (
17
+ DEFAULT_RERANKING_MODEL,
18
+ DEFAULT_LLM_ENDPOINT,
19
+ DEFAULT_RERANKER_N,
20
+ DEFAULT_RERANKER_LOGPROBS,
21
+ DEFAULT_RERANKER_TOP_LOGPROBS,
22
+ DEFAULT_RERANKER_MAX_TOKENS,
23
+ DEFAULT_VALID_CLASS_TOKENS,
24
+ )
17
25
  from mindsdb.integrations.libs.base import BaseMLEngine
18
26
 
19
27
  log = logging.getLogger(__name__)
@@ -38,6 +46,11 @@ class BaseLLMReranker(BaseModel, ABC):
38
46
  request_timeout: float = 20.0 # Timeout for API requests
39
47
  early_stop: bool = True # Whether to enable early stopping
40
48
  early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
49
+ n: int = DEFAULT_RERANKER_N # Number of completions to generate
50
+ logprobs: bool = DEFAULT_RERANKER_LOGPROBS # Whether to include log probabilities
51
+ top_logprobs: int = DEFAULT_RERANKER_TOP_LOGPROBS # Number of top log probabilities to include
52
+ max_tokens: int = DEFAULT_RERANKER_MAX_TOKENS # Maximum tokens to generate
53
+ valid_class_tokens: List[str] = DEFAULT_VALID_CLASS_TOKENS
41
54
 
42
55
  class Config:
43
56
  arbitrary_types_allowed = True
@@ -142,7 +155,7 @@ class BaseLLMReranker(BaseModel, ABC):
142
155
  return ranked_results
143
156
  except Exception as e:
144
157
  # Don't let early stopping errors stop the whole process
145
- log.warning(f"Error in early stopping check: {str(e)}")
158
+ log.warning(f"Error in early stopping check: {e}")
146
159
 
147
160
  return ranked_results
148
161
 
@@ -234,6 +247,28 @@ class BaseLLMReranker(BaseModel, ABC):
234
247
  return rerank_data
235
248
 
236
249
  async def search_relevancy_score(self, query: str, document: str) -> Any:
250
+ """
251
+ This method is used to score the relevance of a document to a query.
252
+
253
+ Args:
254
+ query: The query to score the relevance of.
255
+ document: The document to score the relevance of.
256
+
257
+ Returns:
258
+ A dictionary with the document and the relevance score.
259
+ """
260
+
261
+ log.debug("Start search_relevancy_score")
262
+ log.debug(f"Reranker query: {query[:5]}")
263
+ log.debug(f"Reranker document: {document[:50]}")
264
+ log.debug(f"Reranker model: {self.model}")
265
+ log.debug(f"Reranker temperature: {self.temperature}")
266
+ log.debug(f"Reranker n: {self.n}")
267
+ log.debug(f"Reranker logprobs: {self.logprobs}")
268
+ log.debug(f"Reranker top_logprobs: {self.top_logprobs}")
269
+ log.debug(f"Reranker max_tokens: {self.max_tokens}")
270
+ log.debug(f"Reranker valid_class_tokens: {self.valid_class_tokens}")
271
+
237
272
  response = await self.client.chat.completions.create(
238
273
  model=self.model,
239
274
  messages=[
@@ -306,17 +341,30 @@ class BaseLLMReranker(BaseModel, ABC):
306
341
  },
307
342
  ],
308
343
  temperature=self.temperature,
309
- n=1,
310
- logprobs=True,
311
- top_logprobs=4,
312
- max_tokens=3,
344
+ n=self.n,
345
+ logprobs=self.logprobs,
346
+ top_logprobs=self.top_logprobs,
347
+ max_tokens=self.max_tokens,
313
348
  )
314
349
 
315
350
  # Extract response and logprobs
316
351
  token_logprobs = response.choices[0].logprobs.content
317
- # Reconstruct the prediction and extract the top logprobs from the final token (e.g., "1")
318
- final_token_logprob = token_logprobs[-1]
319
- top_logprobs = final_token_logprob.top_logprobs
352
+
353
+ # Find the token that contains the class number
354
+ # Instead of just taking the last token, search for the actual class number token
355
+ class_token_logprob = None
356
+ for token_logprob in reversed(token_logprobs):
357
+ if token_logprob.token in self.valid_class_tokens:
358
+ class_token_logprob = token_logprob
359
+ break
360
+
361
+ # If we couldn't find a class token, fall back to the last non-empty token
362
+ if class_token_logprob is None:
363
+ log.warning("No class token logprob found, using the last token as fallback")
364
+ class_token_logprob = token_logprobs[-1]
365
+
366
+ top_logprobs = class_token_logprob.top_logprobs
367
+
320
368
  # Create a map of 'class_1' -> probability, using token combinations
321
369
  class_probs = {}
322
370
  for top_token in top_logprobs:
@@ -337,6 +385,8 @@ class BaseLLMReranker(BaseModel, ABC):
337
385
  score = 0.0
338
386
 
339
387
  rerank_data = {"document": document, "relevance_score": score}
388
+ log.debug(f"Reranker score: {score}")
389
+ log.debug("End search_relevancy_score")
340
390
  return rerank_data
341
391
 
342
392
  def get_scores(self, query: str, documents: list[str]):
@@ -36,7 +36,7 @@ class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
36
36
  return []
37
37
 
38
38
  # Stream reranking update.
39
- dispatch_custom_event('rerank_begin', {'num_documents': len(documents)})
39
+ dispatch_custom_event("rerank_begin", {"num_documents": len(documents)})
40
40
 
41
41
  try:
42
42
  # Prepare query-document pairs
@@ -73,10 +73,10 @@ class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
73
73
  return filtered_docs
74
74
 
75
75
  except Exception as e:
76
- error_msg = f"Error during async document compression: {str(e)}"
77
- log.error(error_msg)
76
+ error_msg = "Error during async document compression:"
77
+ log.exception(error_msg)
78
78
  if callbacks:
79
- await callbacks.on_retriever_error(error_msg)
79
+ await callbacks.on_retriever_error(f"{error_msg} {e}")
80
80
  return documents # Return original documents on error
81
81
 
82
82
  def compress_documents(
@@ -1,10 +1,10 @@
1
1
  import re
2
-
3
- from pydantic import BaseModel, Field
4
- from typing import List, Any, Optional, Dict, Tuple, Union, Callable
5
- import collections
6
2
  import math
3
+ import logging
4
+ import collections
5
+ from typing import List, Any, Optional, Dict, Tuple, Union, Callable
7
6
 
7
+ from pydantic import BaseModel, Field
8
8
  from langchain.chains.llm import LLMChain
9
9
  from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
10
10
  from langchain_core.documents.base import Document
@@ -39,9 +39,7 @@ class MetadataFilter(BaseModel):
39
39
  """Represents an LLM generated metadata filter to apply to a PostgreSQL query."""
40
40
 
41
41
  attribute: str = Field(description="Database column to apply filter to")
42
- comparator: str = Field(
43
- description="PostgreSQL comparator to use to filter database column"
44
- )
42
+ comparator: str = Field(description="PostgreSQL comparator to use to filter database column")
45
43
  value: Any = Field(description="Value to use to filter database column")
46
44
 
47
45
 
@@ -56,9 +54,7 @@ class AblativeMetadataFilter(MetadataFilter):
56
54
  class MetadataFilters(BaseModel):
57
55
  """List of LLM generated metadata filters to apply to a PostgreSQL query."""
58
56
 
59
- filters: List[MetadataFilter] = Field(
60
- description="List of PostgreSQL metadata filters to apply for user query"
61
- )
57
+ filters: List[MetadataFilter] = Field(description="List of PostgreSQL metadata filters to apply for user query")
62
58
 
63
59
 
64
60
  class SQLRetriever(BaseRetriever):
@@ -142,25 +138,17 @@ class SQLRetriever(BaseRetriever):
142
138
  elif isinstance(schema, ColumnSchema):
143
139
  collection_key = "values"
144
140
  else:
145
- raise Exception(
146
- "schema must be either a DatabaseSchema, TableSchema, or ColumnSchema."
147
- )
141
+ raise Exception("schema must be either a DatabaseSchema, TableSchema, or ColumnSchema.")
148
142
 
149
143
  if update is not None:
150
- ordered = collections.OrderedDict(
151
- sorted(update.items(), key=key, reverse=True)
152
- )
144
+ ordered = collections.OrderedDict(sorted(update.items(), key=key, reverse=True))
153
145
  else:
154
- ordered = collections.OrderedDict(
155
- sorted(getattr(schema, collection_key).items(), key=key, reverse=True)
156
- )
146
+ ordered = collections.OrderedDict(sorted(getattr(schema, collection_key).items(), key=key, reverse=True))
157
147
  schema = schema.model_copy(update={collection_key: ordered})
158
148
 
159
149
  return schema
160
150
 
161
- def _sort_database_schema_by_key(
162
- self, database_schema: DatabaseSchema, key: Callable
163
- ) -> DatabaseSchema:
151
+ def _sort_database_schema_by_key(self, database_schema: DatabaseSchema, key: Callable) -> DatabaseSchema:
164
152
  """Re-build schema with OrderedDicts"""
165
153
  tables = {}
166
154
  # build new tables dict
@@ -169,17 +157,11 @@ class SQLRetriever(BaseRetriever):
169
157
  # build new column dict
170
158
  for column_key, column_schema in table_schema.columns.items():
171
159
  # sort values directly and update column schema
172
- columns[column_key] = self._sort_schema_by_key(
173
- schema=column_schema, key=key
174
- )
160
+ columns[column_key] = self._sort_schema_by_key(schema=column_schema, key=key)
175
161
  # update table schema and sort
176
- tables[table_key] = self._sort_schema_by_key(
177
- schema=table_schema, key=key, update=columns
178
- )
162
+ tables[table_key] = self._sort_schema_by_key(schema=table_schema, key=key, update=columns)
179
163
  # update table schema and sort
180
- database_schema = self._sort_schema_by_key(
181
- schema=database_schema, key=key, update=tables
182
- )
164
+ database_schema = self._sort_schema_by_key(schema=database_schema, key=key, update=tables)
183
165
 
184
166
  return database_schema
185
167
 
@@ -191,15 +173,12 @@ class SQLRetriever(BaseRetriever):
191
173
  boolean_system_prompt: bool = True,
192
174
  format_instructions: Optional[str] = None,
193
175
  ) -> ChatPromptTemplate:
194
-
195
176
  if boolean_system_prompt is True:
196
177
  system_prompt = self.boolean_system_prompt
197
178
  else:
198
179
  system_prompt = self.generative_system_prompt
199
180
 
200
- prepared_column_prompt = self._prepare_column_prompt(
201
- column_schema=column_schema, table_schema=table_schema
202
- )
181
+ prepared_column_prompt = self._prepare_column_prompt(column_schema=column_schema, table_schema=table_schema)
203
182
  column_schema_str = (
204
183
  prepared_column_prompt.messages[1]
205
184
  .format(
@@ -290,7 +269,6 @@ Below is a list of comparison operators for constructing filters for this value
290
269
  table_schema: TableSchema,
291
270
  boolean_system_prompt: bool = True,
292
271
  ) -> ChatPromptTemplate:
293
-
294
272
  if boolean_system_prompt is True:
295
273
  system_prompt = self.boolean_system_prompt
296
274
  else:
@@ -312,9 +290,7 @@ Below is a list of comparison operators for constructing filters for this value
312
290
  [("system", system_prompt), ("user", self.column_prompt_template)]
313
291
  )
314
292
 
315
- header_str = (
316
- f"This schema describes a column in the {table_schema.table} table."
317
- )
293
+ header_str = f"This schema describes a column in the {table_schema.table} table."
318
294
 
319
295
  value_str = """
320
296
  ## **Content**
@@ -388,26 +364,18 @@ Below is a description of the contents in this column in list format:
388
364
  )
389
365
 
390
366
  def _rank_schema(self, prompt: ChatPromptTemplate, query: str) -> float:
391
- rank_chain = LLMChain(
392
- llm=self.llm.bind(logprobs=True), prompt=prompt, return_final_only=False
393
- )
367
+ rank_chain = LLMChain(llm=self.llm.bind(logprobs=True), prompt=prompt, return_final_only=False)
394
368
  output = rank_chain({"query": query}) # returns metadata
395
369
 
396
370
  # parse through metadata tokens until encountering either yes, or no.
397
371
  score = None # a None score indicates the model output could not be parsed.
398
- for content in output["full_generation"][0].message.response_metadata[
399
- "logprobs"
400
- ]["content"]:
372
+ for content in output["full_generation"][0].message.response_metadata["logprobs"]["content"]:
401
373
  # Convert answer to score using the model's confidence
402
374
  if content["token"].lower().strip() == "yes":
403
- score = (
404
- 1 + math.exp(content["logprob"])
405
- ) / 2 # If yes, use the model's confidence
375
+ score = (1 + math.exp(content["logprob"])) / 2 # If yes, use the model's confidence
406
376
  break
407
377
  elif content["token"].lower().strip() == "no":
408
- score = (
409
- 1 - math.exp(content["logprob"])
410
- ) / 2 # If no, invert the confidence
378
+ score = (1 - math.exp(content["logprob"])) / 2 # If no, invert the confidence
411
379
  break
412
380
 
413
381
  if score is None:
@@ -465,9 +433,7 @@ Below is a description of the contents in this column in list format:
465
433
  table_schema=table_schema,
466
434
  boolean_system_prompt=True,
467
435
  )
468
- column_schema.relevance = self._rank_schema(
469
- prompt=prompt, query=query
470
- )
436
+ column_schema.relevance = self._rank_schema(prompt=prompt, query=query)
471
437
 
472
438
  columns[column_key] = column_schema
473
439
 
@@ -512,9 +478,7 @@ Below is a description of the contents in this column in list format:
512
478
  table_schema=table_schema,
513
479
  boolean_system_prompt=True,
514
480
  )
515
- value_schema.relevance = self._rank_schema(
516
- prompt=prompt, query=query
517
- )
481
+ value_schema.relevance = self._rank_schema(prompt=prompt, query=query)
518
482
 
519
483
  values[value_key] = value_schema
520
484
 
@@ -592,19 +556,13 @@ Below is a description of the contents in this column in list format:
592
556
  for table_key, table_schema in ordered_database_schema.tables.items():
593
557
  for column_key, column_schema in table_schema.columns.items():
594
558
  for value_key, value_schema in column_schema.values.items():
595
- ablation_value_dict[(table_key, column_key, value_key)] = (
596
- value_schema.relevance
597
- )
559
+ ablation_value_dict[(table_key, column_key, value_key)] = value_schema.relevance
598
560
 
599
- ablation_value_dict = collections.OrderedDict(
600
- sorted(ablation_value_dict.items(), key=lambda x: x[1])
601
- )
561
+ ablation_value_dict = collections.OrderedDict(sorted(ablation_value_dict.items(), key=lambda x: x[1]))
602
562
 
603
563
  relevance_scores = list(ablation_value_dict.values())
604
564
  if len(relevance_scores) > 0:
605
- ablation_quantiles = np.quantile(
606
- relevance_scores, np.linspace(0, 1, self.num_retries + 2)[1:-1]
607
- )
565
+ ablation_quantiles = np.quantile(relevance_scores, np.linspace(0, 1, self.num_retries + 2)[1:-1])
608
566
  else:
609
567
  ablation_quantiles = None
610
568
 
@@ -628,11 +586,7 @@ Below is a description of the contents in this column in list format:
628
586
  ablated_filters = []
629
587
  for filter in metadata_filters:
630
588
  for key in ablated_dict.keys():
631
- if (
632
- filter.schema_table in key
633
- and filter.schema_column in key
634
- and filter.schema_value in key
635
- ):
589
+ if filter.schema_table in key and filter.schema_column in key and filter.schema_value in key:
636
590
  ablated_filters.append(filter)
637
591
 
638
592
  return ablated_filters
@@ -646,9 +600,7 @@ Below is a description of the contents in this column in list format:
646
600
  pass
647
601
 
648
602
  def _prepare_retrieval_query(self, query: str) -> str:
649
- rewrite_prompt = PromptTemplate(
650
- input_variables=["input"], template=self.rewrite_prompt_template
651
- )
603
+ rewrite_prompt = PromptTemplate(input_variables=["input"], template=self.rewrite_prompt_template)
652
604
  rewrite_chain = LLMChain(llm=self.llm, prompt=rewrite_prompt)
653
605
  return rewrite_chain.predict(input=query)
654
606
 
@@ -668,9 +620,7 @@ Below is a description of the contents in this column in list format:
668
620
  # Add Table JOIN statements
669
621
  join_clauses = set()
670
622
  for metadata_filter in metadata_filters:
671
- join_clause = ranked_database_schema.tables[
672
- metadata_filter.schema_table
673
- ].join
623
+ join_clause = ranked_database_schema.tables[metadata_filter.schema_table].join
674
624
  if join_clause in join_clauses:
675
625
  continue
676
626
  else:
@@ -688,12 +638,12 @@ Below is a description of the contents in this column in list format:
688
638
  if i < len(metadata_filters) - 1:
689
639
  base_query += " AND "
690
640
 
691
- base_query += f" ORDER BY e.embeddings {self.distance_function.value[0]} '{{embeddings}}' LIMIT {self.search_kwargs.k};"
641
+ base_query += (
642
+ f" ORDER BY e.embeddings {self.distance_function.value[0]} '{{embeddings}}' LIMIT {self.search_kwargs.k};"
643
+ )
692
644
  return base_query
693
645
 
694
- def _generate_filter(
695
- self, prompt: ChatPromptTemplate, query: str
696
- ) -> MetadataFilter:
646
+ def _generate_filter(self, prompt: ChatPromptTemplate, query: str) -> MetadataFilter:
697
647
  gen_filter_chain = LLMChain(llm=self.llm, prompt=prompt)
698
648
  output = gen_filter_chain({"query": query})
699
649
  return output
@@ -714,28 +664,22 @@ Below is a description of the contents in this column in list format:
714
664
  # must use generation if field is a dictionary of tuples or a list
715
665
  if type(value_schema.value) in [list, dict]:
716
666
  try:
717
- metadata_prompt: ChatPromptTemplate = (
718
- self._prepare_value_prompt(
719
- format_instructions=parser.get_format_instructions(),
720
- value_schema=value_schema,
721
- column_schema=column_schema,
722
- table_schema=table_schema,
723
- boolean_system_prompt=False,
724
- )
667
+ metadata_prompt: ChatPromptTemplate = self._prepare_value_prompt(
668
+ format_instructions=parser.get_format_instructions(),
669
+ value_schema=value_schema,
670
+ column_schema=column_schema,
671
+ table_schema=table_schema,
672
+ boolean_system_prompt=False,
725
673
  )
726
674
 
727
- metadata_filters_chain = LLMChain(
728
- llm=self.llm, prompt=metadata_prompt
729
- )
675
+ metadata_filters_chain = LLMChain(llm=self.llm, prompt=metadata_prompt)
730
676
  metadata_filter_output = metadata_filters_chain.predict(
731
677
  query=query,
732
678
  )
733
679
 
734
680
  # If the LLM outputs raw JSON, use it as-is.
735
681
  # If the LLM outputs anything including a json markdown section, use the last one.
736
- json_markdown_output = re.findall(
737
- r"```json.*```", metadata_filter_output, re.DOTALL
738
- )
682
+ json_markdown_output = re.findall(r"```json.*```", metadata_filter_output, re.DOTALL)
739
683
  if json_markdown_output:
740
684
  metadata_filter_output = json_markdown_output[-1]
741
685
  # Clean the json tags.
@@ -754,11 +698,10 @@ Below is a description of the contents in this column in list format:
754
698
  metadata_filter = AblativeMetadataFilter(**model_dump)
755
699
  except OutputParserException as e:
756
700
  logger.warning(
757
- f"LLM failed to generate structured metadata filters: {str(e)}"
758
- )
759
- return HandlerResponse(
760
- RESPONSE_TYPE.ERROR, error_message=str(e)
701
+ f"LLM failed to generate structured metadata filters: {e}",
702
+ exc_info=logger.isEnabledFor(logging.DEBUG),
761
703
  )
704
+ return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
762
705
  else:
763
706
  metadata_filter = AblativeMetadataFilter(
764
707
  attribute=column_schema.column,
@@ -779,24 +722,17 @@ Below is a description of the contents in this column in list format:
779
722
  embeddings_str: str,
780
723
  ) -> HandlerResponse:
781
724
  try:
782
- checked_sql_query = self._prepare_pgvector_query(
783
- ranked_database_schema, metadata_filters
784
- )
785
- checked_sql_query_with_embeddings = checked_sql_query.format(
786
- embeddings=embeddings_str
787
- )
788
- return self.vector_store_handler.native_query(
789
- checked_sql_query_with_embeddings
790
- )
725
+ checked_sql_query = self._prepare_pgvector_query(ranked_database_schema, metadata_filters)
726
+ checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=embeddings_str)
727
+ return self.vector_store_handler.native_query(checked_sql_query_with_embeddings)
791
728
  except Exception as e:
792
729
  logger.warning(
793
- f"Failed to prepare and execute SQL query from structured metadata: {str(e)}"
730
+ f"Failed to prepare and execute SQL query from structured metadata: {e}",
731
+ exc_info=logger.isEnabledFor(logging.DEBUG),
794
732
  )
795
733
  return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
796
734
 
797
- def _get_relevant_documents(
798
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
799
- ) -> List[Document]:
735
+ def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
800
736
  # Rewrite query to be suitable for retrieval.
801
737
  retrieval_query = self._prepare_retrieval_query(query)
802
738
 
@@ -804,14 +740,10 @@ Below is a description of the contents in this column in list format:
804
740
  embedded_query = self.embeddings_model.embed_query(retrieval_query)
805
741
 
806
742
  # Search for relevant filters
807
- ranked_database_schema, ablation_value_dict, ablation_quantiles = (
808
- self._breadth_first_search(query=query)
809
- )
743
+ ranked_database_schema, ablation_value_dict, ablation_quantiles = self._breadth_first_search(query=query)
810
744
 
811
745
  # Generate metadata filters
812
- metadata_filters = self._generate_metadata_filters(
813
- query=query, ranked_database_schema=ranked_database_schema
814
- )
746
+ metadata_filters = self._generate_metadata_filters(query=query, ranked_database_schema=ranked_database_schema)
815
747
 
816
748
  if type(metadata_filters) is list:
817
749
  # Initial Execution of the similarity search with metadata filters.
@@ -830,9 +762,7 @@ Below is a description of the contents in this column in list format:
830
762
  break
831
763
  elif document_response.resp_type == RESPONSE_TYPE.ERROR:
832
764
  # LLMs won't always generate structured metadata so we should have a fallback after retrying.
833
- logger.info(
834
- f"SQL Retriever query failed with error {document_response.error_message}"
835
- )
765
+ logger.info(f"SQL Retriever query failed with error {document_response.error_message}")
836
766
  else:
837
767
  logger.info(
838
768
  f"SQL Retriever did not retrieve {self.min_k} documents: {len(document_response.data_frame)} documents retrieved."
@@ -867,17 +797,9 @@ Below is a description of the contents in this column in list format:
867
797
  return retrieved_documents
868
798
 
869
799
  # If the SQL query constructed did not return any documents, fallback.
870
- logger.info(
871
- "No documents returned from SQL retriever, using fallback retriever."
872
- )
873
- return self.fallback_retriever._get_relevant_documents(
874
- retrieval_query, run_manager=run_manager
875
- )
800
+ logger.info("No documents returned from SQL retriever, using fallback retriever.")
801
+ return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager)
876
802
  else:
877
803
  # If no metadata fields could be generated fallback.
878
- logger.info(
879
- "No metadata fields were successfully generated, using fallback retriever."
880
- )
881
- return self.fallback_retriever._get_relevant_documents(
882
- retrieval_query, run_manager=run_manager
883
- )
804
+ logger.info("No metadata fields were successfully generated, using fallback retriever.")
805
+ return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager)