MindsDB 25.7.4.0__py3-none-any.whl → 25.8.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +13 -1
- mindsdb/api/a2a/agent.py +6 -16
- mindsdb/api/a2a/common/types.py +3 -4
- mindsdb/api/a2a/task_manager.py +24 -35
- mindsdb/api/a2a/utils.py +63 -0
- mindsdb/api/executor/command_executor.py +9 -15
- mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +21 -24
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +9 -3
- mindsdb/api/executor/sql_query/steps/subselect_step.py +11 -8
- mindsdb/api/executor/utilities/mysql_to_duckdb_functions.py +264 -0
- mindsdb/api/executor/utilities/sql.py +30 -0
- mindsdb/api/http/initialize.py +2 -1
- mindsdb/api/http/namespaces/agents.py +6 -7
- mindsdb/api/http/namespaces/views.py +56 -72
- mindsdb/integrations/handlers/db2_handler/db2_handler.py +19 -23
- mindsdb/integrations/handlers/gong_handler/__about__.py +2 -0
- mindsdb/integrations/handlers/gong_handler/__init__.py +30 -0
- mindsdb/integrations/handlers/gong_handler/connection_args.py +37 -0
- mindsdb/integrations/handlers/gong_handler/gong_handler.py +164 -0
- mindsdb/integrations/handlers/gong_handler/gong_tables.py +508 -0
- mindsdb/integrations/handlers/gong_handler/icon.svg +25 -0
- mindsdb/integrations/handlers/gong_handler/test_gong_handler.py +125 -0
- mindsdb/integrations/handlers/huggingface_handler/__init__.py +8 -12
- mindsdb/integrations/handlers/huggingface_handler/finetune.py +203 -223
- mindsdb/integrations/handlers/huggingface_handler/huggingface_handler.py +360 -383
- mindsdb/integrations/handlers/huggingface_handler/requirements.txt +7 -7
- mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt +7 -7
- mindsdb/integrations/handlers/huggingface_handler/settings.py +25 -25
- mindsdb/integrations/handlers/langchain_handler/langchain_handler.py +1 -2
- mindsdb/integrations/handlers/openai_handler/constants.py +11 -30
- mindsdb/integrations/handlers/openai_handler/helpers.py +27 -34
- mindsdb/integrations/handlers/openai_handler/openai_handler.py +14 -12
- mindsdb/integrations/handlers/salesforce_handler/constants.py +9 -2
- mindsdb/integrations/libs/llm/config.py +0 -14
- mindsdb/integrations/libs/llm/utils.py +0 -15
- mindsdb/integrations/utilities/files/file_reader.py +5 -19
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +1 -1
- mindsdb/interfaces/agents/agents_controller.py +83 -45
- mindsdb/interfaces/agents/constants.py +16 -3
- mindsdb/interfaces/agents/langchain_agent.py +84 -21
- mindsdb/interfaces/database/projects.py +111 -7
- mindsdb/interfaces/knowledge_base/controller.py +7 -1
- mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +6 -10
- mindsdb/interfaces/knowledge_base/preprocessing/text_splitter.py +73 -0
- mindsdb/interfaces/query_context/context_controller.py +14 -15
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +7 -1
- mindsdb/interfaces/skills/skill_tool.py +7 -1
- mindsdb/interfaces/skills/sql_agent.py +6 -2
- mindsdb/utilities/config.py +2 -0
- mindsdb/utilities/fs.py +60 -17
- {mindsdb-25.7.4.0.dist-info → mindsdb-25.8.3.0.dist-info}/METADATA +277 -262
- {mindsdb-25.7.4.0.dist-info → mindsdb-25.8.3.0.dist-info}/RECORD +57 -56
- mindsdb/integrations/handlers/anyscale_endpoints_handler/__about__.py +0 -9
- mindsdb/integrations/handlers/anyscale_endpoints_handler/__init__.py +0 -20
- mindsdb/integrations/handlers/anyscale_endpoints_handler/anyscale_endpoints_handler.py +0 -290
- mindsdb/integrations/handlers/anyscale_endpoints_handler/creation_args.py +0 -14
- mindsdb/integrations/handlers/anyscale_endpoints_handler/icon.svg +0 -4
- mindsdb/integrations/handlers/anyscale_endpoints_handler/requirements.txt +0 -2
- mindsdb/integrations/handlers/anyscale_endpoints_handler/settings.py +0 -51
- mindsdb/integrations/handlers/anyscale_endpoints_handler/tests/test_anyscale_endpoints_handler.py +0 -212
- /mindsdb/integrations/handlers/{anyscale_endpoints_handler/tests/__init__.py → gong_handler/requirements.txt} +0 -0
- {mindsdb-25.7.4.0.dist-info → mindsdb-25.8.3.0.dist-info}/WHEEL +0 -0
- {mindsdb-25.7.4.0.dist-info → mindsdb-25.8.3.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.7.4.0.dist-info → mindsdb-25.8.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,383 +1,360 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
#
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
#
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
#
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
#
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
#
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
#
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
#
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
#
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
#
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
#
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
#
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
# if task not in FINETUNE_MAP:
|
|
363
|
-
# raise KeyError(
|
|
364
|
-
# f"{task} is not currently supported, please choose a supported task - {', '.join(FINETUNE_MAP)}"
|
|
365
|
-
# )
|
|
366
|
-
|
|
367
|
-
# tokenizer, trainer = FINETUNE_MAP[task](df, args)
|
|
368
|
-
|
|
369
|
-
# try:
|
|
370
|
-
# trainer.train()
|
|
371
|
-
# trainer.save_model(
|
|
372
|
-
# model_folder
|
|
373
|
-
# ) # TODO: save entire pipeline instead https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.Pipeline.save_pretrained
|
|
374
|
-
# tokenizer.save_pretrained(model_folder)
|
|
375
|
-
|
|
376
|
-
# # persist changes
|
|
377
|
-
# self.model_storage.json_set("args", args)
|
|
378
|
-
# self.model_storage.folder_sync(model_folder_name)
|
|
379
|
-
|
|
380
|
-
# except Exception as e:
|
|
381
|
-
# err_str = f"Finetune failed with error: {str(e)}"
|
|
382
|
-
# logger.debug(err_str)
|
|
383
|
-
# raise Exception(err_str)
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import transformers
|
|
5
|
+
from huggingface_hub import HfApi
|
|
6
|
+
|
|
7
|
+
from mindsdb.integrations.handlers.huggingface_handler.settings import FINETUNE_MAP
|
|
8
|
+
from mindsdb.integrations.libs.base import BaseMLEngine
|
|
9
|
+
from mindsdb.utilities import log
|
|
10
|
+
|
|
11
|
+
logger = log.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HuggingFaceHandler(BaseMLEngine):
|
|
15
|
+
name = "huggingface"
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def create_validation(target, args=None, **kwargs):
|
|
19
|
+
if "using" in args:
|
|
20
|
+
args = args["using"]
|
|
21
|
+
|
|
22
|
+
hf_api = HfApi()
|
|
23
|
+
|
|
24
|
+
# check model is pytorch based
|
|
25
|
+
metadata = hf_api.model_info(args["model_name"])
|
|
26
|
+
if "pytorch" not in metadata.tags:
|
|
27
|
+
raise Exception(
|
|
28
|
+
"Currently only PyTorch models are supported (https://huggingface.co/models?library=pytorch&sort=downloads). To request another library, please contact us on our community slack (https://mindsdb.com/joincommunity)."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# check model task
|
|
32
|
+
supported_tasks = [
|
|
33
|
+
"text-classification",
|
|
34
|
+
"text-generation",
|
|
35
|
+
"zero-shot-classification",
|
|
36
|
+
"translation",
|
|
37
|
+
"summarization",
|
|
38
|
+
"text2text-generation",
|
|
39
|
+
"fill-mask",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
if metadata.pipeline_tag not in supported_tasks:
|
|
43
|
+
raise Exception(
|
|
44
|
+
f"Not supported task for model: {metadata.pipeline_tag}.\
|
|
45
|
+
Should be one of {', '.join(supported_tasks)}"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
if "task" not in args:
|
|
49
|
+
args["task"] = metadata.pipeline_tag
|
|
50
|
+
elif args["task"] != metadata.pipeline_tag:
|
|
51
|
+
raise Exception(f"Task mismatch for model: {args['task']}!={metadata.pipeline_tag}")
|
|
52
|
+
|
|
53
|
+
input_keys = list(args.keys())
|
|
54
|
+
|
|
55
|
+
# task, model_name, input_column is essential
|
|
56
|
+
for key in ["task", "model_name", "input_column"]:
|
|
57
|
+
if key not in args:
|
|
58
|
+
raise Exception(f'Parameter "{key}" is required')
|
|
59
|
+
input_keys.remove(key)
|
|
60
|
+
|
|
61
|
+
# check tasks input
|
|
62
|
+
|
|
63
|
+
if args["task"] == "zero-shot-classification":
|
|
64
|
+
key = "candidate_labels"
|
|
65
|
+
if key not in args:
|
|
66
|
+
raise Exception('"candidate_labels" is required for zero-shot-classification')
|
|
67
|
+
input_keys.remove(key)
|
|
68
|
+
|
|
69
|
+
if args["task"] == "translation":
|
|
70
|
+
keys = ["lang_input", "lang_output"]
|
|
71
|
+
for key in keys:
|
|
72
|
+
if key not in args:
|
|
73
|
+
raise Exception(f"{key} is required for translation")
|
|
74
|
+
input_keys.remove(key)
|
|
75
|
+
|
|
76
|
+
if args["task"] == "summarization":
|
|
77
|
+
keys = ["min_output_length", "max_output_length"]
|
|
78
|
+
for key in keys:
|
|
79
|
+
if key not in args:
|
|
80
|
+
raise Exception(f"{key} is required for summarization")
|
|
81
|
+
input_keys.remove(key)
|
|
82
|
+
|
|
83
|
+
# optional keys
|
|
84
|
+
for key in ["labels", "max_length", "truncation_policy"]:
|
|
85
|
+
if key in input_keys:
|
|
86
|
+
input_keys.remove(key)
|
|
87
|
+
|
|
88
|
+
if len(input_keys) > 0:
|
|
89
|
+
raise Exception(f"Not expected parameters: {', '.join(input_keys)}")
|
|
90
|
+
|
|
91
|
+
def create(self, target, args=None, **kwargs):
|
|
92
|
+
# TODO change BaseMLEngine api?
|
|
93
|
+
if "using" in args:
|
|
94
|
+
args = args["using"]
|
|
95
|
+
|
|
96
|
+
args["target"] = target
|
|
97
|
+
|
|
98
|
+
model_name = args["model_name"]
|
|
99
|
+
hf_model_storage_path = self.engine_storage.folder_get(model_name) # real
|
|
100
|
+
|
|
101
|
+
if args["task"] == "translation":
|
|
102
|
+
args["task_proper"] = f"translation_{args['lang_input']}_to_{args['lang_output']}"
|
|
103
|
+
else:
|
|
104
|
+
args["task_proper"] = args["task"]
|
|
105
|
+
|
|
106
|
+
logger.debug(f"Checking file system for {model_name}...")
|
|
107
|
+
|
|
108
|
+
####
|
|
109
|
+
# Check if pipeline has already been downloaded
|
|
110
|
+
try:
|
|
111
|
+
pipeline = transformers.pipeline(
|
|
112
|
+
task=args["task_proper"], model=hf_model_storage_path, tokenizer=hf_model_storage_path
|
|
113
|
+
)
|
|
114
|
+
logger.debug("Model already downloaded!")
|
|
115
|
+
####
|
|
116
|
+
# Otherwise download it
|
|
117
|
+
except (ValueError, OSError):
|
|
118
|
+
try:
|
|
119
|
+
logger.debug(f"Downloading {model_name}...")
|
|
120
|
+
pipeline = transformers.pipeline(task=args["task_proper"], model=model_name)
|
|
121
|
+
|
|
122
|
+
pipeline.save_pretrained(hf_model_storage_path)
|
|
123
|
+
|
|
124
|
+
logger.debug(f"Saved to {hf_model_storage_path}")
|
|
125
|
+
except Exception:
|
|
126
|
+
raise Exception(
|
|
127
|
+
"Error while downloading and setting up the model. Please try a different model. We're working on expanding the list of supported models, so we would appreciate it if you let us know about this in our community slack (https://mindsdb.com/joincommunity)."
|
|
128
|
+
) # noqa
|
|
129
|
+
####
|
|
130
|
+
|
|
131
|
+
if "max_length" in args:
|
|
132
|
+
pass
|
|
133
|
+
elif "max_position_embeddings" in pipeline.model.config.to_dict().keys():
|
|
134
|
+
args["max_length"] = pipeline.model.config.max_position_embeddings
|
|
135
|
+
elif "max_length" in pipeline.model.config.to_dict().keys():
|
|
136
|
+
args["max_length"] = pipeline.model.config.max_length
|
|
137
|
+
else:
|
|
138
|
+
logger.debug("No max_length found!")
|
|
139
|
+
|
|
140
|
+
labels_default = pipeline.model.config.id2label
|
|
141
|
+
labels_map = {}
|
|
142
|
+
if "labels" in args:
|
|
143
|
+
for num in labels_default.keys():
|
|
144
|
+
labels_map[labels_default[num]] = args["labels"][num]
|
|
145
|
+
args["labels_map"] = labels_map
|
|
146
|
+
else:
|
|
147
|
+
for num in labels_default.keys():
|
|
148
|
+
labels_map[labels_default[num]] = labels_default[num]
|
|
149
|
+
args["labels_map"] = labels_map
|
|
150
|
+
|
|
151
|
+
# store and persist in model folder
|
|
152
|
+
self.model_storage.json_set("args", args)
|
|
153
|
+
|
|
154
|
+
# persist changes to handler folder
|
|
155
|
+
self.engine_storage.folder_sync(model_name)
|
|
156
|
+
|
|
157
|
+
# todo move infer tasks to a seperate file
|
|
158
|
+
def predict_text_classification(self, pipeline, item, args):
|
|
159
|
+
top_k = args.get("top_k", 1000)
|
|
160
|
+
|
|
161
|
+
result = pipeline([item], top_k=top_k, truncation=True, max_length=args["max_length"])[0]
|
|
162
|
+
|
|
163
|
+
final = {}
|
|
164
|
+
explain = {}
|
|
165
|
+
if type(result) == dict:
|
|
166
|
+
result = [result]
|
|
167
|
+
final[args["target"]] = args["labels_map"][result[0]["label"]]
|
|
168
|
+
for elem in result:
|
|
169
|
+
if args["labels_map"]:
|
|
170
|
+
explain[args["labels_map"][elem["label"]]] = elem["score"]
|
|
171
|
+
else:
|
|
172
|
+
explain[elem["label"]] = elem["score"]
|
|
173
|
+
final[f"{args['target']}_explain"] = explain
|
|
174
|
+
return final
|
|
175
|
+
|
|
176
|
+
def predict_text_generation(self, pipeline, item, args):
|
|
177
|
+
result = pipeline([item], max_length=args["max_length"])[0]
|
|
178
|
+
|
|
179
|
+
final = {}
|
|
180
|
+
final[args["target"]] = result["generated_text"]
|
|
181
|
+
|
|
182
|
+
return final
|
|
183
|
+
|
|
184
|
+
def predict_zero_shot(self, pipeline, item, args):
|
|
185
|
+
top_k = args.get("top_k", 1000)
|
|
186
|
+
|
|
187
|
+
result = pipeline(
|
|
188
|
+
[item],
|
|
189
|
+
candidate_labels=args["candidate_labels"],
|
|
190
|
+
truncation=True,
|
|
191
|
+
top_k=top_k,
|
|
192
|
+
max_length=args["max_length"],
|
|
193
|
+
)[0]
|
|
194
|
+
|
|
195
|
+
final = {}
|
|
196
|
+
final[args["target"]] = result["labels"][0]
|
|
197
|
+
|
|
198
|
+
explain = dict(zip(result["labels"], result["scores"]))
|
|
199
|
+
final[f"{args['target']}_explain"] = explain
|
|
200
|
+
|
|
201
|
+
return final
|
|
202
|
+
|
|
203
|
+
def predict_translation(self, pipeline, item, args):
|
|
204
|
+
result = pipeline([item], max_length=args["max_length"])[0]
|
|
205
|
+
|
|
206
|
+
final = {}
|
|
207
|
+
final[args["target"]] = result["translation_text"]
|
|
208
|
+
|
|
209
|
+
return final
|
|
210
|
+
|
|
211
|
+
def predict_summarization(self, pipeline, item, args):
|
|
212
|
+
result = pipeline(
|
|
213
|
+
[item],
|
|
214
|
+
min_length=args["min_output_length"],
|
|
215
|
+
max_length=args["max_output_length"],
|
|
216
|
+
)[0]
|
|
217
|
+
|
|
218
|
+
final = {}
|
|
219
|
+
final[args["target"]] = result["summary_text"]
|
|
220
|
+
|
|
221
|
+
return final
|
|
222
|
+
|
|
223
|
+
def predict_text2text(self, pipeline, item, args):
|
|
224
|
+
result = pipeline([item], max_length=args["max_length"])[0]
|
|
225
|
+
|
|
226
|
+
final = {}
|
|
227
|
+
final[args["target"]] = result["generated_text"]
|
|
228
|
+
|
|
229
|
+
return final
|
|
230
|
+
|
|
231
|
+
def predict_fill_mask(self, pipeline, item, args):
|
|
232
|
+
result = pipeline([item])[0]
|
|
233
|
+
|
|
234
|
+
final = {}
|
|
235
|
+
final[args["target"]] = result[0]["sequence"]
|
|
236
|
+
explain = {elem["sequence"]: elem["score"] for elem in result}
|
|
237
|
+
final[f"{args['target']}_explain"] = explain
|
|
238
|
+
|
|
239
|
+
return final
|
|
240
|
+
|
|
241
|
+
def predict(self, df, args=None):
|
|
242
|
+
fnc_list = {
|
|
243
|
+
"text-classification": self.predict_text_classification,
|
|
244
|
+
"text-generation": self.predict_text_generation,
|
|
245
|
+
"zero-shot-classification": self.predict_zero_shot,
|
|
246
|
+
"translation": self.predict_translation,
|
|
247
|
+
"summarization": self.predict_summarization,
|
|
248
|
+
"fill-mask": self.predict_fill_mask,
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
# get stuff from model folder
|
|
252
|
+
args = self.model_storage.json_get("args")
|
|
253
|
+
|
|
254
|
+
task = args["task"]
|
|
255
|
+
|
|
256
|
+
if task not in fnc_list:
|
|
257
|
+
raise RuntimeError(f"Unknown task: {task}")
|
|
258
|
+
|
|
259
|
+
fnc = fnc_list[task]
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
# load from model storage (finetuned models will use this)
|
|
263
|
+
hf_model_storage_path = self.model_storage.folder_get(args["model_name"])
|
|
264
|
+
pipeline = transformers.pipeline(
|
|
265
|
+
task=args["task_proper"],
|
|
266
|
+
model=hf_model_storage_path,
|
|
267
|
+
tokenizer=hf_model_storage_path,
|
|
268
|
+
)
|
|
269
|
+
except (ValueError, OSError):
|
|
270
|
+
# load from engine storage (i.e. 'common' models)
|
|
271
|
+
hf_model_storage_path = self.engine_storage.folder_get(args["model_name"])
|
|
272
|
+
pipeline = transformers.pipeline(
|
|
273
|
+
task=args["task_proper"],
|
|
274
|
+
model=hf_model_storage_path,
|
|
275
|
+
tokenizer=hf_model_storage_path,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
input_column = args["input_column"]
|
|
279
|
+
if input_column not in df.columns:
|
|
280
|
+
raise RuntimeError(f'Column "{input_column}" not found in input data')
|
|
281
|
+
input_list = df[input_column]
|
|
282
|
+
|
|
283
|
+
max_tokens = pipeline.tokenizer.model_max_length
|
|
284
|
+
|
|
285
|
+
results = []
|
|
286
|
+
for item in input_list:
|
|
287
|
+
if max_tokens is not None:
|
|
288
|
+
tokens = pipeline.tokenizer.encode(item)
|
|
289
|
+
if len(tokens) > max_tokens:
|
|
290
|
+
truncation_policy = args.get("truncation_policy", "strict")
|
|
291
|
+
if truncation_policy == "strict":
|
|
292
|
+
results.append({"error": f"Tokens count exceed model limit: {len(tokens)} > {max_tokens}"})
|
|
293
|
+
continue
|
|
294
|
+
elif truncation_policy == "left":
|
|
295
|
+
tokens = tokens[-max_tokens + 1 : -1] # cut 2 empty tokens from left and right
|
|
296
|
+
else:
|
|
297
|
+
tokens = tokens[1 : max_tokens - 1] # cut 2 empty tokens from left and right
|
|
298
|
+
|
|
299
|
+
item = pipeline.tokenizer.decode(tokens)
|
|
300
|
+
|
|
301
|
+
item = str(item)
|
|
302
|
+
try:
|
|
303
|
+
result = fnc(pipeline, item, args)
|
|
304
|
+
except Exception as e:
|
|
305
|
+
msg = str(e).strip()
|
|
306
|
+
if msg == "":
|
|
307
|
+
msg = e.__class__.__name__
|
|
308
|
+
result = {"error": msg}
|
|
309
|
+
results.append(result)
|
|
310
|
+
|
|
311
|
+
pred_df = pd.DataFrame(results)
|
|
312
|
+
|
|
313
|
+
return pred_df
|
|
314
|
+
|
|
315
|
+
def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
|
|
316
|
+
args = self.model_storage.json_get("args")
|
|
317
|
+
if attribute == "args":
|
|
318
|
+
return pd.DataFrame(args.items(), columns=["key", "value"])
|
|
319
|
+
elif attribute == "metadata":
|
|
320
|
+
hf_api = HfApi()
|
|
321
|
+
metadata = hf_api.model_info(args["model_name"])
|
|
322
|
+
data = metadata.__dict__
|
|
323
|
+
return pd.DataFrame(list(data.items()), columns=["key", "value"])
|
|
324
|
+
else:
|
|
325
|
+
tables = ["args", "metadata"]
|
|
326
|
+
return pd.DataFrame(tables, columns=["tables"])
|
|
327
|
+
|
|
328
|
+
def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
|
|
329
|
+
finetune_args = args if args else {}
|
|
330
|
+
args = self.base_model_storage.json_get("args")
|
|
331
|
+
args.update(finetune_args)
|
|
332
|
+
|
|
333
|
+
model_name = args["model_name"]
|
|
334
|
+
model_folder = self.model_storage.folder_get(model_name)
|
|
335
|
+
args["model_folder"] = model_folder
|
|
336
|
+
model_folder_name = model_folder.split("/")[-1]
|
|
337
|
+
task = args["task"]
|
|
338
|
+
|
|
339
|
+
if task not in FINETUNE_MAP:
|
|
340
|
+
raise KeyError(
|
|
341
|
+
f"{task} is not currently supported, please choose a supported task - {', '.join(FINETUNE_MAP)}"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
tokenizer, trainer = FINETUNE_MAP[task](df, args)
|
|
345
|
+
|
|
346
|
+
try:
|
|
347
|
+
trainer.train()
|
|
348
|
+
trainer.save_model(
|
|
349
|
+
model_folder
|
|
350
|
+
) # TODO: save entire pipeline instead https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.Pipeline.save_pretrained
|
|
351
|
+
tokenizer.save_pretrained(model_folder)
|
|
352
|
+
|
|
353
|
+
# persist changes
|
|
354
|
+
self.model_storage.json_set("args", args)
|
|
355
|
+
self.model_storage.folder_sync(model_folder_name)
|
|
356
|
+
|
|
357
|
+
except Exception as e:
|
|
358
|
+
err_str = f"Finetune failed with error: {str(e)}"
|
|
359
|
+
logger.debug(err_str)
|
|
360
|
+
raise Exception(err_str)
|