MindsDB 25.4.3.2__py3-none-any.whl → 25.4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +18 -4
- mindsdb/api/executor/command_executor.py +12 -2
- mindsdb/api/executor/data_types/response_type.py +1 -0
- mindsdb/api/executor/datahub/classes/tables_row.py +3 -10
- mindsdb/api/executor/datahub/datanodes/datanode.py +7 -2
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +44 -10
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +57 -38
- mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +2 -1
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +39 -7
- mindsdb/api/executor/datahub/datanodes/system_tables.py +116 -109
- mindsdb/api/executor/planner/query_plan.py +1 -0
- mindsdb/api/executor/planner/query_planner.py +15 -1
- mindsdb/api/executor/planner/steps.py +8 -2
- mindsdb/api/executor/sql_query/sql_query.py +24 -8
- mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +25 -8
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +4 -2
- mindsdb/api/executor/sql_query/steps/insert_step.py +2 -1
- mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -3
- mindsdb/api/http/namespaces/config.py +19 -11
- mindsdb/api/litellm/start.py +82 -0
- mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +133 -0
- mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +7 -2
- mindsdb/integrations/handlers/chromadb_handler/settings.py +1 -0
- mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +13 -4
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +14 -5
- mindsdb/integrations/handlers/openai_handler/helpers.py +3 -5
- mindsdb/integrations/handlers/openai_handler/openai_handler.py +20 -8
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +14 -4
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +34 -19
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +21 -18
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +14 -4
- mindsdb/integrations/handlers/togetherai_handler/__about__.py +9 -0
- mindsdb/integrations/handlers/togetherai_handler/__init__.py +20 -0
- mindsdb/integrations/handlers/togetherai_handler/creation_args.py +14 -0
- mindsdb/integrations/handlers/togetherai_handler/icon.svg +15 -0
- mindsdb/integrations/handlers/togetherai_handler/model_using_args.py +5 -0
- mindsdb/integrations/handlers/togetherai_handler/requirements.txt +2 -0
- mindsdb/integrations/handlers/togetherai_handler/settings.py +33 -0
- mindsdb/integrations/handlers/togetherai_handler/togetherai_handler.py +234 -0
- mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +1 -1
- mindsdb/integrations/libs/response.py +80 -32
- mindsdb/integrations/utilities/handler_utils.py +4 -0
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +360 -0
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +8 -153
- mindsdb/interfaces/agents/litellm_server.py +345 -0
- mindsdb/interfaces/agents/mcp_client_agent.py +252 -0
- mindsdb/interfaces/agents/run_mcp_agent.py +205 -0
- mindsdb/interfaces/functions/controller.py +3 -2
- mindsdb/interfaces/knowledge_base/controller.py +106 -82
- mindsdb/interfaces/query_context/context_controller.py +55 -15
- mindsdb/interfaces/query_context/query_task.py +19 -0
- mindsdb/interfaces/skills/skill_tool.py +7 -1
- mindsdb/interfaces/skills/sql_agent.py +8 -3
- mindsdb/interfaces/storage/db.py +2 -2
- mindsdb/interfaces/tasks/task_monitor.py +5 -1
- mindsdb/interfaces/tasks/task_thread.py +6 -0
- mindsdb/migrations/versions/2025-04-22_53502b6d63bf_query_database.py +27 -0
- mindsdb/utilities/config.py +20 -2
- mindsdb/utilities/context.py +1 -0
- mindsdb/utilities/starters.py +7 -0
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/METADATA +226 -221
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/RECORD +67 -53
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/WHEEL +1 -1
- mindsdb/integrations/handlers/snowflake_handler/tests/test_snowflake_handler.py +0 -230
- /mindsdb/{integrations/handlers/snowflake_handler/tests → api/litellm}/__init__.py +0 -0
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import argparse
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import List, Dict
|
|
5
|
+
from contextlib import AsyncExitStack
|
|
6
|
+
|
|
7
|
+
from mcp import ClientSession, StdioServerParameters
|
|
8
|
+
from mcp.client.stdio import stdio_client
|
|
9
|
+
|
|
10
|
+
from mindsdb.utilities import log
|
|
11
|
+
from mindsdb.interfaces.agents.mcp_client_agent import create_mcp_agent
|
|
12
|
+
|
|
13
|
+
logger = log.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def run_conversation(agent_wrapper, messages: List[Dict[str, str]], stream: bool = False):
|
|
17
|
+
"""Run a conversation with the agent and print responses"""
|
|
18
|
+
try:
|
|
19
|
+
if stream:
|
|
20
|
+
logger.info("Streaming response:")
|
|
21
|
+
async for chunk in agent_wrapper.acompletion_stream(messages):
|
|
22
|
+
content = chunk["choices"][0]["delta"].get("content", "")
|
|
23
|
+
if content:
|
|
24
|
+
# We still need to print content for streaming display
|
|
25
|
+
# but we'll log it as debug as well
|
|
26
|
+
logger.debug(f"Stream content: {content}")
|
|
27
|
+
sys.stdout.write(content)
|
|
28
|
+
sys.stdout.flush()
|
|
29
|
+
logger.debug("End of stream")
|
|
30
|
+
sys.stdout.write("\n\n")
|
|
31
|
+
sys.stdout.flush()
|
|
32
|
+
else:
|
|
33
|
+
logger.info("Getting response...")
|
|
34
|
+
response = await agent_wrapper.acompletion(messages)
|
|
35
|
+
content = response["choices"][0]["message"]["content"]
|
|
36
|
+
logger.info(f"Response: {content}")
|
|
37
|
+
# We still need to display the response to the user
|
|
38
|
+
sys.stdout.write(f"{content}\n")
|
|
39
|
+
sys.stdout.flush()
|
|
40
|
+
except Exception as e:
|
|
41
|
+
logger.error(f"Error during agent conversation: {str(e)}")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
async def execute_direct_query(query):
|
|
45
|
+
"""Execute a direct SQL query using MCP"""
|
|
46
|
+
logger.info(f"Executing direct SQL query: {query}")
|
|
47
|
+
|
|
48
|
+
# Set up MCP client to connect to the running server
|
|
49
|
+
async with AsyncExitStack() as stack:
|
|
50
|
+
# Connect to MCP server
|
|
51
|
+
server_params = StdioServerParameters(
|
|
52
|
+
command="python",
|
|
53
|
+
args=["-m", "mindsdb", "--api=mcp"],
|
|
54
|
+
env=None
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
stdio_transport = await stack.enter_async_context(stdio_client(server_params))
|
|
59
|
+
stdio, write = stdio_transport
|
|
60
|
+
session = await stack.enter_async_context(ClientSession(stdio, write))
|
|
61
|
+
|
|
62
|
+
await session.initialize()
|
|
63
|
+
|
|
64
|
+
# List available tools
|
|
65
|
+
tools_response = await session.list_tools()
|
|
66
|
+
tool_names = [tool.name for tool in tools_response.tools]
|
|
67
|
+
logger.info(f"Available tools: {tool_names}")
|
|
68
|
+
|
|
69
|
+
# Find query tool
|
|
70
|
+
query_tool = None
|
|
71
|
+
for tool in tools_response.tools:
|
|
72
|
+
if tool.name == "query":
|
|
73
|
+
query_tool = tool
|
|
74
|
+
break
|
|
75
|
+
|
|
76
|
+
if not query_tool:
|
|
77
|
+
logger.error("No 'query' tool found in MCP server")
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
# Execute query
|
|
81
|
+
result = await session.call_tool("query", {"query": query})
|
|
82
|
+
logger.info(f"Query result: {result.content}")
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.error(f"Error executing query: {str(e)}")
|
|
85
|
+
logger.info("Make sure the MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
async def main():
|
|
89
|
+
parser = argparse.ArgumentParser(description="Run an agent as an MCP client")
|
|
90
|
+
parser.add_argument("--agent", type=str, help="Name of the agent to use")
|
|
91
|
+
parser.add_argument("--project", type=str, default="mindsdb", help="Project containing the agent")
|
|
92
|
+
parser.add_argument("--host", type=str, default="127.0.0.1", help="MCP server host")
|
|
93
|
+
parser.add_argument("--port", type=int, default=47337, help="MCP server port")
|
|
94
|
+
parser.add_argument("--query", type=str, help="Query to send to the agent")
|
|
95
|
+
parser.add_argument("--stream", action="store_true", help="Stream the response")
|
|
96
|
+
parser.add_argument("--execute-direct", type=str, help="Execute a direct SQL query via MCP (for testing)")
|
|
97
|
+
|
|
98
|
+
args = parser.parse_args()
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
# Initialize database connection
|
|
102
|
+
from mindsdb.interfaces.storage import db
|
|
103
|
+
db.init()
|
|
104
|
+
|
|
105
|
+
# Direct SQL execution mode (for testing MCP connection)
|
|
106
|
+
if args.execute_direct:
|
|
107
|
+
await execute_direct_query(args.execute_direct)
|
|
108
|
+
return 0
|
|
109
|
+
|
|
110
|
+
# Make sure agent name is provided
|
|
111
|
+
if not args.agent:
|
|
112
|
+
parser.error("the --agent argument is required unless --execute-direct is used")
|
|
113
|
+
|
|
114
|
+
# Create the agent
|
|
115
|
+
logger.info(f"Creating MCP client agent for '{args.agent}' in project '{args.project}'")
|
|
116
|
+
logger.info(f"Connecting to MCP server at {args.host}:{args.port}")
|
|
117
|
+
logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
|
|
118
|
+
|
|
119
|
+
agent_wrapper = create_mcp_agent(
|
|
120
|
+
agent_name=args.agent,
|
|
121
|
+
project_name=args.project,
|
|
122
|
+
mcp_host=args.host,
|
|
123
|
+
mcp_port=args.port
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Run an example query if provided
|
|
127
|
+
if args.query:
|
|
128
|
+
messages = [{"role": "user", "content": args.query}]
|
|
129
|
+
await run_conversation(agent_wrapper, messages, args.stream)
|
|
130
|
+
else:
|
|
131
|
+
# Interactive mode
|
|
132
|
+
logger.info("Entering interactive mode. Type 'exit' to quit.")
|
|
133
|
+
logger.info("Available commands: exit/quit, clear, sql:")
|
|
134
|
+
|
|
135
|
+
# We still need to show these instructions to the user
|
|
136
|
+
sys.stdout.write("\nEntering interactive mode. Type 'exit' to quit.\n")
|
|
137
|
+
sys.stdout.write("\nAvailable commands:\n")
|
|
138
|
+
sys.stdout.write(" exit, quit - Exit the program\n")
|
|
139
|
+
sys.stdout.write(" clear - Clear conversation history\n")
|
|
140
|
+
sys.stdout.write(" sql: <query> - Execute a direct SQL query via MCP\n")
|
|
141
|
+
sys.stdout.flush()
|
|
142
|
+
|
|
143
|
+
messages = []
|
|
144
|
+
|
|
145
|
+
while True:
|
|
146
|
+
# We need to keep input for user interaction
|
|
147
|
+
user_input = input("\nYou: ")
|
|
148
|
+
|
|
149
|
+
# Check for special commands
|
|
150
|
+
if user_input.lower() in ["exit", "quit"]:
|
|
151
|
+
logger.info("Exiting interactive mode")
|
|
152
|
+
break
|
|
153
|
+
elif user_input.lower() == "clear":
|
|
154
|
+
messages = []
|
|
155
|
+
logger.info("Conversation history cleared")
|
|
156
|
+
sys.stdout.write("Conversation history cleared\n")
|
|
157
|
+
sys.stdout.flush()
|
|
158
|
+
continue
|
|
159
|
+
elif user_input.lower().startswith("sql:"):
|
|
160
|
+
# Direct SQL execution using the agent's session
|
|
161
|
+
sql_query = user_input[4:].strip()
|
|
162
|
+
logger.info(f"Executing SQL: {sql_query}")
|
|
163
|
+
try:
|
|
164
|
+
# Use the tool from the agent
|
|
165
|
+
if hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session:
|
|
166
|
+
result = await agent_wrapper.agent.session.call_tool("query", {"query": sql_query})
|
|
167
|
+
logger.info(f"SQL result: {result.content}")
|
|
168
|
+
# We need to show the result to the user
|
|
169
|
+
sys.stdout.write(f"Result: {result.content}\n")
|
|
170
|
+
sys.stdout.flush()
|
|
171
|
+
else:
|
|
172
|
+
logger.error("No active MCP session")
|
|
173
|
+
sys.stdout.write("Error: No active MCP session\n")
|
|
174
|
+
sys.stdout.flush()
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f"SQL Error: {str(e)}")
|
|
177
|
+
sys.stdout.write(f"SQL Error: {str(e)}\n")
|
|
178
|
+
sys.stdout.flush()
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
messages.append({"role": "user", "content": user_input})
|
|
182
|
+
await run_conversation(agent_wrapper, messages, args.stream)
|
|
183
|
+
|
|
184
|
+
# Add assistant's response to the conversation history
|
|
185
|
+
if not args.stream:
|
|
186
|
+
response = await agent_wrapper.acompletion(messages)
|
|
187
|
+
messages.append({
|
|
188
|
+
"role": "assistant",
|
|
189
|
+
"content": response["choices"][0]["message"]["content"]
|
|
190
|
+
})
|
|
191
|
+
|
|
192
|
+
# Clean up resources
|
|
193
|
+
logger.info("Cleaning up resources")
|
|
194
|
+
await agent_wrapper.cleanup()
|
|
195
|
+
|
|
196
|
+
except Exception as e:
|
|
197
|
+
logger.error(f"Error running MCP agent: {str(e)}")
|
|
198
|
+
logger.info("Make sure the MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
|
|
199
|
+
return 1
|
|
200
|
+
|
|
201
|
+
return 0
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
if __name__ == "__main__":
|
|
205
|
+
sys.exit(asyncio.run(main()))
|
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
from duckdb.typing import BIGINT, DOUBLE, VARCHAR, BLOB, BOOLEAN
|
|
4
4
|
from mindsdb.interfaces.functions.to_markdown import ToMarkdown
|
|
5
5
|
from mindsdb.interfaces.storage.model_fs import HandlerStorage
|
|
6
|
+
from mindsdb.utilities.config import config
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def python_to_duckdb_type(py_type):
|
|
@@ -164,7 +165,7 @@ class FunctionController(BYOMFunctionsController):
|
|
|
164
165
|
return self.callbacks[name]
|
|
165
166
|
|
|
166
167
|
def callback(file_path_or_url, use_llm):
|
|
167
|
-
chat_model_params = self._parse_chat_model_params()
|
|
168
|
+
chat_model_params = self._parse_chat_model_params('TO_MARKDOWN_FUNCTION_')
|
|
168
169
|
|
|
169
170
|
llm_client = None
|
|
170
171
|
llm_model = None
|
|
@@ -192,7 +193,7 @@ class FunctionController(BYOMFunctionsController):
|
|
|
192
193
|
"""
|
|
193
194
|
Parses the environment variables for chat model parameters.
|
|
194
195
|
"""
|
|
195
|
-
chat_model_params = {}
|
|
196
|
+
chat_model_params = config.get("default_llm") or {}
|
|
196
197
|
for k, v in os.environ.items():
|
|
197
198
|
if k.startswith(param_prefix):
|
|
198
199
|
param_name = k[len(param_prefix):]
|
|
@@ -27,7 +27,7 @@ from mindsdb.integrations.libs.vectordatabase_handler import (
|
|
|
27
27
|
from mindsdb.integrations.utilities.rag.rag_pipeline_builder import RAG
|
|
28
28
|
from mindsdb.integrations.utilities.rag.config_loader import load_rag_config
|
|
29
29
|
from mindsdb.integrations.utilities.handler_utils import get_api_key
|
|
30
|
-
from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args
|
|
30
|
+
from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args
|
|
31
31
|
|
|
32
32
|
from mindsdb.interfaces.agents.constants import DEFAULT_EMBEDDINGS_MODEL_CLASS
|
|
33
33
|
from mindsdb.interfaces.agents.langchain_agent import create_chat_model, get_llm_provider
|
|
@@ -37,11 +37,12 @@ from mindsdb.interfaces.knowledge_base.preprocessing.document_preprocessor impor
|
|
|
37
37
|
from mindsdb.interfaces.model.functions import PredictorRecordNotFound
|
|
38
38
|
from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
|
|
39
39
|
from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
|
|
40
|
+
from mindsdb.utilities.config import config
|
|
40
41
|
from mindsdb.utilities.context import context as ctx
|
|
41
42
|
|
|
42
43
|
from mindsdb.api.executor.command_executor import ExecuteCommands
|
|
43
44
|
from mindsdb.utilities import log
|
|
44
|
-
from mindsdb.integrations.utilities.rag.rerankers.
|
|
45
|
+
from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMReranker
|
|
45
46
|
|
|
46
47
|
logger = log.getLogger(__name__)
|
|
47
48
|
|
|
@@ -52,6 +53,18 @@ KB_TO_VECTORDB_COLUMNS = {
|
|
|
52
53
|
}
|
|
53
54
|
|
|
54
55
|
|
|
56
|
+
def get_model_params(model_params: dict, default_config_key: str):
|
|
57
|
+
"""
|
|
58
|
+
Get model parameters by combining default config with user provided parameters.
|
|
59
|
+
"""
|
|
60
|
+
combined_model_params = copy.deepcopy(config.get(default_config_key, {}))
|
|
61
|
+
|
|
62
|
+
if model_params:
|
|
63
|
+
combined_model_params.update(model_params)
|
|
64
|
+
|
|
65
|
+
return combined_model_params
|
|
66
|
+
|
|
67
|
+
|
|
55
68
|
def get_embedding_model_from_params(embedding_model_params: dict):
|
|
56
69
|
"""
|
|
57
70
|
Create embedding model from parameters.
|
|
@@ -65,6 +78,11 @@ def get_embedding_model_from_params(embedding_model_params: dict):
|
|
|
65
78
|
if provider == 'azure_openai':
|
|
66
79
|
# Azure OpenAI expects the api_key to be passed as 'openai_api_key'.
|
|
67
80
|
params_copy['openai_api_key'] = api_key
|
|
81
|
+
params_copy['azure_endpoint'] = params_copy.pop('base_url')
|
|
82
|
+
if 'chunk_size' not in params_copy:
|
|
83
|
+
params_copy['chunk_size'] = 2048
|
|
84
|
+
if 'api_version' in params_copy:
|
|
85
|
+
params_copy['openai_api_version'] = params_copy['api_version']
|
|
68
86
|
else:
|
|
69
87
|
params_copy[f"{provider}_api_key"] = api_key
|
|
70
88
|
params_copy.pop('api_key', None)
|
|
@@ -78,14 +96,13 @@ def get_reranking_model_from_params(reranking_model_params: dict):
|
|
|
78
96
|
Create reranking model from parameters.
|
|
79
97
|
"""
|
|
80
98
|
params_copy = copy.deepcopy(reranking_model_params)
|
|
81
|
-
provider = params_copy.
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
params_copy.pop('api_key', None)
|
|
99
|
+
provider = params_copy.get('provider', "openai").lower()
|
|
100
|
+
|
|
101
|
+
if "api_key" not in params_copy:
|
|
102
|
+
params_copy["api_key"] = get_api_key(provider, params_copy, strict=False)
|
|
86
103
|
params_copy['model'] = params_copy.pop('model_name', None)
|
|
87
104
|
|
|
88
|
-
return
|
|
105
|
+
return BaseLLMReranker(**params_copy)
|
|
89
106
|
|
|
90
107
|
|
|
91
108
|
class KnowledgeBaseTable:
|
|
@@ -211,7 +228,7 @@ class KnowledgeBaseTable:
|
|
|
211
228
|
def add_relevance(self, df, query_text, relevance_threshold=None):
|
|
212
229
|
relevance_column = TableField.RELEVANCE.value
|
|
213
230
|
|
|
214
|
-
reranking_model_params = self._kb.params.get("reranking_model")
|
|
231
|
+
reranking_model_params = get_model_params(self._kb.params.get("reranking_model"), "default_llm")
|
|
215
232
|
if reranking_model_params and query_text and len(df) > 0:
|
|
216
233
|
# Use reranker for relevance score
|
|
217
234
|
try:
|
|
@@ -424,11 +441,12 @@ class KnowledgeBaseTable:
|
|
|
424
441
|
db_handler = self.get_vector_db()
|
|
425
442
|
db_handler.delete(self._kb.vector_database_table)
|
|
426
443
|
|
|
427
|
-
def insert(self, df: pd.DataFrame):
|
|
444
|
+
def insert(self, df: pd.DataFrame, params: dict = None):
|
|
428
445
|
"""Insert dataframe to KB table.
|
|
429
446
|
|
|
430
447
|
Args:
|
|
431
448
|
df: DataFrame to insert
|
|
449
|
+
params: User parameters of insert
|
|
432
450
|
"""
|
|
433
451
|
if df.empty:
|
|
434
452
|
return
|
|
@@ -497,7 +515,12 @@ class KnowledgeBaseTable:
|
|
|
497
515
|
df_emb = self._df_to_embeddings(df)
|
|
498
516
|
df = pd.concat([df, df_emb], axis=1)
|
|
499
517
|
db_handler = self.get_vector_db()
|
|
500
|
-
|
|
518
|
+
|
|
519
|
+
if params is not None and params.get('kb_no_upsert', False):
|
|
520
|
+
# speed up inserting by disable checking existing records
|
|
521
|
+
db_handler.insert(self._kb.vector_database_table, df)
|
|
522
|
+
else:
|
|
523
|
+
db_handler.do_upsert(self._kb.vector_database_table, df)
|
|
501
524
|
|
|
502
525
|
def _adapt_column_names(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
503
526
|
'''
|
|
@@ -647,47 +670,34 @@ class KnowledgeBaseTable:
|
|
|
647
670
|
if df.empty:
|
|
648
671
|
return pd.DataFrame([], columns=[TableField.EMBEDDINGS.value])
|
|
649
672
|
|
|
650
|
-
# keep only content
|
|
651
|
-
df = df[[TableField.CONTENT.value]]
|
|
652
|
-
|
|
653
673
|
model_id = self._kb.embedding_model_id
|
|
654
|
-
if model_id:
|
|
655
|
-
# get the input columns
|
|
656
|
-
model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first()
|
|
657
674
|
|
|
658
|
-
|
|
659
|
-
|
|
675
|
+
# get the input columns
|
|
676
|
+
model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first()
|
|
660
677
|
|
|
661
|
-
|
|
678
|
+
assert model_rec is not None, f"Model not found: {model_id}"
|
|
679
|
+
model_project = db.session.query(db.Project).filter_by(id=model_rec.project_id).first()
|
|
662
680
|
|
|
663
|
-
|
|
664
|
-
input_col = model_using.get('question_column')
|
|
665
|
-
if input_col is None:
|
|
666
|
-
input_col = model_using.get('input_column')
|
|
667
|
-
|
|
668
|
-
if input_col is not None and input_col != TableField.CONTENT.value:
|
|
669
|
-
df = df.rename(columns={TableField.CONTENT.value: input_col})
|
|
670
|
-
|
|
671
|
-
df_out = project_datanode.predict(
|
|
672
|
-
model_name=model_rec.name,
|
|
673
|
-
df=df,
|
|
674
|
-
params=self.model_params
|
|
675
|
-
)
|
|
681
|
+
project_datanode = self.session.datahub.get(model_project.name)
|
|
676
682
|
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
683
|
+
model_using = model_rec.learn_args.get('using', {})
|
|
684
|
+
input_col = model_using.get('question_column')
|
|
685
|
+
if input_col is None:
|
|
686
|
+
input_col = model_using.get('input_column')
|
|
681
687
|
|
|
682
|
-
|
|
683
|
-
|
|
688
|
+
if input_col is not None and input_col != TableField.CONTENT.value:
|
|
689
|
+
df = df.rename(columns={TableField.CONTENT.value: input_col})
|
|
684
690
|
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
691
|
+
df_out = project_datanode.predict(
|
|
692
|
+
model_name=model_rec.name,
|
|
693
|
+
df=df,
|
|
694
|
+
params=self.model_params
|
|
695
|
+
)
|
|
688
696
|
|
|
689
|
-
|
|
690
|
-
|
|
697
|
+
target = model_rec.to_predict[0]
|
|
698
|
+
if target != TableField.EMBEDDINGS.value:
|
|
699
|
+
# adapt output for vectordb
|
|
700
|
+
df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
|
|
691
701
|
|
|
692
702
|
df_out = df_out[[TableField.EMBEDDINGS.value]]
|
|
693
703
|
|
|
@@ -718,14 +728,15 @@ class KnowledgeBaseTable:
|
|
|
718
728
|
"""
|
|
719
729
|
# Get embedding model from knowledge base
|
|
720
730
|
embeddings_model = None
|
|
731
|
+
embedding_model_params = get_model_params(self._kb.params.get('embedding_model', {}), 'default_embedding_model')
|
|
721
732
|
if self._kb.embedding_model:
|
|
722
733
|
# Extract embedding model args from knowledge base table
|
|
723
734
|
embedding_args = self._kb.embedding_model.learn_args.get('using', {})
|
|
724
735
|
# Construct the embedding model directly
|
|
725
736
|
embeddings_model = construct_model_from_args(embedding_args)
|
|
726
737
|
logger.debug(f"Using knowledge base embedding model with args: {embedding_args}")
|
|
727
|
-
elif
|
|
728
|
-
embeddings_model = get_embedding_model_from_params(
|
|
738
|
+
elif embedding_model_params:
|
|
739
|
+
embeddings_model = get_embedding_model_from_params(embedding_model_params)
|
|
729
740
|
logger.debug(f"Using knowledge base embedding model from params: {self._kb.params['embedding_model']}")
|
|
730
741
|
else:
|
|
731
742
|
embeddings_model = DEFAULT_EMBEDDINGS_MODEL_CLASS()
|
|
@@ -859,35 +870,33 @@ class KnowledgeBaseController:
|
|
|
859
870
|
return kb
|
|
860
871
|
raise EntityExistsError("Knowledge base already exists", name)
|
|
861
872
|
|
|
862
|
-
|
|
863
|
-
reranking_model_params = params.get('reranking_model', None)
|
|
873
|
+
embedding_params = copy.deepcopy(config.get('default_embedding_model', {}))
|
|
864
874
|
|
|
875
|
+
model_name = None
|
|
876
|
+
model_project = project
|
|
865
877
|
if embedding_model:
|
|
866
878
|
model_name = embedding_model.parts[-1]
|
|
879
|
+
if len(embedding_model.parts) > 1:
|
|
880
|
+
model_project = self.session.database_controller.get_project(embedding_model.parts[-2])
|
|
867
881
|
|
|
868
|
-
elif
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
882
|
+
elif 'embedding_model' in params:
|
|
883
|
+
if isinstance(params['embedding_model'], str):
|
|
884
|
+
# it is model name
|
|
885
|
+
model_name = params['embedding_model']
|
|
886
|
+
else:
|
|
887
|
+
# it is params for model
|
|
888
|
+
embedding_params.update(params['embedding_model'])
|
|
874
889
|
|
|
875
|
-
|
|
876
|
-
model_name = self.
|
|
890
|
+
if model_name is None:
|
|
891
|
+
model_name = self._create_embedding_model(
|
|
877
892
|
project.name,
|
|
878
|
-
params=
|
|
893
|
+
params=embedding_params,
|
|
894
|
+
kb_name=name,
|
|
879
895
|
)
|
|
880
|
-
params['
|
|
881
|
-
|
|
882
|
-
model_project = None
|
|
883
|
-
if embedding_model is not None and len(embedding_model.parts) > 1:
|
|
884
|
-
# model project is set
|
|
885
|
-
model_project = self.session.database_controller.get_project(embedding_model.parts[-2])
|
|
886
|
-
elif not embedding_model_params:
|
|
887
|
-
model_project = project
|
|
896
|
+
params['created_embedding_model'] = model_name
|
|
888
897
|
|
|
889
898
|
embedding_model_id = None
|
|
890
|
-
if
|
|
899
|
+
if model_name is not None:
|
|
891
900
|
model = self.session.model_controller.get_model(
|
|
892
901
|
name=model_name,
|
|
893
902
|
project_name=model_project.name
|
|
@@ -895,6 +904,7 @@ class KnowledgeBaseController:
|
|
|
895
904
|
model_record = db.Predictor.query.get(model['id'])
|
|
896
905
|
embedding_model_id = model_record.id
|
|
897
906
|
|
|
907
|
+
reranking_model_params = get_model_params(params.get('reranking_model', {}), 'default_llm')
|
|
898
908
|
if reranking_model_params:
|
|
899
909
|
# Get reranking model from params.
|
|
900
910
|
# This is called here to check validaity of the parameters.
|
|
@@ -979,38 +989,52 @@ class KnowledgeBaseController:
|
|
|
979
989
|
self.session.integration_controller.add(vector_store_name, engine, connection_args)
|
|
980
990
|
return vector_store_name
|
|
981
991
|
|
|
982
|
-
def
|
|
992
|
+
def _create_embedding_model(self, project_name, engine="openai", params: dict = None, kb_name=''):
|
|
983
993
|
"""create a default embedding model for knowledge base, if not specified"""
|
|
984
|
-
model_name = "
|
|
994
|
+
model_name = f"kb_embedding_{kb_name}"
|
|
985
995
|
|
|
986
|
-
#
|
|
996
|
+
# drop if exists - parameters can be different
|
|
987
997
|
try:
|
|
988
998
|
model = self.session.model_controller.get_model(model_name, project_name=project_name)
|
|
989
999
|
if model is not None:
|
|
990
|
-
|
|
1000
|
+
self.session.model_controller.delete_model(model_name, project_name)
|
|
991
1001
|
except PredictorRecordNotFound:
|
|
992
1002
|
pass
|
|
993
1003
|
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
if engine == '
|
|
998
|
-
|
|
999
|
-
|
|
1004
|
+
if 'provider' in params:
|
|
1005
|
+
engine = params.pop('provider').lower()
|
|
1006
|
+
|
|
1007
|
+
if engine == 'azure_openai':
|
|
1008
|
+
engine = 'openai'
|
|
1009
|
+
params['provider'] = 'azure'
|
|
1010
|
+
|
|
1011
|
+
if engine == 'openai':
|
|
1012
|
+
if 'question_column' not in params:
|
|
1013
|
+
params['question_column'] = 'content'
|
|
1014
|
+
if 'api_key' in params:
|
|
1015
|
+
params[f"{engine}_api_key"] = params.pop('api_key')
|
|
1016
|
+
if 'base_url' in params:
|
|
1017
|
+
params['api_base'] = params.pop('base_url')
|
|
1018
|
+
|
|
1019
|
+
params['engine'] = engine
|
|
1020
|
+
params['join_learn_process'] = True
|
|
1021
|
+
params['mode'] = 'embedding'
|
|
1000
1022
|
|
|
1001
1023
|
# Include API key if provided.
|
|
1002
|
-
using_args.update({k: v for k, v in params.items() if 'api_key' in k})
|
|
1003
1024
|
statement = CreatePredictor(
|
|
1004
1025
|
name=Identifier(parts=[project_name, model_name]),
|
|
1005
|
-
using=
|
|
1026
|
+
using=params,
|
|
1006
1027
|
targets=[
|
|
1007
1028
|
Identifier(parts=[TableField.EMBEDDINGS.value])
|
|
1008
1029
|
]
|
|
1009
1030
|
)
|
|
1010
1031
|
|
|
1011
1032
|
command_executor = ExecuteCommands(self.session)
|
|
1012
|
-
command_executor.answer_create_predictor(statement, project_name)
|
|
1013
|
-
|
|
1033
|
+
resp = command_executor.answer_create_predictor(statement, project_name)
|
|
1034
|
+
# check model status
|
|
1035
|
+
record = resp.data.records[0]
|
|
1036
|
+
if record['STATUS'] == 'error':
|
|
1037
|
+
raise ValueError('Embedding model error:' + record['ERROR'])
|
|
1014
1038
|
return model_name
|
|
1015
1039
|
|
|
1016
1040
|
def delete(self, name: str, project_name: int, if_exists: bool = False) -> None:
|
|
@@ -1044,9 +1068,9 @@ class KnowledgeBaseController:
|
|
|
1044
1068
|
self.session.integration_controller.delete(kb.params['default_vector_storage'])
|
|
1045
1069
|
except EntityNotExistsError:
|
|
1046
1070
|
pass
|
|
1047
|
-
if '
|
|
1071
|
+
if 'created_embedding_model' in kb.params:
|
|
1048
1072
|
try:
|
|
1049
|
-
self.session.model_controller.delete_model(kb.params['
|
|
1073
|
+
self.session.model_controller.delete_model(kb.params['created_embedding_model'], project_name)
|
|
1050
1074
|
except EntityNotExistsError:
|
|
1051
1075
|
pass
|
|
1052
1076
|
|