MindsDB 25.4.2.0__py3-none-any.whl → 25.4.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (39) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/__main__.py +30 -7
  3. mindsdb/api/executor/command_executor.py +29 -0
  4. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +3 -2
  5. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +43 -1
  6. mindsdb/api/executor/planner/plan_join.py +1 -1
  7. mindsdb/api/executor/planner/query_plan.py +1 -0
  8. mindsdb/api/executor/planner/query_planner.py +86 -14
  9. mindsdb/api/executor/planner/steps.py +9 -1
  10. mindsdb/api/executor/sql_query/sql_query.py +37 -6
  11. mindsdb/api/executor/sql_query/steps/__init__.py +1 -0
  12. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +231 -0
  13. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +2 -1
  14. mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +17 -16
  15. mindsdb/integrations/handlers/langchain_handler/langchain_handler.py +1 -0
  16. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +7 -11
  17. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +28 -4
  18. mindsdb/integrations/libs/llm/config.py +11 -1
  19. mindsdb/integrations/libs/llm/utils.py +12 -0
  20. mindsdb/integrations/libs/vectordatabase_handler.py +9 -1
  21. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +1 -1
  22. mindsdb/interfaces/agents/constants.py +12 -1
  23. mindsdb/interfaces/agents/langchain_agent.py +6 -0
  24. mindsdb/interfaces/database/projects.py +7 -1
  25. mindsdb/interfaces/knowledge_base/controller.py +166 -74
  26. mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +43 -62
  27. mindsdb/interfaces/knowledge_base/utils.py +28 -0
  28. mindsdb/interfaces/query_context/context_controller.py +221 -0
  29. mindsdb/interfaces/storage/db.py +23 -0
  30. mindsdb/migrations/versions/2025-03-21_fda503400e43_queries.py +45 -0
  31. mindsdb/utilities/auth.py +5 -1
  32. mindsdb/utilities/cache.py +4 -1
  33. mindsdb/utilities/context_executor.py +1 -1
  34. mindsdb/utilities/partitioning.py +35 -20
  35. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.3.0.dist-info}/METADATA +221 -219
  36. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.3.0.dist-info}/RECORD +39 -36
  37. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.3.0.dist-info}/WHEEL +0 -0
  38. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.3.0.dist-info}/licenses/LICENSE +0 -0
  39. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,231 @@
1
+ import pandas as pd
2
+ from typing import List
3
+
4
+ from mindsdb_sql_parser import ASTNode
5
+ from mindsdb.api.executor.planner.steps import FetchDataframeStepPartition
6
+ from mindsdb.integrations.utilities.query_traversal import query_traversal
7
+
8
+ from mindsdb.interfaces.query_context.context_controller import RunningQuery
9
+ from mindsdb.api.executor.sql_query.result_set import ResultSet
10
+ from mindsdb.utilities import log
11
+ from mindsdb.utilities.config import Config
12
+ from mindsdb.utilities.partitioning import get_max_thread_count, split_data_frame
13
+ from mindsdb.api.executor.sql_query.steps.fetch_dataframe import get_table_alias, get_fill_param_fnc
14
+ from mindsdb.utilities.context_executor import ContextThreadPoolExecutor
15
+
16
+
17
+ from .base import BaseStepCall
18
+
19
+
20
+ logger = log.getLogger(__name__)
21
+
22
+
23
+ class FetchDataframePartitionCall(BaseStepCall):
24
+ """
25
+ Alternative to FetchDataframeCall but fetch data by batches wrapping user's query to:
26
+
27
+ select * from ({user query})
28
+ where {track_column} > {previous value}
29
+ order by track_column
30
+ limit size {batch_size} `
31
+
32
+ """
33
+
34
+ bind = FetchDataframeStepPartition
35
+
36
+ def call(self, step: FetchDataframeStepPartition) -> ResultSet:
37
+ """
38
+ Parameters:
39
+ - batch_size - count of rows to fetch from database per iteration, optional default 1000
40
+ - threads - run partitioning in threads, bool or int, optinal, if set:
41
+ - int value: use this as count of threads
42
+ - true: table threads, autodetect count of thread
43
+ - false: disable threads even if ml task queue is enabled
44
+ - track_column - column used for creating partitions
45
+ - query will be sorted by this column and select will be limited by batch_size
46
+ - error (default raise)
47
+ - when `error='skip'`, errors in partition will be skipped and execution will be continued
48
+ """
49
+
50
+ self.dn = self.session.datahub.get(step.integration)
51
+ query = step.query
52
+
53
+ # fill params
54
+ fill_params = get_fill_param_fnc(self.steps_data)
55
+ query_traversal(query, fill_params)
56
+
57
+ # get query record
58
+ run_query = self.sql_query.run_query
59
+ if run_query is None:
60
+ raise RuntimeError('Error with partitioning of the query')
61
+ run_query.set_params(step.params)
62
+
63
+ self.table_alias = get_table_alias(step.query.from_table, self.context.get('database'))
64
+ self.current_step_num = step.step_num
65
+ self.substeps = step.steps
66
+
67
+ config = Config()
68
+
69
+ # ml task queue enabled?
70
+ use_threads, thread_count = False, None
71
+ if config['ml_task_queue']['type'] == 'redis':
72
+ use_threads = True
73
+
74
+ # use threads?
75
+ if 'threads' in step.params:
76
+ threads = step.params['threads']
77
+ if isinstance(threads, int):
78
+ thread_count = threads
79
+ use_threads = True
80
+ if threads is True:
81
+ use_threads = True
82
+ if threads is False:
83
+ # disable even with ml task queue
84
+ use_threads = False
85
+
86
+ on_error = step.params.get('error', 'raise')
87
+ if use_threads:
88
+ return self.fetch_threads(run_query, query, thread_count=thread_count, on_error=on_error)
89
+ else:
90
+ return self.fetch_iterate(run_query, query, on_error=on_error)
91
+
92
+ def fetch_iterate(self, run_query: RunningQuery, query: ASTNode, on_error: str = None) -> ResultSet:
93
+ """
94
+ Process batches one by one in circle
95
+ """
96
+
97
+ results = []
98
+ while True:
99
+
100
+ # fetch batch
101
+ query2 = run_query.get_partition_query(self.current_step_num, query)
102
+ response = self.dn.query(
103
+ query=query2,
104
+ session=self.session
105
+ )
106
+ df = response.data_frame
107
+
108
+ if df is None or len(df) == 0:
109
+ break
110
+
111
+ # executing of sub steps can modify dataframe columns, lets memorise max tracking value
112
+ max_track_value = run_query.get_max_track_value(df)
113
+ try:
114
+ sub_data = self.exec_sub_steps(df)
115
+ results.append(sub_data)
116
+ except Exception as e:
117
+ if on_error == 'skip':
118
+ logger.error(e)
119
+ else:
120
+ raise e
121
+
122
+ run_query.set_progress(df, max_track_value)
123
+
124
+ return self.concat_results(results)
125
+
126
+ def concat_results(self, results: List[ResultSet]) -> ResultSet:
127
+ """
128
+ Concatenate list of result sets to single result set
129
+ """
130
+ df_list = []
131
+ for res in results:
132
+ df, col_names = res.to_df_cols()
133
+ if len(df) > 0:
134
+ df_list.append(df)
135
+
136
+ data = ResultSet()
137
+ if len(df_list) > 0:
138
+ data.from_df_cols(pd.concat(df_list), col_names)
139
+
140
+ return data
141
+
142
+ def exec_sub_steps(self, df: pd.DataFrame) -> ResultSet:
143
+ """
144
+ FetchDataframeStepPartition has substeps defined
145
+ Every batch of data have to be used to execute these substeps
146
+ - batch of data is put as result of FetchDataframeStepPartition
147
+ - substep are executed using result of previos step (like it is all fetched data is available)
148
+ - the final result is returned and used outside to concatenate with results of other's batches
149
+ """
150
+
151
+ input_data = ResultSet()
152
+
153
+ input_data.from_df(
154
+ df,
155
+ table_name=self.table_alias[1],
156
+ table_alias=self.table_alias[2],
157
+ database=self.table_alias[0]
158
+ )
159
+
160
+ # execute with modified previous results
161
+ steps_data2 = self.steps_data.copy()
162
+ steps_data2[self.current_step_num] = input_data
163
+
164
+ sub_data = None
165
+ for substep in self.substeps:
166
+ sub_data = self.sql_query.execute_step(substep, steps_data=steps_data2)
167
+ steps_data2[substep.step_num] = sub_data
168
+ return sub_data
169
+
170
+ def fetch_threads(self, run_query: RunningQuery, query: ASTNode,
171
+ thread_count: int = None, on_error: str = None) -> ResultSet:
172
+ """
173
+ Process batches in threads
174
+ - spawn required count of threads
175
+ - create in/out queue to communicate with threads
176
+ - send task to threads and receive results
177
+ """
178
+
179
+ # create communication queues
180
+
181
+ if thread_count is None:
182
+ thread_count = get_max_thread_count()
183
+
184
+ # 3 tasks per worker during 1 batch
185
+ partition_size = int(run_query.batch_size / thread_count / 3)
186
+ # min partition size
187
+ if partition_size < 10:
188
+ partition_size = 10
189
+
190
+ results = []
191
+
192
+ with ContextThreadPoolExecutor(max_workers=thread_count) as executor:
193
+
194
+ while True:
195
+ # fetch batch
196
+ query2 = run_query.get_partition_query(self.current_step_num, query)
197
+ response = self.dn.query(
198
+ query=query2,
199
+ session=self.session
200
+ )
201
+ df = response.data_frame
202
+
203
+ if df is None or len(df) == 0:
204
+ # TODO detect circles: data handler ignores condition and output is repeated
205
+
206
+ # exit & stop workers
207
+ break
208
+
209
+ max_track_value = run_query.get_max_track_value(df)
210
+
211
+ # split into chunks and send to workers
212
+ futures = []
213
+ for df2 in split_data_frame(df, partition_size):
214
+ futures.append(executor.submit(self.exec_sub_steps, df2))
215
+
216
+ for future in futures:
217
+ try:
218
+ results.append(future.result())
219
+ except Exception as e:
220
+ if on_error == 'skip':
221
+ logger.error(e)
222
+ else:
223
+ executor.shutdown()
224
+ raise e
225
+
226
+ # TODO
227
+ # 1. get next batch without updating track_value:
228
+ # it allows to keep queue_in filled with data between fetching batches
229
+ run_query.set_progress(df, max_track_value)
230
+
231
+ return self.concat_results(results)
@@ -244,6 +244,7 @@ class ChromaDBHandler(VectorStoreHandler):
244
244
  offset: int = None,
245
245
  limit: int = None,
246
246
  ) -> pd.DataFrame:
247
+
247
248
  collection = self._client.get_collection(table_name)
248
249
  filters = self._translate_metadata_condition(conditions)
249
250
 
@@ -313,7 +314,7 @@ class ChromaDBHandler(VectorStoreHandler):
313
314
  TableField.ID.value: ids,
314
315
  TableField.CONTENT.value: documents,
315
316
  TableField.METADATA.value: metadatas,
316
- TableField.EMBEDDINGS.value: embeddings,
317
+ TableField.EMBEDDINGS.value: list(embeddings),
317
318
  }
318
319
 
319
320
  if columns is not None:
@@ -104,6 +104,22 @@ def construct_model_from_args(args: Dict) -> Embeddings:
104
104
  return model
105
105
 
106
106
 
107
+ def row_to_document(row: pd.Series) -> str:
108
+ """
109
+ Convert a row in the input dataframe into a document
110
+
111
+ Default implementation is to concatenate all the columns
112
+ in the form of
113
+ field1: value1\nfield2: value2\n...
114
+ """
115
+ fields = row.index.tolist()
116
+ values = row.values.tolist()
117
+ document = "\n".join(
118
+ [f"{field}: {value}" for field, value in zip(fields, values)]
119
+ )
120
+ return document
121
+
122
+
107
123
  class LangchainEmbeddingHandler(BaseMLEngine):
108
124
  """
109
125
  Bridge class to connect langchain.embeddings module to mindsDB
@@ -180,7 +196,7 @@ class LangchainEmbeddingHandler(BaseMLEngine):
180
196
  )
181
197
 
182
198
  # convert each row into a document
183
- df_texts = df[input_columns].apply(self.row_to_document, axis=1)
199
+ df_texts = df[input_columns].apply(row_to_document, axis=1)
184
200
  embeddings = model.embed_documents(df_texts.tolist())
185
201
 
186
202
  # create a new dataframe with the embeddings
@@ -188,21 +204,6 @@ class LangchainEmbeddingHandler(BaseMLEngine):
188
204
 
189
205
  return df_embeddings
190
206
 
191
- def row_to_document(self, row: pd.Series) -> str:
192
- """
193
- Convert a row in the input dataframe into a document
194
-
195
- Default implementation is to concatenate all the columns
196
- in the form of
197
- field1: value1\nfield2: value2\n...
198
- """
199
- fields = row.index.tolist()
200
- values = row.values.tolist()
201
- document = "\n".join(
202
- [f"{field}: {value}" for field, value in zip(fields, values)]
203
- )
204
- return document
205
-
206
207
  def finetune(
207
208
  self, df: Union[DataFrame, None] = None, args: Union[Dict, None] = None
208
209
  ) -> None:
@@ -50,6 +50,7 @@ class LangChainHandler(BaseMLEngine):
50
50
  - OpenAI
51
51
  - Anthropic
52
52
  - Anyscale
53
+ - Google
53
54
  - LiteLLM
54
55
  - Ollama
55
56
 
@@ -46,7 +46,8 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
46
46
 
47
47
  def _make_connection_args(self):
48
48
  cloud_pgvector_url = os.environ.get('KB_PGVECTOR_URL')
49
- if cloud_pgvector_url is not None:
49
+ # if no connection args and shared pg vector defined - use it
50
+ if len(self.connection_args) == 0 and cloud_pgvector_url is not None:
50
51
  result = urlparse(cloud_pgvector_url)
51
52
  self.connection_args = {
52
53
  'host': result.hostname,
@@ -157,7 +158,7 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
157
158
  where_clauses.append(f'{key} {value["op"]} {value["value"]}')
158
159
 
159
160
  if len(where_clauses) > 1:
160
- return f"WHERE{' AND '.join(where_clauses)}"
161
+ return f"WHERE {' AND '.join(where_clauses)}"
161
162
  elif len(where_clauses) == 1:
162
163
  return f"WHERE {where_clauses[0]}"
163
164
  else:
@@ -195,11 +196,6 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
195
196
  # given filter conditions, construct where clause
196
197
  where_clause = self._construct_where_clause(filter_conditions)
197
198
 
198
- # construct full after from clause, where clause + offset clause + limit clause
199
- after_from_clause = self._construct_full_after_from_clause(
200
- where_clause, offset_clause, limit_clause
201
- )
202
-
203
199
  # Handle distance column specially since it's calculated, not stored
204
200
  modified_columns = []
205
201
  has_distance = False
@@ -219,7 +215,7 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
219
215
  if filter_conditions:
220
216
 
221
217
  if embedding_search:
222
- search_vector = filter_conditions["embeddings"]["value"][0]
218
+ search_vector = filter_conditions["embeddings"]["value"]
223
219
  filter_conditions.pop("embeddings")
224
220
 
225
221
  if self._is_sparse:
@@ -241,15 +237,15 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
241
237
  if has_distance:
242
238
  targets = f"{targets}, (embeddings {distance_op} '{search_vector}') as distance"
243
239
 
244
- return f"SELECT {targets} FROM {table_name} ORDER BY embeddings {distance_op} '{search_vector}' ASC {after_from_clause}"
240
+ return f"SELECT {targets} FROM {table_name} {where_clause} ORDER BY embeddings {distance_op} '{search_vector}' ASC {limit_clause} {offset_clause} "
245
241
 
246
242
  else:
247
243
  # if filter conditions, return rows that satisfy the conditions
248
- return f"SELECT {targets} FROM {table_name} {after_from_clause}"
244
+ return f"SELECT {targets} FROM {table_name} {where_clause} {limit_clause} {offset_clause}"
249
245
 
250
246
  else:
251
247
  # if no filter conditions, return all rows
252
- return f"SELECT {targets} FROM {table_name} {after_from_clause}"
248
+ return f"SELECT {targets} FROM {table_name} {limit_clause} {offset_clause}"
253
249
 
254
250
  def _check_table(self, table_name: str):
255
251
  # Apply namespace for a user
@@ -1,6 +1,7 @@
1
1
  import time
2
2
  import json
3
3
  from typing import Optional
4
+ import threading
4
5
 
5
6
  import pandas as pd
6
7
  import psycopg
@@ -77,6 +78,8 @@ class PostgresHandler(DatabaseHandler):
77
78
  self.is_connected = False
78
79
  self.thread_safe = True
79
80
 
81
+ self._insert_lock = threading.Lock()
82
+
80
83
  def __del__(self):
81
84
  if self.is_connected:
82
85
  self.disconnect()
@@ -261,14 +264,35 @@ class PostgresHandler(DatabaseHandler):
261
264
 
262
265
  connection = self.connect()
263
266
 
264
- columns = [f'"{c}"' for c in df.columns]
267
+ columns = df.columns
268
+
269
+ # postgres 'copy' is not thread safe. use lock to prevent concurrent execution
270
+ with self._insert_lock:
271
+ resp = self.get_columns(table_name)
272
+
273
+ # copy requires precise cases of names: get current column names from table and adapt input dataframe columns
274
+ if resp.data_frame is not None and not resp.data_frame.empty:
275
+ db_columns = {
276
+ c.lower(): c
277
+ for c in resp.data_frame['Field']
278
+ }
279
+
280
+ # try to get case of existing column
281
+ columns = [
282
+ db_columns.get(c.lower(), c)
283
+ for c in columns
284
+ ]
285
+
286
+ columns = [f'"{c}"' for c in columns]
265
287
  rowcount = None
288
+
266
289
  with connection.cursor() as cur:
267
290
  try:
268
- with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
269
- df.to_csv(copy, index=False, header=False)
291
+ with self._insert_lock:
292
+ with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
293
+ df.to_csv(copy, index=False, header=False)
270
294
 
271
- connection.commit()
295
+ connection.commit()
272
296
  except Exception as e:
273
297
  logger.error(f'Error running insert to {table_name} on {self.database}, {e}!')
274
298
  connection.rollback()
@@ -1,6 +1,6 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from pydantic import BaseModel, ConfigDict
3
+ from pydantic import BaseModel, ConfigDict, Field
4
4
 
5
5
 
6
6
  class BaseLLMConfig(BaseModel):
@@ -104,3 +104,13 @@ class NvidiaNIMConfig(BaseLLMConfig):
104
104
  class MindsdbConfig(BaseLLMConfig):
105
105
  model_name: str
106
106
  project_name: str
107
+
108
+
109
+ # See https://python.langchain.com/api_reference/google_genai/chat_models/langchain_google_genai.chat_models.ChatGoogleGenerativeAI.html
110
+ class GoogleConfig(BaseLLMConfig):
111
+ model: str = Field(description="Gemini model name to use (e.g., 'gemini-1.5-pro')")
112
+ temperature: Optional[float] = Field(default=None, description="Controls randomness in responses")
113
+ top_p: Optional[float] = Field(default=None, description="Nucleus sampling parameter")
114
+ top_k: Optional[int] = Field(default=None, description="Number of highest probability tokens to consider")
115
+ max_output_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate")
116
+ google_api_key: Optional[str] = Field(default=None, description="API key for Google Generative AI")
@@ -10,6 +10,7 @@ from mindsdb.integrations.libs.llm.config import (
10
10
  AnthropicConfig,
11
11
  AnyscaleConfig,
12
12
  BaseLLMConfig,
13
+ GoogleConfig,
13
14
  LiteLLMConfig,
14
15
  OllamaConfig,
15
16
  OpenAIConfig,
@@ -31,6 +32,8 @@ DEFAULT_ANTHROPIC_MODEL = "claude-3-haiku-20240307"
31
32
  DEFAULT_ANYSCALE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
32
33
  DEFAULT_ANYSCALE_BASE_URL = "https://api.endpoints.anyscale.com/v1"
33
34
 
35
+ DEFAULT_GOOGLE_MODEL = "gemini-2.5-pro-preview-03-25"
36
+
34
37
  DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo"
35
38
  DEFAULT_LITELLM_PROVIDER = "openai"
36
39
  DEFAULT_LITELLM_BASE_URL = "https://ai.dev.mindsdb.com"
@@ -225,6 +228,15 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
225
228
  openai_organization=args.get("api_organization", None),
226
229
  request_timeout=args.get("request_timeout", None),
227
230
  )
231
+ if provider == "google":
232
+ return GoogleConfig(
233
+ model=args.get("model_name", DEFAULT_GOOGLE_MODEL),
234
+ temperature=temperature,
235
+ top_p=args.get("top_p", None),
236
+ top_k=args.get("top_k", None),
237
+ max_output_tokens=args.get("max_tokens", None),
238
+ google_api_key=args["api_keys"].get("google", None),
239
+ )
228
240
 
229
241
  raise ValueError(f"Provider {provider} is not supported.")
230
242
 
@@ -278,8 +278,16 @@ class VectorStoreHandler(BaseHandler):
278
278
  return self.do_upsert(table_name, df)
279
279
 
280
280
  def do_upsert(self, table_name, df):
281
- # if handler supports it, call upsert method
281
+ """Upsert data into table, handling document updates and deletions.
282
282
 
283
+ Args:
284
+ table_name (str): Name of the table
285
+ df (pd.DataFrame): DataFrame containing the data to upsert
286
+
287
+ The function handles three cases:
288
+ 1. New documents: Insert them
289
+ 2. Updated documents: Delete old chunks and insert new ones
290
+ """
283
291
  id_col = TableField.ID.value
284
292
  content_col = TableField.CONTENT.value
285
293
 
@@ -18,7 +18,7 @@ log = logging.getLogger(__name__)
18
18
 
19
19
 
20
20
  class LLMReranker(BaseDocumentCompressor):
21
- filtering_threshold: float = 0.5 # Default threshold for filtering
21
+ filtering_threshold: float = 0.0 # Default threshold for filtering
22
22
  model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
23
23
  temperature: float = 0.0 # Temperature for the model
24
24
  openai_api_key: Optional[str] = None
@@ -15,7 +15,8 @@ SUPPORTED_PROVIDERS = {
15
15
  "litellm",
16
16
  "ollama",
17
17
  "nvidia_nim",
18
- "vllm"
18
+ "vllm",
19
+ "google"
19
20
  }
20
21
  # Chat models
21
22
  ANTHROPIC_CHAT_MODELS = (
@@ -153,6 +154,15 @@ NVIDIA_NIM_CHAT_MODELS = (
153
154
  "ibm/granite-34b-code-instruct",
154
155
  )
155
156
 
157
+ GOOGLE_GEMINI_CHAT_MODELS = (
158
+ "gemini-2.5-pro-preview-03-25",
159
+ "gemini-2.0-flash",
160
+ "gemini-2.0-flash-lite",
161
+ "gemini-1.5-flash",
162
+ "gemini-1.5-flash-8b",
163
+ "gemini-1.5-pro",
164
+ )
165
+
156
166
  # Define a read-only dictionary mapping providers to their models
157
167
  PROVIDER_TO_MODELS = MappingProxyType(
158
168
  {
@@ -160,6 +170,7 @@ PROVIDER_TO_MODELS = MappingProxyType(
160
170
  "ollama": OLLAMA_CHAT_MODELS,
161
171
  "openai": OPEN_AI_CHAT_MODELS,
162
172
  "nvidia_nim": NVIDIA_NIM_CHAT_MODELS,
173
+ "google": GOOGLE_GEMINI_CHAT_MODELS,
163
174
  }
164
175
  )
165
176
 
@@ -15,6 +15,7 @@ from langchain_community.chat_models import (
15
15
  ChatAnyscale,
16
16
  ChatLiteLLM,
17
17
  ChatOllama)
18
+ from langchain_google_genai import ChatGoogleGenerativeAI
18
19
  from langchain_core.agents import AgentAction, AgentStep
19
20
  from langchain_core.callbacks.base import BaseCallbackHandler
20
21
 
@@ -50,6 +51,7 @@ from .constants import (
50
51
  DEFAULT_TIKTOKEN_MODEL_NAME,
51
52
  SUPPORTED_PROVIDERS,
52
53
  ANTHROPIC_CHAT_MODELS,
54
+ GOOGLE_GEMINI_CHAT_MODELS,
53
55
  OLLAMA_CHAT_MODELS,
54
56
  NVIDIA_NIM_CHAT_MODELS,
55
57
  USER_COLUMN,
@@ -85,6 +87,8 @@ def get_llm_provider(args: Dict) -> str:
85
87
  return "ollama"
86
88
  if args["model_name"] in NVIDIA_NIM_CHAT_MODELS:
87
89
  return "nvidia_nim"
90
+ if args["model_name"] in GOOGLE_GEMINI_CHAT_MODELS:
91
+ return "google"
88
92
 
89
93
  # For vLLM, require explicit provider specification
90
94
  raise ValueError("Invalid model name. Please define a supported llm provider")
@@ -162,6 +166,8 @@ def create_chat_model(args: Dict):
162
166
  return ChatOllama(**model_kwargs)
163
167
  if args["provider"] == "nvidia_nim":
164
168
  return ChatNVIDIA(**model_kwargs)
169
+ if args["provider"] == "google":
170
+ return ChatGoogleGenerativeAI(**model_kwargs)
165
171
  if args["provider"] == "mindsdb":
166
172
  return ChatMindsdb(**model_kwargs)
167
173
  raise ValueError(f'Unknown provider: {args["provider"]}')
@@ -69,6 +69,12 @@ class Project:
69
69
  self.id = record.id
70
70
 
71
71
  def delete(self):
72
+ if self.record.metadata_ and self.record.metadata_.get('is_default', False):
73
+ raise Exception(
74
+ f"Project '{self.name}' can not be deleted, because it is default project."
75
+ "The default project can be changed in the config file or by setting the environment variable MINDSDB_DEFAULT_PROJECT."
76
+ )
77
+
72
78
  tables = self.get_tables()
73
79
  tables = [key for key, val in tables.items() if val['type'] != 'table']
74
80
  if len(tables) > 0:
@@ -466,7 +472,7 @@ class ProjectController:
466
472
 
467
473
  if new_metadata is not None:
468
474
  project.metadata = new_metadata
469
- project.record.metadata = new_metadata
475
+ project.record.metadata_ = new_metadata
470
476
  flag_modified(project.record, 'metadata_')
471
477
 
472
478
  db.session.commit()