MindsDB 25.6.4.0__py3-none-any.whl → 25.7.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/api/executor/command_executor.py +8 -6
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +1 -1
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +9 -11
- mindsdb/api/executor/datahub/datanodes/system_tables.py +1 -1
- mindsdb/api/executor/planner/query_prepare.py +68 -87
- mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +6 -1
- mindsdb/api/executor/sql_query/steps/union_step.py +11 -9
- mindsdb/api/http/namespaces/file.py +49 -24
- mindsdb/api/mcp/start.py +45 -31
- mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +45 -52
- mindsdb/integrations/handlers/huggingface_handler/__init__.py +17 -12
- mindsdb/integrations/handlers/huggingface_handler/finetune.py +223 -223
- mindsdb/integrations/handlers/huggingface_handler/huggingface_handler.py +383 -383
- mindsdb/integrations/handlers/huggingface_handler/requirements.txt +7 -6
- mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt +7 -6
- mindsdb/integrations/handlers/huggingface_handler/settings.py +25 -25
- mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +22 -15
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +150 -140
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +1 -1
- mindsdb/integrations/handlers/statsforecast_handler/requirements.txt +1 -0
- mindsdb/integrations/handlers/statsforecast_handler/requirements_extra.txt +1 -0
- mindsdb/integrations/libs/vectordatabase_handler.py +86 -77
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +36 -42
- mindsdb/interfaces/agents/agents_controller.py +29 -9
- mindsdb/interfaces/agents/langchain_agent.py +7 -5
- mindsdb/interfaces/agents/mcp_client_agent.py +4 -4
- mindsdb/interfaces/agents/mindsdb_database_agent.py +10 -43
- mindsdb/interfaces/data_catalog/data_catalog_reader.py +3 -1
- mindsdb/interfaces/knowledge_base/controller.py +115 -89
- mindsdb/interfaces/knowledge_base/evaluate.py +16 -4
- mindsdb/interfaces/knowledge_base/executor.py +346 -0
- mindsdb/interfaces/knowledge_base/llm_client.py +5 -6
- mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +20 -45
- mindsdb/interfaces/knowledge_base/preprocessing/models.py +36 -69
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_kb_tools.py +2 -0
- mindsdb/interfaces/skills/sql_agent.py +181 -130
- mindsdb/interfaces/storage/db.py +9 -7
- mindsdb/utilities/config.py +12 -1
- mindsdb/utilities/exception.py +47 -7
- mindsdb/utilities/security.py +54 -11
- {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.1.0.dist-info}/METADATA +248 -262
- {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.1.0.dist-info}/RECORD +46 -45
- {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.1.0.dist-info}/WHEEL +0 -0
- {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.1.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.1.0.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import ast
|
|
|
2
2
|
import hashlib
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from typing import Dict, List, Optional
|
|
5
|
+
import datetime as dt
|
|
5
6
|
|
|
6
7
|
import pandas as pd
|
|
7
8
|
from mindsdb_sql_parser.ast import (
|
|
@@ -28,6 +29,9 @@ from .base import BaseHandler
|
|
|
28
29
|
LOG = log.getLogger(__name__)
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
class VectorHandlerException(Exception): ...
|
|
33
|
+
|
|
34
|
+
|
|
31
35
|
class TableField(Enum):
|
|
32
36
|
"""
|
|
33
37
|
Enum for table fields.
|
|
@@ -43,9 +47,9 @@ class TableField(Enum):
|
|
|
43
47
|
|
|
44
48
|
|
|
45
49
|
class DistanceFunction(Enum):
|
|
46
|
-
SQUARED_EUCLIDEAN_DISTANCE =
|
|
47
|
-
NEGATIVE_DOT_PRODUCT =
|
|
48
|
-
COSINE_DISTANCE =
|
|
50
|
+
SQUARED_EUCLIDEAN_DISTANCE = ("<->",)
|
|
51
|
+
NEGATIVE_DOT_PRODUCT = ("<#>",)
|
|
52
|
+
COSINE_DISTANCE = "<=>"
|
|
49
53
|
|
|
50
54
|
|
|
51
55
|
class VectorStoreHandler(BaseHandler):
|
|
@@ -118,9 +122,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
118
122
|
right_hand = [item.value for item in node.args[1].items]
|
|
119
123
|
else:
|
|
120
124
|
raise Exception(f"Unsupported right hand side: {node.args[1]}")
|
|
121
|
-
conditions.append(
|
|
122
|
-
FilterCondition(column=left_hand, op=op, value=right_hand)
|
|
123
|
-
)
|
|
125
|
+
conditions.append(FilterCondition(column=left_hand, op=op, value=right_hand))
|
|
124
126
|
|
|
125
127
|
query_traversal(where_statement, _extract_comparison_conditions)
|
|
126
128
|
|
|
@@ -129,15 +131,23 @@ class VectorStoreHandler(BaseHandler):
|
|
|
129
131
|
|
|
130
132
|
return conditions
|
|
131
133
|
|
|
132
|
-
def _convert_metadata_filters(self, conditions):
|
|
134
|
+
def _convert_metadata_filters(self, conditions, allowed_metadata_columns=None):
|
|
133
135
|
if conditions is None:
|
|
134
136
|
return
|
|
135
137
|
# try to treat conditions that are not in TableField as metadata conditions
|
|
136
138
|
for condition in conditions:
|
|
137
|
-
if
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
139
|
+
if self._is_metadata_condition(condition):
|
|
140
|
+
# check restriction
|
|
141
|
+
if allowed_metadata_columns is not None:
|
|
142
|
+
# system columns are underscored, skip them
|
|
143
|
+
if condition.column.lower() not in allowed_metadata_columns and not condition.column.startswith(
|
|
144
|
+
"_"
|
|
145
|
+
):
|
|
146
|
+
raise ValueError(f"Column is not found: {condition.column}")
|
|
147
|
+
|
|
148
|
+
# convert if required
|
|
149
|
+
if not condition.column.startswith(TableField.METADATA.value):
|
|
150
|
+
condition.column = TableField.METADATA.value + "." + condition.column
|
|
141
151
|
|
|
142
152
|
def _is_columns_allowed(self, columns: List[str]) -> bool:
|
|
143
153
|
"""
|
|
@@ -146,16 +156,11 @@ class VectorStoreHandler(BaseHandler):
|
|
|
146
156
|
allowed_columns = set([col["name"] for col in self.SCHEMA])
|
|
147
157
|
return set(columns).issubset(allowed_columns)
|
|
148
158
|
|
|
149
|
-
def
|
|
159
|
+
def _is_metadata_condition(self, condition: FilterCondition) -> bool:
|
|
150
160
|
allowed_field_values = set([field.value for field in TableField])
|
|
151
161
|
if condition.column in allowed_field_values:
|
|
152
|
-
return
|
|
153
|
-
|
|
154
|
-
# check if column is a metadata column
|
|
155
|
-
if condition.column.startswith(TableField.METADATA.value):
|
|
156
|
-
return True
|
|
157
|
-
else:
|
|
158
|
-
return False
|
|
162
|
+
return False
|
|
163
|
+
return True
|
|
159
164
|
|
|
160
165
|
def _dispatch_create_table(self, query: CreateTable):
|
|
161
166
|
"""
|
|
@@ -184,17 +189,12 @@ class VectorStoreHandler(BaseHandler):
|
|
|
184
189
|
columns = [column.name for column in query.columns]
|
|
185
190
|
|
|
186
191
|
if not self._is_columns_allowed(columns):
|
|
187
|
-
raise Exception(
|
|
188
|
-
f"Columns {columns} not allowed."
|
|
189
|
-
f"Allowed columns are {[col['name'] for col in self.SCHEMA]}"
|
|
190
|
-
)
|
|
192
|
+
raise Exception(f"Columns {columns} not allowed.Allowed columns are {[col['name'] for col in self.SCHEMA]}")
|
|
191
193
|
|
|
192
194
|
# get content column if it is present
|
|
193
195
|
if TableField.CONTENT.value in columns:
|
|
194
196
|
content_col_index = columns.index("content")
|
|
195
|
-
content = [
|
|
196
|
-
self._value_or_self(row[content_col_index]) for row in query.values
|
|
197
|
-
]
|
|
197
|
+
content = [self._value_or_self(row[content_col_index]) for row in query.values]
|
|
198
198
|
else:
|
|
199
199
|
content = None
|
|
200
200
|
|
|
@@ -209,19 +209,13 @@ class VectorStoreHandler(BaseHandler):
|
|
|
209
209
|
# get embeddings column if it is present
|
|
210
210
|
if TableField.EMBEDDINGS.value in columns:
|
|
211
211
|
embeddings_col_index = columns.index("embeddings")
|
|
212
|
-
embeddings = [
|
|
213
|
-
ast.literal_eval(self._value_or_self(row[embeddings_col_index]))
|
|
214
|
-
for row in query.values
|
|
215
|
-
]
|
|
212
|
+
embeddings = [ast.literal_eval(self._value_or_self(row[embeddings_col_index])) for row in query.values]
|
|
216
213
|
else:
|
|
217
214
|
raise Exception("Embeddings column is required!")
|
|
218
215
|
|
|
219
216
|
if TableField.METADATA.value in columns:
|
|
220
217
|
metadata_col_index = columns.index("metadata")
|
|
221
|
-
metadata = [
|
|
222
|
-
ast.literal_eval(self._value_or_self(row[metadata_col_index]))
|
|
223
|
-
for row in query.values
|
|
224
|
-
]
|
|
218
|
+
metadata = [ast.literal_eval(self._value_or_self(row[metadata_col_index])) for row in query.values]
|
|
225
219
|
else:
|
|
226
220
|
metadata = None
|
|
227
221
|
|
|
@@ -277,6 +271,15 @@ class VectorStoreHandler(BaseHandler):
|
|
|
277
271
|
|
|
278
272
|
return self.do_upsert(table_name, df)
|
|
279
273
|
|
|
274
|
+
def set_metadata_cur_time(self, df, col_name):
|
|
275
|
+
metadata_col = TableField.METADATA.value
|
|
276
|
+
cur_date = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
277
|
+
|
|
278
|
+
def set_time(meta):
|
|
279
|
+
meta[col_name] = cur_date
|
|
280
|
+
|
|
281
|
+
df[metadata_col].apply(set_time)
|
|
282
|
+
|
|
280
283
|
def do_upsert(self, table_name, df):
|
|
281
284
|
"""Upsert data into table, handling document updates and deletions.
|
|
282
285
|
|
|
@@ -289,6 +292,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
289
292
|
2. Updated documents: Delete old chunks and insert new ones
|
|
290
293
|
"""
|
|
291
294
|
id_col = TableField.ID.value
|
|
295
|
+
metadata_col = TableField.METADATA.value
|
|
292
296
|
content_col = TableField.CONTENT.value
|
|
293
297
|
|
|
294
298
|
def gen_hash(v):
|
|
@@ -309,37 +313,48 @@ class VectorStoreHandler(BaseHandler):
|
|
|
309
313
|
# id is string TODO is it ok?
|
|
310
314
|
df[id_col] = df[id_col].apply(str)
|
|
311
315
|
|
|
312
|
-
|
|
316
|
+
# set updated_at
|
|
317
|
+
self.set_metadata_cur_time(df, "_updated_at")
|
|
318
|
+
|
|
319
|
+
if hasattr(self, "upsert"):
|
|
313
320
|
self.upsert(table_name, df)
|
|
314
321
|
return
|
|
315
322
|
|
|
316
323
|
# find existing ids
|
|
317
|
-
|
|
324
|
+
df_existed = self.select(
|
|
318
325
|
table_name,
|
|
319
|
-
columns=[id_col],
|
|
320
|
-
conditions=[
|
|
321
|
-
FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))
|
|
322
|
-
]
|
|
326
|
+
columns=[id_col, metadata_col],
|
|
327
|
+
conditions=[FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))],
|
|
323
328
|
)
|
|
324
|
-
existed_ids = list(
|
|
329
|
+
existed_ids = list(df_existed[id_col])
|
|
325
330
|
|
|
326
331
|
# update existed
|
|
327
332
|
df_update = df[df[id_col].isin(existed_ids)]
|
|
328
333
|
df_insert = df[~df[id_col].isin(existed_ids)]
|
|
329
334
|
|
|
330
335
|
if not df_update.empty:
|
|
336
|
+
# get values of existed `created_at` and return them to metadata
|
|
337
|
+
created_dates = {row[id_col]: row[metadata_col].get("_created_at") for _, row in df_existed.iterrows()}
|
|
338
|
+
|
|
339
|
+
def keep_created_at(row):
|
|
340
|
+
val = created_dates.get(row[id_col])
|
|
341
|
+
if val:
|
|
342
|
+
row[metadata_col]["_created_at"] = val
|
|
343
|
+
return row
|
|
344
|
+
|
|
345
|
+
df_update.apply(keep_created_at, axis=1)
|
|
346
|
+
|
|
331
347
|
try:
|
|
332
348
|
self.update(table_name, df_update, [id_col])
|
|
333
349
|
except NotImplementedError:
|
|
334
350
|
# not implemented? do it with delete and insert
|
|
335
|
-
conditions = [FilterCondition(
|
|
336
|
-
column=id_col,
|
|
337
|
-
op=FilterOperator.IN,
|
|
338
|
-
value=list(df[id_col])
|
|
339
|
-
)]
|
|
351
|
+
conditions = [FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))]
|
|
340
352
|
self.delete(table_name, conditions)
|
|
341
353
|
self.insert(table_name, df_update)
|
|
342
354
|
if not df_insert.empty:
|
|
355
|
+
# set created_at
|
|
356
|
+
self.set_metadata_cur_time(df_insert, "_created_at")
|
|
357
|
+
|
|
343
358
|
self.insert(table_name, df_insert)
|
|
344
359
|
|
|
345
360
|
def dispatch_delete(self, query: Delete, conditions: List[FilterCondition] = None):
|
|
@@ -356,7 +371,9 @@ class VectorStoreHandler(BaseHandler):
|
|
|
356
371
|
# dispatch delete
|
|
357
372
|
return self.delete(table_name, conditions=conditions)
|
|
358
373
|
|
|
359
|
-
def dispatch_select(
|
|
374
|
+
def dispatch_select(
|
|
375
|
+
self, query: Select, conditions: List[FilterCondition] = None, allowed_metadata_columns: List[str] = None
|
|
376
|
+
):
|
|
360
377
|
"""
|
|
361
378
|
Dispatch select query to the appropriate method.
|
|
362
379
|
"""
|
|
@@ -369,29 +386,30 @@ class VectorStoreHandler(BaseHandler):
|
|
|
369
386
|
columns = [col.parts[-1] for col in query.targets]
|
|
370
387
|
|
|
371
388
|
if not self._is_columns_allowed(columns):
|
|
372
|
-
raise Exception(
|
|
373
|
-
f"Columns {columns} not allowed."
|
|
374
|
-
f"Allowed columns are {[col['name'] for col in self.SCHEMA]}"
|
|
375
|
-
)
|
|
389
|
+
raise Exception(f"Columns {columns} not allowed.Allowed columns are {[col['name'] for col in self.SCHEMA]}")
|
|
376
390
|
|
|
377
391
|
# check if columns are allowed
|
|
378
392
|
if conditions is None:
|
|
379
393
|
where_statement = query.where
|
|
380
394
|
conditions = self.extract_conditions(where_statement)
|
|
381
|
-
self._convert_metadata_filters(conditions)
|
|
395
|
+
self._convert_metadata_filters(conditions, allowed_metadata_columns=allowed_metadata_columns)
|
|
382
396
|
|
|
383
397
|
# get offset and limit
|
|
384
398
|
offset = query.offset.value if query.offset is not None else None
|
|
385
399
|
limit = query.limit.value if query.limit is not None else None
|
|
386
400
|
|
|
387
401
|
# dispatch select
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
402
|
+
try:
|
|
403
|
+
return self.select(
|
|
404
|
+
table_name,
|
|
405
|
+
columns=columns,
|
|
406
|
+
conditions=conditions,
|
|
407
|
+
offset=offset,
|
|
408
|
+
limit=limit,
|
|
409
|
+
)
|
|
410
|
+
except Exception as e:
|
|
411
|
+
handler_engine = self.__class__.name
|
|
412
|
+
raise VectorHandlerException(f"Error in {handler_engine} database: {e}")
|
|
395
413
|
|
|
396
414
|
def _dispatch(self, query: ASTNode) -> HandlerResponse:
|
|
397
415
|
"""
|
|
@@ -408,10 +426,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
408
426
|
if type(query) in dispatch_router:
|
|
409
427
|
resp = dispatch_router[type(query)](query)
|
|
410
428
|
if resp is not None:
|
|
411
|
-
return HandlerResponse(
|
|
412
|
-
resp_type=RESPONSE_TYPE.TABLE,
|
|
413
|
-
data_frame=resp
|
|
414
|
-
)
|
|
429
|
+
return HandlerResponse(resp_type=RESPONSE_TYPE.TABLE, data_frame=resp)
|
|
415
430
|
else:
|
|
416
431
|
return HandlerResponse(resp_type=RESPONSE_TYPE.OK)
|
|
417
432
|
|
|
@@ -455,9 +470,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
455
470
|
"""
|
|
456
471
|
raise NotImplementedError()
|
|
457
472
|
|
|
458
|
-
def insert(
|
|
459
|
-
self, table_name: str, data: pd.DataFrame
|
|
460
|
-
) -> HandlerResponse:
|
|
473
|
+
def insert(self, table_name: str, data: pd.DataFrame) -> HandlerResponse:
|
|
461
474
|
"""Insert data into table
|
|
462
475
|
|
|
463
476
|
Args:
|
|
@@ -470,9 +483,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
470
483
|
"""
|
|
471
484
|
raise NotImplementedError()
|
|
472
485
|
|
|
473
|
-
def update(
|
|
474
|
-
self, table_name: str, data: pd.DataFrame, key_columns: List[str] = None
|
|
475
|
-
):
|
|
486
|
+
def update(self, table_name: str, data: pd.DataFrame, key_columns: List[str] = None):
|
|
476
487
|
"""Update data in table
|
|
477
488
|
|
|
478
489
|
Args:
|
|
@@ -485,9 +496,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
485
496
|
"""
|
|
486
497
|
raise NotImplementedError()
|
|
487
498
|
|
|
488
|
-
def delete(
|
|
489
|
-
self, table_name: str, conditions: List[FilterCondition] = None
|
|
490
|
-
) -> HandlerResponse:
|
|
499
|
+
def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> HandlerResponse:
|
|
491
500
|
"""Delete data from table
|
|
492
501
|
|
|
493
502
|
Args:
|
|
@@ -535,9 +544,9 @@ class VectorStoreHandler(BaseHandler):
|
|
|
535
544
|
query: str = None,
|
|
536
545
|
metadata: Dict[str, str] = None,
|
|
537
546
|
distance_function=DistanceFunction.COSINE_DISTANCE,
|
|
538
|
-
**kwargs
|
|
547
|
+
**kwargs,
|
|
539
548
|
) -> pd.DataFrame:
|
|
540
|
-
|
|
549
|
+
"""
|
|
541
550
|
Executes a hybrid search, combining semantic search and one or both of keyword/metadata search.
|
|
542
551
|
|
|
543
552
|
For insight on the query construction, see: https://docs.pgvecto.rs/use-case/hybrid-search.html#advanced-search-merge-the-results-of-full-text-search-and-vector-search.
|
|
@@ -551,11 +560,11 @@ class VectorStoreHandler(BaseHandler):
|
|
|
551
560
|
|
|
552
561
|
Returns:
|
|
553
562
|
df(pd.DataFrame): Hybrid search result, sorted by hybrid search rank
|
|
554
|
-
|
|
555
|
-
raise NotImplementedError(f
|
|
563
|
+
"""
|
|
564
|
+
raise NotImplementedError(f"Hybrid search not supported for VectorStoreHandler {self.name}")
|
|
556
565
|
|
|
557
566
|
def create_index(self, *args, **kwargs):
|
|
558
567
|
"""
|
|
559
568
|
Create an index on the specified table.
|
|
560
569
|
"""
|
|
561
|
-
raise NotImplementedError(f
|
|
570
|
+
raise NotImplementedError(f"create_index not supported for VectorStoreHandler {self.name}")
|
|
@@ -33,7 +33,7 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
33
33
|
client: Optional[AsyncOpenAI | BaseMLEngine] = None
|
|
34
34
|
_semaphore: Optional[asyncio.Semaphore] = None
|
|
35
35
|
max_concurrent_requests: int = 20
|
|
36
|
-
max_retries: int =
|
|
36
|
+
max_retries: int = 2
|
|
37
37
|
retry_delay: float = 1.0
|
|
38
38
|
request_timeout: float = 20.0 # Timeout for API requests
|
|
39
39
|
early_stop: bool = True # Whether to enable early stopping
|
|
@@ -100,7 +100,7 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
100
100
|
if self.api_key is not None:
|
|
101
101
|
kwargs["api_key"] = self.api_key
|
|
102
102
|
|
|
103
|
-
return await self.client.acompletion(
|
|
103
|
+
return await self.client.acompletion(self.provider, model=self.model, messages=messages, args=kwargs)
|
|
104
104
|
|
|
105
105
|
async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]:
|
|
106
106
|
ranked_results = []
|
|
@@ -109,47 +109,41 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
109
109
|
batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
|
|
110
110
|
for i in range(0, len(query_document_pairs), batch_size):
|
|
111
111
|
batch = query_document_pairs[i : i + batch_size]
|
|
112
|
-
try:
|
|
113
|
-
results = await asyncio.gather(
|
|
114
|
-
*[
|
|
115
|
-
self._backoff_wrapper(query=query, document=document, rerank_callback=rerank_callback)
|
|
116
|
-
for (query, document) in batch
|
|
117
|
-
],
|
|
118
|
-
return_exceptions=True,
|
|
119
|
-
)
|
|
120
112
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
113
|
+
results = await asyncio.gather(
|
|
114
|
+
*[
|
|
115
|
+
self._backoff_wrapper(query=query, document=document, rerank_callback=rerank_callback)
|
|
116
|
+
for (query, document) in batch
|
|
117
|
+
],
|
|
118
|
+
return_exceptions=True,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
for idx, result in enumerate(results):
|
|
122
|
+
if isinstance(result, Exception):
|
|
123
|
+
log.error(f"Error processing document {i + idx}: {str(result)}")
|
|
124
|
+
raise RuntimeError(f"Error during reranking: {result}")
|
|
125
|
+
|
|
126
|
+
score = result["relevance_score"]
|
|
127
|
+
|
|
128
|
+
ranked_results.append((batch[idx][1], score))
|
|
129
|
+
|
|
130
|
+
# Check if we should stop early
|
|
131
|
+
try:
|
|
132
|
+
high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
|
|
133
|
+
can_stop_early = (
|
|
134
|
+
self.early_stop # Early stopping is enabled
|
|
135
|
+
and self.num_docs_to_keep # We have a target number of docs
|
|
136
|
+
and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
|
|
137
|
+
and score >= self.early_stop_threshold # Current doc is good enough
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if can_stop_early:
|
|
141
|
+
log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
|
|
142
|
+
return ranked_results
|
|
143
|
+
except Exception as e:
|
|
144
|
+
# Don't let early stopping errors stop the whole process
|
|
145
|
+
log.warning(f"Error in early stopping check: {str(e)}")
|
|
146
|
+
|
|
153
147
|
return ranked_results
|
|
154
148
|
|
|
155
149
|
async def _backoff_wrapper(self, query: str, document: str, rerank_callback=None) -> Any:
|
|
@@ -160,7 +160,7 @@ class AgentsController:
|
|
|
160
160
|
Parameters:
|
|
161
161
|
name (str): The name of the new agent
|
|
162
162
|
project_name (str): The containing project
|
|
163
|
-
model_name (str): The name of the existing ML model the agent will use
|
|
163
|
+
model_name (str | dict): The name of the existing ML model the agent will use
|
|
164
164
|
skills (List[Union[str, dict]]): List of existing skill names to add to the new agent, or list of dicts
|
|
165
165
|
with one of keys is "name", and other is additional parameters for relationship agent<>skill
|
|
166
166
|
provider (str): The provider of the model
|
|
@@ -172,6 +172,9 @@ class AgentsController:
|
|
|
172
172
|
include_knowledge_bases: List of knowledge bases to include for text2sql skills
|
|
173
173
|
ignore_knowledge_bases: List of knowledge bases to ignore for text2sql skills
|
|
174
174
|
<provider>_api_key: API key for the provider (e.g., openai_api_key)
|
|
175
|
+
data: Dict, data sources for an agent, keys:
|
|
176
|
+
- knowledge_bases: List of KBs to use (alternative to `include_knowledge_bases`)
|
|
177
|
+
- tables: list of tables to use (alternative to `include_tables`)
|
|
175
178
|
|
|
176
179
|
Returns:
|
|
177
180
|
agent (db.Agents): The created agent
|
|
@@ -188,12 +191,17 @@ class AgentsController:
|
|
|
188
191
|
if agent is not None:
|
|
189
192
|
raise ValueError(f"Agent with name already exists: {name}")
|
|
190
193
|
|
|
191
|
-
if model_name is not None:
|
|
192
|
-
_, provider = self.check_model_provider(model_name, provider)
|
|
193
|
-
|
|
194
194
|
# No need to copy params since we're not preserving the original reference
|
|
195
195
|
params = params or {}
|
|
196
196
|
|
|
197
|
+
if isinstance(model_name, dict):
|
|
198
|
+
# move into params
|
|
199
|
+
params["model"] = model_name
|
|
200
|
+
model_name = None
|
|
201
|
+
|
|
202
|
+
if model_name is not None:
|
|
203
|
+
_, provider = self.check_model_provider(model_name, provider)
|
|
204
|
+
|
|
197
205
|
if model_name is None:
|
|
198
206
|
logger.warning("'model_name' param is not provided. Using default global llm model at runtime.")
|
|
199
207
|
|
|
@@ -230,6 +238,12 @@ class AgentsController:
|
|
|
230
238
|
if "database" in params or need_params:
|
|
231
239
|
params["database"] = database
|
|
232
240
|
|
|
241
|
+
if "data" in params:
|
|
242
|
+
if include_knowledge_bases is None:
|
|
243
|
+
include_knowledge_bases = params["data"].get("knowledge_bases")
|
|
244
|
+
if include_tables is None:
|
|
245
|
+
include_tables = params["data"].get("tables")
|
|
246
|
+
|
|
233
247
|
if "knowledge_base_database" in params or include_knowledge_bases or ignore_knowledge_bases:
|
|
234
248
|
params["knowledge_base_database"] = knowledge_base_database
|
|
235
249
|
|
|
@@ -549,13 +563,19 @@ class AgentsController:
|
|
|
549
563
|
agent.deleted_at = datetime.datetime.now()
|
|
550
564
|
db.session.commit()
|
|
551
565
|
|
|
552
|
-
def get_agent_llm_params(self,
|
|
566
|
+
def get_agent_llm_params(self, agent_params: dict):
|
|
553
567
|
"""
|
|
554
568
|
Get agent LLM parameters by combining default config with user provided parameters.
|
|
555
569
|
Similar to how knowledge bases handle default parameters.
|
|
556
570
|
"""
|
|
557
571
|
combined_model_params = copy.deepcopy(config.get("default_llm", {}))
|
|
558
572
|
|
|
573
|
+
if "model" in agent_params:
|
|
574
|
+
model_params = agent_params["model"]
|
|
575
|
+
else:
|
|
576
|
+
# params for LLM can be arbitrary
|
|
577
|
+
model_params = agent_params
|
|
578
|
+
|
|
559
579
|
if model_params:
|
|
560
580
|
combined_model_params.update(model_params)
|
|
561
581
|
|
|
@@ -596,9 +616,9 @@ class AgentsController:
|
|
|
596
616
|
db.session.commit()
|
|
597
617
|
|
|
598
618
|
# Get agent parameters and combine with default LLM parameters at runtime
|
|
599
|
-
|
|
619
|
+
llm_params = self.get_agent_llm_params(agent.params)
|
|
600
620
|
|
|
601
|
-
lang_agent = LangchainAgent(agent, model,
|
|
621
|
+
lang_agent = LangchainAgent(agent, model, llm_params=llm_params)
|
|
602
622
|
return lang_agent.get_completion(messages)
|
|
603
623
|
|
|
604
624
|
def _get_completion_stream(
|
|
@@ -636,7 +656,7 @@ class AgentsController:
|
|
|
636
656
|
db.session.commit()
|
|
637
657
|
|
|
638
658
|
# Get agent parameters and combine with default LLM parameters at runtime
|
|
639
|
-
|
|
659
|
+
llm_params = self.get_agent_llm_params(agent.params)
|
|
640
660
|
|
|
641
|
-
lang_agent = LangchainAgent(agent, model=model,
|
|
661
|
+
lang_agent = LangchainAgent(agent, model=model, llm_params=llm_params)
|
|
642
662
|
return lang_agent.get_completion(messages, stream=True)
|
|
@@ -228,7 +228,7 @@ def process_chunk(chunk):
|
|
|
228
228
|
|
|
229
229
|
|
|
230
230
|
class LangchainAgent:
|
|
231
|
-
def __init__(self, agent: db.Agents, model: dict = None,
|
|
231
|
+
def __init__(self, agent: db.Agents, model: dict = None, llm_params: dict = None):
|
|
232
232
|
self.agent = agent
|
|
233
233
|
self.model = model
|
|
234
234
|
|
|
@@ -241,12 +241,12 @@ class LangchainAgent:
|
|
|
241
241
|
self.mdb_langfuse_callback_handler: Optional[object] = None # custom (see langfuse_callback_handler.py)
|
|
242
242
|
|
|
243
243
|
self.langfuse_client_wrapper = LangfuseClientWrapper()
|
|
244
|
-
self.args = self._initialize_args(
|
|
244
|
+
self.args = self._initialize_args(llm_params)
|
|
245
245
|
|
|
246
246
|
# Back compatibility for old models
|
|
247
247
|
self.provider = self.args.get("provider", get_llm_provider(self.args))
|
|
248
248
|
|
|
249
|
-
def _initialize_args(self,
|
|
249
|
+
def _initialize_args(self, llm_params: dict = None) -> dict:
|
|
250
250
|
"""
|
|
251
251
|
Initialize the arguments for agent execution.
|
|
252
252
|
|
|
@@ -254,14 +254,16 @@ class LangchainAgent:
|
|
|
254
254
|
The params are already merged with defaults by AgentsController.get_agent_llm_params.
|
|
255
255
|
|
|
256
256
|
Args:
|
|
257
|
-
|
|
257
|
+
llm_params: Parameters for agent execution (already merged with defaults)
|
|
258
258
|
|
|
259
259
|
Returns:
|
|
260
260
|
dict: Final parameters for agent execution
|
|
261
261
|
"""
|
|
262
262
|
# Use the parameters passed to the method (already merged with defaults by AgentsController)
|
|
263
263
|
# No fallback needed as AgentsController.get_agent_llm_params already handles this
|
|
264
|
-
args = params.copy()
|
|
264
|
+
args = self.agent.params.copy()
|
|
265
|
+
if llm_params:
|
|
266
|
+
args.update(llm_params)
|
|
265
267
|
|
|
266
268
|
# Set model name and provider if given in create agent otherwise use global llm defaults
|
|
267
269
|
# AgentsController.get_agent_llm_params
|
|
@@ -71,11 +71,11 @@ class MCPLangchainAgent(LangchainAgent):
|
|
|
71
71
|
self,
|
|
72
72
|
agent: db.Agents,
|
|
73
73
|
model: dict = None,
|
|
74
|
-
|
|
74
|
+
llm_params: dict = None,
|
|
75
75
|
mcp_host: str = "127.0.0.1",
|
|
76
76
|
mcp_port: int = 47337,
|
|
77
77
|
):
|
|
78
|
-
super().__init__(agent, model,
|
|
78
|
+
super().__init__(agent, model, llm_params)
|
|
79
79
|
self.mcp_host = mcp_host
|
|
80
80
|
self.mcp_port = mcp_port
|
|
81
81
|
self.exit_stack = AsyncExitStack()
|
|
@@ -251,10 +251,10 @@ def create_mcp_agent(
|
|
|
251
251
|
raise ValueError(f"Agent {agent_name} not found in project {project_name}")
|
|
252
252
|
|
|
253
253
|
# Get merged parameters (defaults + agent params)
|
|
254
|
-
|
|
254
|
+
llm_params = agent_controller.get_agent_llm_params(agent_db.params)
|
|
255
255
|
|
|
256
256
|
# Create MCP agent with merged parameters
|
|
257
|
-
mcp_agent = MCPLangchainAgent(agent_db,
|
|
257
|
+
mcp_agent = MCPLangchainAgent(agent_db, llm_params=llm_params, mcp_host=mcp_host, mcp_port=mcp_port)
|
|
258
258
|
|
|
259
259
|
# Wrap for LiteLLM compatibility
|
|
260
260
|
return LiteLLMAgentWrapper(mcp_agent)
|