MindsDB 25.2.3.0__py3-none-any.whl → 25.3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (86) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/__main__.py +16 -11
  3. mindsdb/api/executor/command_executor.py +1 -1
  4. mindsdb/api/executor/datahub/datanodes/system_tables.py +10 -2
  5. mindsdb/api/executor/planner/query_planner.py +6 -2
  6. mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -1
  7. mindsdb/api/http/initialize.py +8 -5
  8. mindsdb/api/http/namespaces/agents.py +0 -7
  9. mindsdb/api/http/namespaces/config.py +0 -48
  10. mindsdb/api/http/namespaces/knowledge_bases.py +1 -1
  11. mindsdb/api/http/namespaces/util.py +0 -28
  12. mindsdb/api/mongo/classes/query_sql.py +2 -1
  13. mindsdb/api/mongo/responders/aggregate.py +2 -2
  14. mindsdb/api/mongo/responders/coll_stats.py +3 -2
  15. mindsdb/api/mongo/responders/db_stats.py +2 -1
  16. mindsdb/api/mongo/responders/insert.py +4 -2
  17. mindsdb/api/mysql/mysql_proxy/classes/fake_mysql_proxy/fake_mysql_proxy.py +2 -1
  18. mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +5 -4
  19. mindsdb/api/postgres/postgres_proxy/postgres_proxy.py +2 -4
  20. mindsdb/integrations/handlers/anyscale_endpoints_handler/requirements.txt +0 -1
  21. mindsdb/integrations/handlers/autosklearn_handler/autosklearn_handler.py +1 -1
  22. mindsdb/integrations/handlers/dspy_handler/requirements.txt +0 -1
  23. mindsdb/integrations/handlers/gmail_handler/connection_args.py +2 -2
  24. mindsdb/integrations/handlers/gmail_handler/gmail_handler.py +19 -66
  25. mindsdb/integrations/handlers/gmail_handler/requirements.txt +0 -1
  26. mindsdb/integrations/handlers/google_calendar_handler/connection_args.py +15 -0
  27. mindsdb/integrations/handlers/google_calendar_handler/google_calendar_handler.py +31 -41
  28. mindsdb/integrations/handlers/google_calendar_handler/requirements.txt +0 -2
  29. mindsdb/integrations/handlers/langchain_embedding_handler/requirements.txt +0 -1
  30. mindsdb/integrations/handlers/langchain_handler/requirements.txt +0 -1
  31. mindsdb/integrations/handlers/llama_index_handler/requirements.txt +0 -1
  32. mindsdb/integrations/handlers/openai_handler/constants.py +3 -1
  33. mindsdb/integrations/handlers/openai_handler/requirements.txt +0 -1
  34. mindsdb/integrations/handlers/rag_handler/requirements.txt +0 -1
  35. mindsdb/integrations/handlers/ray_serve_handler/ray_serve_handler.py +33 -8
  36. mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +3 -2
  37. mindsdb/integrations/handlers/web_handler/web_handler.py +42 -33
  38. mindsdb/integrations/handlers/youtube_handler/__init__.py +2 -0
  39. mindsdb/integrations/handlers/youtube_handler/connection_args.py +32 -0
  40. mindsdb/integrations/handlers/youtube_handler/youtube_handler.py +2 -38
  41. mindsdb/integrations/libs/llm/utils.py +7 -1
  42. mindsdb/integrations/libs/process_cache.py +2 -2
  43. mindsdb/integrations/utilities/handlers/auth_utilities/google/google_user_oauth_utilities.py +29 -38
  44. mindsdb/integrations/utilities/pydantic_utils.py +208 -0
  45. mindsdb/integrations/utilities/rag/chains/local_context_summarizer_chain.py +227 -0
  46. mindsdb/integrations/utilities/rag/pipelines/rag.py +11 -4
  47. mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +800 -135
  48. mindsdb/integrations/utilities/rag/settings.py +390 -152
  49. mindsdb/integrations/utilities/sql_utils.py +2 -1
  50. mindsdb/interfaces/agents/agents_controller.py +14 -10
  51. mindsdb/interfaces/agents/callback_handlers.py +52 -5
  52. mindsdb/interfaces/agents/langchain_agent.py +5 -3
  53. mindsdb/interfaces/agents/mindsdb_chat_model.py +4 -2
  54. mindsdb/interfaces/chatbot/chatbot_controller.py +9 -8
  55. mindsdb/interfaces/database/database.py +3 -2
  56. mindsdb/interfaces/database/integrations.py +1 -1
  57. mindsdb/interfaces/database/projects.py +28 -2
  58. mindsdb/interfaces/jobs/jobs_controller.py +4 -1
  59. mindsdb/interfaces/jobs/scheduler.py +1 -1
  60. mindsdb/interfaces/knowledge_base/preprocessing/constants.py +2 -2
  61. mindsdb/interfaces/model/model_controller.py +5 -2
  62. mindsdb/interfaces/skills/retrieval_tool.py +128 -39
  63. mindsdb/interfaces/skills/skill_tool.py +7 -7
  64. mindsdb/interfaces/skills/skills_controller.py +10 -6
  65. mindsdb/interfaces/skills/sql_agent.py +6 -1
  66. mindsdb/interfaces/storage/db.py +14 -12
  67. mindsdb/interfaces/storage/json.py +59 -0
  68. mindsdb/interfaces/storage/model_fs.py +85 -3
  69. mindsdb/interfaces/triggers/triggers_controller.py +2 -1
  70. mindsdb/migrations/versions/2022-10-14_43c52d23845a_projects.py +17 -3
  71. mindsdb/migrations/versions/2025-02-10_6ab9903fc59a_del_log_table.py +33 -0
  72. mindsdb/migrations/versions/2025-02-14_4521dafe89ab_added_encrypted_content_to_json_storage.py +29 -0
  73. mindsdb/migrations/versions/2025-02-19_11347c213b36_added_metadata_to_projects.py +41 -0
  74. mindsdb/utilities/config.py +6 -1
  75. mindsdb/utilities/functions.py +11 -0
  76. mindsdb/utilities/log.py +17 -2
  77. mindsdb/utilities/ml_task_queue/consumer.py +4 -2
  78. mindsdb/utilities/render/sqlalchemy_render.py +4 -0
  79. {MindsDB-25.2.3.0.dist-info → mindsdb-25.3.1.0.dist-info}/METADATA +226 -247
  80. {MindsDB-25.2.3.0.dist-info → mindsdb-25.3.1.0.dist-info}/RECORD +83 -80
  81. {MindsDB-25.2.3.0.dist-info → mindsdb-25.3.1.0.dist-info}/WHEEL +1 -1
  82. mindsdb/integrations/handlers/gmail_handler/utils.py +0 -45
  83. mindsdb/utilities/log_controller.py +0 -39
  84. mindsdb/utilities/telemetry.py +0 -44
  85. {MindsDB-25.2.3.0.dist-info → mindsdb-25.3.1.0.dist-info}/LICENSE +0 -0
  86. {MindsDB-25.2.3.0.dist-info → mindsdb-25.3.1.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
1
- import json
2
1
  import re
2
+
3
3
  from pydantic import BaseModel, Field
4
- from typing import Any, List, Optional
4
+ from typing import List, Any, Optional, Dict, Tuple, Union, Callable
5
+ import collections
6
+ import math
5
7
 
6
8
  from langchain.chains.llm import LLMChain
7
9
  from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
@@ -10,32 +12,57 @@ from langchain_core.embeddings import Embeddings
10
12
  from langchain_core.exceptions import OutputParserException
11
13
  from langchain_core.language_models.chat_models import BaseChatModel
12
14
  from langchain_core.output_parsers import PydanticOutputParser
13
- from langchain_core.prompts import PromptTemplate
15
+ from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
14
16
  from langchain_core.retrievers import BaseRetriever
15
17
 
16
18
  from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE
17
19
  from mindsdb.integrations.libs.response import HandlerResponse
18
- from mindsdb.integrations.libs.vectordatabase_handler import DistanceFunction, VectorStoreHandler
19
- from mindsdb.integrations.utilities.rag.settings import LLMExample, MetadataSchema, SearchKwargs
20
+ from mindsdb.integrations.libs.vectordatabase_handler import (
21
+ DistanceFunction,
22
+ VectorStoreHandler,
23
+ )
24
+ from mindsdb.integrations.utilities.rag.settings import (
25
+ DatabaseSchema,
26
+ TableSchema,
27
+ ColumnSchema,
28
+ ValueSchema,
29
+ SearchKwargs,
30
+ )
20
31
  from mindsdb.utilities import log
21
32
 
33
+ import numpy as np
34
+
22
35
  logger = log.getLogger(__name__)
23
36
 
24
37
 
25
38
  class MetadataFilter(BaseModel):
26
- '''Represents an LLM generated metadata filter to apply to a PostgreSQL query.'''
39
+ """Represents an LLM generated metadata filter to apply to a PostgreSQL query."""
40
+
27
41
  attribute: str = Field(description="Database column to apply filter to")
28
- comparator: str = Field(description="PostgreSQL comparator to use to filter database column")
42
+ comparator: str = Field(
43
+ description="PostgreSQL comparator to use to filter database column"
44
+ )
29
45
  value: Any = Field(description="Value to use to filter database column")
30
46
 
31
47
 
48
+ class AblativeMetadataFilter(MetadataFilter):
49
+ """Adds additional fields to support ablation."""
50
+
51
+ schema_table: str = Field(description="schema name of the table for this filter")
52
+ schema_column: str = Field(description="schema name of the column for this filter")
53
+ schema_value: str = Field(description="schema name of the value for this filter")
54
+
55
+
32
56
  class MetadataFilters(BaseModel):
33
- '''List of LLM generated metadata filters to apply to a PostgreSQL query.'''
34
- filters: List[MetadataFilter] = Field(description="List of PostgreSQL metadata filters to apply for user query")
57
+ """List of LLM generated metadata filters to apply to a PostgreSQL query."""
58
+
59
+ filters: List[MetadataFilter] = Field(
60
+ description="List of PostgreSQL metadata filters to apply for user query"
61
+ )
35
62
 
36
63
 
37
64
  class SQLRetriever(BaseRetriever):
38
- '''Retriever that uses a LLM to generate pgvector queries to do similarity search with metadata filters.
65
+ """Retriever that uses a LLM to generate pgvector queries to do similarity search with metadata filters.
39
66
 
40
67
  How it works:
41
68
 
@@ -48,134 +75,723 @@ class SQLRetriever(BaseRetriever):
48
75
  3. Generate a prepared PostgreSQL query from the structured metadata filters.
49
76
 
50
77
  4. Actually execute the query against our vector database to retrieve documents & return them.
51
- '''
78
+ """
79
+
52
80
  fallback_retriever: BaseRetriever
53
81
  vector_store_handler: VectorStoreHandler
54
- metadata_schemas: Optional[List[MetadataSchema]] = None
55
- examples: Optional[List[LLMExample]] = None
82
+ # search parameters
83
+ max_filters: int
84
+ filter_threshold: float
85
+ min_k: int
56
86
 
57
- rewrite_prompt_template: str
58
- metadata_filters_prompt_template: str
87
+ # Schema description
88
+ database_schema: Optional[DatabaseSchema] = None
89
+
90
+ # Embeddings
59
91
  embeddings_model: Embeddings
92
+ search_kwargs: SearchKwargs
93
+
94
+ # prompt templates
95
+ rewrite_prompt_template: str
96
+
97
+ # schema templates
98
+ table_prompt_template: str
99
+ column_prompt_template: str
100
+ value_prompt_template: str
101
+
102
+ # formatting templates
103
+ boolean_system_prompt: str
104
+ generative_system_prompt: str
105
+
106
+ # SQL search config
60
107
  num_retries: int
61
108
  embeddings_table: str
62
109
  source_table: str
63
- source_id_column: str = 'Id'
110
+ source_id_column: str
64
111
  distance_function: DistanceFunction
65
- search_kwargs: SearchKwargs
66
112
 
113
+ # Re-rank and metadata generation model.
67
114
  llm: BaseChatModel
68
115
 
69
- def _prepare_metadata_prompt(self) -> PromptTemplate:
70
- base_prompt_template = PromptTemplate(
71
- input_variables=['format_instructions', 'schema', 'examples', 'input', 'embeddings'],
72
- template=self.metadata_filters_prompt_template
73
- )
74
- schema_prompt_str = ''
75
- if self.metadata_schemas is not None:
76
- for i, schema in enumerate(self.metadata_schemas):
77
- column_mapping = {}
78
- for column in schema.columns:
79
- column_mapping[column.name] = {
80
- 'type': column.type,
81
- 'description': column.description,
82
- }
83
- if column.values is not None:
84
- column_mapping[column.name]['values'] = column.values
85
- column_mapping_json_str = json.dumps(column_mapping, indent=4)
86
- schema_str = f'''{i+1}. {schema.table} - {schema.description}
87
-
88
- Columns:
89
- ```json
90
- {column_mapping_json_str}
91
- ```
92
-
93
- '''
94
- schema_prompt_str += schema_str
95
-
96
- examples_prompt_str = ''
97
- if self.examples is not None:
98
- for i, example in enumerate(self.examples):
99
- example_str = f'''{i + 1}. User input: "{example.input}"
100
-
101
- Output:
102
- {example.output}
103
-
104
- '''
105
- examples_prompt_str += example_str
116
+ def _sort_schema_by_priority_key(
117
+ self,
118
+ schema_dict_item: Tuple[str, Union[TableSchema, ColumnSchema, ValueSchema]],
119
+ ):
120
+ return schema_dict_item[1].priority
121
+
122
+ def _sort_schema_by_relevance_key(
123
+ self,
124
+ schema_dict_item: Tuple[str, Union[TableSchema, ColumnSchema, ValueSchema]],
125
+ ):
126
+ if schema_dict_item[1].relevance is not None:
127
+ return schema_dict_item[1].relevance
128
+ else:
129
+ return 0
130
+
131
+ def _sort_schema_by_key(
132
+ self,
133
+ schema: Union[DatabaseSchema, TableSchema, ColumnSchema],
134
+ key: Callable,
135
+ update: Dict[str, Any] = None,
136
+ ) -> Union[DatabaseSchema, TableSchema, ColumnSchema]:
137
+ """Takes a schema and converts its dict into an OrderedDict"""
138
+ if isinstance(schema, DatabaseSchema):
139
+ collection_key = "tables"
140
+ elif isinstance(schema, TableSchema):
141
+ collection_key = "columns"
142
+ elif isinstance(schema, ColumnSchema):
143
+ collection_key = "values"
144
+ else:
145
+ raise Exception(
146
+ "schema must be either a DatabaseSchema, TableSchema, or ColumnSchema."
147
+ )
148
+
149
+ if update is not None:
150
+ ordered = collections.OrderedDict(
151
+ sorted(update.items(), key=key, reverse=True)
152
+ )
153
+ else:
154
+ ordered = collections.OrderedDict(
155
+ sorted(getattr(schema, collection_key).items(), key=key, reverse=True)
156
+ )
157
+ schema = schema.model_copy(update={collection_key: ordered})
158
+
159
+ return schema
160
+
161
+ def _sort_database_schema_by_key(
162
+ self, database_schema: DatabaseSchema, key: Callable
163
+ ) -> DatabaseSchema:
164
+ """Re-build schema with OrderedDicts"""
165
+ tables = {}
166
+ # build new tables dict
167
+ for table_key, table_schema in database_schema.tables.items():
168
+ columns = {}
169
+ # build new column dict
170
+ for column_key, column_schema in table_schema.columns.items():
171
+ # sort values directly and update column schema
172
+ columns[column_key] = self._sort_schema_by_key(
173
+ schema=column_schema, key=key
174
+ )
175
+ # update table schema and sort
176
+ tables[table_key] = self._sort_schema_by_key(
177
+ schema=table_schema, key=key, update=columns
178
+ )
179
+ # update table schema and sort
180
+ database_schema = self._sort_schema_by_key(
181
+ schema=database_schema, key=key, update=tables
182
+ )
183
+
184
+ return database_schema
185
+
186
+ def _prepare_value_prompt(
187
+ self,
188
+ value_schema: ValueSchema,
189
+ column_schema: ColumnSchema,
190
+ table_schema: TableSchema,
191
+ boolean_system_prompt: bool = True,
192
+ format_instructions: Optional[str] = None,
193
+ ) -> ChatPromptTemplate:
194
+
195
+ if boolean_system_prompt is True:
196
+ system_prompt = self.boolean_system_prompt
197
+ else:
198
+ system_prompt = self.generative_system_prompt
199
+
200
+ prepared_column_prompt = self._prepare_column_prompt(
201
+ column_schema=column_schema, table_schema=table_schema
202
+ )
203
+ column_schema_str = (
204
+ prepared_column_prompt.messages[1]
205
+ .format(
206
+ **prepared_column_prompt.partial_variables,
207
+ query="See query at the lowest level schema.",
208
+ )
209
+ .content
210
+ )
211
+
212
+ base_prompt_template = ChatPromptTemplate.from_messages(
213
+ [("system", system_prompt), ("user", self.value_prompt_template)]
214
+ )
215
+
216
+ value_str = ""
217
+ header_str = ""
218
+ if type(value_schema.value) in [str, int, float, bool]:
219
+ header_str = f"This schema describes a single value in the {column_schema.column} column."
220
+
221
+ value_str = f"""
222
+ -**Value**: {value_schema.value}
223
+ """
224
+
225
+ elif type(value_schema.value) is dict:
226
+ header_str = f"This schema describes enumerated values in the {column_schema.column} column."
227
+
228
+ value_str = """
229
+ ## **Enumerated Values**
230
+
231
+ The values in the column are an enumeration of named values. These are listed below with format **[Column Value]**: [named value].
232
+ """
233
+ for value, value_name in value_schema.value.items():
234
+ value_str += f"""
235
+ - **{value}:** {value_name}"""
236
+
237
+ elif type(value_schema.value) is list:
238
+ header_str = f"This schema describes some of the values in the {column_schema.column} column."
239
+
240
+ value_str = """
241
+ ## **Sample Values**
242
+
243
+ There are too many values in this column to list exhaustively. Below is a sampling of values found in the column:
244
+ """
245
+ for value in value_schema.value:
246
+ value_str += f"""
247
+ - {value}"""
248
+
249
+ if getattr(value_schema, "comparator", None) is not None:
250
+ comparator_str = """
251
+
252
+ ## **Comparators**
253
+
254
+ Below is a list of comparison operators for constructing filters for this value schema:
255
+ """
256
+ if type(value_schema.comparator) is str:
257
+ comparator_str += f"""- {value_schema.comparator}
258
+ """
259
+ else:
260
+ for comp in value_schema.comparator:
261
+ comparator_str += f"""- {comp}
262
+ """
263
+ else:
264
+ comparator_str = ""
265
+
266
+ if getattr(value_schema, "example_questions", None) is not None:
267
+ example_str = """## **Example Questions**
268
+ """
269
+ for i, example in enumerate(value_schema.example_questions):
270
+ example_str += f"""{i}. **Query:** {example.input} **Answer:** {example.output}
271
+ """
272
+ else:
273
+ example_str = ""
274
+
275
+ return base_prompt_template.partial(
276
+ format_instructions=format_instructions,
277
+ header=header_str,
278
+ column_schema=column_schema_str,
279
+ value=value_str,
280
+ comparator=comparator_str,
281
+ type=value_schema.type,
282
+ description=value_schema.description,
283
+ usage=value_schema.usage,
284
+ examples=example_str,
285
+ )
286
+
287
+ def _prepare_column_prompt(
288
+ self,
289
+ column_schema: ColumnSchema,
290
+ table_schema: TableSchema,
291
+ boolean_system_prompt: bool = True,
292
+ ) -> ChatPromptTemplate:
293
+
294
+ if boolean_system_prompt is True:
295
+ system_prompt = self.boolean_system_prompt
296
+ else:
297
+ system_prompt = self.generative_system_prompt
298
+
299
+ prepared_table_prompt = self._prepare_table_prompt(
300
+ table_schema=table_schema, boolean_system_prompt=boolean_system_prompt
301
+ )
302
+ table_schema_str = (
303
+ prepared_table_prompt.messages[1]
304
+ .format(
305
+ **prepared_table_prompt.partial_variables,
306
+ query="See query at the lowest level schema",
307
+ )
308
+ .content
309
+ )
310
+
311
+ base_prompt_template = ChatPromptTemplate.from_messages(
312
+ [("system", system_prompt), ("user", self.column_prompt_template)]
313
+ )
314
+
315
+ header_str = (
316
+ f"This schema describes a column in the {table_schema.table} table."
317
+ )
318
+
319
+ value_str = """
320
+ ## **Content**
321
+
322
+ Below is a description of the contents in this column in list format:
323
+ """
324
+ for value_schema in column_schema.values.values():
325
+ value_str += f"""
326
+ - {value_schema.description}
327
+ """
328
+ value_str += """
329
+ **Important:** The above descriptions are not the actual values stored in this column. See the Value schema for actual values.
330
+ """
331
+
332
+ if getattr(column_schema, "examples", None) is not None:
333
+ example_str = """## **Example Questions**
334
+ """
335
+ for example in column_schema.examples:
336
+ example_str += f"""- {example}
337
+ """
338
+ else:
339
+ example_str = ""
340
+
341
+ return base_prompt_template.partial(
342
+ table_schema=table_schema_str,
343
+ header=header_str,
344
+ column=column_schema.column,
345
+ type=column_schema.type,
346
+ description=column_schema.description,
347
+ usage=column_schema.usage,
348
+ values=value_str,
349
+ examples=example_str,
350
+ )
351
+
352
+ def _prepare_table_prompt(
353
+ self, table_schema: TableSchema, boolean_system_prompt: bool = True
354
+ ) -> ChatPromptTemplate:
355
+ if boolean_system_prompt is True:
356
+ system_prompt = self.boolean_system_prompt
357
+ else:
358
+ system_prompt = self.generative_system_prompt
359
+
360
+ base_prompt_template = ChatPromptTemplate.from_messages(
361
+ [("system", system_prompt), ("user", self.table_prompt_template)]
362
+ )
363
+
364
+ header_str = "This schema describes a table in the database."
365
+
366
+ columns_str = ""
367
+ for column_key, column_schema in table_schema.columns.items():
368
+ columns_str += f"""
369
+ - **{column_schema.column}:** {column_schema.description}
370
+ """
371
+
372
+ if getattr(table_schema, "examples", None) is not None:
373
+ example_str = """## **Example Questions**
374
+ """
375
+ for example in table_schema.examples:
376
+ example_str += f"""- {example}
377
+ """
378
+ else:
379
+ example_str = ""
380
+
106
381
  return base_prompt_template.partial(
107
- schema=schema_prompt_str,
108
- examples=examples_prompt_str
382
+ header=header_str,
383
+ table=table_schema.table,
384
+ description=table_schema.description,
385
+ usage=table_schema.usage,
386
+ columns=columns_str,
387
+ examples=example_str,
109
388
  )
110
389
 
390
+ def _rank_schema(self, prompt: ChatPromptTemplate, query: str) -> float:
391
+ rank_chain = LLMChain(
392
+ llm=self.llm.bind(logprobs=True), prompt=prompt, return_final_only=False
393
+ )
394
+ output = rank_chain({"query": query}) # returns metadata
395
+
396
+ # parse through metadata tokens until encountering either yes, or no.
397
+ score = None # a None score indicates the model output could not be parsed.
398
+ for content in output["full_generation"][0].message.response_metadata[
399
+ "logprobs"
400
+ ]["content"]:
401
+ # Convert answer to score using the model's confidence
402
+ if content["token"].lower().strip() == "yes":
403
+ score = (
404
+ 1 + math.exp(content["logprob"])
405
+ ) / 2 # If yes, use the model's confidence
406
+ break
407
+ elif content["token"].lower().strip() == "no":
408
+ score = (
409
+ 1 - math.exp(content["logprob"])
410
+ ) / 2 # If no, invert the confidence
411
+ break
412
+
413
+ if score is None:
414
+ score = 0.0
415
+
416
+ return score
417
+
418
+ def _breadth_first_search(self, query: str, greedy: bool = False) -> Tuple:
419
+ """Search breadth wise through Tables, then Columns, then Values.Uses a greedy strategy to maximize quota if greedy=True, otherwise a dynamic strategy."""
420
+
421
+ # sort based on priority
422
+ ordered_database_schema = self._sort_database_schema_by_key(
423
+ database_schema=self.database_schema, key=self._sort_schema_by_priority_key
424
+ )
425
+
426
+ # Rank Tables ########################################################
427
+ greedy_count = 0
428
+ tables = {}
429
+ # rank tables by relevance
430
+ for table_key, table_schema in ordered_database_schema.tables.items():
431
+ prompt: ChatPromptTemplate = self._prepare_table_prompt(
432
+ table_schema=table_schema, boolean_system_prompt=True
433
+ )
434
+ table_schema.relevance = self._rank_schema(prompt=prompt, query=query)
435
+
436
+ # only keep greedy tables
437
+ tables[table_key] = table_schema
438
+
439
+ if greedy:
440
+ if table_schema.relevance >= ordered_database_schema.filter_threshold:
441
+ greedy_count += 1
442
+ if greedy_count >= ordered_database_schema.max_filters:
443
+ break
444
+
445
+ # sort tables
446
+ ordered_database_schema = self._sort_schema_by_key(
447
+ schema=ordered_database_schema,
448
+ key=self._sort_schema_by_relevance_key,
449
+ update=tables,
450
+ )
451
+
452
+ # Rank Columns #######################################################
453
+ # iterate through tables to rank columns
454
+ tables = {}
455
+ table_count = 0 # take only the top n number of tables specified by the databases max filters
456
+ for table_key, table_schema in ordered_database_schema.tables.items():
457
+ # only drop into tables above the filter threshold
458
+ if table_schema.relevance >= ordered_database_schema.filter_threshold:
459
+ greedy_count = 0
460
+ # rank columns by relevance
461
+ columns = {}
462
+ for column_key, column_schema in table_schema.columns.items():
463
+ prompt: ChatPromptTemplate = self._prepare_column_prompt(
464
+ column_schema=column_schema,
465
+ table_schema=table_schema,
466
+ boolean_system_prompt=True,
467
+ )
468
+ column_schema.relevance = self._rank_schema(
469
+ prompt=prompt, query=query
470
+ )
471
+
472
+ columns[column_key] = column_schema
473
+
474
+ if greedy:
475
+ if column_schema.relevance >= table_schema.filter_threshold:
476
+ greedy_count += 1
477
+ if greedy_count >= table_schema.max_filters:
478
+ break
479
+
480
+ # sort columns and keep only columns that made the cut.
481
+ tables[table_key] = self._sort_schema_by_key(
482
+ table_schema, key=self._sort_schema_by_relevance_key, update=columns
483
+ )
484
+
485
+ table_count += 1
486
+ if table_count >= ordered_database_schema.max_filters:
487
+ break
488
+
489
+ # sort tables and keep only tables that made the cut.
490
+ ordered_database_schema = self._sort_schema_by_key(
491
+ ordered_database_schema,
492
+ key=self._sort_schema_by_relevance_key,
493
+ update=tables,
494
+ )
495
+
496
+ # Rank Values ########################################################
497
+ # iterate through tables to rank values
498
+ tables = {}
499
+ for table_key, table_schema in ordered_database_schema.tables.items():
500
+ columns = {}
501
+ column_count = 0
502
+ # iterate through columns to rank values
503
+ for column_key, column_schema in table_schema.columns.items():
504
+ if column_schema.relevance >= table_schema.filter_threshold:
505
+ greedy_count = 0
506
+ values = {}
507
+ # rank values by relevance
508
+ for value_key, value_schema in column_schema.values.items():
509
+ prompt: ChatPromptTemplate = self._prepare_value_prompt(
510
+ value_schema=value_schema,
511
+ column_schema=column_schema,
512
+ table_schema=table_schema,
513
+ boolean_system_prompt=True,
514
+ )
515
+ value_schema.relevance = self._rank_schema(
516
+ prompt=prompt, query=query
517
+ )
518
+
519
+ values[value_key] = value_schema
520
+
521
+ if greedy:
522
+ if value_schema.relevance >= column_schema.filter_threshold:
523
+ greedy_count += 1
524
+ if greedy_count >= column_schema.max_filters:
525
+ break
526
+
527
+ # sort values and keep only values that make the cut
528
+ columns[column_key] = self._sort_schema_by_key(
529
+ column_schema,
530
+ key=self._sort_schema_by_relevance_key,
531
+ update=values,
532
+ )
533
+
534
+ column_count += 1
535
+ if column_count >= table_schema.max_filters:
536
+ break
537
+
538
+ # sort columns and keep only columns that made the cut
539
+ tables[table_key] = self._sort_schema_by_key(
540
+ table_schema, key=self._sort_schema_by_relevance_key, update=columns
541
+ )
542
+
543
+ # sort tables and keep only tables that made the cut.
544
+ ordered_database_schema = self._sort_schema_by_key(
545
+ ordered_database_schema,
546
+ key=self._sort_schema_by_relevance_key,
547
+ update=tables,
548
+ )
549
+
550
+ # discard low ranked values ###################################################################################
551
+ tables = {}
552
+ for table_key, table_schema in ordered_database_schema.tables.items():
553
+ columns = {}
554
+ # iterate through columns to rank values
555
+ for column_key, column_schema in table_schema.columns.items():
556
+ value_count = 0
557
+ values = {}
558
+ # rank values by relevance
559
+ for value_key, value_schema in column_schema.values.items():
560
+ if value_schema.relevance >= column_schema.filter_threshold:
561
+ values[value_key] = value_schema
562
+
563
+ value_count += 1
564
+ if value_count >= column_schema.max_filters:
565
+ break
566
+
567
+ # sort values and keep only values that make the cut
568
+ columns[column_key] = self._sort_schema_by_key(
569
+ column_schema,
570
+ key=self._sort_schema_by_relevance_key,
571
+ update=values,
572
+ )
573
+
574
+ # sort columns and keep only columns that made the cut
575
+ tables[table_key] = self._sort_schema_by_key(
576
+ table_schema, key=self._sort_schema_by_relevance_key, update=columns
577
+ )
578
+
579
+ # sort tables and keep only tables that made the cut.
580
+ ordered_database_schema = self._sort_schema_by_key(
581
+ ordered_database_schema,
582
+ key=self._sort_schema_by_relevance_key,
583
+ update=tables,
584
+ )
585
+
586
+ ranked_database_schema = ordered_database_schema
587
+
588
+ # Build Ablation #####################################################
589
+
590
+ ablation_value_dict = {}
591
+ # assemble a relevance dictionary
592
+ for table_key, table_schema in ordered_database_schema.tables.items():
593
+ for column_key, column_schema in table_schema.columns.items():
594
+ for value_key, value_schema in column_schema.values.items():
595
+ ablation_value_dict[(table_key, column_key, value_key)] = (
596
+ value_schema.relevance
597
+ )
598
+
599
+ ablation_value_dict = collections.OrderedDict(
600
+ sorted(ablation_value_dict.items(), key=lambda x: x[1])
601
+ )
602
+
603
+ relevance_scores = list(ablation_value_dict.values())
604
+ if len(relevance_scores) > 0:
605
+ ablation_quantiles = np.quantile(
606
+ relevance_scores, np.linspace(0, 1, self.num_retries + 2)[1:-1]
607
+ )
608
+ else:
609
+ ablation_quantiles = None
610
+
611
+ return ranked_database_schema, ablation_value_dict, ablation_quantiles
612
+
613
+ def _dynamic_ablation(
614
+ self,
615
+ metadata_filters: List[AblativeMetadataFilter],
616
+ ablation_value_dict,
617
+ ablation_quantiles,
618
+ retry: int,
619
+ ):
620
+ """Ablate metadata filters in aggregate by quantiles until the required minimum number of documents are returned."""
621
+
622
+ ablated_dict = {}
623
+ for key, value in ablation_value_dict.items():
624
+ if value >= ablation_quantiles[retry]:
625
+ ablated_dict[key] = value
626
+
627
+ # discard low ranked filters ##################################################################################
628
+ ablated_filters = []
629
+ for filter in metadata_filters:
630
+ for key in ablated_dict.keys():
631
+ if (
632
+ filter.schema_table in key
633
+ and filter.schema_column in key
634
+ and filter.schema_value in key
635
+ ):
636
+ ablated_filters.append(filter)
637
+
638
+ return ablated_filters
639
+
640
+ def depth_first_search(self, greedy=True):
641
+ """Search depth wise through Tables, then Columns, then Values. Uses a greedy strategy to maximize quota if greedy=True, otherwise a dynamic strategy."""
642
+ pass
643
+
644
+ def depth_first_ablation(self):
645
+ """Ablate metadata filters in reverse depth first search until the required minimum number of documents are returned."""
646
+ pass
647
+
111
648
  def _prepare_retrieval_query(self, query: str) -> str:
112
649
  rewrite_prompt = PromptTemplate(
113
- input_variables=['input'],
114
- template=self.rewrite_prompt_template
650
+ input_variables=["input"], template=self.rewrite_prompt_template
115
651
  )
116
652
  rewrite_chain = LLMChain(llm=self.llm, prompt=rewrite_prompt)
117
653
  return rewrite_chain.predict(input=query)
118
654
 
119
- def _prepare_pgvector_query(self, metadata_filters: List[MetadataFilter]) -> str:
655
+ def _prepare_pgvector_query(
656
+ self,
657
+ ranked_database_schema: DatabaseSchema,
658
+ metadata_filters: List[AblativeMetadataFilter],
659
+ retry: int = 0,
660
+ ) -> str:
120
661
  # Base select JOINed with document source table.
121
- base_query = f'''SELECT * FROM {self.embeddings_table} AS e INNER JOIN {self.source_table} AS s ON (e.metadata->>'original_row_id')::int = s."{self.source_id_column}" '''
122
- col_to_schema = {}
123
- if not self.metadata_schemas:
124
- return ''
125
- for schema in self.metadata_schemas:
126
- for col in schema.columns:
127
- col_to_schema[col.name] = schema
128
- joined_schemas = set()
129
- for filter in metadata_filters:
130
- # Join schemas before filtering.
131
- schema = col_to_schema.get(filter.attribute)
132
- if schema is None or schema.table in joined_schemas or schema.table == self.source_table:
662
+ base_query = f"""SELECT * FROM {self.embeddings_table} AS e INNER JOIN {self.source_table} AS s ON (e.metadata->>'original_row_id')::int = s."{self.source_id_column}" """
663
+
664
+ # return an empty string if schema has not been ranked
665
+ if not ranked_database_schema:
666
+ return ""
667
+
668
+ # Add Table JOIN statements
669
+ join_clauses = set()
670
+ for metadata_filter in metadata_filters:
671
+ join_clause = ranked_database_schema.tables[
672
+ metadata_filter.schema_table
673
+ ].join
674
+ if join_clause in join_clauses:
133
675
  continue
134
- joined_schemas.add(schema.table)
135
- base_query += schema.join + ' '
136
- # Actually construct WHERE conditions from metadata filters.
676
+ else:
677
+ join_clauses.add(join_clause)
678
+ base_query += join_clause + " "
679
+
680
+ # Add WHERE conditions from metadata filters
137
681
  if metadata_filters:
138
- base_query += 'WHERE '
682
+ base_query += "WHERE "
139
683
  for i, filter in enumerate(metadata_filters):
140
684
  value = filter.value
141
685
  if isinstance(value, str):
142
686
  value = f"'{value}'"
143
687
  base_query += f'"{filter.attribute}" {filter.comparator} {value}'
144
688
  if i < len(metadata_filters) - 1:
145
- base_query += ' AND '
689
+ base_query += " AND "
690
+
146
691
  base_query += f" ORDER BY e.embeddings {self.distance_function.value[0]} '{{embeddings}}' LIMIT {self.search_kwargs.k};"
147
692
  return base_query
148
693
 
149
- def _generate_metadata_filters(self, query: str) -> List[MetadataFilter]:
150
- parser = PydanticOutputParser(pydantic_object=MetadataFilters)
151
- metadata_prompt = self._prepare_metadata_prompt()
152
- metadata_filters_chain = LLMChain(llm=self.llm, prompt=metadata_prompt)
153
- metadata_filters_output = metadata_filters_chain.predict(
154
- format_instructions=parser.get_format_instructions(),
155
- input=query
156
- )
157
- # If the LLM outputs raw JSON, use it as-is.
158
- # If the LLM outputs anything including a json markdown section, use the last one.
159
- json_markdown_output = re.findall(r'```json.*```', metadata_filters_output, re.DOTALL)
160
- if json_markdown_output:
161
- metadata_filters_output = json_markdown_output[-1]
162
- # Clean the json tags.
163
- metadata_filters_output = metadata_filters_output[7:]
164
- metadata_filters_output = metadata_filters_output[:-3]
165
- metadata_filters = parser.invoke(metadata_filters_output)
166
- return metadata_filters.filters
167
-
168
- def _prepare_and_execute_query(self, query: str, embeddings_str: str) -> HandlerResponse:
694
+ def _generate_filter(
695
+ self, prompt: ChatPromptTemplate, query: str
696
+ ) -> MetadataFilter:
697
+ gen_filter_chain = LLMChain(llm=self.llm, prompt=prompt)
698
+ output = gen_filter_chain({"query": query})
699
+ return output
700
+
701
+ def _generate_metadata_filters(
702
+ self, query: str, ranked_database_schema
703
+ ) -> Union[List[AblativeMetadataFilter], HandlerResponse]:
704
+ parser = PydanticOutputParser(pydantic_object=MetadataFilter)
705
+
706
+ metadata_filter_list = []
707
+ # iterate through tables to rank values
708
+ for table_key, table_schema in ranked_database_schema.tables.items():
709
+ # iterate through columns to rank values
710
+ for column_key, column_schema in table_schema.columns.items():
711
+ if column_schema.relevance >= table_schema.filter_threshold:
712
+ # generate filters
713
+ for value_key, value_schema in column_schema.values.items():
714
+ # must use generation if field is a dictionary of tuples or a list
715
+ if type(value_schema.value) in [list, dict]:
716
+ try:
717
+ metadata_prompt: ChatPromptTemplate = (
718
+ self._prepare_value_prompt(
719
+ format_instructions=parser.get_format_instructions(),
720
+ value_schema=value_schema,
721
+ column_schema=column_schema,
722
+ table_schema=table_schema,
723
+ boolean_system_prompt=False,
724
+ )
725
+ )
726
+
727
+ metadata_filters_chain = LLMChain(
728
+ llm=self.llm, prompt=metadata_prompt
729
+ )
730
+ metadata_filter_output = metadata_filters_chain.predict(
731
+ query=query,
732
+ )
733
+
734
+ # If the LLM outputs raw JSON, use it as-is.
735
+ # If the LLM outputs anything including a json markdown section, use the last one.
736
+ json_markdown_output = re.findall(
737
+ r"```json.*```", metadata_filter_output, re.DOTALL
738
+ )
739
+ if json_markdown_output:
740
+ metadata_filter_output = json_markdown_output[-1]
741
+ # Clean the json tags.
742
+ metadata_filter_output = metadata_filter_output[7:]
743
+ metadata_filter_output = metadata_filter_output[:-3]
744
+
745
+ metadata_filter = parser.invoke(metadata_filter_output)
746
+ model_dump = metadata_filter.model_dump()
747
+ model_dump.update(
748
+ {
749
+ "schema_table": table_key,
750
+ "schema_column": column_key,
751
+ "schema_value": value_key,
752
+ }
753
+ )
754
+ metadata_filter = AblativeMetadataFilter(**model_dump)
755
+ except OutputParserException as e:
756
+ logger.warning(
757
+ f"LLM failed to generate structured metadata filters: {str(e)}"
758
+ )
759
+ return HandlerResponse(
760
+ RESPONSE_TYPE.ERROR, error_message=str(e)
761
+ )
762
+ else:
763
+ metadata_filter = AblativeMetadataFilter(
764
+ attribute=column_schema.column,
765
+ comparator=value_schema.comparator,
766
+ value=value_schema.value,
767
+ schema_table=table_key,
768
+ schema_column=column_key,
769
+ schema_value=value_key,
770
+ )
771
+ metadata_filter_list.append(metadata_filter)
772
+
773
+ return metadata_filter_list
774
+
775
+ def _prepare_and_execute_query(
776
+ self,
777
+ ranked_database_schema: DatabaseSchema,
778
+ metadata_filters: List[AblativeMetadataFilter],
779
+ embeddings_str: str,
780
+ ) -> HandlerResponse:
169
781
  try:
170
- metadata_filters = self._generate_metadata_filters(query)
171
- checked_sql_query = self._prepare_pgvector_query(metadata_filters)
172
- checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=embeddings_str)
173
- return self.vector_store_handler.native_query(checked_sql_query_with_embeddings)
174
- except OutputParserException as e:
175
- logger.warning(f'LLM failed to generate structured metadata filters: {str(e)}')
176
- return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
782
+ checked_sql_query = self._prepare_pgvector_query(
783
+ ranked_database_schema, metadata_filters
784
+ )
785
+ checked_sql_query_with_embeddings = checked_sql_query.format(
786
+ embeddings=embeddings_str
787
+ )
788
+ return self.vector_store_handler.native_query(
789
+ checked_sql_query_with_embeddings
790
+ )
177
791
  except Exception as e:
178
- logger.warning(f'Failed to prepare and execute SQL query from structured metadata: {str(e)}')
792
+ logger.warning(
793
+ f"Failed to prepare and execute SQL query from structured metadata: {str(e)}"
794
+ )
179
795
  return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
180
796
 
181
797
  def _get_relevant_documents(
@@ -183,36 +799,85 @@ Output:
183
799
  ) -> List[Document]:
184
800
  # Rewrite query to be suitable for retrieval.
185
801
  retrieval_query = self._prepare_retrieval_query(query)
802
+
186
803
  # Embed the rewritten retrieval query & include it in the similarity search pgvector query.
187
804
  embedded_query = self.embeddings_model.embed_query(retrieval_query)
188
- # Actually execute the similarity search with metadata filters.
189
- document_response = self._prepare_and_execute_query(retrieval_query, str(embedded_query))
190
- num_retries = 0
191
- while num_retries < self.num_retries:
192
- if document_response.resp_type != RESPONSE_TYPE.ERROR and len(document_response.data_frame) > 0:
193
- # Successfully retrieved documents.
194
- break
195
- if document_response.resp_type == RESPONSE_TYPE.ERROR:
196
- # LLMs won't always generate structured metadata so we should have a fallback after retrying.
197
- logger.info(f'SQL Retriever query failed with error {document_response.error_message}')
198
- elif len(document_response.data_frame) == 0:
199
- logger.info('No documents retrieved from SQL Retriever query')
200
-
201
- document_response = self._prepare_and_execute_query(retrieval_query, str(embedded_query))
202
- num_retries += 1
203
- if num_retries >= self.num_retries:
204
- logger.info('Using fallback retriever in SQL retriever.')
205
- return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager)
206
-
207
- document_df = document_response.data_frame
208
- retrieved_documents = []
209
- for _, document_row in document_df.iterrows():
210
- retrieved_documents.append(Document(
211
- document_row.get('content', ''),
212
- metadata=document_row.get('metadata', {})
213
- ))
214
- if retrieved_documents:
215
- return retrieved_documents
216
- # If the SQL query constructed did not return any documents, fallback.
217
- logger.info('No documents returned from SQL retriever. using fallback retriever.')
218
- return self.fallback_retriever._get_relevant_documents(retrieval_query, run_manager=run_manager)
805
+
806
+ # Search for relevant filters
807
+ ranked_database_schema, ablation_value_dict, ablation_quantiles = (
808
+ self._breadth_first_search(query=query)
809
+ )
810
+
811
+ # Generate metadata filters
812
+ metadata_filters = self._generate_metadata_filters(
813
+ query=query, ranked_database_schema=ranked_database_schema
814
+ )
815
+
816
+ if type(metadata_filters) is list:
817
+ # Initial Execution of the similarity search with metadata filters.
818
+ document_response = self._prepare_and_execute_query(
819
+ ranked_database_schema=ranked_database_schema,
820
+ metadata_filters=metadata_filters,
821
+ embeddings_str=str(embedded_query),
822
+ )
823
+ num_retries = 0
824
+ while num_retries < self.num_retries:
825
+ if (
826
+ document_response.resp_type != RESPONSE_TYPE.ERROR
827
+ and len(document_response.data_frame) >= self.min_k
828
+ ):
829
+ # Successfully retrieved k documents to send to re-ranker.
830
+ break
831
+ elif document_response.resp_type == RESPONSE_TYPE.ERROR:
832
+ # LLMs won't always generate structured metadata so we should have a fallback after retrying.
833
+ logger.info(
834
+ f"SQL Retriever query failed with error {document_response.error_message}"
835
+ )
836
+ else:
837
+ logger.info(
838
+ f"SQL Retriever did not retrieve {self.min_k} documents: {len(document_response.data_frame)} documents retrieved."
839
+ )
840
+
841
+ ablated_metadata_filters = self._dynamic_ablation(
842
+ metadata_filters=metadata_filters,
843
+ ablation_value_dict=ablation_value_dict,
844
+ ablation_quantiles=ablation_quantiles,
845
+ retry=num_retries,
846
+ )
847
+
848
+ document_response = self._prepare_and_execute_query(
849
+ ranked_database_schema=ranked_database_schema,
850
+ metadata_filters=ablated_metadata_filters,
851
+ embeddings_str=str(embedded_query),
852
+ )
853
+
854
+ num_retries += 1
855
+
856
+ retrieved_documents = []
857
+ if document_response.resp_type != RESPONSE_TYPE.ERROR:
858
+ document_df = document_response.data_frame
859
+ for _, document_row in document_df.iterrows():
860
+ retrieved_documents.append(
861
+ Document(
862
+ document_row.get("content", ""),
863
+ metadata=document_row.get("metadata", {}),
864
+ )
865
+ )
866
+ if retrieved_documents:
867
+ return retrieved_documents
868
+
869
+ # If the SQL query constructed did not return any documents, fallback.
870
+ logger.info(
871
+ "No documents returned from SQL retriever, using fallback retriever."
872
+ )
873
+ return self.fallback_retriever._get_relevant_documents(
874
+ retrieval_query, run_manager=run_manager
875
+ )
876
+ else:
877
+ # If no metadata fields could be generated fallback.
878
+ logger.info(
879
+ "No metadata fields were successfully generated, using fallback retriever."
880
+ )
881
+ return self.fallback_retriever._get_relevant_documents(
882
+ retrieval_query, run_manager=run_manager
883
+ )