MindsDB 25.4.3.2__py3-none-any.whl → 25.4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +18 -4
- mindsdb/api/executor/command_executor.py +12 -2
- mindsdb/api/executor/data_types/response_type.py +1 -0
- mindsdb/api/executor/datahub/classes/tables_row.py +3 -10
- mindsdb/api/executor/datahub/datanodes/datanode.py +7 -2
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +44 -10
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +57 -38
- mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +2 -1
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +39 -7
- mindsdb/api/executor/datahub/datanodes/system_tables.py +116 -109
- mindsdb/api/executor/planner/query_plan.py +1 -0
- mindsdb/api/executor/planner/query_planner.py +15 -1
- mindsdb/api/executor/planner/steps.py +8 -2
- mindsdb/api/executor/sql_query/sql_query.py +24 -8
- mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +25 -8
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +4 -2
- mindsdb/api/executor/sql_query/steps/insert_step.py +2 -1
- mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -3
- mindsdb/api/http/namespaces/config.py +19 -11
- mindsdb/api/litellm/start.py +82 -0
- mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +133 -0
- mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +7 -2
- mindsdb/integrations/handlers/chromadb_handler/settings.py +1 -0
- mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +13 -4
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +14 -5
- mindsdb/integrations/handlers/openai_handler/helpers.py +3 -5
- mindsdb/integrations/handlers/openai_handler/openai_handler.py +20 -8
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +14 -4
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +34 -19
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +21 -18
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +14 -4
- mindsdb/integrations/handlers/togetherai_handler/__about__.py +9 -0
- mindsdb/integrations/handlers/togetherai_handler/__init__.py +20 -0
- mindsdb/integrations/handlers/togetherai_handler/creation_args.py +14 -0
- mindsdb/integrations/handlers/togetherai_handler/icon.svg +15 -0
- mindsdb/integrations/handlers/togetherai_handler/model_using_args.py +5 -0
- mindsdb/integrations/handlers/togetherai_handler/requirements.txt +2 -0
- mindsdb/integrations/handlers/togetherai_handler/settings.py +33 -0
- mindsdb/integrations/handlers/togetherai_handler/togetherai_handler.py +234 -0
- mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +1 -1
- mindsdb/integrations/libs/response.py +80 -32
- mindsdb/integrations/utilities/handler_utils.py +4 -0
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +360 -0
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +8 -153
- mindsdb/interfaces/agents/litellm_server.py +345 -0
- mindsdb/interfaces/agents/mcp_client_agent.py +252 -0
- mindsdb/interfaces/agents/run_mcp_agent.py +205 -0
- mindsdb/interfaces/functions/controller.py +3 -2
- mindsdb/interfaces/knowledge_base/controller.py +106 -82
- mindsdb/interfaces/query_context/context_controller.py +55 -15
- mindsdb/interfaces/query_context/query_task.py +19 -0
- mindsdb/interfaces/skills/skill_tool.py +7 -1
- mindsdb/interfaces/skills/sql_agent.py +8 -3
- mindsdb/interfaces/storage/db.py +2 -2
- mindsdb/interfaces/tasks/task_monitor.py +5 -1
- mindsdb/interfaces/tasks/task_thread.py +6 -0
- mindsdb/migrations/versions/2025-04-22_53502b6d63bf_query_database.py +27 -0
- mindsdb/utilities/config.py +20 -2
- mindsdb/utilities/context.py +1 -0
- mindsdb/utilities/starters.py +7 -0
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/METADATA +226 -221
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/RECORD +67 -53
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/WHEEL +1 -1
- mindsdb/integrations/handlers/snowflake_handler/tests/test_snowflake_handler.py +0 -230
- /mindsdb/{integrations/handlers/snowflake_handler/tests → api/litellm}/__init__.py +0 -0
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from mindsdb.utilities import log
|
|
3
|
+
from mindsdb.utilities.config import Config
|
|
4
|
+
from mindsdb.interfaces.agents.litellm_server import run_server, run_server_async
|
|
5
|
+
|
|
6
|
+
logger = log.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
async def start_async(verbose=False):
|
|
10
|
+
"""Start the LiteLLM server
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
verbose (bool): Whether to enable verbose logging
|
|
14
|
+
"""
|
|
15
|
+
config = Config()
|
|
16
|
+
|
|
17
|
+
# Get agent name from command line args
|
|
18
|
+
agent_name = config.cmd_args.agent
|
|
19
|
+
if not agent_name:
|
|
20
|
+
logger.error("Agent name is required for LiteLLM server. Use --agent parameter.")
|
|
21
|
+
return 1
|
|
22
|
+
|
|
23
|
+
# Get project name or use default
|
|
24
|
+
project_name = config.cmd_args.project or "mindsdb"
|
|
25
|
+
|
|
26
|
+
# Get MCP server connection details
|
|
27
|
+
mcp_host = config.get('api', {}).get('mcp', {}).get('host', '127.0.0.1')
|
|
28
|
+
mcp_port = int(config.get('api', {}).get('mcp', {}).get('port', 47337))
|
|
29
|
+
|
|
30
|
+
# Get LiteLLM server settings
|
|
31
|
+
litellm_host = config.get('api', {}).get('litellm', {}).get('host', '0.0.0.0')
|
|
32
|
+
litellm_port = int(config.get('api', {}).get('litellm', {}).get('port', 8000))
|
|
33
|
+
|
|
34
|
+
logger.info(f"Starting LiteLLM server for agent '{agent_name}' in project '{project_name}'")
|
|
35
|
+
logger.info(f"Connecting to MCP server at {mcp_host}:{mcp_port}")
|
|
36
|
+
logger.info(f"Binding to {litellm_host}:{litellm_port}")
|
|
37
|
+
|
|
38
|
+
return await run_server_async(
|
|
39
|
+
agent_name=agent_name,
|
|
40
|
+
project_name=project_name,
|
|
41
|
+
mcp_host=mcp_host,
|
|
42
|
+
mcp_port=mcp_port,
|
|
43
|
+
host=litellm_host,
|
|
44
|
+
port=litellm_port
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def start(verbose=False):
|
|
49
|
+
"""Start the LiteLLM server (synchronous wrapper)
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
verbose (bool): Whether to enable verbose logging
|
|
53
|
+
"""
|
|
54
|
+
from mindsdb.interfaces.storage import db
|
|
55
|
+
db.init()
|
|
56
|
+
|
|
57
|
+
# Run the async function in the event loop
|
|
58
|
+
loop = asyncio.new_event_loop()
|
|
59
|
+
asyncio.set_event_loop(loop)
|
|
60
|
+
result = loop.run_until_complete(start_async(verbose))
|
|
61
|
+
|
|
62
|
+
if result == 0:
|
|
63
|
+
# Run the server
|
|
64
|
+
config = Config()
|
|
65
|
+
agent_name = config.cmd_args.agent
|
|
66
|
+
project_name = config.cmd_args.project or "mindsdb"
|
|
67
|
+
mcp_host = config.get('api', {}).get('mcp', {}).get('host', '127.0.0.1')
|
|
68
|
+
mcp_port = int(config.get('api', {}).get('mcp', {}).get('port', 47337))
|
|
69
|
+
litellm_host = config.get('api', {}).get('litellm', {}).get('host', '0.0.0.0')
|
|
70
|
+
litellm_port = int(config.get('api', {}).get('litellm', {}).get('port', 8000))
|
|
71
|
+
|
|
72
|
+
return run_server(
|
|
73
|
+
agent_name=agent_name,
|
|
74
|
+
project_name=project_name,
|
|
75
|
+
mcp_host=mcp_host,
|
|
76
|
+
mcp_port=mcp_port,
|
|
77
|
+
host=litellm_host,
|
|
78
|
+
port=litellm_port
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
logger.error("LiteLLM server initialization failed")
|
|
82
|
+
return result
|
|
@@ -179,6 +179,139 @@ class MYSQL_DATA_TYPE(enum.Enum):
|
|
|
179
179
|
BOOLEAN = 'BOOLEAN'
|
|
180
180
|
|
|
181
181
|
|
|
182
|
+
# Default values for attributes of MySQL data types as they appear in information_schema.columns
|
|
183
|
+
# These values match the MySQL v8.0.37 defaults and are used to properly represent column metadata
|
|
184
|
+
MYSQL_DATA_TYPE_COLUMNS_DEFAULT = {
|
|
185
|
+
MYSQL_DATA_TYPE.TINYINT: {
|
|
186
|
+
'NUMERIC_PRECISION': 3,
|
|
187
|
+
'NUMERIC_SCALE': 0
|
|
188
|
+
},
|
|
189
|
+
MYSQL_DATA_TYPE.SMALLINT: {
|
|
190
|
+
'NUMERIC_PRECISION': 5,
|
|
191
|
+
'NUMERIC_SCALE': 0
|
|
192
|
+
},
|
|
193
|
+
MYSQL_DATA_TYPE.MEDIUMINT: {
|
|
194
|
+
'NUMERIC_PRECISION': 7,
|
|
195
|
+
'NUMERIC_SCALE': 0
|
|
196
|
+
},
|
|
197
|
+
MYSQL_DATA_TYPE.INT: {
|
|
198
|
+
'NUMERIC_PRECISION': 10,
|
|
199
|
+
'NUMERIC_SCALE': 0
|
|
200
|
+
},
|
|
201
|
+
MYSQL_DATA_TYPE.BIGINT: {
|
|
202
|
+
'NUMERIC_PRECISION': 19,
|
|
203
|
+
'NUMERIC_SCALE': 0
|
|
204
|
+
},
|
|
205
|
+
MYSQL_DATA_TYPE.FLOAT: {
|
|
206
|
+
'NUMERIC_PRECISION': 12
|
|
207
|
+
},
|
|
208
|
+
MYSQL_DATA_TYPE.DOUBLE: {
|
|
209
|
+
'NUMERIC_PRECISION': 22
|
|
210
|
+
},
|
|
211
|
+
MYSQL_DATA_TYPE.DECIMAL: {
|
|
212
|
+
'NUMERIC_PRECISION': 10,
|
|
213
|
+
'NUMERIC_SCALE': 0,
|
|
214
|
+
'COLUMN_TYPE': 'decimal(10,0)'
|
|
215
|
+
},
|
|
216
|
+
MYSQL_DATA_TYPE.YEAR: {
|
|
217
|
+
# every column is null
|
|
218
|
+
},
|
|
219
|
+
MYSQL_DATA_TYPE.TIME: {
|
|
220
|
+
'DATETIME_PRECISION': 0
|
|
221
|
+
},
|
|
222
|
+
MYSQL_DATA_TYPE.DATE: {
|
|
223
|
+
# every column is null
|
|
224
|
+
},
|
|
225
|
+
MYSQL_DATA_TYPE.DATETIME: {
|
|
226
|
+
'DATETIME_PRECISION': 0
|
|
227
|
+
},
|
|
228
|
+
MYSQL_DATA_TYPE.TIMESTAMP: {
|
|
229
|
+
'DATETIME_PRECISION': 0
|
|
230
|
+
},
|
|
231
|
+
MYSQL_DATA_TYPE.CHAR: {
|
|
232
|
+
'CHARACTER_MAXIMUM_LENGTH': 1,
|
|
233
|
+
'CHARACTER_OCTET_LENGTH': 4,
|
|
234
|
+
'CHARACTER_SET_NAME': 'utf8',
|
|
235
|
+
'COLLATION_NAME': 'utf8_bin',
|
|
236
|
+
'COLUMN_TYPE': 'char(1)'
|
|
237
|
+
},
|
|
238
|
+
MYSQL_DATA_TYPE.BINARY: {
|
|
239
|
+
'CHARACTER_MAXIMUM_LENGTH': 1,
|
|
240
|
+
'CHARACTER_OCTET_LENGTH': 1,
|
|
241
|
+
'COLUMN_TYPE': 'binary(1)'
|
|
242
|
+
},
|
|
243
|
+
MYSQL_DATA_TYPE.VARCHAR: {
|
|
244
|
+
'CHARACTER_MAXIMUM_LENGTH': 1024, # NOTE mandatory for field creation
|
|
245
|
+
'CHARACTER_OCTET_LENGTH': 4096, # NOTE mandatory for field creation
|
|
246
|
+
'CHARACTER_SET_NAME': 'utf8',
|
|
247
|
+
'COLLATION_NAME': 'utf8_bin',
|
|
248
|
+
'COLUMN_TYPE': 'varchar(1024)'
|
|
249
|
+
},
|
|
250
|
+
MYSQL_DATA_TYPE.VARBINARY: {
|
|
251
|
+
'CHARACTER_MAXIMUM_LENGTH': 1024, # NOTE mandatory for field creation
|
|
252
|
+
'CHARACTER_OCTET_LENGTH': 1024, # NOTE mandatory for field creation
|
|
253
|
+
'COLUMN_TYPE': 'varbinary(1024)'
|
|
254
|
+
},
|
|
255
|
+
MYSQL_DATA_TYPE.TINYBLOB: {
|
|
256
|
+
'CHARACTER_MAXIMUM_LENGTH': 255,
|
|
257
|
+
'CHARACTER_OCTET_LENGTH': 255
|
|
258
|
+
},
|
|
259
|
+
MYSQL_DATA_TYPE.TINYTEXT: {
|
|
260
|
+
'CHARACTER_MAXIMUM_LENGTH': 255,
|
|
261
|
+
'CHARACTER_OCTET_LENGTH': 255,
|
|
262
|
+
'CHARACTER_SET_NAME': 'utf8',
|
|
263
|
+
'COLLATION_NAME': 'utf8_bin'
|
|
264
|
+
},
|
|
265
|
+
MYSQL_DATA_TYPE.BLOB: {
|
|
266
|
+
'CHARACTER_MAXIMUM_LENGTH': 65535,
|
|
267
|
+
'CHARACTER_OCTET_LENGTH': 65535
|
|
268
|
+
},
|
|
269
|
+
MYSQL_DATA_TYPE.TEXT: {
|
|
270
|
+
'CHARACTER_MAXIMUM_LENGTH': 65535,
|
|
271
|
+
'CHARACTER_OCTET_LENGTH': 65535,
|
|
272
|
+
'CHARACTER_SET_NAME': 'utf8',
|
|
273
|
+
'COLLATION_NAME': 'utf8_bin'
|
|
274
|
+
},
|
|
275
|
+
MYSQL_DATA_TYPE.MEDIUMBLOB: {
|
|
276
|
+
'CHARACTER_MAXIMUM_LENGTH': 16777215,
|
|
277
|
+
'CHARACTER_OCTET_LENGTH': 16777215
|
|
278
|
+
},
|
|
279
|
+
MYSQL_DATA_TYPE.MEDIUMTEXT: {
|
|
280
|
+
'CHARACTER_MAXIMUM_LENGTH': 16777215,
|
|
281
|
+
'CHARACTER_OCTET_LENGTH': 16777215,
|
|
282
|
+
'CHARACTER_SET_NAME': 'utf8',
|
|
283
|
+
'COLLATION_NAME': 'utf8_bin'
|
|
284
|
+
},
|
|
285
|
+
MYSQL_DATA_TYPE.LONGBLOB: {
|
|
286
|
+
'CHARACTER_MAXIMUM_LENGTH': 4294967295,
|
|
287
|
+
'CHARACTER_OCTET_LENGTH': 4294967295,
|
|
288
|
+
},
|
|
289
|
+
MYSQL_DATA_TYPE.LONGTEXT: {
|
|
290
|
+
'CHARACTER_MAXIMUM_LENGTH': 4294967295,
|
|
291
|
+
'CHARACTER_OCTET_LENGTH': 4294967295,
|
|
292
|
+
'CHARACTER_SET_NAME': 'utf8',
|
|
293
|
+
'COLLATION_NAME': 'utf8_bin'
|
|
294
|
+
},
|
|
295
|
+
MYSQL_DATA_TYPE.BIT: {
|
|
296
|
+
'NUMERIC_PRECISION': 1,
|
|
297
|
+
'COLUMN_TYPE': 'bit(1)'
|
|
298
|
+
# 'NUMERIC_SCALE': null
|
|
299
|
+
},
|
|
300
|
+
MYSQL_DATA_TYPE.BOOL: {
|
|
301
|
+
'DATA_TYPE': 'tinyint',
|
|
302
|
+
'NUMERIC_PRECISION': 3,
|
|
303
|
+
'NUMERIC_SCALE': 0,
|
|
304
|
+
'COLUMN_TYPE': 'tinyint(1)'
|
|
305
|
+
},
|
|
306
|
+
MYSQL_DATA_TYPE.BOOLEAN: {
|
|
307
|
+
'DATA_TYPE': 'tinyint',
|
|
308
|
+
'NUMERIC_PRECISION': 3,
|
|
309
|
+
'NUMERIC_SCALE': 0,
|
|
310
|
+
'COLUMN_TYPE': 'tinyint(1)'
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
|
|
182
315
|
# Map between data types and C types
|
|
183
316
|
# https://dev.mysql.com/doc/c-api/8.0/en/c-api-prepared-statement-type-codes.html
|
|
184
317
|
DATA_C_TYPE_MAP = {
|
|
@@ -68,6 +68,10 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
68
68
|
"persist_directory": self.persist_directory,
|
|
69
69
|
}
|
|
70
70
|
|
|
71
|
+
self.create_collection_metadata = {
|
|
72
|
+
"hnsw:space": config.distance,
|
|
73
|
+
}
|
|
74
|
+
|
|
71
75
|
self._use_handler_storage = False
|
|
72
76
|
|
|
73
77
|
self.connect()
|
|
@@ -398,7 +402,7 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
398
402
|
Insert/Upsert data into ChromaDB collection.
|
|
399
403
|
If records with same IDs exist, they will be updated.
|
|
400
404
|
"""
|
|
401
|
-
collection = self._client.get_or_create_collection(collection_name)
|
|
405
|
+
collection = self._client.get_or_create_collection(collection_name, metadata=self.create_collection_metadata)
|
|
402
406
|
|
|
403
407
|
# Convert metadata from string to dict if needed
|
|
404
408
|
if TableField.METADATA.value in df.columns:
|
|
@@ -484,7 +488,8 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
484
488
|
"""
|
|
485
489
|
Create a collection with the given name in the ChromaDB database.
|
|
486
490
|
"""
|
|
487
|
-
self._client.create_collection(table_name, get_or_create=if_not_exists
|
|
491
|
+
self._client.create_collection(table_name, get_or_create=if_not_exists,
|
|
492
|
+
metadata=self.create_collection_metadata)
|
|
488
493
|
self._sync()
|
|
489
494
|
|
|
490
495
|
def drop_table(self, table_name: str, if_exists=True):
|
|
@@ -241,14 +241,23 @@ class SqlServerHandler(DatabaseHandler):
|
|
|
241
241
|
|
|
242
242
|
query = f"""
|
|
243
243
|
SELECT
|
|
244
|
-
|
|
245
|
-
|
|
244
|
+
COLUMN_NAME,
|
|
245
|
+
DATA_TYPE,
|
|
246
|
+
ORDINAL_POSITION,
|
|
247
|
+
COLUMN_DEFAULT,
|
|
248
|
+
IS_NULLABLE,
|
|
249
|
+
CHARACTER_MAXIMUM_LENGTH,
|
|
250
|
+
CHARACTER_OCTET_LENGTH,
|
|
251
|
+
NUMERIC_PRECISION,
|
|
252
|
+
NUMERIC_SCALE,
|
|
253
|
+
DATETIME_PRECISION,
|
|
254
|
+
CHARACTER_SET_NAME,
|
|
255
|
+
COLLATION_NAME
|
|
246
256
|
FROM
|
|
247
257
|
information_schema.columns
|
|
248
258
|
WHERE
|
|
249
259
|
table_name = '{table_name}'
|
|
250
260
|
"""
|
|
251
261
|
result = self.native_query(query)
|
|
252
|
-
|
|
253
|
-
result.data_frame['mysql_data_type'] = result.data_frame['Type'].apply(_map_type)
|
|
262
|
+
result.to_columns_table_response(map_type_fn=_map_type)
|
|
254
263
|
return result
|
|
@@ -231,14 +231,23 @@ class MySQLHandler(DatabaseHandler):
|
|
|
231
231
|
"""
|
|
232
232
|
q = f"""
|
|
233
233
|
select
|
|
234
|
-
COLUMN_NAME
|
|
234
|
+
COLUMN_NAME,
|
|
235
|
+
DATA_TYPE,
|
|
236
|
+
ORDINAL_POSITION,
|
|
237
|
+
COLUMN_DEFAULT,
|
|
238
|
+
IS_NULLABLE,
|
|
239
|
+
CHARACTER_MAXIMUM_LENGTH,
|
|
240
|
+
CHARACTER_OCTET_LENGTH,
|
|
241
|
+
NUMERIC_PRECISION,
|
|
242
|
+
NUMERIC_SCALE,
|
|
243
|
+
DATETIME_PRECISION,
|
|
244
|
+
CHARACTER_SET_NAME,
|
|
245
|
+
COLLATION_NAME
|
|
235
246
|
from
|
|
236
247
|
information_schema.columns
|
|
237
248
|
where
|
|
238
|
-
table_name = '{table_name}'
|
|
249
|
+
table_name = '{table_name}';
|
|
239
250
|
"""
|
|
240
251
|
result = self.native_query(q)
|
|
241
|
-
|
|
242
|
-
result.data_frame = result.data_frame.rename(columns={'FIELD': 'Field', 'TYPE': 'Type'})
|
|
243
|
-
result.data_frame['mysql_data_type'] = result.data_frame['Type'].apply(_map_type)
|
|
252
|
+
result.to_columns_table_response(map_type_fn=_map_type)
|
|
244
253
|
return result
|
|
@@ -4,7 +4,6 @@ import time
|
|
|
4
4
|
import math
|
|
5
5
|
|
|
6
6
|
import openai
|
|
7
|
-
from openai import OpenAI
|
|
8
7
|
|
|
9
8
|
import tiktoken
|
|
10
9
|
|
|
@@ -181,17 +180,16 @@ def count_tokens(messages: List[Dict], encoder: tiktoken.core.Encoding, model_na
|
|
|
181
180
|
)
|
|
182
181
|
|
|
183
182
|
|
|
184
|
-
def get_available_models(
|
|
183
|
+
def get_available_models(client) -> List[Text]:
|
|
185
184
|
"""
|
|
186
185
|
Returns a list of available openai models for the given API key.
|
|
187
186
|
|
|
188
187
|
Args:
|
|
189
|
-
|
|
190
|
-
api_base (Text): OpenAI API base URL
|
|
188
|
+
client: openai sdk client
|
|
191
189
|
|
|
192
190
|
Returns:
|
|
193
191
|
List[Text]: List of available models
|
|
194
192
|
"""
|
|
195
|
-
res =
|
|
193
|
+
res = client.models.list()
|
|
196
194
|
|
|
197
195
|
return [models.id for models in res.data]
|
|
@@ -9,7 +9,7 @@ import subprocess
|
|
|
9
9
|
import concurrent.futures
|
|
10
10
|
from typing import Text, Tuple, Dict, List, Optional, Any
|
|
11
11
|
import openai
|
|
12
|
-
from openai import OpenAI, NotFoundError, AuthenticationError
|
|
12
|
+
from openai import OpenAI, AzureOpenAI, NotFoundError, AuthenticationError
|
|
13
13
|
import numpy as np
|
|
14
14
|
import pandas as pd
|
|
15
15
|
|
|
@@ -87,7 +87,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
87
87
|
if api_key is not None:
|
|
88
88
|
org = connection_args.get('api_organization')
|
|
89
89
|
api_base = connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)
|
|
90
|
-
client = self._get_client(api_key=api_key, base_url=api_base, org=org)
|
|
90
|
+
client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=connection_args)
|
|
91
91
|
OpenAIHandler._check_client_connection(client)
|
|
92
92
|
|
|
93
93
|
@staticmethod
|
|
@@ -188,7 +188,9 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
188
188
|
"temperature",
|
|
189
189
|
"openai_api_key",
|
|
190
190
|
"api_organization",
|
|
191
|
-
"api_base"
|
|
191
|
+
"api_base",
|
|
192
|
+
"api_version",
|
|
193
|
+
"provider",
|
|
192
194
|
}
|
|
193
195
|
)
|
|
194
196
|
|
|
@@ -204,7 +206,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
204
206
|
api_key = get_api_key('openai', args, engine_storage=engine_storage)
|
|
205
207
|
api_base = args.get('api_base') or connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)
|
|
206
208
|
org = args.get('api_organization')
|
|
207
|
-
client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org)
|
|
209
|
+
client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
|
|
208
210
|
OpenAIHandler._check_client_connection(client)
|
|
209
211
|
|
|
210
212
|
def create(self, target, args: Dict = None, **kwargs: Any) -> None:
|
|
@@ -228,7 +230,8 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
228
230
|
api_key = get_api_key(self.api_key_name, args, self.engine_storage)
|
|
229
231
|
connection_args = self.engine_storage.get_connection_args()
|
|
230
232
|
api_base = args.get('api_base') or connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE') or self.api_base
|
|
231
|
-
|
|
233
|
+
client = self._get_client(api_key=api_key, base_url=api_base, org=args.get('api_organization'), args=args)
|
|
234
|
+
available_models = get_available_models(client)
|
|
232
235
|
|
|
233
236
|
if not args.get('mode'):
|
|
234
237
|
args['mode'] = self.default_mode
|
|
@@ -810,6 +813,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
810
813
|
api_key=api_key,
|
|
811
814
|
base_url=args.get('api_base'),
|
|
812
815
|
org=args.pop('api_organization') if 'api_organization' in args else None,
|
|
816
|
+
args=args
|
|
813
817
|
)
|
|
814
818
|
|
|
815
819
|
try:
|
|
@@ -891,7 +895,8 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
891
895
|
client = self._get_client(
|
|
892
896
|
api_key=api_key,
|
|
893
897
|
base_url=args.get('api_base'),
|
|
894
|
-
org=args.get('api_organization')
|
|
898
|
+
org=args.get('api_organization'),
|
|
899
|
+
args=args,
|
|
895
900
|
)
|
|
896
901
|
meta = client.models.retrieve(model_name)
|
|
897
902
|
except Exception as e:
|
|
@@ -935,7 +940,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
935
940
|
|
|
936
941
|
api_base = using_args.get('api_base', os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE))
|
|
937
942
|
org = using_args.get('api_organization')
|
|
938
|
-
client = self._get_client(api_key=api_key, base_url=api_base, org=org)
|
|
943
|
+
client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
|
|
939
944
|
|
|
940
945
|
args = {**using_args, **args}
|
|
941
946
|
prev_model_name = self.base_model_storage.json_get('args').get('model_name', '')
|
|
@@ -1173,7 +1178,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1173
1178
|
return ft_stats, result_file_id
|
|
1174
1179
|
|
|
1175
1180
|
@staticmethod
|
|
1176
|
-
def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None) -> OpenAI:
|
|
1181
|
+
def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None, args: dict = None) -> OpenAI:
|
|
1177
1182
|
"""
|
|
1178
1183
|
Get an OpenAI client with the given API key, base URL, and organization.
|
|
1179
1184
|
|
|
@@ -1185,4 +1190,11 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1185
1190
|
Returns:
|
|
1186
1191
|
openai.OpenAI: OpenAI client.
|
|
1187
1192
|
"""
|
|
1193
|
+
if args is not None and args.get('provider') == 'azure':
|
|
1194
|
+
return AzureOpenAI(
|
|
1195
|
+
api_key=api_key,
|
|
1196
|
+
azure_endpoint=base_url,
|
|
1197
|
+
api_version=args.get('api_version'),
|
|
1198
|
+
organization=org
|
|
1199
|
+
)
|
|
1188
1200
|
return OpenAI(api_key=api_key, base_url=base_url, organization=org)
|
|
@@ -282,13 +282,23 @@ class OracleHandler(DatabaseHandler):
|
|
|
282
282
|
"""
|
|
283
283
|
query = f"""
|
|
284
284
|
SELECT
|
|
285
|
-
|
|
286
|
-
|
|
285
|
+
COLUMN_NAME,
|
|
286
|
+
DATA_TYPE,
|
|
287
|
+
COLUMN_ID AS ORDINAL_POSITION,
|
|
288
|
+
DATA_DEFAULT AS COLUMN_DEFAULT,
|
|
289
|
+
CASE NULLABLE WHEN 'Y' THEN 'YES' ELSE 'NO' END AS IS_NULLABLE,
|
|
290
|
+
CHAR_LENGTH AS CHARACTER_MAXIMUM_LENGTH,
|
|
291
|
+
NULL AS CHARACTER_OCTET_LENGTH,
|
|
292
|
+
DATA_PRECISION AS NUMERIC_PRECISION,
|
|
293
|
+
DATA_SCALE AS NUMERIC_SCALE,
|
|
294
|
+
NULL AS DATETIME_PRECISION,
|
|
295
|
+
CHARACTER_SET_NAME,
|
|
296
|
+
NULL AS COLLATION_NAME
|
|
287
297
|
FROM USER_TAB_COLUMNS
|
|
288
298
|
WHERE table_name = '{table_name}'
|
|
299
|
+
ORDER BY TABLE_NAME, COLUMN_ID;
|
|
289
300
|
"""
|
|
290
301
|
result = self.native_query(query)
|
|
291
302
|
if result.resp_type is RESPONSE_TYPE.TABLE:
|
|
292
|
-
result.
|
|
293
|
-
result.data_frame['mysql_data_type'] = result.data_frame['type'].apply(_map_type)
|
|
303
|
+
result.to_columns_table_response(map_type_fn=_map_type)
|
|
294
304
|
return result
|
|
@@ -40,8 +40,31 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
40
40
|
# we get these from the connection args on PostgresHandler parent
|
|
41
41
|
self._is_sparse = self.connection_args.get('is_sparse', False)
|
|
42
42
|
self._vector_size = self.connection_args.get('vector_size', None)
|
|
43
|
-
|
|
44
|
-
|
|
43
|
+
|
|
44
|
+
if self._is_sparse:
|
|
45
|
+
if not self._vector_size:
|
|
46
|
+
raise ValueError("vector_size is required when is_sparse=True")
|
|
47
|
+
|
|
48
|
+
# Use inner product for sparse vectors
|
|
49
|
+
distance_op = "<#>"
|
|
50
|
+
|
|
51
|
+
else:
|
|
52
|
+
distance_op = '<=>'
|
|
53
|
+
if 'distance' in self.connection_args:
|
|
54
|
+
distance_ops = {
|
|
55
|
+
'l1': '<+>',
|
|
56
|
+
'l2': '<->',
|
|
57
|
+
'ip': '<#>', # inner product
|
|
58
|
+
'cosine': '<=>',
|
|
59
|
+
'hamming': '<~>',
|
|
60
|
+
'jaccard': '<%>'
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
distance_op = distance_ops.get(self.connection_args['distance'])
|
|
64
|
+
if distance_op is None:
|
|
65
|
+
raise ValueError(f'Wrong distance type. Allowed options are {list(distance_ops.keys())}')
|
|
66
|
+
|
|
67
|
+
self.distance_op = distance_op
|
|
45
68
|
self.connect()
|
|
46
69
|
|
|
47
70
|
def _make_connection_args(self):
|
|
@@ -224,20 +247,16 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
224
247
|
from pgvector.utils import SparseVector
|
|
225
248
|
embedding = SparseVector(search_vector, self._vector_size)
|
|
226
249
|
search_vector = embedding.to_text()
|
|
227
|
-
# Use inner product for sparse vectors
|
|
228
|
-
distance_op = "<#>"
|
|
229
250
|
else:
|
|
230
251
|
# Convert list to vector string if needed
|
|
231
252
|
if isinstance(search_vector, list):
|
|
232
253
|
search_vector = f"[{','.join(str(x) for x in search_vector)}]"
|
|
233
|
-
# Use cosine similarity for dense vectors
|
|
234
|
-
distance_op = "<=>"
|
|
235
254
|
|
|
236
255
|
# Calculate distance as part of the query if needed
|
|
237
256
|
if has_distance:
|
|
238
|
-
targets = f"{targets}, (embeddings {distance_op} '{search_vector}') as distance"
|
|
257
|
+
targets = f"{targets}, (embeddings {self.distance_op} '{search_vector}') as distance"
|
|
239
258
|
|
|
240
|
-
return f"SELECT {targets} FROM {table_name} {where_clause} ORDER BY embeddings {distance_op} '{search_vector}' ASC {limit_clause} {offset_clause} "
|
|
259
|
+
return f"SELECT {targets} FROM {table_name} {where_clause} ORDER BY embeddings {self.distance_op} '{search_vector}' ASC {limit_clause} {offset_clause} "
|
|
241
260
|
|
|
242
261
|
else:
|
|
243
262
|
# if filter conditions, return rows that satisfy the conditions
|
|
@@ -418,18 +437,14 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
418
437
|
"""
|
|
419
438
|
table_name = self._check_table(table_name)
|
|
420
439
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
if 'metadata' in data_dict:
|
|
424
|
-
data_dict['metadata'] = [json.dumps(i) for i in data_dict['metadata']]
|
|
425
|
-
transposed_data = list(zip(*data_dict.values()))
|
|
426
|
-
|
|
427
|
-
columns = ", ".join(data.keys())
|
|
428
|
-
values = ", ".join(["%s"] * len(data.keys()))
|
|
440
|
+
if 'metadata' in data.columns:
|
|
441
|
+
data['metadata'] = data['metadata'].apply(json.dumps)
|
|
429
442
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
443
|
+
resp = super().insert(table_name, data)
|
|
444
|
+
if resp.resp_type == RESPONSE_TYPE.ERROR:
|
|
445
|
+
raise RuntimeError(resp.error_message)
|
|
446
|
+
if resp.resp_type == RESPONSE_TYPE.TABLE:
|
|
447
|
+
return resp.data_frame
|
|
433
448
|
|
|
434
449
|
def update(
|
|
435
450
|
self, table_name: str, data: pd.DataFrame, key_columns: List[str] = None
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import time
|
|
2
2
|
import json
|
|
3
3
|
from typing import Optional
|
|
4
|
-
import threading
|
|
5
4
|
|
|
6
5
|
import pandas as pd
|
|
7
6
|
import psycopg
|
|
@@ -43,7 +42,8 @@ def _map_type(internal_type_name: str) -> MYSQL_DATA_TYPE:
|
|
|
43
42
|
('real', 'money', 'float'): MYSQL_DATA_TYPE.FLOAT,
|
|
44
43
|
('numeric', 'decimal'): MYSQL_DATA_TYPE.DECIMAL,
|
|
45
44
|
('double precision',): MYSQL_DATA_TYPE.DOUBLE,
|
|
46
|
-
('character varying', 'varchar'
|
|
45
|
+
('character varying', 'varchar'): MYSQL_DATA_TYPE.VARCHAR,
|
|
46
|
+
('character', 'char', 'bpchar', 'bpchar', 'text'): MYSQL_DATA_TYPE.TEXT,
|
|
47
47
|
('timestamp', 'timestamp without time zone', 'timestamp with time zone'): MYSQL_DATA_TYPE.DATETIME,
|
|
48
48
|
('date', ): MYSQL_DATA_TYPE.DATE,
|
|
49
49
|
('time', 'time without time zone', 'time with time zone'): MYSQL_DATA_TYPE.TIME,
|
|
@@ -76,9 +76,7 @@ class PostgresHandler(DatabaseHandler):
|
|
|
76
76
|
|
|
77
77
|
self.connection = None
|
|
78
78
|
self.is_connected = False
|
|
79
|
-
self.thread_safe =
|
|
80
|
-
|
|
81
|
-
self._insert_lock = threading.Lock()
|
|
79
|
+
self.thread_safe = False
|
|
82
80
|
|
|
83
81
|
def __del__(self):
|
|
84
82
|
if self.is_connected:
|
|
@@ -266,15 +264,13 @@ class PostgresHandler(DatabaseHandler):
|
|
|
266
264
|
|
|
267
265
|
columns = df.columns
|
|
268
266
|
|
|
269
|
-
|
|
270
|
-
with self._insert_lock:
|
|
271
|
-
resp = self.get_columns(table_name)
|
|
267
|
+
resp = self.get_columns(table_name)
|
|
272
268
|
|
|
273
269
|
# copy requires precise cases of names: get current column names from table and adapt input dataframe columns
|
|
274
270
|
if resp.data_frame is not None and not resp.data_frame.empty:
|
|
275
271
|
db_columns = {
|
|
276
272
|
c.lower(): c
|
|
277
|
-
for c in resp.data_frame['
|
|
273
|
+
for c in resp.data_frame['COLUMN_NAME']
|
|
278
274
|
}
|
|
279
275
|
|
|
280
276
|
# try to get case of existing column
|
|
@@ -288,11 +284,10 @@ class PostgresHandler(DatabaseHandler):
|
|
|
288
284
|
|
|
289
285
|
with connection.cursor() as cur:
|
|
290
286
|
try:
|
|
291
|
-
with
|
|
292
|
-
|
|
293
|
-
df.to_csv(copy, index=False, header=False)
|
|
287
|
+
with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
|
|
288
|
+
df.to_csv(copy, index=False, header=False)
|
|
294
289
|
|
|
295
|
-
|
|
290
|
+
connection.commit()
|
|
296
291
|
except Exception as e:
|
|
297
292
|
logger.error(f'Error running insert to {table_name} on {self.database}, {e}!')
|
|
298
293
|
connection.rollback()
|
|
@@ -366,8 +361,18 @@ class PostgresHandler(DatabaseHandler):
|
|
|
366
361
|
schema_name = 'current_schema()'
|
|
367
362
|
query = f"""
|
|
368
363
|
SELECT
|
|
369
|
-
|
|
370
|
-
|
|
364
|
+
COLUMN_NAME,
|
|
365
|
+
DATA_TYPE,
|
|
366
|
+
ORDINAL_POSITION,
|
|
367
|
+
COLUMN_DEFAULT,
|
|
368
|
+
IS_NULLABLE,
|
|
369
|
+
CHARACTER_MAXIMUM_LENGTH,
|
|
370
|
+
CHARACTER_OCTET_LENGTH,
|
|
371
|
+
NUMERIC_PRECISION,
|
|
372
|
+
NUMERIC_SCALE,
|
|
373
|
+
DATETIME_PRECISION,
|
|
374
|
+
CHARACTER_SET_NAME,
|
|
375
|
+
COLLATION_NAME
|
|
371
376
|
FROM
|
|
372
377
|
information_schema.columns
|
|
373
378
|
WHERE
|
|
@@ -376,9 +381,7 @@ class PostgresHandler(DatabaseHandler):
|
|
|
376
381
|
table_schema = {schema_name}
|
|
377
382
|
"""
|
|
378
383
|
result = self.native_query(query)
|
|
379
|
-
|
|
380
|
-
result.data_frame.columns = [name.lower() for name in result.data_frame.columns]
|
|
381
|
-
result.data_frame['mysql_data_type'] = result.data_frame['type'].apply(_map_type)
|
|
384
|
+
result.to_columns_table_response(map_type_fn=_map_type)
|
|
382
385
|
return result
|
|
383
386
|
|
|
384
387
|
def subscribe(self, stop_event, callback, table_name, columns=None, **kwargs):
|