MindsDB 25.4.2.0__py3-none-any.whl → 25.4.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/api/executor/command_executor.py +29 -0
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +3 -2
- mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +43 -1
- mindsdb/api/executor/planner/plan_join.py +1 -1
- mindsdb/api/executor/planner/query_plan.py +1 -0
- mindsdb/api/executor/planner/query_planner.py +86 -14
- mindsdb/api/executor/planner/steps.py +9 -1
- mindsdb/api/executor/sql_query/sql_query.py +37 -6
- mindsdb/api/executor/sql_query/steps/__init__.py +1 -0
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +288 -0
- mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +17 -16
- mindsdb/integrations/handlers/langchain_handler/langchain_handler.py +1 -0
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +7 -11
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +28 -4
- mindsdb/integrations/libs/llm/config.py +11 -1
- mindsdb/integrations/libs/llm/utils.py +12 -0
- mindsdb/interfaces/agents/constants.py +12 -1
- mindsdb/interfaces/agents/langchain_agent.py +6 -0
- mindsdb/interfaces/knowledge_base/controller.py +128 -43
- mindsdb/interfaces/query_context/context_controller.py +221 -0
- mindsdb/interfaces/storage/db.py +23 -0
- mindsdb/migrations/versions/2025-03-21_fda503400e43_queries.py +45 -0
- mindsdb/utilities/context_executor.py +1 -1
- mindsdb/utilities/partitioning.py +35 -20
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/METADATA +224 -222
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/RECORD +30 -28
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/WHEEL +0 -0
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,8 @@ from mindsdb.integrations.libs.vectordatabase_handler import (
|
|
|
27
27
|
)
|
|
28
28
|
from mindsdb.integrations.utilities.rag.rag_pipeline_builder import RAG
|
|
29
29
|
from mindsdb.integrations.utilities.rag.config_loader import load_rag_config
|
|
30
|
+
from mindsdb.integrations.utilities.handler_utils import get_api_key
|
|
31
|
+
from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args, row_to_document
|
|
30
32
|
|
|
31
33
|
from mindsdb.interfaces.agents.constants import DEFAULT_EMBEDDINGS_MODEL_CLASS
|
|
32
34
|
from mindsdb.interfaces.agents.langchain_agent import create_chat_model, get_llm_provider
|
|
@@ -36,6 +38,7 @@ from mindsdb.interfaces.knowledge_base.preprocessing.document_preprocessor impor
|
|
|
36
38
|
from mindsdb.interfaces.model.functions import PredictorRecordNotFound
|
|
37
39
|
from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
|
|
38
40
|
from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
|
|
41
|
+
from mindsdb.utilities.context import context as ctx
|
|
39
42
|
|
|
40
43
|
from mindsdb.api.executor.command_executor import ExecuteCommands
|
|
41
44
|
from mindsdb.utilities import log
|
|
@@ -50,6 +53,42 @@ KB_TO_VECTORDB_COLUMNS = {
|
|
|
50
53
|
}
|
|
51
54
|
|
|
52
55
|
|
|
56
|
+
def get_embedding_model_from_params(embedding_model_params: dict):
|
|
57
|
+
"""
|
|
58
|
+
Create embedding model from parameters.
|
|
59
|
+
"""
|
|
60
|
+
params_copy = copy.deepcopy(embedding_model_params)
|
|
61
|
+
provider = params_copy.pop('provider', None).lower()
|
|
62
|
+
api_key = get_api_key(provider, params_copy, strict=False) or params_copy.get('api_key')
|
|
63
|
+
# Underscores are replaced because the provider name ultimately gets mapped to a class name.
|
|
64
|
+
# This is mostly to support Azure OpenAI (azure_openai); the mapped class name is 'AzureOpenAIEmbeddings'.
|
|
65
|
+
params_copy['class'] = provider.replace('_', '')
|
|
66
|
+
if provider == 'azure_openai':
|
|
67
|
+
# Azure OpenAI expects the api_key to be passed as 'openai_api_key'.
|
|
68
|
+
params_copy['openai_api_key'] = api_key
|
|
69
|
+
else:
|
|
70
|
+
params_copy[f"{provider}_api_key"] = api_key
|
|
71
|
+
params_copy.pop('api_key', None)
|
|
72
|
+
params_copy['model'] = params_copy.pop('model_name', None)
|
|
73
|
+
|
|
74
|
+
return construct_model_from_args(params_copy)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_reranking_model_from_params(reranking_model_params: dict):
|
|
78
|
+
"""
|
|
79
|
+
Create reranking model from parameters.
|
|
80
|
+
"""
|
|
81
|
+
params_copy = copy.deepcopy(reranking_model_params)
|
|
82
|
+
provider = params_copy.pop('provider', "openai").lower()
|
|
83
|
+
if provider != 'openai':
|
|
84
|
+
raise ValueError("Only OpenAI provider is supported for the reranking model.")
|
|
85
|
+
params_copy[f"{provider}_api_key"] = get_api_key(provider, params_copy, strict=False) or params_copy.get('api_key')
|
|
86
|
+
params_copy.pop('api_key', None)
|
|
87
|
+
params_copy['model'] = params_copy.pop('model_name', None)
|
|
88
|
+
|
|
89
|
+
return LLMReranker(**params_copy)
|
|
90
|
+
|
|
91
|
+
|
|
53
92
|
class KnowledgeBaseTable:
|
|
54
93
|
"""
|
|
55
94
|
Knowledge base table interface
|
|
@@ -163,18 +202,17 @@ class KnowledgeBaseTable:
|
|
|
163
202
|
def add_relevance(self, df, query_text, reranking_threshold=None):
|
|
164
203
|
relevance_column = TableField.RELEVANCE.value
|
|
165
204
|
|
|
166
|
-
|
|
167
|
-
if
|
|
205
|
+
reranking_model_params = self._kb.params.get("reranking_model")
|
|
206
|
+
if reranking_model_params and query_text and len(df) > 0:
|
|
168
207
|
# Use reranker for relevance score
|
|
169
208
|
try:
|
|
170
|
-
logger.info(f"Using
|
|
171
|
-
reranker_params = {"model": rerank_model}
|
|
209
|
+
logger.info(f"Using knowledge reranking model from params: {reranking_model_params}")
|
|
172
210
|
# Apply custom filtering threshold if provided
|
|
173
211
|
if reranking_threshold is not None:
|
|
174
|
-
|
|
212
|
+
reranking_model_params["filtering_threshold"] = reranking_threshold
|
|
175
213
|
logger.info(f"Using custom filtering threshold: {reranking_threshold}")
|
|
176
214
|
|
|
177
|
-
reranker =
|
|
215
|
+
reranker = get_reranking_model_from_params(reranking_model_params)
|
|
178
216
|
# Get documents to rerank
|
|
179
217
|
documents = df['chunk_content'].tolist()
|
|
180
218
|
# Use the get_scores method with disable_events=True
|
|
@@ -185,7 +223,7 @@ class KnowledgeBaseTable:
|
|
|
185
223
|
# Filter by threshold
|
|
186
224
|
scores_array = np.array(scores)
|
|
187
225
|
df = df[scores_array > reranker.filtering_threshold]
|
|
188
|
-
logger.debug(f"Applied reranking with
|
|
226
|
+
logger.debug(f"Applied reranking with params: {reranking_model_params}")
|
|
189
227
|
except Exception as e:
|
|
190
228
|
logger.error(f"Error during reranking: {str(e)}")
|
|
191
229
|
# Fallback to distance-based relevance
|
|
@@ -198,6 +236,8 @@ class KnowledgeBaseTable:
|
|
|
198
236
|
# Calculate relevance from distance
|
|
199
237
|
logger.info("Calculating relevance from vector distance")
|
|
200
238
|
df[relevance_column] = 1 / (1 + df['distance'])
|
|
239
|
+
if reranking_threshold is not None:
|
|
240
|
+
df = df[df[relevance_column] > reranking_threshold]
|
|
201
241
|
|
|
202
242
|
else:
|
|
203
243
|
df[relevance_column] = None
|
|
@@ -373,6 +413,16 @@ class KnowledgeBaseTable:
|
|
|
373
413
|
if df.empty:
|
|
374
414
|
return
|
|
375
415
|
|
|
416
|
+
try:
|
|
417
|
+
run_query_id = ctx.run_query_id
|
|
418
|
+
# Link current KB to running query (where KB is used to insert data)
|
|
419
|
+
if run_query_id is not None:
|
|
420
|
+
self._kb.query_id = run_query_id
|
|
421
|
+
db.session.commit()
|
|
422
|
+
|
|
423
|
+
except AttributeError:
|
|
424
|
+
...
|
|
425
|
+
|
|
376
426
|
# First adapt column names to identify content and metadata columns
|
|
377
427
|
adapted_df = self._adapt_column_names(df)
|
|
378
428
|
content_columns = self._kb.params.get('content_columns', [TableField.CONTENT.value])
|
|
@@ -577,36 +627,48 @@ class KnowledgeBaseTable:
|
|
|
577
627
|
if df.empty:
|
|
578
628
|
return pd.DataFrame([], columns=[TableField.EMBEDDINGS.value])
|
|
579
629
|
|
|
630
|
+
# keep only content
|
|
631
|
+
df = df[[TableField.CONTENT.value]]
|
|
632
|
+
|
|
580
633
|
model_id = self._kb.embedding_model_id
|
|
581
|
-
|
|
582
|
-
|
|
634
|
+
if model_id:
|
|
635
|
+
# get the input columns
|
|
636
|
+
model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first()
|
|
583
637
|
|
|
584
|
-
|
|
585
|
-
|
|
638
|
+
assert model_rec is not None, f"Model not found: {model_id}"
|
|
639
|
+
model_project = db.session.query(db.Project).filter_by(id=model_rec.project_id).first()
|
|
586
640
|
|
|
587
|
-
|
|
641
|
+
project_datanode = self.session.datahub.get(model_project.name)
|
|
588
642
|
|
|
589
|
-
|
|
590
|
-
|
|
643
|
+
model_using = model_rec.learn_args.get('using', {})
|
|
644
|
+
input_col = model_using.get('question_column')
|
|
645
|
+
if input_col is None:
|
|
646
|
+
input_col = model_using.get('input_column')
|
|
591
647
|
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
if input_col is None:
|
|
595
|
-
input_col = model_using.get('input_column')
|
|
648
|
+
if input_col is not None and input_col != TableField.CONTENT.value:
|
|
649
|
+
df = df.rename(columns={TableField.CONTENT.value: input_col})
|
|
596
650
|
|
|
597
|
-
|
|
598
|
-
|
|
651
|
+
df_out = project_datanode.predict(
|
|
652
|
+
model_name=model_rec.name,
|
|
653
|
+
df=df,
|
|
654
|
+
params=self.model_params
|
|
655
|
+
)
|
|
599
656
|
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
657
|
+
target = model_rec.to_predict[0]
|
|
658
|
+
if target != TableField.EMBEDDINGS.value:
|
|
659
|
+
# adapt output for vectordb
|
|
660
|
+
df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
|
|
661
|
+
|
|
662
|
+
elif self._kb.params.get('embedding_model'):
|
|
663
|
+
embedding_model = get_embedding_model_from_params(self._kb.params.get('embedding_model'))
|
|
664
|
+
|
|
665
|
+
df_texts = df.apply(row_to_document, axis=1)
|
|
666
|
+
embeddings = embedding_model.embed_documents(df_texts.tolist())
|
|
667
|
+
df_out = df.copy().assign(**{TableField.EMBEDDINGS.value: embeddings})
|
|
668
|
+
|
|
669
|
+
else:
|
|
670
|
+
raise ValueError("No embedding model found for the knowledge base.")
|
|
605
671
|
|
|
606
|
-
target = model_rec.to_predict[0]
|
|
607
|
-
if target != TableField.EMBEDDINGS.value:
|
|
608
|
-
# adapt output for vectordb
|
|
609
|
-
df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
|
|
610
672
|
df_out = df_out[[TableField.EMBEDDINGS.value]]
|
|
611
673
|
|
|
612
674
|
return df_out
|
|
@@ -640,9 +702,11 @@ class KnowledgeBaseTable:
|
|
|
640
702
|
# Extract embedding model args from knowledge base table
|
|
641
703
|
embedding_args = self._kb.embedding_model.learn_args.get('using', {})
|
|
642
704
|
# Construct the embedding model directly
|
|
643
|
-
from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args
|
|
644
705
|
embeddings_model = construct_model_from_args(embedding_args)
|
|
645
706
|
logger.debug(f"Using knowledge base embedding model with args: {embedding_args}")
|
|
707
|
+
elif self._kb.params.get('embedding_model'):
|
|
708
|
+
embeddings_model = get_embedding_model_from_params(self._kb.params['embedding_model'])
|
|
709
|
+
logger.debug(f"Using knowledge base embedding model from params: {self._kb.params['embedding_model']}")
|
|
646
710
|
else:
|
|
647
711
|
embeddings_model = DEFAULT_EMBEDDINGS_MODEL_CLASS()
|
|
648
712
|
logger.debug("Using default embedding model as knowledge base has no embedding model")
|
|
@@ -788,26 +852,46 @@ class KnowledgeBaseController:
|
|
|
788
852
|
return kb
|
|
789
853
|
raise EntityExistsError("Knowledge base already exists", name)
|
|
790
854
|
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
else:
|
|
796
|
-
# get embedding model from input
|
|
855
|
+
embedding_model_params = params.get('embedding_model', None)
|
|
856
|
+
reranking_model_params = params.get('reranking_model', None)
|
|
857
|
+
|
|
858
|
+
if embedding_model:
|
|
797
859
|
model_name = embedding_model.parts[-1]
|
|
798
860
|
|
|
861
|
+
elif embedding_model_params:
|
|
862
|
+
# Get embedding model from params.
|
|
863
|
+
# This is called here to check validaity of the parameters.
|
|
864
|
+
get_embedding_model_from_params(
|
|
865
|
+
embedding_model_params
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
else:
|
|
869
|
+
model_name = self._get_default_embedding_model(
|
|
870
|
+
project.name,
|
|
871
|
+
params=params
|
|
872
|
+
)
|
|
873
|
+
params['default_embedding_model'] = model_name
|
|
874
|
+
|
|
875
|
+
model_project = None
|
|
799
876
|
if embedding_model is not None and len(embedding_model.parts) > 1:
|
|
800
877
|
# model project is set
|
|
801
878
|
model_project = self.session.database_controller.get_project(embedding_model.parts[-2])
|
|
802
|
-
|
|
879
|
+
elif not embedding_model_params:
|
|
803
880
|
model_project = project
|
|
804
881
|
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
882
|
+
embedding_model_id = None
|
|
883
|
+
if model_project:
|
|
884
|
+
model = self.session.model_controller.get_model(
|
|
885
|
+
name=model_name,
|
|
886
|
+
project_name=model_project.name
|
|
887
|
+
)
|
|
888
|
+
model_record = db.Predictor.query.get(model['id'])
|
|
889
|
+
embedding_model_id = model_record.id
|
|
890
|
+
|
|
891
|
+
if reranking_model_params:
|
|
892
|
+
# Get reranking model from params.
|
|
893
|
+
# This is called here to check validaity of the parameters.
|
|
894
|
+
get_reranking_model_from_params(reranking_model_params)
|
|
811
895
|
|
|
812
896
|
# search for the vector database table
|
|
813
897
|
if storage is None:
|
|
@@ -1029,6 +1113,7 @@ class KnowledgeBaseController:
|
|
|
1029
1113
|
'embedding_model': embedding_model.name if embedding_model is not None else None,
|
|
1030
1114
|
'vector_database': None if vector_database is None else vector_database.name,
|
|
1031
1115
|
'vector_database_table': record.vector_database_table,
|
|
1116
|
+
'query_id': record.query_id,
|
|
1032
1117
|
'params': record.params
|
|
1033
1118
|
})
|
|
1034
1119
|
|
|
@@ -1,11 +1,17 @@
|
|
|
1
1
|
from typing import List
|
|
2
|
+
import pickle
|
|
3
|
+
import datetime as dt
|
|
2
4
|
|
|
5
|
+
from sqlalchemy.orm.attributes import flag_modified
|
|
3
6
|
import pandas as pd
|
|
4
7
|
|
|
8
|
+
from mindsdb_sql_parser import Select, Star, OrderBy
|
|
9
|
+
|
|
5
10
|
from mindsdb_sql_parser.ast import (
|
|
6
11
|
Identifier, BinaryOperation, Last, Constant, ASTNode
|
|
7
12
|
)
|
|
8
13
|
from mindsdb.integrations.utilities.query_traversal import query_traversal
|
|
14
|
+
from mindsdb.utilities.cache import get_cache
|
|
9
15
|
|
|
10
16
|
from mindsdb.interfaces.storage import db
|
|
11
17
|
from mindsdb.utilities.context import context as ctx
|
|
@@ -13,6 +19,147 @@ from mindsdb.utilities.context import context as ctx
|
|
|
13
19
|
from .last_query import LastQuery
|
|
14
20
|
|
|
15
21
|
|
|
22
|
+
class RunningQuery:
|
|
23
|
+
"""
|
|
24
|
+
Query in progres
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, record: db.Queries):
|
|
28
|
+
self.record = record
|
|
29
|
+
self.sql = record.sql
|
|
30
|
+
|
|
31
|
+
def get_partition_query(self, step_num: int, query: Select) -> Select:
|
|
32
|
+
"""
|
|
33
|
+
Generate query for fetching the next partition
|
|
34
|
+
It wraps query to
|
|
35
|
+
select * from ({query})
|
|
36
|
+
where {track_column} > {previous_value}
|
|
37
|
+
order by track_column
|
|
38
|
+
limit size {batch_size}
|
|
39
|
+
And fill track_column, previous_value, batch_size
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
track_column = self.record.parameters['track_column']
|
|
43
|
+
|
|
44
|
+
query = Select(
|
|
45
|
+
targets=[Star()],
|
|
46
|
+
from_table=query,
|
|
47
|
+
order_by=[OrderBy(Identifier(track_column))],
|
|
48
|
+
limit=Constant(self.batch_size)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
track_value = self.record.context.get('track_value')
|
|
52
|
+
# is it different step?
|
|
53
|
+
cur_step_num = self.record.context.get('step_num')
|
|
54
|
+
if cur_step_num is not None and cur_step_num != step_num:
|
|
55
|
+
# reset track_value
|
|
56
|
+
track_value = None
|
|
57
|
+
self.record.context['track_value'] = None
|
|
58
|
+
self.record.context['step_num'] = step_num
|
|
59
|
+
flag_modified(self.record, 'context')
|
|
60
|
+
db.session.commit()
|
|
61
|
+
|
|
62
|
+
if track_value is not None:
|
|
63
|
+
query.where = BinaryOperation(
|
|
64
|
+
op='>',
|
|
65
|
+
args=[Identifier(track_column), Constant(track_value)],
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return query
|
|
69
|
+
|
|
70
|
+
def set_params(self, params: dict):
|
|
71
|
+
"""
|
|
72
|
+
Store parameters of the step which is about to be split into partitions
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
if 'track_column' not in params:
|
|
76
|
+
raise ValueError('Track column is not defined')
|
|
77
|
+
if 'batch_size' not in params:
|
|
78
|
+
params['batch_size'] = 1000
|
|
79
|
+
|
|
80
|
+
self.record.parameters = params
|
|
81
|
+
self.batch_size = self.record.parameters['batch_size']
|
|
82
|
+
db.session.commit()
|
|
83
|
+
|
|
84
|
+
def get_max_track_value(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
85
|
+
"""
|
|
86
|
+
return max value to use in `set_progress`.
|
|
87
|
+
this function is called before execution substeps,
|
|
88
|
+
`set_progress` function - after
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
track_column = self.record.parameters['track_column']
|
|
92
|
+
return df[track_column].max()
|
|
93
|
+
|
|
94
|
+
def set_progress(self, df: pd.DataFrame, max_track_value: int):
|
|
95
|
+
"""
|
|
96
|
+
Store progres of the query, it is called after processing of batch
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
if len(df) == 0:
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
self.record.processed_rows = self.record.processed_rows + len(df)
|
|
103
|
+
|
|
104
|
+
cur_value = self.record.context.get('track_value')
|
|
105
|
+
new_value = max_track_value
|
|
106
|
+
if new_value is not None:
|
|
107
|
+
if cur_value is None or new_value > cur_value:
|
|
108
|
+
self.record.context['track_value'] = new_value
|
|
109
|
+
flag_modified(self.record, 'context')
|
|
110
|
+
|
|
111
|
+
db.session.commit()
|
|
112
|
+
|
|
113
|
+
def on_error(self, error: Exception, step_num: int, steps_data: dict):
|
|
114
|
+
"""
|
|
115
|
+
Saves error of the query in database
|
|
116
|
+
Also saves step data and current step num to be able to resume query
|
|
117
|
+
"""
|
|
118
|
+
self.record.error = str(error)
|
|
119
|
+
self.record.context['step_num'] = step_num
|
|
120
|
+
flag_modified(self.record, 'context')
|
|
121
|
+
|
|
122
|
+
# save steps_data
|
|
123
|
+
cache = get_cache('steps_data')
|
|
124
|
+
data = pickle.dumps(steps_data, protocol=5)
|
|
125
|
+
cache.set(str(self.record.id), data)
|
|
126
|
+
|
|
127
|
+
db.session.commit()
|
|
128
|
+
|
|
129
|
+
def clear_error(self):
|
|
130
|
+
"""
|
|
131
|
+
Reset error of the query in database
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
if self.record.error is not None:
|
|
135
|
+
self.record.error = None
|
|
136
|
+
db.session.commit()
|
|
137
|
+
|
|
138
|
+
def get_state(self) -> dict:
|
|
139
|
+
"""
|
|
140
|
+
Returns stored state for resuming the query
|
|
141
|
+
"""
|
|
142
|
+
cache = get_cache('steps_data')
|
|
143
|
+
key = self.record.id
|
|
144
|
+
data = cache.get(key)
|
|
145
|
+
cache.delete(key)
|
|
146
|
+
|
|
147
|
+
steps_data = pickle.loads(data)
|
|
148
|
+
|
|
149
|
+
return {
|
|
150
|
+
'step_num': self.record.context.get('step_num'),
|
|
151
|
+
'steps_data': steps_data,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
def finish(self):
|
|
155
|
+
"""
|
|
156
|
+
Mark query as finished
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
self.record.finished_at = dt.datetime.now()
|
|
160
|
+
db.session.commit()
|
|
161
|
+
|
|
162
|
+
|
|
16
163
|
class QueryContextController:
|
|
17
164
|
IGNORE_CONTEXT = '<IGNORE>'
|
|
18
165
|
|
|
@@ -287,5 +434,79 @@ class QueryContextController:
|
|
|
287
434
|
rec.values = values
|
|
288
435
|
db.session.commit()
|
|
289
436
|
|
|
437
|
+
def get_query(self, query_id: int) -> RunningQuery:
|
|
438
|
+
"""
|
|
439
|
+
Get running query by id
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
rec = db.Queries.query.filter(
|
|
443
|
+
db.Queries.id == query_id,
|
|
444
|
+
db.Queries.company_id == ctx.company_id
|
|
445
|
+
).first()
|
|
446
|
+
|
|
447
|
+
if rec is None:
|
|
448
|
+
raise RuntimeError(f'Query not found: {query_id}')
|
|
449
|
+
return RunningQuery(rec)
|
|
450
|
+
|
|
451
|
+
def create_query(self, query: ASTNode) -> RunningQuery:
|
|
452
|
+
"""
|
|
453
|
+
Create a new running query from AST query
|
|
454
|
+
"""
|
|
455
|
+
|
|
456
|
+
# remove old queries
|
|
457
|
+
remove_query = db.session.query(db.Queries).filter(
|
|
458
|
+
db.Queries.company_id == ctx.company_id,
|
|
459
|
+
db.Queries.finished_at < (dt.datetime.now() - dt.timedelta(days=1))
|
|
460
|
+
)
|
|
461
|
+
for rec in remove_query.all():
|
|
462
|
+
db.session.delete(rec)
|
|
463
|
+
|
|
464
|
+
rec = db.Queries(
|
|
465
|
+
sql=str(query),
|
|
466
|
+
company_id=ctx.company_id,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
db.session.add(rec)
|
|
470
|
+
db.session.commit()
|
|
471
|
+
return RunningQuery(rec)
|
|
472
|
+
|
|
473
|
+
def list_queries(self) -> List[dict]:
|
|
474
|
+
"""
|
|
475
|
+
Get list of all running queries with metadata
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
query = db.session.query(db.Queries).filter(
|
|
479
|
+
db.Queries.company_id == ctx.company_id
|
|
480
|
+
)
|
|
481
|
+
return [
|
|
482
|
+
{
|
|
483
|
+
'id': record.id,
|
|
484
|
+
'sql': record.sql,
|
|
485
|
+
'started_at': record.started_at,
|
|
486
|
+
'finished_at': record.finished_at,
|
|
487
|
+
'parameters': record.parameters,
|
|
488
|
+
'context': record.context,
|
|
489
|
+
'processed_rows': record.processed_rows,
|
|
490
|
+
'error': record.error,
|
|
491
|
+
'updated_at': record.updated_at,
|
|
492
|
+
}
|
|
493
|
+
for record in query
|
|
494
|
+
]
|
|
495
|
+
|
|
496
|
+
def cancel_query(self, query_id: int):
|
|
497
|
+
"""
|
|
498
|
+
Cancels running query by id
|
|
499
|
+
"""
|
|
500
|
+
rec = db.Queries.query.filter(
|
|
501
|
+
db.Queries.id == query_id,
|
|
502
|
+
db.Queries.company_id == ctx.company_id
|
|
503
|
+
).first()
|
|
504
|
+
if rec is None:
|
|
505
|
+
raise RuntimeError(f'Query not found: {query_id}')
|
|
506
|
+
|
|
507
|
+
# the query in progress will fail when it tries to update status
|
|
508
|
+
db.session.delete(rec)
|
|
509
|
+
db.session.commit()
|
|
510
|
+
|
|
290
511
|
|
|
291
512
|
query_context_controller = QueryContextController()
|
mindsdb/interfaces/storage/db.py
CHANGED
|
@@ -523,6 +523,7 @@ class KnowledgeBase(Base):
|
|
|
523
523
|
embedding_model = relationship(
|
|
524
524
|
"Predictor", foreign_keys=[embedding_model_id], doc="embedding model"
|
|
525
525
|
)
|
|
526
|
+
query_id = Column(Integer, nullable=True)
|
|
526
527
|
|
|
527
528
|
created_at = Column(DateTime, default=datetime.datetime.now)
|
|
528
529
|
updated_at = Column(
|
|
@@ -564,6 +565,28 @@ class QueryContext(Base):
|
|
|
564
565
|
created_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now)
|
|
565
566
|
|
|
566
567
|
|
|
568
|
+
class Queries(Base):
|
|
569
|
+
__tablename__ = "queries"
|
|
570
|
+
id: int = Column(Integer, primary_key=True)
|
|
571
|
+
company_id: int = Column(Integer, nullable=True)
|
|
572
|
+
|
|
573
|
+
sql: str = Column(String, nullable=False)
|
|
574
|
+
# step_data: JSON = Column(JSON, nullable=True)
|
|
575
|
+
|
|
576
|
+
started_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now)
|
|
577
|
+
finished_at: datetime.datetime = Column(DateTime)
|
|
578
|
+
|
|
579
|
+
parameters = Column(JSON, default={})
|
|
580
|
+
context = Column(JSON, default={})
|
|
581
|
+
processed_rows = Column(Integer, default=0)
|
|
582
|
+
error: str = Column(String, nullable=True)
|
|
583
|
+
|
|
584
|
+
updated_at: datetime.datetime = Column(
|
|
585
|
+
DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now
|
|
586
|
+
)
|
|
587
|
+
created_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now)
|
|
588
|
+
|
|
589
|
+
|
|
567
590
|
class LLMLog(Base):
|
|
568
591
|
__tablename__ = "llm_log"
|
|
569
592
|
id: int = Column(Integer, primary_key=True)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""queries
|
|
2
|
+
|
|
3
|
+
Revision ID: fda503400e43
|
|
4
|
+
Revises: 11347c213b36
|
|
5
|
+
Create Date: 2025-03-21 18:50:20.795930
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
from alembic import op
|
|
9
|
+
import sqlalchemy as sa
|
|
10
|
+
import mindsdb.interfaces.storage.db # noqa
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# revision identifiers, used by Alembic.
|
|
14
|
+
revision = 'fda503400e43'
|
|
15
|
+
down_revision = '11347c213b36'
|
|
16
|
+
branch_labels = None
|
|
17
|
+
depends_on = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def upgrade():
|
|
21
|
+
op.create_table(
|
|
22
|
+
'queries',
|
|
23
|
+
sa.Column('id', sa.Integer(), nullable=False),
|
|
24
|
+
sa.Column('company_id', sa.Integer(), nullable=True),
|
|
25
|
+
sa.Column('sql', sa.String(), nullable=False),
|
|
26
|
+
sa.Column('started_at', sa.DateTime(), nullable=True),
|
|
27
|
+
sa.Column('finished_at', sa.DateTime(), nullable=True),
|
|
28
|
+
sa.Column('parameters', sa.JSON(), nullable=True),
|
|
29
|
+
sa.Column('context', sa.JSON(), nullable=True),
|
|
30
|
+
sa.Column('processed_rows', sa.Integer(), nullable=True),
|
|
31
|
+
sa.Column('error', sa.String(), nullable=True),
|
|
32
|
+
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
|
33
|
+
sa.Column('created_at', sa.DateTime(), nullable=True),
|
|
34
|
+
sa.PrimaryKeyConstraint('id')
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
with op.batch_alter_table('knowledge_base', schema=None) as batch_op:
|
|
38
|
+
batch_op.add_column(sa.Column('query_id', sa.INTEGER(), nullable=True))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def downgrade():
|
|
42
|
+
with op.batch_alter_table('knowledge_base', schema=None) as batch_op:
|
|
43
|
+
batch_op.drop_column('query_id')
|
|
44
|
+
|
|
45
|
+
op.drop_table('queries')
|
|
@@ -43,7 +43,7 @@ def execute_in_threads(func, tasks, thread_count=3, queue_size_k=1.5):
|
|
|
43
43
|
for i in range(queue_size):
|
|
44
44
|
try:
|
|
45
45
|
args = next(tasks)
|
|
46
|
-
futures.append(executor.submit(func,
|
|
46
|
+
futures.append(executor.submit(func, args))
|
|
47
47
|
except StopIteration:
|
|
48
48
|
break
|
|
49
49
|
|
|
@@ -6,6 +6,35 @@ from mindsdb.utilities.config import Config
|
|
|
6
6
|
from mindsdb.utilities.context_executor import execute_in_threads
|
|
7
7
|
|
|
8
8
|
|
|
9
|
+
def get_max_thread_count() -> int:
|
|
10
|
+
"""
|
|
11
|
+
Calculate the maximum number of threads allowed for the system.
|
|
12
|
+
"""
|
|
13
|
+
# workers count
|
|
14
|
+
is_cloud = Config().is_cloud
|
|
15
|
+
if is_cloud:
|
|
16
|
+
max_threads = int(os.getenv('MINDSDB_MAX_PARTITIONING_THREADS', 10))
|
|
17
|
+
else:
|
|
18
|
+
max_threads = os.cpu_count() - 3
|
|
19
|
+
|
|
20
|
+
if max_threads < 1:
|
|
21
|
+
max_threads = 1
|
|
22
|
+
|
|
23
|
+
return max_threads
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def split_data_frame(df: pd.DataFrame, partition_size: int) -> Iterable[pd.DataFrame]:
|
|
27
|
+
"""
|
|
28
|
+
Split data frame into chunks with partition_size and yield them out
|
|
29
|
+
"""
|
|
30
|
+
chunk = 0
|
|
31
|
+
while chunk * partition_size < len(df):
|
|
32
|
+
# create results with partition
|
|
33
|
+
df1 = df.iloc[chunk * partition_size: (chunk + 1) * partition_size]
|
|
34
|
+
chunk += 1
|
|
35
|
+
yield df1
|
|
36
|
+
|
|
37
|
+
|
|
9
38
|
def process_dataframe_in_partitions(df: pd.DataFrame, callback: Callable, partition_size: int) -> Iterable:
|
|
10
39
|
"""
|
|
11
40
|
Splits dataframe into partitions and apply callback on each partition
|
|
@@ -17,35 +46,21 @@ def process_dataframe_in_partitions(df: pd.DataFrame, callback: Callable, partit
|
|
|
17
46
|
"""
|
|
18
47
|
|
|
19
48
|
# tasks
|
|
20
|
-
def split_data_f(df):
|
|
21
|
-
chunk = 0
|
|
22
|
-
while chunk * partition_size < len(df):
|
|
23
|
-
# create results with partition
|
|
24
|
-
df1 = df.iloc[chunk * partition_size: (chunk + 1) * partition_size]
|
|
25
|
-
chunk += 1
|
|
26
|
-
yield [df1]
|
|
27
49
|
|
|
28
|
-
tasks =
|
|
50
|
+
tasks = split_data_frame(df, partition_size)
|
|
29
51
|
|
|
30
|
-
|
|
31
|
-
is_cloud = Config().is_cloud
|
|
32
|
-
if is_cloud:
|
|
33
|
-
max_threads = int(os.getenv('MINDSDB_MAX_PARTITIONING_THREADS', 10))
|
|
34
|
-
else:
|
|
35
|
-
max_threads = os.cpu_count() - 2
|
|
52
|
+
max_threads = get_max_thread_count()
|
|
36
53
|
|
|
37
|
-
# don't exceed chunk_count
|
|
38
54
|
chunk_count = int(len(df) / partition_size)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
max_threads = 1
|
|
55
|
+
# don't exceed chunk_count
|
|
56
|
+
if chunk_count > 0:
|
|
57
|
+
max_threads = min(max_threads, chunk_count)
|
|
43
58
|
|
|
44
59
|
if max_threads == 1:
|
|
45
60
|
# don't spawn threads
|
|
46
61
|
|
|
47
62
|
for task in tasks:
|
|
48
|
-
yield callback(
|
|
63
|
+
yield callback(task)
|
|
49
64
|
|
|
50
65
|
else:
|
|
51
66
|
for result in execute_in_threads(callback, tasks, thread_count=max_threads):
|