MindsDB 25.5.4.0__py3-none-any.whl → 25.5.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

@@ -20,121 +20,109 @@ from mindsdb.integrations.libs.response import HandlerStatusResponse
20
20
  logger = log.getLogger(__name__)
21
21
 
22
22
 
23
- @ns_conf.route('/')
24
- @ns_conf.param('name', 'Get config')
23
+ @ns_conf.route("/")
24
+ @ns_conf.param("name", "Get config")
25
25
  class GetConfig(Resource):
26
- @ns_conf.doc('get_config')
27
- @api_endpoint_metrics('GET', '/config')
26
+ @ns_conf.doc("get_config")
27
+ @api_endpoint_metrics("GET", "/config")
28
28
  def get(self):
29
29
  config = Config()
30
- resp = {
31
- 'auth': {
32
- 'http_auth_enabled': config['auth']['http_auth_enabled']
33
- }
34
- }
35
- for key in ['default_llm', 'default_embedding_model']:
30
+ resp = {"auth": {"http_auth_enabled": config["auth"]["http_auth_enabled"]}}
31
+ for key in ["default_llm", "default_embedding_model", "default_reranking_model"]:
36
32
  value = config.get(key)
37
33
  if value is not None:
38
34
  resp[key] = value
39
35
  return resp
40
36
 
41
- @ns_conf.doc('put_config')
42
- @api_endpoint_metrics('PUT', '/config')
37
+ @ns_conf.doc("put_config")
38
+ @api_endpoint_metrics("PUT", "/config")
43
39
  def put(self):
44
40
  data = request.json
45
41
 
46
- allowed_arguments = {'auth', 'default_llm', 'default_embedding_model'}
42
+ allowed_arguments = {"auth", "default_llm", "default_embedding_model", "default_reranking_model"}
47
43
  unknown_arguments = list(set(data.keys()) - allowed_arguments)
48
44
  if len(unknown_arguments) > 0:
49
- return http_error(
50
- HTTPStatus.BAD_REQUEST, 'Wrong arguments',
51
- f'Unknown argumens: {unknown_arguments}'
52
- )
45
+ return http_error(HTTPStatus.BAD_REQUEST, "Wrong arguments", f"Unknown argumens: {unknown_arguments}")
53
46
 
54
- nested_keys_to_validate = {'auth'}
47
+ nested_keys_to_validate = {"auth"}
55
48
  for key in data.keys():
56
49
  if key in nested_keys_to_validate:
57
- unknown_arguments = list(
58
- set(data[key].keys()) - set(Config()[key].keys())
59
- )
50
+ unknown_arguments = list(set(data[key].keys()) - set(Config()[key].keys()))
60
51
  if len(unknown_arguments) > 0:
61
52
  return http_error(
62
- HTTPStatus.BAD_REQUEST, 'Wrong arguments',
63
- f'Unknown argumens: {unknown_arguments}'
53
+ HTTPStatus.BAD_REQUEST, "Wrong arguments", f"Unknown argumens: {unknown_arguments}"
64
54
  )
65
55
 
66
56
  Config().update(data)
67
57
 
68
- return '', 200
58
+ return "", 200
69
59
 
70
60
 
71
- @ns_conf.route('/integrations')
72
- @ns_conf.param('name', 'List all database integration')
61
+ @ns_conf.route("/integrations")
62
+ @ns_conf.param("name", "List all database integration")
73
63
  class ListIntegration(Resource):
74
- @api_endpoint_metrics('GET', '/config/integrations')
64
+ @api_endpoint_metrics("GET", "/config/integrations")
75
65
  def get(self):
76
- return {
77
- 'integrations': [k for k in ca.integration_controller.get_all(show_secrets=False)]
78
- }
66
+ return {"integrations": [k for k in ca.integration_controller.get_all(show_secrets=False)]}
79
67
 
80
68
 
81
- @ns_conf.route('/all_integrations')
82
- @ns_conf.param('name', 'List all database integration')
69
+ @ns_conf.route("/all_integrations")
70
+ @ns_conf.param("name", "List all database integration")
83
71
  class AllIntegration(Resource):
84
- @ns_conf.doc('get_all_integrations')
85
- @api_endpoint_metrics('GET', '/config/all_integrations')
72
+ @ns_conf.doc("get_all_integrations")
73
+ @api_endpoint_metrics("GET", "/config/all_integrations")
86
74
  def get(self):
87
75
  integrations = ca.integration_controller.get_all(show_secrets=False)
88
76
  return integrations
89
77
 
90
78
 
91
- @ns_conf.route('/integrations/<name>')
92
- @ns_conf.param('name', 'Database integration')
79
+ @ns_conf.route("/integrations/<name>")
80
+ @ns_conf.param("name", "Database integration")
93
81
  class Integration(Resource):
94
- @ns_conf.doc('get_integration')
95
- @api_endpoint_metrics('GET', '/config/integrations/integration')
82
+ @ns_conf.doc("get_integration")
83
+ @api_endpoint_metrics("GET", "/config/integrations/integration")
96
84
  def get(self, name):
97
85
  integration = ca.integration_controller.get(name, show_secrets=False)
98
86
  if integration is None:
99
- return http_error(HTTPStatus.NOT_FOUND, 'Not found', f'Can\'t find integration: {name}')
87
+ return http_error(HTTPStatus.NOT_FOUND, "Not found", f"Can't find integration: {name}")
100
88
  integration = copy.deepcopy(integration)
101
89
  return integration
102
90
 
103
- @ns_conf.doc('put_integration')
104
- @api_endpoint_metrics('PUT', '/config/integrations/integration')
91
+ @ns_conf.doc("put_integration")
92
+ @api_endpoint_metrics("PUT", "/config/integrations/integration")
105
93
  def put(self, name):
106
94
  params = {}
107
95
  if request.is_json:
108
- params.update((request.json or {}).get('params', {}))
96
+ params.update((request.json or {}).get("params", {}))
109
97
  else:
110
98
  params.update(request.form or {})
111
99
 
112
100
  if len(params) == 0:
113
- return http_error(HTTPStatus.BAD_REQUEST, 'Wrong argument', "type of 'params' must be dict")
101
+ return http_error(HTTPStatus.BAD_REQUEST, "Wrong argument", "type of 'params' must be dict")
114
102
 
115
103
  files = request.files
116
104
  temp_dir = None
117
105
  if files is not None and len(files) > 0:
118
- temp_dir = tempfile.mkdtemp(prefix='integration_files_')
106
+ temp_dir = tempfile.mkdtemp(prefix="integration_files_")
119
107
  for key, file in files.items():
120
108
  temp_dir_path = Path(temp_dir)
121
109
  file_name = Path(file.filename)
122
110
  file_path = temp_dir_path.joinpath(file_name).resolve()
123
111
  if temp_dir_path not in file_path.parents:
124
- raise Exception(f'Can not save file at path: {file_path}')
112
+ raise Exception(f"Can not save file at path: {file_path}")
125
113
  file.save(file_path)
126
114
  params[key] = str(file_path)
127
115
 
128
- is_test = params.get('test', False)
116
+ is_test = params.get("test", False)
129
117
  # TODO: Move this to new Endpoint
130
118
 
131
119
  config = Config()
132
- secret_key = config.get('secret_key', 'dummy-key')
120
+ secret_key = config.get("secret_key", "dummy-key")
133
121
 
134
122
  if is_test:
135
- del params['test']
136
- handler_type = params.pop('type', None)
137
- params.pop('publish', None)
123
+ del params["test"]
124
+ handler_type = params.pop("type", None)
125
+ params.pop("publish", None)
138
126
  try:
139
127
  handler = ca.integration_controller.create_tmp_handler(name, handler_type, params)
140
128
  status = handler.check_connection()
@@ -145,33 +133,32 @@ class Integration(Resource):
145
133
 
146
134
  resp = status.to_json()
147
135
 
148
- if status.success and 'code' in params:
149
- if hasattr(handler, 'handler_storage'):
136
+ if status.success and "code" in params:
137
+ if hasattr(handler, "handler_storage"):
150
138
  # attach storage if exists
151
139
  export = handler.handler_storage.export_files()
152
140
  if export:
153
141
  # encrypt with flask secret key
154
142
  encrypted = encrypt(export, secret_key)
155
- resp['storage'] = encrypted.decode()
143
+ resp["storage"] = encrypted.decode()
156
144
 
157
145
  return resp, 200
158
146
 
159
147
  config = Config()
160
- secret_key = config.get('secret_key', 'dummy-key')
148
+ secret_key = config.get("secret_key", "dummy-key")
161
149
 
162
150
  integration = ca.integration_controller.get(name, show_secrets=False)
163
151
  if integration is not None:
164
152
  return http_error(
165
- HTTPStatus.BAD_REQUEST, 'Wrong argument',
166
- f"Integration with name '{name}' already exists"
153
+ HTTPStatus.BAD_REQUEST, "Wrong argument", f"Integration with name '{name}' already exists"
167
154
  )
168
155
 
169
156
  try:
170
- engine = params['type']
157
+ engine = params["type"]
171
158
  if engine is not None:
172
- del params['type']
173
- params.pop('publish', False)
174
- storage = params.pop('storage', None)
159
+ del params["type"]
160
+ params.pop("publish", False)
161
+ storage = params.pop("storage", None)
175
162
  ca.integration_controller.add(name, engine, params)
176
163
 
177
164
  # copy storage
@@ -185,62 +172,50 @@ class Integration(Resource):
185
172
  logger.error(str(e))
186
173
  if temp_dir is not None:
187
174
  shutil.rmtree(temp_dir)
188
- return http_error(
189
- HTTPStatus.INTERNAL_SERVER_ERROR, 'Error',
190
- f'Error during config update: {str(e)}'
191
- )
175
+ return http_error(HTTPStatus.INTERNAL_SERVER_ERROR, "Error", f"Error during config update: {str(e)}")
192
176
 
193
177
  if temp_dir is not None:
194
178
  shutil.rmtree(temp_dir)
195
179
  return {}, 200
196
180
 
197
- @ns_conf.doc('delete_integration')
198
- @api_endpoint_metrics('DELETE', '/config/integrations/integration')
181
+ @ns_conf.doc("delete_integration")
182
+ @api_endpoint_metrics("DELETE", "/config/integrations/integration")
199
183
  def delete(self, name):
200
184
  integration = ca.integration_controller.get(name)
201
185
  if integration is None:
202
186
  return http_error(
203
- HTTPStatus.BAD_REQUEST, 'Integration does not exists',
204
- f"Nothing to delete. '{name}' not exists."
187
+ HTTPStatus.BAD_REQUEST, "Integration does not exists", f"Nothing to delete. '{name}' not exists."
205
188
  )
206
189
  try:
207
190
  ca.integration_controller.delete(name)
208
191
  except Exception as e:
209
192
  logger.error(str(e))
210
- return http_error(
211
- HTTPStatus.INTERNAL_SERVER_ERROR, 'Error',
212
- f"Error during integration delete: {str(e)}"
213
- )
193
+ return http_error(HTTPStatus.INTERNAL_SERVER_ERROR, "Error", f"Error during integration delete: {str(e)}")
214
194
  return "", 200
215
195
 
216
- @ns_conf.doc('modify_integration')
217
- @api_endpoint_metrics('POST', '/config/integrations/integration')
196
+ @ns_conf.doc("modify_integration")
197
+ @api_endpoint_metrics("POST", "/config/integrations/integration")
218
198
  def post(self, name):
219
199
  params = {}
220
- params.update((request.json or {}).get('params', {}))
200
+ params.update((request.json or {}).get("params", {}))
221
201
  params.update(request.form or {})
222
202
 
223
203
  if not isinstance(params, dict):
224
- return http_error(
225
- HTTPStatus.BAD_REQUEST, 'Wrong argument',
226
- "type of 'params' must be dict"
227
- )
204
+ return http_error(HTTPStatus.BAD_REQUEST, "Wrong argument", "type of 'params' must be dict")
228
205
  integration = ca.integration_controller.get(name)
229
206
  if integration is None:
230
207
  return http_error(
231
- HTTPStatus.BAD_REQUEST, 'Integration does not exists',
232
- f"Nothin to modify. '{name}' not exists."
208
+ HTTPStatus.BAD_REQUEST, "Integration does not exists", f"Nothin to modify. '{name}' not exists."
233
209
  )
234
210
  try:
235
- if 'enabled' in params:
236
- params['publish'] = params['enabled']
237
- del params['enabled']
211
+ if "enabled" in params:
212
+ params["publish"] = params["enabled"]
213
+ del params["enabled"]
238
214
  ca.integration_controller.modify(name, params)
239
215
 
240
216
  except Exception as e:
241
217
  logger.error(str(e))
242
218
  return http_error(
243
- HTTPStatus.INTERNAL_SERVER_ERROR, 'Error',
244
- f"Error during integration modification: {str(e)}"
219
+ HTTPStatus.INTERNAL_SERVER_ERROR, "Error", f"Error during integration modification: {str(e)}"
245
220
  )
246
221
  return "", 200
@@ -1,2 +1 @@
1
- virtualenv
2
- pyarrow==19.0.0
1
+ virtualenv
@@ -1,3 +1,2 @@
1
1
  lancedb~=0.3.1
2
2
  lance
3
- pyarrow~=19.0.0
@@ -1,6 +1,8 @@
1
1
  import ast
2
2
  from typing import Dict, Optional, List
3
3
 
4
+
5
+ from litellm import completion, batch_completion, embedding
4
6
  import pandas as pd
5
7
 
6
8
  from mindsdb.integrations.libs.base import BaseMLEngine
@@ -8,8 +10,6 @@ from mindsdb.utilities import log
8
10
 
9
11
  from mindsdb.integrations.handlers.litellm_handler.settings import CompletionParameters
10
12
 
11
- from litellm import completion, batch_completion
12
-
13
13
 
14
14
  logger = log.getLogger(__name__)
15
15
 
@@ -28,10 +28,24 @@ class LiteLLMHandler(BaseMLEngine):
28
28
  @staticmethod
29
29
  def create_validation(target, args=None, **kwargs):
30
30
  if "using" not in args:
31
- raise Exception(
32
- "Litellm engine requires a USING clause. See settings.py for more info on supported args."
31
+ raise Exception("Litellm engine requires a USING clause. See settings.py for more info on supported args.")
32
+
33
+ @staticmethod
34
+ def embeddings(model: str, messages: List[str], args: dict) -> List[list]:
35
+ response = embedding(model=model, input=messages, **args)
36
+ return [rec["embedding"] for rec in response.data]
37
+
38
+ @staticmethod
39
+ async def acompletion(model: str, messages: List[dict], args: dict):
40
+ if model.startswith("snowflake/") and "snowflake_account_id" in args:
41
+ args["api_base"] = (
42
+ f"https://{args['snowflake_account_id']}.snowflakecomputing.com/api/v2/cortex/inference:complete"
33
43
  )
34
44
 
45
+ from litellm import acompletion
46
+
47
+ return await acompletion(model=model, messages=messages, stream=False, **args)
48
+
35
49
  def create(
36
50
  self,
37
51
  target: str,
@@ -70,9 +84,9 @@ class LiteLLMHandler(BaseMLEngine):
70
84
  self._build_messages(args, df)
71
85
 
72
86
  # remove prompt_template from args
73
- args.pop('prompt_template', None)
87
+ args.pop("prompt_template", None)
74
88
 
75
- if len(args['messages']) > 1:
89
+ if len(args["messages"]) > 1:
76
90
  # if more than one message, use batch completion
77
91
  responses = batch_completion(**args)
78
92
  return pd.DataFrame({"result": [response.choices[0].message.content for response in responses]})
@@ -103,36 +117,39 @@ class LiteLLMHandler(BaseMLEngine):
103
117
 
104
118
  if "prompt_template" in prompt_kwargs:
105
119
  # if prompt_template is passed in predict query, use it
106
- logger.info("Using 'prompt_template' passed in SELECT Predict query. "
107
- "Note this will overwrite a 'prompt_template' passed in create MODEL query.")
120
+ logger.info(
121
+ "Using 'prompt_template' passed in SELECT Predict query. "
122
+ "Note this will overwrite a 'prompt_template' passed in create MODEL query."
123
+ )
108
124
 
109
- args['prompt_template'] = prompt_kwargs.pop('prompt_template')
125
+ args["prompt_template"] = prompt_kwargs.pop("prompt_template")
110
126
 
111
- if 'mock_response' in prompt_kwargs:
127
+ if "mock_response" in prompt_kwargs:
112
128
  # used for testing to save on real completion api calls
113
- args['mock_response']: str = prompt_kwargs.pop('mock_response')
129
+ args["mock_response"]: str = prompt_kwargs.pop("mock_response")
114
130
 
115
- if 'messages' in prompt_kwargs and len(prompt_kwargs) > 1:
131
+ if "messages" in prompt_kwargs and len(prompt_kwargs) > 1:
116
132
  # if user passes in messages, no other args can be passed in
117
- raise Exception(
118
- "If 'messages' is passed in SELECT Predict query, no other args can be passed in."
119
- )
133
+ raise Exception("If 'messages' is passed in SELECT Predict query, no other args can be passed in.")
120
134
 
121
135
  # if user passes in messages, use those instead
122
- if 'messages' in prompt_kwargs:
136
+ if "messages" in prompt_kwargs:
123
137
  logger.info("Using messages passed in SELECT Predict query. 'prompt_template' will be ignored.")
124
138
 
125
- args['messages']: List = ast.literal_eval(df['messages'].iloc[0])
139
+ args["messages"]: List = ast.literal_eval(df["messages"].iloc[0])
126
140
 
127
141
  else:
128
142
  # if user passes in prompt_template, use that to create messages
129
143
  if len(prompt_kwargs) == 1:
130
- args['messages'] = self._prompt_to_messages(args['prompt_template'], **prompt_kwargs) \
131
- if args['prompt_template'] else self._prompt_to_messages(df.iloc[0][0])
144
+ args["messages"] = (
145
+ self._prompt_to_messages(args["prompt_template"], **prompt_kwargs)
146
+ if args["prompt_template"]
147
+ else self._prompt_to_messages(df.iloc[0][0])
148
+ )
132
149
 
133
150
  elif len(prompt_kwargs) > 1:
134
151
  try:
135
- args['messages'] = self._prompt_to_messages(args['prompt_template'], **prompt_kwargs)
152
+ args["messages"] = self._prompt_to_messages(args["prompt_template"], **prompt_kwargs)
136
153
  except KeyError as e:
137
154
  raise Exception(
138
155
  f"{e}: Please pass in either a prompt_template on create MODEL or "
@@ -114,3 +114,16 @@ class GoogleConfig(BaseLLMConfig):
114
114
  top_k: Optional[int] = Field(default=None, description="Number of highest probability tokens to consider")
115
115
  max_output_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate")
116
116
  google_api_key: Optional[str] = Field(default=None, description="API key for Google Generative AI")
117
+
118
+
119
+ # See https://api.python.langchain.com/en/latest/llms/langchain_community.llms.writer.Writer.html
120
+ class WriterConfig(BaseLLMConfig):
121
+ model_name: str = Field(default="palmyra-x5", alias="model_id")
122
+ temperature: Optional[float] = Field(default=0.7)
123
+ max_tokens: Optional[int] = Field(default=None)
124
+ top_p: Optional[float] = Field(default=None)
125
+ stop: Optional[List[str]] = Field(default=None)
126
+ best_of: Optional[int] = Field(default=None)
127
+ writer_api_key: Optional[str] = Field(default=None)
128
+ writer_org_id: Optional[str] = Field(default=None)
129
+ base_url: Optional[str] = Field(default=None)
@@ -16,6 +16,7 @@ from mindsdb.integrations.libs.llm.config import (
16
16
  OpenAIConfig,
17
17
  NvidiaNIMConfig,
18
18
  MindsdbConfig,
19
+ WriterConfig,
19
20
  )
20
21
  from mindsdb.utilities.config import config
21
22
  from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
@@ -41,16 +42,12 @@ DEFAULT_LITELLM_BASE_URL = "https://ai.dev.mindsdb.com"
41
42
  DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434"
42
43
  DEFAULT_OLLAMA_MODEL = "llama2"
43
44
 
44
- DEFAULT_NVIDIA_NIM_BASE_URL = (
45
- "http://localhost:8000/v1" # Assumes local port forwarding through ssh
46
- )
45
+ DEFAULT_NVIDIA_NIM_BASE_URL = "http://localhost:8000/v1" # Assumes local port forwarding through ssh
47
46
  DEFAULT_NVIDIA_NIM_MODEL = "meta/llama-3_1-8b-instruct"
48
47
  DEFAULT_VLLM_SERVER_URL = "http://localhost:8000/v1"
49
48
 
50
49
 
51
- def get_completed_prompts(
52
- base_template: str, df: pd.DataFrame, strict=True
53
- ) -> Tuple[List[str], np.ndarray]:
50
+ def get_completed_prompts(base_template: str, df: pd.DataFrame, strict=True) -> Tuple[List[str], np.ndarray]:
54
51
  """
55
52
  Helper method that produces formatted prompts given a template and data in a Pandas DataFrame.
56
53
  It also returns the ID of any empty templates that failed to be filled due to missing data.
@@ -69,9 +66,7 @@ def get_completed_prompts(
69
66
  if len(matches) == 0:
70
67
  # no placeholders
71
68
  if strict:
72
- raise AssertionError(
73
- "No placeholders found in the prompt, please provide a valid prompt template."
74
- )
69
+ raise AssertionError("No placeholders found in the prompt, please provide a valid prompt template.")
75
70
  prompts = [base_template] * len(df)
76
71
  return prompts, np.ndarray(0)
77
72
 
@@ -95,12 +90,8 @@ def get_completed_prompts(
95
90
  for i in range(len(template)):
96
91
  atom = template[i]
97
92
  if i < len(columns):
98
- col = df[columns[i]].replace(
99
- to_replace=[None], value=""
100
- ) # add empty quote if data is missing
101
- df["__mdb_prompt"] = df["__mdb_prompt"].apply(
102
- lambda x: x + atom
103
- ) + col.astype("string")
93
+ col = df[columns[i]].replace(to_replace=[None], value="") # add empty quote if data is missing
94
+ df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom) + col.astype("string")
104
95
  else:
105
96
  df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom)
106
97
  prompts = list(df["__mdb_prompt"])
@@ -119,8 +110,7 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
119
110
  """
120
111
  temperature = min(1.0, max(0.0, args.get("temperature", 0.0)))
121
112
  if provider == "openai":
122
-
123
- if any(x in args.get("model_name", "") for x in ['o1', 'o3']):
113
+ if any(x in args.get("model_name", "") for x in ["o1", "o3"]):
124
114
  # for o1 and 03, 'temperature' does not support 0.0 with this model. Only the default (1) value is supported
125
115
  temperature = 1
126
116
 
@@ -173,9 +163,7 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
173
163
  max_tokens=args.get("max_tokens", DEFAULT_OPENAI_MAX_TOKENS),
174
164
  top_p=args.get("top_p", None),
175
165
  top_k=args.get("top_k", None),
176
- custom_llm_provider=args.get(
177
- "custom_llm_provider", DEFAULT_LITELLM_PROVIDER
178
- ),
166
+ custom_llm_provider=args.get("custom_llm_provider", DEFAULT_LITELLM_PROVIDER),
179
167
  model_kwargs=model_kwargs,
180
168
  )
181
169
  if provider == "ollama":
@@ -237,6 +225,18 @@ def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig:
237
225
  max_output_tokens=args.get("max_tokens", None),
238
226
  google_api_key=args["api_keys"].get("google", None),
239
227
  )
228
+ if provider == "writer":
229
+ return WriterConfig(
230
+ model_name=args.get("model_name", "palmyra-x5"),
231
+ temperature=temperature,
232
+ max_tokens=args.get("max_tokens", None),
233
+ top_p=args.get("top_p", None),
234
+ stop=args.get("stop", None),
235
+ best_of=args.get("best_of", None),
236
+ writer_api_key=args["api_keys"].get("writer", None),
237
+ writer_org_id=args.get("writer_org_id", None),
238
+ base_url=args.get("base_url", None),
239
+ )
240
240
 
241
241
  raise ValueError(f"Provider {provider} is not supported.")
242
242
 
@@ -290,9 +290,7 @@ def ft_jsonl_validation(
290
290
  ) # noqa
291
291
 
292
292
  if messages_col not in batch:
293
- raise Exception(
294
- f"{prefix}Each line in the provided data should have a '{messages_col}' key"
295
- )
293
+ raise Exception(f"{prefix}Each line in the provided data should have a '{messages_col}' key")
296
294
 
297
295
  messages = batch[messages_col]
298
296
  try:
@@ -350,30 +348,22 @@ def ft_chat_format_validation(
350
348
 
351
349
  for c in chat:
352
350
  if any(k not in valid_keys for k in c.keys()):
353
- raise Exception(
354
- f"Each message should only have these keys: `{valid_keys}`. Found: `{c.keys()}`"
355
- )
351
+ raise Exception(f"Each message should only have these keys: `{valid_keys}`. Found: `{c.keys()}`")
356
352
 
357
353
  roles = [m[role_key] for m in chat]
358
354
  contents = [m[content_key] for m in chat]
359
355
 
360
356
  if len(roles) != len(contents):
361
- raise Exception(
362
- f"Each message should contain both `{role_key}` and `{content_key}` fields"
363
- )
357
+ raise Exception(f"Each message should contain both `{role_key}` and `{content_key}` fields")
364
358
 
365
359
  if len(roles) == 0:
366
360
  raise Exception("Chat should have at least one message")
367
361
 
368
362
  if assistant_key not in roles:
369
- raise Exception(
370
- "Chat should have at least one assistant message"
371
- ) # otherwise it is useless for FT
363
+ raise Exception("Chat should have at least one assistant message") # otherwise it is useless for FT
372
364
 
373
365
  if user_key not in roles:
374
- raise Exception(
375
- "Chat should have at least one user message"
376
- ) # perhaps remove in the future
366
+ raise Exception("Chat should have at least one user message") # perhaps remove in the future
377
367
 
378
368
  # set default transitions for finite state machine if undefined
379
369
  if transitions is None:
@@ -387,20 +377,15 @@ def ft_chat_format_validation(
387
377
  # check order is valid via finite state machine
388
378
  state = None
389
379
  for i, (role, content) in enumerate(zip(roles, contents)):
390
-
391
380
  prefix = f"message #{i + 1}: "
392
381
 
393
382
  # check invalid roles
394
383
  if role not in valid_roles:
395
- raise Exception(
396
- f"{prefix}Invalid role (found `{role}`, expected one of `{valid_roles}`)"
397
- )
384
+ raise Exception(f"{prefix}Invalid role (found `{role}`, expected one of `{valid_roles}`)")
398
385
 
399
386
  # check content
400
387
  if not isinstance(content, str):
401
- raise Exception(
402
- f"{prefix}Content should be a string, got type `{type(content)}`"
403
- )
388
+ raise Exception(f"{prefix}Content should be a string, got type `{type(content)}`")
404
389
 
405
390
  # check transition
406
391
  if role not in transitions[state]:
@@ -464,9 +449,7 @@ def ft_chat_formatter(df: pd.DataFrame) -> List[Dict]:
464
449
  df = df.sort_values(["chat_id"], kind="stable")
465
450
  elif "message_id" in df.columns:
466
451
  if df["message_id"].duplicated().any():
467
- raise Exception(
468
- "If `message_id` is provided, it must not contain duplicate IDs."
469
- )
452
+ raise Exception("If `message_id` is provided, it must not contain duplicate IDs.")
470
453
  df = df.sort_values(["message_id"])
471
454
 
472
455
  # 2. build chats
@@ -477,12 +460,8 @@ def ft_chat_formatter(df: pd.DataFrame) -> List[Dict]:
477
460
  for _, row in df.iterrows():
478
461
  try:
479
462
  chat = json.loads(row["chat_json"])
480
- assert list(chat.keys()) == [
481
- "messages"
482
- ], "Each chat should have a 'messages' key, and nothing else."
483
- ft_chat_format_validation(
484
- chat["messages"]
485
- ) # will raise Exception if chat is invalid
463
+ assert list(chat.keys()) == ["messages"], "Each chat should have a 'messages' key, and nothing else."
464
+ ft_chat_format_validation(chat["messages"]) # will raise Exception if chat is invalid
486
465
  chats.append(chat)
487
466
  except json.JSONDecodeError:
488
467
  pass # TODO: add logger info here, prompt user to clean dataset carefully
@@ -492,9 +471,7 @@ def ft_chat_formatter(df: pd.DataFrame) -> List[Dict]:
492
471
  chat = []
493
472
  for i, row in df.iterrows():
494
473
  if row["role"] == "system" and len(chat) > 0:
495
- ft_chat_format_validation(
496
- chat
497
- ) # will raise Exception if chat is invalid
474
+ ft_chat_format_validation(chat) # will raise Exception if chat is invalid
498
475
  chats.append({"messages": chat})
499
476
  chat = []
500
477
  event = {"role": row["role"], "content": row["content"]}
@@ -529,15 +506,11 @@ def ft_code_formatter(
529
506
  # input and setup validation
530
507
  assert len(df) > 0, "Input dataframe should not be empty"
531
508
  assert "code" in df.columns, "Input dataframe should have a 'code' column"
532
- assert chunk_size > 0 and isinstance(
533
- chunk_size, int
534
- ), "`chunk_size` should be a positive integer"
509
+ assert chunk_size > 0 and isinstance(chunk_size, int), "`chunk_size` should be a positive integer"
535
510
 
536
511
  supported_formats = ["chat", "fim"]
537
512
  supported_langs = [e.value for e in Language]
538
- assert (
539
- language.lower() in supported_langs
540
- ), f"Invalid language. Valid choices are: {supported_langs}"
513
+ assert language.lower() in supported_langs, f"Invalid language. Valid choices are: {supported_langs}"
541
514
 
542
515
  # ensure correct encoding
543
516
  df["code"] = df["code"].map(lambda x: x.encode("utf8").decode("unicode_escape"))
@@ -574,7 +547,7 @@ def ft_code_formatter(
574
547
  roles = []
575
548
  contents = []
576
549
  for idx in range(0, len(chunks), 3):
577
- pre, mid, suf = chunks[idx: idx + 3]
550
+ pre, mid, suf = chunks[idx : idx + 3]
578
551
  interleaved = list(itertools.chain(*zip(templates, (pre, suf, mid))))
579
552
  user = "\n".join(interleaved[:-1])
580
553
  assistant = "\n".join(interleaved[-1:])
@@ -595,12 +568,11 @@ def ft_cqa_formatter(
595
568
  default_instruction="You are a helpful assistant.",
596
569
  default_context="",
597
570
  ) -> pd.DataFrame:
598
-
599
571
  # input and setup validation
600
572
  assert len(df) > 0, "Input dataframe should not be empty"
601
- assert {question_col, answer_col}.issubset(
602
- set(df.columns)
603
- ), f"Input dataframe must have columns `{question_col}`, and `{answer_col}`" # noqa
573
+ assert {question_col, answer_col}.issubset(set(df.columns)), (
574
+ f"Input dataframe must have columns `{question_col}`, and `{answer_col}`"
575
+ ) # noqa
604
576
 
605
577
  if instruction_col not in df.columns:
606
578
  df[instruction_col] = default_instruction