MindsDB 25.6.2.0__py3-none-any.whl → 25.6.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/api/a2a/agent.py +25 -4
- mindsdb/api/a2a/task_manager.py +68 -6
- mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +91 -84
- mindsdb/api/http/namespaces/knowledge_bases.py +132 -154
- mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py +219 -28
- mindsdb/integrations/handlers/llama_index_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +3 -0
- mindsdb/integrations/handlers/openai_handler/openai_handler.py +277 -356
- mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +94 -8
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +19 -1
- mindsdb/integrations/libs/api_handler.py +19 -1
- mindsdb/integrations/libs/base.py +86 -2
- mindsdb/interfaces/agents/agents_controller.py +32 -6
- mindsdb/interfaces/agents/constants.py +1 -0
- mindsdb/interfaces/agents/mindsdb_database_agent.py +23 -18
- mindsdb/interfaces/data_catalog/data_catalog_loader.py +22 -6
- mindsdb/interfaces/data_catalog/data_catalog_reader.py +4 -0
- mindsdb/interfaces/database/integrations.py +4 -2
- mindsdb/interfaces/knowledge_base/controller.py +3 -15
- mindsdb/interfaces/knowledge_base/evaluate.py +0 -3
- mindsdb/interfaces/skills/skills_controller.py +0 -23
- mindsdb/interfaces/skills/sql_agent.py +8 -4
- mindsdb/interfaces/storage/db.py +20 -4
- mindsdb/utilities/config.py +5 -1
- {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/METADATA +250 -250
- {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/RECORD +30 -30
- {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/WHEEL +0 -0
- {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/top_level.txt +0 -0
|
@@ -30,7 +30,7 @@ from mindsdb.integrations.handlers.openai_handler.constants import (
|
|
|
30
30
|
OPENAI_API_BASE,
|
|
31
31
|
DEFAULT_CHAT_MODEL,
|
|
32
32
|
DEFAULT_EMBEDDING_MODEL,
|
|
33
|
-
DEFAULT_IMAGE_MODEL
|
|
33
|
+
DEFAULT_IMAGE_MODEL,
|
|
34
34
|
)
|
|
35
35
|
from mindsdb.integrations.libs.llm.utils import get_completed_prompts
|
|
36
36
|
from mindsdb.integrations.utilities.handler_utils import get_api_key
|
|
@@ -43,7 +43,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
43
43
|
This handler handles connection and inference with the OpenAI API.
|
|
44
44
|
"""
|
|
45
45
|
|
|
46
|
-
name =
|
|
46
|
+
name = "openai"
|
|
47
47
|
|
|
48
48
|
def __init__(self, *args, **kwargs):
|
|
49
49
|
super().__init__(*args, **kwargs)
|
|
@@ -51,15 +51,13 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
51
51
|
self.default_model = DEFAULT_CHAT_MODEL
|
|
52
52
|
self.default_embedding_model = DEFAULT_EMBEDDING_MODEL
|
|
53
53
|
self.default_image_model = DEFAULT_IMAGE_MODEL
|
|
54
|
-
self.default_mode =
|
|
55
|
-
'default' # can also be 'conversational' or 'conversational-full'
|
|
56
|
-
)
|
|
54
|
+
self.default_mode = "default" # can also be 'conversational' or 'conversational-full'
|
|
57
55
|
self.supported_modes = [
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
56
|
+
"default",
|
|
57
|
+
"conversational",
|
|
58
|
+
"conversational-full",
|
|
59
|
+
"image",
|
|
60
|
+
"embedding",
|
|
63
61
|
]
|
|
64
62
|
self.rate_limit = 60 # requests per minute
|
|
65
63
|
self.max_batch_size = 20
|
|
@@ -67,8 +65,8 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
67
65
|
self.chat_completion_models = CHAT_MODELS
|
|
68
66
|
self.supported_ft_models = FINETUNING_MODELS # base models compatible with finetuning
|
|
69
67
|
# For now this are only used for handlers that inherits OpenAIHandler and don't need to override base methods
|
|
70
|
-
self.api_key_name = getattr(self,
|
|
71
|
-
self.api_base = getattr(self,
|
|
68
|
+
self.api_key_name = getattr(self, "api_key_name", self.name)
|
|
69
|
+
self.api_base = getattr(self, "api_base", OPENAI_API_BASE)
|
|
72
70
|
|
|
73
71
|
def create_engine(self, connection_args: Dict) -> None:
|
|
74
72
|
"""
|
|
@@ -84,10 +82,10 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
84
82
|
None
|
|
85
83
|
"""
|
|
86
84
|
connection_args = {k.lower(): v for k, v in connection_args.items()}
|
|
87
|
-
api_key = connection_args.get(
|
|
85
|
+
api_key = connection_args.get("openai_api_key")
|
|
88
86
|
if api_key is not None:
|
|
89
|
-
org = connection_args.get(
|
|
90
|
-
api_base = connection_args.get(
|
|
87
|
+
org = connection_args.get("api_organization")
|
|
88
|
+
api_base = connection_args.get("api_base") or os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE)
|
|
91
89
|
client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=connection_args)
|
|
92
90
|
OpenAIHandler._check_client_connection(client)
|
|
93
91
|
|
|
@@ -106,13 +104,13 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
106
104
|
None
|
|
107
105
|
"""
|
|
108
106
|
try:
|
|
109
|
-
client.models.retrieve(
|
|
107
|
+
client.models.retrieve("test")
|
|
110
108
|
except NotFoundError:
|
|
111
109
|
pass
|
|
112
110
|
except AuthenticationError as e:
|
|
113
|
-
if e.body[
|
|
114
|
-
raise Exception(
|
|
115
|
-
raise Exception(f
|
|
111
|
+
if e.body["code"] == "invalid_api_key":
|
|
112
|
+
raise Exception("Invalid api key")
|
|
113
|
+
raise Exception(f"Something went wrong: {e}")
|
|
116
114
|
|
|
117
115
|
@staticmethod
|
|
118
116
|
def create_validation(target: Text, args: Dict = None, **kwargs: Any) -> None:
|
|
@@ -130,41 +128,29 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
130
128
|
Returns:
|
|
131
129
|
None
|
|
132
130
|
"""
|
|
133
|
-
if
|
|
134
|
-
raise Exception(
|
|
135
|
-
"OpenAI engine requires a USING clause! Refer to its documentation for more details."
|
|
136
|
-
)
|
|
131
|
+
if "using" not in args:
|
|
132
|
+
raise Exception("OpenAI engine requires a USING clause! Refer to its documentation for more details.")
|
|
137
133
|
else:
|
|
138
|
-
args = args[
|
|
134
|
+
args = args["using"]
|
|
139
135
|
|
|
140
|
-
if (
|
|
141
|
-
|
|
142
|
-
set(args.keys())
|
|
143
|
-
& {'question_column', 'prompt_template', 'prompt'}
|
|
144
|
-
)
|
|
145
|
-
== 0
|
|
146
|
-
):
|
|
147
|
-
raise Exception(
|
|
148
|
-
'One of `question_column`, `prompt_template` or `prompt` is required for this engine.'
|
|
149
|
-
)
|
|
136
|
+
if len(set(args.keys()) & {"question_column", "prompt_template", "prompt"}) == 0:
|
|
137
|
+
raise Exception("One of `question_column`, `prompt_template` or `prompt` is required for this engine.")
|
|
150
138
|
|
|
151
139
|
keys_collection = [
|
|
152
|
-
[
|
|
153
|
-
[
|
|
154
|
-
[
|
|
140
|
+
["prompt_template"],
|
|
141
|
+
["question_column", "context_column"],
|
|
142
|
+
["prompt", "user_column", "assistant_column"],
|
|
155
143
|
]
|
|
156
144
|
for keys in keys_collection:
|
|
157
|
-
if keys[0] in args and any(
|
|
158
|
-
x[0] in args for x in keys_collection if x != keys
|
|
159
|
-
):
|
|
145
|
+
if keys[0] in args and any(x[0] in args for x in keys_collection if x != keys):
|
|
160
146
|
raise Exception(
|
|
161
147
|
textwrap.dedent(
|
|
162
|
-
|
|
148
|
+
"""\
|
|
163
149
|
Please provide one of
|
|
164
150
|
1) a `prompt_template`
|
|
165
151
|
2) a `question_column` and an optional `context_column`
|
|
166
152
|
3) a `prompt', 'user_column' and 'assistant_column`
|
|
167
|
-
|
|
153
|
+
"""
|
|
168
154
|
)
|
|
169
155
|
)
|
|
170
156
|
|
|
@@ -202,11 +188,15 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
202
188
|
f"Unknown arguments: {', '.join(unknown_args)}.\n Known arguments are: {', '.join(known_args)}"
|
|
203
189
|
)
|
|
204
190
|
|
|
205
|
-
engine_storage = kwargs[
|
|
191
|
+
engine_storage = kwargs["handler_storage"]
|
|
206
192
|
connection_args = engine_storage.get_connection_args()
|
|
207
|
-
api_key = get_api_key(
|
|
208
|
-
api_base =
|
|
209
|
-
|
|
193
|
+
api_key = get_api_key("openai", args, engine_storage=engine_storage)
|
|
194
|
+
api_base = (
|
|
195
|
+
args.get("api_base")
|
|
196
|
+
or connection_args.get("api_base")
|
|
197
|
+
or os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE)
|
|
198
|
+
)
|
|
199
|
+
org = args.get("api_organization")
|
|
210
200
|
client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
|
|
211
201
|
OpenAIHandler._check_client_connection(client)
|
|
212
202
|
|
|
@@ -225,33 +215,36 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
225
215
|
Returns:
|
|
226
216
|
None
|
|
227
217
|
"""
|
|
228
|
-
args = args[
|
|
229
|
-
args[
|
|
218
|
+
args = args["using"]
|
|
219
|
+
args["target"] = target
|
|
230
220
|
try:
|
|
231
221
|
api_key = get_api_key(self.api_key_name, args, self.engine_storage)
|
|
232
222
|
connection_args = self.engine_storage.get_connection_args()
|
|
233
|
-
api_base =
|
|
234
|
-
|
|
223
|
+
api_base = (
|
|
224
|
+
args.get("api_base")
|
|
225
|
+
or connection_args.get("api_base")
|
|
226
|
+
or os.environ.get("OPENAI_API_BASE")
|
|
227
|
+
or self.api_base
|
|
228
|
+
)
|
|
229
|
+
client = self._get_client(api_key=api_key, base_url=api_base, org=args.get("api_organization"), args=args)
|
|
235
230
|
available_models = get_available_models(client)
|
|
236
231
|
|
|
237
|
-
if not args.get(
|
|
238
|
-
args[
|
|
239
|
-
elif args[
|
|
240
|
-
raise Exception(
|
|
241
|
-
f"Invalid operation mode. Please use one of {self.supported_modes}"
|
|
242
|
-
)
|
|
232
|
+
if not args.get("mode"):
|
|
233
|
+
args["mode"] = self.default_mode
|
|
234
|
+
elif args["mode"] not in self.supported_modes:
|
|
235
|
+
raise Exception(f"Invalid operation mode. Please use one of {self.supported_modes}")
|
|
243
236
|
|
|
244
|
-
if not args.get(
|
|
245
|
-
if args[
|
|
246
|
-
args[
|
|
247
|
-
elif args[
|
|
248
|
-
args[
|
|
237
|
+
if not args.get("model_name"):
|
|
238
|
+
if args["mode"] == "embedding":
|
|
239
|
+
args["model_name"] = self.default_embedding_model
|
|
240
|
+
elif args["mode"] == "image":
|
|
241
|
+
args["model_name"] = self.default_image_model
|
|
249
242
|
else:
|
|
250
|
-
args[
|
|
251
|
-
elif args[
|
|
243
|
+
args["model_name"] = self.default_model
|
|
244
|
+
elif (args["model_name"] not in available_models) and (args["mode"] != "embedding"):
|
|
252
245
|
raise Exception(f"Invalid model name. Please use one of {available_models}")
|
|
253
246
|
finally:
|
|
254
|
-
self.model_storage.json_set(
|
|
247
|
+
self.model_storage.json_set("args", args)
|
|
255
248
|
|
|
256
249
|
def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
|
|
257
250
|
"""
|
|
@@ -269,173 +262,138 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
269
262
|
""" # noqa
|
|
270
263
|
# TODO: support for edits, embeddings and moderation
|
|
271
264
|
|
|
272
|
-
pred_args = args[
|
|
273
|
-
args = self.model_storage.json_get(
|
|
265
|
+
pred_args = args["predict_params"] if args else {}
|
|
266
|
+
args = self.model_storage.json_get("args")
|
|
274
267
|
connection_args = self.engine_storage.get_connection_args()
|
|
275
268
|
|
|
276
|
-
args[
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
269
|
+
args["api_base"] = (
|
|
270
|
+
pred_args.get("api_base")
|
|
271
|
+
or args.get("api_base")
|
|
272
|
+
or connection_args.get("api_base")
|
|
273
|
+
or os.environ.get("OPENAI_API_BASE")
|
|
274
|
+
or self.api_base
|
|
275
|
+
)
|
|
281
276
|
|
|
282
|
-
if pred_args.get(
|
|
283
|
-
args[
|
|
277
|
+
if pred_args.get("api_organization"):
|
|
278
|
+
args["api_organization"] = pred_args["api_organization"]
|
|
284
279
|
df = df.reset_index(drop=True)
|
|
285
280
|
|
|
286
|
-
if pred_args.get(
|
|
287
|
-
if pred_args[
|
|
288
|
-
args[
|
|
281
|
+
if pred_args.get("mode"):
|
|
282
|
+
if pred_args["mode"] in self.supported_modes:
|
|
283
|
+
args["mode"] = pred_args["mode"]
|
|
289
284
|
else:
|
|
290
|
-
raise Exception(
|
|
291
|
-
f"Invalid operation mode. Please use one of {self.supported_modes}."
|
|
292
|
-
) # noqa
|
|
285
|
+
raise Exception(f"Invalid operation mode. Please use one of {self.supported_modes}.") # noqa
|
|
293
286
|
|
|
294
287
|
strict_prompt_template = True
|
|
295
|
-
if pred_args.get(
|
|
296
|
-
base_template = pred_args[
|
|
297
|
-
'prompt_template'
|
|
298
|
-
] # override with predict-time template if available
|
|
288
|
+
if pred_args.get("prompt_template", False):
|
|
289
|
+
base_template = pred_args["prompt_template"] # override with predict-time template if available
|
|
299
290
|
strict_prompt_template = False
|
|
300
|
-
elif args.get(
|
|
301
|
-
base_template = args[
|
|
291
|
+
elif args.get("prompt_template", False):
|
|
292
|
+
base_template = args["prompt_template"]
|
|
302
293
|
else:
|
|
303
294
|
base_template = None
|
|
304
295
|
|
|
305
296
|
# Embedding mode
|
|
306
|
-
if args.get(
|
|
297
|
+
if args.get("mode", self.default_mode) == "embedding":
|
|
307
298
|
api_args = {
|
|
308
|
-
|
|
309
|
-
|
|
299
|
+
"question_column": pred_args.get("question_column", None),
|
|
300
|
+
"model": pred_args.get("model_name") or args.get("model_name"),
|
|
310
301
|
}
|
|
311
|
-
model_name =
|
|
312
|
-
if args.get(
|
|
313
|
-
prompts = list(df[args[
|
|
314
|
-
empty_prompt_ids = np.where(
|
|
315
|
-
df[[args['question_column']]].isna().all(axis=1).values
|
|
316
|
-
)[0]
|
|
302
|
+
model_name = "embedding"
|
|
303
|
+
if args.get("question_column"):
|
|
304
|
+
prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
|
|
305
|
+
empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0]
|
|
317
306
|
else:
|
|
318
|
-
raise Exception(
|
|
307
|
+
raise Exception("Embedding mode needs a question_column")
|
|
319
308
|
|
|
320
309
|
# Image mode
|
|
321
|
-
elif args.get(
|
|
310
|
+
elif args.get("mode", self.default_mode) == "image":
|
|
322
311
|
api_args = {
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
312
|
+
"n": pred_args.get("n", None),
|
|
313
|
+
"size": pred_args.get("size", None),
|
|
314
|
+
"response_format": pred_args.get("response_format", None),
|
|
326
315
|
}
|
|
327
|
-
api_args = {
|
|
328
|
-
|
|
329
|
-
} # filter out non-specified api args
|
|
330
|
-
model_name = pred_args.get('model_name') or args.get('model_name')
|
|
316
|
+
api_args = {k: v for k, v in api_args.items() if v is not None} # filter out non-specified api args
|
|
317
|
+
model_name = pred_args.get("model_name") or args.get("model_name")
|
|
331
318
|
|
|
332
|
-
if args.get(
|
|
333
|
-
prompts = list(df[args[
|
|
334
|
-
empty_prompt_ids = np.where(
|
|
335
|
-
|
|
336
|
-
)[0]
|
|
337
|
-
elif args.get('prompt_template'):
|
|
319
|
+
if args.get("question_column"):
|
|
320
|
+
prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
|
|
321
|
+
empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0]
|
|
322
|
+
elif args.get("prompt_template"):
|
|
338
323
|
prompts, empty_prompt_ids = get_completed_prompts(base_template, df)
|
|
339
324
|
else:
|
|
340
|
-
raise Exception(
|
|
341
|
-
'Image mode needs either `prompt_template` or `question_column`.'
|
|
342
|
-
)
|
|
325
|
+
raise Exception("Image mode needs either `prompt_template` or `question_column`.")
|
|
343
326
|
|
|
344
327
|
# Chat or normal completion mode
|
|
345
328
|
else:
|
|
346
|
-
if (
|
|
347
|
-
|
|
348
|
-
and args['question_column'] not in df.columns
|
|
349
|
-
):
|
|
350
|
-
raise Exception(
|
|
351
|
-
f"This model expects a question to answer in the '{args['question_column']}' column."
|
|
352
|
-
)
|
|
329
|
+
if args.get("question_column", False) and args["question_column"] not in df.columns:
|
|
330
|
+
raise Exception(f"This model expects a question to answer in the '{args['question_column']}' column.")
|
|
353
331
|
|
|
354
|
-
if (
|
|
355
|
-
|
|
356
|
-
and args['context_column'] not in df.columns
|
|
357
|
-
):
|
|
358
|
-
raise Exception(
|
|
359
|
-
f"This model expects context in the '{args['context_column']}' column."
|
|
360
|
-
)
|
|
332
|
+
if args.get("context_column", False) and args["context_column"] not in df.columns:
|
|
333
|
+
raise Exception(f"This model expects context in the '{args['context_column']}' column.")
|
|
361
334
|
|
|
362
335
|
# API argument validation
|
|
363
|
-
model_name = args.get(
|
|
336
|
+
model_name = args.get("model_name", self.default_model)
|
|
364
337
|
api_args = {
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
),
|
|
368
|
-
'temperature': min(
|
|
338
|
+
"max_tokens": pred_args.get("max_tokens", args.get("max_tokens", self.default_max_tokens)),
|
|
339
|
+
"temperature": min(
|
|
369
340
|
1.0,
|
|
370
|
-
max(
|
|
371
|
-
0.0, pred_args.get('temperature', args.get('temperature', 0.0))
|
|
372
|
-
),
|
|
341
|
+
max(0.0, pred_args.get("temperature", args.get("temperature", 0.0))),
|
|
373
342
|
),
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
343
|
+
"top_p": pred_args.get("top_p", None),
|
|
344
|
+
"n": pred_args.get("n", None),
|
|
345
|
+
"stop": pred_args.get("stop", None),
|
|
346
|
+
"presence_penalty": pred_args.get("presence_penalty", None),
|
|
347
|
+
"frequency_penalty": pred_args.get("frequency_penalty", None),
|
|
348
|
+
"best_of": pred_args.get("best_of", None),
|
|
349
|
+
"logit_bias": pred_args.get("logit_bias", None),
|
|
350
|
+
"user": pred_args.get("user", None),
|
|
382
351
|
}
|
|
383
352
|
|
|
384
|
-
if (
|
|
385
|
-
args.get('mode', self.default_mode) != 'default'
|
|
386
|
-
and model_name not in self.chat_completion_models
|
|
387
|
-
):
|
|
353
|
+
if args.get("mode", self.default_mode) != "default" and model_name not in self.chat_completion_models:
|
|
388
354
|
raise Exception(
|
|
389
355
|
f"Conversational modes are only available for the following models: {', '.join(self.chat_completion_models)}"
|
|
390
356
|
) # noqa
|
|
391
357
|
|
|
392
|
-
if args.get(
|
|
358
|
+
if args.get("prompt_template", False):
|
|
393
359
|
prompts, empty_prompt_ids = get_completed_prompts(base_template, df, strict=strict_prompt_template)
|
|
394
360
|
|
|
395
|
-
elif args.get(
|
|
361
|
+
elif args.get("context_column", False):
|
|
396
362
|
empty_prompt_ids = np.where(
|
|
397
|
-
df[[args[
|
|
398
|
-
.isna()
|
|
399
|
-
.all(axis=1)
|
|
400
|
-
.values
|
|
363
|
+
df[[args["context_column"], args["question_column"]]].isna().all(axis=1).values
|
|
401
364
|
)[0]
|
|
402
|
-
contexts = list(df[args[
|
|
403
|
-
questions = list(df[args[
|
|
404
|
-
prompts = [
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
]
|
|
408
|
-
|
|
409
|
-
elif 'prompt' in args:
|
|
365
|
+
contexts = list(df[args["context_column"]].apply(lambda x: str(x)))
|
|
366
|
+
questions = list(df[args["question_column"]].apply(lambda x: str(x)))
|
|
367
|
+
prompts = [f"Context: {c}\nQuestion: {q}\nAnswer: " for c, q in zip(contexts, questions)]
|
|
368
|
+
|
|
369
|
+
elif "prompt" in args:
|
|
410
370
|
empty_prompt_ids = []
|
|
411
|
-
prompts = list(df[args[
|
|
371
|
+
prompts = list(df[args["user_column"]])
|
|
412
372
|
else:
|
|
413
|
-
empty_prompt_ids = np.where(
|
|
414
|
-
|
|
415
|
-
)[0]
|
|
416
|
-
prompts = list(df[args['question_column']].apply(lambda x: str(x)))
|
|
373
|
+
empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0]
|
|
374
|
+
prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
|
|
417
375
|
|
|
418
376
|
# add json struct if available
|
|
419
|
-
if args.get(
|
|
377
|
+
if args.get("json_struct", False):
|
|
420
378
|
for i, prompt in enumerate(prompts):
|
|
421
|
-
json_struct =
|
|
422
|
-
if
|
|
379
|
+
json_struct = ""
|
|
380
|
+
if "json_struct" in df.columns and i not in empty_prompt_ids:
|
|
423
381
|
# if row has a specific json, we try to use it instead of the base prompt template
|
|
424
382
|
try:
|
|
425
|
-
if isinstance(df[
|
|
426
|
-
df[
|
|
427
|
-
for ind, val in enumerate(df[
|
|
428
|
-
json_struct = json_struct + f
|
|
383
|
+
if isinstance(df["json_struct"][i], str):
|
|
384
|
+
df["json_struct"][i] = json.loads(df["json_struct"][i])
|
|
385
|
+
for ind, val in enumerate(df["json_struct"][i].values()):
|
|
386
|
+
json_struct = json_struct + f"{ind}. {val}\n"
|
|
429
387
|
except Exception:
|
|
430
388
|
pass # if the row's json is invalid, we use the prompt template instead
|
|
431
389
|
|
|
432
|
-
if json_struct ==
|
|
433
|
-
for ind, val in enumerate(args[
|
|
434
|
-
json_struct = json_struct + f
|
|
390
|
+
if json_struct == "":
|
|
391
|
+
for ind, val in enumerate(args["json_struct"].values()):
|
|
392
|
+
json_struct = json_struct + f"{ind + 1}. {val}\n"
|
|
435
393
|
|
|
436
394
|
p = textwrap.dedent(
|
|
437
|
-
f
|
|
438
|
-
Based on the text following 'The reference text is:', assign values to the following {len(args[
|
|
395
|
+
f"""\
|
|
396
|
+
Based on the text following 'The reference text is:', assign values to the following {len(args["json_struct"])} JSON attributes:
|
|
439
397
|
{{{{json_struct}}}}
|
|
440
398
|
|
|
441
399
|
Values should follow the same order as the attributes above.
|
|
@@ -456,48 +414,51 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
456
414
|
|
|
457
415
|
Now for the real task. The reference text is:
|
|
458
416
|
{prompt}
|
|
459
|
-
|
|
417
|
+
"""
|
|
460
418
|
)
|
|
461
419
|
|
|
462
|
-
p = p.replace(
|
|
420
|
+
p = p.replace("{{json_struct}}", json_struct)
|
|
463
421
|
prompts[i] = p
|
|
464
422
|
|
|
465
423
|
# remove prompts without signal from completion queue
|
|
466
424
|
prompts = [j for i, j in enumerate(prompts) if i not in empty_prompt_ids]
|
|
467
425
|
|
|
468
426
|
api_key = get_api_key(self.api_key_name, args, self.engine_storage)
|
|
469
|
-
api_args = {
|
|
470
|
-
k: v for k, v in api_args.items() if v is not None
|
|
471
|
-
} # filter out non-specified api args
|
|
427
|
+
api_args = {k: v for k, v in api_args.items() if v is not None} # filter out non-specified api args
|
|
472
428
|
completion = self._completion(model_name, prompts, api_key, api_args, args, df)
|
|
473
429
|
|
|
474
430
|
# add null completion for empty prompts
|
|
475
431
|
for i in sorted(empty_prompt_ids):
|
|
476
432
|
completion.insert(i, None)
|
|
477
433
|
|
|
478
|
-
pred_df = pd.DataFrame(completion, columns=[args[
|
|
434
|
+
pred_df = pd.DataFrame(completion, columns=[args["target"]])
|
|
479
435
|
|
|
480
436
|
# restore json struct
|
|
481
|
-
if args.get(
|
|
437
|
+
if args.get("json_struct", False):
|
|
482
438
|
for i in pred_df.index:
|
|
483
439
|
try:
|
|
484
|
-
if
|
|
485
|
-
json_keys = df[
|
|
440
|
+
if "json_struct" in df.columns:
|
|
441
|
+
json_keys = df["json_struct"][i].keys()
|
|
486
442
|
else:
|
|
487
|
-
json_keys = args[
|
|
488
|
-
responses = pred_df[args[
|
|
443
|
+
json_keys = args["json_struct"].keys()
|
|
444
|
+
responses = pred_df[args["target"]][i].split("\n")
|
|
489
445
|
responses = [x[3:] for x in responses] # del question index
|
|
490
446
|
|
|
491
|
-
pred_df[args[
|
|
492
|
-
key: val for key, val in zip(json_keys, responses)
|
|
493
|
-
}
|
|
447
|
+
pred_df[args["target"]][i] = {key: val for key, val in zip(json_keys, responses)}
|
|
494
448
|
except Exception:
|
|
495
|
-
pred_df[args[
|
|
449
|
+
pred_df[args["target"]][i] = None
|
|
496
450
|
|
|
497
451
|
return pred_df
|
|
498
452
|
|
|
499
453
|
def _completion(
|
|
500
|
-
self,
|
|
454
|
+
self,
|
|
455
|
+
model_name: Text,
|
|
456
|
+
prompts: List[Text],
|
|
457
|
+
api_key: Text,
|
|
458
|
+
api_args: Dict,
|
|
459
|
+
args: Dict,
|
|
460
|
+
df: pd.DataFrame,
|
|
461
|
+
parallel: bool = True,
|
|
501
462
|
) -> List[Any]:
|
|
502
463
|
"""
|
|
503
464
|
Handles completion for an arbitrary amount of rows using a model connected to the OpenAI API.
|
|
@@ -531,7 +492,9 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
531
492
|
"""
|
|
532
493
|
|
|
533
494
|
@retry_with_exponential_backoff()
|
|
534
|
-
def _submit_completion(
|
|
495
|
+
def _submit_completion(
|
|
496
|
+
model_name: Text, prompts: List[Text], api_args: Dict, args: Dict, df: pd.DataFrame
|
|
497
|
+
) -> List[Text]:
|
|
535
498
|
"""
|
|
536
499
|
Submit a request to the relevant completion endpoint of the OpenAI API based on the type of task.
|
|
537
500
|
|
|
@@ -546,26 +509,22 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
546
509
|
List[Text]: List of completions.
|
|
547
510
|
"""
|
|
548
511
|
kwargs = {
|
|
549
|
-
|
|
512
|
+
"model": model_name,
|
|
550
513
|
}
|
|
551
514
|
if model_name in IMAGE_MODELS:
|
|
552
515
|
return _submit_image_completion(kwargs, prompts, api_args)
|
|
553
|
-
elif model_name ==
|
|
516
|
+
elif model_name == "embedding":
|
|
554
517
|
return _submit_embedding_completion(kwargs, prompts, api_args)
|
|
555
518
|
elif model_name in self.chat_completion_models:
|
|
556
519
|
if model_name == "gpt-3.5-turbo-instruct":
|
|
557
|
-
return _submit_normal_completion(
|
|
558
|
-
kwargs,
|
|
559
|
-
prompts,
|
|
560
|
-
api_args
|
|
561
|
-
)
|
|
520
|
+
return _submit_normal_completion(kwargs, prompts, api_args)
|
|
562
521
|
else:
|
|
563
522
|
return _submit_chat_completion(
|
|
564
523
|
kwargs,
|
|
565
524
|
prompts,
|
|
566
525
|
api_args,
|
|
567
526
|
df,
|
|
568
|
-
mode=args.get(
|
|
527
|
+
mode=args.get("mode", "conversational"),
|
|
569
528
|
)
|
|
570
529
|
else:
|
|
571
530
|
return _submit_normal_completion(kwargs, prompts, api_args)
|
|
@@ -584,9 +543,9 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
584
543
|
after_openai_query(params, response)
|
|
585
544
|
|
|
586
545
|
params2 = params.copy()
|
|
587
|
-
params2.pop(
|
|
588
|
-
params2.pop(
|
|
589
|
-
logger.debug(f
|
|
546
|
+
params2.pop("api_key", None)
|
|
547
|
+
params2.pop("user", None)
|
|
548
|
+
logger.debug(f">>>openai call: {params2}:\n{response}")
|
|
590
549
|
|
|
591
550
|
def _submit_normal_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[Text]:
|
|
592
551
|
"""
|
|
@@ -616,11 +575,11 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
616
575
|
"""
|
|
617
576
|
tidy_comps = []
|
|
618
577
|
for c in comp.choices:
|
|
619
|
-
if hasattr(c,
|
|
620
|
-
tidy_comps.append(c.text.strip(
|
|
578
|
+
if hasattr(c, "text"):
|
|
579
|
+
tidy_comps.append(c.text.strip("\n").strip(""))
|
|
621
580
|
return tidy_comps
|
|
622
581
|
|
|
623
|
-
kwargs[
|
|
582
|
+
kwargs["prompt"] = prompts
|
|
624
583
|
kwargs = {**kwargs, **api_args}
|
|
625
584
|
|
|
626
585
|
before_openai_query(kwargs)
|
|
@@ -656,11 +615,11 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
656
615
|
"""
|
|
657
616
|
tidy_comps = []
|
|
658
617
|
for c in comp.data:
|
|
659
|
-
if hasattr(c,
|
|
618
|
+
if hasattr(c, "embedding"):
|
|
660
619
|
tidy_comps.append([c.embedding])
|
|
661
620
|
return tidy_comps
|
|
662
621
|
|
|
663
|
-
kwargs[
|
|
622
|
+
kwargs["input"] = prompts
|
|
664
623
|
kwargs = {**kwargs, **api_args}
|
|
665
624
|
|
|
666
625
|
before_openai_query(kwargs)
|
|
@@ -668,7 +627,9 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
668
627
|
_log_api_call(kwargs, resp)
|
|
669
628
|
return resp
|
|
670
629
|
|
|
671
|
-
def _submit_chat_completion(
|
|
630
|
+
def _submit_chat_completion(
|
|
631
|
+
kwargs: Dict, prompts: List[Text], api_args: Dict, df: pd.DataFrame, mode: Text = "conversational"
|
|
632
|
+
) -> List[Text]:
|
|
672
633
|
"""
|
|
673
634
|
Submit a request to the chat completion endpoint of the OpenAI API.
|
|
674
635
|
|
|
@@ -698,48 +659,42 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
698
659
|
"""
|
|
699
660
|
tidy_comps = []
|
|
700
661
|
for c in comp.choices:
|
|
701
|
-
if hasattr(c,
|
|
702
|
-
tidy_comps.append(c.message.content.strip(
|
|
662
|
+
if hasattr(c, "message"):
|
|
663
|
+
tidy_comps.append(c.message.content.strip("\n").strip(""))
|
|
703
664
|
return tidy_comps
|
|
704
665
|
|
|
705
666
|
completions = []
|
|
706
|
-
if mode !=
|
|
667
|
+
if mode != "conversational" or "prompt" not in args:
|
|
707
668
|
initial_prompt = {
|
|
708
669
|
"role": "system",
|
|
709
670
|
"content": "You are a helpful assistant. Your task is to continue the chat.",
|
|
710
671
|
} # noqa
|
|
711
672
|
else:
|
|
712
673
|
# get prompt from model
|
|
713
|
-
initial_prompt = {"role": "system", "content": args[
|
|
674
|
+
initial_prompt = {"role": "system", "content": args["prompt"]} # noqa
|
|
714
675
|
|
|
715
|
-
kwargs[
|
|
676
|
+
kwargs["messages"] = [initial_prompt]
|
|
716
677
|
last_completion_content = None
|
|
717
678
|
|
|
718
679
|
for pidx in range(len(prompts)):
|
|
719
|
-
if mode !=
|
|
720
|
-
kwargs[
|
|
721
|
-
{'role': 'user', 'content': prompts[pidx]}
|
|
722
|
-
)
|
|
680
|
+
if mode != "conversational":
|
|
681
|
+
kwargs["messages"].append({"role": "user", "content": prompts[pidx]})
|
|
723
682
|
else:
|
|
724
683
|
question = prompts[pidx]
|
|
725
684
|
if question:
|
|
726
|
-
kwargs[
|
|
685
|
+
kwargs["messages"].append({"role": "user", "content": question})
|
|
727
686
|
|
|
728
|
-
assistant_column = args.get(
|
|
687
|
+
assistant_column = args.get("assistant_column")
|
|
729
688
|
if assistant_column in df.columns:
|
|
730
689
|
answer = df.iloc[pidx][assistant_column]
|
|
731
690
|
else:
|
|
732
691
|
answer = None
|
|
733
692
|
if answer:
|
|
734
|
-
kwargs[
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
mode == 'conversational' and pidx == len(prompts) - 1
|
|
740
|
-
):
|
|
741
|
-
kwargs['messages'] = truncate_msgs_for_token_limit(
|
|
742
|
-
kwargs['messages'], kwargs['model'], api_args['max_tokens']
|
|
693
|
+
kwargs["messages"].append({"role": "assistant", "content": answer})
|
|
694
|
+
|
|
695
|
+
if mode == "conversational-full" or (mode == "conversational" and pidx == len(prompts) - 1):
|
|
696
|
+
kwargs["messages"] = truncate_msgs_for_token_limit(
|
|
697
|
+
kwargs["messages"], kwargs["model"], api_args["max_tokens"]
|
|
743
698
|
)
|
|
744
699
|
pkwargs = {**kwargs, **api_args}
|
|
745
700
|
|
|
@@ -748,8 +703,8 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
748
703
|
_log_api_call(pkwargs, resp)
|
|
749
704
|
|
|
750
705
|
completions.extend(resp)
|
|
751
|
-
elif mode ==
|
|
752
|
-
kwargs[
|
|
706
|
+
elif mode == "default":
|
|
707
|
+
kwargs["messages"] = [initial_prompt] + [kwargs["messages"][-1]]
|
|
753
708
|
pkwargs = {**kwargs, **api_args}
|
|
754
709
|
|
|
755
710
|
before_openai_query(kwargs)
|
|
@@ -760,13 +715,11 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
760
715
|
else:
|
|
761
716
|
# in "normal" conversational mode, we request completions only for the last row
|
|
762
717
|
last_completion_content = None
|
|
763
|
-
completions.extend([
|
|
718
|
+
completions.extend([""])
|
|
764
719
|
|
|
765
720
|
if last_completion_content:
|
|
766
721
|
# interleave assistant responses with user input
|
|
767
|
-
kwargs[
|
|
768
|
-
{'role': 'assistant', 'content': last_completion_content[0]}
|
|
769
|
-
)
|
|
722
|
+
kwargs["messages"].append({"role": "assistant", "content": last_completion_content[0]})
|
|
770
723
|
|
|
771
724
|
return completions
|
|
772
725
|
|
|
@@ -799,46 +752,36 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
799
752
|
Returns:
|
|
800
753
|
List[Text]: List of image completions as URLs or base64 encoded images.
|
|
801
754
|
"""
|
|
802
|
-
return [
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
]
|
|
806
|
-
|
|
807
|
-
completions = [
|
|
808
|
-
client.images.generate(**{'prompt': p, **kwargs, **api_args}).data[0]
|
|
809
|
-
for p in prompts
|
|
810
|
-
]
|
|
755
|
+
return [c.url if hasattr(c, "url") else c.b64_json for c in comp]
|
|
756
|
+
|
|
757
|
+
completions = [client.images.generate(**{"prompt": p, **kwargs, **api_args}).data[0] for p in prompts]
|
|
811
758
|
return _tidy(completions)
|
|
812
759
|
|
|
813
760
|
client = self._get_client(
|
|
814
761
|
api_key=api_key,
|
|
815
|
-
base_url=args.get(
|
|
816
|
-
org=args.pop(
|
|
817
|
-
args=args
|
|
762
|
+
base_url=args.get("api_base"),
|
|
763
|
+
org=args.pop("api_organization") if "api_organization" in args else None,
|
|
764
|
+
args=args,
|
|
818
765
|
)
|
|
819
766
|
|
|
820
767
|
try:
|
|
821
768
|
# check if simple completion works
|
|
822
|
-
completion = _submit_completion(
|
|
823
|
-
model_name, prompts, api_args, args, df
|
|
824
|
-
)
|
|
769
|
+
completion = _submit_completion(model_name, prompts, api_args, args, df)
|
|
825
770
|
return completion
|
|
826
771
|
except Exception as e:
|
|
827
772
|
# else, we get the max batch size
|
|
828
|
-
if
|
|
829
|
-
pattern =
|
|
830
|
-
max_batch_size = int(e[e.find(pattern) + len(pattern):].split(
|
|
773
|
+
if "you can currently request up to at most a total of" in str(e):
|
|
774
|
+
pattern = "a total of"
|
|
775
|
+
max_batch_size = int(e[e.find(pattern) + len(pattern) :].split(").")[0])
|
|
831
776
|
else:
|
|
832
|
-
max_batch_size =
|
|
833
|
-
self.max_batch_size
|
|
834
|
-
) # guards against changes in the API message
|
|
777
|
+
max_batch_size = self.max_batch_size # guards against changes in the API message
|
|
835
778
|
|
|
836
779
|
if not parallel:
|
|
837
780
|
completion = None
|
|
838
781
|
for i in range(math.ceil(len(prompts) / max_batch_size)):
|
|
839
782
|
partial = _submit_completion(
|
|
840
783
|
model_name,
|
|
841
|
-
prompts[i * max_batch_size: (i + 1) * max_batch_size],
|
|
784
|
+
prompts[i * max_batch_size : (i + 1) * max_batch_size],
|
|
842
785
|
api_args,
|
|
843
786
|
args,
|
|
844
787
|
df,
|
|
@@ -846,20 +789,18 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
846
789
|
if not completion:
|
|
847
790
|
completion = partial
|
|
848
791
|
else:
|
|
849
|
-
completion[
|
|
850
|
-
for field in (
|
|
851
|
-
completion[
|
|
792
|
+
completion["choices"].extend(partial["choices"])
|
|
793
|
+
for field in ("prompt_tokens", "completion_tokens", "total_tokens"):
|
|
794
|
+
completion["usage"][field] += partial["usage"][field]
|
|
852
795
|
else:
|
|
853
796
|
promises = []
|
|
854
797
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
855
798
|
for i in range(math.ceil(len(prompts) / max_batch_size)):
|
|
856
|
-
logger.debug(
|
|
857
|
-
f'{i * max_batch_size}:{(i+1) * max_batch_size}/{len(prompts)}'
|
|
858
|
-
)
|
|
799
|
+
logger.debug(f"{i * max_batch_size}:{(i + 1) * max_batch_size}/{len(prompts)}")
|
|
859
800
|
future = executor.submit(
|
|
860
801
|
_submit_completion,
|
|
861
802
|
model_name,
|
|
862
|
-
prompts[i * max_batch_size: (i + 1) * max_batch_size],
|
|
803
|
+
prompts[i * max_batch_size : (i + 1) * max_batch_size],
|
|
863
804
|
api_args,
|
|
864
805
|
args,
|
|
865
806
|
df,
|
|
@@ -868,9 +809,9 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
868
809
|
completion = None
|
|
869
810
|
for p in promises:
|
|
870
811
|
if not completion:
|
|
871
|
-
completion = p[
|
|
812
|
+
completion = p["choices"].result()
|
|
872
813
|
else:
|
|
873
|
-
completion.extend(p[
|
|
814
|
+
completion.extend(p["choices"].result())
|
|
874
815
|
|
|
875
816
|
return completion
|
|
876
817
|
|
|
@@ -886,26 +827,26 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
886
827
|
"""
|
|
887
828
|
# TODO: Update to use update() artifacts
|
|
888
829
|
|
|
889
|
-
args = self.model_storage.json_get(
|
|
830
|
+
args = self.model_storage.json_get("args")
|
|
890
831
|
api_key = get_api_key(self.api_key_name, args, self.engine_storage)
|
|
891
|
-
if attribute ==
|
|
892
|
-
return pd.DataFrame(args.items(), columns=[
|
|
893
|
-
elif attribute ==
|
|
894
|
-
model_name = args.get(
|
|
832
|
+
if attribute == "args":
|
|
833
|
+
return pd.DataFrame(args.items(), columns=["key", "value"])
|
|
834
|
+
elif attribute == "metadata":
|
|
835
|
+
model_name = args.get("model_name", self.default_model)
|
|
895
836
|
try:
|
|
896
837
|
client = self._get_client(
|
|
897
838
|
api_key=api_key,
|
|
898
|
-
base_url=args.get(
|
|
899
|
-
org=args.get(
|
|
839
|
+
base_url=args.get("api_base"),
|
|
840
|
+
org=args.get("api_organization"),
|
|
900
841
|
args=args,
|
|
901
842
|
)
|
|
902
843
|
meta = client.models.retrieve(model_name)
|
|
903
844
|
except Exception as e:
|
|
904
|
-
meta = {
|
|
905
|
-
return pd.DataFrame(dict(meta).items(), columns=[
|
|
845
|
+
meta = {"error": str(e)}
|
|
846
|
+
return pd.DataFrame(dict(meta).items(), columns=["key", "value"])
|
|
906
847
|
else:
|
|
907
|
-
tables = [
|
|
908
|
-
return pd.DataFrame(tables, columns=[
|
|
848
|
+
tables = ["args", "metadata"]
|
|
849
|
+
return pd.DataFrame(tables, columns=["tables"])
|
|
909
850
|
|
|
910
851
|
def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
|
|
911
852
|
"""
|
|
@@ -937,14 +878,14 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
937
878
|
|
|
938
879
|
api_key = get_api_key(self.api_key_name, args, self.engine_storage)
|
|
939
880
|
|
|
940
|
-
using_args = args.pop(
|
|
881
|
+
using_args = args.pop("using") if "using" in args else {}
|
|
941
882
|
|
|
942
|
-
api_base = using_args.get(
|
|
943
|
-
org = using_args.get(
|
|
883
|
+
api_base = using_args.get("api_base", os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE))
|
|
884
|
+
org = using_args.get("api_organization")
|
|
944
885
|
client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
|
|
945
886
|
|
|
946
887
|
args = {**using_args, **args}
|
|
947
|
-
prev_model_name = self.base_model_storage.json_get(
|
|
888
|
+
prev_model_name = self.base_model_storage.json_get("args").get("model_name", "")
|
|
948
889
|
|
|
949
890
|
if prev_model_name not in self.supported_ft_models:
|
|
950
891
|
# base model may be already FTed, check prefixes
|
|
@@ -956,80 +897,73 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
956
897
|
f"This model cannot be finetuned. Supported base models are {self.supported_ft_models}."
|
|
957
898
|
)
|
|
958
899
|
|
|
959
|
-
finetune_time = datetime.datetime.now().strftime(
|
|
900
|
+
finetune_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
960
901
|
|
|
961
902
|
temp_storage_path = tempfile.mkdtemp()
|
|
962
903
|
temp_file_name = f"ft_{finetune_time}"
|
|
963
904
|
temp_model_storage_path = f"{temp_storage_path}/{temp_file_name}.jsonl"
|
|
964
905
|
|
|
965
|
-
file_names = self._prepare_ft_jsonl(
|
|
966
|
-
df, temp_storage_path, temp_file_name, temp_model_storage_path
|
|
967
|
-
)
|
|
906
|
+
file_names = self._prepare_ft_jsonl(df, temp_storage_path, temp_file_name, temp_model_storage_path)
|
|
968
907
|
|
|
969
908
|
jsons = {k: None for k in file_names.keys()}
|
|
970
909
|
for split, file_name in file_names.items():
|
|
971
910
|
if os.path.isfile(os.path.join(temp_storage_path, file_name)):
|
|
972
911
|
jsons[split] = client.files.create(
|
|
973
|
-
file=open(f"{temp_storage_path}/{file_name}", "rb"),
|
|
974
|
-
purpose='fine-tune'
|
|
912
|
+
file=open(f"{temp_storage_path}/{file_name}", "rb"), purpose="fine-tune"
|
|
975
913
|
)
|
|
976
914
|
|
|
977
|
-
if type(jsons[
|
|
978
|
-
train_file_id = jsons[
|
|
915
|
+
if type(jsons["train"]) is openai.types.FileObject:
|
|
916
|
+
train_file_id = jsons["train"].id
|
|
979
917
|
else:
|
|
980
|
-
train_file_id = jsons[
|
|
918
|
+
train_file_id = jsons["base"].id
|
|
981
919
|
|
|
982
|
-
if type(jsons[
|
|
983
|
-
val_file_id = jsons[
|
|
920
|
+
if type(jsons["val"]) is openai.types.FileObject:
|
|
921
|
+
val_file_id = jsons["val"].id
|
|
984
922
|
else:
|
|
985
923
|
val_file_id = None
|
|
986
924
|
|
|
987
925
|
# `None` values are internally imputed by OpenAI to `null` or default values
|
|
988
926
|
ft_params = {
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
927
|
+
"training_file": train_file_id,
|
|
928
|
+
"validation_file": val_file_id,
|
|
929
|
+
"model": self._get_ft_model_type(prev_model_name),
|
|
992
930
|
}
|
|
993
931
|
ft_params = self._add_extra_ft_params(ft_params, using_args)
|
|
994
932
|
|
|
995
933
|
start_time = datetime.datetime.now()
|
|
996
934
|
|
|
997
|
-
ft_stats, result_file_id = self._ft_call(ft_params, client, args.get(
|
|
935
|
+
ft_stats, result_file_id = self._ft_call(ft_params, client, args.get("hour_budget", 8))
|
|
998
936
|
ft_model_name = ft_stats.fine_tuned_model
|
|
999
937
|
|
|
1000
938
|
end_time = datetime.datetime.now()
|
|
1001
939
|
runtime = end_time - start_time
|
|
1002
940
|
name_extension = client.files.retrieve(file_id=result_file_id).filename
|
|
1003
|
-
result_path = f
|
|
941
|
+
result_path = f"{temp_storage_path}/ft_{finetune_time}_result_{name_extension}"
|
|
1004
942
|
|
|
1005
943
|
try:
|
|
1006
944
|
client.files.content(file_id=result_file_id).stream_to_file(result_path)
|
|
1007
|
-
if
|
|
945
|
+
if ".csv" in name_extension:
|
|
1008
946
|
# legacy endpoint
|
|
1009
947
|
train_stats = pd.read_csv(result_path)
|
|
1010
|
-
if
|
|
1011
|
-
train_stats = train_stats[
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
train_stats = pd.read_json(
|
|
1019
|
-
path_or_buf=result_path, lines=True
|
|
1020
|
-
) # new endpoint
|
|
1021
|
-
args['ft_api_info'] = args['ft_result_stats'] = train_stats.to_dict()
|
|
948
|
+
if "validation_token_accuracy" in train_stats.columns:
|
|
949
|
+
train_stats = train_stats[train_stats["validation_token_accuracy"].notnull()]
|
|
950
|
+
args["ft_api_info"] = ft_stats.dict()
|
|
951
|
+
args["ft_result_stats"] = train_stats.to_dict()
|
|
952
|
+
|
|
953
|
+
elif ".json" in name_extension:
|
|
954
|
+
train_stats = pd.read_json(path_or_buf=result_path, lines=True) # new endpoint
|
|
955
|
+
args["ft_api_info"] = args["ft_result_stats"] = train_stats.to_dict()
|
|
1022
956
|
|
|
1023
957
|
except Exception:
|
|
1024
|
-
logger.info(
|
|
958
|
+
logger.info(
|
|
959
|
+
f"Error retrieving fine-tuning results. Please check manually for information on job {ft_stats.id} (result file {result_file_id})."
|
|
960
|
+
)
|
|
1025
961
|
|
|
1026
|
-
args[
|
|
1027
|
-
args[
|
|
1028
|
-
args[
|
|
1029
|
-
'mode', self.default_mode
|
|
1030
|
-
)
|
|
962
|
+
args["model_name"] = ft_model_name
|
|
963
|
+
args["runtime"] = runtime.total_seconds()
|
|
964
|
+
args["mode"] = self.base_model_storage.json_get("args").get("mode", self.default_mode)
|
|
1031
965
|
|
|
1032
|
-
self.model_storage.json_set(
|
|
966
|
+
self.model_storage.json_set("args", args)
|
|
1033
967
|
shutil.rmtree(temp_storage_path)
|
|
1034
968
|
|
|
1035
969
|
@staticmethod
|
|
@@ -1045,7 +979,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1045
979
|
Returns:
|
|
1046
980
|
Dict: File names for the fine-tuning process.
|
|
1047
981
|
"""
|
|
1048
|
-
df.to_json(temp_model_path, orient=
|
|
982
|
+
df.to_json(temp_model_path, orient="records", lines=True)
|
|
1049
983
|
|
|
1050
984
|
# TODO avoid subprocess usage once OpenAI enables non-CLI access, or refactor to use our own LLM utils instead
|
|
1051
985
|
subprocess.run(
|
|
@@ -1055,7 +989,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1055
989
|
"fine_tunes.prepare_data",
|
|
1056
990
|
"-f",
|
|
1057
991
|
temp_model_path, # from file
|
|
1058
|
-
|
|
992
|
+
"-q", # quiet mode (accepts all suggestions)
|
|
1059
993
|
],
|
|
1060
994
|
stdout=subprocess.PIPE,
|
|
1061
995
|
stderr=subprocess.PIPE,
|
|
@@ -1063,10 +997,10 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1063
997
|
)
|
|
1064
998
|
|
|
1065
999
|
file_names = {
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1000
|
+
"original": f"{temp_filename}.jsonl",
|
|
1001
|
+
"base": f"{temp_filename}_prepared.jsonl",
|
|
1002
|
+
"train": f"{temp_filename}_prepared_train.jsonl",
|
|
1003
|
+
"val": f"{temp_filename}_prepared_valid.jsonl",
|
|
1070
1004
|
}
|
|
1071
1005
|
return file_names
|
|
1072
1006
|
|
|
@@ -1083,7 +1017,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1083
1017
|
for model_type in self.supported_ft_models:
|
|
1084
1018
|
if model_type in model_name.lower():
|
|
1085
1019
|
return model_type
|
|
1086
|
-
return
|
|
1020
|
+
return "babbage-002"
|
|
1087
1021
|
|
|
1088
1022
|
@staticmethod
|
|
1089
1023
|
def _add_extra_ft_params(ft_params: Dict, using_args: Dict) -> Dict:
|
|
@@ -1098,22 +1032,14 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1098
1032
|
Dict: Fine-tuning parameters with extra parameters.
|
|
1099
1033
|
"""
|
|
1100
1034
|
extra_params = {
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
),
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
),
|
|
1110
|
-
'classification_n_classes': using_args.get(
|
|
1111
|
-
'classification_n_classes', None
|
|
1112
|
-
),
|
|
1113
|
-
'classification_positive_class': using_args.get(
|
|
1114
|
-
'classification_positive_class', None
|
|
1115
|
-
),
|
|
1116
|
-
'classification_betas': using_args.get('classification_betas', None),
|
|
1035
|
+
"n_epochs": using_args.get("n_epochs", None),
|
|
1036
|
+
"batch_size": using_args.get("batch_size", None),
|
|
1037
|
+
"learning_rate_multiplier": using_args.get("learning_rate_multiplier", None),
|
|
1038
|
+
"prompt_loss_weight": using_args.get("prompt_loss_weight", None),
|
|
1039
|
+
"compute_classification_metrics": using_args.get("compute_classification_metrics", None),
|
|
1040
|
+
"classification_n_classes": using_args.get("classification_n_classes", None),
|
|
1041
|
+
"classification_positive_class": using_args.get("classification_positive_class", None),
|
|
1042
|
+
"classification_betas": using_args.get("classification_betas", None),
|
|
1117
1043
|
}
|
|
1118
1044
|
return {**ft_params, **extra_params}
|
|
1119
1045
|
|
|
@@ -1137,9 +1063,7 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1137
1063
|
Returns:
|
|
1138
1064
|
Tuple[FineTuningJob, Text]: Fine-tuning stats and result file ID.
|
|
1139
1065
|
"""
|
|
1140
|
-
ft_result = client.fine_tuning.jobs.create(
|
|
1141
|
-
**{k: v for k, v in ft_params.items() if v is not None}
|
|
1142
|
-
)
|
|
1066
|
+
ft_result = client.fine_tuning.jobs.create(**{k: v for k, v in ft_params.items() if v is not None})
|
|
1143
1067
|
|
|
1144
1068
|
@retry_with_exponential_backoff(
|
|
1145
1069
|
hour_budget=hour_budget,
|
|
@@ -1158,22 +1082,22 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1158
1082
|
FineTuningJob: Fine-tuning stats.
|
|
1159
1083
|
"""
|
|
1160
1084
|
ft_retrieved = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id)
|
|
1161
|
-
if ft_retrieved.status in (
|
|
1085
|
+
if ft_retrieved.status in ("succeeded", "failed", "cancelled"):
|
|
1162
1086
|
return ft_retrieved
|
|
1163
1087
|
else:
|
|
1164
|
-
raise PendingFT(
|
|
1088
|
+
raise PendingFT("Fine-tuning still pending!")
|
|
1165
1089
|
|
|
1166
1090
|
ft_stats = _check_ft_status(ft_result.id)
|
|
1167
1091
|
|
|
1168
|
-
if ft_stats.status !=
|
|
1169
|
-
err_message = ft_stats.events[-1].message if hasattr(ft_stats,
|
|
1170
|
-
ft_status = ft_stats.status if hasattr(ft_stats,
|
|
1092
|
+
if ft_stats.status != "succeeded":
|
|
1093
|
+
err_message = ft_stats.events[-1].message if hasattr(ft_stats, "events") else "could not retrieve!"
|
|
1094
|
+
ft_status = ft_stats.status if hasattr(ft_stats, "status") else "N/A"
|
|
1171
1095
|
raise Exception(
|
|
1172
1096
|
f"Fine-tuning did not complete successfully (status: {ft_status}). Error message: {err_message}"
|
|
1173
1097
|
) # noqa
|
|
1174
1098
|
|
|
1175
1099
|
result_file_id = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=ft_result.id).result_files[0]
|
|
1176
|
-
if hasattr(result_file_id,
|
|
1100
|
+
if hasattr(result_file_id, "id"):
|
|
1177
1101
|
result_file_id = result_file_id.id # legacy endpoint
|
|
1178
1102
|
|
|
1179
1103
|
return ft_stats, result_file_id
|
|
@@ -1191,11 +1115,8 @@ class OpenAIHandler(BaseMLEngine):
|
|
|
1191
1115
|
Returns:
|
|
1192
1116
|
openai.OpenAI: OpenAI client.
|
|
1193
1117
|
"""
|
|
1194
|
-
if args is not None and args.get(
|
|
1118
|
+
if args is not None and args.get("provider") == "azure":
|
|
1195
1119
|
return AzureOpenAI(
|
|
1196
|
-
api_key=api_key,
|
|
1197
|
-
azure_endpoint=base_url,
|
|
1198
|
-
api_version=args.get('api_version'),
|
|
1199
|
-
organization=org
|
|
1120
|
+
api_key=api_key, azure_endpoint=base_url, api_version=args.get("api_version"), organization=org
|
|
1200
1121
|
)
|
|
1201
1122
|
return OpenAI(api_key=api_key, base_url=base_url, organization=org)
|