MindsDB 25.4.3.2__py3-none-any.whl → 25.4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (68) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/__main__.py +18 -4
  3. mindsdb/api/executor/command_executor.py +12 -2
  4. mindsdb/api/executor/data_types/response_type.py +1 -0
  5. mindsdb/api/executor/datahub/classes/tables_row.py +3 -10
  6. mindsdb/api/executor/datahub/datanodes/datanode.py +7 -2
  7. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +44 -10
  8. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +57 -38
  9. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +2 -1
  10. mindsdb/api/executor/datahub/datanodes/project_datanode.py +39 -7
  11. mindsdb/api/executor/datahub/datanodes/system_tables.py +116 -109
  12. mindsdb/api/executor/planner/query_plan.py +1 -0
  13. mindsdb/api/executor/planner/query_planner.py +15 -1
  14. mindsdb/api/executor/planner/steps.py +8 -2
  15. mindsdb/api/executor/sql_query/sql_query.py +24 -8
  16. mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +25 -8
  17. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +4 -2
  18. mindsdb/api/executor/sql_query/steps/insert_step.py +2 -1
  19. mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -3
  20. mindsdb/api/http/namespaces/config.py +19 -11
  21. mindsdb/api/litellm/start.py +82 -0
  22. mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +133 -0
  23. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +7 -2
  24. mindsdb/integrations/handlers/chromadb_handler/settings.py +1 -0
  25. mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +13 -4
  26. mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +14 -5
  27. mindsdb/integrations/handlers/openai_handler/helpers.py +3 -5
  28. mindsdb/integrations/handlers/openai_handler/openai_handler.py +20 -8
  29. mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +14 -4
  30. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +34 -19
  31. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +21 -18
  32. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +14 -4
  33. mindsdb/integrations/handlers/togetherai_handler/__about__.py +9 -0
  34. mindsdb/integrations/handlers/togetherai_handler/__init__.py +20 -0
  35. mindsdb/integrations/handlers/togetherai_handler/creation_args.py +14 -0
  36. mindsdb/integrations/handlers/togetherai_handler/icon.svg +15 -0
  37. mindsdb/integrations/handlers/togetherai_handler/model_using_args.py +5 -0
  38. mindsdb/integrations/handlers/togetherai_handler/requirements.txt +2 -0
  39. mindsdb/integrations/handlers/togetherai_handler/settings.py +33 -0
  40. mindsdb/integrations/handlers/togetherai_handler/togetherai_handler.py +234 -0
  41. mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +1 -1
  42. mindsdb/integrations/libs/response.py +80 -32
  43. mindsdb/integrations/utilities/handler_utils.py +4 -0
  44. mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +360 -0
  45. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +8 -153
  46. mindsdb/interfaces/agents/litellm_server.py +345 -0
  47. mindsdb/interfaces/agents/mcp_client_agent.py +252 -0
  48. mindsdb/interfaces/agents/run_mcp_agent.py +205 -0
  49. mindsdb/interfaces/functions/controller.py +3 -2
  50. mindsdb/interfaces/knowledge_base/controller.py +106 -82
  51. mindsdb/interfaces/query_context/context_controller.py +55 -15
  52. mindsdb/interfaces/query_context/query_task.py +19 -0
  53. mindsdb/interfaces/skills/skill_tool.py +7 -1
  54. mindsdb/interfaces/skills/sql_agent.py +8 -3
  55. mindsdb/interfaces/storage/db.py +2 -2
  56. mindsdb/interfaces/tasks/task_monitor.py +5 -1
  57. mindsdb/interfaces/tasks/task_thread.py +6 -0
  58. mindsdb/migrations/versions/2025-04-22_53502b6d63bf_query_database.py +27 -0
  59. mindsdb/utilities/config.py +20 -2
  60. mindsdb/utilities/context.py +1 -0
  61. mindsdb/utilities/starters.py +7 -0
  62. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/METADATA +226 -221
  63. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/RECORD +67 -53
  64. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/WHEEL +1 -1
  65. mindsdb/integrations/handlers/snowflake_handler/tests/test_snowflake_handler.py +0 -230
  66. /mindsdb/{integrations/handlers/snowflake_handler/tests → api/litellm}/__init__.py +0 -0
  67. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/licenses/LICENSE +0 -0
  68. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,205 @@
1
+ import sys
2
+ import argparse
3
+ import asyncio
4
+ from typing import List, Dict
5
+ from contextlib import AsyncExitStack
6
+
7
+ from mcp import ClientSession, StdioServerParameters
8
+ from mcp.client.stdio import stdio_client
9
+
10
+ from mindsdb.utilities import log
11
+ from mindsdb.interfaces.agents.mcp_client_agent import create_mcp_agent
12
+
13
+ logger = log.getLogger(__name__)
14
+
15
+
16
+ async def run_conversation(agent_wrapper, messages: List[Dict[str, str]], stream: bool = False):
17
+ """Run a conversation with the agent and print responses"""
18
+ try:
19
+ if stream:
20
+ logger.info("Streaming response:")
21
+ async for chunk in agent_wrapper.acompletion_stream(messages):
22
+ content = chunk["choices"][0]["delta"].get("content", "")
23
+ if content:
24
+ # We still need to print content for streaming display
25
+ # but we'll log it as debug as well
26
+ logger.debug(f"Stream content: {content}")
27
+ sys.stdout.write(content)
28
+ sys.stdout.flush()
29
+ logger.debug("End of stream")
30
+ sys.stdout.write("\n\n")
31
+ sys.stdout.flush()
32
+ else:
33
+ logger.info("Getting response...")
34
+ response = await agent_wrapper.acompletion(messages)
35
+ content = response["choices"][0]["message"]["content"]
36
+ logger.info(f"Response: {content}")
37
+ # We still need to display the response to the user
38
+ sys.stdout.write(f"{content}\n")
39
+ sys.stdout.flush()
40
+ except Exception as e:
41
+ logger.error(f"Error during agent conversation: {str(e)}")
42
+
43
+
44
+ async def execute_direct_query(query):
45
+ """Execute a direct SQL query using MCP"""
46
+ logger.info(f"Executing direct SQL query: {query}")
47
+
48
+ # Set up MCP client to connect to the running server
49
+ async with AsyncExitStack() as stack:
50
+ # Connect to MCP server
51
+ server_params = StdioServerParameters(
52
+ command="python",
53
+ args=["-m", "mindsdb", "--api=mcp"],
54
+ env=None
55
+ )
56
+
57
+ try:
58
+ stdio_transport = await stack.enter_async_context(stdio_client(server_params))
59
+ stdio, write = stdio_transport
60
+ session = await stack.enter_async_context(ClientSession(stdio, write))
61
+
62
+ await session.initialize()
63
+
64
+ # List available tools
65
+ tools_response = await session.list_tools()
66
+ tool_names = [tool.name for tool in tools_response.tools]
67
+ logger.info(f"Available tools: {tool_names}")
68
+
69
+ # Find query tool
70
+ query_tool = None
71
+ for tool in tools_response.tools:
72
+ if tool.name == "query":
73
+ query_tool = tool
74
+ break
75
+
76
+ if not query_tool:
77
+ logger.error("No 'query' tool found in MCP server")
78
+ return
79
+
80
+ # Execute query
81
+ result = await session.call_tool("query", {"query": query})
82
+ logger.info(f"Query result: {result.content}")
83
+ except Exception as e:
84
+ logger.error(f"Error executing query: {str(e)}")
85
+ logger.info("Make sure the MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
86
+
87
+
88
+ async def main():
89
+ parser = argparse.ArgumentParser(description="Run an agent as an MCP client")
90
+ parser.add_argument("--agent", type=str, help="Name of the agent to use")
91
+ parser.add_argument("--project", type=str, default="mindsdb", help="Project containing the agent")
92
+ parser.add_argument("--host", type=str, default="127.0.0.1", help="MCP server host")
93
+ parser.add_argument("--port", type=int, default=47337, help="MCP server port")
94
+ parser.add_argument("--query", type=str, help="Query to send to the agent")
95
+ parser.add_argument("--stream", action="store_true", help="Stream the response")
96
+ parser.add_argument("--execute-direct", type=str, help="Execute a direct SQL query via MCP (for testing)")
97
+
98
+ args = parser.parse_args()
99
+
100
+ try:
101
+ # Initialize database connection
102
+ from mindsdb.interfaces.storage import db
103
+ db.init()
104
+
105
+ # Direct SQL execution mode (for testing MCP connection)
106
+ if args.execute_direct:
107
+ await execute_direct_query(args.execute_direct)
108
+ return 0
109
+
110
+ # Make sure agent name is provided
111
+ if not args.agent:
112
+ parser.error("the --agent argument is required unless --execute-direct is used")
113
+
114
+ # Create the agent
115
+ logger.info(f"Creating MCP client agent for '{args.agent}' in project '{args.project}'")
116
+ logger.info(f"Connecting to MCP server at {args.host}:{args.port}")
117
+ logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
118
+
119
+ agent_wrapper = create_mcp_agent(
120
+ agent_name=args.agent,
121
+ project_name=args.project,
122
+ mcp_host=args.host,
123
+ mcp_port=args.port
124
+ )
125
+
126
+ # Run an example query if provided
127
+ if args.query:
128
+ messages = [{"role": "user", "content": args.query}]
129
+ await run_conversation(agent_wrapper, messages, args.stream)
130
+ else:
131
+ # Interactive mode
132
+ logger.info("Entering interactive mode. Type 'exit' to quit.")
133
+ logger.info("Available commands: exit/quit, clear, sql:")
134
+
135
+ # We still need to show these instructions to the user
136
+ sys.stdout.write("\nEntering interactive mode. Type 'exit' to quit.\n")
137
+ sys.stdout.write("\nAvailable commands:\n")
138
+ sys.stdout.write(" exit, quit - Exit the program\n")
139
+ sys.stdout.write(" clear - Clear conversation history\n")
140
+ sys.stdout.write(" sql: <query> - Execute a direct SQL query via MCP\n")
141
+ sys.stdout.flush()
142
+
143
+ messages = []
144
+
145
+ while True:
146
+ # We need to keep input for user interaction
147
+ user_input = input("\nYou: ")
148
+
149
+ # Check for special commands
150
+ if user_input.lower() in ["exit", "quit"]:
151
+ logger.info("Exiting interactive mode")
152
+ break
153
+ elif user_input.lower() == "clear":
154
+ messages = []
155
+ logger.info("Conversation history cleared")
156
+ sys.stdout.write("Conversation history cleared\n")
157
+ sys.stdout.flush()
158
+ continue
159
+ elif user_input.lower().startswith("sql:"):
160
+ # Direct SQL execution using the agent's session
161
+ sql_query = user_input[4:].strip()
162
+ logger.info(f"Executing SQL: {sql_query}")
163
+ try:
164
+ # Use the tool from the agent
165
+ if hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session:
166
+ result = await agent_wrapper.agent.session.call_tool("query", {"query": sql_query})
167
+ logger.info(f"SQL result: {result.content}")
168
+ # We need to show the result to the user
169
+ sys.stdout.write(f"Result: {result.content}\n")
170
+ sys.stdout.flush()
171
+ else:
172
+ logger.error("No active MCP session")
173
+ sys.stdout.write("Error: No active MCP session\n")
174
+ sys.stdout.flush()
175
+ except Exception as e:
176
+ logger.error(f"SQL Error: {str(e)}")
177
+ sys.stdout.write(f"SQL Error: {str(e)}\n")
178
+ sys.stdout.flush()
179
+ continue
180
+
181
+ messages.append({"role": "user", "content": user_input})
182
+ await run_conversation(agent_wrapper, messages, args.stream)
183
+
184
+ # Add assistant's response to the conversation history
185
+ if not args.stream:
186
+ response = await agent_wrapper.acompletion(messages)
187
+ messages.append({
188
+ "role": "assistant",
189
+ "content": response["choices"][0]["message"]["content"]
190
+ })
191
+
192
+ # Clean up resources
193
+ logger.info("Cleaning up resources")
194
+ await agent_wrapper.cleanup()
195
+
196
+ except Exception as e:
197
+ logger.error(f"Error running MCP agent: {str(e)}")
198
+ logger.info("Make sure the MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
199
+ return 1
200
+
201
+ return 0
202
+
203
+
204
+ if __name__ == "__main__":
205
+ sys.exit(asyncio.run(main()))
@@ -3,6 +3,7 @@ import os
3
3
  from duckdb.typing import BIGINT, DOUBLE, VARCHAR, BLOB, BOOLEAN
4
4
  from mindsdb.interfaces.functions.to_markdown import ToMarkdown
5
5
  from mindsdb.interfaces.storage.model_fs import HandlerStorage
6
+ from mindsdb.utilities.config import config
6
7
 
7
8
 
8
9
  def python_to_duckdb_type(py_type):
@@ -164,7 +165,7 @@ class FunctionController(BYOMFunctionsController):
164
165
  return self.callbacks[name]
165
166
 
166
167
  def callback(file_path_or_url, use_llm):
167
- chat_model_params = self._parse_chat_model_params()
168
+ chat_model_params = self._parse_chat_model_params('TO_MARKDOWN_FUNCTION_')
168
169
 
169
170
  llm_client = None
170
171
  llm_model = None
@@ -192,7 +193,7 @@ class FunctionController(BYOMFunctionsController):
192
193
  """
193
194
  Parses the environment variables for chat model parameters.
194
195
  """
195
- chat_model_params = {}
196
+ chat_model_params = config.get("default_llm") or {}
196
197
  for k, v in os.environ.items():
197
198
  if k.startswith(param_prefix):
198
199
  param_name = k[len(param_prefix):]
@@ -27,7 +27,7 @@ from mindsdb.integrations.libs.vectordatabase_handler import (
27
27
  from mindsdb.integrations.utilities.rag.rag_pipeline_builder import RAG
28
28
  from mindsdb.integrations.utilities.rag.config_loader import load_rag_config
29
29
  from mindsdb.integrations.utilities.handler_utils import get_api_key
30
- from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args, row_to_document
30
+ from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args
31
31
 
32
32
  from mindsdb.interfaces.agents.constants import DEFAULT_EMBEDDINGS_MODEL_CLASS
33
33
  from mindsdb.interfaces.agents.langchain_agent import create_chat_model, get_llm_provider
@@ -37,11 +37,12 @@ from mindsdb.interfaces.knowledge_base.preprocessing.document_preprocessor impor
37
37
  from mindsdb.interfaces.model.functions import PredictorRecordNotFound
38
38
  from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
39
39
  from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
40
+ from mindsdb.utilities.config import config
40
41
  from mindsdb.utilities.context import context as ctx
41
42
 
42
43
  from mindsdb.api.executor.command_executor import ExecuteCommands
43
44
  from mindsdb.utilities import log
44
- from mindsdb.integrations.utilities.rag.rerankers.reranker_compressor import LLMReranker
45
+ from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMReranker
45
46
 
46
47
  logger = log.getLogger(__name__)
47
48
 
@@ -52,6 +53,18 @@ KB_TO_VECTORDB_COLUMNS = {
52
53
  }
53
54
 
54
55
 
56
+ def get_model_params(model_params: dict, default_config_key: str):
57
+ """
58
+ Get model parameters by combining default config with user provided parameters.
59
+ """
60
+ combined_model_params = copy.deepcopy(config.get(default_config_key, {}))
61
+
62
+ if model_params:
63
+ combined_model_params.update(model_params)
64
+
65
+ return combined_model_params
66
+
67
+
55
68
  def get_embedding_model_from_params(embedding_model_params: dict):
56
69
  """
57
70
  Create embedding model from parameters.
@@ -65,6 +78,11 @@ def get_embedding_model_from_params(embedding_model_params: dict):
65
78
  if provider == 'azure_openai':
66
79
  # Azure OpenAI expects the api_key to be passed as 'openai_api_key'.
67
80
  params_copy['openai_api_key'] = api_key
81
+ params_copy['azure_endpoint'] = params_copy.pop('base_url')
82
+ if 'chunk_size' not in params_copy:
83
+ params_copy['chunk_size'] = 2048
84
+ if 'api_version' in params_copy:
85
+ params_copy['openai_api_version'] = params_copy['api_version']
68
86
  else:
69
87
  params_copy[f"{provider}_api_key"] = api_key
70
88
  params_copy.pop('api_key', None)
@@ -78,14 +96,13 @@ def get_reranking_model_from_params(reranking_model_params: dict):
78
96
  Create reranking model from parameters.
79
97
  """
80
98
  params_copy = copy.deepcopy(reranking_model_params)
81
- provider = params_copy.pop('provider', "openai").lower()
82
- if provider != 'openai':
83
- raise ValueError("Only OpenAI provider is supported for the reranking model.")
84
- params_copy[f"{provider}_api_key"] = get_api_key(provider, params_copy, strict=False) or params_copy.get('api_key')
85
- params_copy.pop('api_key', None)
99
+ provider = params_copy.get('provider', "openai").lower()
100
+
101
+ if "api_key" not in params_copy:
102
+ params_copy["api_key"] = get_api_key(provider, params_copy, strict=False)
86
103
  params_copy['model'] = params_copy.pop('model_name', None)
87
104
 
88
- return LLMReranker(**params_copy)
105
+ return BaseLLMReranker(**params_copy)
89
106
 
90
107
 
91
108
  class KnowledgeBaseTable:
@@ -211,7 +228,7 @@ class KnowledgeBaseTable:
211
228
  def add_relevance(self, df, query_text, relevance_threshold=None):
212
229
  relevance_column = TableField.RELEVANCE.value
213
230
 
214
- reranking_model_params = self._kb.params.get("reranking_model")
231
+ reranking_model_params = get_model_params(self._kb.params.get("reranking_model"), "default_llm")
215
232
  if reranking_model_params and query_text and len(df) > 0:
216
233
  # Use reranker for relevance score
217
234
  try:
@@ -424,11 +441,12 @@ class KnowledgeBaseTable:
424
441
  db_handler = self.get_vector_db()
425
442
  db_handler.delete(self._kb.vector_database_table)
426
443
 
427
- def insert(self, df: pd.DataFrame):
444
+ def insert(self, df: pd.DataFrame, params: dict = None):
428
445
  """Insert dataframe to KB table.
429
446
 
430
447
  Args:
431
448
  df: DataFrame to insert
449
+ params: User parameters of insert
432
450
  """
433
451
  if df.empty:
434
452
  return
@@ -497,7 +515,12 @@ class KnowledgeBaseTable:
497
515
  df_emb = self._df_to_embeddings(df)
498
516
  df = pd.concat([df, df_emb], axis=1)
499
517
  db_handler = self.get_vector_db()
500
- db_handler.do_upsert(self._kb.vector_database_table, df)
518
+
519
+ if params is not None and params.get('kb_no_upsert', False):
520
+ # speed up inserting by disable checking existing records
521
+ db_handler.insert(self._kb.vector_database_table, df)
522
+ else:
523
+ db_handler.do_upsert(self._kb.vector_database_table, df)
501
524
 
502
525
  def _adapt_column_names(self, df: pd.DataFrame) -> pd.DataFrame:
503
526
  '''
@@ -647,47 +670,34 @@ class KnowledgeBaseTable:
647
670
  if df.empty:
648
671
  return pd.DataFrame([], columns=[TableField.EMBEDDINGS.value])
649
672
 
650
- # keep only content
651
- df = df[[TableField.CONTENT.value]]
652
-
653
673
  model_id = self._kb.embedding_model_id
654
- if model_id:
655
- # get the input columns
656
- model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first()
657
674
 
658
- assert model_rec is not None, f"Model not found: {model_id}"
659
- model_project = db.session.query(db.Project).filter_by(id=model_rec.project_id).first()
675
+ # get the input columns
676
+ model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first()
660
677
 
661
- project_datanode = self.session.datahub.get(model_project.name)
678
+ assert model_rec is not None, f"Model not found: {model_id}"
679
+ model_project = db.session.query(db.Project).filter_by(id=model_rec.project_id).first()
662
680
 
663
- model_using = model_rec.learn_args.get('using', {})
664
- input_col = model_using.get('question_column')
665
- if input_col is None:
666
- input_col = model_using.get('input_column')
667
-
668
- if input_col is not None and input_col != TableField.CONTENT.value:
669
- df = df.rename(columns={TableField.CONTENT.value: input_col})
670
-
671
- df_out = project_datanode.predict(
672
- model_name=model_rec.name,
673
- df=df,
674
- params=self.model_params
675
- )
681
+ project_datanode = self.session.datahub.get(model_project.name)
676
682
 
677
- target = model_rec.to_predict[0]
678
- if target != TableField.EMBEDDINGS.value:
679
- # adapt output for vectordb
680
- df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
683
+ model_using = model_rec.learn_args.get('using', {})
684
+ input_col = model_using.get('question_column')
685
+ if input_col is None:
686
+ input_col = model_using.get('input_column')
681
687
 
682
- elif self._kb.params.get('embedding_model'):
683
- embedding_model = get_embedding_model_from_params(self._kb.params.get('embedding_model'))
688
+ if input_col is not None and input_col != TableField.CONTENT.value:
689
+ df = df.rename(columns={TableField.CONTENT.value: input_col})
684
690
 
685
- df_texts = df.apply(row_to_document, axis=1)
686
- embeddings = embedding_model.embed_documents(df_texts.tolist())
687
- df_out = df.copy().assign(**{TableField.EMBEDDINGS.value: embeddings})
691
+ df_out = project_datanode.predict(
692
+ model_name=model_rec.name,
693
+ df=df,
694
+ params=self.model_params
695
+ )
688
696
 
689
- else:
690
- raise ValueError("No embedding model found for the knowledge base.")
697
+ target = model_rec.to_predict[0]
698
+ if target != TableField.EMBEDDINGS.value:
699
+ # adapt output for vectordb
700
+ df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
691
701
 
692
702
  df_out = df_out[[TableField.EMBEDDINGS.value]]
693
703
 
@@ -718,14 +728,15 @@ class KnowledgeBaseTable:
718
728
  """
719
729
  # Get embedding model from knowledge base
720
730
  embeddings_model = None
731
+ embedding_model_params = get_model_params(self._kb.params.get('embedding_model', {}), 'default_embedding_model')
721
732
  if self._kb.embedding_model:
722
733
  # Extract embedding model args from knowledge base table
723
734
  embedding_args = self._kb.embedding_model.learn_args.get('using', {})
724
735
  # Construct the embedding model directly
725
736
  embeddings_model = construct_model_from_args(embedding_args)
726
737
  logger.debug(f"Using knowledge base embedding model with args: {embedding_args}")
727
- elif self._kb.params.get('embedding_model'):
728
- embeddings_model = get_embedding_model_from_params(self._kb.params['embedding_model'])
738
+ elif embedding_model_params:
739
+ embeddings_model = get_embedding_model_from_params(embedding_model_params)
729
740
  logger.debug(f"Using knowledge base embedding model from params: {self._kb.params['embedding_model']}")
730
741
  else:
731
742
  embeddings_model = DEFAULT_EMBEDDINGS_MODEL_CLASS()
@@ -859,35 +870,33 @@ class KnowledgeBaseController:
859
870
  return kb
860
871
  raise EntityExistsError("Knowledge base already exists", name)
861
872
 
862
- embedding_model_params = params.get('embedding_model', None)
863
- reranking_model_params = params.get('reranking_model', None)
873
+ embedding_params = copy.deepcopy(config.get('default_embedding_model', {}))
864
874
 
875
+ model_name = None
876
+ model_project = project
865
877
  if embedding_model:
866
878
  model_name = embedding_model.parts[-1]
879
+ if len(embedding_model.parts) > 1:
880
+ model_project = self.session.database_controller.get_project(embedding_model.parts[-2])
867
881
 
868
- elif embedding_model_params:
869
- # Get embedding model from params.
870
- # This is called here to check validaity of the parameters.
871
- get_embedding_model_from_params(
872
- embedding_model_params
873
- )
882
+ elif 'embedding_model' in params:
883
+ if isinstance(params['embedding_model'], str):
884
+ # it is model name
885
+ model_name = params['embedding_model']
886
+ else:
887
+ # it is params for model
888
+ embedding_params.update(params['embedding_model'])
874
889
 
875
- else:
876
- model_name = self._get_default_embedding_model(
890
+ if model_name is None:
891
+ model_name = self._create_embedding_model(
877
892
  project.name,
878
- params=params
893
+ params=embedding_params,
894
+ kb_name=name,
879
895
  )
880
- params['default_embedding_model'] = model_name
881
-
882
- model_project = None
883
- if embedding_model is not None and len(embedding_model.parts) > 1:
884
- # model project is set
885
- model_project = self.session.database_controller.get_project(embedding_model.parts[-2])
886
- elif not embedding_model_params:
887
- model_project = project
896
+ params['created_embedding_model'] = model_name
888
897
 
889
898
  embedding_model_id = None
890
- if model_project:
899
+ if model_name is not None:
891
900
  model = self.session.model_controller.get_model(
892
901
  name=model_name,
893
902
  project_name=model_project.name
@@ -895,6 +904,7 @@ class KnowledgeBaseController:
895
904
  model_record = db.Predictor.query.get(model['id'])
896
905
  embedding_model_id = model_record.id
897
906
 
907
+ reranking_model_params = get_model_params(params.get('reranking_model', {}), 'default_llm')
898
908
  if reranking_model_params:
899
909
  # Get reranking model from params.
900
910
  # This is called here to check validaity of the parameters.
@@ -979,38 +989,52 @@ class KnowledgeBaseController:
979
989
  self.session.integration_controller.add(vector_store_name, engine, connection_args)
980
990
  return vector_store_name
981
991
 
982
- def _get_default_embedding_model(self, project_name, engine="langchain_embedding", params: dict = None):
992
+ def _create_embedding_model(self, project_name, engine="openai", params: dict = None, kb_name=''):
983
993
  """create a default embedding model for knowledge base, if not specified"""
984
- model_name = "kb_default_embedding_model"
994
+ model_name = f"kb_embedding_{kb_name}"
985
995
 
986
- # check exists
996
+ # drop if exists - parameters can be different
987
997
  try:
988
998
  model = self.session.model_controller.get_model(model_name, project_name=project_name)
989
999
  if model is not None:
990
- return model_name
1000
+ self.session.model_controller.delete_model(model_name, project_name)
991
1001
  except PredictorRecordNotFound:
992
1002
  pass
993
1003
 
994
- using_args = {
995
- 'engine': engine
996
- }
997
- if engine == 'langchain_embedding':
998
- # Use default embeddings.
999
- using_args['class'] = 'openai'
1004
+ if 'provider' in params:
1005
+ engine = params.pop('provider').lower()
1006
+
1007
+ if engine == 'azure_openai':
1008
+ engine = 'openai'
1009
+ params['provider'] = 'azure'
1010
+
1011
+ if engine == 'openai':
1012
+ if 'question_column' not in params:
1013
+ params['question_column'] = 'content'
1014
+ if 'api_key' in params:
1015
+ params[f"{engine}_api_key"] = params.pop('api_key')
1016
+ if 'base_url' in params:
1017
+ params['api_base'] = params.pop('base_url')
1018
+
1019
+ params['engine'] = engine
1020
+ params['join_learn_process'] = True
1021
+ params['mode'] = 'embedding'
1000
1022
 
1001
1023
  # Include API key if provided.
1002
- using_args.update({k: v for k, v in params.items() if 'api_key' in k})
1003
1024
  statement = CreatePredictor(
1004
1025
  name=Identifier(parts=[project_name, model_name]),
1005
- using=using_args,
1026
+ using=params,
1006
1027
  targets=[
1007
1028
  Identifier(parts=[TableField.EMBEDDINGS.value])
1008
1029
  ]
1009
1030
  )
1010
1031
 
1011
1032
  command_executor = ExecuteCommands(self.session)
1012
- command_executor.answer_create_predictor(statement, project_name)
1013
-
1033
+ resp = command_executor.answer_create_predictor(statement, project_name)
1034
+ # check model status
1035
+ record = resp.data.records[0]
1036
+ if record['STATUS'] == 'error':
1037
+ raise ValueError('Embedding model error:' + record['ERROR'])
1014
1038
  return model_name
1015
1039
 
1016
1040
  def delete(self, name: str, project_name: int, if_exists: bool = False) -> None:
@@ -1044,9 +1068,9 @@ class KnowledgeBaseController:
1044
1068
  self.session.integration_controller.delete(kb.params['default_vector_storage'])
1045
1069
  except EntityNotExistsError:
1046
1070
  pass
1047
- if 'default_embedding_model' in kb.params:
1071
+ if 'created_embedding_model' in kb.params:
1048
1072
  try:
1049
- self.session.model_controller.delete_model(kb.params['default_embedding_model'], project_name)
1073
+ self.session.model_controller.delete_model(kb.params['created_embedding_model'], project_name)
1050
1074
  except EntityNotExistsError:
1051
1075
  pass
1052
1076