MindsDB 25.6.2.0__py3-none-any.whl → 25.6.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (30) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/api/a2a/agent.py +25 -4
  3. mindsdb/api/a2a/task_manager.py +68 -6
  4. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +91 -84
  5. mindsdb/api/http/namespaces/knowledge_bases.py +132 -154
  6. mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py +219 -28
  7. mindsdb/integrations/handlers/llama_index_handler/requirements.txt +1 -1
  8. mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +3 -0
  9. mindsdb/integrations/handlers/openai_handler/openai_handler.py +277 -356
  10. mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +94 -8
  11. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +19 -1
  12. mindsdb/integrations/libs/api_handler.py +19 -1
  13. mindsdb/integrations/libs/base.py +86 -2
  14. mindsdb/interfaces/agents/agents_controller.py +32 -6
  15. mindsdb/interfaces/agents/constants.py +1 -0
  16. mindsdb/interfaces/agents/mindsdb_database_agent.py +23 -18
  17. mindsdb/interfaces/data_catalog/data_catalog_loader.py +22 -6
  18. mindsdb/interfaces/data_catalog/data_catalog_reader.py +4 -0
  19. mindsdb/interfaces/database/integrations.py +4 -2
  20. mindsdb/interfaces/knowledge_base/controller.py +3 -15
  21. mindsdb/interfaces/knowledge_base/evaluate.py +0 -3
  22. mindsdb/interfaces/skills/skills_controller.py +0 -23
  23. mindsdb/interfaces/skills/sql_agent.py +8 -4
  24. mindsdb/interfaces/storage/db.py +20 -4
  25. mindsdb/utilities/config.py +5 -1
  26. {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/METADATA +250 -250
  27. {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/RECORD +30 -30
  28. {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/WHEEL +0 -0
  29. {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/licenses/LICENSE +0 -0
  30. {mindsdb-25.6.2.0.dist-info → mindsdb-25.6.3.0.dist-info}/top_level.txt +0 -0
@@ -30,7 +30,7 @@ from mindsdb.integrations.handlers.openai_handler.constants import (
30
30
  OPENAI_API_BASE,
31
31
  DEFAULT_CHAT_MODEL,
32
32
  DEFAULT_EMBEDDING_MODEL,
33
- DEFAULT_IMAGE_MODEL
33
+ DEFAULT_IMAGE_MODEL,
34
34
  )
35
35
  from mindsdb.integrations.libs.llm.utils import get_completed_prompts
36
36
  from mindsdb.integrations.utilities.handler_utils import get_api_key
@@ -43,7 +43,7 @@ class OpenAIHandler(BaseMLEngine):
43
43
  This handler handles connection and inference with the OpenAI API.
44
44
  """
45
45
 
46
- name = 'openai'
46
+ name = "openai"
47
47
 
48
48
  def __init__(self, *args, **kwargs):
49
49
  super().__init__(*args, **kwargs)
@@ -51,15 +51,13 @@ class OpenAIHandler(BaseMLEngine):
51
51
  self.default_model = DEFAULT_CHAT_MODEL
52
52
  self.default_embedding_model = DEFAULT_EMBEDDING_MODEL
53
53
  self.default_image_model = DEFAULT_IMAGE_MODEL
54
- self.default_mode = (
55
- 'default' # can also be 'conversational' or 'conversational-full'
56
- )
54
+ self.default_mode = "default" # can also be 'conversational' or 'conversational-full'
57
55
  self.supported_modes = [
58
- 'default',
59
- 'conversational',
60
- 'conversational-full',
61
- 'image',
62
- 'embedding',
56
+ "default",
57
+ "conversational",
58
+ "conversational-full",
59
+ "image",
60
+ "embedding",
63
61
  ]
64
62
  self.rate_limit = 60 # requests per minute
65
63
  self.max_batch_size = 20
@@ -67,8 +65,8 @@ class OpenAIHandler(BaseMLEngine):
67
65
  self.chat_completion_models = CHAT_MODELS
68
66
  self.supported_ft_models = FINETUNING_MODELS # base models compatible with finetuning
69
67
  # For now this are only used for handlers that inherits OpenAIHandler and don't need to override base methods
70
- self.api_key_name = getattr(self, 'api_key_name', self.name)
71
- self.api_base = getattr(self, 'api_base', OPENAI_API_BASE)
68
+ self.api_key_name = getattr(self, "api_key_name", self.name)
69
+ self.api_base = getattr(self, "api_base", OPENAI_API_BASE)
72
70
 
73
71
  def create_engine(self, connection_args: Dict) -> None:
74
72
  """
@@ -84,10 +82,10 @@ class OpenAIHandler(BaseMLEngine):
84
82
  None
85
83
  """
86
84
  connection_args = {k.lower(): v for k, v in connection_args.items()}
87
- api_key = connection_args.get('openai_api_key')
85
+ api_key = connection_args.get("openai_api_key")
88
86
  if api_key is not None:
89
- org = connection_args.get('api_organization')
90
- api_base = connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)
87
+ org = connection_args.get("api_organization")
88
+ api_base = connection_args.get("api_base") or os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE)
91
89
  client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=connection_args)
92
90
  OpenAIHandler._check_client_connection(client)
93
91
 
@@ -106,13 +104,13 @@ class OpenAIHandler(BaseMLEngine):
106
104
  None
107
105
  """
108
106
  try:
109
- client.models.retrieve('test')
107
+ client.models.retrieve("test")
110
108
  except NotFoundError:
111
109
  pass
112
110
  except AuthenticationError as e:
113
- if e.body['code'] == 'invalid_api_key':
114
- raise Exception('Invalid api key')
115
- raise Exception(f'Something went wrong: {e}')
111
+ if e.body["code"] == "invalid_api_key":
112
+ raise Exception("Invalid api key")
113
+ raise Exception(f"Something went wrong: {e}")
116
114
 
117
115
  @staticmethod
118
116
  def create_validation(target: Text, args: Dict = None, **kwargs: Any) -> None:
@@ -130,41 +128,29 @@ class OpenAIHandler(BaseMLEngine):
130
128
  Returns:
131
129
  None
132
130
  """
133
- if 'using' not in args:
134
- raise Exception(
135
- "OpenAI engine requires a USING clause! Refer to its documentation for more details."
136
- )
131
+ if "using" not in args:
132
+ raise Exception("OpenAI engine requires a USING clause! Refer to its documentation for more details.")
137
133
  else:
138
- args = args['using']
134
+ args = args["using"]
139
135
 
140
- if (
141
- len(
142
- set(args.keys())
143
- & {'question_column', 'prompt_template', 'prompt'}
144
- )
145
- == 0
146
- ):
147
- raise Exception(
148
- 'One of `question_column`, `prompt_template` or `prompt` is required for this engine.'
149
- )
136
+ if len(set(args.keys()) & {"question_column", "prompt_template", "prompt"}) == 0:
137
+ raise Exception("One of `question_column`, `prompt_template` or `prompt` is required for this engine.")
150
138
 
151
139
  keys_collection = [
152
- ['prompt_template'],
153
- ['question_column', 'context_column'],
154
- ['prompt', 'user_column', 'assistant_column'],
140
+ ["prompt_template"],
141
+ ["question_column", "context_column"],
142
+ ["prompt", "user_column", "assistant_column"],
155
143
  ]
156
144
  for keys in keys_collection:
157
- if keys[0] in args and any(
158
- x[0] in args for x in keys_collection if x != keys
159
- ):
145
+ if keys[0] in args and any(x[0] in args for x in keys_collection if x != keys):
160
146
  raise Exception(
161
147
  textwrap.dedent(
162
- '''\
148
+ """\
163
149
  Please provide one of
164
150
  1) a `prompt_template`
165
151
  2) a `question_column` and an optional `context_column`
166
152
  3) a `prompt', 'user_column' and 'assistant_column`
167
- '''
153
+ """
168
154
  )
169
155
  )
170
156
 
@@ -202,11 +188,15 @@ class OpenAIHandler(BaseMLEngine):
202
188
  f"Unknown arguments: {', '.join(unknown_args)}.\n Known arguments are: {', '.join(known_args)}"
203
189
  )
204
190
 
205
- engine_storage = kwargs['handler_storage']
191
+ engine_storage = kwargs["handler_storage"]
206
192
  connection_args = engine_storage.get_connection_args()
207
- api_key = get_api_key('openai', args, engine_storage=engine_storage)
208
- api_base = args.get('api_base') or connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)
209
- org = args.get('api_organization')
193
+ api_key = get_api_key("openai", args, engine_storage=engine_storage)
194
+ api_base = (
195
+ args.get("api_base")
196
+ or connection_args.get("api_base")
197
+ or os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE)
198
+ )
199
+ org = args.get("api_organization")
210
200
  client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
211
201
  OpenAIHandler._check_client_connection(client)
212
202
 
@@ -225,33 +215,36 @@ class OpenAIHandler(BaseMLEngine):
225
215
  Returns:
226
216
  None
227
217
  """
228
- args = args['using']
229
- args['target'] = target
218
+ args = args["using"]
219
+ args["target"] = target
230
220
  try:
231
221
  api_key = get_api_key(self.api_key_name, args, self.engine_storage)
232
222
  connection_args = self.engine_storage.get_connection_args()
233
- api_base = args.get('api_base') or connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE') or self.api_base
234
- client = self._get_client(api_key=api_key, base_url=api_base, org=args.get('api_organization'), args=args)
223
+ api_base = (
224
+ args.get("api_base")
225
+ or connection_args.get("api_base")
226
+ or os.environ.get("OPENAI_API_BASE")
227
+ or self.api_base
228
+ )
229
+ client = self._get_client(api_key=api_key, base_url=api_base, org=args.get("api_organization"), args=args)
235
230
  available_models = get_available_models(client)
236
231
 
237
- if not args.get('mode'):
238
- args['mode'] = self.default_mode
239
- elif args['mode'] not in self.supported_modes:
240
- raise Exception(
241
- f"Invalid operation mode. Please use one of {self.supported_modes}"
242
- )
232
+ if not args.get("mode"):
233
+ args["mode"] = self.default_mode
234
+ elif args["mode"] not in self.supported_modes:
235
+ raise Exception(f"Invalid operation mode. Please use one of {self.supported_modes}")
243
236
 
244
- if not args.get('model_name'):
245
- if args['mode'] == 'embedding':
246
- args['model_name'] = self.default_embedding_model
247
- elif args['mode'] == 'image':
248
- args['model_name'] = self.default_image_model
237
+ if not args.get("model_name"):
238
+ if args["mode"] == "embedding":
239
+ args["model_name"] = self.default_embedding_model
240
+ elif args["mode"] == "image":
241
+ args["model_name"] = self.default_image_model
249
242
  else:
250
- args['model_name'] = self.default_model
251
- elif args['model_name'] not in available_models:
243
+ args["model_name"] = self.default_model
244
+ elif (args["model_name"] not in available_models) and (args["mode"] != "embedding"):
252
245
  raise Exception(f"Invalid model name. Please use one of {available_models}")
253
246
  finally:
254
- self.model_storage.json_set('args', args)
247
+ self.model_storage.json_set("args", args)
255
248
 
256
249
  def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
257
250
  """
@@ -269,173 +262,138 @@ class OpenAIHandler(BaseMLEngine):
269
262
  """ # noqa
270
263
  # TODO: support for edits, embeddings and moderation
271
264
 
272
- pred_args = args['predict_params'] if args else {}
273
- args = self.model_storage.json_get('args')
265
+ pred_args = args["predict_params"] if args else {}
266
+ args = self.model_storage.json_get("args")
274
267
  connection_args = self.engine_storage.get_connection_args()
275
268
 
276
- args['api_base'] = (pred_args.get('api_base')
277
- or args.get('api_base')
278
- or connection_args.get('api_base')
279
- or os.environ.get('OPENAI_API_BASE')
280
- or self.api_base)
269
+ args["api_base"] = (
270
+ pred_args.get("api_base")
271
+ or args.get("api_base")
272
+ or connection_args.get("api_base")
273
+ or os.environ.get("OPENAI_API_BASE")
274
+ or self.api_base
275
+ )
281
276
 
282
- if pred_args.get('api_organization'):
283
- args['api_organization'] = pred_args['api_organization']
277
+ if pred_args.get("api_organization"):
278
+ args["api_organization"] = pred_args["api_organization"]
284
279
  df = df.reset_index(drop=True)
285
280
 
286
- if pred_args.get('mode'):
287
- if pred_args['mode'] in self.supported_modes:
288
- args['mode'] = pred_args['mode']
281
+ if pred_args.get("mode"):
282
+ if pred_args["mode"] in self.supported_modes:
283
+ args["mode"] = pred_args["mode"]
289
284
  else:
290
- raise Exception(
291
- f"Invalid operation mode. Please use one of {self.supported_modes}."
292
- ) # noqa
285
+ raise Exception(f"Invalid operation mode. Please use one of {self.supported_modes}.") # noqa
293
286
 
294
287
  strict_prompt_template = True
295
- if pred_args.get('prompt_template', False):
296
- base_template = pred_args[
297
- 'prompt_template'
298
- ] # override with predict-time template if available
288
+ if pred_args.get("prompt_template", False):
289
+ base_template = pred_args["prompt_template"] # override with predict-time template if available
299
290
  strict_prompt_template = False
300
- elif args.get('prompt_template', False):
301
- base_template = args['prompt_template']
291
+ elif args.get("prompt_template", False):
292
+ base_template = args["prompt_template"]
302
293
  else:
303
294
  base_template = None
304
295
 
305
296
  # Embedding mode
306
- if args.get('mode', self.default_mode) == 'embedding':
297
+ if args.get("mode", self.default_mode) == "embedding":
307
298
  api_args = {
308
- 'question_column': pred_args.get('question_column', None),
309
- 'model': pred_args.get('model_name') or args.get('model_name'),
299
+ "question_column": pred_args.get("question_column", None),
300
+ "model": pred_args.get("model_name") or args.get("model_name"),
310
301
  }
311
- model_name = 'embedding'
312
- if args.get('question_column'):
313
- prompts = list(df[args['question_column']].apply(lambda x: str(x)))
314
- empty_prompt_ids = np.where(
315
- df[[args['question_column']]].isna().all(axis=1).values
316
- )[0]
302
+ model_name = "embedding"
303
+ if args.get("question_column"):
304
+ prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
305
+ empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0]
317
306
  else:
318
- raise Exception('Embedding mode needs a question_column')
307
+ raise Exception("Embedding mode needs a question_column")
319
308
 
320
309
  # Image mode
321
- elif args.get('mode', self.default_mode) == 'image':
310
+ elif args.get("mode", self.default_mode) == "image":
322
311
  api_args = {
323
- 'n': pred_args.get('n', None),
324
- 'size': pred_args.get('size', None),
325
- 'response_format': pred_args.get('response_format', None),
312
+ "n": pred_args.get("n", None),
313
+ "size": pred_args.get("size", None),
314
+ "response_format": pred_args.get("response_format", None),
326
315
  }
327
- api_args = {
328
- k: v for k, v in api_args.items() if v is not None
329
- } # filter out non-specified api args
330
- model_name = pred_args.get('model_name') or args.get('model_name')
316
+ api_args = {k: v for k, v in api_args.items() if v is not None} # filter out non-specified api args
317
+ model_name = pred_args.get("model_name") or args.get("model_name")
331
318
 
332
- if args.get('question_column'):
333
- prompts = list(df[args['question_column']].apply(lambda x: str(x)))
334
- empty_prompt_ids = np.where(
335
- df[[args['question_column']]].isna().all(axis=1).values
336
- )[0]
337
- elif args.get('prompt_template'):
319
+ if args.get("question_column"):
320
+ prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
321
+ empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0]
322
+ elif args.get("prompt_template"):
338
323
  prompts, empty_prompt_ids = get_completed_prompts(base_template, df)
339
324
  else:
340
- raise Exception(
341
- 'Image mode needs either `prompt_template` or `question_column`.'
342
- )
325
+ raise Exception("Image mode needs either `prompt_template` or `question_column`.")
343
326
 
344
327
  # Chat or normal completion mode
345
328
  else:
346
- if (
347
- args.get('question_column', False)
348
- and args['question_column'] not in df.columns
349
- ):
350
- raise Exception(
351
- f"This model expects a question to answer in the '{args['question_column']}' column."
352
- )
329
+ if args.get("question_column", False) and args["question_column"] not in df.columns:
330
+ raise Exception(f"This model expects a question to answer in the '{args['question_column']}' column.")
353
331
 
354
- if (
355
- args.get('context_column', False)
356
- and args['context_column'] not in df.columns
357
- ):
358
- raise Exception(
359
- f"This model expects context in the '{args['context_column']}' column."
360
- )
332
+ if args.get("context_column", False) and args["context_column"] not in df.columns:
333
+ raise Exception(f"This model expects context in the '{args['context_column']}' column.")
361
334
 
362
335
  # API argument validation
363
- model_name = args.get('model_name', self.default_model)
336
+ model_name = args.get("model_name", self.default_model)
364
337
  api_args = {
365
- 'max_tokens': pred_args.get(
366
- 'max_tokens', args.get('max_tokens', self.default_max_tokens)
367
- ),
368
- 'temperature': min(
338
+ "max_tokens": pred_args.get("max_tokens", args.get("max_tokens", self.default_max_tokens)),
339
+ "temperature": min(
369
340
  1.0,
370
- max(
371
- 0.0, pred_args.get('temperature', args.get('temperature', 0.0))
372
- ),
341
+ max(0.0, pred_args.get("temperature", args.get("temperature", 0.0))),
373
342
  ),
374
- 'top_p': pred_args.get('top_p', None),
375
- 'n': pred_args.get('n', None),
376
- 'stop': pred_args.get('stop', None),
377
- 'presence_penalty': pred_args.get('presence_penalty', None),
378
- 'frequency_penalty': pred_args.get('frequency_penalty', None),
379
- 'best_of': pred_args.get('best_of', None),
380
- 'logit_bias': pred_args.get('logit_bias', None),
381
- 'user': pred_args.get('user', None),
343
+ "top_p": pred_args.get("top_p", None),
344
+ "n": pred_args.get("n", None),
345
+ "stop": pred_args.get("stop", None),
346
+ "presence_penalty": pred_args.get("presence_penalty", None),
347
+ "frequency_penalty": pred_args.get("frequency_penalty", None),
348
+ "best_of": pred_args.get("best_of", None),
349
+ "logit_bias": pred_args.get("logit_bias", None),
350
+ "user": pred_args.get("user", None),
382
351
  }
383
352
 
384
- if (
385
- args.get('mode', self.default_mode) != 'default'
386
- and model_name not in self.chat_completion_models
387
- ):
353
+ if args.get("mode", self.default_mode) != "default" and model_name not in self.chat_completion_models:
388
354
  raise Exception(
389
355
  f"Conversational modes are only available for the following models: {', '.join(self.chat_completion_models)}"
390
356
  ) # noqa
391
357
 
392
- if args.get('prompt_template', False):
358
+ if args.get("prompt_template", False):
393
359
  prompts, empty_prompt_ids = get_completed_prompts(base_template, df, strict=strict_prompt_template)
394
360
 
395
- elif args.get('context_column', False):
361
+ elif args.get("context_column", False):
396
362
  empty_prompt_ids = np.where(
397
- df[[args['context_column'], args['question_column']]]
398
- .isna()
399
- .all(axis=1)
400
- .values
363
+ df[[args["context_column"], args["question_column"]]].isna().all(axis=1).values
401
364
  )[0]
402
- contexts = list(df[args['context_column']].apply(lambda x: str(x)))
403
- questions = list(df[args['question_column']].apply(lambda x: str(x)))
404
- prompts = [
405
- f'Context: {c}\nQuestion: {q}\nAnswer: '
406
- for c, q in zip(contexts, questions)
407
- ]
408
-
409
- elif 'prompt' in args:
365
+ contexts = list(df[args["context_column"]].apply(lambda x: str(x)))
366
+ questions = list(df[args["question_column"]].apply(lambda x: str(x)))
367
+ prompts = [f"Context: {c}\nQuestion: {q}\nAnswer: " for c, q in zip(contexts, questions)]
368
+
369
+ elif "prompt" in args:
410
370
  empty_prompt_ids = []
411
- prompts = list(df[args['user_column']])
371
+ prompts = list(df[args["user_column"]])
412
372
  else:
413
- empty_prompt_ids = np.where(
414
- df[[args['question_column']]].isna().all(axis=1).values
415
- )[0]
416
- prompts = list(df[args['question_column']].apply(lambda x: str(x)))
373
+ empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0]
374
+ prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
417
375
 
418
376
  # add json struct if available
419
- if args.get('json_struct', False):
377
+ if args.get("json_struct", False):
420
378
  for i, prompt in enumerate(prompts):
421
- json_struct = ''
422
- if 'json_struct' in df.columns and i not in empty_prompt_ids:
379
+ json_struct = ""
380
+ if "json_struct" in df.columns and i not in empty_prompt_ids:
423
381
  # if row has a specific json, we try to use it instead of the base prompt template
424
382
  try:
425
- if isinstance(df['json_struct'][i], str):
426
- df['json_struct'][i] = json.loads(df['json_struct'][i])
427
- for ind, val in enumerate(df['json_struct'][i].values()):
428
- json_struct = json_struct + f'{ind}. {val}\n'
383
+ if isinstance(df["json_struct"][i], str):
384
+ df["json_struct"][i] = json.loads(df["json_struct"][i])
385
+ for ind, val in enumerate(df["json_struct"][i].values()):
386
+ json_struct = json_struct + f"{ind}. {val}\n"
429
387
  except Exception:
430
388
  pass # if the row's json is invalid, we use the prompt template instead
431
389
 
432
- if json_struct == '':
433
- for ind, val in enumerate(args['json_struct'].values()):
434
- json_struct = json_struct + f'{ind + 1}. {val}\n'
390
+ if json_struct == "":
391
+ for ind, val in enumerate(args["json_struct"].values()):
392
+ json_struct = json_struct + f"{ind + 1}. {val}\n"
435
393
 
436
394
  p = textwrap.dedent(
437
- f'''\
438
- Based on the text following 'The reference text is:', assign values to the following {len(args['json_struct'])} JSON attributes:
395
+ f"""\
396
+ Based on the text following 'The reference text is:', assign values to the following {len(args["json_struct"])} JSON attributes:
439
397
  {{{{json_struct}}}}
440
398
 
441
399
  Values should follow the same order as the attributes above.
@@ -456,48 +414,51 @@ class OpenAIHandler(BaseMLEngine):
456
414
 
457
415
  Now for the real task. The reference text is:
458
416
  {prompt}
459
- '''
417
+ """
460
418
  )
461
419
 
462
- p = p.replace('{{json_struct}}', json_struct)
420
+ p = p.replace("{{json_struct}}", json_struct)
463
421
  prompts[i] = p
464
422
 
465
423
  # remove prompts without signal from completion queue
466
424
  prompts = [j for i, j in enumerate(prompts) if i not in empty_prompt_ids]
467
425
 
468
426
  api_key = get_api_key(self.api_key_name, args, self.engine_storage)
469
- api_args = {
470
- k: v for k, v in api_args.items() if v is not None
471
- } # filter out non-specified api args
427
+ api_args = {k: v for k, v in api_args.items() if v is not None} # filter out non-specified api args
472
428
  completion = self._completion(model_name, prompts, api_key, api_args, args, df)
473
429
 
474
430
  # add null completion for empty prompts
475
431
  for i in sorted(empty_prompt_ids):
476
432
  completion.insert(i, None)
477
433
 
478
- pred_df = pd.DataFrame(completion, columns=[args['target']])
434
+ pred_df = pd.DataFrame(completion, columns=[args["target"]])
479
435
 
480
436
  # restore json struct
481
- if args.get('json_struct', False):
437
+ if args.get("json_struct", False):
482
438
  for i in pred_df.index:
483
439
  try:
484
- if 'json_struct' in df.columns:
485
- json_keys = df['json_struct'][i].keys()
440
+ if "json_struct" in df.columns:
441
+ json_keys = df["json_struct"][i].keys()
486
442
  else:
487
- json_keys = args['json_struct'].keys()
488
- responses = pred_df[args['target']][i].split('\n')
443
+ json_keys = args["json_struct"].keys()
444
+ responses = pred_df[args["target"]][i].split("\n")
489
445
  responses = [x[3:] for x in responses] # del question index
490
446
 
491
- pred_df[args['target']][i] = {
492
- key: val for key, val in zip(json_keys, responses)
493
- }
447
+ pred_df[args["target"]][i] = {key: val for key, val in zip(json_keys, responses)}
494
448
  except Exception:
495
- pred_df[args['target']][i] = None
449
+ pred_df[args["target"]][i] = None
496
450
 
497
451
  return pred_df
498
452
 
499
453
  def _completion(
500
- self, model_name: Text, prompts: List[Text], api_key: Text, api_args: Dict, args: Dict, df: pd.DataFrame, parallel: bool = True
454
+ self,
455
+ model_name: Text,
456
+ prompts: List[Text],
457
+ api_key: Text,
458
+ api_args: Dict,
459
+ args: Dict,
460
+ df: pd.DataFrame,
461
+ parallel: bool = True,
501
462
  ) -> List[Any]:
502
463
  """
503
464
  Handles completion for an arbitrary amount of rows using a model connected to the OpenAI API.
@@ -531,7 +492,9 @@ class OpenAIHandler(BaseMLEngine):
531
492
  """
532
493
 
533
494
  @retry_with_exponential_backoff()
534
- def _submit_completion(model_name: Text, prompts: List[Text], api_args: Dict, args: Dict, df: pd.DataFrame) -> List[Text]:
495
+ def _submit_completion(
496
+ model_name: Text, prompts: List[Text], api_args: Dict, args: Dict, df: pd.DataFrame
497
+ ) -> List[Text]:
535
498
  """
536
499
  Submit a request to the relevant completion endpoint of the OpenAI API based on the type of task.
537
500
 
@@ -546,26 +509,22 @@ class OpenAIHandler(BaseMLEngine):
546
509
  List[Text]: List of completions.
547
510
  """
548
511
  kwargs = {
549
- 'model': model_name,
512
+ "model": model_name,
550
513
  }
551
514
  if model_name in IMAGE_MODELS:
552
515
  return _submit_image_completion(kwargs, prompts, api_args)
553
- elif model_name == 'embedding':
516
+ elif model_name == "embedding":
554
517
  return _submit_embedding_completion(kwargs, prompts, api_args)
555
518
  elif model_name in self.chat_completion_models:
556
519
  if model_name == "gpt-3.5-turbo-instruct":
557
- return _submit_normal_completion(
558
- kwargs,
559
- prompts,
560
- api_args
561
- )
520
+ return _submit_normal_completion(kwargs, prompts, api_args)
562
521
  else:
563
522
  return _submit_chat_completion(
564
523
  kwargs,
565
524
  prompts,
566
525
  api_args,
567
526
  df,
568
- mode=args.get('mode', 'conversational'),
527
+ mode=args.get("mode", "conversational"),
569
528
  )
570
529
  else:
571
530
  return _submit_normal_completion(kwargs, prompts, api_args)
@@ -584,9 +543,9 @@ class OpenAIHandler(BaseMLEngine):
584
543
  after_openai_query(params, response)
585
544
 
586
545
  params2 = params.copy()
587
- params2.pop('api_key', None)
588
- params2.pop('user', None)
589
- logger.debug(f'>>>openai call: {params2}:\n{response}')
546
+ params2.pop("api_key", None)
547
+ params2.pop("user", None)
548
+ logger.debug(f">>>openai call: {params2}:\n{response}")
590
549
 
591
550
  def _submit_normal_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[Text]:
592
551
  """
@@ -616,11 +575,11 @@ class OpenAIHandler(BaseMLEngine):
616
575
  """
617
576
  tidy_comps = []
618
577
  for c in comp.choices:
619
- if hasattr(c, 'text'):
620
- tidy_comps.append(c.text.strip('\n').strip(''))
578
+ if hasattr(c, "text"):
579
+ tidy_comps.append(c.text.strip("\n").strip(""))
621
580
  return tidy_comps
622
581
 
623
- kwargs['prompt'] = prompts
582
+ kwargs["prompt"] = prompts
624
583
  kwargs = {**kwargs, **api_args}
625
584
 
626
585
  before_openai_query(kwargs)
@@ -656,11 +615,11 @@ class OpenAIHandler(BaseMLEngine):
656
615
  """
657
616
  tidy_comps = []
658
617
  for c in comp.data:
659
- if hasattr(c, 'embedding'):
618
+ if hasattr(c, "embedding"):
660
619
  tidy_comps.append([c.embedding])
661
620
  return tidy_comps
662
621
 
663
- kwargs['input'] = prompts
622
+ kwargs["input"] = prompts
664
623
  kwargs = {**kwargs, **api_args}
665
624
 
666
625
  before_openai_query(kwargs)
@@ -668,7 +627,9 @@ class OpenAIHandler(BaseMLEngine):
668
627
  _log_api_call(kwargs, resp)
669
628
  return resp
670
629
 
671
- def _submit_chat_completion(kwargs: Dict, prompts: List[Text], api_args: Dict, df: pd.DataFrame, mode: Text = 'conversational') -> List[Text]:
630
+ def _submit_chat_completion(
631
+ kwargs: Dict, prompts: List[Text], api_args: Dict, df: pd.DataFrame, mode: Text = "conversational"
632
+ ) -> List[Text]:
672
633
  """
673
634
  Submit a request to the chat completion endpoint of the OpenAI API.
674
635
 
@@ -698,48 +659,42 @@ class OpenAIHandler(BaseMLEngine):
698
659
  """
699
660
  tidy_comps = []
700
661
  for c in comp.choices:
701
- if hasattr(c, 'message'):
702
- tidy_comps.append(c.message.content.strip('\n').strip(''))
662
+ if hasattr(c, "message"):
663
+ tidy_comps.append(c.message.content.strip("\n").strip(""))
703
664
  return tidy_comps
704
665
 
705
666
  completions = []
706
- if mode != 'conversational' or 'prompt' not in args:
667
+ if mode != "conversational" or "prompt" not in args:
707
668
  initial_prompt = {
708
669
  "role": "system",
709
670
  "content": "You are a helpful assistant. Your task is to continue the chat.",
710
671
  } # noqa
711
672
  else:
712
673
  # get prompt from model
713
- initial_prompt = {"role": "system", "content": args['prompt']} # noqa
674
+ initial_prompt = {"role": "system", "content": args["prompt"]} # noqa
714
675
 
715
- kwargs['messages'] = [initial_prompt]
676
+ kwargs["messages"] = [initial_prompt]
716
677
  last_completion_content = None
717
678
 
718
679
  for pidx in range(len(prompts)):
719
- if mode != 'conversational':
720
- kwargs['messages'].append(
721
- {'role': 'user', 'content': prompts[pidx]}
722
- )
680
+ if mode != "conversational":
681
+ kwargs["messages"].append({"role": "user", "content": prompts[pidx]})
723
682
  else:
724
683
  question = prompts[pidx]
725
684
  if question:
726
- kwargs['messages'].append({'role': 'user', 'content': question})
685
+ kwargs["messages"].append({"role": "user", "content": question})
727
686
 
728
- assistant_column = args.get('assistant_column')
687
+ assistant_column = args.get("assistant_column")
729
688
  if assistant_column in df.columns:
730
689
  answer = df.iloc[pidx][assistant_column]
731
690
  else:
732
691
  answer = None
733
692
  if answer:
734
- kwargs['messages'].append(
735
- {'role': 'assistant', 'content': answer}
736
- )
737
-
738
- if mode == 'conversational-full' or (
739
- mode == 'conversational' and pidx == len(prompts) - 1
740
- ):
741
- kwargs['messages'] = truncate_msgs_for_token_limit(
742
- kwargs['messages'], kwargs['model'], api_args['max_tokens']
693
+ kwargs["messages"].append({"role": "assistant", "content": answer})
694
+
695
+ if mode == "conversational-full" or (mode == "conversational" and pidx == len(prompts) - 1):
696
+ kwargs["messages"] = truncate_msgs_for_token_limit(
697
+ kwargs["messages"], kwargs["model"], api_args["max_tokens"]
743
698
  )
744
699
  pkwargs = {**kwargs, **api_args}
745
700
 
@@ -748,8 +703,8 @@ class OpenAIHandler(BaseMLEngine):
748
703
  _log_api_call(pkwargs, resp)
749
704
 
750
705
  completions.extend(resp)
751
- elif mode == 'default':
752
- kwargs['messages'] = [initial_prompt] + [kwargs['messages'][-1]]
706
+ elif mode == "default":
707
+ kwargs["messages"] = [initial_prompt] + [kwargs["messages"][-1]]
753
708
  pkwargs = {**kwargs, **api_args}
754
709
 
755
710
  before_openai_query(kwargs)
@@ -760,13 +715,11 @@ class OpenAIHandler(BaseMLEngine):
760
715
  else:
761
716
  # in "normal" conversational mode, we request completions only for the last row
762
717
  last_completion_content = None
763
- completions.extend([''])
718
+ completions.extend([""])
764
719
 
765
720
  if last_completion_content:
766
721
  # interleave assistant responses with user input
767
- kwargs['messages'].append(
768
- {'role': 'assistant', 'content': last_completion_content[0]}
769
- )
722
+ kwargs["messages"].append({"role": "assistant", "content": last_completion_content[0]})
770
723
 
771
724
  return completions
772
725
 
@@ -799,46 +752,36 @@ class OpenAIHandler(BaseMLEngine):
799
752
  Returns:
800
753
  List[Text]: List of image completions as URLs or base64 encoded images.
801
754
  """
802
- return [
803
- c.url if hasattr(c, 'url') else c.b64_json
804
- for c in comp
805
- ]
806
-
807
- completions = [
808
- client.images.generate(**{'prompt': p, **kwargs, **api_args}).data[0]
809
- for p in prompts
810
- ]
755
+ return [c.url if hasattr(c, "url") else c.b64_json for c in comp]
756
+
757
+ completions = [client.images.generate(**{"prompt": p, **kwargs, **api_args}).data[0] for p in prompts]
811
758
  return _tidy(completions)
812
759
 
813
760
  client = self._get_client(
814
761
  api_key=api_key,
815
- base_url=args.get('api_base'),
816
- org=args.pop('api_organization') if 'api_organization' in args else None,
817
- args=args
762
+ base_url=args.get("api_base"),
763
+ org=args.pop("api_organization") if "api_organization" in args else None,
764
+ args=args,
818
765
  )
819
766
 
820
767
  try:
821
768
  # check if simple completion works
822
- completion = _submit_completion(
823
- model_name, prompts, api_args, args, df
824
- )
769
+ completion = _submit_completion(model_name, prompts, api_args, args, df)
825
770
  return completion
826
771
  except Exception as e:
827
772
  # else, we get the max batch size
828
- if 'you can currently request up to at most a total of' in str(e):
829
- pattern = 'a total of'
830
- max_batch_size = int(e[e.find(pattern) + len(pattern):].split(').')[0])
773
+ if "you can currently request up to at most a total of" in str(e):
774
+ pattern = "a total of"
775
+ max_batch_size = int(e[e.find(pattern) + len(pattern) :].split(").")[0])
831
776
  else:
832
- max_batch_size = (
833
- self.max_batch_size
834
- ) # guards against changes in the API message
777
+ max_batch_size = self.max_batch_size # guards against changes in the API message
835
778
 
836
779
  if not parallel:
837
780
  completion = None
838
781
  for i in range(math.ceil(len(prompts) / max_batch_size)):
839
782
  partial = _submit_completion(
840
783
  model_name,
841
- prompts[i * max_batch_size: (i + 1) * max_batch_size],
784
+ prompts[i * max_batch_size : (i + 1) * max_batch_size],
842
785
  api_args,
843
786
  args,
844
787
  df,
@@ -846,20 +789,18 @@ class OpenAIHandler(BaseMLEngine):
846
789
  if not completion:
847
790
  completion = partial
848
791
  else:
849
- completion['choices'].extend(partial['choices'])
850
- for field in ('prompt_tokens', 'completion_tokens', 'total_tokens'):
851
- completion['usage'][field] += partial['usage'][field]
792
+ completion["choices"].extend(partial["choices"])
793
+ for field in ("prompt_tokens", "completion_tokens", "total_tokens"):
794
+ completion["usage"][field] += partial["usage"][field]
852
795
  else:
853
796
  promises = []
854
797
  with concurrent.futures.ThreadPoolExecutor() as executor:
855
798
  for i in range(math.ceil(len(prompts) / max_batch_size)):
856
- logger.debug(
857
- f'{i * max_batch_size}:{(i+1) * max_batch_size}/{len(prompts)}'
858
- )
799
+ logger.debug(f"{i * max_batch_size}:{(i + 1) * max_batch_size}/{len(prompts)}")
859
800
  future = executor.submit(
860
801
  _submit_completion,
861
802
  model_name,
862
- prompts[i * max_batch_size: (i + 1) * max_batch_size],
803
+ prompts[i * max_batch_size : (i + 1) * max_batch_size],
863
804
  api_args,
864
805
  args,
865
806
  df,
@@ -868,9 +809,9 @@ class OpenAIHandler(BaseMLEngine):
868
809
  completion = None
869
810
  for p in promises:
870
811
  if not completion:
871
- completion = p['choices'].result()
812
+ completion = p["choices"].result()
872
813
  else:
873
- completion.extend(p['choices'].result())
814
+ completion.extend(p["choices"].result())
874
815
 
875
816
  return completion
876
817
 
@@ -886,26 +827,26 @@ class OpenAIHandler(BaseMLEngine):
886
827
  """
887
828
  # TODO: Update to use update() artifacts
888
829
 
889
- args = self.model_storage.json_get('args')
830
+ args = self.model_storage.json_get("args")
890
831
  api_key = get_api_key(self.api_key_name, args, self.engine_storage)
891
- if attribute == 'args':
892
- return pd.DataFrame(args.items(), columns=['key', 'value'])
893
- elif attribute == 'metadata':
894
- model_name = args.get('model_name', self.default_model)
832
+ if attribute == "args":
833
+ return pd.DataFrame(args.items(), columns=["key", "value"])
834
+ elif attribute == "metadata":
835
+ model_name = args.get("model_name", self.default_model)
895
836
  try:
896
837
  client = self._get_client(
897
838
  api_key=api_key,
898
- base_url=args.get('api_base'),
899
- org=args.get('api_organization'),
839
+ base_url=args.get("api_base"),
840
+ org=args.get("api_organization"),
900
841
  args=args,
901
842
  )
902
843
  meta = client.models.retrieve(model_name)
903
844
  except Exception as e:
904
- meta = {'error': str(e)}
905
- return pd.DataFrame(dict(meta).items(), columns=['key', 'value'])
845
+ meta = {"error": str(e)}
846
+ return pd.DataFrame(dict(meta).items(), columns=["key", "value"])
906
847
  else:
907
- tables = ['args', 'metadata']
908
- return pd.DataFrame(tables, columns=['tables'])
848
+ tables = ["args", "metadata"]
849
+ return pd.DataFrame(tables, columns=["tables"])
909
850
 
910
851
  def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
911
852
  """
@@ -937,14 +878,14 @@ class OpenAIHandler(BaseMLEngine):
937
878
 
938
879
  api_key = get_api_key(self.api_key_name, args, self.engine_storage)
939
880
 
940
- using_args = args.pop('using') if 'using' in args else {}
881
+ using_args = args.pop("using") if "using" in args else {}
941
882
 
942
- api_base = using_args.get('api_base', os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE))
943
- org = using_args.get('api_organization')
883
+ api_base = using_args.get("api_base", os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE))
884
+ org = using_args.get("api_organization")
944
885
  client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
945
886
 
946
887
  args = {**using_args, **args}
947
- prev_model_name = self.base_model_storage.json_get('args').get('model_name', '')
888
+ prev_model_name = self.base_model_storage.json_get("args").get("model_name", "")
948
889
 
949
890
  if prev_model_name not in self.supported_ft_models:
950
891
  # base model may be already FTed, check prefixes
@@ -956,80 +897,73 @@ class OpenAIHandler(BaseMLEngine):
956
897
  f"This model cannot be finetuned. Supported base models are {self.supported_ft_models}."
957
898
  )
958
899
 
959
- finetune_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
900
+ finetune_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
960
901
 
961
902
  temp_storage_path = tempfile.mkdtemp()
962
903
  temp_file_name = f"ft_{finetune_time}"
963
904
  temp_model_storage_path = f"{temp_storage_path}/{temp_file_name}.jsonl"
964
905
 
965
- file_names = self._prepare_ft_jsonl(
966
- df, temp_storage_path, temp_file_name, temp_model_storage_path
967
- )
906
+ file_names = self._prepare_ft_jsonl(df, temp_storage_path, temp_file_name, temp_model_storage_path)
968
907
 
969
908
  jsons = {k: None for k in file_names.keys()}
970
909
  for split, file_name in file_names.items():
971
910
  if os.path.isfile(os.path.join(temp_storage_path, file_name)):
972
911
  jsons[split] = client.files.create(
973
- file=open(f"{temp_storage_path}/{file_name}", "rb"),
974
- purpose='fine-tune'
912
+ file=open(f"{temp_storage_path}/{file_name}", "rb"), purpose="fine-tune"
975
913
  )
976
914
 
977
- if type(jsons['train']) is openai.types.FileObject:
978
- train_file_id = jsons['train'].id
915
+ if type(jsons["train"]) is openai.types.FileObject:
916
+ train_file_id = jsons["train"].id
979
917
  else:
980
- train_file_id = jsons['base'].id
918
+ train_file_id = jsons["base"].id
981
919
 
982
- if type(jsons['val']) is openai.types.FileObject:
983
- val_file_id = jsons['val'].id
920
+ if type(jsons["val"]) is openai.types.FileObject:
921
+ val_file_id = jsons["val"].id
984
922
  else:
985
923
  val_file_id = None
986
924
 
987
925
  # `None` values are internally imputed by OpenAI to `null` or default values
988
926
  ft_params = {
989
- 'training_file': train_file_id,
990
- 'validation_file': val_file_id,
991
- 'model': self._get_ft_model_type(prev_model_name),
927
+ "training_file": train_file_id,
928
+ "validation_file": val_file_id,
929
+ "model": self._get_ft_model_type(prev_model_name),
992
930
  }
993
931
  ft_params = self._add_extra_ft_params(ft_params, using_args)
994
932
 
995
933
  start_time = datetime.datetime.now()
996
934
 
997
- ft_stats, result_file_id = self._ft_call(ft_params, client, args.get('hour_budget', 8))
935
+ ft_stats, result_file_id = self._ft_call(ft_params, client, args.get("hour_budget", 8))
998
936
  ft_model_name = ft_stats.fine_tuned_model
999
937
 
1000
938
  end_time = datetime.datetime.now()
1001
939
  runtime = end_time - start_time
1002
940
  name_extension = client.files.retrieve(file_id=result_file_id).filename
1003
- result_path = f'{temp_storage_path}/ft_{finetune_time}_result_{name_extension}'
941
+ result_path = f"{temp_storage_path}/ft_{finetune_time}_result_{name_extension}"
1004
942
 
1005
943
  try:
1006
944
  client.files.content(file_id=result_file_id).stream_to_file(result_path)
1007
- if '.csv' in name_extension:
945
+ if ".csv" in name_extension:
1008
946
  # legacy endpoint
1009
947
  train_stats = pd.read_csv(result_path)
1010
- if 'validation_token_accuracy' in train_stats.columns:
1011
- train_stats = train_stats[
1012
- train_stats['validation_token_accuracy'].notnull()
1013
- ]
1014
- args['ft_api_info'] = ft_stats.dict()
1015
- args['ft_result_stats'] = train_stats.to_dict()
1016
-
1017
- elif '.json' in name_extension:
1018
- train_stats = pd.read_json(
1019
- path_or_buf=result_path, lines=True
1020
- ) # new endpoint
1021
- args['ft_api_info'] = args['ft_result_stats'] = train_stats.to_dict()
948
+ if "validation_token_accuracy" in train_stats.columns:
949
+ train_stats = train_stats[train_stats["validation_token_accuracy"].notnull()]
950
+ args["ft_api_info"] = ft_stats.dict()
951
+ args["ft_result_stats"] = train_stats.to_dict()
952
+
953
+ elif ".json" in name_extension:
954
+ train_stats = pd.read_json(path_or_buf=result_path, lines=True) # new endpoint
955
+ args["ft_api_info"] = args["ft_result_stats"] = train_stats.to_dict()
1022
956
 
1023
957
  except Exception:
1024
- logger.info(f'Error retrieving fine-tuning results. Please check manually for information on job {ft_stats.id} (result file {result_file_id}).')
958
+ logger.info(
959
+ f"Error retrieving fine-tuning results. Please check manually for information on job {ft_stats.id} (result file {result_file_id})."
960
+ )
1025
961
 
1026
- args['model_name'] = ft_model_name
1027
- args['runtime'] = runtime.total_seconds()
1028
- args['mode'] = self.base_model_storage.json_get('args').get(
1029
- 'mode', self.default_mode
1030
- )
962
+ args["model_name"] = ft_model_name
963
+ args["runtime"] = runtime.total_seconds()
964
+ args["mode"] = self.base_model_storage.json_get("args").get("mode", self.default_mode)
1031
965
 
1032
- self.model_storage.json_set('args', args)
966
+ self.model_storage.json_set("args", args)
1033
967
  shutil.rmtree(temp_storage_path)
1034
968
 
1035
969
  @staticmethod
@@ -1045,7 +979,7 @@ class OpenAIHandler(BaseMLEngine):
1045
979
  Returns:
1046
980
  Dict: File names for the fine-tuning process.
1047
981
  """
1048
- df.to_json(temp_model_path, orient='records', lines=True)
982
+ df.to_json(temp_model_path, orient="records", lines=True)
1049
983
 
1050
984
  # TODO avoid subprocess usage once OpenAI enables non-CLI access, or refactor to use our own LLM utils instead
1051
985
  subprocess.run(
@@ -1055,7 +989,7 @@ class OpenAIHandler(BaseMLEngine):
1055
989
  "fine_tunes.prepare_data",
1056
990
  "-f",
1057
991
  temp_model_path, # from file
1058
- '-q', # quiet mode (accepts all suggestions)
992
+ "-q", # quiet mode (accepts all suggestions)
1059
993
  ],
1060
994
  stdout=subprocess.PIPE,
1061
995
  stderr=subprocess.PIPE,
@@ -1063,10 +997,10 @@ class OpenAIHandler(BaseMLEngine):
1063
997
  )
1064
998
 
1065
999
  file_names = {
1066
- 'original': f'{temp_filename}.jsonl',
1067
- 'base': f'{temp_filename}_prepared.jsonl',
1068
- 'train': f'{temp_filename}_prepared_train.jsonl',
1069
- 'val': f'{temp_filename}_prepared_valid.jsonl',
1000
+ "original": f"{temp_filename}.jsonl",
1001
+ "base": f"{temp_filename}_prepared.jsonl",
1002
+ "train": f"{temp_filename}_prepared_train.jsonl",
1003
+ "val": f"{temp_filename}_prepared_valid.jsonl",
1070
1004
  }
1071
1005
  return file_names
1072
1006
 
@@ -1083,7 +1017,7 @@ class OpenAIHandler(BaseMLEngine):
1083
1017
  for model_type in self.supported_ft_models:
1084
1018
  if model_type in model_name.lower():
1085
1019
  return model_type
1086
- return 'babbage-002'
1020
+ return "babbage-002"
1087
1021
 
1088
1022
  @staticmethod
1089
1023
  def _add_extra_ft_params(ft_params: Dict, using_args: Dict) -> Dict:
@@ -1098,22 +1032,14 @@ class OpenAIHandler(BaseMLEngine):
1098
1032
  Dict: Fine-tuning parameters with extra parameters.
1099
1033
  """
1100
1034
  extra_params = {
1101
- 'n_epochs': using_args.get('n_epochs', None),
1102
- 'batch_size': using_args.get('batch_size', None),
1103
- 'learning_rate_multiplier': using_args.get(
1104
- 'learning_rate_multiplier', None
1105
- ),
1106
- 'prompt_loss_weight': using_args.get('prompt_loss_weight', None),
1107
- 'compute_classification_metrics': using_args.get(
1108
- 'compute_classification_metrics', None
1109
- ),
1110
- 'classification_n_classes': using_args.get(
1111
- 'classification_n_classes', None
1112
- ),
1113
- 'classification_positive_class': using_args.get(
1114
- 'classification_positive_class', None
1115
- ),
1116
- 'classification_betas': using_args.get('classification_betas', None),
1035
+ "n_epochs": using_args.get("n_epochs", None),
1036
+ "batch_size": using_args.get("batch_size", None),
1037
+ "learning_rate_multiplier": using_args.get("learning_rate_multiplier", None),
1038
+ "prompt_loss_weight": using_args.get("prompt_loss_weight", None),
1039
+ "compute_classification_metrics": using_args.get("compute_classification_metrics", None),
1040
+ "classification_n_classes": using_args.get("classification_n_classes", None),
1041
+ "classification_positive_class": using_args.get("classification_positive_class", None),
1042
+ "classification_betas": using_args.get("classification_betas", None),
1117
1043
  }
1118
1044
  return {**ft_params, **extra_params}
1119
1045
 
@@ -1137,9 +1063,7 @@ class OpenAIHandler(BaseMLEngine):
1137
1063
  Returns:
1138
1064
  Tuple[FineTuningJob, Text]: Fine-tuning stats and result file ID.
1139
1065
  """
1140
- ft_result = client.fine_tuning.jobs.create(
1141
- **{k: v for k, v in ft_params.items() if v is not None}
1142
- )
1066
+ ft_result = client.fine_tuning.jobs.create(**{k: v for k, v in ft_params.items() if v is not None})
1143
1067
 
1144
1068
  @retry_with_exponential_backoff(
1145
1069
  hour_budget=hour_budget,
@@ -1158,22 +1082,22 @@ class OpenAIHandler(BaseMLEngine):
1158
1082
  FineTuningJob: Fine-tuning stats.
1159
1083
  """
1160
1084
  ft_retrieved = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id)
1161
- if ft_retrieved.status in ('succeeded', 'failed', 'cancelled'):
1085
+ if ft_retrieved.status in ("succeeded", "failed", "cancelled"):
1162
1086
  return ft_retrieved
1163
1087
  else:
1164
- raise PendingFT('Fine-tuning still pending!')
1088
+ raise PendingFT("Fine-tuning still pending!")
1165
1089
 
1166
1090
  ft_stats = _check_ft_status(ft_result.id)
1167
1091
 
1168
- if ft_stats.status != 'succeeded':
1169
- err_message = ft_stats.events[-1].message if hasattr(ft_stats, 'events') else 'could not retrieve!'
1170
- ft_status = ft_stats.status if hasattr(ft_stats, 'status') else 'N/A'
1092
+ if ft_stats.status != "succeeded":
1093
+ err_message = ft_stats.events[-1].message if hasattr(ft_stats, "events") else "could not retrieve!"
1094
+ ft_status = ft_stats.status if hasattr(ft_stats, "status") else "N/A"
1171
1095
  raise Exception(
1172
1096
  f"Fine-tuning did not complete successfully (status: {ft_status}). Error message: {err_message}"
1173
1097
  ) # noqa
1174
1098
 
1175
1099
  result_file_id = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=ft_result.id).result_files[0]
1176
- if hasattr(result_file_id, 'id'):
1100
+ if hasattr(result_file_id, "id"):
1177
1101
  result_file_id = result_file_id.id # legacy endpoint
1178
1102
 
1179
1103
  return ft_stats, result_file_id
@@ -1191,11 +1115,8 @@ class OpenAIHandler(BaseMLEngine):
1191
1115
  Returns:
1192
1116
  openai.OpenAI: OpenAI client.
1193
1117
  """
1194
- if args is not None and args.get('provider') == 'azure':
1118
+ if args is not None and args.get("provider") == "azure":
1195
1119
  return AzureOpenAI(
1196
- api_key=api_key,
1197
- azure_endpoint=base_url,
1198
- api_version=args.get('api_version'),
1199
- organization=org
1120
+ api_key=api_key, azure_endpoint=base_url, api_version=args.get("api_version"), organization=org
1200
1121
  )
1201
1122
  return OpenAI(api_key=api_key, base_url=base_url, organization=org)