MindsDB 25.4.3.2__py3-none-any.whl → 25.4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (68) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/__main__.py +18 -4
  3. mindsdb/api/executor/command_executor.py +12 -2
  4. mindsdb/api/executor/data_types/response_type.py +1 -0
  5. mindsdb/api/executor/datahub/classes/tables_row.py +3 -10
  6. mindsdb/api/executor/datahub/datanodes/datanode.py +7 -2
  7. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +44 -10
  8. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +57 -38
  9. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +2 -1
  10. mindsdb/api/executor/datahub/datanodes/project_datanode.py +39 -7
  11. mindsdb/api/executor/datahub/datanodes/system_tables.py +116 -109
  12. mindsdb/api/executor/planner/query_plan.py +1 -0
  13. mindsdb/api/executor/planner/query_planner.py +15 -1
  14. mindsdb/api/executor/planner/steps.py +8 -2
  15. mindsdb/api/executor/sql_query/sql_query.py +24 -8
  16. mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +25 -8
  17. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +4 -2
  18. mindsdb/api/executor/sql_query/steps/insert_step.py +2 -1
  19. mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -3
  20. mindsdb/api/http/namespaces/config.py +19 -11
  21. mindsdb/api/litellm/start.py +82 -0
  22. mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +133 -0
  23. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +7 -2
  24. mindsdb/integrations/handlers/chromadb_handler/settings.py +1 -0
  25. mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +13 -4
  26. mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +14 -5
  27. mindsdb/integrations/handlers/openai_handler/helpers.py +3 -5
  28. mindsdb/integrations/handlers/openai_handler/openai_handler.py +20 -8
  29. mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +14 -4
  30. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +34 -19
  31. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +21 -18
  32. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +14 -4
  33. mindsdb/integrations/handlers/togetherai_handler/__about__.py +9 -0
  34. mindsdb/integrations/handlers/togetherai_handler/__init__.py +20 -0
  35. mindsdb/integrations/handlers/togetherai_handler/creation_args.py +14 -0
  36. mindsdb/integrations/handlers/togetherai_handler/icon.svg +15 -0
  37. mindsdb/integrations/handlers/togetherai_handler/model_using_args.py +5 -0
  38. mindsdb/integrations/handlers/togetherai_handler/requirements.txt +2 -0
  39. mindsdb/integrations/handlers/togetherai_handler/settings.py +33 -0
  40. mindsdb/integrations/handlers/togetherai_handler/togetherai_handler.py +234 -0
  41. mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +1 -1
  42. mindsdb/integrations/libs/response.py +80 -32
  43. mindsdb/integrations/utilities/handler_utils.py +4 -0
  44. mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +360 -0
  45. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +8 -153
  46. mindsdb/interfaces/agents/litellm_server.py +345 -0
  47. mindsdb/interfaces/agents/mcp_client_agent.py +252 -0
  48. mindsdb/interfaces/agents/run_mcp_agent.py +205 -0
  49. mindsdb/interfaces/functions/controller.py +3 -2
  50. mindsdb/interfaces/knowledge_base/controller.py +106 -82
  51. mindsdb/interfaces/query_context/context_controller.py +55 -15
  52. mindsdb/interfaces/query_context/query_task.py +19 -0
  53. mindsdb/interfaces/skills/skill_tool.py +7 -1
  54. mindsdb/interfaces/skills/sql_agent.py +8 -3
  55. mindsdb/interfaces/storage/db.py +2 -2
  56. mindsdb/interfaces/tasks/task_monitor.py +5 -1
  57. mindsdb/interfaces/tasks/task_thread.py +6 -0
  58. mindsdb/migrations/versions/2025-04-22_53502b6d63bf_query_database.py +27 -0
  59. mindsdb/utilities/config.py +20 -2
  60. mindsdb/utilities/context.py +1 -0
  61. mindsdb/utilities/starters.py +7 -0
  62. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/METADATA +226 -221
  63. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/RECORD +67 -53
  64. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/WHEEL +1 -1
  65. mindsdb/integrations/handlers/snowflake_handler/tests/test_snowflake_handler.py +0 -230
  66. /mindsdb/{integrations/handlers/snowflake_handler/tests → api/litellm}/__init__.py +0 -0
  67. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/licenses/LICENSE +0 -0
  68. {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,82 @@
1
+ import asyncio
2
+ from mindsdb.utilities import log
3
+ from mindsdb.utilities.config import Config
4
+ from mindsdb.interfaces.agents.litellm_server import run_server, run_server_async
5
+
6
+ logger = log.getLogger(__name__)
7
+
8
+
9
+ async def start_async(verbose=False):
10
+ """Start the LiteLLM server
11
+
12
+ Args:
13
+ verbose (bool): Whether to enable verbose logging
14
+ """
15
+ config = Config()
16
+
17
+ # Get agent name from command line args
18
+ agent_name = config.cmd_args.agent
19
+ if not agent_name:
20
+ logger.error("Agent name is required for LiteLLM server. Use --agent parameter.")
21
+ return 1
22
+
23
+ # Get project name or use default
24
+ project_name = config.cmd_args.project or "mindsdb"
25
+
26
+ # Get MCP server connection details
27
+ mcp_host = config.get('api', {}).get('mcp', {}).get('host', '127.0.0.1')
28
+ mcp_port = int(config.get('api', {}).get('mcp', {}).get('port', 47337))
29
+
30
+ # Get LiteLLM server settings
31
+ litellm_host = config.get('api', {}).get('litellm', {}).get('host', '0.0.0.0')
32
+ litellm_port = int(config.get('api', {}).get('litellm', {}).get('port', 8000))
33
+
34
+ logger.info(f"Starting LiteLLM server for agent '{agent_name}' in project '{project_name}'")
35
+ logger.info(f"Connecting to MCP server at {mcp_host}:{mcp_port}")
36
+ logger.info(f"Binding to {litellm_host}:{litellm_port}")
37
+
38
+ return await run_server_async(
39
+ agent_name=agent_name,
40
+ project_name=project_name,
41
+ mcp_host=mcp_host,
42
+ mcp_port=mcp_port,
43
+ host=litellm_host,
44
+ port=litellm_port
45
+ )
46
+
47
+
48
+ def start(verbose=False):
49
+ """Start the LiteLLM server (synchronous wrapper)
50
+
51
+ Args:
52
+ verbose (bool): Whether to enable verbose logging
53
+ """
54
+ from mindsdb.interfaces.storage import db
55
+ db.init()
56
+
57
+ # Run the async function in the event loop
58
+ loop = asyncio.new_event_loop()
59
+ asyncio.set_event_loop(loop)
60
+ result = loop.run_until_complete(start_async(verbose))
61
+
62
+ if result == 0:
63
+ # Run the server
64
+ config = Config()
65
+ agent_name = config.cmd_args.agent
66
+ project_name = config.cmd_args.project or "mindsdb"
67
+ mcp_host = config.get('api', {}).get('mcp', {}).get('host', '127.0.0.1')
68
+ mcp_port = int(config.get('api', {}).get('mcp', {}).get('port', 47337))
69
+ litellm_host = config.get('api', {}).get('litellm', {}).get('host', '0.0.0.0')
70
+ litellm_port = int(config.get('api', {}).get('litellm', {}).get('port', 8000))
71
+
72
+ return run_server(
73
+ agent_name=agent_name,
74
+ project_name=project_name,
75
+ mcp_host=mcp_host,
76
+ mcp_port=mcp_port,
77
+ host=litellm_host,
78
+ port=litellm_port
79
+ )
80
+ else:
81
+ logger.error("LiteLLM server initialization failed")
82
+ return result
@@ -179,6 +179,139 @@ class MYSQL_DATA_TYPE(enum.Enum):
179
179
  BOOLEAN = 'BOOLEAN'
180
180
 
181
181
 
182
+ # Default values for attributes of MySQL data types as they appear in information_schema.columns
183
+ # These values match the MySQL v8.0.37 defaults and are used to properly represent column metadata
184
+ MYSQL_DATA_TYPE_COLUMNS_DEFAULT = {
185
+ MYSQL_DATA_TYPE.TINYINT: {
186
+ 'NUMERIC_PRECISION': 3,
187
+ 'NUMERIC_SCALE': 0
188
+ },
189
+ MYSQL_DATA_TYPE.SMALLINT: {
190
+ 'NUMERIC_PRECISION': 5,
191
+ 'NUMERIC_SCALE': 0
192
+ },
193
+ MYSQL_DATA_TYPE.MEDIUMINT: {
194
+ 'NUMERIC_PRECISION': 7,
195
+ 'NUMERIC_SCALE': 0
196
+ },
197
+ MYSQL_DATA_TYPE.INT: {
198
+ 'NUMERIC_PRECISION': 10,
199
+ 'NUMERIC_SCALE': 0
200
+ },
201
+ MYSQL_DATA_TYPE.BIGINT: {
202
+ 'NUMERIC_PRECISION': 19,
203
+ 'NUMERIC_SCALE': 0
204
+ },
205
+ MYSQL_DATA_TYPE.FLOAT: {
206
+ 'NUMERIC_PRECISION': 12
207
+ },
208
+ MYSQL_DATA_TYPE.DOUBLE: {
209
+ 'NUMERIC_PRECISION': 22
210
+ },
211
+ MYSQL_DATA_TYPE.DECIMAL: {
212
+ 'NUMERIC_PRECISION': 10,
213
+ 'NUMERIC_SCALE': 0,
214
+ 'COLUMN_TYPE': 'decimal(10,0)'
215
+ },
216
+ MYSQL_DATA_TYPE.YEAR: {
217
+ # every column is null
218
+ },
219
+ MYSQL_DATA_TYPE.TIME: {
220
+ 'DATETIME_PRECISION': 0
221
+ },
222
+ MYSQL_DATA_TYPE.DATE: {
223
+ # every column is null
224
+ },
225
+ MYSQL_DATA_TYPE.DATETIME: {
226
+ 'DATETIME_PRECISION': 0
227
+ },
228
+ MYSQL_DATA_TYPE.TIMESTAMP: {
229
+ 'DATETIME_PRECISION': 0
230
+ },
231
+ MYSQL_DATA_TYPE.CHAR: {
232
+ 'CHARACTER_MAXIMUM_LENGTH': 1,
233
+ 'CHARACTER_OCTET_LENGTH': 4,
234
+ 'CHARACTER_SET_NAME': 'utf8',
235
+ 'COLLATION_NAME': 'utf8_bin',
236
+ 'COLUMN_TYPE': 'char(1)'
237
+ },
238
+ MYSQL_DATA_TYPE.BINARY: {
239
+ 'CHARACTER_MAXIMUM_LENGTH': 1,
240
+ 'CHARACTER_OCTET_LENGTH': 1,
241
+ 'COLUMN_TYPE': 'binary(1)'
242
+ },
243
+ MYSQL_DATA_TYPE.VARCHAR: {
244
+ 'CHARACTER_MAXIMUM_LENGTH': 1024, # NOTE mandatory for field creation
245
+ 'CHARACTER_OCTET_LENGTH': 4096, # NOTE mandatory for field creation
246
+ 'CHARACTER_SET_NAME': 'utf8',
247
+ 'COLLATION_NAME': 'utf8_bin',
248
+ 'COLUMN_TYPE': 'varchar(1024)'
249
+ },
250
+ MYSQL_DATA_TYPE.VARBINARY: {
251
+ 'CHARACTER_MAXIMUM_LENGTH': 1024, # NOTE mandatory for field creation
252
+ 'CHARACTER_OCTET_LENGTH': 1024, # NOTE mandatory for field creation
253
+ 'COLUMN_TYPE': 'varbinary(1024)'
254
+ },
255
+ MYSQL_DATA_TYPE.TINYBLOB: {
256
+ 'CHARACTER_MAXIMUM_LENGTH': 255,
257
+ 'CHARACTER_OCTET_LENGTH': 255
258
+ },
259
+ MYSQL_DATA_TYPE.TINYTEXT: {
260
+ 'CHARACTER_MAXIMUM_LENGTH': 255,
261
+ 'CHARACTER_OCTET_LENGTH': 255,
262
+ 'CHARACTER_SET_NAME': 'utf8',
263
+ 'COLLATION_NAME': 'utf8_bin'
264
+ },
265
+ MYSQL_DATA_TYPE.BLOB: {
266
+ 'CHARACTER_MAXIMUM_LENGTH': 65535,
267
+ 'CHARACTER_OCTET_LENGTH': 65535
268
+ },
269
+ MYSQL_DATA_TYPE.TEXT: {
270
+ 'CHARACTER_MAXIMUM_LENGTH': 65535,
271
+ 'CHARACTER_OCTET_LENGTH': 65535,
272
+ 'CHARACTER_SET_NAME': 'utf8',
273
+ 'COLLATION_NAME': 'utf8_bin'
274
+ },
275
+ MYSQL_DATA_TYPE.MEDIUMBLOB: {
276
+ 'CHARACTER_MAXIMUM_LENGTH': 16777215,
277
+ 'CHARACTER_OCTET_LENGTH': 16777215
278
+ },
279
+ MYSQL_DATA_TYPE.MEDIUMTEXT: {
280
+ 'CHARACTER_MAXIMUM_LENGTH': 16777215,
281
+ 'CHARACTER_OCTET_LENGTH': 16777215,
282
+ 'CHARACTER_SET_NAME': 'utf8',
283
+ 'COLLATION_NAME': 'utf8_bin'
284
+ },
285
+ MYSQL_DATA_TYPE.LONGBLOB: {
286
+ 'CHARACTER_MAXIMUM_LENGTH': 4294967295,
287
+ 'CHARACTER_OCTET_LENGTH': 4294967295,
288
+ },
289
+ MYSQL_DATA_TYPE.LONGTEXT: {
290
+ 'CHARACTER_MAXIMUM_LENGTH': 4294967295,
291
+ 'CHARACTER_OCTET_LENGTH': 4294967295,
292
+ 'CHARACTER_SET_NAME': 'utf8',
293
+ 'COLLATION_NAME': 'utf8_bin'
294
+ },
295
+ MYSQL_DATA_TYPE.BIT: {
296
+ 'NUMERIC_PRECISION': 1,
297
+ 'COLUMN_TYPE': 'bit(1)'
298
+ # 'NUMERIC_SCALE': null
299
+ },
300
+ MYSQL_DATA_TYPE.BOOL: {
301
+ 'DATA_TYPE': 'tinyint',
302
+ 'NUMERIC_PRECISION': 3,
303
+ 'NUMERIC_SCALE': 0,
304
+ 'COLUMN_TYPE': 'tinyint(1)'
305
+ },
306
+ MYSQL_DATA_TYPE.BOOLEAN: {
307
+ 'DATA_TYPE': 'tinyint',
308
+ 'NUMERIC_PRECISION': 3,
309
+ 'NUMERIC_SCALE': 0,
310
+ 'COLUMN_TYPE': 'tinyint(1)'
311
+ }
312
+ }
313
+
314
+
182
315
  # Map between data types and C types
183
316
  # https://dev.mysql.com/doc/c-api/8.0/en/c-api-prepared-statement-type-codes.html
184
317
  DATA_C_TYPE_MAP = {
@@ -68,6 +68,10 @@ class ChromaDBHandler(VectorStoreHandler):
68
68
  "persist_directory": self.persist_directory,
69
69
  }
70
70
 
71
+ self.create_collection_metadata = {
72
+ "hnsw:space": config.distance,
73
+ }
74
+
71
75
  self._use_handler_storage = False
72
76
 
73
77
  self.connect()
@@ -398,7 +402,7 @@ class ChromaDBHandler(VectorStoreHandler):
398
402
  Insert/Upsert data into ChromaDB collection.
399
403
  If records with same IDs exist, they will be updated.
400
404
  """
401
- collection = self._client.get_or_create_collection(collection_name)
405
+ collection = self._client.get_or_create_collection(collection_name, metadata=self.create_collection_metadata)
402
406
 
403
407
  # Convert metadata from string to dict if needed
404
408
  if TableField.METADATA.value in df.columns:
@@ -484,7 +488,8 @@ class ChromaDBHandler(VectorStoreHandler):
484
488
  """
485
489
  Create a collection with the given name in the ChromaDB database.
486
490
  """
487
- self._client.create_collection(table_name, get_or_create=if_not_exists)
491
+ self._client.create_collection(table_name, get_or_create=if_not_exists,
492
+ metadata=self.create_collection_metadata)
488
493
  self._sync()
489
494
 
490
495
  def drop_table(self, table_name: str, if_exists=True):
@@ -14,6 +14,7 @@ class ChromaHandlerConfig(BaseModel):
14
14
  host: str = None
15
15
  port: str = None
16
16
  password: str = None
17
+ distance: str = 'cosine'
17
18
 
18
19
  class Config:
19
20
  extra = "forbid"
@@ -241,14 +241,23 @@ class SqlServerHandler(DatabaseHandler):
241
241
 
242
242
  query = f"""
243
243
  SELECT
244
- column_name as "Field",
245
- data_type as "Type"
244
+ COLUMN_NAME,
245
+ DATA_TYPE,
246
+ ORDINAL_POSITION,
247
+ COLUMN_DEFAULT,
248
+ IS_NULLABLE,
249
+ CHARACTER_MAXIMUM_LENGTH,
250
+ CHARACTER_OCTET_LENGTH,
251
+ NUMERIC_PRECISION,
252
+ NUMERIC_SCALE,
253
+ DATETIME_PRECISION,
254
+ CHARACTER_SET_NAME,
255
+ COLLATION_NAME
246
256
  FROM
247
257
  information_schema.columns
248
258
  WHERE
249
259
  table_name = '{table_name}'
250
260
  """
251
261
  result = self.native_query(query)
252
- if result.resp_type is RESPONSE_TYPE.TABLE:
253
- result.data_frame['mysql_data_type'] = result.data_frame['Type'].apply(_map_type)
262
+ result.to_columns_table_response(map_type_fn=_map_type)
254
263
  return result
@@ -231,14 +231,23 @@ class MySQLHandler(DatabaseHandler):
231
231
  """
232
232
  q = f"""
233
233
  select
234
- COLUMN_NAME AS FIELD, DATA_TYPE AS TYPE
234
+ COLUMN_NAME,
235
+ DATA_TYPE,
236
+ ORDINAL_POSITION,
237
+ COLUMN_DEFAULT,
238
+ IS_NULLABLE,
239
+ CHARACTER_MAXIMUM_LENGTH,
240
+ CHARACTER_OCTET_LENGTH,
241
+ NUMERIC_PRECISION,
242
+ NUMERIC_SCALE,
243
+ DATETIME_PRECISION,
244
+ CHARACTER_SET_NAME,
245
+ COLLATION_NAME
235
246
  from
236
247
  information_schema.columns
237
248
  where
238
- table_name = '{table_name}'
249
+ table_name = '{table_name}';
239
250
  """
240
251
  result = self.native_query(q)
241
- if result.resp_type is RESPONSE_TYPE.TABLE:
242
- result.data_frame = result.data_frame.rename(columns={'FIELD': 'Field', 'TYPE': 'Type'})
243
- result.data_frame['mysql_data_type'] = result.data_frame['Type'].apply(_map_type)
252
+ result.to_columns_table_response(map_type_fn=_map_type)
244
253
  return result
@@ -4,7 +4,6 @@ import time
4
4
  import math
5
5
 
6
6
  import openai
7
- from openai import OpenAI
8
7
 
9
8
  import tiktoken
10
9
 
@@ -181,17 +180,16 @@ def count_tokens(messages: List[Dict], encoder: tiktoken.core.Encoding, model_na
181
180
  )
182
181
 
183
182
 
184
- def get_available_models(api_key: Text, api_base: Text) -> List[Text]:
183
+ def get_available_models(client) -> List[Text]:
185
184
  """
186
185
  Returns a list of available openai models for the given API key.
187
186
 
188
187
  Args:
189
- api_key (Text): OpenAI API key
190
- api_base (Text): OpenAI API base URL
188
+ client: openai sdk client
191
189
 
192
190
  Returns:
193
191
  List[Text]: List of available models
194
192
  """
195
- res = OpenAI(api_key=api_key, base_url=api_base).models.list()
193
+ res = client.models.list()
196
194
 
197
195
  return [models.id for models in res.data]
@@ -9,7 +9,7 @@ import subprocess
9
9
  import concurrent.futures
10
10
  from typing import Text, Tuple, Dict, List, Optional, Any
11
11
  import openai
12
- from openai import OpenAI, NotFoundError, AuthenticationError
12
+ from openai import OpenAI, AzureOpenAI, NotFoundError, AuthenticationError
13
13
  import numpy as np
14
14
  import pandas as pd
15
15
 
@@ -87,7 +87,7 @@ class OpenAIHandler(BaseMLEngine):
87
87
  if api_key is not None:
88
88
  org = connection_args.get('api_organization')
89
89
  api_base = connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)
90
- client = self._get_client(api_key=api_key, base_url=api_base, org=org)
90
+ client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=connection_args)
91
91
  OpenAIHandler._check_client_connection(client)
92
92
 
93
93
  @staticmethod
@@ -188,7 +188,9 @@ class OpenAIHandler(BaseMLEngine):
188
188
  "temperature",
189
189
  "openai_api_key",
190
190
  "api_organization",
191
- "api_base"
191
+ "api_base",
192
+ "api_version",
193
+ "provider",
192
194
  }
193
195
  )
194
196
 
@@ -204,7 +206,7 @@ class OpenAIHandler(BaseMLEngine):
204
206
  api_key = get_api_key('openai', args, engine_storage=engine_storage)
205
207
  api_base = args.get('api_base') or connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)
206
208
  org = args.get('api_organization')
207
- client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org)
209
+ client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
208
210
  OpenAIHandler._check_client_connection(client)
209
211
 
210
212
  def create(self, target, args: Dict = None, **kwargs: Any) -> None:
@@ -228,7 +230,8 @@ class OpenAIHandler(BaseMLEngine):
228
230
  api_key = get_api_key(self.api_key_name, args, self.engine_storage)
229
231
  connection_args = self.engine_storage.get_connection_args()
230
232
  api_base = args.get('api_base') or connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE') or self.api_base
231
- available_models = get_available_models(api_key, api_base)
233
+ client = self._get_client(api_key=api_key, base_url=api_base, org=args.get('api_organization'), args=args)
234
+ available_models = get_available_models(client)
232
235
 
233
236
  if not args.get('mode'):
234
237
  args['mode'] = self.default_mode
@@ -810,6 +813,7 @@ class OpenAIHandler(BaseMLEngine):
810
813
  api_key=api_key,
811
814
  base_url=args.get('api_base'),
812
815
  org=args.pop('api_organization') if 'api_organization' in args else None,
816
+ args=args
813
817
  )
814
818
 
815
819
  try:
@@ -891,7 +895,8 @@ class OpenAIHandler(BaseMLEngine):
891
895
  client = self._get_client(
892
896
  api_key=api_key,
893
897
  base_url=args.get('api_base'),
894
- org=args.get('api_organization')
898
+ org=args.get('api_organization'),
899
+ args=args,
895
900
  )
896
901
  meta = client.models.retrieve(model_name)
897
902
  except Exception as e:
@@ -935,7 +940,7 @@ class OpenAIHandler(BaseMLEngine):
935
940
 
936
941
  api_base = using_args.get('api_base', os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE))
937
942
  org = using_args.get('api_organization')
938
- client = self._get_client(api_key=api_key, base_url=api_base, org=org)
943
+ client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
939
944
 
940
945
  args = {**using_args, **args}
941
946
  prev_model_name = self.base_model_storage.json_get('args').get('model_name', '')
@@ -1173,7 +1178,7 @@ class OpenAIHandler(BaseMLEngine):
1173
1178
  return ft_stats, result_file_id
1174
1179
 
1175
1180
  @staticmethod
1176
- def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None) -> OpenAI:
1181
+ def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None, args: dict = None) -> OpenAI:
1177
1182
  """
1178
1183
  Get an OpenAI client with the given API key, base URL, and organization.
1179
1184
 
@@ -1185,4 +1190,11 @@ class OpenAIHandler(BaseMLEngine):
1185
1190
  Returns:
1186
1191
  openai.OpenAI: OpenAI client.
1187
1192
  """
1193
+ if args is not None and args.get('provider') == 'azure':
1194
+ return AzureOpenAI(
1195
+ api_key=api_key,
1196
+ azure_endpoint=base_url,
1197
+ api_version=args.get('api_version'),
1198
+ organization=org
1199
+ )
1188
1200
  return OpenAI(api_key=api_key, base_url=base_url, organization=org)
@@ -282,13 +282,23 @@ class OracleHandler(DatabaseHandler):
282
282
  """
283
283
  query = f"""
284
284
  SELECT
285
- column_name AS field,
286
- data_type AS type
285
+ COLUMN_NAME,
286
+ DATA_TYPE,
287
+ COLUMN_ID AS ORDINAL_POSITION,
288
+ DATA_DEFAULT AS COLUMN_DEFAULT,
289
+ CASE NULLABLE WHEN 'Y' THEN 'YES' ELSE 'NO' END AS IS_NULLABLE,
290
+ CHAR_LENGTH AS CHARACTER_MAXIMUM_LENGTH,
291
+ NULL AS CHARACTER_OCTET_LENGTH,
292
+ DATA_PRECISION AS NUMERIC_PRECISION,
293
+ DATA_SCALE AS NUMERIC_SCALE,
294
+ NULL AS DATETIME_PRECISION,
295
+ CHARACTER_SET_NAME,
296
+ NULL AS COLLATION_NAME
287
297
  FROM USER_TAB_COLUMNS
288
298
  WHERE table_name = '{table_name}'
299
+ ORDER BY TABLE_NAME, COLUMN_ID;
289
300
  """
290
301
  result = self.native_query(query)
291
302
  if result.resp_type is RESPONSE_TYPE.TABLE:
292
- result.data_frame.columns = [name.lower() for name in result.data_frame.columns]
293
- result.data_frame['mysql_data_type'] = result.data_frame['type'].apply(_map_type)
303
+ result.to_columns_table_response(map_type_fn=_map_type)
294
304
  return result
@@ -40,8 +40,31 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
40
40
  # we get these from the connection args on PostgresHandler parent
41
41
  self._is_sparse = self.connection_args.get('is_sparse', False)
42
42
  self._vector_size = self.connection_args.get('vector_size', None)
43
- if self._is_sparse and not self._vector_size:
44
- raise ValueError("vector_size is required when is_sparse=True")
43
+
44
+ if self._is_sparse:
45
+ if not self._vector_size:
46
+ raise ValueError("vector_size is required when is_sparse=True")
47
+
48
+ # Use inner product for sparse vectors
49
+ distance_op = "<#>"
50
+
51
+ else:
52
+ distance_op = '<=>'
53
+ if 'distance' in self.connection_args:
54
+ distance_ops = {
55
+ 'l1': '<+>',
56
+ 'l2': '<->',
57
+ 'ip': '<#>', # inner product
58
+ 'cosine': '<=>',
59
+ 'hamming': '<~>',
60
+ 'jaccard': '<%>'
61
+ }
62
+
63
+ distance_op = distance_ops.get(self.connection_args['distance'])
64
+ if distance_op is None:
65
+ raise ValueError(f'Wrong distance type. Allowed options are {list(distance_ops.keys())}')
66
+
67
+ self.distance_op = distance_op
45
68
  self.connect()
46
69
 
47
70
  def _make_connection_args(self):
@@ -224,20 +247,16 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
224
247
  from pgvector.utils import SparseVector
225
248
  embedding = SparseVector(search_vector, self._vector_size)
226
249
  search_vector = embedding.to_text()
227
- # Use inner product for sparse vectors
228
- distance_op = "<#>"
229
250
  else:
230
251
  # Convert list to vector string if needed
231
252
  if isinstance(search_vector, list):
232
253
  search_vector = f"[{','.join(str(x) for x in search_vector)}]"
233
- # Use cosine similarity for dense vectors
234
- distance_op = "<=>"
235
254
 
236
255
  # Calculate distance as part of the query if needed
237
256
  if has_distance:
238
- targets = f"{targets}, (embeddings {distance_op} '{search_vector}') as distance"
257
+ targets = f"{targets}, (embeddings {self.distance_op} '{search_vector}') as distance"
239
258
 
240
- return f"SELECT {targets} FROM {table_name} {where_clause} ORDER BY embeddings {distance_op} '{search_vector}' ASC {limit_clause} {offset_clause} "
259
+ return f"SELECT {targets} FROM {table_name} {where_clause} ORDER BY embeddings {self.distance_op} '{search_vector}' ASC {limit_clause} {offset_clause} "
241
260
 
242
261
  else:
243
262
  # if filter conditions, return rows that satisfy the conditions
@@ -418,18 +437,14 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
418
437
  """
419
438
  table_name = self._check_table(table_name)
420
439
 
421
- data_dict = data.to_dict(orient="list")
422
-
423
- if 'metadata' in data_dict:
424
- data_dict['metadata'] = [json.dumps(i) for i in data_dict['metadata']]
425
- transposed_data = list(zip(*data_dict.values()))
426
-
427
- columns = ", ".join(data.keys())
428
- values = ", ".join(["%s"] * len(data.keys()))
440
+ if 'metadata' in data.columns:
441
+ data['metadata'] = data['metadata'].apply(json.dumps)
429
442
 
430
- insert_statement = f"INSERT INTO {table_name} ({columns}) VALUES ({values})"
431
-
432
- self.raw_query(insert_statement, params=transposed_data)
443
+ resp = super().insert(table_name, data)
444
+ if resp.resp_type == RESPONSE_TYPE.ERROR:
445
+ raise RuntimeError(resp.error_message)
446
+ if resp.resp_type == RESPONSE_TYPE.TABLE:
447
+ return resp.data_frame
433
448
 
434
449
  def update(
435
450
  self, table_name: str, data: pd.DataFrame, key_columns: List[str] = None
@@ -1,7 +1,6 @@
1
1
  import time
2
2
  import json
3
3
  from typing import Optional
4
- import threading
5
4
 
6
5
  import pandas as pd
7
6
  import psycopg
@@ -43,7 +42,8 @@ def _map_type(internal_type_name: str) -> MYSQL_DATA_TYPE:
43
42
  ('real', 'money', 'float'): MYSQL_DATA_TYPE.FLOAT,
44
43
  ('numeric', 'decimal'): MYSQL_DATA_TYPE.DECIMAL,
45
44
  ('double precision',): MYSQL_DATA_TYPE.DOUBLE,
46
- ('character varying', 'varchar', 'character', 'char', 'bpchar', 'bpchar', 'text'): MYSQL_DATA_TYPE.TEXT,
45
+ ('character varying', 'varchar'): MYSQL_DATA_TYPE.VARCHAR,
46
+ ('character', 'char', 'bpchar', 'bpchar', 'text'): MYSQL_DATA_TYPE.TEXT,
47
47
  ('timestamp', 'timestamp without time zone', 'timestamp with time zone'): MYSQL_DATA_TYPE.DATETIME,
48
48
  ('date', ): MYSQL_DATA_TYPE.DATE,
49
49
  ('time', 'time without time zone', 'time with time zone'): MYSQL_DATA_TYPE.TIME,
@@ -76,9 +76,7 @@ class PostgresHandler(DatabaseHandler):
76
76
 
77
77
  self.connection = None
78
78
  self.is_connected = False
79
- self.thread_safe = True
80
-
81
- self._insert_lock = threading.Lock()
79
+ self.thread_safe = False
82
80
 
83
81
  def __del__(self):
84
82
  if self.is_connected:
@@ -266,15 +264,13 @@ class PostgresHandler(DatabaseHandler):
266
264
 
267
265
  columns = df.columns
268
266
 
269
- # postgres 'copy' is not thread safe. use lock to prevent concurrent execution
270
- with self._insert_lock:
271
- resp = self.get_columns(table_name)
267
+ resp = self.get_columns(table_name)
272
268
 
273
269
  # copy requires precise cases of names: get current column names from table and adapt input dataframe columns
274
270
  if resp.data_frame is not None and not resp.data_frame.empty:
275
271
  db_columns = {
276
272
  c.lower(): c
277
- for c in resp.data_frame['field']
273
+ for c in resp.data_frame['COLUMN_NAME']
278
274
  }
279
275
 
280
276
  # try to get case of existing column
@@ -288,11 +284,10 @@ class PostgresHandler(DatabaseHandler):
288
284
 
289
285
  with connection.cursor() as cur:
290
286
  try:
291
- with self._insert_lock:
292
- with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
293
- df.to_csv(copy, index=False, header=False)
287
+ with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
288
+ df.to_csv(copy, index=False, header=False)
294
289
 
295
- connection.commit()
290
+ connection.commit()
296
291
  except Exception as e:
297
292
  logger.error(f'Error running insert to {table_name} on {self.database}, {e}!')
298
293
  connection.rollback()
@@ -366,8 +361,18 @@ class PostgresHandler(DatabaseHandler):
366
361
  schema_name = 'current_schema()'
367
362
  query = f"""
368
363
  SELECT
369
- column_name as "Field",
370
- data_type as "Type"
364
+ COLUMN_NAME,
365
+ DATA_TYPE,
366
+ ORDINAL_POSITION,
367
+ COLUMN_DEFAULT,
368
+ IS_NULLABLE,
369
+ CHARACTER_MAXIMUM_LENGTH,
370
+ CHARACTER_OCTET_LENGTH,
371
+ NUMERIC_PRECISION,
372
+ NUMERIC_SCALE,
373
+ DATETIME_PRECISION,
374
+ CHARACTER_SET_NAME,
375
+ COLLATION_NAME
371
376
  FROM
372
377
  information_schema.columns
373
378
  WHERE
@@ -376,9 +381,7 @@ class PostgresHandler(DatabaseHandler):
376
381
  table_schema = {schema_name}
377
382
  """
378
383
  result = self.native_query(query)
379
- if result.resp_type is RESPONSE_TYPE.TABLE:
380
- result.data_frame.columns = [name.lower() for name in result.data_frame.columns]
381
- result.data_frame['mysql_data_type'] = result.data_frame['type'].apply(_map_type)
384
+ result.to_columns_table_response(map_type_fn=_map_type)
382
385
  return result
383
386
 
384
387
  def subscribe(self, stop_event, callback, table_name, columns=None, **kwargs):