genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,2982 @@
1
+ """MLflow module for HuggingFace/transformer support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import base64
7
+ import binascii
8
+ import contextlib
9
+ import copy
10
+ import functools
11
+ import importlib
12
+ import json
13
+ import logging
14
+ import os
15
+ import pathlib
16
+ import re
17
+ import shutil
18
+ import string
19
+ import sys
20
+ from collections import namedtuple
21
+ from types import MappingProxyType
22
+ from typing import TYPE_CHECKING, Any, Optional, Union
23
+ from urllib.parse import urlparse
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+ import yaml
28
+ from packaging.version import Version
29
+
30
+ from mlflow import pyfunc
31
+ from mlflow.entities.model_registry.prompt import Prompt
32
+ from mlflow.environment_variables import (
33
+ MLFLOW_DEFAULT_PREDICTION_DEVICE,
34
+ MLFLOW_HUGGINGFACE_DEVICE_MAP_STRATEGY,
35
+ MLFLOW_HUGGINGFACE_USE_DEVICE_MAP,
36
+ MLFLOW_HUGGINGFACE_USE_LOW_CPU_MEM_USAGE,
37
+ MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT,
38
+ )
39
+ from mlflow.exceptions import MlflowException
40
+ from mlflow.models import (
41
+ Model,
42
+ ModelInputExample,
43
+ ModelSignature,
44
+ )
45
+ from mlflow.models.model import MLMODEL_FILE_NAME
46
+ from mlflow.models.utils import _save_example
47
+ from mlflow.protos.databricks_pb2 import (
48
+ BAD_REQUEST,
49
+ INVALID_PARAMETER_VALUE,
50
+ RESOURCE_DOES_NOT_EXIST,
51
+ )
52
+ from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
53
+ from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
54
+ from mlflow.tracking.artifact_utils import _get_root_uri_and_artifact_path
55
+ from mlflow.transformers.flavor_config import (
56
+ FlavorKey,
57
+ build_flavor_config,
58
+ build_flavor_config_from_local_checkpoint,
59
+ update_flavor_conf_to_persist_pretrained_model,
60
+ )
61
+ from mlflow.transformers.hub_utils import (
62
+ is_valid_hf_repo_id,
63
+ )
64
+ from mlflow.transformers.llm_inference_utils import (
65
+ _LLM_INFERENCE_TASK_CHAT,
66
+ _LLM_INFERENCE_TASK_COMPLETIONS,
67
+ _LLM_INFERENCE_TASK_EMBEDDING,
68
+ _LLM_INFERENCE_TASK_KEY,
69
+ _LLM_INFERENCE_TASK_PREFIX,
70
+ _METADATA_LLM_INFERENCE_TASK_KEY,
71
+ _SUPPORTED_LLM_INFERENCE_TASK_TYPES_BY_PIPELINE_TASK,
72
+ _get_default_task_for_llm_inference_task,
73
+ convert_messages_to_prompt,
74
+ infer_signature_from_llm_inference_task,
75
+ postprocess_output_for_llm_inference_task,
76
+ postprocess_output_for_llm_v1_embedding_task,
77
+ preprocess_llm_embedding_params,
78
+ preprocess_llm_inference_input,
79
+ )
80
+ from mlflow.transformers.model_io import (
81
+ _COMPONENTS_BINARY_DIR_NAME,
82
+ _MODEL_BINARY_FILE_NAME,
83
+ load_model_and_components_from_huggingface_hub,
84
+ load_model_and_components_from_local,
85
+ save_local_checkpoint,
86
+ save_pipeline_pretrained_weights,
87
+ )
88
+ from mlflow.transformers.peft import (
89
+ _PEFT_ADAPTOR_DIR_NAME,
90
+ get_model_with_peft_adapter,
91
+ get_peft_base_model,
92
+ is_peft_model,
93
+ )
94
+ from mlflow.transformers.signature import (
95
+ format_input_example_for_special_cases,
96
+ infer_or_get_default_signature,
97
+ )
98
+ from mlflow.transformers.torch_utils import _TORCH_DTYPE_KEY, _deserialize_torch_dtype
99
+ from mlflow.types.utils import _validate_input_dictionary_contains_only_strings_and_lists_of_strings
100
+ from mlflow.utils import _truncate_and_ellipsize
101
+ from mlflow.utils.autologging_utils import (
102
+ autologging_integration,
103
+ disable_discrete_autologging,
104
+ safe_patch,
105
+ )
106
+ from mlflow.utils.docstring_utils import (
107
+ LOG_MODEL_PARAM_DOCS,
108
+ docstring_version_compatibility_warning,
109
+ format_docstring,
110
+ )
111
+ from mlflow.utils.environment import (
112
+ _CONDA_ENV_FILE_NAME,
113
+ _CONSTRAINTS_FILE_NAME,
114
+ _PYTHON_ENV_FILE_NAME,
115
+ _REQUIREMENTS_FILE_NAME,
116
+ _mlflow_conda_env,
117
+ _process_conda_env,
118
+ _process_pip_requirements,
119
+ _PythonEnv,
120
+ _validate_env_arguments,
121
+ infer_pip_requirements,
122
+ )
123
+ from mlflow.utils.file_utils import TempDir, get_total_file_size, write_to
124
+ from mlflow.utils.logging_utils import suppress_logs
125
+ from mlflow.utils.model_utils import (
126
+ _add_code_from_conf_to_system_path,
127
+ _download_artifact_from_uri,
128
+ _get_flavor_configuration,
129
+ _get_flavor_configuration_from_uri,
130
+ _validate_and_copy_code_paths,
131
+ _validate_and_prepare_target_save_path,
132
+ )
133
+ from mlflow.utils.requirements_utils import _get_pinned_requirement
134
+
135
+ # The following import is only used for type hinting
136
+ if TYPE_CHECKING:
137
+ import torch
138
+ from transformers import Pipeline
139
+
140
+ # Transformers pipeline complains that PeftModel is not supported for any task type, even
141
+ # when the wrapped model is supported. As MLflow require users to use pipeline for logging,
142
+ # we should suppress that confusing error message.
143
+ _PEFT_PIPELINE_ERROR_MSG = re.compile(r"The model 'PeftModel[^']*' is not supported for")
144
+
145
+ FLAVOR_NAME = "transformers"
146
+
147
+ _CARD_TEXT_FILE_NAME = "model_card.md"
148
+ _CARD_DATA_FILE_NAME = "model_card_data.yaml"
149
+ _INFERENCE_CONFIG_BINARY_KEY = "inference_config.txt"
150
+ _LICENSE_FILE_NAME = "LICENSE.txt"
151
+ _LICENSE_FILE_PATTERN = re.compile(r"license(\.[a-z]+|$)", re.IGNORECASE)
152
+
153
+ _SUPPORTED_RETURN_TYPES = {"pipeline", "components"}
154
+ # The default device id for CPU is -1 and GPU IDs are ordinal starting at 0, as documented here:
155
+ # https://huggingface.co/transformers/v4.7.0/main_classes/pipelines.html
156
+ _TRANSFORMERS_DEFAULT_CPU_DEVICE_ID = -1
157
+ _TRANSFORMERS_DEFAULT_GPU_DEVICE_ID = 0
158
+ _SUPPORTED_SAVE_KEYS = {
159
+ FlavorKey.MODEL,
160
+ FlavorKey.TOKENIZER,
161
+ FlavorKey.FEATURE_EXTRACTOR,
162
+ FlavorKey.IMAGE_PROCESSOR,
163
+ FlavorKey.TORCH_DTYPE,
164
+ }
165
+
166
+ _SUPPORTED_PROMPT_TEMPLATING_TASK_TYPES = {
167
+ "feature-extraction",
168
+ "fill-mask",
169
+ "summarization",
170
+ "text2text-generation",
171
+ "text-generation",
172
+ }
173
+
174
+ _PROMPT_TEMPLATE_RETURN_FULL_TEXT_INFO = (
175
+ "text-generation pipelines saved with prompt templates have the `return_full_text` "
176
+ "pipeline kwarg set to False by default. To override this behavior, provide a "
177
+ "`model_config` dict with `return_full_text` set to `True` when saving the model."
178
+ )
179
+
180
+
181
+ # Alias for the audio data types that Transformers pipeline (e.g. Whisper) expects.
182
+ # It can be one of:
183
+ # 1. A string representing the path or URL to an audio file.
184
+ # 2. A bytes object representing the raw audio data.
185
+ # 3. A float numpy array representing the audio time series.
186
+ AudioInput = Union[str, bytes, np.ndarray]
187
+
188
+ _logger = logging.getLogger(__name__)
189
+
190
+
191
+ def get_default_pip_requirements(model) -> list[str]:
192
+ """
193
+ Args:
194
+ model: The model instance to be saved in order to provide the required underlying
195
+ deep learning execution framework dependency requirements. Note that this must
196
+ be the actual model instance and not a Pipeline.
197
+
198
+ Returns:
199
+ A list of default pip requirements for MLflow Models that have been produced with the
200
+ ``transformers`` flavor. Calls to :py:func:`save_model()` and :py:func:`log_model()`
201
+ produce a pip environment that contain these requirements at a minimum.
202
+ """
203
+ packages = ["transformers"]
204
+
205
+ try:
206
+ engine = _get_engine_type(model)
207
+ packages.append(engine)
208
+ except Exception as e:
209
+ packages += ["torch", "tensorflow"]
210
+ _logger.warning(
211
+ "Could not infer model execution engine type due to huggingface_hub not "
212
+ "being installed or unable to connect in online mode. Adding both Pytorch"
213
+ f"and Tensorflow to requirements.\nFailure cause: {e}"
214
+ )
215
+
216
+ if "torch" in packages:
217
+ packages.append("torchvision")
218
+ if importlib.util.find_spec("accelerate"):
219
+ packages.append("accelerate")
220
+
221
+ if is_peft_model(model):
222
+ packages.append("peft")
223
+
224
+ return [_get_pinned_requirement(module) for module in packages]
225
+
226
+
227
+ def _validate_transformers_model_dict(transformers_model):
228
+ """
229
+ Validator for a submitted save dictionary for the transformers model. If any additional keys
230
+ are provided, raise to indicate which invalid keys were submitted.
231
+ """
232
+ if isinstance(transformers_model, dict):
233
+ invalid_keys = [key for key in transformers_model.keys() if key not in _SUPPORTED_SAVE_KEYS]
234
+ if invalid_keys:
235
+ raise MlflowException(
236
+ "Invalid dictionary submitted for 'transformers_model'. The "
237
+ f"key(s) {invalid_keys} are not permitted. Must be one of: "
238
+ f"{_SUPPORTED_SAVE_KEYS}",
239
+ error_code=INVALID_PARAMETER_VALUE,
240
+ )
241
+ if FlavorKey.MODEL not in transformers_model:
242
+ raise MlflowException(
243
+ f"The 'transformers_model' dictionary must have an entry for {FlavorKey.MODEL}",
244
+ error_code=INVALID_PARAMETER_VALUE,
245
+ )
246
+ model = transformers_model[FlavorKey.MODEL]
247
+ else:
248
+ model = transformers_model.model
249
+ if not hasattr(model, "name_or_path"):
250
+ raise MlflowException(
251
+ f"The submitted model type {type(model).__name__} does not inherit "
252
+ "from a transformers pre-trained model. It is missing the attribute "
253
+ "'name_or_path'. Please verify that the model is a supported "
254
+ "transformers model.",
255
+ error_code=INVALID_PARAMETER_VALUE,
256
+ )
257
+
258
+
259
+ def get_default_conda_env(model):
260
+ """
261
+ Returns:
262
+ The default Conda environment for MLflow Models produced with the ``transformers``
263
+ flavor, based on the model instance framework type of the model to be logged.
264
+ """
265
+ return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements(model))
266
+
267
+
268
+ @docstring_version_compatibility_warning(integration_name=FLAVOR_NAME)
269
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
270
+ def save_model(
271
+ transformers_model,
272
+ path: str,
273
+ processor=None,
274
+ task: Optional[str] = None,
275
+ torch_dtype: Optional[torch.dtype] = None,
276
+ model_card=None,
277
+ code_paths: Optional[list[str]] = None,
278
+ mlflow_model: Optional[Model] = None,
279
+ signature: Optional[ModelSignature] = None,
280
+ input_example: Optional[ModelInputExample] = None,
281
+ pip_requirements: Optional[Union[list[str], str]] = None,
282
+ extra_pip_requirements: Optional[Union[list[str], str]] = None,
283
+ conda_env=None,
284
+ metadata: Optional[dict[str, Any]] = None,
285
+ model_config: Optional[dict[str, Any]] = None,
286
+ prompt_template: Optional[str] = None,
287
+ save_pretrained: bool = True,
288
+ **kwargs, # pylint: disable=unused-argument
289
+ ) -> None:
290
+ """
291
+ Save a trained transformers model to a path on the local file system. Note that
292
+ saving transformers models with custom code (i.e. models that require
293
+ ``trust_remote_code=True``) requires ``transformers >= 4.26.0``.
294
+
295
+ Args:
296
+ transformers_model:
297
+ The transformers model to save. This can be one of the following format:
298
+
299
+ 1. A transformers `Pipeline` instance.
300
+ 2. A dictionary that maps required components of a pipeline to the named keys
301
+ of ["model", "image_processor", "tokenizer", "feature_extractor"].
302
+ The `model` key in the dictionary must map to a value that inherits from
303
+ `PreTrainedModel`, `TFPreTrainedModel`, or `FlaxPreTrainedModel`.
304
+ All other component entries in the dictionary must support the defined task
305
+ type that is associated with the base model type configuration.
306
+ 3. A string that represents a path to a local/DBFS directory containing a model
307
+ checkpoint. The directory must contain a `config.json` file that is required
308
+ for loading the transformers model. This is particularly useful when logging
309
+ a model that cannot be loaded into memory for serialization.
310
+
311
+ An example of specifying a `Pipeline` from a default pipeline instantiation:
312
+
313
+ .. code-block:: python
314
+
315
+ from transformers import pipeline
316
+
317
+ qa_pipe = pipeline("question-answering", "csarron/mobilebert-uncased-squad-v2")
318
+
319
+ with mlflow.start_run():
320
+ mlflow.transformers.save_model(
321
+ transformers_model=qa_pipe,
322
+ path="path/to/save/model",
323
+ )
324
+
325
+ An example of specifying component-level parts of a transformers model is shown below:
326
+
327
+ .. code-block:: python
328
+
329
+ from transformers import MobileBertForQuestionAnswering, AutoTokenizer
330
+
331
+ architecture = "csarron/mobilebert-uncased-squad-v2"
332
+ tokenizer = AutoTokenizer.from_pretrained(architecture)
333
+ model = MobileBertForQuestionAnswering.from_pretrained(architecture)
334
+
335
+ with mlflow.start_run():
336
+ components = {
337
+ "model": model,
338
+ "tokenizer": tokenizer,
339
+ }
340
+ mlflow.transformers.save_model(
341
+ transformers_model=components,
342
+ path="path/to/save/model",
343
+ )
344
+
345
+ An example of specifying a local checkpoint path is shown below:
346
+
347
+ .. code-block:: python
348
+
349
+ with mlflow.start_run():
350
+ mlflow.transformers.save_model(
351
+ transformers_model="path/to/local/checkpoint",
352
+ path="path/to/save/model",
353
+ )
354
+
355
+ path: Local path destination for the serialized model to be saved.
356
+ processor: An optional ``Processor`` subclass object. Some model architectures,
357
+ particularly multi-modal types, utilize Processors to combine text
358
+ encoding and image or audio encoding in a single entrypoint.
359
+
360
+ .. Note:: If a processor is supplied when saving a model, the
361
+ model will be unavailable for loading as a ``Pipeline`` or for
362
+ usage with pyfunc inference.
363
+ task: The transformers-specific task type of the model, or MLflow inference task type.
364
+ If provided a transformers-specific task type, these strings are utilized so
365
+ that a pipeline can be created with the appropriate internal call architecture
366
+ to meet the needs of a given model.
367
+ If this argument is provided as a inference task type or not specified, the
368
+ pipeline utilities within the transformers library will be used to infer the
369
+ correct task type. If the value specified is not a supported type,
370
+ an Exception will be thrown.
371
+ torch_dtype: The Pytorch dtype applied to the model when loading back. This is useful
372
+ when you want to save the model with a specific dtype that is different from the
373
+ dtype of the model when it was trained. If not specified, the current dtype of the
374
+ model instance will be used.
375
+ model_card: An Optional `ModelCard` instance from `huggingface-hub`. If provided, the
376
+ contents of the model card will be saved along with the provided
377
+ `transformers_model`. If not provided, an attempt will be made to fetch
378
+ the card from the base pretrained model that is provided (or the one that is
379
+ included within a provided `Pipeline`).
380
+
381
+ .. Note:: In order for a ModelCard to be fetched (if not provided),
382
+ the huggingface_hub package must be installed and the version
383
+ must be >=0.10.0
384
+
385
+ code_paths: {{ code_paths }}
386
+ mlflow_model: An MLflow model object that specifies the flavor that this model is being
387
+ added to.
388
+ signature: A Model Signature object that describes the input and output Schema of the
389
+ model. The model signature can be inferred using `infer_signature` function
390
+ of `mlflow.models.signature`.
391
+
392
+ .. code-block:: python
393
+ :caption: Example
394
+
395
+ from mlflow.models import infer_signature
396
+ from mlflow.transformers import generate_signature_output
397
+ from transformers import pipeline
398
+
399
+ en_to_de = pipeline("translation_en_to_de")
400
+
401
+ data = "MLflow is great!"
402
+ output = generate_signature_output(en_to_de, data)
403
+ signature = infer_signature(data, output)
404
+
405
+ mlflow.transformers.save_model(
406
+ transformers_model=en_to_de,
407
+ path="/path/to/save/model",
408
+ signature=signature,
409
+ input_example=data,
410
+ )
411
+
412
+ loaded = mlflow.pyfunc.load_model("/path/to/save/model")
413
+ print(loaded.predict(data))
414
+ # MLflow ist großartig!
415
+
416
+ If an input_example is provided and the signature is not, a signature will
417
+ be inferred automatically and applied to the MLmodel file iff the
418
+ pipeline type is a text-based model (NLP). If the pipeline type is not
419
+ a supported type, this inference functionality will not function correctly
420
+ and a warning will be issued. In order to ensure that a precise signature
421
+ is logged, it is recommended to explicitly provide one.
422
+ input_example: {{ input_example }}
423
+ pip_requirements: {{ pip_requirements }}
424
+ extra_pip_requirements: {{ extra_pip_requirements }}
425
+ conda_env: {{ conda_env }}
426
+ metadata: {{ metadata }}
427
+ model_config:
428
+ A dict of valid overrides that can be applied to a pipeline instance during inference.
429
+ These arguments are used exclusively for the case of loading the model as a ``pyfunc``
430
+ Model or for use in Spark.
431
+ These values are not applied to a returned Pipeline from a call to
432
+ ``mlflow.transformers.load_model()``
433
+
434
+ .. Warning:: If the key provided is not compatible with either the
435
+ Pipeline instance for the task provided or is not a valid
436
+ override to any arguments available in the Model, an
437
+ Exception will be raised at runtime. It is very important
438
+ to validate the entries in this dictionary to ensure
439
+ that they are valid prior to saving or logging.
440
+
441
+ An example of providing overrides for a question generation model:
442
+
443
+ .. code-block:: python
444
+
445
+ from transformers import pipeline, AutoTokenizer
446
+
447
+ task = "text-generation"
448
+ architecture = "gpt2"
449
+
450
+ sentence_pipeline = pipeline(
451
+ task=task,
452
+ tokenizer=AutoTokenizer.from_pretrained(architecture),
453
+ model=architecture,
454
+ )
455
+
456
+ # Validate that the overrides function
457
+ prompts = ["Generative models are", "I'd like a coconut so that I can"]
458
+
459
+ # validation of config prior to save or log
460
+ model_config = {
461
+ "top_k": 2,
462
+ "num_beams": 5,
463
+ "max_length": 30,
464
+ "temperature": 0.62,
465
+ "top_p": 0.85,
466
+ "repetition_penalty": 1.15,
467
+ }
468
+
469
+ # Verify that no exceptions are thrown
470
+ sentence_pipeline(prompts, **model_config)
471
+
472
+ mlflow.transformers.save_model(
473
+ transformers_model=sentence_pipeline,
474
+ path="/path/for/model",
475
+ task=task,
476
+ model_config=model_config,
477
+ )
478
+ prompt_template: {{ prompt_template }}
479
+ save_pretrained: {{ save_pretrained }}
480
+ kwargs: Optional additional configurations for transformers serialization.
481
+
482
+ """
483
+ import transformers
484
+
485
+ _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
486
+
487
+ path = pathlib.Path(path).absolute()
488
+
489
+ _validate_and_prepare_target_save_path(str(path))
490
+
491
+ code_dir_subpath = _validate_and_copy_code_paths(code_paths, str(path))
492
+
493
+ if isinstance(transformers_model, transformers.Pipeline):
494
+ _validate_transformers_model_dict(transformers_model)
495
+ built_pipeline = transformers_model
496
+ elif isinstance(transformers_model, dict):
497
+ _validate_transformers_model_dict(transformers_model)
498
+ built_pipeline = _build_pipeline_from_model_input(transformers_model, task=task)
499
+ elif isinstance(transformers_model, str):
500
+ # When a string is passed, it should be a path to model checkpoint in local storage or DBFS
501
+ if transformers_model.startswith("dbfs:"):
502
+ # Replace the DBFS URI to the actual mount point
503
+ transformers_model = transformers_model.replace("dbfs:", "/dbfs", 1)
504
+
505
+ if task is None:
506
+ raise MlflowException(
507
+ "The `task` argument must be specified when logging a model from a local "
508
+ "checkpoint. Please provide the task type of the pipeline.",
509
+ error_code=INVALID_PARAMETER_VALUE,
510
+ )
511
+
512
+ if not save_pretrained:
513
+ raise MlflowException(
514
+ "The `save_pretrained` argument must be set to True when logging a model from a "
515
+ "local checkpoint. Please set `save_pretrained=True`.",
516
+ error_code=INVALID_PARAMETER_VALUE,
517
+ )
518
+
519
+ # Create a dummy pipeline object to be used for saving the model
520
+ DummyModel = namedtuple("DummyModel", ["name_or_path"])
521
+ DummyPipeline = namedtuple("DummyPipeline", ["task", "model"])
522
+ built_pipeline = DummyPipeline(task=task, model=DummyModel(name_or_path=transformers_model))
523
+ else:
524
+ raise MlflowException(
525
+ "The `transformers_model` must be one of the following types: \n"
526
+ " (1) a transformers Pipeline\n"
527
+ " (2) a dictionary of components for a transformers Pipeline\n"
528
+ " (3) a path to a local/DBFS directory containing a transformers model checkpoint.\n"
529
+ f"received: {type(transformers_model)}",
530
+ error_code=INVALID_PARAMETER_VALUE,
531
+ )
532
+
533
+ # Verify that the model has not been loaded to distributed memory
534
+ # NB: transformers does not correctly save a model whose weights have been loaded
535
+ # using accelerate iff the model weights have been loaded using a device_map that is
536
+ # heterogeneous. There is a distinct possibility for a partial write to occur, causing an
537
+ # invalid state of the model's weights in this scenario. Hence, we raise.
538
+ # We might be able to remove this check once this PR is merged to transformers:
539
+ # https://github.com/huggingface/transformers/issues/20072
540
+ if _is_model_distributed_in_memory(built_pipeline.model):
541
+ raise MlflowException(
542
+ "The model that is attempting to be saved has been loaded into memory "
543
+ "with an incompatible configuration. If you are using the accelerate "
544
+ "library to load your model, please ensure that it is saved only after "
545
+ "loading with the default device mapping. Do not specify `device_map` "
546
+ "and please try again."
547
+ )
548
+
549
+ if mlflow_model is None:
550
+ mlflow_model = Model()
551
+
552
+ if task and task.startswith(_LLM_INFERENCE_TASK_PREFIX):
553
+ llm_inference_task = task
554
+
555
+ # For local checkpoint saving, we set built_pipeline.task to the original `task`
556
+ # argument value earlier, which is LLM v1 task. Thereby here we update it to the
557
+ # corresponding Transformers task type.
558
+ if isinstance(transformers_model, str):
559
+ default_task = _get_default_task_for_llm_inference_task(llm_inference_task)
560
+ built_pipeline = built_pipeline._replace(task=default_task)
561
+
562
+ _validate_llm_inference_task_type(llm_inference_task, built_pipeline.task)
563
+ else:
564
+ llm_inference_task = None
565
+
566
+ if llm_inference_task:
567
+ mlflow_model.signature = infer_signature_from_llm_inference_task(
568
+ llm_inference_task, signature
569
+ )
570
+ elif signature is not None:
571
+ mlflow_model.signature = signature
572
+
573
+ if input_example is not None:
574
+ input_example = format_input_example_for_special_cases(input_example, built_pipeline)
575
+ _save_example(mlflow_model, input_example, str(path))
576
+
577
+ if metadata is not None:
578
+ mlflow_model.metadata = metadata
579
+
580
+ # Check task consistency between model metadata and task argument
581
+ # NB: Using mlflow_model.metadata instead of passed metadata argument directly, because
582
+ # metadata argument is not directly propagated from log_model() to save_model(), instead
583
+ # via the mlflow_model object attribute.
584
+ if (
585
+ mlflow_model.metadata is not None
586
+ and (metadata_task := mlflow_model.metadata.get(_METADATA_LLM_INFERENCE_TASK_KEY))
587
+ and metadata_task != task
588
+ ):
589
+ raise MlflowException(
590
+ f"LLM v1 task type '{metadata_task}' is specified in "
591
+ "metadata, but it doesn't match the task type provided in the `task` argument: "
592
+ f"'{task}'. The mismatched task type may cause incorrect model inference behavior. "
593
+ "Please provide the correct LLM v1 task type in the `task` argument. E.g. "
594
+ f'`mlflow.transformers.save_model(task="{metadata_task}", ...)`',
595
+ error_code=INVALID_PARAMETER_VALUE,
596
+ )
597
+
598
+ if prompt_template is not None:
599
+ # prevent saving prompt templates for unsupported pipeline types
600
+ if built_pipeline.task not in _SUPPORTED_PROMPT_TEMPLATING_TASK_TYPES:
601
+ raise MlflowException(
602
+ f"Prompt templating is not supported for the `{built_pipeline.task}` task type. "
603
+ f"Supported task types are: {_SUPPORTED_PROMPT_TEMPLATING_TASK_TYPES}."
604
+ )
605
+
606
+ _validate_prompt_template(prompt_template)
607
+ if mlflow_model.metadata:
608
+ mlflow_model.metadata[FlavorKey.PROMPT_TEMPLATE] = prompt_template
609
+ else:
610
+ mlflow_model.metadata = {FlavorKey.PROMPT_TEMPLATE: prompt_template}
611
+
612
+ if is_peft_model(built_pipeline.model):
613
+ _logger.info(
614
+ "Overriding save_pretrained to False for PEFT models, following the Transformers "
615
+ "behavior. The PEFT adaptor and config will be saved, but the base model weights "
616
+ "will not and reference to the HuggingFace Hub repository will be logged instead."
617
+ )
618
+ # This will only save PEFT adaptor weights and config, not the base model weights
619
+ built_pipeline.model.save_pretrained(path.joinpath(_PEFT_ADAPTOR_DIR_NAME))
620
+ save_pretrained = False
621
+
622
+ if not save_pretrained and not is_valid_hf_repo_id(built_pipeline.model.name_or_path):
623
+ _logger.warning(
624
+ "The save_pretrained parameter is set to False, but the specified model does not "
625
+ "have a valid HuggingFace Hub repository identifier. Therefore, the weights will "
626
+ "be saved to disk anyway."
627
+ )
628
+ save_pretrained = True
629
+
630
+ # Create the flavor configuration
631
+ if isinstance(transformers_model, str):
632
+ flavor_conf = build_flavor_config_from_local_checkpoint(
633
+ transformers_model, built_pipeline.task, processor, torch_dtype
634
+ )
635
+ else:
636
+ flavor_conf = build_flavor_config(built_pipeline, processor, torch_dtype, save_pretrained)
637
+
638
+ if llm_inference_task:
639
+ flavor_conf.update({_LLM_INFERENCE_TASK_KEY: llm_inference_task})
640
+ if mlflow_model.metadata:
641
+ mlflow_model.metadata[_METADATA_LLM_INFERENCE_TASK_KEY] = llm_inference_task
642
+ else:
643
+ mlflow_model.metadata = {_METADATA_LLM_INFERENCE_TASK_KEY: llm_inference_task}
644
+
645
+ mlflow_model.add_flavor(
646
+ FLAVOR_NAME,
647
+ transformers_version=transformers.__version__,
648
+ code=code_dir_subpath,
649
+ **flavor_conf,
650
+ )
651
+
652
+ # Flavor config should not be mutated after being added to MLModel
653
+ flavor_conf = MappingProxyType(flavor_conf)
654
+
655
+ # Save pipeline model and components weights
656
+ if save_pretrained:
657
+ if isinstance(transformers_model, str):
658
+ save_local_checkpoint(path, transformers_model, flavor_conf, processor)
659
+ else:
660
+ save_pipeline_pretrained_weights(path, built_pipeline, flavor_conf, processor)
661
+ else:
662
+ repo = built_pipeline.model.name_or_path
663
+ _logger.info(
664
+ "Skipping saving pretrained model weights to disk as the save_pretrained argument"
665
+ f"is set to False. The reference to the HuggingFace Hub repository {repo} "
666
+ "will be logged instead."
667
+ )
668
+
669
+ model_name = built_pipeline.model.name_or_path
670
+
671
+ # Get the model card from either the argument or the HuggingFace marketplace
672
+ card_data = model_card or _fetch_model_card(model_name)
673
+
674
+ # If the card data can be acquired, save the text and the data separately
675
+ _write_card_data(card_data, path)
676
+
677
+ # Write the license information (or guidance) along with the model
678
+ _write_license_information(model_name, card_data, path)
679
+
680
+ # Only allow a subset of task types to have a pyfunc definition.
681
+ # Currently supported types are NLP-based language tasks which have a pipeline definition
682
+ # consisting exclusively of a Model and a Tokenizer.
683
+ if (
684
+ # TODO: when a local checkpoint path is provided as a model, we assume it is eligible
685
+ # for pyfunc prediction. This may not be true for all cases, so we should revisit this.
686
+ isinstance(transformers_model, str) or _should_add_pyfunc_to_model(built_pipeline)
687
+ ):
688
+ if mlflow_model.signature is None:
689
+ mlflow_model.signature = infer_or_get_default_signature(
690
+ pipeline=built_pipeline,
691
+ example=input_example,
692
+ model_config=model_config,
693
+ flavor_config=flavor_conf,
694
+ )
695
+
696
+ # if pipeline is text-generation and a prompt template is specified,
697
+ # provide the return_full_text=False config by default to avoid confusing
698
+ # extra text for end-users
699
+ if prompt_template is not None and built_pipeline.task == "text-generation":
700
+ return_full_text_key = "return_full_text"
701
+ model_config = model_config or {}
702
+ if return_full_text_key not in model_config:
703
+ model_config[return_full_text_key] = False
704
+ _logger.info(_PROMPT_TEMPLATE_RETURN_FULL_TEXT_INFO)
705
+
706
+ pyfunc.add_to_model(
707
+ mlflow_model,
708
+ loader_module="mlflow.transformers",
709
+ conda_env=_CONDA_ENV_FILE_NAME,
710
+ python_env=_PYTHON_ENV_FILE_NAME,
711
+ code=code_dir_subpath,
712
+ model_config=model_config,
713
+ )
714
+ else:
715
+ if processor:
716
+ reason = "the model has been saved with a 'processor' argument supplied."
717
+ else:
718
+ reason = (
719
+ "the model is not a language-based model and requires a complex input type "
720
+ "that is currently not supported."
721
+ )
722
+ _logger.warning(
723
+ f"This model is unable to be used for pyfunc prediction because {reason} "
724
+ f"The pyfunc flavor will not be added to the Model."
725
+ )
726
+
727
+ if size := get_total_file_size(path):
728
+ mlflow_model.model_size_bytes = size
729
+
730
+ mlflow_model.save(str(path.joinpath(MLMODEL_FILE_NAME)))
731
+
732
+ if conda_env is None:
733
+ if pip_requirements is None:
734
+ default_reqs = get_default_pip_requirements(built_pipeline.model)
735
+ if isinstance(transformers_model, str) or is_peft_model(built_pipeline.model):
736
+ _logger.info(
737
+ "A local checkpoint path or PEFT model is given as the `transformers_model`. "
738
+ "To avoid loading the full model into memory, we don't infer the pip "
739
+ "requirement for the model. Instead, we will use the default requirements, "
740
+ "but it may not capture all required pip libraries for the model. Consider "
741
+ "providing the pip requirements explicitly."
742
+ )
743
+ else:
744
+ # Infer the pip requirements with a timeout to avoid hanging at prediction
745
+ inferred_reqs = infer_pip_requirements(
746
+ model_uri=str(path),
747
+ flavor=FLAVOR_NAME,
748
+ fallback=default_reqs,
749
+ timeout=MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT.get(),
750
+ )
751
+ default_reqs = set(inferred_reqs).union(default_reqs)
752
+ default_reqs = sorted(default_reqs)
753
+ else:
754
+ default_reqs = None
755
+ conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
756
+ default_reqs, pip_requirements, extra_pip_requirements
757
+ )
758
+ else:
759
+ conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
760
+
761
+ with path.joinpath(_CONDA_ENV_FILE_NAME).open("w") as f:
762
+ yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
763
+
764
+ if pip_constraints:
765
+ write_to(str(path.joinpath(_CONSTRAINTS_FILE_NAME)), "\n".join(pip_constraints))
766
+
767
+ write_to(str(path.joinpath(_REQUIREMENTS_FILE_NAME)), "\n".join(pip_requirements))
768
+
769
+ _PythonEnv.current().to_yaml(str(path.joinpath(_PYTHON_ENV_FILE_NAME)))
770
+
771
+
772
+ @docstring_version_compatibility_warning(integration_name=FLAVOR_NAME)
773
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
774
+ def log_model(
775
+ transformers_model,
776
+ artifact_path: Optional[str] = None,
777
+ processor=None,
778
+ task: Optional[str] = None,
779
+ torch_dtype: Optional[torch.dtype] = None,
780
+ model_card=None,
781
+ code_paths: Optional[list[str]] = None,
782
+ registered_model_name: Optional[str] = None,
783
+ signature: Optional[ModelSignature] = None,
784
+ input_example: Optional[ModelInputExample] = None,
785
+ await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
786
+ pip_requirements: Optional[Union[list[str], str]] = None,
787
+ extra_pip_requirements: Optional[Union[list[str], str]] = None,
788
+ conda_env=None,
789
+ metadata: Optional[dict[str, Any]] = None,
790
+ model_config: Optional[dict[str, Any]] = None,
791
+ prompt_template: Optional[str] = None,
792
+ save_pretrained: bool = True,
793
+ prompts: Optional[list[Union[str, Prompt]]] = None,
794
+ name: Optional[str] = None,
795
+ params: Optional[dict[str, Any]] = None,
796
+ tags: Optional[dict[str, Any]] = None,
797
+ model_type: Optional[str] = None,
798
+ step: int = 0,
799
+ model_id: Optional[str] = None,
800
+ **kwargs,
801
+ ):
802
+ """
803
+ Log a ``transformers`` object as an MLflow artifact for the current run. Note that
804
+ logging transformers models with custom code (i.e. models that require
805
+ ``trust_remote_code=True``) requires ``transformers >= 4.26.0``.
806
+
807
+ Args:
808
+ transformers_model:
809
+ The transformers model to save. This can be one of the following format:
810
+
811
+ 1. A transformers `Pipeline` instance.
812
+ 2. A dictionary that maps required components of a pipeline to the named keys
813
+ of ["model", "image_processor", "tokenizer", "feature_extractor"].
814
+ The `model` key in the dictionary must map to a value that inherits from
815
+ `PreTrainedModel`, `TFPreTrainedModel`, or `FlaxPreTrainedModel`.
816
+ All other component entries in the dictionary must support the defined task
817
+ type that is associated with the base model type configuration.
818
+ 3. A string that represents a path to a local/DBFS directory containing a model
819
+ checkpoint. The directory must contain a `config.json` file that is required
820
+ for loading the transformers model. This is particularly useful when logging
821
+ a model that cannot be loaded into memory for serialization.
822
+
823
+ An example of specifying a `Pipeline` from a default pipeline instantiation:
824
+
825
+ .. code-block:: python
826
+
827
+ from transformers import pipeline
828
+
829
+ qa_pipe = pipeline("question-answering", "csarron/mobilebert-uncased-squad-v2")
830
+
831
+ with mlflow.start_run():
832
+ mlflow.transformers.log_model(
833
+ transformers_model=qa_pipe,
834
+ name="model",
835
+ )
836
+
837
+ An example of specifying component-level parts of a transformers model is shown below:
838
+
839
+ .. code-block:: python
840
+
841
+ from transformers import MobileBertForQuestionAnswering, AutoTokenizer
842
+
843
+ architecture = "csarron/mobilebert-uncased-squad-v2"
844
+ tokenizer = AutoTokenizer.from_pretrained(architecture)
845
+ model = MobileBertForQuestionAnswering.from_pretrained(architecture)
846
+
847
+ with mlflow.start_run():
848
+ components = {
849
+ "model": model,
850
+ "tokenizer": tokenizer,
851
+ }
852
+ mlflow.transformers.log_model(
853
+ transformers_model=components,
854
+ name="model",
855
+ )
856
+
857
+ An example of specifying a local checkpoint path is shown below:
858
+
859
+ .. code-block:: python
860
+
861
+ with mlflow.start_run():
862
+ mlflow.transformers.log_model(
863
+ transformers_model="path/to/local/checkpoint",
864
+ name="model",
865
+ )
866
+
867
+ artifact_path: Deprecated. Use `name` instead.
868
+ processor: An optional ``Processor`` subclass object. Some model architectures,
869
+ particularly multi-modal types, utilize Processors to combine text
870
+ encoding and image or audio encoding in a single entrypoint.
871
+
872
+ .. Note:: If a processor is supplied when logging a model, the
873
+ model will be unavailable for loading as a ``Pipeline`` or for usage
874
+ with pyfunc inference.
875
+ task: The transformers-specific task type of the model. These strings are utilized so
876
+ that a pipeline can be created with the appropriate internal call architecture
877
+ to meet the needs of a given model. If this argument is not specified, the
878
+ pipeline utilities within the transformers library will be used to infer the
879
+ correct task type. If the value specified is not a supported type within the
880
+ version of transformers that is currently installed, an Exception will be thrown.
881
+ torch_dtype: The Pytorch dtype applied to the model when loading back. This is useful
882
+ when you want to save the model with a specific dtype that is different from the
883
+ dtype of the model when it was trained. If not specified, the current dtype of the
884
+ model instance will be used.
885
+ model_card: An Optional `ModelCard` instance from `huggingface-hub`. If provided, the
886
+ contents of the model card will be saved along with the provided
887
+ `transformers_model`. If not provided, an attempt will be made to fetch
888
+ the card from the base pretrained model that is provided (or the one that is
889
+ included within a provided `Pipeline`).
890
+
891
+ .. Note:: In order for a ModelCard to be fetched (if not provided),
892
+ the huggingface_hub package must be installed and the version
893
+ must be >=0.10.0
894
+
895
+ code_paths: {{ code_paths }}
896
+ registered_model_name: If given, create a model
897
+ version under ``registered_model_name``, also creating a
898
+ registered model if one with the given name does not exist.
899
+ signature: A Model Signature object that describes the input and output Schema of the
900
+ model. The model signature can be inferred using `infer_signature` function
901
+ of `mlflow.models.signature`.
902
+
903
+ .. code-block:: python
904
+ :caption: Example
905
+
906
+ from mlflow.models import infer_signature
907
+ from mlflow.transformers import generate_signature_output
908
+ from transformers import pipeline
909
+
910
+ en_to_de = pipeline("translation_en_to_de")
911
+
912
+ data = "MLflow is great!"
913
+ output = generate_signature_output(en_to_de, data)
914
+ signature = infer_signature(data, output)
915
+
916
+ with mlflow.start_run() as run:
917
+ mlflow.transformers.log_model(
918
+ transformers_model=en_to_de,
919
+ name="english_to_german_translator",
920
+ signature=signature,
921
+ input_example=data,
922
+ )
923
+
924
+ model_uri = f"runs:/{run.info.run_id}/english_to_german_translator"
925
+ loaded = mlflow.pyfunc.load_model(model_uri)
926
+
927
+ print(loaded.predict(data))
928
+ # MLflow ist großartig!
929
+
930
+ If an input_example is provided and the signature is not, a signature will
931
+ be inferred automatically and applied to the MLmodel file iff the
932
+ pipeline type is a text-based model (NLP). If the pipeline type is not
933
+ a supported type, this inference functionality will not function correctly
934
+ and a warning will be issued. In order to ensure that a precise signature
935
+ is logged, it is recommended to explicitly provide one.
936
+ input_example: {{ input_example }}
937
+ await_registration_for: Number of seconds to wait for the model version
938
+ to finish being created and is in ``READY`` status.
939
+ By default, the function waits for five minutes.
940
+ Specify 0 or None to skip waiting.
941
+ pip_requirements: {{ pip_requirements }}
942
+ extra_pip_requirements: {{ extra_pip_requirements }}
943
+ conda_env: {{ conda_env }}
944
+ metadata: {{ metadata }}
945
+ model_config:
946
+ A dict of valid overrides that can be applied to a pipeline instance during inference.
947
+ These arguments are used exclusively for the case of loading the model as a ``pyfunc``
948
+ Model or for use in Spark. These values are not applied to a returned Pipeline from a
949
+ call to ``mlflow.transformers.load_model()``
950
+
951
+ .. Warning:: If the key provided is not compatible with either the
952
+ Pipeline instance for the task provided or is not a valid
953
+ override to any arguments available in the Model, an
954
+ Exception will be raised at runtime. It is very important
955
+ to validate the entries in this dictionary to ensure
956
+ that they are valid prior to saving or logging.
957
+
958
+ An example of providing overrides for a question generation model:
959
+
960
+ .. code-block:: python
961
+
962
+ from transformers import pipeline, AutoTokenizer
963
+
964
+ task = "text-generation"
965
+ architecture = "gpt2"
966
+
967
+ sentence_pipeline = pipeline(
968
+ task=task,
969
+ tokenizer=AutoTokenizer.from_pretrained(architecture),
970
+ model=architecture,
971
+ )
972
+
973
+ # Validate that the overrides function
974
+ prompts = ["Generative models are", "I'd like a coconut so that I can"]
975
+
976
+ # validation of config prior to save or log
977
+ model_config = {
978
+ "top_k": 2,
979
+ "num_beams": 5,
980
+ "max_length": 30,
981
+ "temperature": 0.62,
982
+ "top_p": 0.85,
983
+ "repetition_penalty": 1.15,
984
+ }
985
+
986
+ # Verify that no exceptions are thrown
987
+ sentence_pipeline(prompts, **model_config)
988
+
989
+ with mlflow.start_run():
990
+ mlflow.transformers.log_model(
991
+ transformers_model=sentence_pipeline,
992
+ name="my_sentence_generator",
993
+ task=task,
994
+ model_config=model_config,
995
+ )
996
+ prompt_template: {{ prompt_template }}
997
+ save_pretrained: {{ save_pretrained }}
998
+ prompts: {{ prompts }}
999
+ name: {{ name }}
1000
+ params: {{ params }}
1001
+ tags: {{ tags }}
1002
+ model_type: {{ model_type }}
1003
+ step: {{ step }}
1004
+ model_id: {{ model_id }}
1005
+ kwargs: Additional arguments for :py:class:`mlflow.models.model.Model`
1006
+ """
1007
+ return Model.log(
1008
+ artifact_path=artifact_path,
1009
+ name=name,
1010
+ flavor=sys.modules[__name__], # Get the current module.
1011
+ registered_model_name=registered_model_name,
1012
+ await_registration_for=await_registration_for,
1013
+ metadata=metadata,
1014
+ transformers_model=transformers_model,
1015
+ processor=processor,
1016
+ task=task,
1017
+ torch_dtype=torch_dtype,
1018
+ model_card=model_card,
1019
+ conda_env=conda_env,
1020
+ code_paths=code_paths,
1021
+ signature=signature,
1022
+ input_example=input_example,
1023
+ # NB: We don't validate the serving input if the provided model is a path
1024
+ # to a local checkpoint. This is because the purpose of supporting that
1025
+ # input format is to avoid loading large model into memory. Serving input
1026
+ # validation loads the model into memory and make prediction, which is
1027
+ # expensive and can cause OOM errors.
1028
+ validate_serving_input=not isinstance(transformers_model, str),
1029
+ pip_requirements=pip_requirements,
1030
+ extra_pip_requirements=extra_pip_requirements,
1031
+ model_config=model_config,
1032
+ prompt_template=prompt_template,
1033
+ save_pretrained=save_pretrained,
1034
+ prompts=prompts,
1035
+ params=params,
1036
+ tags=tags,
1037
+ model_type=model_type,
1038
+ step=step,
1039
+ model_id=model_id,
1040
+ **kwargs,
1041
+ )
1042
+
1043
+
1044
+ @docstring_version_compatibility_warning(integration_name=FLAVOR_NAME)
1045
+ def load_model(
1046
+ model_uri: str, dst_path: Optional[str] = None, return_type="pipeline", device=None, **kwargs
1047
+ ):
1048
+ """
1049
+ Load a ``transformers`` object from a local file or a run.
1050
+
1051
+ Args:
1052
+ model_uri: The location, in URI format, of the MLflow model. For example:
1053
+
1054
+ - ``/Users/me/path/to/local/model``
1055
+ - ``relative/path/to/local/model``
1056
+ - ``s3://my_bucket/path/to/model``
1057
+ - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
1058
+ - ``mlflow-artifacts:/path/to/model``
1059
+
1060
+ For more information about supported URI schemes, see
1061
+ `Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
1062
+ artifact-locations>`_.
1063
+ dst_path: The local filesystem path to utilize for downloading the model artifact.
1064
+ This directory must already exist if provided. If unspecified, a local output
1065
+ path will be created.
1066
+ return_type: A return type modifier for the stored ``transformers`` object.
1067
+ If set as "components", the return type will be a dictionary of the saved
1068
+ individual components of either the ``Pipeline`` or the pre-trained model.
1069
+ The components for NLP-focused models will typically consist of a
1070
+ return representation as shown below with a text-classification example:
1071
+
1072
+ .. code-block:: python
1073
+
1074
+ {"model": BertForSequenceClassification, "tokenizer": BertTokenizerFast}
1075
+
1076
+ Vision models will return an ``ImageProcessor`` instance of the appropriate
1077
+ type, while multi-modal models will return both a ``FeatureExtractor`` and
1078
+ a ``Tokenizer`` along with the model.
1079
+ Returning "components" can be useful for certain model types that do not
1080
+ have the desired pipeline return types for certain use cases.
1081
+ If set as "pipeline", the model, along with any and all required
1082
+ ``Tokenizer``, ``FeatureExtractor``, ``Processor``, or ``ImageProcessor``
1083
+ objects will be returned within a ``Pipeline`` object of the appropriate
1084
+ type defined by the ``task`` set by the model instance type. To override
1085
+ this behavior, supply a valid ``task`` argument during model logging or
1086
+ saving. Default is "pipeline".
1087
+ device: The device on which to load the model. Default is None. Use 0 to
1088
+ load to the default GPU.
1089
+ kwargs: Optional configuration options for loading of a ``transformers`` object.
1090
+ For information on parameters and their usage, see
1091
+ `transformers documentation <https://huggingface.co/docs/transformers/index>`_.
1092
+
1093
+ Returns:
1094
+ A ``transformers`` model instance or a dictionary of components
1095
+ """
1096
+
1097
+ if return_type not in _SUPPORTED_RETURN_TYPES:
1098
+ raise MlflowException(
1099
+ f"The specified return_type mode '{return_type}' is unsupported. "
1100
+ "Please select one of: 'pipeline' or 'components'.",
1101
+ error_code=INVALID_PARAMETER_VALUE,
1102
+ )
1103
+
1104
+ model_uri = str(model_uri)
1105
+
1106
+ local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
1107
+
1108
+ flavor_config = _get_flavor_configuration_from_uri(model_uri, FLAVOR_NAME, _logger)
1109
+
1110
+ if return_type == "pipeline" and FlavorKey.PROCESSOR_TYPE in flavor_config:
1111
+ raise MlflowException(
1112
+ "This model has been saved with a processor. Processor objects are "
1113
+ "not compatible with Pipelines. Please load this model by specifying "
1114
+ "the 'return_type'='components'.",
1115
+ error_code=BAD_REQUEST,
1116
+ )
1117
+
1118
+ _add_code_from_conf_to_system_path(local_model_path, flavor_config)
1119
+
1120
+ return _load_model(local_model_path, flavor_config, return_type, device, **kwargs)
1121
+
1122
+
1123
+ def persist_pretrained_model(model_uri: str) -> None:
1124
+ """
1125
+ Persist Transformers pretrained model weights to the artifacts directory of the specified
1126
+ model_uri. This API is primary used for updating an MLflow Model that was logged or saved
1127
+ with setting save_pretrained=False. Such models cannot be registered to Databricks Workspace
1128
+ Model Registry, due to the full pretrained model weights being absent in the artifacts.
1129
+ Transformers models saved in this mode store only the reference to the HuggingFace Hub
1130
+ repository. This API will download the model weights from the HuggingFace Hub repository
1131
+ and save them in the artifacts of the given model_uri so that the model can be registered
1132
+ to Databricks Workspace Model Registry.
1133
+
1134
+ Args:
1135
+ model_uri: The URI of the existing MLflow Model of the Transformers flavor.
1136
+ It must be logged/saved with save_pretrained=False.
1137
+
1138
+ Examples:
1139
+
1140
+ .. code-block:: python
1141
+
1142
+ import mlflow
1143
+
1144
+ # Saving a model with save_pretrained=False
1145
+ with mlflow.start_run() as run:
1146
+ model = pipeline("question-answering", "csarron/mobilebert-uncased-squad-v2")
1147
+ mlflow.transformers.log_model(
1148
+ transformers_model=model, name="pipeline", save_pretrained=False
1149
+ )
1150
+
1151
+ # The model cannot be registered to the Model Registry as it is
1152
+ try:
1153
+ mlflow.register_model(f"runs:/{run.info.run_id}/pipeline", "qa_pipeline")
1154
+ except MlflowException as e:
1155
+ print(e.message)
1156
+
1157
+ # Use this API to persist the pretrained model weights
1158
+ mlflow.transformers.persist_pretrained_model(f"runs:/{run.info.run_id}/pipeline")
1159
+
1160
+ # Now the model can be registered to the Model Registry
1161
+ mlflow.register_model(f"runs:/{run.info.run_id}/pipeline", "qa_pipeline")
1162
+ """
1163
+ # Check if the model weight already exists in the model artifact before downloading
1164
+ root_uri, artifact_path = _get_root_uri_and_artifact_path(model_uri)
1165
+ artifact_repo = get_artifact_repository(root_uri)
1166
+
1167
+ file_names = [os.path.basename(f.path) for f in artifact_repo.list_artifacts(artifact_path)]
1168
+ if MLMODEL_FILE_NAME in file_names and _MODEL_BINARY_FILE_NAME in file_names:
1169
+ _logger.info(
1170
+ "The full pretrained model weight already exists in the artifact directory of the "
1171
+ f"specified model_uri: {model_uri}. No action is needed."
1172
+ )
1173
+ return
1174
+
1175
+ with TempDir() as tmp_dir:
1176
+ local_model_path = artifact_repo.download_artifacts(artifact_path, dst_path=tmp_dir.path())
1177
+ pipeline = load_model(local_model_path, return_type="pipeline")
1178
+
1179
+ # Update MLModel flavor config
1180
+ mlmodel_path = os.path.join(local_model_path, MLMODEL_FILE_NAME)
1181
+ model_conf = Model.load(mlmodel_path)
1182
+ updated_flavor_conf = update_flavor_conf_to_persist_pretrained_model(
1183
+ model_conf.flavors[FLAVOR_NAME]
1184
+ )
1185
+ model_conf.add_flavor(FLAVOR_NAME, **updated_flavor_conf)
1186
+ model_conf.save(mlmodel_path)
1187
+
1188
+ # Save pretrained weights
1189
+ save_pipeline_pretrained_weights(
1190
+ pathlib.Path(local_model_path), pipeline, updated_flavor_conf
1191
+ )
1192
+
1193
+ # Upload updated local artifacts to MLflow
1194
+ for dir_to_upload in (_MODEL_BINARY_FILE_NAME, _COMPONENTS_BINARY_DIR_NAME):
1195
+ local_dir = os.path.join(local_model_path, dir_to_upload)
1196
+ if not os.path.isdir(local_dir):
1197
+ continue
1198
+
1199
+ try:
1200
+ artifact_repo.log_artifacts(local_dir, os.path.join(artifact_path, dir_to_upload))
1201
+ except Exception as e:
1202
+ # NB: log_artifacts method doesn't support rollback for partial uploads,
1203
+ raise MlflowException(
1204
+ f"Failed to upload {local_dir} to the existing model_uri due to {e}."
1205
+ "Some other files may have been uploaded."
1206
+ ) from e
1207
+
1208
+ # Upload MLModel file
1209
+ artifact_repo.log_artifact(mlmodel_path, artifact_path)
1210
+
1211
+ _logger.info(f"The pretrained model has been successfully persisted in {model_uri}.")
1212
+
1213
+
1214
+ def _is_model_distributed_in_memory(transformers_model):
1215
+ """Check if the model is distributed across multiple devices in memory."""
1216
+
1217
+ # Check if the model attribute exists. If not, accelerate was not used and the model can
1218
+ # be safely saved
1219
+ if not hasattr(transformers_model, "hf_device_map"):
1220
+ return False
1221
+ # If the device map has more than one unique value entry, then the weights are not within
1222
+ # a contiguous memory system (VRAM, SYS, or DISK) and thus cannot be safely saved.
1223
+ return len(set(transformers_model.hf_device_map.values())) > 1
1224
+
1225
+
1226
+ # This function attempts to determine if a GPU is available for the PyTorch and TensorFlow libraries
1227
+ def is_gpu_available():
1228
+ # try pytorch and if it fails, try tf
1229
+ is_gpu = None
1230
+ try:
1231
+ import torch
1232
+
1233
+ is_gpu = torch.cuda.is_available()
1234
+ except ImportError:
1235
+ pass
1236
+ if is_gpu is None:
1237
+ try:
1238
+ import tensorflow as tf
1239
+
1240
+ is_gpu = tf.test.is_gpu_available()
1241
+ except ImportError:
1242
+ pass
1243
+ if is_gpu is None:
1244
+ is_gpu = False
1245
+ return is_gpu
1246
+
1247
+
1248
+ def _load_model(path: str, flavor_config, return_type: str, device=None, **kwargs):
1249
+ """
1250
+ Loads components from a locally serialized ``Pipeline`` object.
1251
+ """
1252
+ import transformers
1253
+
1254
+ conf = {
1255
+ "task": flavor_config[FlavorKey.TASK],
1256
+ }
1257
+ if framework := flavor_config.get(FlavorKey.FRAMEWORK):
1258
+ conf["framework"] = framework
1259
+
1260
+ # Note that we don't set the device in the conf yet because device is
1261
+ # incompatible with device_map.
1262
+ accelerate_model_conf = {}
1263
+ if MLFLOW_HUGGINGFACE_USE_DEVICE_MAP.get():
1264
+ device_map_strategy = MLFLOW_HUGGINGFACE_DEVICE_MAP_STRATEGY.get()
1265
+ conf["device_map"] = device_map_strategy
1266
+ accelerate_model_conf["device_map"] = device_map_strategy
1267
+ # Cannot use device with device_map
1268
+ if device is not None:
1269
+ raise MlflowException.invalid_parameter_value(
1270
+ "The environment variable MLFLOW_HUGGINGFACE_USE_DEVICE_MAP is set to True, but "
1271
+ f"the `device` argument is provided with value {device}. The device_map and "
1272
+ "`device` argument cannot be used together. Set MLFLOW_HUGGINGFACE_USE_DEVICE_MAP "
1273
+ "to False to specify a particular device ID, or pass None for the `device` "
1274
+ "argument to use device_map."
1275
+ )
1276
+ device = None
1277
+ elif device is None:
1278
+ if device_value := MLFLOW_DEFAULT_PREDICTION_DEVICE.get():
1279
+ try:
1280
+ device = int(device_value)
1281
+ except ValueError:
1282
+ _logger.warning(
1283
+ f"Invalid value for {MLFLOW_DEFAULT_PREDICTION_DEVICE}: {device_value}. "
1284
+ f"{MLFLOW_DEFAULT_PREDICTION_DEVICE} value must be an integer. "
1285
+ f"Setting to: {_TRANSFORMERS_DEFAULT_CPU_DEVICE_ID}."
1286
+ )
1287
+ device = _TRANSFORMERS_DEFAULT_CPU_DEVICE_ID
1288
+ elif is_gpu_available():
1289
+ device = _TRANSFORMERS_DEFAULT_GPU_DEVICE_ID
1290
+
1291
+ if device is not None:
1292
+ conf["device"] = device
1293
+ accelerate_model_conf["device"] = device
1294
+
1295
+ if dtype_val := kwargs.get(_TORCH_DTYPE_KEY) or flavor_config.get(FlavorKey.TORCH_DTYPE):
1296
+ if isinstance(dtype_val, str):
1297
+ dtype_val = _deserialize_torch_dtype(dtype_val)
1298
+ conf[_TORCH_DTYPE_KEY] = dtype_val
1299
+ flavor_config[_TORCH_DTYPE_KEY] = dtype_val
1300
+ accelerate_model_conf[_TORCH_DTYPE_KEY] = dtype_val
1301
+
1302
+ accelerate_model_conf["low_cpu_mem_usage"] = MLFLOW_HUGGINGFACE_USE_LOW_CPU_MEM_USAGE.get()
1303
+
1304
+ # Load model and components either from local or from HuggingFace Hub. We check for the
1305
+ # presence of the model revision (a commit hash of the hub repository) that is only present
1306
+ # in the model logged with `save_pretrained=False
1307
+ if FlavorKey.MODEL_REVISION not in flavor_config:
1308
+ model_and_components = load_model_and_components_from_local(
1309
+ path=pathlib.Path(path),
1310
+ flavor_conf=flavor_config,
1311
+ accelerate_conf=accelerate_model_conf,
1312
+ device=device,
1313
+ )
1314
+ else:
1315
+ model_and_components = load_model_and_components_from_huggingface_hub(
1316
+ flavor_conf=flavor_config, accelerate_conf=accelerate_model_conf, device=device
1317
+ )
1318
+
1319
+ # Load and apply PEFT adaptor if saved
1320
+ if peft_adapter_dir := flavor_config.get(FlavorKey.PEFT, None):
1321
+ model_and_components[FlavorKey.MODEL] = get_model_with_peft_adapter(
1322
+ base_model=model_and_components[FlavorKey.MODEL],
1323
+ peft_adapter_path=os.path.join(path, peft_adapter_dir),
1324
+ )
1325
+
1326
+ conf = {**conf, **model_and_components}
1327
+
1328
+ if return_type == "pipeline":
1329
+ conf.update(**kwargs)
1330
+ with suppress_logs("transformers.pipelines.base", filter_regex=_PEFT_PIPELINE_ERROR_MSG):
1331
+ return transformers.pipeline(**conf)
1332
+ elif return_type == "components":
1333
+ return conf
1334
+
1335
+
1336
+ def _fetch_model_card(model_name):
1337
+ """
1338
+ Attempts to retrieve the model card for the specified model architecture iff the
1339
+ `huggingface_hub` library is installed. If a card cannot be found in the registry or
1340
+ the library is not installed, returns None.
1341
+ """
1342
+ try:
1343
+ import huggingface_hub as hub
1344
+ except ImportError:
1345
+ _logger.warning(
1346
+ "Unable to store ModelCard data with the saved artifact. In order to "
1347
+ "preserve this information, please install the huggingface_hub package "
1348
+ "by running 'pip install huggingingface_hub>0.10.0'"
1349
+ )
1350
+ return
1351
+
1352
+ if hasattr(hub, "ModelCard"):
1353
+ try:
1354
+ return hub.ModelCard.load(model_name)
1355
+ except Exception as e:
1356
+ _logger.warning(f"The model card could not be retrieved from the hub due to {e}")
1357
+ else:
1358
+ _logger.warning(
1359
+ "The version of huggingface_hub that is installed does not provide "
1360
+ f"ModelCard functionality. You have version {hub.__version__} installed. "
1361
+ "Update huggingface_hub to >= '0.10.0' to retrieve the ModelCard data."
1362
+ )
1363
+
1364
+
1365
+ def _write_card_data(card_data, path):
1366
+ """
1367
+ Writes the card data, if specified or available, to the provided path in two separate files
1368
+ """
1369
+ if card_data:
1370
+ try:
1371
+ path.joinpath(_CARD_TEXT_FILE_NAME).write_text(card_data.text, encoding="utf-8")
1372
+ except UnicodeError as e:
1373
+ _logger.warning(f"Unable to save the model card text due to: {e}")
1374
+
1375
+ with path.joinpath(_CARD_DATA_FILE_NAME).open("w") as file:
1376
+ yaml.safe_dump(
1377
+ card_data.data.to_dict(), stream=file, default_flow_style=False, encoding="utf-8"
1378
+ )
1379
+
1380
+
1381
+ def _extract_license_file_from_repository(model_name):
1382
+ """Returns the top-level file inventory of `RepoFile` objects from the huggingface hub"""
1383
+ try:
1384
+ import huggingface_hub as hub
1385
+ except ImportError:
1386
+ _logger.debug(
1387
+ f"Unable to list repository contents for the model repo {model_name}. In order "
1388
+ "to enable repository listing functionality, please install the huggingface_hub "
1389
+ "package by running `pip install huggingface_hub>0.10.0"
1390
+ )
1391
+ return
1392
+ try:
1393
+ files = hub.list_repo_files(model_name)
1394
+ return next(file for file in files if _LICENSE_FILE_PATTERN.search(file))
1395
+ except Exception as e:
1396
+ _logger.debug(
1397
+ f"Failed to retrieve repository file listing data for {model_name} due to {e}"
1398
+ )
1399
+
1400
+
1401
+ def _write_license_information(model_name, card_data, path):
1402
+ """Writes the license file or instructions to retrieve license information."""
1403
+
1404
+ fallback = (
1405
+ f"A license file could not be found for the '{model_name}' repository. \n"
1406
+ "To ensure that you are in compliance with the license requirements for this "
1407
+ f"model, please visit the model repository here: https://huggingface.co/{model_name}"
1408
+ )
1409
+
1410
+ if license_file := _extract_license_file_from_repository(model_name):
1411
+ try:
1412
+ import huggingface_hub as hub
1413
+
1414
+ license_location = hub.hf_hub_download(repo_id=model_name, filename=license_file)
1415
+ except Exception as e:
1416
+ _logger.warning(f"Failed to download the license file due to: {e}")
1417
+ else:
1418
+ local_license_path = pathlib.Path(license_location)
1419
+ target_path = path.joinpath(local_license_path.name)
1420
+ try:
1421
+ shutil.copy(local_license_path, target_path)
1422
+ return
1423
+ except Exception as e:
1424
+ _logger.warning(f"The license file could not be copied due to: {e}")
1425
+
1426
+ # Fallback or card data license info
1427
+ if card_data and card_data.data.license != "other":
1428
+ fallback = f"{fallback}\nThe declared license type is: '{card_data.data.license}'"
1429
+ else:
1430
+ _logger.warning(
1431
+ "Unable to find license information for this model. Please verify "
1432
+ "permissible usage for the model you are storing prior to use."
1433
+ )
1434
+ path.joinpath(_LICENSE_FILE_NAME).write_text(fallback, encoding="utf-8")
1435
+
1436
+
1437
+ def _get_supported_pretrained_model_types():
1438
+ """
1439
+ Users might not have all the necessary libraries installed to determine the supported model
1440
+ """
1441
+
1442
+ supported_model_types = ()
1443
+
1444
+ try:
1445
+ from transformers import FlaxPreTrainedModel
1446
+
1447
+ supported_model_types += (FlaxPreTrainedModel,)
1448
+ except Exception:
1449
+ pass
1450
+
1451
+ try:
1452
+ from transformers import PreTrainedModel
1453
+
1454
+ supported_model_types += (PreTrainedModel,)
1455
+ except Exception:
1456
+ pass
1457
+
1458
+ try:
1459
+ from transformers import TFPreTrainedModel
1460
+
1461
+ supported_model_types += (TFPreTrainedModel,)
1462
+ except Exception:
1463
+ pass
1464
+
1465
+ return supported_model_types
1466
+
1467
+
1468
+ def _build_pipeline_from_model_input(model_dict: dict[str, Any], task: Optional[str]) -> Pipeline:
1469
+ """
1470
+ Utility for generating a pipeline from component parts. If required components are not
1471
+ specified, use the transformers library pipeline component validation to force raising an
1472
+ exception. The underlying Exception thrown in transformers is verbose enough for diagnosis.
1473
+ """
1474
+
1475
+ from transformers import pipeline
1476
+
1477
+ model = model_dict[FlavorKey.MODEL]
1478
+
1479
+ if not (isinstance(model, _get_supported_pretrained_model_types()) or is_peft_model(model)):
1480
+ raise MlflowException(
1481
+ "The supplied model type is unsupported. The model must be one of: "
1482
+ "PreTrainedModel, TFPreTrainedModel, FlaxPreTrainedModel, or PeftModel",
1483
+ error_code=INVALID_PARAMETER_VALUE,
1484
+ )
1485
+
1486
+ if task is None or task.startswith(_LLM_INFERENCE_TASK_PREFIX):
1487
+ default_task = _get_default_task_for_llm_inference_task(task)
1488
+ task = _get_task_for_model(model.name_or_path, default_task=default_task)
1489
+
1490
+ try:
1491
+ with suppress_logs("transformers.pipelines.base", filter_regex=_PEFT_PIPELINE_ERROR_MSG):
1492
+ return pipeline(task=task, **model_dict)
1493
+ except Exception as e:
1494
+ raise MlflowException(
1495
+ "The provided model configuration cannot be created as a Pipeline. "
1496
+ "Please verify that all required and compatible components are "
1497
+ "specified with the correct keys.",
1498
+ error_code=INVALID_PARAMETER_VALUE,
1499
+ ) from e
1500
+
1501
+
1502
+ def _get_task_for_model(model_name_or_path: str, default_task=None) -> str:
1503
+ """
1504
+ Get the Transformers pipeline task type fro the model instance.
1505
+
1506
+ NB: The get_task() function only works for remote models available in the Hugging
1507
+ Face hub, so the default task should be supplied when using a custom local model.
1508
+ """
1509
+ from transformers.pipelines import get_supported_tasks, get_task
1510
+
1511
+ try:
1512
+ model_task = get_task(model_name_or_path)
1513
+ if model_task in get_supported_tasks():
1514
+ return model_task
1515
+ elif default_task is not None:
1516
+ _logger.warning(
1517
+ f"The task '{model_task}' inferred from the model is not"
1518
+ "supported by the transformers pipeline. MLflow will "
1519
+ f"construct the pipeline with the fallback task {default_task} "
1520
+ "inferred from the specified 'llm/v1/xxx' task."
1521
+ )
1522
+ return default_task
1523
+ else:
1524
+ raise MlflowException(
1525
+ f"Cannot construct transformers pipeline because the task '{model_task}' "
1526
+ "inferred from the model is not supported by the transformers pipeline. "
1527
+ "Please construct the pipeline instance manually and pass it to the "
1528
+ "`log_model` or `save_model` function."
1529
+ )
1530
+
1531
+ except RuntimeError as e:
1532
+ if default_task:
1533
+ return default_task
1534
+ raise MlflowException(
1535
+ "The task could not be inferred from the model. If you are saving a custom "
1536
+ "local model that is not available in the Hugging Face hub, please provide "
1537
+ "the `task` argument to the `log_model` or `save_model` function.",
1538
+ error_code=INVALID_PARAMETER_VALUE,
1539
+ ) from e
1540
+
1541
+
1542
+ def _validate_llm_inference_task_type(llm_inference_task: str, pipeline_task: str) -> None:
1543
+ """
1544
+ Validates that an ``inference_task`` type is supported by ``transformers`` pipeline type.
1545
+ """
1546
+ supported_llm_inference_tasks = _SUPPORTED_LLM_INFERENCE_TASK_TYPES_BY_PIPELINE_TASK.get(
1547
+ pipeline_task, []
1548
+ )
1549
+
1550
+ if llm_inference_task not in supported_llm_inference_tasks:
1551
+ raise MlflowException(
1552
+ f"The task provided is invalid. '{llm_inference_task}' is not a supported task for "
1553
+ f"the {pipeline_task} pipeline. Must be one of {supported_llm_inference_tasks}",
1554
+ error_code=INVALID_PARAMETER_VALUE,
1555
+ )
1556
+
1557
+
1558
+ def _get_engine_type(model):
1559
+ """
1560
+ Determines the underlying execution engine for the model based on the 3 currently supported
1561
+ deep learning framework backends: ``tensorflow``, ``torch``, or ``flax``.
1562
+ """
1563
+ from transformers import FlaxPreTrainedModel, PreTrainedModel, TFPreTrainedModel
1564
+ from transformers.utils import is_torch_available
1565
+
1566
+ if is_peft_model(model):
1567
+ model = get_peft_base_model(model)
1568
+
1569
+ for cls in model.__class__.__mro__:
1570
+ if issubclass(cls, TFPreTrainedModel):
1571
+ return "tensorflow"
1572
+ elif issubclass(cls, PreTrainedModel):
1573
+ return "torch"
1574
+ elif issubclass(cls, FlaxPreTrainedModel):
1575
+ return "flax"
1576
+
1577
+ # As a fallback, we check current environment to determine the engine type
1578
+ return "torch" if is_torch_available() else "tensorflow"
1579
+
1580
+
1581
+ def _should_add_pyfunc_to_model(pipeline) -> bool:
1582
+ """
1583
+ Discriminator for determining whether a particular task type and model instance from within
1584
+ a ``Pipeline`` is currently supported for the pyfunc flavor.
1585
+
1586
+ Image and Video pipelines can still be logged and used, but are not available for
1587
+ loading as pyfunc.
1588
+ Similarly, esoteric model types (Graph Models, Timeseries Models, and Reinforcement Learning
1589
+ Models) are not permitted for loading as pyfunc due to the complex input types that, in
1590
+ order to support, will require significant modifications (breaking changes) to the pyfunc
1591
+ contract.
1592
+ """
1593
+ import transformers
1594
+
1595
+ exclusion_model_types = {
1596
+ "GraphormerPreTrainedModel",
1597
+ "InformerPreTrainedModel",
1598
+ "TimeSeriesTransformerPreTrainedModel",
1599
+ "DecisionTransformerPreTrainedModel",
1600
+ }
1601
+
1602
+ # NB: When pyfunc functionality is added for these pipeline types over time, remove the
1603
+ # entries from the following list.
1604
+ exclusion_pipeline_types = [
1605
+ "DocumentQuestionAnsweringPipeline",
1606
+ "ImageToTextPipeline",
1607
+ "VisualQuestionAnsweringPipeline",
1608
+ "ImageSegmentationPipeline",
1609
+ "DepthEstimationPipeline",
1610
+ "ObjectDetectionPipeline",
1611
+ "VideoClassificationPipeline",
1612
+ "ZeroShotImageClassificationPipeline",
1613
+ "ZeroShotObjectDetectionPipeline",
1614
+ "ZeroShotAudioClassificationPipeline",
1615
+ ]
1616
+
1617
+ for model_type in exclusion_model_types:
1618
+ if hasattr(transformers, model_type):
1619
+ if isinstance(pipeline.model, getattr(transformers, model_type)):
1620
+ return False
1621
+ if type(pipeline).__name__ in exclusion_pipeline_types:
1622
+ return False
1623
+ return True
1624
+
1625
+
1626
+ def _get_model_config(local_path, pyfunc_config):
1627
+ """
1628
+ Load the model configuration if it was provided for use in the `_TransformersWrapper` pyfunc
1629
+ Model wrapper.
1630
+ """
1631
+ config_path = local_path.joinpath("inference_config.txt")
1632
+ if config_path.exists():
1633
+ _logger.warning(
1634
+ "Inference config stored in file ``inference_config.txt`` is deprecated. New logged "
1635
+ "models will store the model configuration in the ``pyfunc`` flavor configuration."
1636
+ )
1637
+ return json.loads(config_path.read_text())
1638
+ else:
1639
+ return pyfunc_config or {}
1640
+
1641
+
1642
+ def _load_pyfunc(path, model_config: Optional[dict[str, Any]] = None):
1643
+ """
1644
+ Loads the model as pyfunc model
1645
+ """
1646
+ local_path = pathlib.Path(path)
1647
+ flavor_configuration = _get_flavor_configuration(local_path, FLAVOR_NAME)
1648
+ model_config = _get_model_config(local_path.joinpath(_COMPONENTS_BINARY_DIR_NAME), model_config)
1649
+ prompt_template = _get_prompt_template(local_path)
1650
+
1651
+ return _TransformersWrapper(
1652
+ _load_model(str(local_path), flavor_configuration, "pipeline"),
1653
+ flavor_configuration,
1654
+ model_config,
1655
+ prompt_template,
1656
+ )
1657
+
1658
+
1659
+ def _is_conversational_pipeline(pipeline):
1660
+ """
1661
+ Checks if the pipeline is a ConversationalPipeline.
1662
+ """
1663
+ if cp := _try_import_conversational_pipeline():
1664
+ return isinstance(pipeline, cp)
1665
+ return False
1666
+
1667
+
1668
+ def _try_import_conversational_pipeline():
1669
+ """
1670
+ Try importing ConversationalPipeline because for version > 4.41.2
1671
+ it is removed from the transformers package.
1672
+ """
1673
+ try:
1674
+ from transformers import ConversationalPipeline
1675
+
1676
+ return ConversationalPipeline
1677
+ except ImportError:
1678
+ return
1679
+
1680
+
1681
+ def generate_signature_output(pipeline, data, model_config=None, params=None, flavor_config=None):
1682
+ """
1683
+ Utility for generating the response output for the purposes of extracting an output signature
1684
+ for model saving and logging. This function simulates loading of a saved model or pipeline
1685
+ as a ``pyfunc`` model without having to incur a write to disk.
1686
+
1687
+ Args:
1688
+ pipeline: A ``transformers`` pipeline object. Note that component-level or model-level
1689
+ inputs are not permitted for extracting an output example.
1690
+ data: An example input that is compatible with the given pipeline
1691
+ model_config: Any additional model configuration, provided as kwargs, to inform
1692
+ the format of the output type from a pipeline inference call.
1693
+ params: A dictionary of additional parameters to pass to the pipeline for inference.
1694
+ flavor_config: The flavor configuration for the model.
1695
+
1696
+ Returns:
1697
+ The output from the ``pyfunc`` pipeline wrapper's ``predict`` method
1698
+ """
1699
+ import transformers
1700
+
1701
+ from mlflow.transformers import signature
1702
+
1703
+ if not isinstance(pipeline, transformers.Pipeline):
1704
+ raise MlflowException(
1705
+ f"The pipeline type submitted is not a valid transformers Pipeline. "
1706
+ f"The type {type(pipeline).__name__} is not supported.",
1707
+ error_code=INVALID_PARAMETER_VALUE,
1708
+ )
1709
+
1710
+ return signature.generate_signature_output(pipeline, data, model_config, params)
1711
+
1712
+
1713
+ class _TransformersWrapper:
1714
+ def __init__(self, pipeline, flavor_config=None, model_config=None, prompt_template=None):
1715
+ self.pipeline = pipeline
1716
+ self.flavor_config = flavor_config
1717
+ # The predict method updates the model_config several times. This should be done over a
1718
+ # deep copy of the original model_config that was specified by the user, otherwise the
1719
+ # prediction won't be idempotent. Hence we creates an immutable dictionary of the original
1720
+ # model config here and enforce creating a deep copy at every predict call.
1721
+ self.model_config = MappingProxyType(model_config or {})
1722
+
1723
+ self.prompt_template = prompt_template
1724
+ self._conversation = None
1725
+ # NB: Current special-case custom pipeline types that have not been added to
1726
+ # the native-supported transformers package but require custom parsing:
1727
+ # InstructionTextGenerationPipeline [Dolly] https://huggingface.co/databricks/dolly-v2-12b
1728
+ # (and all variants)
1729
+ self._supported_custom_generator_types = {"InstructionTextGenerationPipeline"}
1730
+ self.llm_inference_task = (
1731
+ self.flavor_config.get(_LLM_INFERENCE_TASK_KEY) if self.flavor_config else None
1732
+ )
1733
+
1734
+ def get_raw_model(self):
1735
+ """
1736
+ Returns the underlying model.
1737
+ """
1738
+ return self.pipeline
1739
+
1740
+ def _convert_pandas_to_dict(self, data):
1741
+ import transformers
1742
+
1743
+ if not isinstance(self.pipeline, transformers.ZeroShotClassificationPipeline):
1744
+ return data.to_dict(orient="records")
1745
+ else:
1746
+ # NB: The ZeroShotClassificationPipeline requires an input in the form of
1747
+ # Dict[str, Union[str, List[str]]] and will throw if an additional nested
1748
+ # List is present within the List value (which is what the duplicated values
1749
+ # within the orient="list" conversion in Pandas will do. This parser will
1750
+ # deduplicate label lists to a single list.
1751
+ unpacked = data.to_dict(orient="list")
1752
+ parsed = {}
1753
+ for key, value in unpacked.items():
1754
+ if isinstance(value, list):
1755
+ contents = []
1756
+ for item in value:
1757
+ # Deduplication logic
1758
+ if item not in contents:
1759
+ contents.append(item)
1760
+ # Collapse nested lists to return the correct data structure for the
1761
+ # ZeroShotClassificationPipeline input structure
1762
+ parsed[key] = (
1763
+ contents
1764
+ if all(isinstance(item, str) for item in contents) and len(contents) > 1
1765
+ else contents[0]
1766
+ )
1767
+ return parsed
1768
+
1769
+ def _merge_model_config_with_params(self, model_config, params):
1770
+ if params:
1771
+ _logger.warning(
1772
+ "params provided to the `predict` method will override the inference "
1773
+ "configuration saved with the model. If the params provided are not "
1774
+ "valid for the pipeline, MlflowException will be raised."
1775
+ )
1776
+ # Override the inference configuration with any additional kwargs provided by the user.
1777
+ return {**model_config, **params}
1778
+ else:
1779
+ return model_config
1780
+
1781
+ def _validate_model_config_and_return_output(self, data, model_config, return_tensors=False):
1782
+ import transformers
1783
+
1784
+ if return_tensors:
1785
+ model_config["return_tensors"] = True
1786
+ if model_config.get("return_full_text", None) is not None:
1787
+ _logger.warning(
1788
+ "The `return_full_text` parameter is mutually exclusive with the "
1789
+ "`return_tensors` parameter set when a MLflow inference task is provided. "
1790
+ "The `return_full_text` parameter will be ignored."
1791
+ )
1792
+ # `return_full_text` is mutually exclusive with `return_tensors`
1793
+ model_config["return_full_text"] = None
1794
+
1795
+ try:
1796
+ if isinstance(data, dict):
1797
+ return self.pipeline(**data, **model_config)
1798
+ return self.pipeline(data, **model_config)
1799
+ except ValueError as e:
1800
+ if "The following `model_kwargs` are not used by the model" in str(e):
1801
+ raise MlflowException.invalid_parameter_value(
1802
+ "The params provided to the `predict` method are not valid "
1803
+ f"for pipeline {type(self.pipeline).__name__}.",
1804
+ ) from e
1805
+ if isinstance(
1806
+ self.pipeline,
1807
+ (
1808
+ transformers.AutomaticSpeechRecognitionPipeline,
1809
+ transformers.AudioClassificationPipeline,
1810
+ ),
1811
+ ) and (
1812
+ # transformers <= 4.33.3
1813
+ "Malformed soundfile" in str(e)
1814
+ # transformers > 4.33.3
1815
+ or "Soundfile is either not in the correct format or is malformed" in str(e)
1816
+ ):
1817
+ raise MlflowException.invalid_parameter_value(
1818
+ "Failed to process the input audio data. Either the audio file is "
1819
+ "corrupted or a uri was passed in without overriding the default model "
1820
+ "signature. If submitting a string uri, please ensure that the model has "
1821
+ "been saved with a signature that defines a string input type.",
1822
+ ) from e
1823
+ raise
1824
+
1825
+ def predict(self, data, params: Optional[dict[str, Any]] = None):
1826
+ """
1827
+ Args:
1828
+ data: Model input data.
1829
+ params: Additional parameters to pass to the model for inference.
1830
+
1831
+ Returns:
1832
+ Model predictions.
1833
+ """
1834
+ # NB: This `predict` method updates the model_config several times. To make the predict
1835
+ # call idempotent, we keep the original self.model_config immutable and creates a deep
1836
+ # copy of it at every predict call.
1837
+ model_config = copy.deepcopy(dict(self.model_config))
1838
+ params = self._merge_model_config_with_params(model_config, params)
1839
+
1840
+ if self.llm_inference_task == _LLM_INFERENCE_TASK_CHAT:
1841
+ data, params = preprocess_llm_inference_input(data, params, self.flavor_config)
1842
+ data = [convert_messages_to_prompt(msgs, self.pipeline.tokenizer) for msgs in data]
1843
+ elif self.llm_inference_task == _LLM_INFERENCE_TASK_COMPLETIONS:
1844
+ data, params = preprocess_llm_inference_input(data, params, self.flavor_config)
1845
+ elif self.llm_inference_task == _LLM_INFERENCE_TASK_EMBEDDING:
1846
+ data, params = preprocess_llm_embedding_params(data)
1847
+
1848
+ if isinstance(data, pd.DataFrame):
1849
+ input_data = self._convert_pandas_to_dict(data)
1850
+ elif isinstance(data, (dict, str, bytes, np.ndarray)):
1851
+ input_data = data
1852
+ elif isinstance(data, list):
1853
+ if not all(isinstance(entry, (str, dict)) for entry in data):
1854
+ raise MlflowException(
1855
+ "Invalid data submission. Ensure all elements in the list are strings "
1856
+ "or dictionaries. If dictionaries are supplied, all keys in the "
1857
+ "dictionaries must be strings and values must be either str or List[str].",
1858
+ error_code=INVALID_PARAMETER_VALUE,
1859
+ )
1860
+ input_data = data
1861
+ else:
1862
+ raise MlflowException(
1863
+ "Input data must be either a pandas.DataFrame, a string, bytes, List[str], "
1864
+ "List[Dict[str, str]], List[Dict[str, Union[str, List[str]]]], "
1865
+ "or Dict[str, Union[str, List[str]]].",
1866
+ error_code=INVALID_PARAMETER_VALUE,
1867
+ )
1868
+ input_data = self._parse_raw_pipeline_input(input_data)
1869
+ # Validate resolved or input dict types
1870
+ if isinstance(input_data, dict):
1871
+ _validate_input_dictionary_contains_only_strings_and_lists_of_strings(input_data)
1872
+ elif isinstance(input_data, list) and all(isinstance(entry, dict) for entry in input_data):
1873
+ # Validate each dict inside an input List[Dict]
1874
+ all(
1875
+ _validate_input_dictionary_contains_only_strings_and_lists_of_strings(x)
1876
+ for x in input_data
1877
+ )
1878
+ return self._predict(input_data, params)
1879
+
1880
+ def _predict(self, data, model_config):
1881
+ import transformers
1882
+
1883
+ # NB: the ordering of these conditional statements matters. TranslationPipeline and
1884
+ # SummarizationPipeline both inherit from TextGenerationPipeline (they are subclasses)
1885
+ # in which the return data structure from their __call__ implementation is modified.
1886
+ if isinstance(self.pipeline, transformers.TranslationPipeline):
1887
+ self._validate_str_or_list_str(data)
1888
+ output_key = "translation_text"
1889
+ elif isinstance(self.pipeline, transformers.SummarizationPipeline):
1890
+ self._validate_str_or_list_str(data)
1891
+ data = self._format_prompt_template(data)
1892
+ output_key = "summary_text"
1893
+ elif isinstance(self.pipeline, transformers.Text2TextGenerationPipeline):
1894
+ data = self._parse_text2text_input(data)
1895
+ data = self._format_prompt_template(data)
1896
+ output_key = "generated_text"
1897
+ elif isinstance(self.pipeline, transformers.TextGenerationPipeline):
1898
+ self._validate_str_or_list_str(data)
1899
+ data = self._format_prompt_template(data)
1900
+ output_key = "generated_text"
1901
+ elif isinstance(self.pipeline, transformers.QuestionAnsweringPipeline):
1902
+ data = self._parse_question_answer_input(data)
1903
+ output_key = "answer"
1904
+ elif isinstance(self.pipeline, transformers.FillMaskPipeline):
1905
+ self._validate_str_or_list_str(data)
1906
+ data = self._format_prompt_template(data)
1907
+ output_key = "token_str"
1908
+ elif isinstance(self.pipeline, transformers.TextClassificationPipeline):
1909
+ output_key = "label"
1910
+ elif isinstance(self.pipeline, transformers.ImageClassificationPipeline):
1911
+ data = self._convert_image_input(data)
1912
+ output_key = "label"
1913
+ elif isinstance(self.pipeline, transformers.ZeroShotClassificationPipeline):
1914
+ output_key = "labels"
1915
+ data = self._parse_json_encoded_list(data, "candidate_labels")
1916
+ elif isinstance(self.pipeline, transformers.TableQuestionAnsweringPipeline):
1917
+ output_key = "answer"
1918
+ data = self._parse_json_encoded_dict_payload_to_dict(data, "table")
1919
+ elif isinstance(self.pipeline, transformers.TokenClassificationPipeline):
1920
+ output_key = {"entity_group", "entity"}
1921
+ elif isinstance(self.pipeline, transformers.FeatureExtractionPipeline):
1922
+ output_key = None
1923
+ data = self._parse_feature_extraction_input(data)
1924
+ data = self._format_prompt_template(data)
1925
+ elif _is_conversational_pipeline(self.pipeline):
1926
+ output_key = None
1927
+ if not self._conversation:
1928
+ # this import is valid if conversational_pipeline is not None
1929
+ self._conversation = transformers.Conversation()
1930
+ self._conversation.add_user_input(data)
1931
+ elif type(self.pipeline).__name__ in self._supported_custom_generator_types:
1932
+ self._validate_str_or_list_str(data)
1933
+ output_key = "generated_text"
1934
+ elif isinstance(self.pipeline, transformers.AutomaticSpeechRecognitionPipeline):
1935
+ if model_config.get("return_timestamps", None) in ["word", "char"]:
1936
+ output_key = None
1937
+ else:
1938
+ output_key = "text"
1939
+ data = self._convert_audio_input(data)
1940
+ elif isinstance(self.pipeline, transformers.AudioClassificationPipeline):
1941
+ data = self._convert_audio_input(data)
1942
+ output_key = None
1943
+ else:
1944
+ raise MlflowException(
1945
+ f"The loaded pipeline type {type(self.pipeline).__name__} is "
1946
+ "not enabled for pyfunc predict functionality.",
1947
+ error_code=BAD_REQUEST,
1948
+ )
1949
+
1950
+ # Optional input preservation for specific pipeline types. This is True (include raw
1951
+ # formatting output), but if `include_prompt` is set to False in the `model_config`
1952
+ # option during model saving, excess newline characters and the fed-in prompt will be
1953
+ # trimmed out from the start of the response.
1954
+ include_prompt = model_config.pop("include_prompt", True)
1955
+ # Optional stripping out of `\n` for specific generator pipelines.
1956
+ collapse_whitespace = model_config.pop("collapse_whitespace", False)
1957
+
1958
+ data = self._convert_cast_lists_from_np_back_to_list(data)
1959
+
1960
+ # Generate inference data with the pipeline object
1961
+ if _is_conversational_pipeline(self.pipeline):
1962
+ conversation_output = self.pipeline(self._conversation)
1963
+ return conversation_output.generated_responses[-1]
1964
+ else:
1965
+ # If inference task is defined, return tensors internally to get usage information
1966
+ return_tensors = False
1967
+ if self.llm_inference_task:
1968
+ return_tensors = True
1969
+ output_key = "generated_token_ids"
1970
+
1971
+ raw_output = self._validate_model_config_and_return_output(
1972
+ data, model_config=model_config, return_tensors=return_tensors
1973
+ )
1974
+
1975
+ # Handle the pipeline outputs
1976
+ if type(self.pipeline).__name__ in self._supported_custom_generator_types or isinstance(
1977
+ self.pipeline, transformers.TextGenerationPipeline
1978
+ ):
1979
+ output = self._strip_input_from_response_in_instruction_pipelines(
1980
+ data,
1981
+ raw_output,
1982
+ output_key,
1983
+ self.flavor_config,
1984
+ include_prompt,
1985
+ collapse_whitespace,
1986
+ )
1987
+
1988
+ if self.llm_inference_task:
1989
+ output = postprocess_output_for_llm_inference_task(
1990
+ data,
1991
+ output,
1992
+ self.pipeline,
1993
+ self.flavor_config,
1994
+ model_config,
1995
+ self.llm_inference_task,
1996
+ )
1997
+
1998
+ elif isinstance(self.pipeline, transformers.FeatureExtractionPipeline):
1999
+ if self.llm_inference_task:
2000
+ output = [np.array(tensor[0][0]) for tensor in raw_output]
2001
+ output = postprocess_output_for_llm_v1_embedding_task(
2002
+ data, output, self.pipeline.tokenizer
2003
+ )
2004
+ else:
2005
+ return self._parse_feature_extraction_output(raw_output)
2006
+ elif isinstance(self.pipeline, transformers.FillMaskPipeline):
2007
+ output = self._parse_list_of_multiple_dicts(raw_output, output_key)
2008
+ elif isinstance(self.pipeline, transformers.ZeroShotClassificationPipeline):
2009
+ return self._flatten_zero_shot_text_classifier_output_to_df(raw_output)
2010
+ elif isinstance(self.pipeline, transformers.TokenClassificationPipeline):
2011
+ output = self._parse_tokenizer_output(raw_output, output_key)
2012
+ elif isinstance(
2013
+ self.pipeline, transformers.AutomaticSpeechRecognitionPipeline
2014
+ ) and model_config.get("return_timestamps", None) in ["word", "char"]:
2015
+ output = json.dumps(raw_output)
2016
+ elif isinstance(
2017
+ self.pipeline,
2018
+ (
2019
+ transformers.AudioClassificationPipeline,
2020
+ transformers.TextClassificationPipeline,
2021
+ transformers.ImageClassificationPipeline,
2022
+ ),
2023
+ ):
2024
+ return pd.DataFrame(raw_output)
2025
+ else:
2026
+ output = self._parse_lists_of_dict_to_list_of_str(raw_output, output_key)
2027
+
2028
+ sanitized = self._sanitize_output(output, data)
2029
+ return self._wrap_strings_as_list_if_scalar(sanitized)
2030
+
2031
+ def _parse_raw_pipeline_input(self, data):
2032
+ """
2033
+ Converts inputs to the expected types for specific Pipeline types.
2034
+ Specific logic for individual pipeline types are called via their respective methods if
2035
+ the input isn't a basic str or List[str] input type of Pipeline.
2036
+ These parsers are required due to the conversion that occurs within schema validation to
2037
+ a Pandas DataFrame encapsulation, a format which is unsupported for the `transformers`
2038
+ library.
2039
+ """
2040
+ import transformers
2041
+
2042
+ if isinstance(self.pipeline, transformers.TableQuestionAnsweringPipeline):
2043
+ data = self._coerce_exploded_dict_to_single_dict(data)
2044
+ return self._parse_input_for_table_question_answering(data)
2045
+ elif _is_conversational_pipeline(self.pipeline):
2046
+ return self._parse_conversation_input(data)
2047
+ elif ( # noqa: SIM114
2048
+ isinstance(
2049
+ self.pipeline,
2050
+ (
2051
+ transformers.FillMaskPipeline,
2052
+ transformers.TextGenerationPipeline,
2053
+ transformers.TranslationPipeline,
2054
+ transformers.SummarizationPipeline,
2055
+ transformers.TokenClassificationPipeline,
2056
+ ),
2057
+ )
2058
+ and isinstance(data, list)
2059
+ and all(isinstance(entry, dict) for entry in data)
2060
+ ):
2061
+ return [list(entry.values())[0] for entry in data]
2062
+ # NB: For Text2TextGenerationPipeline, we need more complex handling for dictionary,
2063
+ # as we allow both single string input and dictionary input (or list of them). Both
2064
+ # are once wrapped to Pandas DataFrame during schema enforcement and convert back to
2065
+ # dictionary. The difference between two is columns of the DataFrame, where the first
2066
+ # case (string) will have auto-generated columns like 0, 1, ... while the latter (dict)
2067
+ # will have the original keys to be the columns. When converting back to dictionary,
2068
+ # those columns will becomes the key of dictionary.
2069
+ #
2070
+ # E.g.
2071
+ # 1. If user's input is string like model.predict("foo")
2072
+ # -> Raw input: "foo"
2073
+ # -> Pandas dataframe has column 0, with single row "foo"
2074
+ # -> Derived dictionary will be {0: "foo"}
2075
+ # 2. If user's input is dictionary like model.predict({"text": "foo"})
2076
+ # -> Raw input: {"text": "foo"}
2077
+ # -> Pandas dataframe has column "text", with single row "foo"
2078
+ # -> Derived dictionary will be {"text": "foo"}
2079
+ #
2080
+ # Then for the first case, we want to extract values only, similar to other pipelines.
2081
+ # However, for the second case, we want to keep the key-value pair as it is.
2082
+ # In long-term, we should definitely change the upstream handling to avoid this
2083
+ # complexity, but here we just try to make it work by checking if the key is auto-generated.
2084
+ elif (
2085
+ isinstance(self.pipeline, transformers.Text2TextGenerationPipeline)
2086
+ and isinstance(data, list)
2087
+ and all(isinstance(entry, dict) for entry in data)
2088
+ # Pandas Dataframe derived dictionary will have integer key (row index)
2089
+ and 0 in data[0].keys()
2090
+ ):
2091
+ return [list(entry.values())[0] for entry in data]
2092
+ elif isinstance(self.pipeline, transformers.TextClassificationPipeline):
2093
+ return self._validate_text_classification_input(data)
2094
+ else:
2095
+ return data
2096
+
2097
+ @staticmethod
2098
+ def _validate_text_classification_input(data):
2099
+ """
2100
+ Perform input type validation for TextClassification pipelines and casting of data
2101
+ that is manipulated internally by the MLflow model server back to a structure that
2102
+ can be used for pipeline inference.
2103
+
2104
+ To illustrate the input and outputs of this function, for the following inputs to
2105
+ the pyfunc.predict() call for this pipeline type:
2106
+
2107
+ "text to classify"
2108
+ ["text to classify", "other text to classify"]
2109
+ {"text": "text to classify", "text_pair": "pair text"}
2110
+ [{"text": "text", "text_pair": "pair"}, {"text": "t", "text_pair": "tp" }]
2111
+
2112
+ Pyfunc processing will convert these to the following structures:
2113
+
2114
+ [{0: "text to classify"}]
2115
+ [{0: "text to classify"}, {0: "other text to classify"}]
2116
+ [{"text": "text to classify", "text_pair": "pair text"}]
2117
+ [{"text": "text", "text_pair": "pair"}, {"text": "t", "text_pair": "tp" }]
2118
+
2119
+ The purpose of this function is to convert them into the correct format for input
2120
+ to the pipeline (wrapping as a list has no bearing on the correctness of the
2121
+ inferred classifications):
2122
+
2123
+ ["text to classify"]
2124
+ ["text to classify", "other text to classify"]
2125
+ [{"text": "text to classify", "text_pair": "pair text"}]
2126
+ [{"text": "text", "text_pair": "pair"}, {"text": "t", "text_pair": "tp" }]
2127
+
2128
+ Additionally, for dict input types (the 'text' & 'text_pair' input example), the dict
2129
+ input will be JSON stringified within MLflow model serving. In order to reconvert this
2130
+ structure back into the appropriate type, we use ast.literal_eval() to convert back
2131
+ to a dict. We avoid using JSON.loads() due to pandas DataFrame conversions that invert
2132
+ single and double quotes with escape sequences that are not consistent if the string
2133
+ contains escaped quotes.
2134
+ """
2135
+
2136
+ def _check_keys(payload):
2137
+ """Check if a dictionary contains only allowable keys."""
2138
+ allowable_str_keys = {"text", "text_pair"}
2139
+ if set(payload) - allowable_str_keys and not all(
2140
+ isinstance(key, int) for key in payload.keys()
2141
+ ):
2142
+ raise MlflowException(
2143
+ "Text Classification pipelines may only define dictionary inputs with keys "
2144
+ f"defined as {allowable_str_keys}"
2145
+ )
2146
+
2147
+ if isinstance(data, str):
2148
+ return data
2149
+ elif isinstance(data, dict):
2150
+ _check_keys(data)
2151
+ return data
2152
+ elif isinstance(data, list):
2153
+ if all(isinstance(item, str) for item in data):
2154
+ return data
2155
+ elif all(isinstance(item, dict) for item in data):
2156
+ for payload in data:
2157
+ _check_keys(payload)
2158
+ if list(data[0].keys())[0] == 0:
2159
+ data = [item[0] for item in data]
2160
+ try:
2161
+ # NB: To support MLflow serving signature validation, the value within dict
2162
+ # inputs is JSON encoded. In order for the proper data structure input support
2163
+ # for a {"text": "a", "text_pair": "b"} (or the list of such a structure) as
2164
+ # an input, we have to convert the string encoded dict back to a dict.
2165
+ # Due to how unescaped characters (such as "'") are encoded, using an explicit
2166
+ # json.loads() attempted cast can result in invalid input data to the pipeline.
2167
+ # ast.literal_eval() shows correct conversion, as validated in unit tests.
2168
+ return [ast.literal_eval(s) for s in data]
2169
+ except (ValueError, SyntaxError):
2170
+ return data
2171
+ else:
2172
+ raise MlflowException(
2173
+ "An unsupported data type has been passed for Text Classification inference. "
2174
+ "Only str, list of str, dict, and list of dict are supported."
2175
+ )
2176
+ else:
2177
+ raise MlflowException(
2178
+ "An unsupported data type has been passed for Text Classification inference. "
2179
+ "Only str, list of str, dict, and list of dict are supported."
2180
+ )
2181
+
2182
+ def _parse_conversation_input(self, data) -> str:
2183
+ if isinstance(data, str):
2184
+ return data
2185
+ elif isinstance(data, list) and all(isinstance(elem, dict) for elem in data):
2186
+ return next(iter(data[0].values()))
2187
+ elif isinstance(data, dict):
2188
+ # The conversation pipeline can only accept a single string at a time
2189
+ return next(iter(data.values()))
2190
+
2191
+ def _parse_input_for_table_question_answering(self, data):
2192
+ if "table" not in data:
2193
+ raise MlflowException(
2194
+ "The input dictionary must have the 'table' key.",
2195
+ error_code=INVALID_PARAMETER_VALUE,
2196
+ )
2197
+ elif isinstance(data["table"], dict):
2198
+ data["table"] = json.dumps(data["table"])
2199
+ return data
2200
+ else:
2201
+ return data
2202
+
2203
+ def _coerce_exploded_dict_to_single_dict(
2204
+ self, data: list[dict[str, Any]]
2205
+ ) -> dict[str, list[Any]]:
2206
+ """
2207
+ Parses the result of Pandas DataFrame.to_dict(orient="records") from pyfunc
2208
+ signature validation to coerce the output to the required format for a
2209
+ Pipeline that requires a single dict with list elements such as
2210
+ TableQuestionAnsweringPipeline.
2211
+ Example input:
2212
+
2213
+ [
2214
+ {"answer": "We should order more pizzas to meet the demand."},
2215
+ {"answer": "The venue size should be updated to handle the number of guests."},
2216
+ ]
2217
+
2218
+ Output:
2219
+
2220
+ {
2221
+ "answer": [
2222
+ "We should order more pizzas to meet the demand.",
2223
+ "The venue size should be updated to handle the number of guests.",
2224
+ ]
2225
+ }
2226
+ """
2227
+ if isinstance(data, list) and all(isinstance(item, dict) for item in data):
2228
+ collection = data.copy()
2229
+ parsed = collection[0]
2230
+ for coll in collection:
2231
+ for key, value in coll.items():
2232
+ if key not in parsed:
2233
+ raise MlflowException(
2234
+ "Unable to parse the input. The keys within each "
2235
+ "dictionary of the parsed input are not consistent"
2236
+ "among the dictionaries.",
2237
+ error_code=INVALID_PARAMETER_VALUE,
2238
+ )
2239
+ if value != parsed[key]:
2240
+ value_type = type(parsed[key])
2241
+ if value_type == str:
2242
+ parsed[key] = [parsed[key], value]
2243
+ elif value_type == list:
2244
+ if all(len(entry) == 1 for entry in value):
2245
+ # This conversion is required solely for model serving.
2246
+ # In the parsing logic that occurs internally, strings that
2247
+ # contain single quotes `'` result in casting to a List[char]
2248
+ # instead of a str type. Attempting to append a List[char]
2249
+ # to a List[str] as would happen in the `else` block here
2250
+ # results in the entire List being overwritten as `None` without
2251
+ # an Exception being raised. By checking for single value entries
2252
+ # and subsequently converting to list and extracting the first
2253
+ # element reconstructs the original input string.
2254
+ parsed[key].append([str(value)][0])
2255
+ else:
2256
+ parsed[key] = parsed[key].append(value)
2257
+ else:
2258
+ parsed[key] = value
2259
+ return parsed
2260
+ else:
2261
+ return data
2262
+
2263
+ def _flatten_zero_shot_text_classifier_output_to_df(self, data):
2264
+ """
2265
+ Converts the output of sequences, labels, and scores to a Pandas DataFrame output.
2266
+
2267
+ Example input:
2268
+
2269
+ [{'sequence': 'My dog loves to eat spaghetti',
2270
+ 'labels': ['happy', 'sad'],
2271
+ 'scores': [0.9896970987319946, 0.010302911512553692]},
2272
+ {'sequence': 'My dog hates going to the vet',
2273
+ 'labels': ['sad', 'happy'],
2274
+ 'scores': [0.957074761390686, 0.042925238609313965]}]
2275
+
2276
+ Output:
2277
+
2278
+ pd.DataFrame in a fully normalized (flattened) format with each sequence, label, and score
2279
+ having a row entry.
2280
+ For example, here is the DataFrame output:
2281
+
2282
+ sequence labels scores
2283
+ 0 My dog loves to eat spaghetti happy 0.989697
2284
+ 1 My dog loves to eat spaghetti sad 0.010303
2285
+ 2 My dog hates going to the vet sad 0.957075
2286
+ 3 My dog hates going to the vet happy 0.042925
2287
+ """
2288
+ if isinstance(data, list) and not all(isinstance(item, dict) for item in data):
2289
+ raise MlflowException(
2290
+ "Encountered an unknown return type from the pipeline type "
2291
+ f"{type(self.pipeline).__name__}. Expecting a List[Dict]",
2292
+ error_code=BAD_REQUEST,
2293
+ )
2294
+ if isinstance(data, dict):
2295
+ data = [data]
2296
+
2297
+ flattened_data = []
2298
+ for entry in data:
2299
+ for label, score in zip(entry["labels"], entry["scores"]):
2300
+ flattened_data.append(
2301
+ {"sequence": entry["sequence"], "labels": label, "scores": score}
2302
+ )
2303
+ return pd.DataFrame(flattened_data)
2304
+
2305
+ def _strip_input_from_response_in_instruction_pipelines(
2306
+ self,
2307
+ input_data,
2308
+ output,
2309
+ output_key,
2310
+ flavor_config,
2311
+ include_prompt=True,
2312
+ collapse_whitespace=False,
2313
+ ):
2314
+ """
2315
+ Parse the output from instruction pipelines to conform with other text generator
2316
+ pipeline types and remove line feed characters and other confusing outputs
2317
+ """
2318
+
2319
+ def extract_response_data(data_out):
2320
+ if all(isinstance(x, dict) for x in data_out):
2321
+ return [elem[output_key] for elem in data_out][0]
2322
+ elif all(isinstance(x, list) for x in data_out):
2323
+ return [elem[output_key] for coll in data_out for elem in coll]
2324
+ else:
2325
+ raise MlflowException(
2326
+ "Unable to parse the pipeline output. Expected List[Dict[str,str]] or "
2327
+ f"List[List[Dict[str,str]]] but got {type(data_out)} instead."
2328
+ )
2329
+
2330
+ output = extract_response_data(output)
2331
+
2332
+ def trim_input(data_in, data_out):
2333
+ # NB: the '\n\n' pattern is exclusive to specific InstructionalTextGenerationPipeline
2334
+ # types that have been loaded as a plain TextGenerator. The structure of these
2335
+ # pipelines will precisely repeat the input question immediately followed by 2 carriage
2336
+ # return statements, followed by the start of the response to the prompt. We only
2337
+ # want to left-trim these types of pipelines output values if the user has indicated
2338
+ # the removal action of the input prompt in the returned str or List[str] by applying
2339
+ # the optional model_config entry of `{"include_prompt": False}`.
2340
+ # By default, the prompt is included in the response.
2341
+ # Stripping out additional carriage returns (\n) is another additional optional flag
2342
+ # that can be set for these generator pipelines. It is off by default (False).
2343
+ if (
2344
+ not include_prompt
2345
+ and flavor_config[FlavorKey.INSTANCE_TYPE] in self._supported_custom_generator_types
2346
+ and data_out.startswith(data_in + "\n\n")
2347
+ ):
2348
+ # If the user has indicated to not preserve the prompt input in the response,
2349
+ # split the response output and trim the input prompt from the response.
2350
+ data_out = data_out[len(data_in) :].lstrip()
2351
+ if data_out.startswith("A:"):
2352
+ data_out = data_out[2:].lstrip()
2353
+
2354
+ # If the user has indicated to remove newlines and extra spaces from the generated
2355
+ # text, replace them with a single space.
2356
+ if collapse_whitespace:
2357
+ data_out = re.sub(r"\s+", " ", data_out).strip()
2358
+ return data_out
2359
+
2360
+ if isinstance(input_data, list) and isinstance(output, list):
2361
+ return [trim_input(data_in, data_out) for data_in, data_out in zip(input_data, output)]
2362
+ elif isinstance(input_data, str) and isinstance(output, str):
2363
+ return trim_input(input_data, output)
2364
+ else:
2365
+ raise MlflowException(
2366
+ "Unknown data structure after parsing output. Expected str or List[str]. "
2367
+ f"Got {type(output)} instead."
2368
+ )
2369
+
2370
+ def _sanitize_output(self, output, input_data):
2371
+ # Some pipelines and their underlying models leave leading or trailing whitespace.
2372
+ # This method removes that whitespace.
2373
+ import transformers
2374
+
2375
+ if (
2376
+ not isinstance(self.pipeline, transformers.TokenClassificationPipeline)
2377
+ and isinstance(input_data, str)
2378
+ and isinstance(output, list)
2379
+ ):
2380
+ # Retrieve the first output for return types that are List[str] of only a single
2381
+ # element.
2382
+ output = output[0]
2383
+ if isinstance(output, str):
2384
+ return output.strip()
2385
+ elif isinstance(output, list):
2386
+ if all(isinstance(elem, str) for elem in output):
2387
+ cleaned = [text.strip() for text in output]
2388
+ # If the list has only a single string, return as string.
2389
+ return cleaned if len(cleaned) > 1 else cleaned[0]
2390
+ else:
2391
+ return [self._sanitize_output(coll, input_data) for coll in output]
2392
+ elif isinstance(output, dict) and all(
2393
+ isinstance(key, str) and isinstance(value, str) for key, value in output.items()
2394
+ ):
2395
+ return {k: v.strip() for k, v in output.items()}
2396
+ else:
2397
+ return output
2398
+
2399
+ @staticmethod
2400
+ def _wrap_strings_as_list_if_scalar(output_data):
2401
+ """
2402
+ Wraps single string outputs in a list to support batch processing logic in serving.
2403
+ Scalar values are not supported for processing in batch logic as they cannot be coerced
2404
+ to DataFrame representations.
2405
+ """
2406
+ if isinstance(output_data, str):
2407
+ return [output_data]
2408
+ else:
2409
+ return output_data
2410
+
2411
+ def _parse_lists_of_dict_to_list_of_str(self, output_data, target_dict_key) -> list[str]:
2412
+ """
2413
+ Parses the output results from select Pipeline types to extract specific values from a
2414
+ target key.
2415
+ Examples (with "a" as the `target_dict_key`):
2416
+
2417
+ Input: [{"a": "valid", "b": "invalid"}, {"a": "another valid", "c": invalid"}]
2418
+ Output: ["valid", "another_valid"]
2419
+
2420
+ Input: [{"a": "valid", "b": [{"a": "another valid"}, {"b": "invalid"}]},
2421
+ {"a": "valid 2", "b": [{"a": "another valid 2"}, {"c": "invalid"}]}]
2422
+ Output: ["valid", "another valid", "valid 2", "another valid 2"]
2423
+ """
2424
+ if isinstance(output_data, list):
2425
+ output_coll = []
2426
+ for output in output_data:
2427
+ if isinstance(output, dict):
2428
+ for key, value in output.items():
2429
+ if key == target_dict_key:
2430
+ output_coll.append(output[target_dict_key])
2431
+ elif isinstance(value, list) and all(
2432
+ isinstance(elem, dict) for elem in value
2433
+ ):
2434
+ output_coll.extend(
2435
+ self._parse_lists_of_dict_to_list_of_str(value, target_dict_key)
2436
+ )
2437
+ elif isinstance(output, list):
2438
+ output_coll.extend(
2439
+ self._parse_lists_of_dict_to_list_of_str(output, target_dict_key)
2440
+ )
2441
+ return output_coll
2442
+ elif target_dict_key:
2443
+ return output_data[target_dict_key]
2444
+ else:
2445
+ return output_data
2446
+
2447
+ @staticmethod
2448
+ def _parse_feature_extraction_input(input_data):
2449
+ if isinstance(input_data, list) and isinstance(input_data[0], dict):
2450
+ return [list(data.values())[0] for data in input_data]
2451
+ else:
2452
+ return input_data
2453
+
2454
+ @staticmethod
2455
+ def _parse_feature_extraction_output(output_data):
2456
+ """
2457
+ Parse the return type from a FeatureExtractionPipeline output. The mixed types for
2458
+ input are present depending on how the pyfunc is instantiated. For model serving usage,
2459
+ the returned type from MLServer will be a numpy.ndarray type, otherwise, the return
2460
+ within a manually executed pyfunc (i.e., for udf usage), the return will be a collection
2461
+ of nested lists.
2462
+
2463
+ Examples:
2464
+
2465
+ Input: [[[0.11, 0.98, 0.76]]] or np.array([0.11, 0.98, 0.76])
2466
+ Output: np.array([0.11, 0.98, 0.76])
2467
+
2468
+ Input: [[[[0.1, 0.2], [0.3, 0.4]]]] or
2469
+ np.array([np.array([0.1, 0.2]), np.array([0.3, 0.4])])
2470
+ Output: np.array([np.array([0.1, 0.2]), np.array([0.3, 0.4])])
2471
+ """
2472
+ if isinstance(output_data, np.ndarray):
2473
+ return output_data
2474
+ else:
2475
+ return np.array(output_data[0][0])
2476
+
2477
+ def _parse_tokenizer_output(self, output_data, target_set):
2478
+ """
2479
+ Parses the tokenizer pipeline output.
2480
+
2481
+ Examples:
2482
+
2483
+ Input: [{"entity": "PRON", "score": 0.95}, {"entity": "NOUN", "score": 0.998}]
2484
+ Output: "PRON,NOUN"
2485
+
2486
+ Input: [[{"entity": "PRON", "score": 0.95}, {"entity": "NOUN", "score": 0.998}],
2487
+ [{"entity": "PRON", "score": 0.95}, {"entity": "NOUN", "score": 0.998}]]
2488
+ Output: ["PRON,NOUN", "PRON,NOUN"]
2489
+ """
2490
+ # NB: We're collapsing the results here to a comma separated string for each inference
2491
+ # input string. This is to simplify having to otherwise make extensive changes to
2492
+ # ColSpec in order to support schema enforcement of List[List[str]]
2493
+ if isinstance(output_data[0], list):
2494
+ return [self._parse_tokenizer_output(coll, target_set) for coll in output_data]
2495
+ else:
2496
+ # NB: Since there are no attributes accessible from the pipeline object that determine
2497
+ # what the characteristics of the return structure names are within the dictionaries,
2498
+ # Determine which one is present in the output to extract the correct entries.
2499
+ target = target_set.intersection(output_data[0].keys()).pop()
2500
+ return ",".join([coll[target] for coll in output_data])
2501
+
2502
+ @staticmethod
2503
+ def _parse_list_of_multiple_dicts(output_data, target_dict_key):
2504
+ """
2505
+ Returns the first value of the `target_dict_key` that matches in the first dictionary in a
2506
+ list of dictionaries.
2507
+ """
2508
+
2509
+ def fetch_target_key_value(data, key):
2510
+ if isinstance(data[0], dict):
2511
+ return data[0][key]
2512
+ return [item[0][key] for item in data]
2513
+
2514
+ if isinstance(output_data[0], list):
2515
+ return [
2516
+ fetch_target_key_value(collection, target_dict_key) for collection in output_data
2517
+ ]
2518
+ else:
2519
+ return [output_data[0][target_dict_key]]
2520
+
2521
+ def _parse_question_answer_input(self, data):
2522
+ """
2523
+ Parses the single string input representation for a question answer pipeline into the
2524
+ required dict format for a `question-answering` pipeline.
2525
+ """
2526
+ if isinstance(data, list):
2527
+ return [self._parse_question_answer_input(entry) for entry in data]
2528
+ elif isinstance(data, dict):
2529
+ expected_keys = {"question", "context"}
2530
+ if not expected_keys.intersection(set(data.keys())) == expected_keys:
2531
+ raise MlflowException(
2532
+ f"Invalid keys were submitted. Keys must be exclusively {expected_keys}"
2533
+ )
2534
+ return data
2535
+ else:
2536
+ raise MlflowException(
2537
+ "An invalid type has been supplied. Must be either List[Dict[str, str]] or "
2538
+ f"Dict[str, str]. {type(data)} is not supported.",
2539
+ error_code=INVALID_PARAMETER_VALUE,
2540
+ )
2541
+
2542
+ def _parse_text2text_input(self, data):
2543
+ """
2544
+ Parses the mixed input types that can be submitted into a text2text Pipeline.
2545
+ Valid examples:
2546
+
2547
+ Input:
2548
+ {"context": "abc", "answer": "def"}
2549
+ Output:
2550
+ "context: abc answer: def"
2551
+ Input:
2552
+ [{"context": "abc", "answer": "def"}, {"context": "ghi", "answer": "jkl"}]
2553
+ Output:
2554
+ ["context: abc answer: def", "context: ghi answer: jkl"]
2555
+ Input:
2556
+ "abc"
2557
+ Output:
2558
+ "abc"
2559
+ Input:
2560
+ ["abc", "def"]
2561
+ Output:
2562
+ ["abc", "def"]
2563
+ """
2564
+ if isinstance(data, dict) and all(isinstance(value, str) for value in data.values()):
2565
+ if all(isinstance(key, str) for key in data) and "inputs" not in data:
2566
+ # NB: Text2Text Pipelines require submission of text in a pseudo-string based dict
2567
+ # formatting.
2568
+ # As an example, for the input of:
2569
+ # data = {"context": "The sky is blue", "answer": "blue"}
2570
+ # This method will return the Pipeline-required format of:
2571
+ # "context: The sky is blue. answer: blue"
2572
+ return " ".join(f"{key}: {value}" for key, value in data.items())
2573
+ else:
2574
+ return list(data.values())
2575
+ elif isinstance(data, list) and all(isinstance(value, dict) for value in data):
2576
+ return [self._parse_text2text_input(entry) for entry in data]
2577
+ elif isinstance(data, str) or (
2578
+ isinstance(data, list) and all(isinstance(value, str) for value in data)
2579
+ ):
2580
+ return data
2581
+ else:
2582
+ raise MlflowException(
2583
+ f"An invalid type has been supplied: {_truncate_and_ellipsize(data, 100)} "
2584
+ f"(type: {type(data).__name__}). Please supply a Dict[str, str], str, List[str], "
2585
+ "or a List[Dict[str, str]] for a Text2Text Pipeline.",
2586
+ error_code=INVALID_PARAMETER_VALUE,
2587
+ )
2588
+
2589
+ def _parse_json_encoded_list(self, data, key_to_unpack):
2590
+ """
2591
+ Parses the complex input types for pipelines such as ZeroShotClassification in which
2592
+ the required input type is Dict[str, Union[str, List[str]]] wherein the list
2593
+ provided is encoded as JSON. This method unpacks that string to the required
2594
+ elements.
2595
+ """
2596
+ if isinstance(data, list):
2597
+ return [self._parse_json_encoded_list(entry, key_to_unpack) for entry in data]
2598
+ elif isinstance(data, dict):
2599
+ if key_to_unpack not in data:
2600
+ raise MlflowException(
2601
+ "Invalid key in inference payload. The expected inference data key "
2602
+ f"is: {key_to_unpack}",
2603
+ error_code=INVALID_PARAMETER_VALUE,
2604
+ )
2605
+ if isinstance(data[key_to_unpack], str):
2606
+ try:
2607
+ return {
2608
+ k: (json.loads(v) if k == key_to_unpack else v) for k, v in data.items()
2609
+ }
2610
+ except json.JSONDecodeError:
2611
+ return data
2612
+ elif isinstance(data[key_to_unpack], list):
2613
+ return data
2614
+
2615
+ @staticmethod
2616
+ def _parse_json_encoded_dict_payload_to_dict(data, key_to_unpack):
2617
+ """
2618
+ Parses complex dict input types that have been json encoded. Pipelines like
2619
+ TableQuestionAnswering require such input types.
2620
+ """
2621
+ if isinstance(data, list):
2622
+ return [
2623
+ {
2624
+ key: (
2625
+ json.loads(value)
2626
+ if key == key_to_unpack and isinstance(value, str)
2627
+ else value
2628
+ )
2629
+ for key, value in entry.items()
2630
+ }
2631
+ for entry in data
2632
+ ]
2633
+ elif isinstance(data, dict):
2634
+ # This is to handle serving use cases as the DataFrame encapsulation converts
2635
+ # collections within rows to np.array type. In order to process this data through
2636
+ # the transformers.Pipeline API, we need to cast these arrays back to lists
2637
+ # and replace the single quotes with double quotes after extracting the
2638
+ # json-encoded `table` (a pandas DF) in order to convert it to a dict that
2639
+ # the TableQuestionAnsweringPipeline can accept and cast to a Pandas DataFrame.
2640
+ #
2641
+ # An example casting that occurs for this case when input to model serving is the
2642
+ # conversion of a user input of:
2643
+ # '{"inputs": {"query": "What is the longest distance?",
2644
+ # "table": {"Distance": ["1000", "10", "1"]}}}'
2645
+ # is converted to:
2646
+ # [{'query': array('What is the longest distance?', dtype='<U29'),
2647
+ # 'table': array('{\'Distance\': [\'1000\', \'10\', \'1\']}', dtype='U<204')}]
2648
+ # which is an invalid input to the pipeline.
2649
+ # this method converts the input to:
2650
+ # {'query': 'What is the longest distance?',
2651
+ # 'table': {'Distance': ['1000', '10', '1']}}
2652
+ # which is a valid input to the TableQuestionAnsweringPipeline.
2653
+ output = {}
2654
+ for key, value in data.items():
2655
+ if key == key_to_unpack:
2656
+ if isinstance(value, np.ndarray):
2657
+ output[key] = ast.literal_eval(value.item())
2658
+ else:
2659
+ output[key] = ast.literal_eval(value)
2660
+ else:
2661
+ if isinstance(value, np.ndarray):
2662
+ # This cast to np.ndarray occurs when more than one question is asked.
2663
+ output[key] = value.item()
2664
+ else:
2665
+ # Otherwise, the entry does not need casting from a np.ndarray type to
2666
+ # list as it is already a scalar string.
2667
+ output[key] = value
2668
+ return output
2669
+ else:
2670
+ return {
2671
+ key: (
2672
+ json.loads(value) if key == key_to_unpack and isinstance(value, str) else value
2673
+ )
2674
+ for key, value in data.items()
2675
+ }
2676
+
2677
+ @staticmethod
2678
+ def _validate_str_or_list_str(data):
2679
+ if not isinstance(data, (str, list)):
2680
+ raise MlflowException(
2681
+ f"The input data is of an incorrect type. {type(data)} is invalid. "
2682
+ "Must be either string or List[str]",
2683
+ error_code=INVALID_PARAMETER_VALUE,
2684
+ )
2685
+ elif isinstance(data, list) and not all(isinstance(entry, str) for entry in data):
2686
+ raise MlflowException(
2687
+ "If supplying a list, all values must be of string type.",
2688
+ error_code=INVALID_PARAMETER_VALUE,
2689
+ )
2690
+
2691
+ @staticmethod
2692
+ def _convert_cast_lists_from_np_back_to_list(data):
2693
+ """
2694
+ This handles the casting of dicts within lists from Pandas DF conversion within model
2695
+ serving back into the required Dict[str, List[str]] if this type matching occurs.
2696
+ Otherwise, it's a noop.
2697
+ """
2698
+ if not isinstance(data, list):
2699
+ # NB: applying a short-circuit return here to not incur runtime overhead with
2700
+ # type validation if the input is not a list
2701
+ return data
2702
+ elif not all(isinstance(value, dict) for value in data):
2703
+ return data
2704
+ else:
2705
+ parsed_data = []
2706
+ for entry in data:
2707
+ if all(isinstance(value, np.ndarray) for value in entry.values()):
2708
+ parsed_data.append({key: value.tolist() for key, value in entry.items()})
2709
+ else:
2710
+ parsed_data.append(entry)
2711
+ return parsed_data
2712
+
2713
+ @staticmethod
2714
+ def is_base64_image(image):
2715
+ """Check whether input image is a base64 encoded"""
2716
+
2717
+ try:
2718
+ b64_decoded_image = base64.b64decode(image)
2719
+ return (
2720
+ base64.b64encode(b64_decoded_image).decode("utf-8") == image
2721
+ or base64.encodebytes(b64_decoded_image).decode("utf-8") == image
2722
+ )
2723
+ except binascii.Error:
2724
+ return False
2725
+
2726
+ def _convert_image_input(self, input_data):
2727
+ """
2728
+ Conversion utility for decoding the base64 encoded bytes data of a raw image file when
2729
+ parsed through model serving, if applicable. Direct usage of the pyfunc implementation
2730
+ outside of model serving will treat this utility as a noop.
2731
+
2732
+ For reference, the expected encoding for input to Model Serving will be:
2733
+
2734
+ import requests
2735
+ import base64
2736
+
2737
+ response = requests.get("https://www.my.images/a/sound/file.jpg")
2738
+ encoded_image = base64.b64encode(response.content).decode("utf-8")
2739
+
2740
+ inference_data = json.dumps({"inputs": [encoded_image]})
2741
+
2742
+ or
2743
+
2744
+ inference_df = pd.DataFrame(
2745
+ pd.Series([encoded_image], name="image_file")
2746
+ )
2747
+ split_dict = {"dataframe_split": inference_df.to_dict(orient="split")}
2748
+ split_json = json.dumps(split_dict)
2749
+
2750
+ or
2751
+
2752
+ records_dict = {"dataframe_records": inference_df.to_dict(orient="records")}
2753
+ records_json = json.dumps(records_dict)
2754
+
2755
+ This utility will convert this JSON encoded, base64 encoded text back into bytes for
2756
+ input into the Image pipelines for inference.
2757
+ """
2758
+
2759
+ def process_input_element(input_element):
2760
+ input_value = next(iter(input_element.values()))
2761
+ if isinstance(input_value, str) and not self.is_base64_image(input_value):
2762
+ self._validate_str_input_uri_or_file(input_value)
2763
+ return input_value
2764
+
2765
+ if isinstance(input_data, list) and all(
2766
+ isinstance(element, dict) for element in input_data
2767
+ ):
2768
+ # Use a list comprehension for readability
2769
+ # the elimination of empty collection declarations
2770
+ return [process_input_element(element) for element in input_data]
2771
+ elif isinstance(input_data, str) and not self.is_base64_image(input_data):
2772
+ self._validate_str_input_uri_or_file(input_data)
2773
+
2774
+ return input_data
2775
+
2776
+ def _convert_audio_input(
2777
+ self, data: Union[AudioInput, list[dict[int, list[AudioInput]]]]
2778
+ ) -> Union[AudioInput, list[AudioInput]]:
2779
+ """
2780
+ Convert the input data into the format that the Transformers pipeline expects.
2781
+
2782
+ Args:
2783
+ data: The input data to be converted. This can be one of the following:
2784
+ 1. A single input audio data (bytes, numpy array, or a path or URI to an audio file)
2785
+ 2. List of dictionaries, derived from Pandas DataFrame with `orient="records"`.
2786
+ This is the outcome of the pyfunc signature validation for the audio input.
2787
+ E.g. [{[0]: <audio data>}, {[1]: <audio data>}]
2788
+
2789
+ Returns:
2790
+ A single or list of audio data.
2791
+ """
2792
+ if isinstance(data, list):
2793
+ data = [list(element.values())[0] for element in data]
2794
+ decoded = [self._decode_audio(audio) for audio in data]
2795
+ # Signature validation converts a single audio data into a list (via Pandas Series).
2796
+ # We have to unwrap it back not to confuse with batch processing.
2797
+ return decoded if len(decoded) > 1 else decoded[0]
2798
+ else:
2799
+ return self._decode_audio(data)
2800
+
2801
+ def _decode_audio(self, audio: AudioInput) -> AudioInput:
2802
+ """
2803
+ Decode the audio data if it is base64 encoded bytes, otherwise no-op.
2804
+ """
2805
+ if isinstance(audio, str):
2806
+ # Input is an URI to the audio file to be processed.
2807
+ self._validate_str_input_uri_or_file(audio)
2808
+ return audio
2809
+ elif isinstance(audio, np.ndarray):
2810
+ # Input is a numpy array that contains floating point time series of the audio.
2811
+ return audio
2812
+ elif isinstance(audio, bytes):
2813
+ # Input is a bytes object. In model serving, the input audio data is b64encoded.
2814
+ # They are typically decoded before reaching here, but iff the inference payload
2815
+ # contains raw bytes in the key 'inputs', the upstream code will not decode the
2816
+ # bytes. Therefore, we need to decode the bytes here. For other cases like
2817
+ # 'dataframe_records' or 'dataframe_split', the bytes should be already decoded.
2818
+ if self.is_base64_audio(audio):
2819
+ return base64.b64decode(audio)
2820
+ else:
2821
+ return audio
2822
+ else:
2823
+ raise MlflowException(
2824
+ "Invalid audio data. Must be either bytes, str, or np.ndarray.",
2825
+ error_code=INVALID_PARAMETER_VALUE,
2826
+ )
2827
+
2828
+ @staticmethod
2829
+ def is_base64_audio(audio: bytes) -> bool:
2830
+ """Check whether input audio is a base64 encoded"""
2831
+ try:
2832
+ return base64.b64encode(base64.b64decode(audio)) == audio
2833
+ except binascii.Error:
2834
+ return False
2835
+
2836
+ @staticmethod
2837
+ def _validate_str_input_uri_or_file(input_str):
2838
+ """
2839
+ Validation of blob references to either audio or image files,
2840
+ if a string is input to the ``predict``
2841
+ method, perform validation of the string contents by checking for a valid uri or
2842
+ filesystem reference instead of surfacing the cryptic stack trace that is otherwise raised
2843
+ for an invalid uri input.
2844
+ """
2845
+
2846
+ def is_uri(s):
2847
+ try:
2848
+ result = urlparse(s)
2849
+ return all([result.scheme, result.netloc])
2850
+ except ValueError:
2851
+ return False
2852
+
2853
+ valid_uri = os.path.isfile(input_str) or is_uri(input_str)
2854
+
2855
+ if not valid_uri:
2856
+ if len(input_str) <= 20:
2857
+ data_str = f"Received: {input_str}"
2858
+ else:
2859
+ data_str = f"Received (truncated): {input_str[:20]}..."
2860
+ raise MlflowException(
2861
+ "An invalid string input was provided. String inputs to "
2862
+ "audio or image files must be either a file location or a uri."
2863
+ f"audio files must be either a file location or a uri. {data_str}",
2864
+ error_code=BAD_REQUEST,
2865
+ )
2866
+
2867
+ def _format_prompt_template(self, input_data):
2868
+ """
2869
+ Wraps the input data in the specified prompt template. If no template is
2870
+ specified, or if the pipeline is an unsupported type, or if the input type
2871
+ is not a string or list of strings, then the input data is returned unchanged.
2872
+ """
2873
+ if not self.prompt_template:
2874
+ return input_data
2875
+
2876
+ if self.pipeline.task not in _SUPPORTED_PROMPT_TEMPLATING_TASK_TYPES:
2877
+ raise MlflowException(
2878
+ f"_format_prompt_template called on an unexpected pipeline type. "
2879
+ f"Expected one of: {_SUPPORTED_PROMPT_TEMPLATING_TASK_TYPES}. "
2880
+ f"Received: {self.pipeline.task}"
2881
+ )
2882
+
2883
+ if isinstance(input_data, str):
2884
+ return self.prompt_template.format(prompt=input_data)
2885
+ elif isinstance(input_data, list):
2886
+ # if every item is a string, then apply formatting to every item
2887
+ if all(isinstance(data, str) for data in input_data):
2888
+ return [self.prompt_template.format(prompt=data) for data in input_data]
2889
+
2890
+ # throw for unsupported types
2891
+ raise MlflowException.invalid_parameter_value(
2892
+ "Prompt templating is only supported for data of type str or List[str]. "
2893
+ f"Got {type(input_data)} instead."
2894
+ )
2895
+
2896
+
2897
+ @autologging_integration(FLAVOR_NAME)
2898
+ def autolog(
2899
+ log_input_examples=False,
2900
+ log_model_signatures=False,
2901
+ log_models=False,
2902
+ log_datasets=False,
2903
+ disable=False,
2904
+ exclusive=False,
2905
+ disable_for_unsupported_versions=False,
2906
+ silent=False,
2907
+ extra_tags=None,
2908
+ ):
2909
+ """
2910
+ This autologging integration is solely used for disabling spurious autologging of irrelevant
2911
+ sub-models that are created during the training and evaluation of transformers-based models.
2912
+ Autologging functionality is not implemented fully for the transformers flavor.
2913
+ """
2914
+ # A list of other flavors whose base autologging config would be automatically logged due to
2915
+ # training a model that would otherwise create a run and be logged internally within the
2916
+ # transformers-supported trainer calls.
2917
+ DISABLED_ANCILLARY_FLAVOR_AUTOLOGGING = ["sklearn", "tensorflow", "pytorch"]
2918
+
2919
+ def train(original, *args, **kwargs):
2920
+ with disable_discrete_autologging(DISABLED_ANCILLARY_FLAVOR_AUTOLOGGING):
2921
+ return original(*args, **kwargs)
2922
+
2923
+ with contextlib.suppress(ImportError):
2924
+ import setfit
2925
+
2926
+ safe_patch(
2927
+ FLAVOR_NAME,
2928
+ (
2929
+ setfit.SetFitTrainer
2930
+ if Version(setfit.__version__) < Version("1.0.0")
2931
+ else setfit.Trainer
2932
+ ),
2933
+ "train",
2934
+ functools.partial(train),
2935
+ manage_run=False,
2936
+ )
2937
+
2938
+ with contextlib.suppress(ImportError):
2939
+ import transformers
2940
+
2941
+ classes = [transformers.Trainer, transformers.Seq2SeqTrainer]
2942
+ methods = ["train"]
2943
+ for clazz in classes:
2944
+ for method in methods:
2945
+ safe_patch(FLAVOR_NAME, clazz, method, functools.partial(train), manage_run=False)
2946
+
2947
+
2948
+ def _get_prompt_template(model_path):
2949
+ if not os.path.exists(model_path):
2950
+ raise MlflowException(
2951
+ f'Could not find an "{MLMODEL_FILE_NAME}" configuration file at "{model_path}"',
2952
+ RESOURCE_DOES_NOT_EXIST,
2953
+ )
2954
+
2955
+ model_conf = Model.load(model_path)
2956
+ if model_conf.metadata:
2957
+ return model_conf.metadata.get(FlavorKey.PROMPT_TEMPLATE)
2958
+
2959
+ return None
2960
+
2961
+
2962
+ def _validate_prompt_template(prompt_template):
2963
+ if prompt_template is None:
2964
+ return
2965
+
2966
+ if not isinstance(prompt_template, str):
2967
+ raise MlflowException(
2968
+ f"Argument `prompt_template` must be a string, received {type(prompt_template)}",
2969
+ INVALID_PARAMETER_VALUE,
2970
+ )
2971
+
2972
+ format_args = [
2973
+ tup[1] for tup in string.Formatter().parse(prompt_template) if tup[1] is not None
2974
+ ]
2975
+
2976
+ # expect there to only be one format arg, and for that arg to be "prompt"
2977
+ if format_args != ["prompt"]:
2978
+ raise MlflowException.invalid_parameter_value(
2979
+ "Argument `prompt_template` must be a string with a single format arg, 'prompt'. "
2980
+ "For example: 'Answer the following question in a friendly tone. Q: {prompt}. A:'\n"
2981
+ f"Received {prompt_template}. "
2982
+ )