MindsDB 25.9.2.0a1__py3-none-any.whl → 25.10.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +40 -29
- mindsdb/api/a2a/__init__.py +1 -1
- mindsdb/api/a2a/agent.py +16 -10
- mindsdb/api/a2a/common/server/server.py +7 -3
- mindsdb/api/a2a/common/server/task_manager.py +12 -5
- mindsdb/api/a2a/common/types.py +66 -0
- mindsdb/api/a2a/task_manager.py +65 -17
- mindsdb/api/common/middleware.py +10 -12
- mindsdb/api/executor/command_executor.py +51 -40
- mindsdb/api/executor/datahub/datanodes/datanode.py +2 -2
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +7 -13
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +101 -49
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +8 -4
- mindsdb/api/executor/datahub/datanodes/system_tables.py +3 -2
- mindsdb/api/executor/exceptions.py +29 -10
- mindsdb/api/executor/planner/plan_join.py +17 -3
- mindsdb/api/executor/planner/query_prepare.py +2 -20
- mindsdb/api/executor/sql_query/sql_query.py +74 -74
- mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +1 -2
- mindsdb/api/executor/sql_query/steps/subselect_step.py +0 -1
- mindsdb/api/executor/utilities/functions.py +6 -6
- mindsdb/api/executor/utilities/sql.py +37 -20
- mindsdb/api/http/gui.py +5 -11
- mindsdb/api/http/initialize.py +75 -61
- mindsdb/api/http/namespaces/agents.py +10 -15
- mindsdb/api/http/namespaces/analysis.py +13 -20
- mindsdb/api/http/namespaces/auth.py +1 -1
- mindsdb/api/http/namespaces/chatbots.py +0 -5
- mindsdb/api/http/namespaces/config.py +15 -11
- mindsdb/api/http/namespaces/databases.py +140 -201
- mindsdb/api/http/namespaces/file.py +17 -4
- mindsdb/api/http/namespaces/handlers.py +17 -7
- mindsdb/api/http/namespaces/knowledge_bases.py +28 -7
- mindsdb/api/http/namespaces/models.py +94 -126
- mindsdb/api/http/namespaces/projects.py +13 -22
- mindsdb/api/http/namespaces/sql.py +33 -25
- mindsdb/api/http/namespaces/tab.py +27 -37
- mindsdb/api/http/namespaces/views.py +1 -1
- mindsdb/api/http/start.py +16 -10
- mindsdb/api/mcp/__init__.py +2 -1
- mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +15 -20
- mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +26 -50
- mindsdb/api/mysql/mysql_proxy/utilities/__init__.py +0 -1
- mindsdb/api/mysql/mysql_proxy/utilities/dump.py +8 -2
- mindsdb/integrations/handlers/byom_handler/byom_handler.py +165 -190
- mindsdb/integrations/handlers/databricks_handler/databricks_handler.py +98 -46
- mindsdb/integrations/handlers/druid_handler/druid_handler.py +32 -40
- mindsdb/integrations/handlers/file_handler/file_handler.py +7 -0
- mindsdb/integrations/handlers/gitlab_handler/gitlab_handler.py +5 -2
- mindsdb/integrations/handlers/lightwood_handler/functions.py +45 -79
- mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +438 -100
- mindsdb/integrations/handlers/mssql_handler/requirements_odbc.txt +3 -0
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +235 -3
- mindsdb/integrations/handlers/oracle_handler/__init__.py +2 -0
- mindsdb/integrations/handlers/oracle_handler/connection_args.py +7 -1
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +321 -16
- mindsdb/integrations/handlers/oracle_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +14 -2
- mindsdb/integrations/handlers/shopify_handler/shopify_handler.py +25 -12
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +2 -1
- mindsdb/integrations/handlers/statsforecast_handler/requirements.txt +1 -0
- mindsdb/integrations/handlers/statsforecast_handler/requirements_extra.txt +1 -0
- mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +4 -4
- mindsdb/integrations/handlers/zendesk_handler/zendesk_tables.py +144 -111
- mindsdb/integrations/libs/api_handler.py +10 -10
- mindsdb/integrations/libs/base.py +4 -4
- mindsdb/integrations/libs/llm/utils.py +2 -2
- mindsdb/integrations/libs/ml_handler_process/create_engine_process.py +4 -7
- mindsdb/integrations/libs/ml_handler_process/func_call_process.py +2 -7
- mindsdb/integrations/libs/ml_handler_process/learn_process.py +37 -47
- mindsdb/integrations/libs/ml_handler_process/update_engine_process.py +4 -7
- mindsdb/integrations/libs/ml_handler_process/update_process.py +2 -7
- mindsdb/integrations/libs/process_cache.py +132 -140
- mindsdb/integrations/libs/response.py +18 -12
- mindsdb/integrations/libs/vectordatabase_handler.py +26 -0
- mindsdb/integrations/utilities/files/file_reader.py +6 -7
- mindsdb/integrations/utilities/handlers/auth_utilities/snowflake/__init__.py +1 -0
- mindsdb/integrations/utilities/handlers/auth_utilities/snowflake/snowflake_jwt_gen.py +151 -0
- mindsdb/integrations/utilities/rag/config_loader.py +37 -26
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +83 -30
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +4 -4
- mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +55 -133
- mindsdb/integrations/utilities/rag/settings.py +58 -133
- mindsdb/integrations/utilities/rag/splitters/file_splitter.py +5 -15
- mindsdb/interfaces/agents/agents_controller.py +2 -3
- mindsdb/interfaces/agents/constants.py +0 -2
- mindsdb/interfaces/agents/litellm_server.py +34 -58
- mindsdb/interfaces/agents/mcp_client_agent.py +10 -10
- mindsdb/interfaces/agents/mindsdb_database_agent.py +5 -5
- mindsdb/interfaces/agents/run_mcp_agent.py +12 -21
- mindsdb/interfaces/chatbot/chatbot_task.py +20 -23
- mindsdb/interfaces/chatbot/polling.py +30 -18
- mindsdb/interfaces/data_catalog/data_catalog_loader.py +16 -17
- mindsdb/interfaces/data_catalog/data_catalog_reader.py +15 -4
- mindsdb/interfaces/database/data_handlers_cache.py +190 -0
- mindsdb/interfaces/database/database.py +3 -3
- mindsdb/interfaces/database/integrations.py +7 -110
- mindsdb/interfaces/database/projects.py +2 -6
- mindsdb/interfaces/database/views.py +1 -4
- mindsdb/interfaces/file/file_controller.py +6 -6
- mindsdb/interfaces/functions/controller.py +1 -1
- mindsdb/interfaces/functions/to_markdown.py +2 -2
- mindsdb/interfaces/jobs/jobs_controller.py +5 -9
- mindsdb/interfaces/jobs/scheduler.py +3 -9
- mindsdb/interfaces/knowledge_base/controller.py +244 -128
- mindsdb/interfaces/knowledge_base/evaluate.py +36 -41
- mindsdb/interfaces/knowledge_base/executor.py +11 -0
- mindsdb/interfaces/knowledge_base/llm_client.py +51 -17
- mindsdb/interfaces/knowledge_base/preprocessing/json_chunker.py +40 -61
- mindsdb/interfaces/model/model_controller.py +172 -168
- mindsdb/interfaces/query_context/context_controller.py +14 -2
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +10 -14
- mindsdb/interfaces/skills/retrieval_tool.py +43 -50
- mindsdb/interfaces/skills/skill_tool.py +2 -2
- mindsdb/interfaces/skills/skills_controller.py +1 -4
- mindsdb/interfaces/skills/sql_agent.py +25 -19
- mindsdb/interfaces/storage/db.py +16 -6
- mindsdb/interfaces/storage/fs.py +114 -169
- mindsdb/interfaces/storage/json.py +19 -18
- mindsdb/interfaces/tabs/tabs_controller.py +49 -72
- mindsdb/interfaces/tasks/task_monitor.py +3 -9
- mindsdb/interfaces/tasks/task_thread.py +7 -9
- mindsdb/interfaces/triggers/trigger_task.py +7 -13
- mindsdb/interfaces/triggers/triggers_controller.py +47 -52
- mindsdb/migrations/migrate.py +16 -16
- mindsdb/utilities/api_status.py +58 -0
- mindsdb/utilities/config.py +68 -2
- mindsdb/utilities/exception.py +40 -1
- mindsdb/utilities/fs.py +0 -1
- mindsdb/utilities/hooks/profiling.py +17 -14
- mindsdb/utilities/json_encoder.py +24 -10
- mindsdb/utilities/langfuse.py +40 -45
- mindsdb/utilities/log.py +272 -0
- mindsdb/utilities/ml_task_queue/consumer.py +52 -58
- mindsdb/utilities/ml_task_queue/producer.py +26 -30
- mindsdb/utilities/render/sqlalchemy_render.py +22 -20
- mindsdb/utilities/starters.py +0 -10
- mindsdb/utilities/utils.py +2 -2
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0rc1.dist-info}/METADATA +293 -276
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0rc1.dist-info}/RECORD +144 -158
- mindsdb/api/mysql/mysql_proxy/utilities/exceptions.py +0 -14
- mindsdb/api/postgres/__init__.py +0 -0
- mindsdb/api/postgres/postgres_proxy/__init__.py +0 -0
- mindsdb/api/postgres/postgres_proxy/executor/__init__.py +0 -1
- mindsdb/api/postgres/postgres_proxy/executor/executor.py +0 -189
- mindsdb/api/postgres/postgres_proxy/postgres_packets/__init__.py +0 -0
- mindsdb/api/postgres/postgres_proxy/postgres_packets/errors.py +0 -322
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_fields.py +0 -34
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_message.py +0 -31
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_message_formats.py +0 -1265
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_message_identifiers.py +0 -31
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_packets.py +0 -253
- mindsdb/api/postgres/postgres_proxy/postgres_proxy.py +0 -477
- mindsdb/api/postgres/postgres_proxy/utilities/__init__.py +0 -10
- mindsdb/api/postgres/start.py +0 -11
- mindsdb/integrations/handlers/mssql_handler/tests/__init__.py +0 -0
- mindsdb/integrations/handlers/mssql_handler/tests/test_mssql_handler.py +0 -169
- mindsdb/integrations/handlers/oracle_handler/tests/__init__.py +0 -0
- mindsdb/integrations/handlers/oracle_handler/tests/test_oracle_handler.py +0 -32
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0rc1.dist-info}/WHEEL +0 -0
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0rc1.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.9.2.0a1.dist-info → mindsdb-25.10.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,16 @@
|
|
|
1
|
-
from dataclasses import dataclass, astuple
|
|
2
|
-
import traceback
|
|
3
|
-
import json
|
|
4
1
|
import csv
|
|
5
|
-
|
|
6
|
-
from pathlib import Path
|
|
2
|
+
import json
|
|
7
3
|
import codecs
|
|
4
|
+
from io import BytesIO, StringIO, IOBase
|
|
8
5
|
from typing import List, Generator
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from dataclasses import dataclass, astuple
|
|
9
8
|
|
|
10
9
|
import filetype
|
|
11
10
|
import pandas as pd
|
|
12
11
|
from charset_normalizer import from_bytes
|
|
13
|
-
from mindsdb.interfaces.knowledge_base.preprocessing.text_splitter import TextSplitter
|
|
14
12
|
|
|
13
|
+
from mindsdb.interfaces.knowledge_base.preprocessing.text_splitter import TextSplitter
|
|
15
14
|
from mindsdb.utilities import log
|
|
16
15
|
|
|
17
16
|
logger = log.getLogger(__name__)
|
|
@@ -76,7 +75,7 @@ def decode(file_obj: IOBase) -> StringIO:
|
|
|
76
75
|
|
|
77
76
|
data_str = StringIO(byte_str.decode(encoding, errors))
|
|
78
77
|
except Exception as e:
|
|
79
|
-
logger.
|
|
78
|
+
logger.exception("Error during file decode:")
|
|
80
79
|
raise FileProcessingError("Could not load into string") from e
|
|
81
80
|
|
|
82
81
|
return data_str
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .snowflake_jwt_gen import get_validated_jwt as get_validated_jwt
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# Based on https://docs.snowflake.com/en/developer-guide/sql-api/authenticating
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import base64
|
|
5
|
+
import hashlib
|
|
6
|
+
import logging
|
|
7
|
+
from datetime import timedelta, timezone, datetime
|
|
8
|
+
|
|
9
|
+
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
|
10
|
+
from cryptography.hazmat.primitives.serialization import Encoding
|
|
11
|
+
from cryptography.hazmat.primitives.serialization import PublicFormat
|
|
12
|
+
from cryptography.hazmat.backends import default_backend
|
|
13
|
+
import jwt
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
ISSUER = "iss"
|
|
18
|
+
EXPIRE_TIME = "exp"
|
|
19
|
+
ISSUE_TIME = "iat"
|
|
20
|
+
SUBJECT = "sub"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class JWTGenerator(object):
|
|
24
|
+
"""
|
|
25
|
+
Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator keeps the
|
|
26
|
+
generated token and only regenerates the token if a specified period of time has passed.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
LIFETIME = timedelta(minutes=60) # The tokens will have a 59 minute lifetime
|
|
30
|
+
ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
|
|
31
|
+
|
|
32
|
+
def __init__(self, account: str, user: str, private_key: str, lifetime: timedelta = LIFETIME):
|
|
33
|
+
"""
|
|
34
|
+
__init__ creates an object that generates JWTs for the specified user, account identifier, and private key.
|
|
35
|
+
:param account: Your Snowflake account identifier. See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html. Note that if you are using the account locator, exclude any region information from the account locator.
|
|
36
|
+
:param user: The Snowflake username.
|
|
37
|
+
:param private_key: The private key file used for signing the JWTs.
|
|
38
|
+
:param lifetime: The number of minutes (as a timedelta) during which the key will be valid.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
logger.info(
|
|
42
|
+
"""Creating JWTGenerator with arguments
|
|
43
|
+
account : %s, user : %s, lifetime : %s""",
|
|
44
|
+
account,
|
|
45
|
+
user,
|
|
46
|
+
lifetime,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Construct the fully qualified name of the user in uppercase.
|
|
50
|
+
self.account = self.prepare_account_name_for_jwt(account)
|
|
51
|
+
self.user = user.upper()
|
|
52
|
+
self.qualified_username = self.account + "." + self.user
|
|
53
|
+
|
|
54
|
+
self.lifetime = lifetime
|
|
55
|
+
self.renew_time = datetime.now(timezone.utc)
|
|
56
|
+
self.token = None
|
|
57
|
+
|
|
58
|
+
self.private_key = load_pem_private_key(private_key.encode(), None, default_backend())
|
|
59
|
+
|
|
60
|
+
def prepare_account_name_for_jwt(self, raw_account: str) -> str:
|
|
61
|
+
"""
|
|
62
|
+
Prepare the account identifier for use in the JWT.
|
|
63
|
+
For the JWT, the account identifier must not include the subdomain or any region or cloud provider information.
|
|
64
|
+
:param raw_account: The specified account identifier.
|
|
65
|
+
:return: The account identifier in a form that can be used to generate JWT.
|
|
66
|
+
"""
|
|
67
|
+
account = raw_account
|
|
68
|
+
if ".global" not in account:
|
|
69
|
+
# Handle the general case.
|
|
70
|
+
idx = account.find(".")
|
|
71
|
+
if idx > 0:
|
|
72
|
+
account = account[0:idx]
|
|
73
|
+
else:
|
|
74
|
+
# Handle the replication case.
|
|
75
|
+
idx = account.find("-")
|
|
76
|
+
if idx > 0:
|
|
77
|
+
account = account[0:idx]
|
|
78
|
+
# Use uppercase for the account identifier.
|
|
79
|
+
return account.upper()
|
|
80
|
+
|
|
81
|
+
def get_token(self) -> str:
|
|
82
|
+
"""
|
|
83
|
+
Generates a new JWT.
|
|
84
|
+
:return: the new token
|
|
85
|
+
"""
|
|
86
|
+
now = datetime.now(timezone.utc) # Fetch the current time
|
|
87
|
+
|
|
88
|
+
# Prepare the fields for the payload.
|
|
89
|
+
# Generate the public key fingerprint for the issuer in the payload.
|
|
90
|
+
public_key_fp = self.calculate_public_key_fingerprint(self.private_key)
|
|
91
|
+
|
|
92
|
+
# Create our payload
|
|
93
|
+
payload = {
|
|
94
|
+
# Set the issuer to the fully qualified username concatenated with the public key fingerprint.
|
|
95
|
+
ISSUER: self.qualified_username + "." + public_key_fp,
|
|
96
|
+
# Set the subject to the fully qualified username.
|
|
97
|
+
SUBJECT: self.qualified_username,
|
|
98
|
+
# Set the issue time to now.
|
|
99
|
+
ISSUE_TIME: now,
|
|
100
|
+
# Set the expiration time, based on the lifetime specified for this object.
|
|
101
|
+
EXPIRE_TIME: now + self.lifetime,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
# Regenerate the actual token
|
|
105
|
+
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
|
|
106
|
+
# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string, rather than a string.
|
|
107
|
+
# If the token is a byte string, convert it to a string.
|
|
108
|
+
if isinstance(token, bytes):
|
|
109
|
+
token = token.decode("utf-8")
|
|
110
|
+
self.token = token
|
|
111
|
+
|
|
112
|
+
return self.token
|
|
113
|
+
|
|
114
|
+
def calculate_public_key_fingerprint(self, private_key: str) -> str:
|
|
115
|
+
"""
|
|
116
|
+
Given a private key in PEM format, return the public key fingerprint.
|
|
117
|
+
:param private_key: private key string
|
|
118
|
+
:return: public key fingerprint
|
|
119
|
+
"""
|
|
120
|
+
# Get the raw bytes of public key.
|
|
121
|
+
public_key_raw = private_key.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
|
|
122
|
+
|
|
123
|
+
# Get the sha256 hash of the raw bytes.
|
|
124
|
+
sha256hash = hashlib.sha256()
|
|
125
|
+
sha256hash.update(public_key_raw)
|
|
126
|
+
|
|
127
|
+
# Base64-encode the value and prepend the prefix 'SHA256:'.
|
|
128
|
+
public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
|
|
129
|
+
logger.info("Public key fingerprint is %s", public_key_fp)
|
|
130
|
+
|
|
131
|
+
return public_key_fp
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_validated_jwt(token: str, account: str, user: str, private_key: str) -> str:
|
|
135
|
+
try:
|
|
136
|
+
content = jwt.decode(token, algorithms=[JWTGenerator.ALGORITHM], options={"verify_signature": False})
|
|
137
|
+
|
|
138
|
+
expired = content.get("exp", 0)
|
|
139
|
+
# add 5 seconds before limit
|
|
140
|
+
if expired - 5 > time.time():
|
|
141
|
+
# keep the same
|
|
142
|
+
return token
|
|
143
|
+
|
|
144
|
+
except jwt.DecodeError:
|
|
145
|
+
# wrong key
|
|
146
|
+
...
|
|
147
|
+
|
|
148
|
+
# generate new token
|
|
149
|
+
if private_key is None:
|
|
150
|
+
raise ValueError("Private key is missing")
|
|
151
|
+
return JWTGenerator(account, user, private_key).get_token()
|
|
@@ -1,17 +1,26 @@
|
|
|
1
1
|
"""Utility functions for RAG pipeline configuration"""
|
|
2
|
+
|
|
2
3
|
from typing import Dict, Any, Optional
|
|
3
4
|
|
|
4
5
|
from mindsdb.utilities.log import getLogger
|
|
5
6
|
from mindsdb.integrations.utilities.rag.settings import (
|
|
6
|
-
RetrieverType,
|
|
7
|
-
|
|
8
|
-
|
|
7
|
+
RetrieverType,
|
|
8
|
+
MultiVectorRetrieverMode,
|
|
9
|
+
SearchType,
|
|
10
|
+
SearchKwargs,
|
|
11
|
+
SummarizationConfig,
|
|
12
|
+
VectorStoreConfig,
|
|
13
|
+
RerankerConfig,
|
|
14
|
+
RAGPipelineModel,
|
|
15
|
+
DEFAULT_COLLECTION_NAME,
|
|
9
16
|
)
|
|
10
17
|
|
|
11
18
|
logger = getLogger(__name__)
|
|
12
19
|
|
|
13
20
|
|
|
14
|
-
def load_rag_config(
|
|
21
|
+
def load_rag_config(
|
|
22
|
+
base_config: Dict[str, Any], kb_params: Optional[Dict[str, Any]] = None, embedding_model: Any = None
|
|
23
|
+
) -> RAGPipelineModel:
|
|
15
24
|
"""
|
|
16
25
|
Load and validate RAG configuration parameters. This function handles the conversion of configuration
|
|
17
26
|
parameters into their appropriate types and ensures all required settings are properly configured.
|
|
@@ -37,41 +46,43 @@ def load_rag_config(base_config: Dict[str, Any], kb_params: Optional[Dict[str, A
|
|
|
37
46
|
|
|
38
47
|
# Set embedding model if provided
|
|
39
48
|
if embedding_model is not None:
|
|
40
|
-
rag_params[
|
|
49
|
+
rag_params["embedding_model"] = embedding_model
|
|
41
50
|
|
|
42
51
|
# Handle enums and type conversions
|
|
43
|
-
if
|
|
44
|
-
rag_params[
|
|
45
|
-
if
|
|
46
|
-
rag_params[
|
|
47
|
-
if
|
|
48
|
-
rag_params[
|
|
52
|
+
if "retriever_type" in rag_params:
|
|
53
|
+
rag_params["retriever_type"] = RetrieverType(rag_params["retriever_type"])
|
|
54
|
+
if "multi_retriever_mode" in rag_params:
|
|
55
|
+
rag_params["multi_retriever_mode"] = MultiVectorRetrieverMode(rag_params["multi_retriever_mode"])
|
|
56
|
+
if "search_type" in rag_params:
|
|
57
|
+
rag_params["search_type"] = SearchType(rag_params["search_type"])
|
|
49
58
|
|
|
50
59
|
# Handle search kwargs if present
|
|
51
|
-
if
|
|
52
|
-
rag_params[
|
|
60
|
+
if "search_kwargs" in rag_params and isinstance(rag_params["search_kwargs"], dict):
|
|
61
|
+
rag_params["search_kwargs"] = SearchKwargs(**rag_params["search_kwargs"])
|
|
53
62
|
|
|
54
63
|
# Handle summarization config if present
|
|
55
|
-
summarization_config = rag_params.get(
|
|
64
|
+
summarization_config = rag_params.get("summarization_config")
|
|
56
65
|
if summarization_config is not None and isinstance(summarization_config, dict):
|
|
57
|
-
rag_params[
|
|
66
|
+
rag_params["summarization_config"] = SummarizationConfig(**summarization_config)
|
|
58
67
|
|
|
59
68
|
# Handle vector store config
|
|
60
|
-
if
|
|
61
|
-
if isinstance(rag_params[
|
|
62
|
-
rag_params[
|
|
69
|
+
if "vector_store_config" in rag_params:
|
|
70
|
+
if isinstance(rag_params["vector_store_config"], dict):
|
|
71
|
+
rag_params["vector_store_config"] = VectorStoreConfig(**rag_params["vector_store_config"])
|
|
63
72
|
else:
|
|
64
|
-
rag_params[
|
|
65
|
-
logger.warning(
|
|
66
|
-
|
|
67
|
-
|
|
73
|
+
rag_params["vector_store_config"] = {}
|
|
74
|
+
logger.warning(
|
|
75
|
+
f"No collection_name specified for the retrieval tool, "
|
|
76
|
+
f"using default collection_name: '{DEFAULT_COLLECTION_NAME}'"
|
|
77
|
+
f"\nWarning: If this collection does not exist, no data will be retrieved"
|
|
78
|
+
)
|
|
68
79
|
|
|
69
|
-
if
|
|
70
|
-
rag_params[
|
|
80
|
+
if "reranker_config" in rag_params:
|
|
81
|
+
rag_params["reranker_config"] = RerankerConfig(**rag_params["reranker_config"])
|
|
71
82
|
|
|
72
83
|
# Convert to RAGPipelineModel with validation
|
|
73
84
|
try:
|
|
74
85
|
return RAGPipelineModel(**rag_params)
|
|
75
86
|
except Exception as e:
|
|
76
|
-
logger.
|
|
77
|
-
raise ValueError(f"Configuration validation failed: {str(e)}")
|
|
87
|
+
logger.exception("Invalid RAG configuration:")
|
|
88
|
+
raise ValueError(f"Configuration validation failed: {str(e)}") from e
|
|
@@ -7,18 +7,35 @@ import math
|
|
|
7
7
|
import os
|
|
8
8
|
import random
|
|
9
9
|
from abc import ABC
|
|
10
|
-
from textwrap import dedent
|
|
11
10
|
from typing import Any, List, Optional, Tuple
|
|
12
11
|
|
|
13
12
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
|
14
13
|
from pydantic import BaseModel
|
|
15
14
|
|
|
16
|
-
from mindsdb.integrations.utilities.rag.settings import
|
|
15
|
+
from mindsdb.integrations.utilities.rag.settings import (
|
|
16
|
+
DEFAULT_RERANKING_MODEL,
|
|
17
|
+
DEFAULT_LLM_ENDPOINT,
|
|
18
|
+
DEFAULT_RERANKER_N,
|
|
19
|
+
DEFAULT_RERANKER_LOGPROBS,
|
|
20
|
+
DEFAULT_RERANKER_TOP_LOGPROBS,
|
|
21
|
+
DEFAULT_RERANKER_MAX_TOKENS,
|
|
22
|
+
DEFAULT_VALID_CLASS_TOKENS,
|
|
23
|
+
)
|
|
17
24
|
from mindsdb.integrations.libs.base import BaseMLEngine
|
|
18
25
|
|
|
19
26
|
log = logging.getLogger(__name__)
|
|
20
27
|
|
|
21
28
|
|
|
29
|
+
def get_event_loop():
|
|
30
|
+
try:
|
|
31
|
+
loop = asyncio.get_running_loop()
|
|
32
|
+
except RuntimeError:
|
|
33
|
+
# If no running loop exists, create a new one
|
|
34
|
+
loop = asyncio.new_event_loop()
|
|
35
|
+
asyncio.set_event_loop(loop)
|
|
36
|
+
return loop
|
|
37
|
+
|
|
38
|
+
|
|
22
39
|
class BaseLLMReranker(BaseModel, ABC):
|
|
23
40
|
filtering_threshold: float = 0.0 # Default threshold for filtering
|
|
24
41
|
provider: str = "openai"
|
|
@@ -38,6 +55,11 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
38
55
|
request_timeout: float = 20.0 # Timeout for API requests
|
|
39
56
|
early_stop: bool = True # Whether to enable early stopping
|
|
40
57
|
early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
|
|
58
|
+
n: int = DEFAULT_RERANKER_N # Number of completions to generate
|
|
59
|
+
logprobs: bool = DEFAULT_RERANKER_LOGPROBS # Whether to include log probabilities
|
|
60
|
+
top_logprobs: int = DEFAULT_RERANKER_TOP_LOGPROBS # Number of top log probabilities to include
|
|
61
|
+
max_tokens: int = DEFAULT_RERANKER_MAX_TOKENS # Maximum tokens to generate
|
|
62
|
+
valid_class_tokens: List[str] = DEFAULT_VALID_CLASS_TOKENS
|
|
41
63
|
|
|
42
64
|
class Config:
|
|
43
65
|
arbitrary_types_allowed = True
|
|
@@ -61,7 +83,12 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
61
83
|
timeout=self.request_timeout,
|
|
62
84
|
max_retries=2,
|
|
63
85
|
)
|
|
64
|
-
elif self.provider
|
|
86
|
+
elif self.provider in ("openai", "ollama"):
|
|
87
|
+
if self.provider == "ollama":
|
|
88
|
+
self.method = "no-logprobs"
|
|
89
|
+
if self.api_key is None:
|
|
90
|
+
self.api_key = "n/a"
|
|
91
|
+
|
|
65
92
|
api_key_var: str = "OPENAI_API_KEY"
|
|
66
93
|
openai_api_key = self.api_key or os.getenv(api_key_var)
|
|
67
94
|
if not openai_api_key:
|
|
@@ -71,7 +98,6 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
71
98
|
self.client = AsyncOpenAI(
|
|
72
99
|
api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2
|
|
73
100
|
)
|
|
74
|
-
|
|
75
101
|
else:
|
|
76
102
|
# try to use litellm
|
|
77
103
|
from mindsdb.api.executor.controllers.session_controller import SessionController
|
|
@@ -86,7 +112,7 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
86
112
|
self.method = "no-logprobs"
|
|
87
113
|
|
|
88
114
|
async def _call_llm(self, messages):
|
|
89
|
-
if self.provider in ("azure_openai", "openai"):
|
|
115
|
+
if self.provider in ("azure_openai", "openai", "ollama"):
|
|
90
116
|
return await self.client.chat.completions.create(
|
|
91
117
|
model=self.model,
|
|
92
118
|
messages=messages,
|
|
@@ -121,7 +147,7 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
121
147
|
for idx, result in enumerate(results):
|
|
122
148
|
if isinstance(result, Exception):
|
|
123
149
|
log.error(f"Error processing document {i + idx}: {str(result)}")
|
|
124
|
-
raise RuntimeError(f"Error during reranking: {result}")
|
|
150
|
+
raise RuntimeError(f"Error during reranking: {result}") from result
|
|
125
151
|
|
|
126
152
|
score = result["relevance_score"]
|
|
127
153
|
|
|
@@ -142,7 +168,7 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
142
168
|
return ranked_results
|
|
143
169
|
except Exception as e:
|
|
144
170
|
# Don't let early stopping errors stop the whole process
|
|
145
|
-
log.warning(f"Error in early stopping check: {
|
|
171
|
+
log.warning(f"Error in early stopping check: {e}")
|
|
146
172
|
|
|
147
173
|
return ranked_results
|
|
148
174
|
|
|
@@ -204,13 +230,11 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
204
230
|
return rerank_data
|
|
205
231
|
|
|
206
232
|
async def search_relevancy_no_logprob(self, query: str, document: str) -> Any:
|
|
207
|
-
prompt =
|
|
208
|
-
f""
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
Search query: {query}
|
|
213
|
-
"""
|
|
233
|
+
prompt = (
|
|
234
|
+
f"Score the relevance between search query and user message on scale between 0 and 100 per cents. "
|
|
235
|
+
f"Consider semantic meaning, key concepts, and contextual relevance. "
|
|
236
|
+
f"Return ONLY a numerical score between 0 and 100 per cents. No other text. Stop after sending a number. "
|
|
237
|
+
f"Search query: {query}"
|
|
214
238
|
)
|
|
215
239
|
|
|
216
240
|
response = await self._call_llm(
|
|
@@ -234,6 +258,28 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
234
258
|
return rerank_data
|
|
235
259
|
|
|
236
260
|
async def search_relevancy_score(self, query: str, document: str) -> Any:
|
|
261
|
+
"""
|
|
262
|
+
This method is used to score the relevance of a document to a query.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
query: The query to score the relevance of.
|
|
266
|
+
document: The document to score the relevance of.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
A dictionary with the document and the relevance score.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
log.debug("Start search_relevancy_score")
|
|
273
|
+
log.debug(f"Reranker query: {query[:5]}")
|
|
274
|
+
log.debug(f"Reranker document: {document[:50]}")
|
|
275
|
+
log.debug(f"Reranker model: {self.model}")
|
|
276
|
+
log.debug(f"Reranker temperature: {self.temperature}")
|
|
277
|
+
log.debug(f"Reranker n: {self.n}")
|
|
278
|
+
log.debug(f"Reranker logprobs: {self.logprobs}")
|
|
279
|
+
log.debug(f"Reranker top_logprobs: {self.top_logprobs}")
|
|
280
|
+
log.debug(f"Reranker max_tokens: {self.max_tokens}")
|
|
281
|
+
log.debug(f"Reranker valid_class_tokens: {self.valid_class_tokens}")
|
|
282
|
+
|
|
237
283
|
response = await self.client.chat.completions.create(
|
|
238
284
|
model=self.model,
|
|
239
285
|
messages=[
|
|
@@ -306,17 +352,30 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
306
352
|
},
|
|
307
353
|
],
|
|
308
354
|
temperature=self.temperature,
|
|
309
|
-
n=
|
|
310
|
-
logprobs=
|
|
311
|
-
top_logprobs=
|
|
312
|
-
max_tokens=
|
|
355
|
+
n=self.n,
|
|
356
|
+
logprobs=self.logprobs,
|
|
357
|
+
top_logprobs=self.top_logprobs,
|
|
358
|
+
max_tokens=self.max_tokens,
|
|
313
359
|
)
|
|
314
360
|
|
|
315
361
|
# Extract response and logprobs
|
|
316
362
|
token_logprobs = response.choices[0].logprobs.content
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
363
|
+
|
|
364
|
+
# Find the token that contains the class number
|
|
365
|
+
# Instead of just taking the last token, search for the actual class number token
|
|
366
|
+
class_token_logprob = None
|
|
367
|
+
for token_logprob in reversed(token_logprobs):
|
|
368
|
+
if token_logprob.token in self.valid_class_tokens:
|
|
369
|
+
class_token_logprob = token_logprob
|
|
370
|
+
break
|
|
371
|
+
|
|
372
|
+
# If we couldn't find a class token, fall back to the last non-empty token
|
|
373
|
+
if class_token_logprob is None:
|
|
374
|
+
log.warning("No class token logprob found, using the last token as fallback")
|
|
375
|
+
class_token_logprob = token_logprobs[-1]
|
|
376
|
+
|
|
377
|
+
top_logprobs = class_token_logprob.top_logprobs
|
|
378
|
+
|
|
320
379
|
# Create a map of 'class_1' -> probability, using token combinations
|
|
321
380
|
class_probs = {}
|
|
322
381
|
for top_token in top_logprobs:
|
|
@@ -337,21 +396,15 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
337
396
|
score = 0.0
|
|
338
397
|
|
|
339
398
|
rerank_data = {"document": document, "relevance_score": score}
|
|
399
|
+
log.debug(f"Reranker score: {score}")
|
|
400
|
+
log.debug("End search_relevancy_score")
|
|
340
401
|
return rerank_data
|
|
341
402
|
|
|
342
403
|
def get_scores(self, query: str, documents: list[str]):
|
|
343
404
|
query_document_pairs = [(query, doc) for doc in documents]
|
|
344
405
|
# Create event loop and run async code
|
|
345
|
-
import asyncio
|
|
346
|
-
|
|
347
|
-
try:
|
|
348
|
-
loop = asyncio.get_running_loop()
|
|
349
|
-
except RuntimeError:
|
|
350
|
-
# If no running loop exists, create a new one
|
|
351
|
-
loop = asyncio.new_event_loop()
|
|
352
|
-
asyncio.set_event_loop(loop)
|
|
353
406
|
|
|
354
|
-
documents_and_scores =
|
|
407
|
+
documents_and_scores = get_event_loop().run_until_complete(self._rank(query_document_pairs))
|
|
355
408
|
|
|
356
409
|
scores = [score for _, score in documents_and_scores]
|
|
357
410
|
return scores
|
|
@@ -36,7 +36,7 @@ class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
|
|
|
36
36
|
return []
|
|
37
37
|
|
|
38
38
|
# Stream reranking update.
|
|
39
|
-
dispatch_custom_event(
|
|
39
|
+
dispatch_custom_event("rerank_begin", {"num_documents": len(documents)})
|
|
40
40
|
|
|
41
41
|
try:
|
|
42
42
|
# Prepare query-document pairs
|
|
@@ -73,10 +73,10 @@ class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
|
|
|
73
73
|
return filtered_docs
|
|
74
74
|
|
|
75
75
|
except Exception as e:
|
|
76
|
-
error_msg =
|
|
77
|
-
log.
|
|
76
|
+
error_msg = "Error during async document compression:"
|
|
77
|
+
log.exception(error_msg)
|
|
78
78
|
if callbacks:
|
|
79
|
-
await callbacks.on_retriever_error(error_msg)
|
|
79
|
+
await callbacks.on_retriever_error(f"{error_msg} {e}")
|
|
80
80
|
return documents # Return original documents on error
|
|
81
81
|
|
|
82
82
|
def compress_documents(
|