MindsDB 25.4.1.0__py3-none-any.whl → 25.4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (63) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/api/executor/command_executor.py +91 -61
  3. mindsdb/api/executor/data_types/answer.py +9 -12
  4. mindsdb/api/executor/datahub/classes/response.py +11 -0
  5. mindsdb/api/executor/datahub/datanodes/datanode.py +4 -4
  6. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +10 -11
  7. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +22 -16
  8. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +43 -1
  9. mindsdb/api/executor/datahub/datanodes/project_datanode.py +20 -20
  10. mindsdb/api/executor/planner/plan_join.py +2 -2
  11. mindsdb/api/executor/planner/query_plan.py +1 -0
  12. mindsdb/api/executor/planner/query_planner.py +86 -14
  13. mindsdb/api/executor/planner/steps.py +11 -2
  14. mindsdb/api/executor/sql_query/result_set.py +10 -7
  15. mindsdb/api/executor/sql_query/sql_query.py +69 -84
  16. mindsdb/api/executor/sql_query/steps/__init__.py +1 -0
  17. mindsdb/api/executor/sql_query/steps/delete_step.py +2 -3
  18. mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +5 -3
  19. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +288 -0
  20. mindsdb/api/executor/sql_query/steps/insert_step.py +2 -2
  21. mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -2
  22. mindsdb/api/executor/sql_query/steps/subselect_step.py +20 -8
  23. mindsdb/api/executor/sql_query/steps/update_step.py +4 -6
  24. mindsdb/api/http/namespaces/sql.py +4 -1
  25. mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/ok_packet.py +1 -1
  26. mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +4 -27
  27. mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +1 -0
  28. mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +38 -37
  29. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +23 -13
  30. mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +17 -16
  31. mindsdb/integrations/handlers/langchain_handler/langchain_handler.py +1 -0
  32. mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +1 -1
  33. mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +3 -2
  34. mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +4 -4
  35. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +26 -16
  36. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +36 -7
  37. mindsdb/integrations/handlers/redshift_handler/redshift_handler.py +1 -1
  38. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +18 -11
  39. mindsdb/integrations/libs/llm/config.py +11 -1
  40. mindsdb/integrations/libs/llm/utils.py +12 -0
  41. mindsdb/integrations/libs/ml_handler_process/learn_process.py +1 -2
  42. mindsdb/integrations/libs/response.py +9 -4
  43. mindsdb/integrations/libs/vectordatabase_handler.py +17 -5
  44. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +8 -98
  45. mindsdb/interfaces/agents/constants.py +12 -1
  46. mindsdb/interfaces/agents/langchain_agent.py +6 -0
  47. mindsdb/interfaces/database/log.py +8 -9
  48. mindsdb/interfaces/database/projects.py +1 -5
  49. mindsdb/interfaces/functions/controller.py +59 -17
  50. mindsdb/interfaces/functions/to_markdown.py +194 -0
  51. mindsdb/interfaces/jobs/jobs_controller.py +3 -3
  52. mindsdb/interfaces/knowledge_base/controller.py +223 -97
  53. mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +3 -14
  54. mindsdb/interfaces/query_context/context_controller.py +224 -1
  55. mindsdb/interfaces/storage/db.py +23 -0
  56. mindsdb/migrations/versions/2025-03-21_fda503400e43_queries.py +45 -0
  57. mindsdb/utilities/context_executor.py +1 -1
  58. mindsdb/utilities/partitioning.py +35 -20
  59. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/METADATA +227 -224
  60. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/RECORD +63 -59
  61. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/WHEEL +0 -0
  62. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/licenses/LICENSE +0 -0
  63. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import ast
2
2
  import sys
3
+ import os
3
4
  from typing import Dict, List, Optional, Union
4
5
  import hashlib
5
6
 
@@ -67,6 +68,8 @@ class ChromaDBHandler(VectorStoreHandler):
67
68
  "persist_directory": self.persist_directory,
68
69
  }
69
70
 
71
+ self._use_handler_storage = False
72
+
70
73
  self.connect()
71
74
 
72
75
  def validate_connection_parameters(self, name, **kwargs):
@@ -79,11 +82,15 @@ class ChromaDBHandler(VectorStoreHandler):
79
82
 
80
83
  config = ChromaHandlerConfig(**_config)
81
84
 
82
- if config.persist_directory and not self.handler_storage.is_temporal:
83
- # get full persistence directory from handler storage
84
- self.persist_directory = self.handler_storage.folder_get(
85
- config.persist_directory
86
- )
85
+ if config.persist_directory:
86
+ if os.path.isabs(config.persist_directory):
87
+ self.persist_directory = config.persist_directory
88
+ elif not self.handler_storage.is_temporal:
89
+ # get full persistence directory from handler storage
90
+ self.persist_directory = self.handler_storage.folder_get(
91
+ config.persist_directory
92
+ )
93
+ self._use_handler_storage = True
87
94
 
88
95
  return config
89
96
 
@@ -105,7 +112,7 @@ class ChromaDBHandler(VectorStoreHandler):
105
112
 
106
113
  def _sync(self):
107
114
  """Sync the database to disk if using persistent storage"""
108
- if self.persist_directory:
115
+ if self.persist_directory and self._use_handler_storage:
109
116
  self.handler_storage.folder_sync(self.persist_directory)
110
117
 
111
118
  def __del__(self):
@@ -162,6 +169,8 @@ class ChromaDBHandler(VectorStoreHandler):
162
169
  FilterOperator.LESS_THAN_OR_EQUAL: "$lte",
163
170
  FilterOperator.GREATER_THAN: "$gt",
164
171
  FilterOperator.GREATER_THAN_OR_EQUAL: "$gte",
172
+ FilterOperator.IN: "$in",
173
+ FilterOperator.NOT_IN: "$nin",
165
174
  }
166
175
 
167
176
  if operator not in mapping:
@@ -308,7 +317,7 @@ class ChromaDBHandler(VectorStoreHandler):
308
317
  }
309
318
 
310
319
  if columns is not None:
311
- payload = {column: payload[column] for column in columns}
320
+ payload = {column: payload[column] for column in columns if column != TableField.DISTANCE.value}
312
321
 
313
322
  # always include distance
314
323
  distance_filter = None
@@ -316,10 +325,11 @@ class ChromaDBHandler(VectorStoreHandler):
316
325
  if distances is not None:
317
326
  payload[distance_col] = distances
318
327
 
319
- for cond in conditions:
320
- if cond.column == distance_col:
321
- distance_filter = cond
322
- break
328
+ if conditions is not None:
329
+ for cond in conditions:
330
+ if cond.column == distance_col:
331
+ distance_filter = cond
332
+ break
323
333
 
324
334
  df = pd.DataFrame(payload)
325
335
  if distance_filter is not None:
@@ -413,8 +423,8 @@ class ChromaDBHandler(VectorStoreHandler):
413
423
  collection.upsert(
414
424
  ids=data_dict[TableField.ID.value],
415
425
  documents=data_dict[TableField.CONTENT.value],
416
- embeddings=data_dict.get(TableField.EMBEDDINGS.value),
417
- metadatas=data_dict.get(TableField.METADATA.value)
426
+ embeddings=data_dict.get(TableField.EMBEDDINGS.value, None),
427
+ metadatas=data_dict.get(TableField.METADATA.value, None)
418
428
  )
419
429
  self._sync()
420
430
  except Exception as e:
@@ -104,6 +104,22 @@ def construct_model_from_args(args: Dict) -> Embeddings:
104
104
  return model
105
105
 
106
106
 
107
+ def row_to_document(row: pd.Series) -> str:
108
+ """
109
+ Convert a row in the input dataframe into a document
110
+
111
+ Default implementation is to concatenate all the columns
112
+ in the form of
113
+ field1: value1\nfield2: value2\n...
114
+ """
115
+ fields = row.index.tolist()
116
+ values = row.values.tolist()
117
+ document = "\n".join(
118
+ [f"{field}: {value}" for field, value in zip(fields, values)]
119
+ )
120
+ return document
121
+
122
+
107
123
  class LangchainEmbeddingHandler(BaseMLEngine):
108
124
  """
109
125
  Bridge class to connect langchain.embeddings module to mindsDB
@@ -180,7 +196,7 @@ class LangchainEmbeddingHandler(BaseMLEngine):
180
196
  )
181
197
 
182
198
  # convert each row into a document
183
- df_texts = df[input_columns].apply(self.row_to_document, axis=1)
199
+ df_texts = df[input_columns].apply(row_to_document, axis=1)
184
200
  embeddings = model.embed_documents(df_texts.tolist())
185
201
 
186
202
  # create a new dataframe with the embeddings
@@ -188,21 +204,6 @@ class LangchainEmbeddingHandler(BaseMLEngine):
188
204
 
189
205
  return df_embeddings
190
206
 
191
- def row_to_document(self, row: pd.Series) -> str:
192
- """
193
- Convert a row in the input dataframe into a document
194
-
195
- Default implementation is to concatenate all the columns
196
- in the form of
197
- field1: value1\nfield2: value2\n...
198
- """
199
- fields = row.index.tolist()
200
- values = row.values.tolist()
201
- document = "\n".join(
202
- [f"{field}: {value}" for field, value in zip(fields, values)]
203
- )
204
- return document
205
-
206
207
  def finetune(
207
208
  self, df: Union[DataFrame, None] = None, args: Union[Dict, None] = None
208
209
  ) -> None:
@@ -50,6 +50,7 @@ class LangChainHandler(BaseMLEngine):
50
50
  - OpenAI
51
51
  - Anthropic
52
52
  - Anyscale
53
+ - Google
53
54
  - LiteLLM
54
55
  - Ollama
55
56
 
@@ -177,7 +177,7 @@ class SqlServerHandler(DatabaseHandler):
177
177
  )
178
178
  )
179
179
  else:
180
- response = Response(RESPONSE_TYPE.OK)
180
+ response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
181
181
  connection.commit()
182
182
  except Exception as e:
183
183
  logger.error(f'Error running query: {query} on {self.database}, {e}!')
@@ -178,10 +178,11 @@ class MySQLHandler(DatabaseHandler):
178
178
  pd.DataFrame(
179
179
  result,
180
180
  columns=[x[0] for x in cur.description]
181
- )
181
+ ),
182
+ affected_rows=cur.rowcount
182
183
  )
183
184
  else:
184
- response = Response(RESPONSE_TYPE.OK)
185
+ response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
185
186
  except mysql.connector.Error as e:
186
187
  logger.error(f'Error running query: {query} on {self.connection_data["database"]}!')
187
188
  response = Response(
@@ -205,8 +205,10 @@ class OracleHandler(DatabaseHandler):
205
205
  with connection.cursor() as cur:
206
206
  try:
207
207
  cur.execute(query)
208
- result = cur.fetchall()
209
- if result:
208
+ if cur.description is None:
209
+ response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
210
+ else:
211
+ result = cur.fetchall()
210
212
  response = Response(
211
213
  RESPONSE_TYPE.TABLE,
212
214
  data_frame=pd.DataFrame(
@@ -214,8 +216,6 @@ class OracleHandler(DatabaseHandler):
214
216
  columns=[row[0] for row in cur.description],
215
217
  ),
216
218
  )
217
- else:
218
- response = Response(RESPONSE_TYPE.OK)
219
219
 
220
220
  connection.commit()
221
221
  except DatabaseError as database_error:
@@ -46,7 +46,8 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
46
46
 
47
47
  def _make_connection_args(self):
48
48
  cloud_pgvector_url = os.environ.get('KB_PGVECTOR_URL')
49
- if cloud_pgvector_url is not None:
49
+ # if no connection args and shared pg vector defined - use it
50
+ if len(self.connection_args) == 0 and cloud_pgvector_url is not None:
50
51
  result = urlparse(cloud_pgvector_url)
51
52
  self.connection_args = {
52
53
  'host': result.hostname,
@@ -149,7 +150,7 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
149
150
  for key, value in filter_conditions.items():
150
151
  if key == "embeddings":
151
152
  continue
152
- if value['op'].lower() == 'in':
153
+ if value['op'].lower() in ('in', 'not in'):
153
154
  values = list(repr(i) for i in value['value'])
154
155
  value['value'] = '({})'.format(', '.join(values))
155
156
  else:
@@ -157,7 +158,7 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
157
158
  where_clauses.append(f'{key} {value["op"]} {value["value"]}')
158
159
 
159
160
  if len(where_clauses) > 1:
160
- return f"WHERE{' AND '.join(where_clauses)}"
161
+ return f"WHERE {' AND '.join(where_clauses)}"
161
162
  elif len(where_clauses) == 1:
162
163
  return f"WHERE {where_clauses[0]}"
163
164
  else:
@@ -165,9 +166,9 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
165
166
 
166
167
  @staticmethod
167
168
  def _construct_full_after_from_clause(
169
+ where_clause: str,
168
170
  offset_clause: str,
169
171
  limit_clause: str,
170
- where_clause: str,
171
172
  ) -> str:
172
173
 
173
174
  return f"{where_clause} {offset_clause} {limit_clause}"
@@ -195,21 +196,26 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
195
196
  # given filter conditions, construct where clause
196
197
  where_clause = self._construct_where_clause(filter_conditions)
197
198
 
198
- # construct full after from clause, where clause + offset clause + limit clause
199
- after_from_clause = self._construct_full_after_from_clause(
200
- where_clause, offset_clause, limit_clause
201
- )
202
-
203
- if columns is None:
204
- targets = '*'
199
+ # Handle distance column specially since it's calculated, not stored
200
+ modified_columns = []
201
+ has_distance = False
202
+ if columns is not None:
203
+ for col in columns:
204
+ if col == TableField.DISTANCE.value:
205
+ has_distance = True
206
+ else:
207
+ modified_columns.append(col)
205
208
  else:
206
- targets = ', '.join(columns)
209
+ modified_columns = ['id', 'content', 'embeddings', 'metadata']
210
+ has_distance = True
211
+
212
+ targets = ', '.join(modified_columns)
207
213
 
208
214
 
209
215
  if filter_conditions:
210
216
 
211
217
  if embedding_search:
212
- search_vector = filter_conditions["embeddings"]["value"][0]
218
+ search_vector = filter_conditions["embeddings"]["value"]
213
219
  filter_conditions.pop("embeddings")
214
220
 
215
221
  if self._is_sparse:
@@ -227,15 +233,19 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
227
233
  # Use cosine similarity for dense vectors
228
234
  distance_op = "<=>"
229
235
 
230
- return f"SELECT {targets} FROM {table_name} ORDER BY embeddings {distance_op} '{search_vector}' ASC {after_from_clause}"
236
+ # Calculate distance as part of the query if needed
237
+ if has_distance:
238
+ targets = f"{targets}, (embeddings {distance_op} '{search_vector}') as distance"
239
+
240
+ return f"SELECT {targets} FROM {table_name} {where_clause} ORDER BY embeddings {distance_op} '{search_vector}' ASC {limit_clause} {offset_clause} "
231
241
 
232
242
  else:
233
243
  # if filter conditions, return rows that satisfy the conditions
234
- return f"SELECT {targets} FROM {table_name} {after_from_clause}"
244
+ return f"SELECT {targets} FROM {table_name} {where_clause} {limit_clause} {offset_clause}"
235
245
 
236
246
  else:
237
247
  # if no filter conditions, return all rows
238
- return f"SELECT {targets} FROM {table_name} {after_from_clause}"
248
+ return f"SELECT {targets} FROM {table_name} {limit_clause} {offset_clause}"
239
249
 
240
250
  def _check_table(self, table_name: str):
241
251
  # Apply namespace for a user
@@ -1,6 +1,7 @@
1
1
  import time
2
2
  import json
3
3
  from typing import Optional
4
+ import threading
4
5
 
5
6
  import pandas as pd
6
7
  import psycopg
@@ -77,6 +78,8 @@ class PostgresHandler(DatabaseHandler):
77
78
  self.is_connected = False
78
79
  self.thread_safe = True
79
80
 
81
+ self._insert_lock = threading.Lock()
82
+
80
83
  def __del__(self):
81
84
  if self.is_connected:
82
85
  self.disconnect()
@@ -228,7 +231,7 @@ class PostgresHandler(DatabaseHandler):
228
231
  else:
229
232
  cur.execute(query)
230
233
  if cur.pgresult is None or ExecStatus(cur.pgresult.status) == ExecStatus.COMMAND_OK:
231
- response = Response(RESPONSE_TYPE.OK)
234
+ response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
232
235
  else:
233
236
  result = cur.fetchall()
234
237
  df = DataFrame(
@@ -238,7 +241,8 @@ class PostgresHandler(DatabaseHandler):
238
241
  self._cast_dtypes(df, cur.description)
239
242
  response = Response(
240
243
  RESPONSE_TYPE.TABLE,
241
- df
244
+ data_frame=df,
245
+ affected_rows=cur.rowcount
242
246
  )
243
247
  connection.commit()
244
248
  except Exception as e:
@@ -255,26 +259,51 @@ class PostgresHandler(DatabaseHandler):
255
259
 
256
260
  return response
257
261
 
258
- def insert(self, table_name: str, df: pd.DataFrame):
262
+ def insert(self, table_name: str, df: pd.DataFrame) -> Response:
259
263
  need_to_close = not self.is_connected
260
264
 
261
265
  connection = self.connect()
262
266
 
263
- columns = [f'"{c}"' for c in df.columns]
267
+ columns = df.columns
268
+
269
+ # postgres 'copy' is not thread safe. use lock to prevent concurrent execution
270
+ with self._insert_lock:
271
+ resp = self.get_columns(table_name)
272
+
273
+ # copy requires precise cases of names: get current column names from table and adapt input dataframe columns
274
+ if resp.data_frame is not None and not resp.data_frame.empty:
275
+ db_columns = {
276
+ c.lower(): c
277
+ for c in resp.data_frame['Field']
278
+ }
279
+
280
+ # try to get case of existing column
281
+ columns = [
282
+ db_columns.get(c.lower(), c)
283
+ for c in columns
284
+ ]
285
+
286
+ columns = [f'"{c}"' for c in columns]
287
+ rowcount = None
288
+
264
289
  with connection.cursor() as cur:
265
290
  try:
266
- with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
267
- df.to_csv(copy, index=False, header=False)
291
+ with self._insert_lock:
292
+ with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
293
+ df.to_csv(copy, index=False, header=False)
268
294
 
269
- connection.commit()
295
+ connection.commit()
270
296
  except Exception as e:
271
297
  logger.error(f'Error running insert to {table_name} on {self.database}, {e}!')
272
298
  connection.rollback()
273
299
  raise e
300
+ rowcount = cur.rowcount
274
301
 
275
302
  if need_to_close:
276
303
  self.disconnect()
277
304
 
305
+ return Response(RESPONSE_TYPE.OK, affected_rows=rowcount)
306
+
278
307
  @profiler.profile()
279
308
  def query(self, query: ASTNode) -> Response:
280
309
  """
@@ -52,7 +52,7 @@ class RedshiftHandler(PostgresHandler):
52
52
  with connection.cursor() as cur:
53
53
  try:
54
54
  cur.executemany(query, df.values.tolist())
55
- response = Response(RESPONSE_TYPE.OK)
55
+ response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
56
56
 
57
57
  connection.commit()
58
58
  except Exception as e:
@@ -230,18 +230,25 @@ class SnowflakeHandler(DatabaseHandler):
230
230
  # Fallback for CREATE/DELETE/UPDATE. These commands returns table with single column,
231
231
  # but it cannot be retrieved as pandas DataFrame.
232
232
  result = cur.fetchall()
233
- if result:
234
- response = Response(
235
- RESPONSE_TYPE.TABLE,
236
- DataFrame(
237
- result,
238
- columns=[x[0] for x in cur.description]
233
+ match result:
234
+ case (
235
+ [{'number of rows inserted': affected_rows}]
236
+ | [{'number of rows deleted': affected_rows}]
237
+ | [{'number of rows updated': affected_rows, 'number of multi-joined rows updated': _}]
238
+ ):
239
+ response = Response(RESPONSE_TYPE.OK, affected_rows=affected_rows)
240
+ case list():
241
+ response = Response(
242
+ RESPONSE_TYPE.TABLE,
243
+ DataFrame(
244
+ result,
245
+ columns=[x[0] for x in cur.description]
246
+ )
239
247
  )
240
- )
241
- else:
242
- # Looks like SnowFlake always returns something in response, so this is suspicious
243
- logger.warning('Snowflake did not return any data in response.')
244
- response = Response(RESPONSE_TYPE.OK)
248
+ case _:
249
+ # Looks like SnowFlake always returns something in response, so this is suspicious
250
+ logger.warning('Snowflake did not return any data in response.')
251
+ response = Response(RESPONSE_TYPE.OK)
245
252
  except Exception as e:
246
253
  logger.error(f"Error running query: {query} on {self.connection_data.get('database')}, {e}!")
247
254
  response = Response(
@@ -1,6 +1,6 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from pydantic import BaseModel, ConfigDict
3
+ from pydantic import BaseModel, ConfigDict, Field
4
4
 
5
5
 
6
6
  class BaseLLMConfig(BaseModel):
@@ -104,3 +104,13 @@ class NvidiaNIMConfig(BaseLLMConfig):
104
104
  class MindsdbConfig(BaseLLMConfig):
105
105
  model_name: str
106
106
  project_name: str
107
+
108
+
109
+ # See https://python.langchain.com/api_reference/google_genai/chat_models/langchain_google_genai.chat_models.ChatGoogleGenerativeAI.html
110
+ class GoogleConfig(BaseLLMConfig):
111
+ model: str = Field(description="Gemini model name to use (e.g., 'gemini-1.5-pro')")
112
+ temperature: Optional[float] = Field(default=None, description="Controls randomness in responses")
113
+ top_p: Optional[float] = Field(default=None, description="Nucleus sampling parameter")
114
+ top_k: Optional[int] = Field(default=None, description="Number of highest probability tokens to consider")
115
+ max_output_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate")
116
+ google_api_key: Optional[str] = Field(default=None, description="API key for Google Generative AI")
@@ -10,6 +10,7 @@ from mindsdb.integrations.libs.llm.config import (
10
10
  AnthropicConfig,
11
11
  AnyscaleConfig,
12
12
  BaseLLMConfig,
13
+ GoogleConfig,
13
14
  LiteLLMConfig,
14
15
  OllamaConfig,
15
16
  OpenAIConfig,
@@ -31,6 +32,8 @@ DEFAULT_ANTHROPIC_MODEL = "claude-3-haiku-20240307"
31
32
  DEFAULT_ANYSCALE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
32
33
  DEFAULT_ANYSCALE_BASE_URL = "https://api.endpoints.anyscale.com/v1"
33
34
 
35
+ DEFAULT_GOOGLE_MODEL = "gemini-2.5-pro-preview-03-25"
36
+
34
37
  DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo"
35
38
  DEFAULT_LITELLM_PROVIDER = "openai"
36
39
  DEFAULT_LITELLM_BASE_URL = "https://ai.dev.mindsdb.com"
@@ -225,6 +228,15 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
225
228
  openai_organization=args.get("api_organization", None),
226
229
  request_timeout=args.get("request_timeout", None),
227
230
  )
231
+ if provider == "google":
232
+ return GoogleConfig(
233
+ model=args.get("model_name", DEFAULT_GOOGLE_MODEL),
234
+ temperature=temperature,
235
+ top_p=args.get("top_p", None),
236
+ top_k=args.get("top_k", None),
237
+ max_output_tokens=args.get("max_tokens", None),
238
+ google_api_key=args["api_keys"].get("google", None),
239
+ )
228
240
 
229
241
  raise ValueError(f"Provider {provider} is not supported.")
230
242
 
@@ -78,8 +78,7 @@ def learn_process(data_integration_ref: dict, problem_definition: dict, fetch_da
78
78
  query_ast = parse_sql(fetch_data_query)
79
79
  sqlquery = SQLQuery(query_ast, session=sql_session)
80
80
 
81
- result = sqlquery.fetch(view='dataframe')
82
- training_data_df = result['result']
81
+ training_data_df = sqlquery.fetched_data.to_df()
83
82
 
84
83
  training_data_columns_count, training_data_rows_count = 0, 0
85
84
  if training_data_df is not None:
@@ -1,3 +1,4 @@
1
+ from typing import Optional
1
2
  from pandas import DataFrame
2
3
 
3
4
  from mindsdb.utilities import log
@@ -8,13 +9,16 @@ from mindsdb_sql_parser.ast import ASTNode
8
9
  logger = log.getLogger(__name__)
9
10
 
10
11
  class HandlerResponse:
11
- def __init__(self, resp_type: RESPONSE_TYPE, data_frame: DataFrame = None,
12
- query: ASTNode = 0, error_code: int = 0, error_message: str = None) -> None:
12
+ def __init__(self, resp_type: RESPONSE_TYPE, data_frame: DataFrame = None, query: ASTNode = 0, error_code: int = 0,
13
+ error_message: Optional[str] = None, affected_rows: Optional[int] = None) -> None:
13
14
  self.resp_type = resp_type
14
15
  self.query = query
15
16
  self.data_frame = data_frame
16
17
  self.error_code = error_code
17
18
  self.error_message = error_message
19
+ self.affected_rows = affected_rows
20
+ if isinstance(self.affected_rows, int) is False or self.affected_rows < 0:
21
+ self.affected_rows = 0
18
22
 
19
23
  @property
20
24
  def type(self):
@@ -35,13 +39,14 @@ class HandlerResponse:
35
39
  "error": self.error_message}
36
40
 
37
41
  def __repr__(self):
38
- return "%s: resp_type=%s, query=%s, data_frame=%s, err_code=%s, error=%s" % (
42
+ return "%s: resp_type=%s, query=%s, data_frame=%s, err_code=%s, error=%s, affected_rows=%s" % (
39
43
  self.__class__.__name__,
40
44
  self.resp_type,
41
45
  self.query,
42
46
  self.data_frame,
43
47
  self.error_code,
44
- self.error_message
48
+ self.error_message,
49
+ self.affected_rows
45
50
  )
46
51
 
47
52
  class HandlerStatusResponse:
@@ -20,7 +20,7 @@ from mindsdb_sql_parser.ast.base import ASTNode
20
20
 
21
21
  from mindsdb.integrations.libs.response import RESPONSE_TYPE, HandlerResponse
22
22
  from mindsdb.utilities import log
23
- from mindsdb.integrations.utilities.sql_utils import conditions_to_filter, FilterCondition, FilterOperator
23
+ from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
24
24
 
25
25
  from mindsdb.integrations.utilities.query_traversal import query_traversal
26
26
  from .base import BaseHandler
@@ -39,6 +39,7 @@ class TableField(Enum):
39
39
  METADATA = "metadata"
40
40
  SEARCH_VECTOR = "search_vector"
41
41
  DISTANCE = "distance"
42
+ RELEVANCE = "relevance"
42
43
 
43
44
 
44
45
  class DistanceFunction(Enum):
@@ -69,6 +70,10 @@ class VectorStoreHandler(BaseHandler):
69
70
  "name": TableField.METADATA.value,
70
71
  "data_type": "json",
71
72
  },
73
+ {
74
+ "name": TableField.DISTANCE.value,
75
+ "data_type": "float",
76
+ },
72
77
  ]
73
78
 
74
79
  def validate_connection_parameters(self, name, **kwargs):
@@ -231,7 +236,7 @@ class VectorStoreHandler(BaseHandler):
231
236
 
232
237
  return self.do_upsert(table_name, pd.DataFrame(data))
233
238
 
234
- def _dispatch_update(self, query: Update):
239
+ def dispatch_update(self, query: Update, conditions: List[FilterCondition] = None):
235
240
  """
236
241
  Dispatch update query to the appropriate method.
237
242
  """
@@ -250,8 +255,15 @@ class VectorStoreHandler(BaseHandler):
250
255
  pass
251
256
  row[k] = v
252
257
 
253
- filters = conditions_to_filter(query.where)
254
- row.update(filters)
258
+ if conditions is None:
259
+ where_statement = query.where
260
+ conditions = self.extract_conditions(where_statement)
261
+
262
+ for condition in conditions:
263
+ if condition.op != FilterOperator.EQUAL:
264
+ raise NotImplementedError
265
+
266
+ row[condition.column] = condition.value
255
267
 
256
268
  # checks
257
269
  if TableField.EMBEDDINGS.value not in row:
@@ -381,7 +393,7 @@ class VectorStoreHandler(BaseHandler):
381
393
  CreateTable: self._dispatch_create_table,
382
394
  DropTables: self._dispatch_drop_table,
383
395
  Insert: self._dispatch_insert,
384
- Update: self._dispatch_update,
396
+ Update: self.dispatch_update,
385
397
  Delete: self.dispatch_delete,
386
398
  Select: self.dispatch_select,
387
399
  }