MindsDB 25.9.2.0a1__py3-none-any.whl → 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (164) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/__main__.py +40 -29
  3. mindsdb/api/a2a/__init__.py +1 -1
  4. mindsdb/api/a2a/agent.py +16 -10
  5. mindsdb/api/a2a/common/server/server.py +7 -3
  6. mindsdb/api/a2a/common/server/task_manager.py +12 -5
  7. mindsdb/api/a2a/common/types.py +66 -0
  8. mindsdb/api/a2a/task_manager.py +65 -17
  9. mindsdb/api/common/middleware.py +10 -12
  10. mindsdb/api/executor/command_executor.py +51 -40
  11. mindsdb/api/executor/datahub/datanodes/datanode.py +2 -2
  12. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +7 -13
  13. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +101 -49
  14. mindsdb/api/executor/datahub/datanodes/project_datanode.py +8 -4
  15. mindsdb/api/executor/datahub/datanodes/system_tables.py +3 -2
  16. mindsdb/api/executor/exceptions.py +29 -10
  17. mindsdb/api/executor/planner/plan_join.py +17 -3
  18. mindsdb/api/executor/planner/query_prepare.py +2 -20
  19. mindsdb/api/executor/sql_query/sql_query.py +74 -74
  20. mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +1 -2
  21. mindsdb/api/executor/sql_query/steps/subselect_step.py +0 -1
  22. mindsdb/api/executor/utilities/functions.py +6 -6
  23. mindsdb/api/executor/utilities/sql.py +37 -20
  24. mindsdb/api/http/gui.py +5 -11
  25. mindsdb/api/http/initialize.py +75 -61
  26. mindsdb/api/http/namespaces/agents.py +10 -15
  27. mindsdb/api/http/namespaces/analysis.py +13 -20
  28. mindsdb/api/http/namespaces/auth.py +1 -1
  29. mindsdb/api/http/namespaces/chatbots.py +0 -5
  30. mindsdb/api/http/namespaces/config.py +15 -11
  31. mindsdb/api/http/namespaces/databases.py +140 -201
  32. mindsdb/api/http/namespaces/file.py +17 -4
  33. mindsdb/api/http/namespaces/handlers.py +17 -7
  34. mindsdb/api/http/namespaces/knowledge_bases.py +28 -7
  35. mindsdb/api/http/namespaces/models.py +94 -126
  36. mindsdb/api/http/namespaces/projects.py +13 -22
  37. mindsdb/api/http/namespaces/sql.py +33 -25
  38. mindsdb/api/http/namespaces/tab.py +27 -37
  39. mindsdb/api/http/namespaces/views.py +1 -1
  40. mindsdb/api/http/start.py +16 -10
  41. mindsdb/api/mcp/__init__.py +2 -1
  42. mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +15 -20
  43. mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +26 -50
  44. mindsdb/api/mysql/mysql_proxy/utilities/__init__.py +0 -1
  45. mindsdb/api/mysql/mysql_proxy/utilities/dump.py +8 -2
  46. mindsdb/integrations/handlers/byom_handler/byom_handler.py +165 -190
  47. mindsdb/integrations/handlers/databricks_handler/databricks_handler.py +98 -46
  48. mindsdb/integrations/handlers/druid_handler/druid_handler.py +32 -40
  49. mindsdb/integrations/handlers/file_handler/file_handler.py +7 -0
  50. mindsdb/integrations/handlers/gitlab_handler/gitlab_handler.py +5 -2
  51. mindsdb/integrations/handlers/lightwood_handler/functions.py +45 -79
  52. mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +438 -100
  53. mindsdb/integrations/handlers/mssql_handler/requirements_odbc.txt +3 -0
  54. mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +235 -3
  55. mindsdb/integrations/handlers/oracle_handler/__init__.py +2 -0
  56. mindsdb/integrations/handlers/oracle_handler/connection_args.py +7 -1
  57. mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +321 -16
  58. mindsdb/integrations/handlers/oracle_handler/requirements.txt +1 -1
  59. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +14 -2
  60. mindsdb/integrations/handlers/shopify_handler/requirements.txt +1 -0
  61. mindsdb/integrations/handlers/shopify_handler/shopify_handler.py +80 -13
  62. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +2 -1
  63. mindsdb/integrations/handlers/statsforecast_handler/requirements.txt +1 -0
  64. mindsdb/integrations/handlers/statsforecast_handler/requirements_extra.txt +1 -0
  65. mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +4 -4
  66. mindsdb/integrations/handlers/zendesk_handler/zendesk_tables.py +144 -111
  67. mindsdb/integrations/libs/api_handler.py +10 -10
  68. mindsdb/integrations/libs/base.py +4 -4
  69. mindsdb/integrations/libs/llm/utils.py +2 -2
  70. mindsdb/integrations/libs/ml_handler_process/create_engine_process.py +4 -7
  71. mindsdb/integrations/libs/ml_handler_process/func_call_process.py +2 -7
  72. mindsdb/integrations/libs/ml_handler_process/learn_process.py +37 -47
  73. mindsdb/integrations/libs/ml_handler_process/update_engine_process.py +4 -7
  74. mindsdb/integrations/libs/ml_handler_process/update_process.py +2 -7
  75. mindsdb/integrations/libs/process_cache.py +132 -140
  76. mindsdb/integrations/libs/response.py +18 -12
  77. mindsdb/integrations/libs/vectordatabase_handler.py +26 -0
  78. mindsdb/integrations/utilities/files/file_reader.py +6 -7
  79. mindsdb/integrations/utilities/handlers/auth_utilities/snowflake/__init__.py +1 -0
  80. mindsdb/integrations/utilities/handlers/auth_utilities/snowflake/snowflake_jwt_gen.py +151 -0
  81. mindsdb/integrations/utilities/rag/config_loader.py +37 -26
  82. mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +83 -30
  83. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +4 -4
  84. mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +55 -133
  85. mindsdb/integrations/utilities/rag/settings.py +58 -133
  86. mindsdb/integrations/utilities/rag/splitters/file_splitter.py +5 -15
  87. mindsdb/interfaces/agents/agents_controller.py +2 -3
  88. mindsdb/interfaces/agents/constants.py +0 -2
  89. mindsdb/interfaces/agents/litellm_server.py +34 -58
  90. mindsdb/interfaces/agents/mcp_client_agent.py +10 -10
  91. mindsdb/interfaces/agents/mindsdb_database_agent.py +5 -5
  92. mindsdb/interfaces/agents/run_mcp_agent.py +12 -21
  93. mindsdb/interfaces/chatbot/chatbot_task.py +20 -23
  94. mindsdb/interfaces/chatbot/polling.py +30 -18
  95. mindsdb/interfaces/data_catalog/data_catalog_loader.py +16 -17
  96. mindsdb/interfaces/data_catalog/data_catalog_reader.py +15 -4
  97. mindsdb/interfaces/database/data_handlers_cache.py +190 -0
  98. mindsdb/interfaces/database/database.py +3 -3
  99. mindsdb/interfaces/database/integrations.py +7 -110
  100. mindsdb/interfaces/database/projects.py +2 -6
  101. mindsdb/interfaces/database/views.py +1 -4
  102. mindsdb/interfaces/file/file_controller.py +6 -6
  103. mindsdb/interfaces/functions/controller.py +1 -1
  104. mindsdb/interfaces/functions/to_markdown.py +2 -2
  105. mindsdb/interfaces/jobs/jobs_controller.py +5 -9
  106. mindsdb/interfaces/jobs/scheduler.py +3 -9
  107. mindsdb/interfaces/knowledge_base/controller.py +244 -128
  108. mindsdb/interfaces/knowledge_base/evaluate.py +36 -41
  109. mindsdb/interfaces/knowledge_base/executor.py +11 -0
  110. mindsdb/interfaces/knowledge_base/llm_client.py +51 -17
  111. mindsdb/interfaces/knowledge_base/preprocessing/json_chunker.py +40 -61
  112. mindsdb/interfaces/model/model_controller.py +172 -168
  113. mindsdb/interfaces/query_context/context_controller.py +14 -2
  114. mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +10 -14
  115. mindsdb/interfaces/skills/retrieval_tool.py +43 -50
  116. mindsdb/interfaces/skills/skill_tool.py +2 -2
  117. mindsdb/interfaces/skills/skills_controller.py +1 -4
  118. mindsdb/interfaces/skills/sql_agent.py +25 -19
  119. mindsdb/interfaces/storage/db.py +16 -6
  120. mindsdb/interfaces/storage/fs.py +114 -169
  121. mindsdb/interfaces/storage/json.py +19 -18
  122. mindsdb/interfaces/tabs/tabs_controller.py +49 -72
  123. mindsdb/interfaces/tasks/task_monitor.py +3 -9
  124. mindsdb/interfaces/tasks/task_thread.py +7 -9
  125. mindsdb/interfaces/triggers/trigger_task.py +7 -13
  126. mindsdb/interfaces/triggers/triggers_controller.py +47 -52
  127. mindsdb/migrations/migrate.py +16 -16
  128. mindsdb/utilities/api_status.py +58 -0
  129. mindsdb/utilities/config.py +68 -2
  130. mindsdb/utilities/exception.py +40 -1
  131. mindsdb/utilities/fs.py +0 -1
  132. mindsdb/utilities/hooks/profiling.py +17 -14
  133. mindsdb/utilities/json_encoder.py +24 -10
  134. mindsdb/utilities/langfuse.py +40 -45
  135. mindsdb/utilities/log.py +272 -0
  136. mindsdb/utilities/ml_task_queue/consumer.py +52 -58
  137. mindsdb/utilities/ml_task_queue/producer.py +26 -30
  138. mindsdb/utilities/render/sqlalchemy_render.py +22 -20
  139. mindsdb/utilities/starters.py +0 -10
  140. mindsdb/utilities/utils.py +2 -2
  141. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0.dist-info}/METADATA +286 -267
  142. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0.dist-info}/RECORD +145 -159
  143. mindsdb/api/mysql/mysql_proxy/utilities/exceptions.py +0 -14
  144. mindsdb/api/postgres/__init__.py +0 -0
  145. mindsdb/api/postgres/postgres_proxy/__init__.py +0 -0
  146. mindsdb/api/postgres/postgres_proxy/executor/__init__.py +0 -1
  147. mindsdb/api/postgres/postgres_proxy/executor/executor.py +0 -189
  148. mindsdb/api/postgres/postgres_proxy/postgres_packets/__init__.py +0 -0
  149. mindsdb/api/postgres/postgres_proxy/postgres_packets/errors.py +0 -322
  150. mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_fields.py +0 -34
  151. mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_message.py +0 -31
  152. mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_message_formats.py +0 -1265
  153. mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_message_identifiers.py +0 -31
  154. mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_packets.py +0 -253
  155. mindsdb/api/postgres/postgres_proxy/postgres_proxy.py +0 -477
  156. mindsdb/api/postgres/postgres_proxy/utilities/__init__.py +0 -10
  157. mindsdb/api/postgres/start.py +0 -11
  158. mindsdb/integrations/handlers/mssql_handler/tests/__init__.py +0 -0
  159. mindsdb/integrations/handlers/mssql_handler/tests/test_mssql_handler.py +0 -169
  160. mindsdb/integrations/handlers/oracle_handler/tests/__init__.py +0 -0
  161. mindsdb/integrations/handlers/oracle_handler/tests/test_oracle_handler.py +0 -32
  162. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0.dist-info}/WHEEL +0 -0
  163. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0.dist-info}/licenses/LICENSE +0 -0
  164. {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,16 @@
1
- from dataclasses import dataclass, astuple
2
- import traceback
3
- import json
4
1
  import csv
5
- from io import BytesIO, StringIO, IOBase
6
- from pathlib import Path
2
+ import json
7
3
  import codecs
4
+ from io import BytesIO, StringIO, IOBase
8
5
  from typing import List, Generator
6
+ from pathlib import Path
7
+ from dataclasses import dataclass, astuple
9
8
 
10
9
  import filetype
11
10
  import pandas as pd
12
11
  from charset_normalizer import from_bytes
13
- from mindsdb.interfaces.knowledge_base.preprocessing.text_splitter import TextSplitter
14
12
 
13
+ from mindsdb.interfaces.knowledge_base.preprocessing.text_splitter import TextSplitter
15
14
  from mindsdb.utilities import log
16
15
 
17
16
  logger = log.getLogger(__name__)
@@ -76,7 +75,7 @@ def decode(file_obj: IOBase) -> StringIO:
76
75
 
77
76
  data_str = StringIO(byte_str.decode(encoding, errors))
78
77
  except Exception as e:
79
- logger.error(traceback.format_exc())
78
+ logger.exception("Error during file decode:")
80
79
  raise FileProcessingError("Could not load into string") from e
81
80
 
82
81
  return data_str
@@ -0,0 +1 @@
1
+ from .snowflake_jwt_gen import get_validated_jwt as get_validated_jwt
@@ -0,0 +1,151 @@
1
+ # Based on https://docs.snowflake.com/en/developer-guide/sql-api/authenticating
2
+
3
+ import time
4
+ import base64
5
+ import hashlib
6
+ import logging
7
+ from datetime import timedelta, timezone, datetime
8
+
9
+ from cryptography.hazmat.primitives.serialization import load_pem_private_key
10
+ from cryptography.hazmat.primitives.serialization import Encoding
11
+ from cryptography.hazmat.primitives.serialization import PublicFormat
12
+ from cryptography.hazmat.backends import default_backend
13
+ import jwt
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ ISSUER = "iss"
18
+ EXPIRE_TIME = "exp"
19
+ ISSUE_TIME = "iat"
20
+ SUBJECT = "sub"
21
+
22
+
23
+ class JWTGenerator(object):
24
+ """
25
+ Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator keeps the
26
+ generated token and only regenerates the token if a specified period of time has passed.
27
+ """
28
+
29
+ LIFETIME = timedelta(minutes=60) # The tokens will have a 59 minute lifetime
30
+ ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
31
+
32
+ def __init__(self, account: str, user: str, private_key: str, lifetime: timedelta = LIFETIME):
33
+ """
34
+ __init__ creates an object that generates JWTs for the specified user, account identifier, and private key.
35
+ :param account: Your Snowflake account identifier. See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html. Note that if you are using the account locator, exclude any region information from the account locator.
36
+ :param user: The Snowflake username.
37
+ :param private_key: The private key file used for signing the JWTs.
38
+ :param lifetime: The number of minutes (as a timedelta) during which the key will be valid.
39
+ """
40
+
41
+ logger.info(
42
+ """Creating JWTGenerator with arguments
43
+ account : %s, user : %s, lifetime : %s""",
44
+ account,
45
+ user,
46
+ lifetime,
47
+ )
48
+
49
+ # Construct the fully qualified name of the user in uppercase.
50
+ self.account = self.prepare_account_name_for_jwt(account)
51
+ self.user = user.upper()
52
+ self.qualified_username = self.account + "." + self.user
53
+
54
+ self.lifetime = lifetime
55
+ self.renew_time = datetime.now(timezone.utc)
56
+ self.token = None
57
+
58
+ self.private_key = load_pem_private_key(private_key.encode(), None, default_backend())
59
+
60
+ def prepare_account_name_for_jwt(self, raw_account: str) -> str:
61
+ """
62
+ Prepare the account identifier for use in the JWT.
63
+ For the JWT, the account identifier must not include the subdomain or any region or cloud provider information.
64
+ :param raw_account: The specified account identifier.
65
+ :return: The account identifier in a form that can be used to generate JWT.
66
+ """
67
+ account = raw_account
68
+ if ".global" not in account:
69
+ # Handle the general case.
70
+ idx = account.find(".")
71
+ if idx > 0:
72
+ account = account[0:idx]
73
+ else:
74
+ # Handle the replication case.
75
+ idx = account.find("-")
76
+ if idx > 0:
77
+ account = account[0:idx]
78
+ # Use uppercase for the account identifier.
79
+ return account.upper()
80
+
81
+ def get_token(self) -> str:
82
+ """
83
+ Generates a new JWT.
84
+ :return: the new token
85
+ """
86
+ now = datetime.now(timezone.utc) # Fetch the current time
87
+
88
+ # Prepare the fields for the payload.
89
+ # Generate the public key fingerprint for the issuer in the payload.
90
+ public_key_fp = self.calculate_public_key_fingerprint(self.private_key)
91
+
92
+ # Create our payload
93
+ payload = {
94
+ # Set the issuer to the fully qualified username concatenated with the public key fingerprint.
95
+ ISSUER: self.qualified_username + "." + public_key_fp,
96
+ # Set the subject to the fully qualified username.
97
+ SUBJECT: self.qualified_username,
98
+ # Set the issue time to now.
99
+ ISSUE_TIME: now,
100
+ # Set the expiration time, based on the lifetime specified for this object.
101
+ EXPIRE_TIME: now + self.lifetime,
102
+ }
103
+
104
+ # Regenerate the actual token
105
+ token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
106
+ # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string, rather than a string.
107
+ # If the token is a byte string, convert it to a string.
108
+ if isinstance(token, bytes):
109
+ token = token.decode("utf-8")
110
+ self.token = token
111
+
112
+ return self.token
113
+
114
+ def calculate_public_key_fingerprint(self, private_key: str) -> str:
115
+ """
116
+ Given a private key in PEM format, return the public key fingerprint.
117
+ :param private_key: private key string
118
+ :return: public key fingerprint
119
+ """
120
+ # Get the raw bytes of public key.
121
+ public_key_raw = private_key.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
122
+
123
+ # Get the sha256 hash of the raw bytes.
124
+ sha256hash = hashlib.sha256()
125
+ sha256hash.update(public_key_raw)
126
+
127
+ # Base64-encode the value and prepend the prefix 'SHA256:'.
128
+ public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
129
+ logger.info("Public key fingerprint is %s", public_key_fp)
130
+
131
+ return public_key_fp
132
+
133
+
134
+ def get_validated_jwt(token: str, account: str, user: str, private_key: str) -> str:
135
+ try:
136
+ content = jwt.decode(token, algorithms=[JWTGenerator.ALGORITHM], options={"verify_signature": False})
137
+
138
+ expired = content.get("exp", 0)
139
+ # add 5 seconds before limit
140
+ if expired - 5 > time.time():
141
+ # keep the same
142
+ return token
143
+
144
+ except jwt.DecodeError:
145
+ # wrong key
146
+ ...
147
+
148
+ # generate new token
149
+ if private_key is None:
150
+ raise ValueError("Private key is missing")
151
+ return JWTGenerator(account, user, private_key).get_token()
@@ -1,17 +1,26 @@
1
1
  """Utility functions for RAG pipeline configuration"""
2
+
2
3
  from typing import Dict, Any, Optional
3
4
 
4
5
  from mindsdb.utilities.log import getLogger
5
6
  from mindsdb.integrations.utilities.rag.settings import (
6
- RetrieverType, MultiVectorRetrieverMode, SearchType,
7
- SearchKwargs, SummarizationConfig, VectorStoreConfig,
8
- RerankerConfig, RAGPipelineModel, DEFAULT_COLLECTION_NAME
7
+ RetrieverType,
8
+ MultiVectorRetrieverMode,
9
+ SearchType,
10
+ SearchKwargs,
11
+ SummarizationConfig,
12
+ VectorStoreConfig,
13
+ RerankerConfig,
14
+ RAGPipelineModel,
15
+ DEFAULT_COLLECTION_NAME,
9
16
  )
10
17
 
11
18
  logger = getLogger(__name__)
12
19
 
13
20
 
14
- def load_rag_config(base_config: Dict[str, Any], kb_params: Optional[Dict[str, Any]] = None, embedding_model: Any = None) -> RAGPipelineModel:
21
+ def load_rag_config(
22
+ base_config: Dict[str, Any], kb_params: Optional[Dict[str, Any]] = None, embedding_model: Any = None
23
+ ) -> RAGPipelineModel:
15
24
  """
16
25
  Load and validate RAG configuration parameters. This function handles the conversion of configuration
17
26
  parameters into their appropriate types and ensures all required settings are properly configured.
@@ -37,41 +46,43 @@ def load_rag_config(base_config: Dict[str, Any], kb_params: Optional[Dict[str, A
37
46
 
38
47
  # Set embedding model if provided
39
48
  if embedding_model is not None:
40
- rag_params['embedding_model'] = embedding_model
49
+ rag_params["embedding_model"] = embedding_model
41
50
 
42
51
  # Handle enums and type conversions
43
- if 'retriever_type' in rag_params:
44
- rag_params['retriever_type'] = RetrieverType(rag_params['retriever_type'])
45
- if 'multi_retriever_mode' in rag_params:
46
- rag_params['multi_retriever_mode'] = MultiVectorRetrieverMode(rag_params['multi_retriever_mode'])
47
- if 'search_type' in rag_params:
48
- rag_params['search_type'] = SearchType(rag_params['search_type'])
52
+ if "retriever_type" in rag_params:
53
+ rag_params["retriever_type"] = RetrieverType(rag_params["retriever_type"])
54
+ if "multi_retriever_mode" in rag_params:
55
+ rag_params["multi_retriever_mode"] = MultiVectorRetrieverMode(rag_params["multi_retriever_mode"])
56
+ if "search_type" in rag_params:
57
+ rag_params["search_type"] = SearchType(rag_params["search_type"])
49
58
 
50
59
  # Handle search kwargs if present
51
- if 'search_kwargs' in rag_params and isinstance(rag_params['search_kwargs'], dict):
52
- rag_params['search_kwargs'] = SearchKwargs(**rag_params['search_kwargs'])
60
+ if "search_kwargs" in rag_params and isinstance(rag_params["search_kwargs"], dict):
61
+ rag_params["search_kwargs"] = SearchKwargs(**rag_params["search_kwargs"])
53
62
 
54
63
  # Handle summarization config if present
55
- summarization_config = rag_params.get('summarization_config')
64
+ summarization_config = rag_params.get("summarization_config")
56
65
  if summarization_config is not None and isinstance(summarization_config, dict):
57
- rag_params['summarization_config'] = SummarizationConfig(**summarization_config)
66
+ rag_params["summarization_config"] = SummarizationConfig(**summarization_config)
58
67
 
59
68
  # Handle vector store config
60
- if 'vector_store_config' in rag_params:
61
- if isinstance(rag_params['vector_store_config'], dict):
62
- rag_params['vector_store_config'] = VectorStoreConfig(**rag_params['vector_store_config'])
69
+ if "vector_store_config" in rag_params:
70
+ if isinstance(rag_params["vector_store_config"], dict):
71
+ rag_params["vector_store_config"] = VectorStoreConfig(**rag_params["vector_store_config"])
63
72
  else:
64
- rag_params['vector_store_config'] = {}
65
- logger.warning(f'No collection_name specified for the retrieval tool, '
66
- f"using default collection_name: '{DEFAULT_COLLECTION_NAME}'"
67
- f'\nWarning: If this collection does not exist, no data will be retrieved')
73
+ rag_params["vector_store_config"] = {}
74
+ logger.warning(
75
+ f"No collection_name specified for the retrieval tool, "
76
+ f"using default collection_name: '{DEFAULT_COLLECTION_NAME}'"
77
+ f"\nWarning: If this collection does not exist, no data will be retrieved"
78
+ )
68
79
 
69
- if 'reranker_config' in rag_params:
70
- rag_params['reranker_config'] = RerankerConfig(**rag_params['reranker_config'])
80
+ if "reranker_config" in rag_params:
81
+ rag_params["reranker_config"] = RerankerConfig(**rag_params["reranker_config"])
71
82
 
72
83
  # Convert to RAGPipelineModel with validation
73
84
  try:
74
85
  return RAGPipelineModel(**rag_params)
75
86
  except Exception as e:
76
- logger.error(f"Invalid RAG configuration: {str(e)}")
77
- raise ValueError(f"Configuration validation failed: {str(e)}")
87
+ logger.exception("Invalid RAG configuration:")
88
+ raise ValueError(f"Configuration validation failed: {str(e)}") from e
@@ -7,18 +7,35 @@ import math
7
7
  import os
8
8
  import random
9
9
  from abc import ABC
10
- from textwrap import dedent
11
10
  from typing import Any, List, Optional, Tuple
12
11
 
13
12
  from openai import AsyncOpenAI, AsyncAzureOpenAI
14
13
  from pydantic import BaseModel
15
14
 
16
- from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
15
+ from mindsdb.integrations.utilities.rag.settings import (
16
+ DEFAULT_RERANKING_MODEL,
17
+ DEFAULT_LLM_ENDPOINT,
18
+ DEFAULT_RERANKER_N,
19
+ DEFAULT_RERANKER_LOGPROBS,
20
+ DEFAULT_RERANKER_TOP_LOGPROBS,
21
+ DEFAULT_RERANKER_MAX_TOKENS,
22
+ DEFAULT_VALID_CLASS_TOKENS,
23
+ )
17
24
  from mindsdb.integrations.libs.base import BaseMLEngine
18
25
 
19
26
  log = logging.getLogger(__name__)
20
27
 
21
28
 
29
+ def get_event_loop():
30
+ try:
31
+ loop = asyncio.get_running_loop()
32
+ except RuntimeError:
33
+ # If no running loop exists, create a new one
34
+ loop = asyncio.new_event_loop()
35
+ asyncio.set_event_loop(loop)
36
+ return loop
37
+
38
+
22
39
  class BaseLLMReranker(BaseModel, ABC):
23
40
  filtering_threshold: float = 0.0 # Default threshold for filtering
24
41
  provider: str = "openai"
@@ -38,6 +55,11 @@ class BaseLLMReranker(BaseModel, ABC):
38
55
  request_timeout: float = 20.0 # Timeout for API requests
39
56
  early_stop: bool = True # Whether to enable early stopping
40
57
  early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
58
+ n: int = DEFAULT_RERANKER_N # Number of completions to generate
59
+ logprobs: bool = DEFAULT_RERANKER_LOGPROBS # Whether to include log probabilities
60
+ top_logprobs: int = DEFAULT_RERANKER_TOP_LOGPROBS # Number of top log probabilities to include
61
+ max_tokens: int = DEFAULT_RERANKER_MAX_TOKENS # Maximum tokens to generate
62
+ valid_class_tokens: List[str] = DEFAULT_VALID_CLASS_TOKENS
41
63
 
42
64
  class Config:
43
65
  arbitrary_types_allowed = True
@@ -61,7 +83,12 @@ class BaseLLMReranker(BaseModel, ABC):
61
83
  timeout=self.request_timeout,
62
84
  max_retries=2,
63
85
  )
64
- elif self.provider == "openai":
86
+ elif self.provider in ("openai", "ollama"):
87
+ if self.provider == "ollama":
88
+ self.method = "no-logprobs"
89
+ if self.api_key is None:
90
+ self.api_key = "n/a"
91
+
65
92
  api_key_var: str = "OPENAI_API_KEY"
66
93
  openai_api_key = self.api_key or os.getenv(api_key_var)
67
94
  if not openai_api_key:
@@ -71,7 +98,6 @@ class BaseLLMReranker(BaseModel, ABC):
71
98
  self.client = AsyncOpenAI(
72
99
  api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2
73
100
  )
74
-
75
101
  else:
76
102
  # try to use litellm
77
103
  from mindsdb.api.executor.controllers.session_controller import SessionController
@@ -86,7 +112,7 @@ class BaseLLMReranker(BaseModel, ABC):
86
112
  self.method = "no-logprobs"
87
113
 
88
114
  async def _call_llm(self, messages):
89
- if self.provider in ("azure_openai", "openai"):
115
+ if self.provider in ("azure_openai", "openai", "ollama"):
90
116
  return await self.client.chat.completions.create(
91
117
  model=self.model,
92
118
  messages=messages,
@@ -121,7 +147,7 @@ class BaseLLMReranker(BaseModel, ABC):
121
147
  for idx, result in enumerate(results):
122
148
  if isinstance(result, Exception):
123
149
  log.error(f"Error processing document {i + idx}: {str(result)}")
124
- raise RuntimeError(f"Error during reranking: {result}")
150
+ raise RuntimeError(f"Error during reranking: {result}") from result
125
151
 
126
152
  score = result["relevance_score"]
127
153
 
@@ -142,7 +168,7 @@ class BaseLLMReranker(BaseModel, ABC):
142
168
  return ranked_results
143
169
  except Exception as e:
144
170
  # Don't let early stopping errors stop the whole process
145
- log.warning(f"Error in early stopping check: {str(e)}")
171
+ log.warning(f"Error in early stopping check: {e}")
146
172
 
147
173
  return ranked_results
148
174
 
@@ -204,13 +230,11 @@ class BaseLLMReranker(BaseModel, ABC):
204
230
  return rerank_data
205
231
 
206
232
  async def search_relevancy_no_logprob(self, query: str, document: str) -> Any:
207
- prompt = dedent(
208
- f"""
209
- Score the relevance between search query and user message on scale between 0 and 100 per cents.
210
- Consider semantic meaning, key concepts, and contextual relevance.
211
- Return ONLY a numerical score between 0 and 100 per cents. No other text. Stop after sending a number
212
- Search query: {query}
213
- """
233
+ prompt = (
234
+ f"Score the relevance between search query and user message on scale between 0 and 100 per cents. "
235
+ f"Consider semantic meaning, key concepts, and contextual relevance. "
236
+ f"Return ONLY a numerical score between 0 and 100 per cents. No other text. Stop after sending a number. "
237
+ f"Search query: {query}"
214
238
  )
215
239
 
216
240
  response = await self._call_llm(
@@ -234,6 +258,28 @@ class BaseLLMReranker(BaseModel, ABC):
234
258
  return rerank_data
235
259
 
236
260
  async def search_relevancy_score(self, query: str, document: str) -> Any:
261
+ """
262
+ This method is used to score the relevance of a document to a query.
263
+
264
+ Args:
265
+ query: The query to score the relevance of.
266
+ document: The document to score the relevance of.
267
+
268
+ Returns:
269
+ A dictionary with the document and the relevance score.
270
+ """
271
+
272
+ log.debug("Start search_relevancy_score")
273
+ log.debug(f"Reranker query: {query[:5]}")
274
+ log.debug(f"Reranker document: {document[:50]}")
275
+ log.debug(f"Reranker model: {self.model}")
276
+ log.debug(f"Reranker temperature: {self.temperature}")
277
+ log.debug(f"Reranker n: {self.n}")
278
+ log.debug(f"Reranker logprobs: {self.logprobs}")
279
+ log.debug(f"Reranker top_logprobs: {self.top_logprobs}")
280
+ log.debug(f"Reranker max_tokens: {self.max_tokens}")
281
+ log.debug(f"Reranker valid_class_tokens: {self.valid_class_tokens}")
282
+
237
283
  response = await self.client.chat.completions.create(
238
284
  model=self.model,
239
285
  messages=[
@@ -306,17 +352,30 @@ class BaseLLMReranker(BaseModel, ABC):
306
352
  },
307
353
  ],
308
354
  temperature=self.temperature,
309
- n=1,
310
- logprobs=True,
311
- top_logprobs=4,
312
- max_tokens=3,
355
+ n=self.n,
356
+ logprobs=self.logprobs,
357
+ top_logprobs=self.top_logprobs,
358
+ max_tokens=self.max_tokens,
313
359
  )
314
360
 
315
361
  # Extract response and logprobs
316
362
  token_logprobs = response.choices[0].logprobs.content
317
- # Reconstruct the prediction and extract the top logprobs from the final token (e.g., "1")
318
- final_token_logprob = token_logprobs[-1]
319
- top_logprobs = final_token_logprob.top_logprobs
363
+
364
+ # Find the token that contains the class number
365
+ # Instead of just taking the last token, search for the actual class number token
366
+ class_token_logprob = None
367
+ for token_logprob in reversed(token_logprobs):
368
+ if token_logprob.token in self.valid_class_tokens:
369
+ class_token_logprob = token_logprob
370
+ break
371
+
372
+ # If we couldn't find a class token, fall back to the last non-empty token
373
+ if class_token_logprob is None:
374
+ log.warning("No class token logprob found, using the last token as fallback")
375
+ class_token_logprob = token_logprobs[-1]
376
+
377
+ top_logprobs = class_token_logprob.top_logprobs
378
+
320
379
  # Create a map of 'class_1' -> probability, using token combinations
321
380
  class_probs = {}
322
381
  for top_token in top_logprobs:
@@ -337,21 +396,15 @@ class BaseLLMReranker(BaseModel, ABC):
337
396
  score = 0.0
338
397
 
339
398
  rerank_data = {"document": document, "relevance_score": score}
399
+ log.debug(f"Reranker score: {score}")
400
+ log.debug("End search_relevancy_score")
340
401
  return rerank_data
341
402
 
342
403
  def get_scores(self, query: str, documents: list[str]):
343
404
  query_document_pairs = [(query, doc) for doc in documents]
344
405
  # Create event loop and run async code
345
- import asyncio
346
-
347
- try:
348
- loop = asyncio.get_running_loop()
349
- except RuntimeError:
350
- # If no running loop exists, create a new one
351
- loop = asyncio.new_event_loop()
352
- asyncio.set_event_loop(loop)
353
406
 
354
- documents_and_scores = loop.run_until_complete(self._rank(query_document_pairs))
407
+ documents_and_scores = get_event_loop().run_until_complete(self._rank(query_document_pairs))
355
408
 
356
409
  scores = [score for _, score in documents_and_scores]
357
410
  return scores
@@ -36,7 +36,7 @@ class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
36
36
  return []
37
37
 
38
38
  # Stream reranking update.
39
- dispatch_custom_event('rerank_begin', {'num_documents': len(documents)})
39
+ dispatch_custom_event("rerank_begin", {"num_documents": len(documents)})
40
40
 
41
41
  try:
42
42
  # Prepare query-document pairs
@@ -73,10 +73,10 @@ class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
73
73
  return filtered_docs
74
74
 
75
75
  except Exception as e:
76
- error_msg = f"Error during async document compression: {str(e)}"
77
- log.error(error_msg)
76
+ error_msg = "Error during async document compression:"
77
+ log.exception(error_msg)
78
78
  if callbacks:
79
- await callbacks.on_retriever_error(error_msg)
79
+ await callbacks.on_retriever_error(f"{error_msg} {e}")
80
80
  return documents # Return original documents on error
81
81
 
82
82
  def compress_documents(