MindsDB 25.5.4.0__py3-none-any.whl → 25.5.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +8 -8
- mindsdb/api/a2a/__main__.py +38 -8
- mindsdb/api/a2a/run_a2a.py +10 -53
- mindsdb/api/a2a/task_manager.py +19 -53
- mindsdb/api/executor/command_executor.py +147 -291
- mindsdb/api/http/namespaces/config.py +61 -86
- mindsdb/integrations/handlers/byom_handler/requirements.txt +1 -2
- mindsdb/integrations/handlers/lancedb_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +37 -20
- mindsdb/integrations/libs/llm/config.py +13 -0
- mindsdb/integrations/libs/llm/utils.py +37 -65
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +230 -227
- mindsdb/interfaces/agents/constants.py +17 -13
- mindsdb/interfaces/agents/langchain_agent.py +93 -94
- mindsdb/interfaces/knowledge_base/controller.py +230 -221
- mindsdb/utilities/config.py +43 -84
- mindsdb/utilities/starters.py +9 -1
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.2.dist-info}/METADATA +268 -266
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.2.dist-info}/RECORD +22 -26
- mindsdb/api/a2a/a2a_client.py +0 -439
- mindsdb/api/a2a/common/client/__init__.py +0 -4
- mindsdb/api/a2a/common/client/card_resolver.py +0 -21
- mindsdb/api/a2a/common/client/client.py +0 -86
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.2.dist-info}/WHEEL +0 -0
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.2.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.2.dist-info}/top_level.txt +0 -0
|
@@ -20,121 +20,109 @@ from mindsdb.integrations.libs.response import HandlerStatusResponse
|
|
|
20
20
|
logger = log.getLogger(__name__)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
@ns_conf.route(
|
|
24
|
-
@ns_conf.param(
|
|
23
|
+
@ns_conf.route("/")
|
|
24
|
+
@ns_conf.param("name", "Get config")
|
|
25
25
|
class GetConfig(Resource):
|
|
26
|
-
@ns_conf.doc(
|
|
27
|
-
@api_endpoint_metrics(
|
|
26
|
+
@ns_conf.doc("get_config")
|
|
27
|
+
@api_endpoint_metrics("GET", "/config")
|
|
28
28
|
def get(self):
|
|
29
29
|
config = Config()
|
|
30
|
-
resp = {
|
|
31
|
-
|
|
32
|
-
'http_auth_enabled': config['auth']['http_auth_enabled']
|
|
33
|
-
}
|
|
34
|
-
}
|
|
35
|
-
for key in ['default_llm', 'default_embedding_model']:
|
|
30
|
+
resp = {"auth": {"http_auth_enabled": config["auth"]["http_auth_enabled"]}}
|
|
31
|
+
for key in ["default_llm", "default_embedding_model", "default_reranking_model"]:
|
|
36
32
|
value = config.get(key)
|
|
37
33
|
if value is not None:
|
|
38
34
|
resp[key] = value
|
|
39
35
|
return resp
|
|
40
36
|
|
|
41
|
-
@ns_conf.doc(
|
|
42
|
-
@api_endpoint_metrics(
|
|
37
|
+
@ns_conf.doc("put_config")
|
|
38
|
+
@api_endpoint_metrics("PUT", "/config")
|
|
43
39
|
def put(self):
|
|
44
40
|
data = request.json
|
|
45
41
|
|
|
46
|
-
allowed_arguments = {
|
|
42
|
+
allowed_arguments = {"auth", "default_llm", "default_embedding_model", "default_reranking_model"}
|
|
47
43
|
unknown_arguments = list(set(data.keys()) - allowed_arguments)
|
|
48
44
|
if len(unknown_arguments) > 0:
|
|
49
|
-
return http_error(
|
|
50
|
-
HTTPStatus.BAD_REQUEST, 'Wrong arguments',
|
|
51
|
-
f'Unknown argumens: {unknown_arguments}'
|
|
52
|
-
)
|
|
45
|
+
return http_error(HTTPStatus.BAD_REQUEST, "Wrong arguments", f"Unknown argumens: {unknown_arguments}")
|
|
53
46
|
|
|
54
|
-
nested_keys_to_validate = {
|
|
47
|
+
nested_keys_to_validate = {"auth"}
|
|
55
48
|
for key in data.keys():
|
|
56
49
|
if key in nested_keys_to_validate:
|
|
57
|
-
unknown_arguments = list(
|
|
58
|
-
set(data[key].keys()) - set(Config()[key].keys())
|
|
59
|
-
)
|
|
50
|
+
unknown_arguments = list(set(data[key].keys()) - set(Config()[key].keys()))
|
|
60
51
|
if len(unknown_arguments) > 0:
|
|
61
52
|
return http_error(
|
|
62
|
-
HTTPStatus.BAD_REQUEST,
|
|
63
|
-
f'Unknown argumens: {unknown_arguments}'
|
|
53
|
+
HTTPStatus.BAD_REQUEST, "Wrong arguments", f"Unknown argumens: {unknown_arguments}"
|
|
64
54
|
)
|
|
65
55
|
|
|
66
56
|
Config().update(data)
|
|
67
57
|
|
|
68
|
-
return
|
|
58
|
+
return "", 200
|
|
69
59
|
|
|
70
60
|
|
|
71
|
-
@ns_conf.route(
|
|
72
|
-
@ns_conf.param(
|
|
61
|
+
@ns_conf.route("/integrations")
|
|
62
|
+
@ns_conf.param("name", "List all database integration")
|
|
73
63
|
class ListIntegration(Resource):
|
|
74
|
-
@api_endpoint_metrics(
|
|
64
|
+
@api_endpoint_metrics("GET", "/config/integrations")
|
|
75
65
|
def get(self):
|
|
76
|
-
return {
|
|
77
|
-
'integrations': [k for k in ca.integration_controller.get_all(show_secrets=False)]
|
|
78
|
-
}
|
|
66
|
+
return {"integrations": [k for k in ca.integration_controller.get_all(show_secrets=False)]}
|
|
79
67
|
|
|
80
68
|
|
|
81
|
-
@ns_conf.route(
|
|
82
|
-
@ns_conf.param(
|
|
69
|
+
@ns_conf.route("/all_integrations")
|
|
70
|
+
@ns_conf.param("name", "List all database integration")
|
|
83
71
|
class AllIntegration(Resource):
|
|
84
|
-
@ns_conf.doc(
|
|
85
|
-
@api_endpoint_metrics(
|
|
72
|
+
@ns_conf.doc("get_all_integrations")
|
|
73
|
+
@api_endpoint_metrics("GET", "/config/all_integrations")
|
|
86
74
|
def get(self):
|
|
87
75
|
integrations = ca.integration_controller.get_all(show_secrets=False)
|
|
88
76
|
return integrations
|
|
89
77
|
|
|
90
78
|
|
|
91
|
-
@ns_conf.route(
|
|
92
|
-
@ns_conf.param(
|
|
79
|
+
@ns_conf.route("/integrations/<name>")
|
|
80
|
+
@ns_conf.param("name", "Database integration")
|
|
93
81
|
class Integration(Resource):
|
|
94
|
-
@ns_conf.doc(
|
|
95
|
-
@api_endpoint_metrics(
|
|
82
|
+
@ns_conf.doc("get_integration")
|
|
83
|
+
@api_endpoint_metrics("GET", "/config/integrations/integration")
|
|
96
84
|
def get(self, name):
|
|
97
85
|
integration = ca.integration_controller.get(name, show_secrets=False)
|
|
98
86
|
if integration is None:
|
|
99
|
-
return http_error(HTTPStatus.NOT_FOUND,
|
|
87
|
+
return http_error(HTTPStatus.NOT_FOUND, "Not found", f"Can't find integration: {name}")
|
|
100
88
|
integration = copy.deepcopy(integration)
|
|
101
89
|
return integration
|
|
102
90
|
|
|
103
|
-
@ns_conf.doc(
|
|
104
|
-
@api_endpoint_metrics(
|
|
91
|
+
@ns_conf.doc("put_integration")
|
|
92
|
+
@api_endpoint_metrics("PUT", "/config/integrations/integration")
|
|
105
93
|
def put(self, name):
|
|
106
94
|
params = {}
|
|
107
95
|
if request.is_json:
|
|
108
|
-
params.update((request.json or {}).get(
|
|
96
|
+
params.update((request.json or {}).get("params", {}))
|
|
109
97
|
else:
|
|
110
98
|
params.update(request.form or {})
|
|
111
99
|
|
|
112
100
|
if len(params) == 0:
|
|
113
|
-
return http_error(HTTPStatus.BAD_REQUEST,
|
|
101
|
+
return http_error(HTTPStatus.BAD_REQUEST, "Wrong argument", "type of 'params' must be dict")
|
|
114
102
|
|
|
115
103
|
files = request.files
|
|
116
104
|
temp_dir = None
|
|
117
105
|
if files is not None and len(files) > 0:
|
|
118
|
-
temp_dir = tempfile.mkdtemp(prefix=
|
|
106
|
+
temp_dir = tempfile.mkdtemp(prefix="integration_files_")
|
|
119
107
|
for key, file in files.items():
|
|
120
108
|
temp_dir_path = Path(temp_dir)
|
|
121
109
|
file_name = Path(file.filename)
|
|
122
110
|
file_path = temp_dir_path.joinpath(file_name).resolve()
|
|
123
111
|
if temp_dir_path not in file_path.parents:
|
|
124
|
-
raise Exception(f
|
|
112
|
+
raise Exception(f"Can not save file at path: {file_path}")
|
|
125
113
|
file.save(file_path)
|
|
126
114
|
params[key] = str(file_path)
|
|
127
115
|
|
|
128
|
-
is_test = params.get(
|
|
116
|
+
is_test = params.get("test", False)
|
|
129
117
|
# TODO: Move this to new Endpoint
|
|
130
118
|
|
|
131
119
|
config = Config()
|
|
132
|
-
secret_key = config.get(
|
|
120
|
+
secret_key = config.get("secret_key", "dummy-key")
|
|
133
121
|
|
|
134
122
|
if is_test:
|
|
135
|
-
del params[
|
|
136
|
-
handler_type = params.pop(
|
|
137
|
-
params.pop(
|
|
123
|
+
del params["test"]
|
|
124
|
+
handler_type = params.pop("type", None)
|
|
125
|
+
params.pop("publish", None)
|
|
138
126
|
try:
|
|
139
127
|
handler = ca.integration_controller.create_tmp_handler(name, handler_type, params)
|
|
140
128
|
status = handler.check_connection()
|
|
@@ -145,33 +133,32 @@ class Integration(Resource):
|
|
|
145
133
|
|
|
146
134
|
resp = status.to_json()
|
|
147
135
|
|
|
148
|
-
if status.success and
|
|
149
|
-
if hasattr(handler,
|
|
136
|
+
if status.success and "code" in params:
|
|
137
|
+
if hasattr(handler, "handler_storage"):
|
|
150
138
|
# attach storage if exists
|
|
151
139
|
export = handler.handler_storage.export_files()
|
|
152
140
|
if export:
|
|
153
141
|
# encrypt with flask secret key
|
|
154
142
|
encrypted = encrypt(export, secret_key)
|
|
155
|
-
resp[
|
|
143
|
+
resp["storage"] = encrypted.decode()
|
|
156
144
|
|
|
157
145
|
return resp, 200
|
|
158
146
|
|
|
159
147
|
config = Config()
|
|
160
|
-
secret_key = config.get(
|
|
148
|
+
secret_key = config.get("secret_key", "dummy-key")
|
|
161
149
|
|
|
162
150
|
integration = ca.integration_controller.get(name, show_secrets=False)
|
|
163
151
|
if integration is not None:
|
|
164
152
|
return http_error(
|
|
165
|
-
HTTPStatus.BAD_REQUEST,
|
|
166
|
-
f"Integration with name '{name}' already exists"
|
|
153
|
+
HTTPStatus.BAD_REQUEST, "Wrong argument", f"Integration with name '{name}' already exists"
|
|
167
154
|
)
|
|
168
155
|
|
|
169
156
|
try:
|
|
170
|
-
engine = params[
|
|
157
|
+
engine = params["type"]
|
|
171
158
|
if engine is not None:
|
|
172
|
-
del params[
|
|
173
|
-
params.pop(
|
|
174
|
-
storage = params.pop(
|
|
159
|
+
del params["type"]
|
|
160
|
+
params.pop("publish", False)
|
|
161
|
+
storage = params.pop("storage", None)
|
|
175
162
|
ca.integration_controller.add(name, engine, params)
|
|
176
163
|
|
|
177
164
|
# copy storage
|
|
@@ -185,62 +172,50 @@ class Integration(Resource):
|
|
|
185
172
|
logger.error(str(e))
|
|
186
173
|
if temp_dir is not None:
|
|
187
174
|
shutil.rmtree(temp_dir)
|
|
188
|
-
return http_error(
|
|
189
|
-
HTTPStatus.INTERNAL_SERVER_ERROR, 'Error',
|
|
190
|
-
f'Error during config update: {str(e)}'
|
|
191
|
-
)
|
|
175
|
+
return http_error(HTTPStatus.INTERNAL_SERVER_ERROR, "Error", f"Error during config update: {str(e)}")
|
|
192
176
|
|
|
193
177
|
if temp_dir is not None:
|
|
194
178
|
shutil.rmtree(temp_dir)
|
|
195
179
|
return {}, 200
|
|
196
180
|
|
|
197
|
-
@ns_conf.doc(
|
|
198
|
-
@api_endpoint_metrics(
|
|
181
|
+
@ns_conf.doc("delete_integration")
|
|
182
|
+
@api_endpoint_metrics("DELETE", "/config/integrations/integration")
|
|
199
183
|
def delete(self, name):
|
|
200
184
|
integration = ca.integration_controller.get(name)
|
|
201
185
|
if integration is None:
|
|
202
186
|
return http_error(
|
|
203
|
-
HTTPStatus.BAD_REQUEST,
|
|
204
|
-
f"Nothing to delete. '{name}' not exists."
|
|
187
|
+
HTTPStatus.BAD_REQUEST, "Integration does not exists", f"Nothing to delete. '{name}' not exists."
|
|
205
188
|
)
|
|
206
189
|
try:
|
|
207
190
|
ca.integration_controller.delete(name)
|
|
208
191
|
except Exception as e:
|
|
209
192
|
logger.error(str(e))
|
|
210
|
-
return http_error(
|
|
211
|
-
HTTPStatus.INTERNAL_SERVER_ERROR, 'Error',
|
|
212
|
-
f"Error during integration delete: {str(e)}"
|
|
213
|
-
)
|
|
193
|
+
return http_error(HTTPStatus.INTERNAL_SERVER_ERROR, "Error", f"Error during integration delete: {str(e)}")
|
|
214
194
|
return "", 200
|
|
215
195
|
|
|
216
|
-
@ns_conf.doc(
|
|
217
|
-
@api_endpoint_metrics(
|
|
196
|
+
@ns_conf.doc("modify_integration")
|
|
197
|
+
@api_endpoint_metrics("POST", "/config/integrations/integration")
|
|
218
198
|
def post(self, name):
|
|
219
199
|
params = {}
|
|
220
|
-
params.update((request.json or {}).get(
|
|
200
|
+
params.update((request.json or {}).get("params", {}))
|
|
221
201
|
params.update(request.form or {})
|
|
222
202
|
|
|
223
203
|
if not isinstance(params, dict):
|
|
224
|
-
return http_error(
|
|
225
|
-
HTTPStatus.BAD_REQUEST, 'Wrong argument',
|
|
226
|
-
"type of 'params' must be dict"
|
|
227
|
-
)
|
|
204
|
+
return http_error(HTTPStatus.BAD_REQUEST, "Wrong argument", "type of 'params' must be dict")
|
|
228
205
|
integration = ca.integration_controller.get(name)
|
|
229
206
|
if integration is None:
|
|
230
207
|
return http_error(
|
|
231
|
-
HTTPStatus.BAD_REQUEST,
|
|
232
|
-
f"Nothin to modify. '{name}' not exists."
|
|
208
|
+
HTTPStatus.BAD_REQUEST, "Integration does not exists", f"Nothin to modify. '{name}' not exists."
|
|
233
209
|
)
|
|
234
210
|
try:
|
|
235
|
-
if
|
|
236
|
-
params[
|
|
237
|
-
del params[
|
|
211
|
+
if "enabled" in params:
|
|
212
|
+
params["publish"] = params["enabled"]
|
|
213
|
+
del params["enabled"]
|
|
238
214
|
ca.integration_controller.modify(name, params)
|
|
239
215
|
|
|
240
216
|
except Exception as e:
|
|
241
217
|
logger.error(str(e))
|
|
242
218
|
return http_error(
|
|
243
|
-
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
244
|
-
f"Error during integration modification: {str(e)}"
|
|
219
|
+
HTTPStatus.INTERNAL_SERVER_ERROR, "Error", f"Error during integration modification: {str(e)}"
|
|
245
220
|
)
|
|
246
221
|
return "", 200
|
|
@@ -1,2 +1 @@
|
|
|
1
|
-
virtualenv
|
|
2
|
-
pyarrow==19.0.0
|
|
1
|
+
virtualenv
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
from typing import Dict, Optional, List
|
|
3
3
|
|
|
4
|
+
|
|
5
|
+
from litellm import completion, batch_completion, embedding
|
|
4
6
|
import pandas as pd
|
|
5
7
|
|
|
6
8
|
from mindsdb.integrations.libs.base import BaseMLEngine
|
|
@@ -8,8 +10,6 @@ from mindsdb.utilities import log
|
|
|
8
10
|
|
|
9
11
|
from mindsdb.integrations.handlers.litellm_handler.settings import CompletionParameters
|
|
10
12
|
|
|
11
|
-
from litellm import completion, batch_completion
|
|
12
|
-
|
|
13
13
|
|
|
14
14
|
logger = log.getLogger(__name__)
|
|
15
15
|
|
|
@@ -28,10 +28,24 @@ class LiteLLMHandler(BaseMLEngine):
|
|
|
28
28
|
@staticmethod
|
|
29
29
|
def create_validation(target, args=None, **kwargs):
|
|
30
30
|
if "using" not in args:
|
|
31
|
-
raise Exception(
|
|
32
|
-
|
|
31
|
+
raise Exception("Litellm engine requires a USING clause. See settings.py for more info on supported args.")
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def embeddings(model: str, messages: List[str], args: dict) -> List[list]:
|
|
35
|
+
response = embedding(model=model, input=messages, **args)
|
|
36
|
+
return [rec["embedding"] for rec in response.data]
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
async def acompletion(model: str, messages: List[dict], args: dict):
|
|
40
|
+
if model.startswith("snowflake/") and "snowflake_account_id" in args:
|
|
41
|
+
args["api_base"] = (
|
|
42
|
+
f"https://{args['snowflake_account_id']}.snowflakecomputing.com/api/v2/cortex/inference:complete"
|
|
33
43
|
)
|
|
34
44
|
|
|
45
|
+
from litellm import acompletion
|
|
46
|
+
|
|
47
|
+
return await acompletion(model=model, messages=messages, stream=False, **args)
|
|
48
|
+
|
|
35
49
|
def create(
|
|
36
50
|
self,
|
|
37
51
|
target: str,
|
|
@@ -70,9 +84,9 @@ class LiteLLMHandler(BaseMLEngine):
|
|
|
70
84
|
self._build_messages(args, df)
|
|
71
85
|
|
|
72
86
|
# remove prompt_template from args
|
|
73
|
-
args.pop(
|
|
87
|
+
args.pop("prompt_template", None)
|
|
74
88
|
|
|
75
|
-
if len(args[
|
|
89
|
+
if len(args["messages"]) > 1:
|
|
76
90
|
# if more than one message, use batch completion
|
|
77
91
|
responses = batch_completion(**args)
|
|
78
92
|
return pd.DataFrame({"result": [response.choices[0].message.content for response in responses]})
|
|
@@ -103,36 +117,39 @@ class LiteLLMHandler(BaseMLEngine):
|
|
|
103
117
|
|
|
104
118
|
if "prompt_template" in prompt_kwargs:
|
|
105
119
|
# if prompt_template is passed in predict query, use it
|
|
106
|
-
logger.info(
|
|
107
|
-
|
|
120
|
+
logger.info(
|
|
121
|
+
"Using 'prompt_template' passed in SELECT Predict query. "
|
|
122
|
+
"Note this will overwrite a 'prompt_template' passed in create MODEL query."
|
|
123
|
+
)
|
|
108
124
|
|
|
109
|
-
args[
|
|
125
|
+
args["prompt_template"] = prompt_kwargs.pop("prompt_template")
|
|
110
126
|
|
|
111
|
-
if
|
|
127
|
+
if "mock_response" in prompt_kwargs:
|
|
112
128
|
# used for testing to save on real completion api calls
|
|
113
|
-
args[
|
|
129
|
+
args["mock_response"]: str = prompt_kwargs.pop("mock_response")
|
|
114
130
|
|
|
115
|
-
if
|
|
131
|
+
if "messages" in prompt_kwargs and len(prompt_kwargs) > 1:
|
|
116
132
|
# if user passes in messages, no other args can be passed in
|
|
117
|
-
raise Exception(
|
|
118
|
-
"If 'messages' is passed in SELECT Predict query, no other args can be passed in."
|
|
119
|
-
)
|
|
133
|
+
raise Exception("If 'messages' is passed in SELECT Predict query, no other args can be passed in.")
|
|
120
134
|
|
|
121
135
|
# if user passes in messages, use those instead
|
|
122
|
-
if
|
|
136
|
+
if "messages" in prompt_kwargs:
|
|
123
137
|
logger.info("Using messages passed in SELECT Predict query. 'prompt_template' will be ignored.")
|
|
124
138
|
|
|
125
|
-
args[
|
|
139
|
+
args["messages"]: List = ast.literal_eval(df["messages"].iloc[0])
|
|
126
140
|
|
|
127
141
|
else:
|
|
128
142
|
# if user passes in prompt_template, use that to create messages
|
|
129
143
|
if len(prompt_kwargs) == 1:
|
|
130
|
-
args[
|
|
131
|
-
|
|
144
|
+
args["messages"] = (
|
|
145
|
+
self._prompt_to_messages(args["prompt_template"], **prompt_kwargs)
|
|
146
|
+
if args["prompt_template"]
|
|
147
|
+
else self._prompt_to_messages(df.iloc[0][0])
|
|
148
|
+
)
|
|
132
149
|
|
|
133
150
|
elif len(prompt_kwargs) > 1:
|
|
134
151
|
try:
|
|
135
|
-
args[
|
|
152
|
+
args["messages"] = self._prompt_to_messages(args["prompt_template"], **prompt_kwargs)
|
|
136
153
|
except KeyError as e:
|
|
137
154
|
raise Exception(
|
|
138
155
|
f"{e}: Please pass in either a prompt_template on create MODEL or "
|
|
@@ -114,3 +114,16 @@ class GoogleConfig(BaseLLMConfig):
|
|
|
114
114
|
top_k: Optional[int] = Field(default=None, description="Number of highest probability tokens to consider")
|
|
115
115
|
max_output_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate")
|
|
116
116
|
google_api_key: Optional[str] = Field(default=None, description="API key for Google Generative AI")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# See https://api.python.langchain.com/en/latest/llms/langchain_community.llms.writer.Writer.html
|
|
120
|
+
class WriterConfig(BaseLLMConfig):
|
|
121
|
+
model_name: str = Field(default="palmyra-x5", alias="model_id")
|
|
122
|
+
temperature: Optional[float] = Field(default=0.7)
|
|
123
|
+
max_tokens: Optional[int] = Field(default=None)
|
|
124
|
+
top_p: Optional[float] = Field(default=None)
|
|
125
|
+
stop: Optional[List[str]] = Field(default=None)
|
|
126
|
+
best_of: Optional[int] = Field(default=None)
|
|
127
|
+
writer_api_key: Optional[str] = Field(default=None)
|
|
128
|
+
writer_org_id: Optional[str] = Field(default=None)
|
|
129
|
+
base_url: Optional[str] = Field(default=None)
|
|
@@ -16,6 +16,7 @@ from mindsdb.integrations.libs.llm.config import (
|
|
|
16
16
|
OpenAIConfig,
|
|
17
17
|
NvidiaNIMConfig,
|
|
18
18
|
MindsdbConfig,
|
|
19
|
+
WriterConfig,
|
|
19
20
|
)
|
|
20
21
|
from mindsdb.utilities.config import config
|
|
21
22
|
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
|
|
@@ -41,16 +42,12 @@ DEFAULT_LITELLM_BASE_URL = "https://ai.dev.mindsdb.com"
|
|
|
41
42
|
DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434"
|
|
42
43
|
DEFAULT_OLLAMA_MODEL = "llama2"
|
|
43
44
|
|
|
44
|
-
DEFAULT_NVIDIA_NIM_BASE_URL =
|
|
45
|
-
"http://localhost:8000/v1" # Assumes local port forwarding through ssh
|
|
46
|
-
)
|
|
45
|
+
DEFAULT_NVIDIA_NIM_BASE_URL = "http://localhost:8000/v1" # Assumes local port forwarding through ssh
|
|
47
46
|
DEFAULT_NVIDIA_NIM_MODEL = "meta/llama-3_1-8b-instruct"
|
|
48
47
|
DEFAULT_VLLM_SERVER_URL = "http://localhost:8000/v1"
|
|
49
48
|
|
|
50
49
|
|
|
51
|
-
def get_completed_prompts(
|
|
52
|
-
base_template: str, df: pd.DataFrame, strict=True
|
|
53
|
-
) -> Tuple[List[str], np.ndarray]:
|
|
50
|
+
def get_completed_prompts(base_template: str, df: pd.DataFrame, strict=True) -> Tuple[List[str], np.ndarray]:
|
|
54
51
|
"""
|
|
55
52
|
Helper method that produces formatted prompts given a template and data in a Pandas DataFrame.
|
|
56
53
|
It also returns the ID of any empty templates that failed to be filled due to missing data.
|
|
@@ -69,9 +66,7 @@ def get_completed_prompts(
|
|
|
69
66
|
if len(matches) == 0:
|
|
70
67
|
# no placeholders
|
|
71
68
|
if strict:
|
|
72
|
-
raise AssertionError(
|
|
73
|
-
"No placeholders found in the prompt, please provide a valid prompt template."
|
|
74
|
-
)
|
|
69
|
+
raise AssertionError("No placeholders found in the prompt, please provide a valid prompt template.")
|
|
75
70
|
prompts = [base_template] * len(df)
|
|
76
71
|
return prompts, np.ndarray(0)
|
|
77
72
|
|
|
@@ -95,12 +90,8 @@ def get_completed_prompts(
|
|
|
95
90
|
for i in range(len(template)):
|
|
96
91
|
atom = template[i]
|
|
97
92
|
if i < len(columns):
|
|
98
|
-
col = df[columns[i]].replace(
|
|
99
|
-
|
|
100
|
-
) # add empty quote if data is missing
|
|
101
|
-
df["__mdb_prompt"] = df["__mdb_prompt"].apply(
|
|
102
|
-
lambda x: x + atom
|
|
103
|
-
) + col.astype("string")
|
|
93
|
+
col = df[columns[i]].replace(to_replace=[None], value="") # add empty quote if data is missing
|
|
94
|
+
df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom) + col.astype("string")
|
|
104
95
|
else:
|
|
105
96
|
df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom)
|
|
106
97
|
prompts = list(df["__mdb_prompt"])
|
|
@@ -119,8 +110,7 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
|
|
|
119
110
|
"""
|
|
120
111
|
temperature = min(1.0, max(0.0, args.get("temperature", 0.0)))
|
|
121
112
|
if provider == "openai":
|
|
122
|
-
|
|
123
|
-
if any(x in args.get("model_name", "") for x in ['o1', 'o3']):
|
|
113
|
+
if any(x in args.get("model_name", "") for x in ["o1", "o3"]):
|
|
124
114
|
# for o1 and 03, 'temperature' does not support 0.0 with this model. Only the default (1) value is supported
|
|
125
115
|
temperature = 1
|
|
126
116
|
|
|
@@ -173,9 +163,7 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
|
|
|
173
163
|
max_tokens=args.get("max_tokens", DEFAULT_OPENAI_MAX_TOKENS),
|
|
174
164
|
top_p=args.get("top_p", None),
|
|
175
165
|
top_k=args.get("top_k", None),
|
|
176
|
-
custom_llm_provider=args.get(
|
|
177
|
-
"custom_llm_provider", DEFAULT_LITELLM_PROVIDER
|
|
178
|
-
),
|
|
166
|
+
custom_llm_provider=args.get("custom_llm_provider", DEFAULT_LITELLM_PROVIDER),
|
|
179
167
|
model_kwargs=model_kwargs,
|
|
180
168
|
)
|
|
181
169
|
if provider == "ollama":
|
|
@@ -237,6 +225,18 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
|
|
|
237
225
|
max_output_tokens=args.get("max_tokens", None),
|
|
238
226
|
google_api_key=args["api_keys"].get("google", None),
|
|
239
227
|
)
|
|
228
|
+
if provider == "writer":
|
|
229
|
+
return WriterConfig(
|
|
230
|
+
model_name=args.get("model_name", "palmyra-x5"),
|
|
231
|
+
temperature=temperature,
|
|
232
|
+
max_tokens=args.get("max_tokens", None),
|
|
233
|
+
top_p=args.get("top_p", None),
|
|
234
|
+
stop=args.get("stop", None),
|
|
235
|
+
best_of=args.get("best_of", None),
|
|
236
|
+
writer_api_key=args["api_keys"].get("writer", None),
|
|
237
|
+
writer_org_id=args.get("writer_org_id", None),
|
|
238
|
+
base_url=args.get("base_url", None),
|
|
239
|
+
)
|
|
240
240
|
|
|
241
241
|
raise ValueError(f"Provider {provider} is not supported.")
|
|
242
242
|
|
|
@@ -290,9 +290,7 @@ def ft_jsonl_validation(
|
|
|
290
290
|
) # noqa
|
|
291
291
|
|
|
292
292
|
if messages_col not in batch:
|
|
293
|
-
raise Exception(
|
|
294
|
-
f"{prefix}Each line in the provided data should have a '{messages_col}' key"
|
|
295
|
-
)
|
|
293
|
+
raise Exception(f"{prefix}Each line in the provided data should have a '{messages_col}' key")
|
|
296
294
|
|
|
297
295
|
messages = batch[messages_col]
|
|
298
296
|
try:
|
|
@@ -350,30 +348,22 @@ def ft_chat_format_validation(
|
|
|
350
348
|
|
|
351
349
|
for c in chat:
|
|
352
350
|
if any(k not in valid_keys for k in c.keys()):
|
|
353
|
-
raise Exception(
|
|
354
|
-
f"Each message should only have these keys: `{valid_keys}`. Found: `{c.keys()}`"
|
|
355
|
-
)
|
|
351
|
+
raise Exception(f"Each message should only have these keys: `{valid_keys}`. Found: `{c.keys()}`")
|
|
356
352
|
|
|
357
353
|
roles = [m[role_key] for m in chat]
|
|
358
354
|
contents = [m[content_key] for m in chat]
|
|
359
355
|
|
|
360
356
|
if len(roles) != len(contents):
|
|
361
|
-
raise Exception(
|
|
362
|
-
f"Each message should contain both `{role_key}` and `{content_key}` fields"
|
|
363
|
-
)
|
|
357
|
+
raise Exception(f"Each message should contain both `{role_key}` and `{content_key}` fields")
|
|
364
358
|
|
|
365
359
|
if len(roles) == 0:
|
|
366
360
|
raise Exception("Chat should have at least one message")
|
|
367
361
|
|
|
368
362
|
if assistant_key not in roles:
|
|
369
|
-
raise Exception(
|
|
370
|
-
"Chat should have at least one assistant message"
|
|
371
|
-
) # otherwise it is useless for FT
|
|
363
|
+
raise Exception("Chat should have at least one assistant message") # otherwise it is useless for FT
|
|
372
364
|
|
|
373
365
|
if user_key not in roles:
|
|
374
|
-
raise Exception(
|
|
375
|
-
"Chat should have at least one user message"
|
|
376
|
-
) # perhaps remove in the future
|
|
366
|
+
raise Exception("Chat should have at least one user message") # perhaps remove in the future
|
|
377
367
|
|
|
378
368
|
# set default transitions for finite state machine if undefined
|
|
379
369
|
if transitions is None:
|
|
@@ -387,20 +377,15 @@ def ft_chat_format_validation(
|
|
|
387
377
|
# check order is valid via finite state machine
|
|
388
378
|
state = None
|
|
389
379
|
for i, (role, content) in enumerate(zip(roles, contents)):
|
|
390
|
-
|
|
391
380
|
prefix = f"message #{i + 1}: "
|
|
392
381
|
|
|
393
382
|
# check invalid roles
|
|
394
383
|
if role not in valid_roles:
|
|
395
|
-
raise Exception(
|
|
396
|
-
f"{prefix}Invalid role (found `{role}`, expected one of `{valid_roles}`)"
|
|
397
|
-
)
|
|
384
|
+
raise Exception(f"{prefix}Invalid role (found `{role}`, expected one of `{valid_roles}`)")
|
|
398
385
|
|
|
399
386
|
# check content
|
|
400
387
|
if not isinstance(content, str):
|
|
401
|
-
raise Exception(
|
|
402
|
-
f"{prefix}Content should be a string, got type `{type(content)}`"
|
|
403
|
-
)
|
|
388
|
+
raise Exception(f"{prefix}Content should be a string, got type `{type(content)}`")
|
|
404
389
|
|
|
405
390
|
# check transition
|
|
406
391
|
if role not in transitions[state]:
|
|
@@ -464,9 +449,7 @@ def ft_chat_formatter(df: pd.DataFrame) -> List[Dict]:
|
|
|
464
449
|
df = df.sort_values(["chat_id"], kind="stable")
|
|
465
450
|
elif "message_id" in df.columns:
|
|
466
451
|
if df["message_id"].duplicated().any():
|
|
467
|
-
raise Exception(
|
|
468
|
-
"If `message_id` is provided, it must not contain duplicate IDs."
|
|
469
|
-
)
|
|
452
|
+
raise Exception("If `message_id` is provided, it must not contain duplicate IDs.")
|
|
470
453
|
df = df.sort_values(["message_id"])
|
|
471
454
|
|
|
472
455
|
# 2. build chats
|
|
@@ -477,12 +460,8 @@ def ft_chat_formatter(df: pd.DataFrame) -> List[Dict]:
|
|
|
477
460
|
for _, row in df.iterrows():
|
|
478
461
|
try:
|
|
479
462
|
chat = json.loads(row["chat_json"])
|
|
480
|
-
assert list(chat.keys()) == [
|
|
481
|
-
|
|
482
|
-
], "Each chat should have a 'messages' key, and nothing else."
|
|
483
|
-
ft_chat_format_validation(
|
|
484
|
-
chat["messages"]
|
|
485
|
-
) # will raise Exception if chat is invalid
|
|
463
|
+
assert list(chat.keys()) == ["messages"], "Each chat should have a 'messages' key, and nothing else."
|
|
464
|
+
ft_chat_format_validation(chat["messages"]) # will raise Exception if chat is invalid
|
|
486
465
|
chats.append(chat)
|
|
487
466
|
except json.JSONDecodeError:
|
|
488
467
|
pass # TODO: add logger info here, prompt user to clean dataset carefully
|
|
@@ -492,9 +471,7 @@ def ft_chat_formatter(df: pd.DataFrame) -> List[Dict]:
|
|
|
492
471
|
chat = []
|
|
493
472
|
for i, row in df.iterrows():
|
|
494
473
|
if row["role"] == "system" and len(chat) > 0:
|
|
495
|
-
ft_chat_format_validation(
|
|
496
|
-
chat
|
|
497
|
-
) # will raise Exception if chat is invalid
|
|
474
|
+
ft_chat_format_validation(chat) # will raise Exception if chat is invalid
|
|
498
475
|
chats.append({"messages": chat})
|
|
499
476
|
chat = []
|
|
500
477
|
event = {"role": row["role"], "content": row["content"]}
|
|
@@ -529,15 +506,11 @@ def ft_code_formatter(
|
|
|
529
506
|
# input and setup validation
|
|
530
507
|
assert len(df) > 0, "Input dataframe should not be empty"
|
|
531
508
|
assert "code" in df.columns, "Input dataframe should have a 'code' column"
|
|
532
|
-
assert chunk_size > 0 and isinstance(
|
|
533
|
-
chunk_size, int
|
|
534
|
-
), "`chunk_size` should be a positive integer"
|
|
509
|
+
assert chunk_size > 0 and isinstance(chunk_size, int), "`chunk_size` should be a positive integer"
|
|
535
510
|
|
|
536
511
|
supported_formats = ["chat", "fim"]
|
|
537
512
|
supported_langs = [e.value for e in Language]
|
|
538
|
-
assert (
|
|
539
|
-
language.lower() in supported_langs
|
|
540
|
-
), f"Invalid language. Valid choices are: {supported_langs}"
|
|
513
|
+
assert language.lower() in supported_langs, f"Invalid language. Valid choices are: {supported_langs}"
|
|
541
514
|
|
|
542
515
|
# ensure correct encoding
|
|
543
516
|
df["code"] = df["code"].map(lambda x: x.encode("utf8").decode("unicode_escape"))
|
|
@@ -574,7 +547,7 @@ def ft_code_formatter(
|
|
|
574
547
|
roles = []
|
|
575
548
|
contents = []
|
|
576
549
|
for idx in range(0, len(chunks), 3):
|
|
577
|
-
pre, mid, suf = chunks[idx: idx + 3]
|
|
550
|
+
pre, mid, suf = chunks[idx : idx + 3]
|
|
578
551
|
interleaved = list(itertools.chain(*zip(templates, (pre, suf, mid))))
|
|
579
552
|
user = "\n".join(interleaved[:-1])
|
|
580
553
|
assistant = "\n".join(interleaved[-1:])
|
|
@@ -595,12 +568,11 @@ def ft_cqa_formatter(
|
|
|
595
568
|
default_instruction="You are a helpful assistant.",
|
|
596
569
|
default_context="",
|
|
597
570
|
) -> pd.DataFrame:
|
|
598
|
-
|
|
599
571
|
# input and setup validation
|
|
600
572
|
assert len(df) > 0, "Input dataframe should not be empty"
|
|
601
|
-
assert {question_col, answer_col}.issubset(
|
|
602
|
-
|
|
603
|
-
)
|
|
573
|
+
assert {question_col, answer_col}.issubset(set(df.columns)), (
|
|
574
|
+
f"Input dataframe must have columns `{question_col}`, and `{answer_col}`"
|
|
575
|
+
) # noqa
|
|
604
576
|
|
|
605
577
|
if instruction_col not in df.columns:
|
|
606
578
|
df[instruction_col] = default_instruction
|