MindsDB 25.4.2.0__py3-none-any.whl → 25.4.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/api/executor/command_executor.py +29 -0
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +3 -2
- mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +43 -1
- mindsdb/api/executor/planner/plan_join.py +1 -1
- mindsdb/api/executor/planner/query_plan.py +1 -0
- mindsdb/api/executor/planner/query_planner.py +86 -14
- mindsdb/api/executor/planner/steps.py +9 -1
- mindsdb/api/executor/sql_query/sql_query.py +37 -6
- mindsdb/api/executor/sql_query/steps/__init__.py +1 -0
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +288 -0
- mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +17 -16
- mindsdb/integrations/handlers/langchain_handler/langchain_handler.py +1 -0
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +7 -11
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +28 -4
- mindsdb/integrations/libs/llm/config.py +11 -1
- mindsdb/integrations/libs/llm/utils.py +12 -0
- mindsdb/interfaces/agents/constants.py +12 -1
- mindsdb/interfaces/agents/langchain_agent.py +6 -0
- mindsdb/interfaces/knowledge_base/controller.py +128 -43
- mindsdb/interfaces/query_context/context_controller.py +221 -0
- mindsdb/interfaces/storage/db.py +23 -0
- mindsdb/migrations/versions/2025-03-21_fda503400e43_queries.py +45 -0
- mindsdb/utilities/context_executor.py +1 -1
- mindsdb/utilities/partitioning.py +35 -20
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/METADATA +224 -222
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/RECORD +30 -28
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/WHEEL +0 -0
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import threading
|
|
3
|
+
import queue
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
from mindsdb_sql_parser import ASTNode
|
|
7
|
+
from mindsdb.api.executor.planner.steps import FetchDataframeStepPartition
|
|
8
|
+
from mindsdb.integrations.utilities.query_traversal import query_traversal
|
|
9
|
+
|
|
10
|
+
from mindsdb.interfaces.query_context.context_controller import RunningQuery
|
|
11
|
+
from mindsdb.api.executor.sql_query.result_set import ResultSet
|
|
12
|
+
from mindsdb.utilities import log
|
|
13
|
+
from mindsdb.utilities.config import Config
|
|
14
|
+
from mindsdb.utilities.context import Context, context as ctx
|
|
15
|
+
from mindsdb.utilities.partitioning import get_max_thread_count, split_data_frame
|
|
16
|
+
from mindsdb.api.executor.sql_query.steps.fetch_dataframe import get_table_alias, get_fill_param_fnc
|
|
17
|
+
|
|
18
|
+
from .base import BaseStepCall
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = log.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FetchDataframePartitionCall(BaseStepCall):
|
|
25
|
+
"""
|
|
26
|
+
Alternative to FetchDataframeCall but fetch data by batches wrapping user's query to:
|
|
27
|
+
|
|
28
|
+
select * from ({user query})
|
|
29
|
+
where {track_column} > {previous value}
|
|
30
|
+
order by track_column
|
|
31
|
+
limit size {batch_size} `
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
bind = FetchDataframeStepPartition
|
|
36
|
+
|
|
37
|
+
def call(self, step: FetchDataframeStepPartition) -> ResultSet:
|
|
38
|
+
"""
|
|
39
|
+
Parameters:
|
|
40
|
+
- batch_size - count of rows to fetch from database per iteration, optional default 1000
|
|
41
|
+
- threads - run partitioning in threads, bool or int, optinal, if set:
|
|
42
|
+
- int value: use this as count of threads
|
|
43
|
+
- true: table threads, autodetect count of thread
|
|
44
|
+
- false: disable threads even if ml task queue is enabled
|
|
45
|
+
- track_column - column used for creating partitions
|
|
46
|
+
- query will be sorted by this column and select will be limited by batch_size
|
|
47
|
+
- error (default raise)
|
|
48
|
+
- when `error='skip'`, errors in partition will be skipped and execution will be continued
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
self.dn = self.session.datahub.get(step.integration)
|
|
52
|
+
query = step.query
|
|
53
|
+
|
|
54
|
+
# fill params
|
|
55
|
+
fill_params = get_fill_param_fnc(self.steps_data)
|
|
56
|
+
query_traversal(query, fill_params)
|
|
57
|
+
|
|
58
|
+
# get query record
|
|
59
|
+
run_query = self.sql_query.run_query
|
|
60
|
+
if run_query is None:
|
|
61
|
+
raise RuntimeError('Error with partitioning of the query')
|
|
62
|
+
run_query.set_params(step.params)
|
|
63
|
+
|
|
64
|
+
self.table_alias = get_table_alias(step.query.from_table, self.context.get('database'))
|
|
65
|
+
self.current_step_num = step.step_num
|
|
66
|
+
self.substeps = step.steps
|
|
67
|
+
|
|
68
|
+
config = Config()
|
|
69
|
+
|
|
70
|
+
# ml task queue enabled?
|
|
71
|
+
use_threads, thread_count = False, None
|
|
72
|
+
if config['ml_task_queue']['type'] == 'redis':
|
|
73
|
+
use_threads = True
|
|
74
|
+
|
|
75
|
+
# use threads?
|
|
76
|
+
if 'threads' in step.params:
|
|
77
|
+
threads = step.params['threads']
|
|
78
|
+
if isinstance(threads, int):
|
|
79
|
+
thread_count = threads
|
|
80
|
+
use_threads = True
|
|
81
|
+
if threads is True:
|
|
82
|
+
use_threads = True
|
|
83
|
+
if threads is False:
|
|
84
|
+
# disable even with ml task queue
|
|
85
|
+
use_threads = False
|
|
86
|
+
|
|
87
|
+
on_error = step.params.get('error', 'raise')
|
|
88
|
+
if use_threads:
|
|
89
|
+
return self.fetch_threads(run_query, query, thread_count=thread_count, on_error=on_error)
|
|
90
|
+
else:
|
|
91
|
+
return self.fetch_iterate(run_query, query, on_error=on_error)
|
|
92
|
+
|
|
93
|
+
def fetch_iterate(self, run_query: RunningQuery, query: ASTNode, on_error: str = None) -> ResultSet:
|
|
94
|
+
"""
|
|
95
|
+
Process batches one by one in circle
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
results = []
|
|
99
|
+
while True:
|
|
100
|
+
|
|
101
|
+
# fetch batch
|
|
102
|
+
query2 = run_query.get_partition_query(self.current_step_num, query)
|
|
103
|
+
response = self.dn.query(
|
|
104
|
+
query=query2,
|
|
105
|
+
session=self.session
|
|
106
|
+
)
|
|
107
|
+
df = response.data_frame
|
|
108
|
+
|
|
109
|
+
if df is None or len(df) == 0:
|
|
110
|
+
break
|
|
111
|
+
|
|
112
|
+
# executing of sub steps can modify dataframe columns, lets memorise max tracking value
|
|
113
|
+
max_track_value = run_query.get_max_track_value(df)
|
|
114
|
+
try:
|
|
115
|
+
sub_data = self.exec_sub_steps(df)
|
|
116
|
+
results.append(sub_data)
|
|
117
|
+
except Exception as e:
|
|
118
|
+
if on_error == 'skip':
|
|
119
|
+
logger.error(e)
|
|
120
|
+
else:
|
|
121
|
+
raise e
|
|
122
|
+
|
|
123
|
+
run_query.set_progress(df, max_track_value)
|
|
124
|
+
|
|
125
|
+
return self.concat_results(results)
|
|
126
|
+
|
|
127
|
+
def concat_results(self, results: List[ResultSet]) -> ResultSet:
|
|
128
|
+
"""
|
|
129
|
+
Concatenate list of result sets to single result set
|
|
130
|
+
"""
|
|
131
|
+
df_list = []
|
|
132
|
+
for res in results:
|
|
133
|
+
df, col_names = res.to_df_cols()
|
|
134
|
+
if len(df) > 0:
|
|
135
|
+
df_list.append(df)
|
|
136
|
+
|
|
137
|
+
data = ResultSet()
|
|
138
|
+
if len(df_list) > 0:
|
|
139
|
+
data.from_df_cols(pd.concat(df_list), col_names)
|
|
140
|
+
|
|
141
|
+
return data
|
|
142
|
+
|
|
143
|
+
def exec_sub_steps(self, df: pd.DataFrame) -> ResultSet:
|
|
144
|
+
"""
|
|
145
|
+
FetchDataframeStepPartition has substeps defined
|
|
146
|
+
Every batch of data have to be used to execute these substeps
|
|
147
|
+
- batch of data is put as result of FetchDataframeStepPartition
|
|
148
|
+
- substep are executed using result of previos step (like it is all fetched data is available)
|
|
149
|
+
- the final result is returned and used outside to concatenate with results of other's batches
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
input_data = ResultSet()
|
|
153
|
+
|
|
154
|
+
input_data.from_df(
|
|
155
|
+
df,
|
|
156
|
+
table_name=self.table_alias[1],
|
|
157
|
+
table_alias=self.table_alias[2],
|
|
158
|
+
database=self.table_alias[0]
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# execute with modified previous results
|
|
162
|
+
steps_data2 = self.steps_data.copy()
|
|
163
|
+
steps_data2[self.current_step_num] = input_data
|
|
164
|
+
|
|
165
|
+
sub_data = None
|
|
166
|
+
for substep in self.substeps:
|
|
167
|
+
sub_data = self.sql_query.execute_step(substep, steps_data=steps_data2)
|
|
168
|
+
steps_data2[substep.step_num] = sub_data
|
|
169
|
+
return sub_data
|
|
170
|
+
|
|
171
|
+
def fetch_threads(self, run_query: RunningQuery, query: ASTNode,
|
|
172
|
+
thread_count: int = None, on_error: str = None) -> ResultSet:
|
|
173
|
+
"""
|
|
174
|
+
Process batches in threads
|
|
175
|
+
- spawn required count of threads
|
|
176
|
+
- create in/out queue to communicate with threads
|
|
177
|
+
- send task to threads and receive results
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
# create communication queues
|
|
181
|
+
queue_in = queue.Queue()
|
|
182
|
+
queue_out = queue.Queue()
|
|
183
|
+
self.stop_event = threading.Event()
|
|
184
|
+
|
|
185
|
+
if thread_count is None:
|
|
186
|
+
thread_count = get_max_thread_count()
|
|
187
|
+
|
|
188
|
+
# 3 tasks per worker during 1 batch
|
|
189
|
+
partition_size = int(run_query.batch_size / thread_count / 3)
|
|
190
|
+
# min partition size
|
|
191
|
+
if partition_size < 10:
|
|
192
|
+
partition_size = 10
|
|
193
|
+
|
|
194
|
+
# create N workers pool
|
|
195
|
+
workers = []
|
|
196
|
+
results = []
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
for i in range(thread_count):
|
|
200
|
+
worker = threading.Thread(target=self._worker, daemon=True, args=(ctx.dump(), queue_in,
|
|
201
|
+
queue_out, self.stop_event))
|
|
202
|
+
worker.start()
|
|
203
|
+
workers.append(worker)
|
|
204
|
+
|
|
205
|
+
while True:
|
|
206
|
+
# fetch batch
|
|
207
|
+
query2 = run_query.get_partition_query(self.current_step_num, query)
|
|
208
|
+
response = self.dn.query(
|
|
209
|
+
query=query2,
|
|
210
|
+
session=self.session
|
|
211
|
+
)
|
|
212
|
+
df = response.data_frame
|
|
213
|
+
|
|
214
|
+
if df is None or len(df) == 0:
|
|
215
|
+
# TODO detect circles: data handler ignores condition and output is repeated
|
|
216
|
+
|
|
217
|
+
# exit & stop workers
|
|
218
|
+
break
|
|
219
|
+
|
|
220
|
+
max_track_value = run_query.get_max_track_value(df)
|
|
221
|
+
|
|
222
|
+
# split into chunks and send to workers
|
|
223
|
+
sent_chunks = 0
|
|
224
|
+
for df2 in split_data_frame(df, partition_size):
|
|
225
|
+
queue_in.put([sent_chunks, df2])
|
|
226
|
+
sent_chunks += 1
|
|
227
|
+
|
|
228
|
+
batch_results = []
|
|
229
|
+
for i in range(sent_chunks):
|
|
230
|
+
res = queue_out.get()
|
|
231
|
+
if 'error' in res:
|
|
232
|
+
if on_error == 'skip':
|
|
233
|
+
logger.error(res['error'])
|
|
234
|
+
else:
|
|
235
|
+
raise RuntimeError(res['error'])
|
|
236
|
+
|
|
237
|
+
if res['data']:
|
|
238
|
+
batch_results.append(res)
|
|
239
|
+
|
|
240
|
+
# sort results
|
|
241
|
+
batch_results.sort(key=lambda x: x['num'])
|
|
242
|
+
|
|
243
|
+
results.append(self.concat_results(
|
|
244
|
+
[item['data'] for item in batch_results]
|
|
245
|
+
))
|
|
246
|
+
|
|
247
|
+
# TODO
|
|
248
|
+
# 1. get next batch without updating track_value:
|
|
249
|
+
# it allows to keep queue_in filled with data between fetching batches
|
|
250
|
+
run_query.set_progress(df, max_track_value)
|
|
251
|
+
finally:
|
|
252
|
+
self.close_workers(workers)
|
|
253
|
+
|
|
254
|
+
return self.concat_results(results)
|
|
255
|
+
|
|
256
|
+
def close_workers(self, workers: List[threading.Thread]):
|
|
257
|
+
"""
|
|
258
|
+
Sent signal to workers to stop
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
self.stop_event.set()
|
|
262
|
+
for worker in workers:
|
|
263
|
+
if worker.is_alive():
|
|
264
|
+
worker.join()
|
|
265
|
+
|
|
266
|
+
def _worker(self, context: Context, queue_in: queue.Queue, queue_out: queue.Queue, stop_event: threading.Event):
|
|
267
|
+
"""
|
|
268
|
+
Worker function. Execute incoming tasks unless stop_event is set
|
|
269
|
+
"""
|
|
270
|
+
ctx.load(context)
|
|
271
|
+
while True:
|
|
272
|
+
if stop_event.is_set():
|
|
273
|
+
break
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
chunk_num, df = queue_in.get(timeout=1)
|
|
277
|
+
if df is None:
|
|
278
|
+
continue
|
|
279
|
+
|
|
280
|
+
sub_data = self.exec_sub_steps(df)
|
|
281
|
+
|
|
282
|
+
queue_out.put({'data': sub_data, 'num': chunk_num})
|
|
283
|
+
except queue.Empty:
|
|
284
|
+
continue
|
|
285
|
+
|
|
286
|
+
except Exception as e:
|
|
287
|
+
queue_out.put({'error': str(e)})
|
|
288
|
+
stop_event.set()
|
|
@@ -104,6 +104,22 @@ def construct_model_from_args(args: Dict) -> Embeddings:
|
|
|
104
104
|
return model
|
|
105
105
|
|
|
106
106
|
|
|
107
|
+
def row_to_document(row: pd.Series) -> str:
|
|
108
|
+
"""
|
|
109
|
+
Convert a row in the input dataframe into a document
|
|
110
|
+
|
|
111
|
+
Default implementation is to concatenate all the columns
|
|
112
|
+
in the form of
|
|
113
|
+
field1: value1\nfield2: value2\n...
|
|
114
|
+
"""
|
|
115
|
+
fields = row.index.tolist()
|
|
116
|
+
values = row.values.tolist()
|
|
117
|
+
document = "\n".join(
|
|
118
|
+
[f"{field}: {value}" for field, value in zip(fields, values)]
|
|
119
|
+
)
|
|
120
|
+
return document
|
|
121
|
+
|
|
122
|
+
|
|
107
123
|
class LangchainEmbeddingHandler(BaseMLEngine):
|
|
108
124
|
"""
|
|
109
125
|
Bridge class to connect langchain.embeddings module to mindsDB
|
|
@@ -180,7 +196,7 @@ class LangchainEmbeddingHandler(BaseMLEngine):
|
|
|
180
196
|
)
|
|
181
197
|
|
|
182
198
|
# convert each row into a document
|
|
183
|
-
df_texts = df[input_columns].apply(
|
|
199
|
+
df_texts = df[input_columns].apply(row_to_document, axis=1)
|
|
184
200
|
embeddings = model.embed_documents(df_texts.tolist())
|
|
185
201
|
|
|
186
202
|
# create a new dataframe with the embeddings
|
|
@@ -188,21 +204,6 @@ class LangchainEmbeddingHandler(BaseMLEngine):
|
|
|
188
204
|
|
|
189
205
|
return df_embeddings
|
|
190
206
|
|
|
191
|
-
def row_to_document(self, row: pd.Series) -> str:
|
|
192
|
-
"""
|
|
193
|
-
Convert a row in the input dataframe into a document
|
|
194
|
-
|
|
195
|
-
Default implementation is to concatenate all the columns
|
|
196
|
-
in the form of
|
|
197
|
-
field1: value1\nfield2: value2\n...
|
|
198
|
-
"""
|
|
199
|
-
fields = row.index.tolist()
|
|
200
|
-
values = row.values.tolist()
|
|
201
|
-
document = "\n".join(
|
|
202
|
-
[f"{field}: {value}" for field, value in zip(fields, values)]
|
|
203
|
-
)
|
|
204
|
-
return document
|
|
205
|
-
|
|
206
207
|
def finetune(
|
|
207
208
|
self, df: Union[DataFrame, None] = None, args: Union[Dict, None] = None
|
|
208
209
|
) -> None:
|
|
@@ -46,7 +46,8 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
46
46
|
|
|
47
47
|
def _make_connection_args(self):
|
|
48
48
|
cloud_pgvector_url = os.environ.get('KB_PGVECTOR_URL')
|
|
49
|
-
if
|
|
49
|
+
# if no connection args and shared pg vector defined - use it
|
|
50
|
+
if len(self.connection_args) == 0 and cloud_pgvector_url is not None:
|
|
50
51
|
result = urlparse(cloud_pgvector_url)
|
|
51
52
|
self.connection_args = {
|
|
52
53
|
'host': result.hostname,
|
|
@@ -157,7 +158,7 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
157
158
|
where_clauses.append(f'{key} {value["op"]} {value["value"]}')
|
|
158
159
|
|
|
159
160
|
if len(where_clauses) > 1:
|
|
160
|
-
return f"WHERE{' AND '.join(where_clauses)}"
|
|
161
|
+
return f"WHERE {' AND '.join(where_clauses)}"
|
|
161
162
|
elif len(where_clauses) == 1:
|
|
162
163
|
return f"WHERE {where_clauses[0]}"
|
|
163
164
|
else:
|
|
@@ -195,11 +196,6 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
195
196
|
# given filter conditions, construct where clause
|
|
196
197
|
where_clause = self._construct_where_clause(filter_conditions)
|
|
197
198
|
|
|
198
|
-
# construct full after from clause, where clause + offset clause + limit clause
|
|
199
|
-
after_from_clause = self._construct_full_after_from_clause(
|
|
200
|
-
where_clause, offset_clause, limit_clause
|
|
201
|
-
)
|
|
202
|
-
|
|
203
199
|
# Handle distance column specially since it's calculated, not stored
|
|
204
200
|
modified_columns = []
|
|
205
201
|
has_distance = False
|
|
@@ -219,7 +215,7 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
219
215
|
if filter_conditions:
|
|
220
216
|
|
|
221
217
|
if embedding_search:
|
|
222
|
-
search_vector = filter_conditions["embeddings"]["value"]
|
|
218
|
+
search_vector = filter_conditions["embeddings"]["value"]
|
|
223
219
|
filter_conditions.pop("embeddings")
|
|
224
220
|
|
|
225
221
|
if self._is_sparse:
|
|
@@ -241,15 +237,15 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
241
237
|
if has_distance:
|
|
242
238
|
targets = f"{targets}, (embeddings {distance_op} '{search_vector}') as distance"
|
|
243
239
|
|
|
244
|
-
return f"SELECT {targets} FROM {table_name} ORDER BY embeddings {distance_op} '{search_vector}' ASC {
|
|
240
|
+
return f"SELECT {targets} FROM {table_name} {where_clause} ORDER BY embeddings {distance_op} '{search_vector}' ASC {limit_clause} {offset_clause} "
|
|
245
241
|
|
|
246
242
|
else:
|
|
247
243
|
# if filter conditions, return rows that satisfy the conditions
|
|
248
|
-
return f"SELECT {targets} FROM {table_name} {
|
|
244
|
+
return f"SELECT {targets} FROM {table_name} {where_clause} {limit_clause} {offset_clause}"
|
|
249
245
|
|
|
250
246
|
else:
|
|
251
247
|
# if no filter conditions, return all rows
|
|
252
|
-
return f"SELECT {targets} FROM {table_name} {
|
|
248
|
+
return f"SELECT {targets} FROM {table_name} {limit_clause} {offset_clause}"
|
|
253
249
|
|
|
254
250
|
def _check_table(self, table_name: str):
|
|
255
251
|
# Apply namespace for a user
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import time
|
|
2
2
|
import json
|
|
3
3
|
from typing import Optional
|
|
4
|
+
import threading
|
|
4
5
|
|
|
5
6
|
import pandas as pd
|
|
6
7
|
import psycopg
|
|
@@ -77,6 +78,8 @@ class PostgresHandler(DatabaseHandler):
|
|
|
77
78
|
self.is_connected = False
|
|
78
79
|
self.thread_safe = True
|
|
79
80
|
|
|
81
|
+
self._insert_lock = threading.Lock()
|
|
82
|
+
|
|
80
83
|
def __del__(self):
|
|
81
84
|
if self.is_connected:
|
|
82
85
|
self.disconnect()
|
|
@@ -261,14 +264,35 @@ class PostgresHandler(DatabaseHandler):
|
|
|
261
264
|
|
|
262
265
|
connection = self.connect()
|
|
263
266
|
|
|
264
|
-
columns =
|
|
267
|
+
columns = df.columns
|
|
268
|
+
|
|
269
|
+
# postgres 'copy' is not thread safe. use lock to prevent concurrent execution
|
|
270
|
+
with self._insert_lock:
|
|
271
|
+
resp = self.get_columns(table_name)
|
|
272
|
+
|
|
273
|
+
# copy requires precise cases of names: get current column names from table and adapt input dataframe columns
|
|
274
|
+
if resp.data_frame is not None and not resp.data_frame.empty:
|
|
275
|
+
db_columns = {
|
|
276
|
+
c.lower(): c
|
|
277
|
+
for c in resp.data_frame['Field']
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
# try to get case of existing column
|
|
281
|
+
columns = [
|
|
282
|
+
db_columns.get(c.lower(), c)
|
|
283
|
+
for c in columns
|
|
284
|
+
]
|
|
285
|
+
|
|
286
|
+
columns = [f'"{c}"' for c in columns]
|
|
265
287
|
rowcount = None
|
|
288
|
+
|
|
266
289
|
with connection.cursor() as cur:
|
|
267
290
|
try:
|
|
268
|
-
with
|
|
269
|
-
|
|
291
|
+
with self._insert_lock:
|
|
292
|
+
with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
|
|
293
|
+
df.to_csv(copy, index=False, header=False)
|
|
270
294
|
|
|
271
|
-
|
|
295
|
+
connection.commit()
|
|
272
296
|
except Exception as e:
|
|
273
297
|
logger.error(f'Error running insert to {table_name} on {self.database}, {e}!')
|
|
274
298
|
connection.rollback()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
|
2
2
|
|
|
3
|
-
from pydantic import BaseModel, ConfigDict
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class BaseLLMConfig(BaseModel):
|
|
@@ -104,3 +104,13 @@ class NvidiaNIMConfig(BaseLLMConfig):
|
|
|
104
104
|
class MindsdbConfig(BaseLLMConfig):
|
|
105
105
|
model_name: str
|
|
106
106
|
project_name: str
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# See https://python.langchain.com/api_reference/google_genai/chat_models/langchain_google_genai.chat_models.ChatGoogleGenerativeAI.html
|
|
110
|
+
class GoogleConfig(BaseLLMConfig):
|
|
111
|
+
model: str = Field(description="Gemini model name to use (e.g., 'gemini-1.5-pro')")
|
|
112
|
+
temperature: Optional[float] = Field(default=None, description="Controls randomness in responses")
|
|
113
|
+
top_p: Optional[float] = Field(default=None, description="Nucleus sampling parameter")
|
|
114
|
+
top_k: Optional[int] = Field(default=None, description="Number of highest probability tokens to consider")
|
|
115
|
+
max_output_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate")
|
|
116
|
+
google_api_key: Optional[str] = Field(default=None, description="API key for Google Generative AI")
|
|
@@ -10,6 +10,7 @@ from mindsdb.integrations.libs.llm.config import (
|
|
|
10
10
|
AnthropicConfig,
|
|
11
11
|
AnyscaleConfig,
|
|
12
12
|
BaseLLMConfig,
|
|
13
|
+
GoogleConfig,
|
|
13
14
|
LiteLLMConfig,
|
|
14
15
|
OllamaConfig,
|
|
15
16
|
OpenAIConfig,
|
|
@@ -31,6 +32,8 @@ DEFAULT_ANTHROPIC_MODEL = "claude-3-haiku-20240307"
|
|
|
31
32
|
DEFAULT_ANYSCALE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
|
32
33
|
DEFAULT_ANYSCALE_BASE_URL = "https://api.endpoints.anyscale.com/v1"
|
|
33
34
|
|
|
35
|
+
DEFAULT_GOOGLE_MODEL = "gemini-2.5-pro-preview-03-25"
|
|
36
|
+
|
|
34
37
|
DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo"
|
|
35
38
|
DEFAULT_LITELLM_PROVIDER = "openai"
|
|
36
39
|
DEFAULT_LITELLM_BASE_URL = "https://ai.dev.mindsdb.com"
|
|
@@ -225,6 +228,15 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
|
|
|
225
228
|
openai_organization=args.get("api_organization", None),
|
|
226
229
|
request_timeout=args.get("request_timeout", None),
|
|
227
230
|
)
|
|
231
|
+
if provider == "google":
|
|
232
|
+
return GoogleConfig(
|
|
233
|
+
model=args.get("model_name", DEFAULT_GOOGLE_MODEL),
|
|
234
|
+
temperature=temperature,
|
|
235
|
+
top_p=args.get("top_p", None),
|
|
236
|
+
top_k=args.get("top_k", None),
|
|
237
|
+
max_output_tokens=args.get("max_tokens", None),
|
|
238
|
+
google_api_key=args["api_keys"].get("google", None),
|
|
239
|
+
)
|
|
228
240
|
|
|
229
241
|
raise ValueError(f"Provider {provider} is not supported.")
|
|
230
242
|
|
|
@@ -15,7 +15,8 @@ SUPPORTED_PROVIDERS = {
|
|
|
15
15
|
"litellm",
|
|
16
16
|
"ollama",
|
|
17
17
|
"nvidia_nim",
|
|
18
|
-
"vllm"
|
|
18
|
+
"vllm",
|
|
19
|
+
"google"
|
|
19
20
|
}
|
|
20
21
|
# Chat models
|
|
21
22
|
ANTHROPIC_CHAT_MODELS = (
|
|
@@ -153,6 +154,15 @@ NVIDIA_NIM_CHAT_MODELS = (
|
|
|
153
154
|
"ibm/granite-34b-code-instruct",
|
|
154
155
|
)
|
|
155
156
|
|
|
157
|
+
GOOGLE_GEMINI_CHAT_MODELS = (
|
|
158
|
+
"gemini-2.5-pro-preview-03-25",
|
|
159
|
+
"gemini-2.0-flash",
|
|
160
|
+
"gemini-2.0-flash-lite",
|
|
161
|
+
"gemini-1.5-flash",
|
|
162
|
+
"gemini-1.5-flash-8b",
|
|
163
|
+
"gemini-1.5-pro",
|
|
164
|
+
)
|
|
165
|
+
|
|
156
166
|
# Define a read-only dictionary mapping providers to their models
|
|
157
167
|
PROVIDER_TO_MODELS = MappingProxyType(
|
|
158
168
|
{
|
|
@@ -160,6 +170,7 @@ PROVIDER_TO_MODELS = MappingProxyType(
|
|
|
160
170
|
"ollama": OLLAMA_CHAT_MODELS,
|
|
161
171
|
"openai": OPEN_AI_CHAT_MODELS,
|
|
162
172
|
"nvidia_nim": NVIDIA_NIM_CHAT_MODELS,
|
|
173
|
+
"google": GOOGLE_GEMINI_CHAT_MODELS,
|
|
163
174
|
}
|
|
164
175
|
)
|
|
165
176
|
|
|
@@ -15,6 +15,7 @@ from langchain_community.chat_models import (
|
|
|
15
15
|
ChatAnyscale,
|
|
16
16
|
ChatLiteLLM,
|
|
17
17
|
ChatOllama)
|
|
18
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
18
19
|
from langchain_core.agents import AgentAction, AgentStep
|
|
19
20
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
|
20
21
|
|
|
@@ -50,6 +51,7 @@ from .constants import (
|
|
|
50
51
|
DEFAULT_TIKTOKEN_MODEL_NAME,
|
|
51
52
|
SUPPORTED_PROVIDERS,
|
|
52
53
|
ANTHROPIC_CHAT_MODELS,
|
|
54
|
+
GOOGLE_GEMINI_CHAT_MODELS,
|
|
53
55
|
OLLAMA_CHAT_MODELS,
|
|
54
56
|
NVIDIA_NIM_CHAT_MODELS,
|
|
55
57
|
USER_COLUMN,
|
|
@@ -85,6 +87,8 @@ def get_llm_provider(args: Dict) -> str:
|
|
|
85
87
|
return "ollama"
|
|
86
88
|
if args["model_name"] in NVIDIA_NIM_CHAT_MODELS:
|
|
87
89
|
return "nvidia_nim"
|
|
90
|
+
if args["model_name"] in GOOGLE_GEMINI_CHAT_MODELS:
|
|
91
|
+
return "google"
|
|
88
92
|
|
|
89
93
|
# For vLLM, require explicit provider specification
|
|
90
94
|
raise ValueError("Invalid model name. Please define a supported llm provider")
|
|
@@ -162,6 +166,8 @@ def create_chat_model(args: Dict):
|
|
|
162
166
|
return ChatOllama(**model_kwargs)
|
|
163
167
|
if args["provider"] == "nvidia_nim":
|
|
164
168
|
return ChatNVIDIA(**model_kwargs)
|
|
169
|
+
if args["provider"] == "google":
|
|
170
|
+
return ChatGoogleGenerativeAI(**model_kwargs)
|
|
165
171
|
if args["provider"] == "mindsdb":
|
|
166
172
|
return ChatMindsdb(**model_kwargs)
|
|
167
173
|
raise ValueError(f'Unknown provider: {args["provider"]}')
|