MindsDB 25.4.3.1__py3-none-any.whl → 25.4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +18 -4
- mindsdb/api/executor/data_types/response_type.py +1 -0
- mindsdb/api/executor/datahub/classes/tables_row.py +3 -10
- mindsdb/api/executor/datahub/datanodes/datanode.py +7 -2
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +44 -10
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +57 -38
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +39 -7
- mindsdb/api/executor/datahub/datanodes/system_tables.py +116 -109
- mindsdb/api/executor/planner/query_planner.py +10 -1
- mindsdb/api/executor/planner/steps.py +8 -2
- mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +5 -5
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +1 -1
- mindsdb/api/executor/sql_query/steps/insert_step.py +2 -1
- mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -3
- mindsdb/api/litellm/start.py +82 -0
- mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +133 -0
- mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +7 -2
- mindsdb/integrations/handlers/chromadb_handler/settings.py +1 -0
- mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +13 -4
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +14 -5
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +14 -4
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +34 -19
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +21 -18
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +14 -4
- mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +1 -1
- mindsdb/integrations/libs/response.py +80 -32
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +208 -13
- mindsdb/interfaces/agents/litellm_server.py +345 -0
- mindsdb/interfaces/agents/mcp_client_agent.py +252 -0
- mindsdb/interfaces/agents/run_mcp_agent.py +205 -0
- mindsdb/interfaces/knowledge_base/controller.py +17 -7
- mindsdb/interfaces/skills/skill_tool.py +7 -1
- mindsdb/interfaces/skills/sql_agent.py +8 -3
- mindsdb/utilities/config.py +8 -1
- mindsdb/utilities/starters.py +7 -0
- {mindsdb-25.4.3.1.dist-info → mindsdb-25.4.4.0.dist-info}/METADATA +232 -230
- {mindsdb-25.4.3.1.dist-info → mindsdb-25.4.4.0.dist-info}/RECORD +42 -39
- {mindsdb-25.4.3.1.dist-info → mindsdb-25.4.4.0.dist-info}/WHEEL +1 -1
- mindsdb/integrations/handlers/snowflake_handler/tests/test_snowflake_handler.py +0 -230
- /mindsdb/{integrations/handlers/snowflake_handler/tests → api/litellm}/__init__.py +0 -0
- {mindsdb-25.4.3.1.dist-info → mindsdb-25.4.4.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.4.3.1.dist-info → mindsdb-25.4.4.0.dist-info}/top_level.txt +0 -0
|
@@ -10,7 +10,8 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
|
10
10
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
|
11
11
|
from langchain_core.callbacks import Callbacks, dispatch_custom_event
|
|
12
12
|
from langchain_core.documents import Document
|
|
13
|
-
from openai import AsyncOpenAI
|
|
13
|
+
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
|
14
|
+
from pydantic import field_validator
|
|
14
15
|
|
|
15
16
|
from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
|
|
16
17
|
|
|
@@ -19,12 +20,15 @@ log = logging.getLogger(__name__)
|
|
|
19
20
|
|
|
20
21
|
class LLMReranker(BaseDocumentCompressor):
|
|
21
22
|
filtering_threshold: float = 0.0 # Default threshold for filtering
|
|
23
|
+
provider: str = 'openai'
|
|
22
24
|
model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
|
|
23
25
|
temperature: float = 0.0 # Temperature for the model
|
|
24
|
-
|
|
26
|
+
api_key: Optional[str] = None
|
|
25
27
|
remove_irrelevant: bool = True # New flag to control removal of irrelevant documents
|
|
26
|
-
base_url: str =
|
|
28
|
+
base_url: Optional[str] = None
|
|
29
|
+
api_version: Optional[str] = None
|
|
27
30
|
num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
|
|
31
|
+
method: str = "multi-class" # Scoring method: 'multi-class' or 'binary'
|
|
28
32
|
_api_key_var: str = "OPENAI_API_KEY"
|
|
29
33
|
client: Optional[AsyncOpenAI] = None
|
|
30
34
|
_semaphore: Optional[asyncio.Semaphore] = None
|
|
@@ -38,21 +42,40 @@ class LLMReranker(BaseDocumentCompressor):
|
|
|
38
42
|
class Config:
|
|
39
43
|
arbitrary_types_allowed = True
|
|
40
44
|
|
|
45
|
+
@field_validator('provider')
|
|
46
|
+
@classmethod
|
|
47
|
+
def validate_provider(cls, v: str) -> str:
|
|
48
|
+
allowed = {'openai', 'azure_openai'}
|
|
49
|
+
v_lower = v.lower()
|
|
50
|
+
if v_lower not in allowed:
|
|
51
|
+
raise ValueError(f"Unsupported provider: {v}.")
|
|
52
|
+
return v_lower
|
|
53
|
+
|
|
41
54
|
def __init__(self, **kwargs):
|
|
42
55
|
super().__init__(**kwargs)
|
|
43
56
|
self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
|
|
44
57
|
|
|
45
58
|
async def _init_client(self):
|
|
46
59
|
if self.client is None:
|
|
47
|
-
|
|
48
|
-
if
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
60
|
+
|
|
61
|
+
if self.provider == "azure_openai":
|
|
62
|
+
|
|
63
|
+
azure_api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
|
64
|
+
azure_api_endpoint = self.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
|
65
|
+
azure_api_version = self.api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
|
|
66
|
+
self.client = AsyncAzureOpenAI(api_key=azure_api_key,
|
|
67
|
+
azure_endpoint=azure_api_endpoint,
|
|
68
|
+
api_version=azure_api_version,
|
|
69
|
+
timeout=self.request_timeout,
|
|
70
|
+
max_retries=2)
|
|
71
|
+
elif self.provider == "openai":
|
|
72
|
+
api_key_var: str = "OPENAI_API_KEY"
|
|
73
|
+
openai_api_key = self.api_key or os.getenv(api_key_var)
|
|
74
|
+
if not openai_api_key:
|
|
75
|
+
raise ValueError(f"OpenAI API key not found in environment variable {api_key_var}")
|
|
76
|
+
|
|
77
|
+
base_url = self.base_url or DEFAULT_LLM_ENDPOINT
|
|
78
|
+
self.client = AsyncOpenAI(api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2)
|
|
56
79
|
|
|
57
80
|
async def search_relevancy(self, query: str, document: str, custom_event: bool = True) -> Any:
|
|
58
81
|
await self._init_client()
|
|
@@ -147,6 +170,173 @@ class LLMReranker(BaseDocumentCompressor):
|
|
|
147
170
|
except Exception as e:
|
|
148
171
|
log.error(f"Batch processing error: {str(e)}")
|
|
149
172
|
continue
|
|
173
|
+
return ranked_results
|
|
174
|
+
|
|
175
|
+
async def search_relevancy_score(self, query: str, document: str) -> Any:
|
|
176
|
+
await self._init_client()
|
|
177
|
+
|
|
178
|
+
async with self._semaphore:
|
|
179
|
+
for attempt in range(self.max_retries):
|
|
180
|
+
try:
|
|
181
|
+
response = await self.client.chat.completions.create(
|
|
182
|
+
model=self.model,
|
|
183
|
+
messages=[
|
|
184
|
+
{"role": "system", "content": """
|
|
185
|
+
You are an intelligent assistant that evaluates how relevant a given document chunk is to a user's search query.
|
|
186
|
+
Your task is to analyze the similarity between the search query and the document chunk, and return **only the class label** that best represents the relevance:
|
|
187
|
+
|
|
188
|
+
- "class_1": Not relevant (score between 0.0 and 0.25)
|
|
189
|
+
- "class_2": Slightly relevant (score between 0.25 and 0.5)
|
|
190
|
+
- "class_3": Moderately relevant (score between 0.5 and 0.75)
|
|
191
|
+
- "class_4": Highly relevant (score between 0.75 and 1.0)
|
|
192
|
+
|
|
193
|
+
Respond with only one of: "class_1", "class_2", "class_3", or "class_4".
|
|
194
|
+
|
|
195
|
+
Examples:
|
|
196
|
+
|
|
197
|
+
Search query: "How to reset a router to factory settings?"
|
|
198
|
+
Document chunk: "Computers often come with customizable parental control settings."
|
|
199
|
+
Score: class_1
|
|
200
|
+
|
|
201
|
+
Search query: "Symptoms of vitamin D deficiency"
|
|
202
|
+
Document chunk: "Vitamin D deficiency has been linked to fatigue, bone pain, and muscle weakness."
|
|
203
|
+
Score: class_4
|
|
204
|
+
|
|
205
|
+
Search query: "Best practices for onboarding remote employees"
|
|
206
|
+
Document chunk: "An employee handbook can be useful for new hires, outlining company policies and benefits."
|
|
207
|
+
Score: class_2
|
|
208
|
+
|
|
209
|
+
Search query: "Benefits of mindfulness meditation"
|
|
210
|
+
Document chunk: "Practicing mindfulness has shown to reduce stress and improve focus in multiple studies."
|
|
211
|
+
Score: class_3
|
|
212
|
+
|
|
213
|
+
Search query: "What is Kubernetes used for?"
|
|
214
|
+
Document chunk: "Kubernetes is an open-source system for automating deployment, scaling, and management of containerized applications."
|
|
215
|
+
Score: class_4
|
|
216
|
+
|
|
217
|
+
Search query: "How to bake sourdough bread at home"
|
|
218
|
+
Document chunk: "The French Revolution began in 1789 and radically transformed society."
|
|
219
|
+
Score: class_1
|
|
220
|
+
|
|
221
|
+
Search query: "Machine learning algorithms for image classification"
|
|
222
|
+
Document chunk: "Convolutional Neural Networks (CNNs) are particularly effective in image classification tasks."
|
|
223
|
+
Score: class_4
|
|
224
|
+
|
|
225
|
+
Search query: "How to improve focus while working remotely"
|
|
226
|
+
Document chunk: "Creating a dedicated workspace and setting a consistent schedule can significantly improve focus during remote work."
|
|
227
|
+
Score: class_4
|
|
228
|
+
|
|
229
|
+
Search query: "Carbon emissions from electric vehicles vs gas cars"
|
|
230
|
+
Document chunk: "Electric vehicles produce zero emissions while driving, but battery production has environmental impacts."
|
|
231
|
+
Score: class_3
|
|
232
|
+
|
|
233
|
+
Search query: "Time zones in the United States"
|
|
234
|
+
Document chunk: "The U.S. is divided into six primary time zones: Eastern, Central, Mountain, Pacific, Alaska, and Hawaii-Aleutian."
|
|
235
|
+
Score: class_4
|
|
236
|
+
"""},
|
|
237
|
+
|
|
238
|
+
{"role": "user", "content": f"""
|
|
239
|
+
Now evaluate the following pair:
|
|
240
|
+
|
|
241
|
+
Search query: {query}
|
|
242
|
+
Document chunk: {document}
|
|
243
|
+
|
|
244
|
+
Which class best represents the relevance?
|
|
245
|
+
"""}
|
|
246
|
+
],
|
|
247
|
+
temperature=self.temperature,
|
|
248
|
+
n=1,
|
|
249
|
+
logprobs=True,
|
|
250
|
+
top_logprobs=4,
|
|
251
|
+
max_tokens=3
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Extract response and logprobs
|
|
255
|
+
class_label = response.choices[0].message.content.strip()
|
|
256
|
+
token_logprobs = response.choices[0].logprobs.content
|
|
257
|
+
# Reconstruct the prediction and extract the top logprobs from the final token (e.g., "1")
|
|
258
|
+
final_token_logprob = token_logprobs[-1]
|
|
259
|
+
top_logprobs = final_token_logprob.top_logprobs
|
|
260
|
+
# Create a map of 'class_1' -> probability, using token combinations
|
|
261
|
+
class_probs = {}
|
|
262
|
+
for top_token in top_logprobs:
|
|
263
|
+
full_label = f"class_{top_token.token}"
|
|
264
|
+
prob = math.exp(top_token.logprob)
|
|
265
|
+
class_probs[full_label] = prob
|
|
266
|
+
# Optional: normalize in case some are missing
|
|
267
|
+
total_prob = sum(class_probs.values())
|
|
268
|
+
class_probs = {k: v / total_prob for k, v in class_probs.items()}
|
|
269
|
+
# Assign weights to classes
|
|
270
|
+
class_weights = {
|
|
271
|
+
"class_1": 0.25,
|
|
272
|
+
"class_2": 0.5,
|
|
273
|
+
"class_3": 0.75,
|
|
274
|
+
"class_4": 1.0
|
|
275
|
+
}
|
|
276
|
+
# Compute the final smooth score
|
|
277
|
+
relevance_score = sum(class_weights.get(class_label, 0) * prob for class_label, prob in class_probs.items())
|
|
278
|
+
rerank_data = {
|
|
279
|
+
"document": document,
|
|
280
|
+
"answer": class_label,
|
|
281
|
+
"relevance_score": relevance_score
|
|
282
|
+
}
|
|
283
|
+
return rerank_data
|
|
284
|
+
|
|
285
|
+
except Exception as e:
|
|
286
|
+
if attempt == self.max_retries - 1:
|
|
287
|
+
log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
|
|
288
|
+
raise
|
|
289
|
+
# Exponential backoff with jitter
|
|
290
|
+
retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
|
|
291
|
+
await asyncio.sleep(retry_delay)
|
|
292
|
+
|
|
293
|
+
async def _rank_score(self, query_document_pairs: List[Tuple[str, str]]) -> List[Tuple[str, float]]:
|
|
294
|
+
ranked_results = []
|
|
295
|
+
|
|
296
|
+
# Process in larger batches for better throughput
|
|
297
|
+
batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
|
|
298
|
+
for i in range(0, len(query_document_pairs), batch_size):
|
|
299
|
+
batch = query_document_pairs[i:i + batch_size]
|
|
300
|
+
try:
|
|
301
|
+
results = await asyncio.gather(
|
|
302
|
+
*[self.search_relevancy_score(query=query, document=document) for (query, document) in batch],
|
|
303
|
+
return_exceptions=True
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
for idx, result in enumerate(results):
|
|
307
|
+
if isinstance(result, Exception):
|
|
308
|
+
log.error(f"Error processing document {i+idx}: {str(result)}")
|
|
309
|
+
ranked_results.append((batch[idx][1], 0.0))
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
score = result["relevance_score"]
|
|
313
|
+
if score is not None:
|
|
314
|
+
if score > 1.0:
|
|
315
|
+
score = 1.0
|
|
316
|
+
elif score < 0.0:
|
|
317
|
+
score = 0.0
|
|
318
|
+
|
|
319
|
+
ranked_results.append((batch[idx][1], score))
|
|
320
|
+
# Check if we should stop early
|
|
321
|
+
try:
|
|
322
|
+
high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
|
|
323
|
+
can_stop_early = (
|
|
324
|
+
self.early_stop # Early stopping is enabled
|
|
325
|
+
and self.num_docs_to_keep # We have a target number of docs
|
|
326
|
+
and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
|
|
327
|
+
and score >= self.early_stop_threshold # Current doc is good enough
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if can_stop_early:
|
|
331
|
+
log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
|
|
332
|
+
return ranked_results
|
|
333
|
+
except Exception as e:
|
|
334
|
+
# Don't let early stopping errors stop the whole process
|
|
335
|
+
log.warning(f"Error in early stopping check: {str(e)}")
|
|
336
|
+
|
|
337
|
+
except Exception as e:
|
|
338
|
+
log.error(f"Batch processing error: {str(e)}")
|
|
339
|
+
continue
|
|
150
340
|
|
|
151
341
|
return ranked_results
|
|
152
342
|
|
|
@@ -226,6 +416,7 @@ class LLMReranker(BaseDocumentCompressor):
|
|
|
226
416
|
"model": self.model,
|
|
227
417
|
"temperature": self.temperature,
|
|
228
418
|
"remove_irrelevant": self.remove_irrelevant,
|
|
419
|
+
"method": self.method,
|
|
229
420
|
}
|
|
230
421
|
|
|
231
422
|
def get_scores(self, query: str, documents: list[str], custom_event: bool = False):
|
|
@@ -239,6 +430,10 @@ class LLMReranker(BaseDocumentCompressor):
|
|
|
239
430
|
loop = asyncio.new_event_loop()
|
|
240
431
|
asyncio.set_event_loop(loop)
|
|
241
432
|
|
|
242
|
-
|
|
433
|
+
if self.method == "multi-class": # default 'multi-class' method
|
|
434
|
+
documents_and_scores = loop.run_until_complete(self._rank_score(query_document_pairs))
|
|
435
|
+
else:
|
|
436
|
+
documents_and_scores = loop.run_until_complete(self._rank(query_document_pairs, custom_event=custom_event))
|
|
437
|
+
|
|
243
438
|
scores = [score for _, score in documents_and_scores]
|
|
244
439
|
return scores
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import argparse
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from typing import List, Dict, Optional
|
|
6
|
+
from contextlib import AsyncExitStack
|
|
7
|
+
|
|
8
|
+
import uvicorn
|
|
9
|
+
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
10
|
+
from fastapi.responses import StreamingResponse
|
|
11
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
12
|
+
from pydantic import BaseModel, Field
|
|
13
|
+
from mcp import ClientSession, StdioServerParameters
|
|
14
|
+
from mcp.client.stdio import stdio_client
|
|
15
|
+
|
|
16
|
+
from mindsdb.utilities import log
|
|
17
|
+
from mindsdb.interfaces.agents.mcp_client_agent import create_mcp_agent
|
|
18
|
+
|
|
19
|
+
logger = log.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
app = FastAPI(title="MindsDB MCP Agent LiteLLM API")
|
|
22
|
+
|
|
23
|
+
# Configure CORS
|
|
24
|
+
app.add_middleware(
|
|
25
|
+
CORSMiddleware,
|
|
26
|
+
allow_origins=["*"],
|
|
27
|
+
allow_credentials=True,
|
|
28
|
+
allow_methods=["*"],
|
|
29
|
+
allow_headers=["*"],
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Store agent wrapper as a global variable
|
|
33
|
+
agent_wrapper = None
|
|
34
|
+
# MCP session for direct SQL queries
|
|
35
|
+
mcp_session = None
|
|
36
|
+
exit_stack = AsyncExitStack()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ChatMessage(BaseModel):
|
|
40
|
+
role: str
|
|
41
|
+
content: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ChatCompletionRequest(BaseModel):
|
|
45
|
+
model: str
|
|
46
|
+
messages: List[ChatMessage]
|
|
47
|
+
stream: bool = False
|
|
48
|
+
temperature: Optional[float] = None
|
|
49
|
+
max_tokens: Optional[int] = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ChatCompletionChoice(BaseModel):
|
|
53
|
+
index: int = 0
|
|
54
|
+
message: Optional[Dict[str, str]] = None
|
|
55
|
+
delta: Optional[Dict[str, str]] = None
|
|
56
|
+
finish_reason: Optional[str] = "stop"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ChatCompletionResponse(BaseModel):
|
|
60
|
+
id: str = "mcp-agent-response"
|
|
61
|
+
object: str = "chat.completion"
|
|
62
|
+
created: int = 0
|
|
63
|
+
model: str
|
|
64
|
+
choices: List[ChatCompletionChoice]
|
|
65
|
+
usage: Dict[str, int] = Field(default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class DirectSQLRequest(BaseModel):
|
|
69
|
+
query: str
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@app.post("/v1/chat/completions")
|
|
73
|
+
async def chat_completions(request: ChatCompletionRequest):
|
|
74
|
+
global agent_wrapper
|
|
75
|
+
|
|
76
|
+
if agent_wrapper is None:
|
|
77
|
+
raise HTTPException(status_code=500, detail="Agent not initialized. Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
# Convert request to messages format
|
|
81
|
+
messages = [
|
|
82
|
+
{"role": msg.role, "content": msg.content}
|
|
83
|
+
for msg in request.messages
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
if request.stream:
|
|
87
|
+
# Return a streaming response
|
|
88
|
+
async def generate():
|
|
89
|
+
try:
|
|
90
|
+
async for chunk in agent_wrapper.acompletion_stream(messages, model=request.model):
|
|
91
|
+
yield f"data: {json.dumps(chunk)}\n\n"
|
|
92
|
+
yield "data: [DONE]\n\n"
|
|
93
|
+
except Exception as e:
|
|
94
|
+
logger.error(f"Streaming error: {str(e)}")
|
|
95
|
+
yield "data: {{'error': 'Streaming failed due to an internal error.'}}\n\n"
|
|
96
|
+
return StreamingResponse(generate(), media_type="text/event-stream")
|
|
97
|
+
else:
|
|
98
|
+
# Return a regular response
|
|
99
|
+
response = await agent_wrapper.acompletion(messages)
|
|
100
|
+
|
|
101
|
+
# Ensure the content is a string
|
|
102
|
+
content = response["choices"][0]["message"].get("content", "")
|
|
103
|
+
if not isinstance(content, str):
|
|
104
|
+
content = str(content)
|
|
105
|
+
|
|
106
|
+
# Transform to proper OpenAI format
|
|
107
|
+
return ChatCompletionResponse(
|
|
108
|
+
model=request.model,
|
|
109
|
+
choices=[
|
|
110
|
+
ChatCompletionChoice(
|
|
111
|
+
message={"role": "assistant", "content": content}
|
|
112
|
+
)
|
|
113
|
+
]
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.error(f"Error in chat completion: {str(e)}")
|
|
118
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@app.post("/direct-sql")
|
|
122
|
+
async def direct_sql(request: DirectSQLRequest, background_tasks: BackgroundTasks):
|
|
123
|
+
"""Execute a direct SQL query via MCP (for testing)"""
|
|
124
|
+
global agent_wrapper, mcp_session
|
|
125
|
+
|
|
126
|
+
if agent_wrapper is None and mcp_session is None:
|
|
127
|
+
raise HTTPException(status_code=500, detail="No MCP session available. Make sure MindsDB server is running with MCP enabled.")
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
# First try to use the agent's session if available
|
|
131
|
+
if hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session:
|
|
132
|
+
session = agent_wrapper.agent.session
|
|
133
|
+
result = await session.call_tool("query", {"query": request.query})
|
|
134
|
+
return {"result": result.content}
|
|
135
|
+
# If agent session not available, use the direct session
|
|
136
|
+
elif mcp_session:
|
|
137
|
+
result = await mcp_session.call_tool("query", {"query": request.query})
|
|
138
|
+
return {"result": result.content}
|
|
139
|
+
else:
|
|
140
|
+
raise HTTPException(status_code=500, detail="No MCP session available")
|
|
141
|
+
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(f"Error executing direct SQL: {str(e)}")
|
|
144
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@app.get("/v1/models")
|
|
148
|
+
async def list_models():
|
|
149
|
+
"""List available models - always returns the single model we're using"""
|
|
150
|
+
global agent_wrapper
|
|
151
|
+
|
|
152
|
+
if agent_wrapper is None:
|
|
153
|
+
return {
|
|
154
|
+
"object": "list",
|
|
155
|
+
"data": [
|
|
156
|
+
{
|
|
157
|
+
"id": "mcp-agent",
|
|
158
|
+
"object": "model",
|
|
159
|
+
"created": 0,
|
|
160
|
+
"owned_by": "mindsdb"
|
|
161
|
+
}
|
|
162
|
+
]
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
# Return the actual model name if available
|
|
166
|
+
model_name = agent_wrapper.agent.args.get("model_name", "mcp-agent")
|
|
167
|
+
|
|
168
|
+
return {
|
|
169
|
+
"object": "list",
|
|
170
|
+
"data": [
|
|
171
|
+
{
|
|
172
|
+
"id": model_name,
|
|
173
|
+
"object": "model",
|
|
174
|
+
"created": 0,
|
|
175
|
+
"owned_by": "mindsdb"
|
|
176
|
+
}
|
|
177
|
+
]
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@app.get("/health")
|
|
182
|
+
async def health_check():
|
|
183
|
+
"""Health check endpoint"""
|
|
184
|
+
global agent_wrapper
|
|
185
|
+
|
|
186
|
+
health_status = {
|
|
187
|
+
"status": "ok",
|
|
188
|
+
"agent_initialized": agent_wrapper is not None,
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
if agent_wrapper is not None:
|
|
192
|
+
health_status["mcp_connected"] = hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session is not None
|
|
193
|
+
health_status["agent_name"] = agent_wrapper.agent.agent.name
|
|
194
|
+
health_status["model_name"] = agent_wrapper.agent.args.get("model_name", "unknown")
|
|
195
|
+
|
|
196
|
+
return health_status
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@app.get("/test-mcp-connection")
|
|
200
|
+
async def test_mcp_connection():
|
|
201
|
+
"""Test the connection to the MCP server"""
|
|
202
|
+
global mcp_session, exit_stack
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
# If we already have a session, test it
|
|
206
|
+
if mcp_session:
|
|
207
|
+
try:
|
|
208
|
+
tools_response = await mcp_session.list_tools()
|
|
209
|
+
return {
|
|
210
|
+
"status": "ok",
|
|
211
|
+
"message": "Successfully connected to MCP server",
|
|
212
|
+
"tools": [tool.name for tool in tools_response.tools]
|
|
213
|
+
}
|
|
214
|
+
except Exception:
|
|
215
|
+
# If error, close existing session and create a new one
|
|
216
|
+
await exit_stack.aclose()
|
|
217
|
+
mcp_session = None
|
|
218
|
+
|
|
219
|
+
# Create a new MCP session - connect to running server
|
|
220
|
+
server_params = StdioServerParameters(
|
|
221
|
+
command="python",
|
|
222
|
+
args=["-m", "mindsdb", "--api=mcp"],
|
|
223
|
+
env=None
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params))
|
|
227
|
+
stdio, write = stdio_transport
|
|
228
|
+
session = await exit_stack.enter_async_context(ClientSession(stdio, write))
|
|
229
|
+
|
|
230
|
+
await session.initialize()
|
|
231
|
+
|
|
232
|
+
# Save the session for future use
|
|
233
|
+
mcp_session = session
|
|
234
|
+
|
|
235
|
+
# Get available tools
|
|
236
|
+
tools_response = await session.list_tools()
|
|
237
|
+
|
|
238
|
+
return {
|
|
239
|
+
"status": "ok",
|
|
240
|
+
"message": "Successfully connected to MCP server",
|
|
241
|
+
"tools": [tool.name for tool in tools_response.tools]
|
|
242
|
+
}
|
|
243
|
+
except Exception as e:
|
|
244
|
+
logger.error(f"Error connecting to MCP server: {str(e)}")
|
|
245
|
+
error_detail = f"Error connecting to MCP server: {str(e)}. Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http"
|
|
246
|
+
raise HTTPException(status_code=500, detail=error_detail)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
async def init_agent(agent_name: str, project_name: str, mcp_host: str, mcp_port: int):
|
|
250
|
+
"""Initialize the agent"""
|
|
251
|
+
global agent_wrapper
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
logger.info(f"Initializing MCP agent '{agent_name}' in project '{project_name}'")
|
|
255
|
+
logger.info(f"Connecting to MCP server at {mcp_host}:{mcp_port}")
|
|
256
|
+
logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
|
|
257
|
+
|
|
258
|
+
agent_wrapper = create_mcp_agent(
|
|
259
|
+
agent_name=agent_name,
|
|
260
|
+
project_name=project_name,
|
|
261
|
+
mcp_host=mcp_host,
|
|
262
|
+
mcp_port=mcp_port
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
logger.info("Agent initialized successfully")
|
|
266
|
+
return True
|
|
267
|
+
except Exception as e:
|
|
268
|
+
logger.error(f"Failed to initialize agent: {str(e)}")
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
@app.on_event("shutdown")
|
|
273
|
+
async def shutdown_event():
|
|
274
|
+
"""Clean up resources on server shutdown"""
|
|
275
|
+
global agent_wrapper, exit_stack
|
|
276
|
+
|
|
277
|
+
if agent_wrapper:
|
|
278
|
+
await agent_wrapper.cleanup()
|
|
279
|
+
|
|
280
|
+
await exit_stack.aclose()
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
async def run_server_async(
|
|
284
|
+
agent_name: str,
|
|
285
|
+
project_name: str = "mindsdb",
|
|
286
|
+
mcp_host: str = "127.0.0.1",
|
|
287
|
+
mcp_port: int = 47337,
|
|
288
|
+
host: str = "0.0.0.0",
|
|
289
|
+
port: int = 8000
|
|
290
|
+
):
|
|
291
|
+
"""Run the FastAPI server"""
|
|
292
|
+
# Initialize the agent
|
|
293
|
+
success = await init_agent(agent_name, project_name, mcp_host, mcp_port)
|
|
294
|
+
if not success:
|
|
295
|
+
logger.error("Failed to initialize agent. Make sure MindsDB server is running with MCP enabled.")
|
|
296
|
+
return 1
|
|
297
|
+
|
|
298
|
+
return 0
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def run_server(
|
|
302
|
+
agent_name: str,
|
|
303
|
+
project_name: str = "mindsdb",
|
|
304
|
+
mcp_host: str = "127.0.0.1",
|
|
305
|
+
mcp_port: int = 47337,
|
|
306
|
+
host: str = "0.0.0.0",
|
|
307
|
+
port: int = 8000
|
|
308
|
+
):
|
|
309
|
+
"""Run the FastAPI server"""
|
|
310
|
+
logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
|
|
311
|
+
# Initialize database
|
|
312
|
+
from mindsdb.interfaces.storage import db
|
|
313
|
+
db.init()
|
|
314
|
+
|
|
315
|
+
# Run initialization in the event loop
|
|
316
|
+
loop = asyncio.new_event_loop()
|
|
317
|
+
asyncio.set_event_loop(loop)
|
|
318
|
+
result = loop.run_until_complete(run_server_async(agent_name, project_name, mcp_host, mcp_port))
|
|
319
|
+
if result != 0:
|
|
320
|
+
return result
|
|
321
|
+
# Run the server
|
|
322
|
+
logger.info(f"Starting server on {host}:{port}")
|
|
323
|
+
uvicorn.run(app, host=host, port=port)
|
|
324
|
+
return 0
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
if __name__ == "__main__":
|
|
328
|
+
parser = argparse.ArgumentParser(description="Run a LiteLLM-compatible API server for MCP agent")
|
|
329
|
+
parser.add_argument("--agent", type=str, required=True, help="Name of the agent to use")
|
|
330
|
+
parser.add_argument("--project", type=str, default="mindsdb", help="Project containing the agent")
|
|
331
|
+
parser.add_argument("--mcp-host", type=str, default="127.0.0.1", help="MCP server host")
|
|
332
|
+
parser.add_argument("--mcp-port", type=int, default=47337, help="MCP server port")
|
|
333
|
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
|
|
334
|
+
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
|
335
|
+
|
|
336
|
+
args = parser.parse_args()
|
|
337
|
+
|
|
338
|
+
run_server(
|
|
339
|
+
agent_name=args.agent,
|
|
340
|
+
project_name=args.project,
|
|
341
|
+
mcp_host=args.mcp_host,
|
|
342
|
+
mcp_port=args.mcp_port,
|
|
343
|
+
host=args.host,
|
|
344
|
+
port=args.port
|
|
345
|
+
)
|