MindsDB 25.4.2.0__py3-none-any.whl → 25.4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (30) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/api/executor/command_executor.py +29 -0
  3. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +3 -2
  4. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +43 -1
  5. mindsdb/api/executor/planner/plan_join.py +1 -1
  6. mindsdb/api/executor/planner/query_plan.py +1 -0
  7. mindsdb/api/executor/planner/query_planner.py +86 -14
  8. mindsdb/api/executor/planner/steps.py +9 -1
  9. mindsdb/api/executor/sql_query/sql_query.py +37 -6
  10. mindsdb/api/executor/sql_query/steps/__init__.py +1 -0
  11. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +288 -0
  12. mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +17 -16
  13. mindsdb/integrations/handlers/langchain_handler/langchain_handler.py +1 -0
  14. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +7 -11
  15. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +28 -4
  16. mindsdb/integrations/libs/llm/config.py +11 -1
  17. mindsdb/integrations/libs/llm/utils.py +12 -0
  18. mindsdb/interfaces/agents/constants.py +12 -1
  19. mindsdb/interfaces/agents/langchain_agent.py +6 -0
  20. mindsdb/interfaces/knowledge_base/controller.py +128 -43
  21. mindsdb/interfaces/query_context/context_controller.py +221 -0
  22. mindsdb/interfaces/storage/db.py +23 -0
  23. mindsdb/migrations/versions/2025-03-21_fda503400e43_queries.py +45 -0
  24. mindsdb/utilities/context_executor.py +1 -1
  25. mindsdb/utilities/partitioning.py +35 -20
  26. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/METADATA +224 -222
  27. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/RECORD +30 -28
  28. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/WHEEL +0 -0
  29. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/licenses/LICENSE +0 -0
  30. {mindsdb-25.4.2.0.dist-info → mindsdb-25.4.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,288 @@
1
+ import pandas as pd
2
+ import threading
3
+ import queue
4
+ from typing import List
5
+
6
+ from mindsdb_sql_parser import ASTNode
7
+ from mindsdb.api.executor.planner.steps import FetchDataframeStepPartition
8
+ from mindsdb.integrations.utilities.query_traversal import query_traversal
9
+
10
+ from mindsdb.interfaces.query_context.context_controller import RunningQuery
11
+ from mindsdb.api.executor.sql_query.result_set import ResultSet
12
+ from mindsdb.utilities import log
13
+ from mindsdb.utilities.config import Config
14
+ from mindsdb.utilities.context import Context, context as ctx
15
+ from mindsdb.utilities.partitioning import get_max_thread_count, split_data_frame
16
+ from mindsdb.api.executor.sql_query.steps.fetch_dataframe import get_table_alias, get_fill_param_fnc
17
+
18
+ from .base import BaseStepCall
19
+
20
+
21
+ logger = log.getLogger(__name__)
22
+
23
+
24
+ class FetchDataframePartitionCall(BaseStepCall):
25
+ """
26
+ Alternative to FetchDataframeCall but fetch data by batches wrapping user's query to:
27
+
28
+ select * from ({user query})
29
+ where {track_column} > {previous value}
30
+ order by track_column
31
+ limit size {batch_size} `
32
+
33
+ """
34
+
35
+ bind = FetchDataframeStepPartition
36
+
37
+ def call(self, step: FetchDataframeStepPartition) -> ResultSet:
38
+ """
39
+ Parameters:
40
+ - batch_size - count of rows to fetch from database per iteration, optional default 1000
41
+ - threads - run partitioning in threads, bool or int, optinal, if set:
42
+ - int value: use this as count of threads
43
+ - true: table threads, autodetect count of thread
44
+ - false: disable threads even if ml task queue is enabled
45
+ - track_column - column used for creating partitions
46
+ - query will be sorted by this column and select will be limited by batch_size
47
+ - error (default raise)
48
+ - when `error='skip'`, errors in partition will be skipped and execution will be continued
49
+ """
50
+
51
+ self.dn = self.session.datahub.get(step.integration)
52
+ query = step.query
53
+
54
+ # fill params
55
+ fill_params = get_fill_param_fnc(self.steps_data)
56
+ query_traversal(query, fill_params)
57
+
58
+ # get query record
59
+ run_query = self.sql_query.run_query
60
+ if run_query is None:
61
+ raise RuntimeError('Error with partitioning of the query')
62
+ run_query.set_params(step.params)
63
+
64
+ self.table_alias = get_table_alias(step.query.from_table, self.context.get('database'))
65
+ self.current_step_num = step.step_num
66
+ self.substeps = step.steps
67
+
68
+ config = Config()
69
+
70
+ # ml task queue enabled?
71
+ use_threads, thread_count = False, None
72
+ if config['ml_task_queue']['type'] == 'redis':
73
+ use_threads = True
74
+
75
+ # use threads?
76
+ if 'threads' in step.params:
77
+ threads = step.params['threads']
78
+ if isinstance(threads, int):
79
+ thread_count = threads
80
+ use_threads = True
81
+ if threads is True:
82
+ use_threads = True
83
+ if threads is False:
84
+ # disable even with ml task queue
85
+ use_threads = False
86
+
87
+ on_error = step.params.get('error', 'raise')
88
+ if use_threads:
89
+ return self.fetch_threads(run_query, query, thread_count=thread_count, on_error=on_error)
90
+ else:
91
+ return self.fetch_iterate(run_query, query, on_error=on_error)
92
+
93
+ def fetch_iterate(self, run_query: RunningQuery, query: ASTNode, on_error: str = None) -> ResultSet:
94
+ """
95
+ Process batches one by one in circle
96
+ """
97
+
98
+ results = []
99
+ while True:
100
+
101
+ # fetch batch
102
+ query2 = run_query.get_partition_query(self.current_step_num, query)
103
+ response = self.dn.query(
104
+ query=query2,
105
+ session=self.session
106
+ )
107
+ df = response.data_frame
108
+
109
+ if df is None or len(df) == 0:
110
+ break
111
+
112
+ # executing of sub steps can modify dataframe columns, lets memorise max tracking value
113
+ max_track_value = run_query.get_max_track_value(df)
114
+ try:
115
+ sub_data = self.exec_sub_steps(df)
116
+ results.append(sub_data)
117
+ except Exception as e:
118
+ if on_error == 'skip':
119
+ logger.error(e)
120
+ else:
121
+ raise e
122
+
123
+ run_query.set_progress(df, max_track_value)
124
+
125
+ return self.concat_results(results)
126
+
127
+ def concat_results(self, results: List[ResultSet]) -> ResultSet:
128
+ """
129
+ Concatenate list of result sets to single result set
130
+ """
131
+ df_list = []
132
+ for res in results:
133
+ df, col_names = res.to_df_cols()
134
+ if len(df) > 0:
135
+ df_list.append(df)
136
+
137
+ data = ResultSet()
138
+ if len(df_list) > 0:
139
+ data.from_df_cols(pd.concat(df_list), col_names)
140
+
141
+ return data
142
+
143
+ def exec_sub_steps(self, df: pd.DataFrame) -> ResultSet:
144
+ """
145
+ FetchDataframeStepPartition has substeps defined
146
+ Every batch of data have to be used to execute these substeps
147
+ - batch of data is put as result of FetchDataframeStepPartition
148
+ - substep are executed using result of previos step (like it is all fetched data is available)
149
+ - the final result is returned and used outside to concatenate with results of other's batches
150
+ """
151
+
152
+ input_data = ResultSet()
153
+
154
+ input_data.from_df(
155
+ df,
156
+ table_name=self.table_alias[1],
157
+ table_alias=self.table_alias[2],
158
+ database=self.table_alias[0]
159
+ )
160
+
161
+ # execute with modified previous results
162
+ steps_data2 = self.steps_data.copy()
163
+ steps_data2[self.current_step_num] = input_data
164
+
165
+ sub_data = None
166
+ for substep in self.substeps:
167
+ sub_data = self.sql_query.execute_step(substep, steps_data=steps_data2)
168
+ steps_data2[substep.step_num] = sub_data
169
+ return sub_data
170
+
171
+ def fetch_threads(self, run_query: RunningQuery, query: ASTNode,
172
+ thread_count: int = None, on_error: str = None) -> ResultSet:
173
+ """
174
+ Process batches in threads
175
+ - spawn required count of threads
176
+ - create in/out queue to communicate with threads
177
+ - send task to threads and receive results
178
+ """
179
+
180
+ # create communication queues
181
+ queue_in = queue.Queue()
182
+ queue_out = queue.Queue()
183
+ self.stop_event = threading.Event()
184
+
185
+ if thread_count is None:
186
+ thread_count = get_max_thread_count()
187
+
188
+ # 3 tasks per worker during 1 batch
189
+ partition_size = int(run_query.batch_size / thread_count / 3)
190
+ # min partition size
191
+ if partition_size < 10:
192
+ partition_size = 10
193
+
194
+ # create N workers pool
195
+ workers = []
196
+ results = []
197
+
198
+ try:
199
+ for i in range(thread_count):
200
+ worker = threading.Thread(target=self._worker, daemon=True, args=(ctx.dump(), queue_in,
201
+ queue_out, self.stop_event))
202
+ worker.start()
203
+ workers.append(worker)
204
+
205
+ while True:
206
+ # fetch batch
207
+ query2 = run_query.get_partition_query(self.current_step_num, query)
208
+ response = self.dn.query(
209
+ query=query2,
210
+ session=self.session
211
+ )
212
+ df = response.data_frame
213
+
214
+ if df is None or len(df) == 0:
215
+ # TODO detect circles: data handler ignores condition and output is repeated
216
+
217
+ # exit & stop workers
218
+ break
219
+
220
+ max_track_value = run_query.get_max_track_value(df)
221
+
222
+ # split into chunks and send to workers
223
+ sent_chunks = 0
224
+ for df2 in split_data_frame(df, partition_size):
225
+ queue_in.put([sent_chunks, df2])
226
+ sent_chunks += 1
227
+
228
+ batch_results = []
229
+ for i in range(sent_chunks):
230
+ res = queue_out.get()
231
+ if 'error' in res:
232
+ if on_error == 'skip':
233
+ logger.error(res['error'])
234
+ else:
235
+ raise RuntimeError(res['error'])
236
+
237
+ if res['data']:
238
+ batch_results.append(res)
239
+
240
+ # sort results
241
+ batch_results.sort(key=lambda x: x['num'])
242
+
243
+ results.append(self.concat_results(
244
+ [item['data'] for item in batch_results]
245
+ ))
246
+
247
+ # TODO
248
+ # 1. get next batch without updating track_value:
249
+ # it allows to keep queue_in filled with data between fetching batches
250
+ run_query.set_progress(df, max_track_value)
251
+ finally:
252
+ self.close_workers(workers)
253
+
254
+ return self.concat_results(results)
255
+
256
+ def close_workers(self, workers: List[threading.Thread]):
257
+ """
258
+ Sent signal to workers to stop
259
+ """
260
+
261
+ self.stop_event.set()
262
+ for worker in workers:
263
+ if worker.is_alive():
264
+ worker.join()
265
+
266
+ def _worker(self, context: Context, queue_in: queue.Queue, queue_out: queue.Queue, stop_event: threading.Event):
267
+ """
268
+ Worker function. Execute incoming tasks unless stop_event is set
269
+ """
270
+ ctx.load(context)
271
+ while True:
272
+ if stop_event.is_set():
273
+ break
274
+
275
+ try:
276
+ chunk_num, df = queue_in.get(timeout=1)
277
+ if df is None:
278
+ continue
279
+
280
+ sub_data = self.exec_sub_steps(df)
281
+
282
+ queue_out.put({'data': sub_data, 'num': chunk_num})
283
+ except queue.Empty:
284
+ continue
285
+
286
+ except Exception as e:
287
+ queue_out.put({'error': str(e)})
288
+ stop_event.set()
@@ -104,6 +104,22 @@ def construct_model_from_args(args: Dict) -> Embeddings:
104
104
  return model
105
105
 
106
106
 
107
+ def row_to_document(row: pd.Series) -> str:
108
+ """
109
+ Convert a row in the input dataframe into a document
110
+
111
+ Default implementation is to concatenate all the columns
112
+ in the form of
113
+ field1: value1\nfield2: value2\n...
114
+ """
115
+ fields = row.index.tolist()
116
+ values = row.values.tolist()
117
+ document = "\n".join(
118
+ [f"{field}: {value}" for field, value in zip(fields, values)]
119
+ )
120
+ return document
121
+
122
+
107
123
  class LangchainEmbeddingHandler(BaseMLEngine):
108
124
  """
109
125
  Bridge class to connect langchain.embeddings module to mindsDB
@@ -180,7 +196,7 @@ class LangchainEmbeddingHandler(BaseMLEngine):
180
196
  )
181
197
 
182
198
  # convert each row into a document
183
- df_texts = df[input_columns].apply(self.row_to_document, axis=1)
199
+ df_texts = df[input_columns].apply(row_to_document, axis=1)
184
200
  embeddings = model.embed_documents(df_texts.tolist())
185
201
 
186
202
  # create a new dataframe with the embeddings
@@ -188,21 +204,6 @@ class LangchainEmbeddingHandler(BaseMLEngine):
188
204
 
189
205
  return df_embeddings
190
206
 
191
- def row_to_document(self, row: pd.Series) -> str:
192
- """
193
- Convert a row in the input dataframe into a document
194
-
195
- Default implementation is to concatenate all the columns
196
- in the form of
197
- field1: value1\nfield2: value2\n...
198
- """
199
- fields = row.index.tolist()
200
- values = row.values.tolist()
201
- document = "\n".join(
202
- [f"{field}: {value}" for field, value in zip(fields, values)]
203
- )
204
- return document
205
-
206
207
  def finetune(
207
208
  self, df: Union[DataFrame, None] = None, args: Union[Dict, None] = None
208
209
  ) -> None:
@@ -50,6 +50,7 @@ class LangChainHandler(BaseMLEngine):
50
50
  - OpenAI
51
51
  - Anthropic
52
52
  - Anyscale
53
+ - Google
53
54
  - LiteLLM
54
55
  - Ollama
55
56
 
@@ -46,7 +46,8 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
46
46
 
47
47
  def _make_connection_args(self):
48
48
  cloud_pgvector_url = os.environ.get('KB_PGVECTOR_URL')
49
- if cloud_pgvector_url is not None:
49
+ # if no connection args and shared pg vector defined - use it
50
+ if len(self.connection_args) == 0 and cloud_pgvector_url is not None:
50
51
  result = urlparse(cloud_pgvector_url)
51
52
  self.connection_args = {
52
53
  'host': result.hostname,
@@ -157,7 +158,7 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
157
158
  where_clauses.append(f'{key} {value["op"]} {value["value"]}')
158
159
 
159
160
  if len(where_clauses) > 1:
160
- return f"WHERE{' AND '.join(where_clauses)}"
161
+ return f"WHERE {' AND '.join(where_clauses)}"
161
162
  elif len(where_clauses) == 1:
162
163
  return f"WHERE {where_clauses[0]}"
163
164
  else:
@@ -195,11 +196,6 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
195
196
  # given filter conditions, construct where clause
196
197
  where_clause = self._construct_where_clause(filter_conditions)
197
198
 
198
- # construct full after from clause, where clause + offset clause + limit clause
199
- after_from_clause = self._construct_full_after_from_clause(
200
- where_clause, offset_clause, limit_clause
201
- )
202
-
203
199
  # Handle distance column specially since it's calculated, not stored
204
200
  modified_columns = []
205
201
  has_distance = False
@@ -219,7 +215,7 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
219
215
  if filter_conditions:
220
216
 
221
217
  if embedding_search:
222
- search_vector = filter_conditions["embeddings"]["value"][0]
218
+ search_vector = filter_conditions["embeddings"]["value"]
223
219
  filter_conditions.pop("embeddings")
224
220
 
225
221
  if self._is_sparse:
@@ -241,15 +237,15 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
241
237
  if has_distance:
242
238
  targets = f"{targets}, (embeddings {distance_op} '{search_vector}') as distance"
243
239
 
244
- return f"SELECT {targets} FROM {table_name} ORDER BY embeddings {distance_op} '{search_vector}' ASC {after_from_clause}"
240
+ return f"SELECT {targets} FROM {table_name} {where_clause} ORDER BY embeddings {distance_op} '{search_vector}' ASC {limit_clause} {offset_clause} "
245
241
 
246
242
  else:
247
243
  # if filter conditions, return rows that satisfy the conditions
248
- return f"SELECT {targets} FROM {table_name} {after_from_clause}"
244
+ return f"SELECT {targets} FROM {table_name} {where_clause} {limit_clause} {offset_clause}"
249
245
 
250
246
  else:
251
247
  # if no filter conditions, return all rows
252
- return f"SELECT {targets} FROM {table_name} {after_from_clause}"
248
+ return f"SELECT {targets} FROM {table_name} {limit_clause} {offset_clause}"
253
249
 
254
250
  def _check_table(self, table_name: str):
255
251
  # Apply namespace for a user
@@ -1,6 +1,7 @@
1
1
  import time
2
2
  import json
3
3
  from typing import Optional
4
+ import threading
4
5
 
5
6
  import pandas as pd
6
7
  import psycopg
@@ -77,6 +78,8 @@ class PostgresHandler(DatabaseHandler):
77
78
  self.is_connected = False
78
79
  self.thread_safe = True
79
80
 
81
+ self._insert_lock = threading.Lock()
82
+
80
83
  def __del__(self):
81
84
  if self.is_connected:
82
85
  self.disconnect()
@@ -261,14 +264,35 @@ class PostgresHandler(DatabaseHandler):
261
264
 
262
265
  connection = self.connect()
263
266
 
264
- columns = [f'"{c}"' for c in df.columns]
267
+ columns = df.columns
268
+
269
+ # postgres 'copy' is not thread safe. use lock to prevent concurrent execution
270
+ with self._insert_lock:
271
+ resp = self.get_columns(table_name)
272
+
273
+ # copy requires precise cases of names: get current column names from table and adapt input dataframe columns
274
+ if resp.data_frame is not None and not resp.data_frame.empty:
275
+ db_columns = {
276
+ c.lower(): c
277
+ for c in resp.data_frame['Field']
278
+ }
279
+
280
+ # try to get case of existing column
281
+ columns = [
282
+ db_columns.get(c.lower(), c)
283
+ for c in columns
284
+ ]
285
+
286
+ columns = [f'"{c}"' for c in columns]
265
287
  rowcount = None
288
+
266
289
  with connection.cursor() as cur:
267
290
  try:
268
- with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
269
- df.to_csv(copy, index=False, header=False)
291
+ with self._insert_lock:
292
+ with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
293
+ df.to_csv(copy, index=False, header=False)
270
294
 
271
- connection.commit()
295
+ connection.commit()
272
296
  except Exception as e:
273
297
  logger.error(f'Error running insert to {table_name} on {self.database}, {e}!')
274
298
  connection.rollback()
@@ -1,6 +1,6 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from pydantic import BaseModel, ConfigDict
3
+ from pydantic import BaseModel, ConfigDict, Field
4
4
 
5
5
 
6
6
  class BaseLLMConfig(BaseModel):
@@ -104,3 +104,13 @@ class NvidiaNIMConfig(BaseLLMConfig):
104
104
  class MindsdbConfig(BaseLLMConfig):
105
105
  model_name: str
106
106
  project_name: str
107
+
108
+
109
+ # See https://python.langchain.com/api_reference/google_genai/chat_models/langchain_google_genai.chat_models.ChatGoogleGenerativeAI.html
110
+ class GoogleConfig(BaseLLMConfig):
111
+ model: str = Field(description="Gemini model name to use (e.g., 'gemini-1.5-pro')")
112
+ temperature: Optional[float] = Field(default=None, description="Controls randomness in responses")
113
+ top_p: Optional[float] = Field(default=None, description="Nucleus sampling parameter")
114
+ top_k: Optional[int] = Field(default=None, description="Number of highest probability tokens to consider")
115
+ max_output_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate")
116
+ google_api_key: Optional[str] = Field(default=None, description="API key for Google Generative AI")
@@ -10,6 +10,7 @@ from mindsdb.integrations.libs.llm.config import (
10
10
  AnthropicConfig,
11
11
  AnyscaleConfig,
12
12
  BaseLLMConfig,
13
+ GoogleConfig,
13
14
  LiteLLMConfig,
14
15
  OllamaConfig,
15
16
  OpenAIConfig,
@@ -31,6 +32,8 @@ DEFAULT_ANTHROPIC_MODEL = "claude-3-haiku-20240307"
31
32
  DEFAULT_ANYSCALE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
32
33
  DEFAULT_ANYSCALE_BASE_URL = "https://api.endpoints.anyscale.com/v1"
33
34
 
35
+ DEFAULT_GOOGLE_MODEL = "gemini-2.5-pro-preview-03-25"
36
+
34
37
  DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo"
35
38
  DEFAULT_LITELLM_PROVIDER = "openai"
36
39
  DEFAULT_LITELLM_BASE_URL = "https://ai.dev.mindsdb.com"
@@ -225,6 +228,15 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
225
228
  openai_organization=args.get("api_organization", None),
226
229
  request_timeout=args.get("request_timeout", None),
227
230
  )
231
+ if provider == "google":
232
+ return GoogleConfig(
233
+ model=args.get("model_name", DEFAULT_GOOGLE_MODEL),
234
+ temperature=temperature,
235
+ top_p=args.get("top_p", None),
236
+ top_k=args.get("top_k", None),
237
+ max_output_tokens=args.get("max_tokens", None),
238
+ google_api_key=args["api_keys"].get("google", None),
239
+ )
228
240
 
229
241
  raise ValueError(f"Provider {provider} is not supported.")
230
242
 
@@ -15,7 +15,8 @@ SUPPORTED_PROVIDERS = {
15
15
  "litellm",
16
16
  "ollama",
17
17
  "nvidia_nim",
18
- "vllm"
18
+ "vllm",
19
+ "google"
19
20
  }
20
21
  # Chat models
21
22
  ANTHROPIC_CHAT_MODELS = (
@@ -153,6 +154,15 @@ NVIDIA_NIM_CHAT_MODELS = (
153
154
  "ibm/granite-34b-code-instruct",
154
155
  )
155
156
 
157
+ GOOGLE_GEMINI_CHAT_MODELS = (
158
+ "gemini-2.5-pro-preview-03-25",
159
+ "gemini-2.0-flash",
160
+ "gemini-2.0-flash-lite",
161
+ "gemini-1.5-flash",
162
+ "gemini-1.5-flash-8b",
163
+ "gemini-1.5-pro",
164
+ )
165
+
156
166
  # Define a read-only dictionary mapping providers to their models
157
167
  PROVIDER_TO_MODELS = MappingProxyType(
158
168
  {
@@ -160,6 +170,7 @@ PROVIDER_TO_MODELS = MappingProxyType(
160
170
  "ollama": OLLAMA_CHAT_MODELS,
161
171
  "openai": OPEN_AI_CHAT_MODELS,
162
172
  "nvidia_nim": NVIDIA_NIM_CHAT_MODELS,
173
+ "google": GOOGLE_GEMINI_CHAT_MODELS,
163
174
  }
164
175
  )
165
176
 
@@ -15,6 +15,7 @@ from langchain_community.chat_models import (
15
15
  ChatAnyscale,
16
16
  ChatLiteLLM,
17
17
  ChatOllama)
18
+ from langchain_google_genai import ChatGoogleGenerativeAI
18
19
  from langchain_core.agents import AgentAction, AgentStep
19
20
  from langchain_core.callbacks.base import BaseCallbackHandler
20
21
 
@@ -50,6 +51,7 @@ from .constants import (
50
51
  DEFAULT_TIKTOKEN_MODEL_NAME,
51
52
  SUPPORTED_PROVIDERS,
52
53
  ANTHROPIC_CHAT_MODELS,
54
+ GOOGLE_GEMINI_CHAT_MODELS,
53
55
  OLLAMA_CHAT_MODELS,
54
56
  NVIDIA_NIM_CHAT_MODELS,
55
57
  USER_COLUMN,
@@ -85,6 +87,8 @@ def get_llm_provider(args: Dict) -> str:
85
87
  return "ollama"
86
88
  if args["model_name"] in NVIDIA_NIM_CHAT_MODELS:
87
89
  return "nvidia_nim"
90
+ if args["model_name"] in GOOGLE_GEMINI_CHAT_MODELS:
91
+ return "google"
88
92
 
89
93
  # For vLLM, require explicit provider specification
90
94
  raise ValueError("Invalid model name. Please define a supported llm provider")
@@ -162,6 +166,8 @@ def create_chat_model(args: Dict):
162
166
  return ChatOllama(**model_kwargs)
163
167
  if args["provider"] == "nvidia_nim":
164
168
  return ChatNVIDIA(**model_kwargs)
169
+ if args["provider"] == "google":
170
+ return ChatGoogleGenerativeAI(**model_kwargs)
165
171
  if args["provider"] == "mindsdb":
166
172
  return ChatMindsdb(**model_kwargs)
167
173
  raise ValueError(f'Unknown provider: {args["provider"]}')