genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,301 @@
1
+ import logging
2
+ import shutil
3
+
4
+ from mlflow.environment_variables import (
5
+ MLFLOW_HUGGINGFACE_DISABLE_ACCELERATE_FEATURES,
6
+ MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE,
7
+ )
8
+ from mlflow.exceptions import MlflowException
9
+ from mlflow.protos.databricks_pb2 import INVALID_STATE
10
+ from mlflow.transformers.flavor_config import FlavorKey, get_peft_base_model, is_peft_model
11
+
12
+ _logger = logging.getLogger(__name__)
13
+
14
+ # File/directory names for saved artifacts
15
+ _MODEL_BINARY_FILE_NAME = "model"
16
+ _COMPONENTS_BINARY_DIR_NAME = "components"
17
+ _PROCESSOR_BINARY_DIR_NAME = "processor"
18
+
19
+
20
+ def save_pipeline_pretrained_weights(path, pipeline, flavor_conf, processor=None):
21
+ """
22
+ Save the binary artifacts of the pipeline to the specified local path.
23
+
24
+ Args:
25
+ path: The local path to save the pipeline
26
+ pipeline: Transformers pipeline instance
27
+ flavor_conf: The flavor configuration constructed for the pipeline
28
+ processor: Optional processor instance to save alongside the pipeline
29
+ """
30
+ model = get_peft_base_model(pipeline.model) if is_peft_model(pipeline.model) else pipeline.model
31
+
32
+ model.save_pretrained(
33
+ save_directory=path.joinpath(_MODEL_BINARY_FILE_NAME),
34
+ max_shard_size=MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.get(),
35
+ )
36
+
37
+ component_dir = path.joinpath(_COMPONENTS_BINARY_DIR_NAME)
38
+ for name in flavor_conf.get(FlavorKey.COMPONENTS, []):
39
+ getattr(pipeline, name).save_pretrained(component_dir.joinpath(name))
40
+
41
+ if processor:
42
+ processor.save_pretrained(component_dir.joinpath(_PROCESSOR_BINARY_DIR_NAME))
43
+
44
+
45
+ def save_local_checkpoint(path, checkpoint_dir, flavor_conf, processor=None):
46
+ """
47
+ Save the local checkpoint of the model and other components to the specified local path.
48
+
49
+ Args:
50
+ path: The local path to save the pipeline
51
+ checkpoint_dir: The local path to the checkpoint directory
52
+ flavor_conf: The flavor configuration constructed for the pipeline
53
+ processor: Optional processor instance to save alongside the pipeline
54
+ """
55
+ # Copy files within checkpoint dir to the model path
56
+ shutil.copytree(checkpoint_dir, path.joinpath(_MODEL_BINARY_FILE_NAME))
57
+
58
+ for name in flavor_conf.get(FlavorKey.COMPONENTS, []):
59
+ # Other pipeline components such as tokenizer may not saved in the checkpoint.
60
+ # We first try to load the component instance from the checkpoint directory,
61
+ # if it fails, we load the component from the HuggingFace Hub.
62
+ try:
63
+ component = _load_component(flavor_conf, name, local_path=checkpoint_dir)
64
+ except Exception:
65
+ repo_id = flavor_conf[FlavorKey.MODEL_NAME]
66
+ _logger.info(
67
+ f"The {name} state file is not found ins the local checkpoint directory. MLflow "
68
+ f"will use the default component state from the base HF repository {repo_id}."
69
+ )
70
+ component = _load_component(flavor_conf, name, repo_id=repo_id)
71
+
72
+ component.save_pretrained(path.joinpath(_COMPONENTS_BINARY_DIR_NAME, name))
73
+
74
+ if processor:
75
+ processor.save_pretrained(
76
+ path.joinpath(_COMPONENTS_BINARY_DIR_NAME, _PROCESSOR_BINARY_DIR_NAME)
77
+ )
78
+
79
+
80
+ def load_model_and_components_from_local(path, flavor_conf, accelerate_conf, device=None):
81
+ """
82
+ Load the model and components of a Transformer pipeline from the specified local path.
83
+
84
+ Args:
85
+ path: The local path contains MLflow model artifacts
86
+ flavor_conf: The flavor configuration
87
+ accelerate_conf: The configuration for the accelerate library
88
+ device: The device to load the model onto
89
+ """
90
+ loaded = {}
91
+
92
+ # NB: Path resolution for models that were saved prior to 2.4.1 release when the patching for
93
+ # the saved pipeline or component artifacts was handled by duplicate entries for components
94
+ # (artifacts/pipeline/* and artifacts/components/*) and pipelines were saved via the
95
+ # "artifacts/pipeline/*" path. In order to load the older formats after the change, the
96
+ # presence of the new path key is checked.
97
+ model_path = path.joinpath(flavor_conf.get(FlavorKey.MODEL_BINARY, "pipeline"))
98
+ loaded[FlavorKey.MODEL] = _load_model(model_path, flavor_conf, accelerate_conf, device)
99
+
100
+ components = flavor_conf.get(FlavorKey.COMPONENTS, [])
101
+ if FlavorKey.PROCESSOR_TYPE in flavor_conf:
102
+ components.append("processor")
103
+
104
+ for component_key in components:
105
+ component_path = path.joinpath(_COMPONENTS_BINARY_DIR_NAME, component_key)
106
+ loaded[component_key] = _load_component(
107
+ flavor_conf, component_key, local_path=component_path
108
+ )
109
+
110
+ return loaded
111
+
112
+
113
+ def load_model_and_components_from_huggingface_hub(flavor_conf, accelerate_conf, device=None):
114
+ """
115
+ Load the model and components of a Transformer pipeline from HuggingFace Hub.
116
+
117
+ Args:
118
+ flavor_conf: The flavor configuration
119
+ accelerate_conf: The configuration for the accelerate library
120
+ device: The device to load the model onto
121
+ """
122
+ loaded = {}
123
+
124
+ model_repo = flavor_conf[FlavorKey.MODEL_NAME]
125
+ model_revision = flavor_conf.get(FlavorKey.MODEL_REVISION)
126
+
127
+ if not model_revision:
128
+ raise MlflowException(
129
+ "The model was saved with 'save_pretrained' set to False, but the commit hash is not "
130
+ "found in the saved metadata. Loading the model with the different version may cause "
131
+ "inconsistency issue and security risk.",
132
+ error_code=INVALID_STATE,
133
+ )
134
+
135
+ loaded[FlavorKey.MODEL] = _load_model(
136
+ model_repo, flavor_conf, accelerate_conf, device, revision=model_revision
137
+ )
138
+
139
+ components = flavor_conf.get(FlavorKey.COMPONENTS, [])
140
+ if FlavorKey.PROCESSOR_TYPE in flavor_conf:
141
+ components.append("processor")
142
+
143
+ for name in components:
144
+ loaded[name] = _load_component(flavor_conf, name)
145
+
146
+ return loaded
147
+
148
+
149
+ def _load_component(flavor_conf, name, local_path=None, repo_id=None):
150
+ import transformers
151
+
152
+ _COMPONENT_TO_AUTOCLASS_MAP = {
153
+ FlavorKey.TOKENIZER: transformers.AutoTokenizer,
154
+ FlavorKey.FEATURE_EXTRACTOR: transformers.AutoFeatureExtractor,
155
+ FlavorKey.PROCESSOR: transformers.AutoProcessor,
156
+ FlavorKey.IMAGE_PROCESSOR: transformers.AutoImageProcessor,
157
+ }
158
+
159
+ component_name = flavor_conf[FlavorKey.COMPONENT_TYPE.format(name)]
160
+ if hasattr(transformers, component_name):
161
+ cls = getattr(transformers, component_name)
162
+ trust_remote = False
163
+ else:
164
+ if local_path is None:
165
+ raise MlflowException(
166
+ f"A custom component `{component_name}` was specified, "
167
+ "but no local config file was found to retrieve the "
168
+ "definition. Make sure your model was saved with "
169
+ "save_pretrained=True."
170
+ )
171
+ cls = _COMPONENT_TO_AUTOCLASS_MAP[name]
172
+ trust_remote = True
173
+
174
+ if local_path is not None:
175
+ # Load component from local file
176
+ return cls.from_pretrained(str(local_path), trust_remote_code=trust_remote)
177
+ else:
178
+ # Load component from HuggingFace Hub
179
+ repo = repo_id or flavor_conf[FlavorKey.COMPONENT_NAME.format(name)]
180
+ revision = flavor_conf.get(FlavorKey.COMPONENT_REVISION.format(name))
181
+ return cls.from_pretrained(repo, revision=revision, trust_remote_code=trust_remote)
182
+
183
+
184
+ def _load_class_from_transformers_config(model_name_or_path, revision=None):
185
+ """
186
+ This method retrieves the Transformers AutoClass from the transformers config.
187
+ Using the correct AutoClass allows us to leverage Transformers' model loading
188
+ machinery, which is necessary for supporting models using custom code.
189
+ """
190
+ import transformers
191
+ from transformers import AutoConfig
192
+
193
+ config = AutoConfig.from_pretrained(
194
+ model_name_or_path,
195
+ revision=revision,
196
+ # trust_remote_code is set to True in order to
197
+ # make sure the config gets loaded as the correct
198
+ # class. if this is not set for custom models, the
199
+ # base class will be loaded instead of the custom one.
200
+ trust_remote_code=True,
201
+ )
202
+
203
+ # the model's class name (e.g. "MPTForCausalLM")
204
+ # is stored in the `architectures` field. it
205
+ # seems to usually just have one element.
206
+ class_name = config.architectures[0]
207
+
208
+ # if the class is available in transformers natively,
209
+ # then we don't need to execute any custom code.
210
+ if hasattr(transformers, class_name):
211
+ cls = getattr(transformers, class_name)
212
+ return cls, False
213
+ else:
214
+ # else, we need to fetch the correct AutoClass.
215
+ # this is defined in the `auto_map` field. there
216
+ # should only be one AutoClass that maps to the
217
+ # model's class name.
218
+ auto_classes = [
219
+ auto_class
220
+ for auto_class, module in config.auto_map.items()
221
+ if module.split(".")[-1] == class_name
222
+ ]
223
+
224
+ if len(auto_classes) == 0:
225
+ raise MlflowException(f"Couldn't find a loader class for {class_name}")
226
+
227
+ auto_class = auto_classes[0]
228
+ cls = getattr(transformers, auto_class)
229
+
230
+ # we will need to trust remote code when loading the model
231
+ return cls, True
232
+
233
+
234
+ def _load_model(model_name_or_path, flavor_conf, accelerate_conf, device, revision=None):
235
+ """
236
+ Try to load a model with various loading strategies.
237
+ 1. Try to load the model with accelerate
238
+ 2. Try to load the model with the specified device
239
+ 3. Load the model without the device
240
+ """
241
+ import transformers
242
+
243
+ if hasattr(transformers, flavor_conf[FlavorKey.MODEL_TYPE]):
244
+ cls = getattr(transformers, flavor_conf[FlavorKey.MODEL_TYPE])
245
+ trust_remote = False
246
+ else:
247
+ cls, trust_remote = _load_class_from_transformers_config(
248
+ model_name_or_path, revision=revision
249
+ )
250
+
251
+ load_kwargs = {"revision": revision} if revision else {}
252
+ if trust_remote:
253
+ load_kwargs.update({"trust_remote_code": True})
254
+
255
+ if model := _try_load_model_with_accelerate(
256
+ cls, model_name_or_path, {**accelerate_conf, **load_kwargs}
257
+ ):
258
+ return model
259
+
260
+ load_kwargs["device"] = device
261
+ if torch_dtype := flavor_conf.get(FlavorKey.TORCH_DTYPE):
262
+ load_kwargs[FlavorKey.TORCH_DTYPE] = torch_dtype
263
+
264
+ if model := _try_load_model_with_device(cls, model_name_or_path, load_kwargs):
265
+ return model
266
+ _logger.warning(
267
+ "Could not specify device parameter for this pipeline type."
268
+ "Falling back to loading the model with the default device."
269
+ )
270
+
271
+ load_kwargs.pop("device", None)
272
+ return cls.from_pretrained(model_name_or_path, **load_kwargs)
273
+
274
+
275
+ def _try_load_model_with_accelerate(model_class, model_name_or_path, load_kwargs):
276
+ if MLFLOW_HUGGINGFACE_DISABLE_ACCELERATE_FEATURES.get():
277
+ return None
278
+
279
+ try:
280
+ return model_class.from_pretrained(model_name_or_path, **load_kwargs)
281
+ except (ValueError, TypeError, NotImplementedError, ImportError):
282
+ # NB: ImportError is caught here in the event that `accelerate` is not installed
283
+ # on the system, which will raise if `low_cpu_mem_usage` is set or the argument
284
+ # `device_map` is set and accelerate is not installed.
285
+ pass
286
+
287
+
288
+ def _try_load_model_with_device(model_class, model_name_or_path, load_kwargs):
289
+ try:
290
+ return model_class.from_pretrained(model_name_or_path, **load_kwargs)
291
+ except OSError as e:
292
+ revision = load_kwargs.get("revision")
293
+ if f"{revision} is not a valid git identifier" in str(e):
294
+ raise MlflowException(
295
+ f"The model was saved with a HuggingFace Hub repository name '{model_name_or_path}'"
296
+ f"and a commit hash '{revision}', but the commit is not found in the repository. "
297
+ )
298
+ else:
299
+ raise e
300
+ except (ValueError, TypeError, NotImplementedError):
301
+ pass
@@ -0,0 +1,51 @@
1
+ """
2
+ PEFT (Parameter-Efficient Fine-Tuning) is a library for efficiently adapting large pretrained
3
+ models without fine-tuning all of model parameters but only a small number of (extra) parameters.
4
+ Users can define a PEFT model that wraps a Transformer model to apply a thin adapter layer on
5
+ top of the base model. The PEFT model provides almost the same APIs as the original model such
6
+ as from_pretrained(), save_pretrained().
7
+ """
8
+
9
+ _PEFT_ADAPTOR_DIR_NAME = "peft"
10
+
11
+
12
+ def is_peft_model(model) -> bool:
13
+ try:
14
+ from peft import PeftModel
15
+ except ImportError:
16
+ return False
17
+
18
+ return isinstance(model, PeftModel)
19
+
20
+
21
+ def get_peft_base_model(model):
22
+ """Extract the base model from a PEFT model."""
23
+ peft_config = model.peft_config.get(model.active_adapter) if model.peft_config else None
24
+
25
+ # PEFT usually wraps the base model with two additional classes, one is PeftModel class
26
+ # and the other is the adaptor specific class, like LoraModel class, so the class hierarchy
27
+ # looks like PeftModel -> LoraModel -> BaseModel
28
+ # However, when the PEFT config is the one for "prompt learning", there is not adaptor class
29
+ # and the PeftModel class directly wraps the base model.
30
+ if peft_config and not peft_config.is_prompt_learning:
31
+ return model.base_model.model
32
+
33
+ return model.base_model
34
+
35
+
36
+ def get_model_with_peft_adapter(base_model, peft_adapter_path):
37
+ """
38
+ Apply the PEFT adapter to the base model to create a PEFT model.
39
+
40
+ NB: The alternative way to load PEFT adapter is to use load_adapter API like
41
+ `base_model.load_adapter(peft_adapter_path)`, as it injects the adapter weights
42
+ into the model in-place hence reducing the memory footprint. However, doing so
43
+ returns the base model class and not the PEFT model, loosing some properties
44
+ such as peft_config. This is not preferable because load_model API should
45
+ return the exact same object that was saved. Hence we construct the PEFT model
46
+ instead of in-place injection, for consistency over the memory saving which
47
+ should be small in most cases.
48
+ """
49
+ from peft import PeftModel
50
+
51
+ return PeftModel.from_pretrained(base_model, peft_adapter_path)
@@ -0,0 +1,183 @@
1
+ import json
2
+ import logging
3
+
4
+ import numpy as np
5
+
6
+ from mlflow.environment_variables import MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT
7
+ from mlflow.models.signature import ModelSignature, infer_signature
8
+ from mlflow.models.utils import _contains_params
9
+ from mlflow.types.schema import ColSpec, DataType, Schema, TensorSpec
10
+ from mlflow.utils.os import is_windows
11
+ from mlflow.utils.timeout import MlflowTimeoutError, run_with_timeout
12
+
13
+ _logger = logging.getLogger(__name__)
14
+
15
+
16
+ _TEXT2TEXT_SIGNATURE = ModelSignature(
17
+ inputs=Schema([ColSpec("string")]),
18
+ outputs=Schema([ColSpec("string")]),
19
+ )
20
+ _CLASSIFICATION_SIGNATURE = ModelSignature(
21
+ inputs=Schema([ColSpec("string")]),
22
+ outputs=Schema([ColSpec("string", name="label"), ColSpec("double", name="score")]),
23
+ )
24
+
25
+ # Order is important here, the first matching task type will be used
26
+ _DEFAULT_SIGNATURE_FOR_TASK = {
27
+ "token-classification": _TEXT2TEXT_SIGNATURE,
28
+ "translation": _TEXT2TEXT_SIGNATURE,
29
+ "text-generation": _TEXT2TEXT_SIGNATURE,
30
+ "text2text-generation": _TEXT2TEXT_SIGNATURE,
31
+ "text-classification": _CLASSIFICATION_SIGNATURE,
32
+ "conversational": _TEXT2TEXT_SIGNATURE,
33
+ "fill-mask": _TEXT2TEXT_SIGNATURE,
34
+ "summarization": _TEXT2TEXT_SIGNATURE,
35
+ "image-classification": _CLASSIFICATION_SIGNATURE,
36
+ "zero-shot-classification": ModelSignature(
37
+ inputs=Schema(
38
+ [
39
+ ColSpec(DataType.string, name="sequences"),
40
+ ColSpec(DataType.string, name="candidate_labels"),
41
+ ColSpec(DataType.string, name="hypothesis_template"),
42
+ ]
43
+ ),
44
+ outputs=Schema(
45
+ [
46
+ ColSpec(DataType.string, name="sequence"),
47
+ ColSpec(DataType.string, name="labels"),
48
+ ColSpec(DataType.double, name="scores"),
49
+ ]
50
+ ),
51
+ ),
52
+ "automatic-speech-recognition": ModelSignature(
53
+ inputs=Schema([ColSpec(DataType.binary)]),
54
+ outputs=Schema([ColSpec(DataType.string)]),
55
+ ),
56
+ "audio-classification": ModelSignature(
57
+ inputs=Schema([ColSpec(DataType.binary)]),
58
+ outputs=Schema(
59
+ [ColSpec(DataType.double, name="score"), ColSpec(DataType.string, name="label")]
60
+ ),
61
+ ),
62
+ "table-question-answering": ModelSignature(
63
+ inputs=Schema(
64
+ [ColSpec(DataType.string, name="query"), ColSpec(DataType.string, name="table")]
65
+ ),
66
+ outputs=Schema([ColSpec(DataType.string)]),
67
+ ),
68
+ "question-answering": ModelSignature(
69
+ inputs=Schema(
70
+ [ColSpec(DataType.string, name="question"), ColSpec(DataType.string, name="context")]
71
+ ),
72
+ outputs=Schema([ColSpec(DataType.string)]),
73
+ ),
74
+ "feature-extraction": ModelSignature(
75
+ inputs=Schema([ColSpec(DataType.string)]),
76
+ outputs=Schema([TensorSpec(np.dtype("float64"), [-1], "double")]),
77
+ ),
78
+ }
79
+
80
+
81
+ def infer_or_get_default_signature(
82
+ pipeline, example=None, model_config=None, flavor_config=None
83
+ ) -> ModelSignature:
84
+ """
85
+ Assigns a default ModelSignature for a given Pipeline type that has pyfunc support. These
86
+ default signatures should only be generated and assigned when saving a model iff the user
87
+ has not supplied a signature.
88
+ For signature inference in some Pipelines that support complex input types, an input example
89
+ is needed.
90
+ """
91
+ import transformers
92
+
93
+ if example is not None and isinstance(pipeline, transformers.Pipeline):
94
+ try:
95
+ timeout = MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT.get()
96
+ if timeout and is_windows():
97
+ timeout = None
98
+ _logger.warning(
99
+ "On Windows, timeout is not supported for model signature inference. "
100
+ "Therefore, the operation is not bound by a timeout and may hang indefinitely. "
101
+ "If it hangs, please consider specifying the signature manually."
102
+ )
103
+ return _infer_signature_with_example(
104
+ pipeline, example, model_config, flavor_config, timeout
105
+ )
106
+ except Exception as e:
107
+ if isinstance(e, MlflowTimeoutError):
108
+ msg = (
109
+ "Attempted to generate a signature for the saved pipeline but prediction timed "
110
+ f"out after {timeout} seconds. Falling back to the default signature for the "
111
+ "pipeline. You can specify a signature manually or increase the timeout "
112
+ f"by setting the environment variable {MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT}"
113
+ )
114
+ else:
115
+ msg = (
116
+ "Attempted to generate a signature for the saved pipeline but encountered an "
117
+ f"error. Fall back to the default signature for the pipeline type. Error: {e}"
118
+ )
119
+ _logger.warning(msg)
120
+
121
+ task = getattr(pipeline, "task", None)
122
+ if task.startswith("translation_"):
123
+ task = "translation"
124
+ if signature := _DEFAULT_SIGNATURE_FOR_TASK.get(task):
125
+ return signature
126
+
127
+ _logger.warning(
128
+ "An unsupported task type was supplied for signature inference. Either provide an "
129
+ "`input_example` or generate a signature manually via `infer_signature` to have a "
130
+ "signature recorded in the MLmodel file."
131
+ )
132
+
133
+
134
+ def _infer_signature_with_example(
135
+ pipeline, example, model_config=None, flavor_config=None, timeout=None
136
+ ) -> ModelSignature:
137
+ params = None
138
+ if _contains_params(example):
139
+ example, params = example
140
+ example = format_input_example_for_special_cases(example, pipeline)
141
+
142
+ if timeout:
143
+ _logger.info(
144
+ "Running model prediction to infer the model output signature with a timeout "
145
+ f"of {timeout} seconds. You can specify a different timeout by setting the "
146
+ f"environment variable {MLFLOW_INPUT_EXAMPLE_INFERENCE_TIMEOUT}."
147
+ )
148
+ with run_with_timeout(timeout):
149
+ prediction = generate_signature_output(
150
+ pipeline, example, model_config, flavor_config, params
151
+ )
152
+ else:
153
+ prediction = generate_signature_output(
154
+ pipeline, example, model_config, flavor_config, params
155
+ )
156
+ return infer_signature(example, prediction, params)
157
+
158
+
159
+ def format_input_example_for_special_cases(input_example, pipeline):
160
+ """
161
+ Handles special formatting for specific types of Pipelines so that the displayed example
162
+ reflects the correct example input structure that mirrors the behavior of the input parsing
163
+ for pyfunc.
164
+ """
165
+ input_data = input_example[0] if isinstance(input_example, tuple) else input_example
166
+
167
+ if (
168
+ pipeline.task == "zero-shot-classification"
169
+ and isinstance(input_data, dict)
170
+ and isinstance(input_data["candidate_labels"], list)
171
+ ):
172
+ input_data["candidate_labels"] = json.dumps(input_data["candidate_labels"])
173
+ return input_data if not isinstance(input_example, tuple) else (input_data, input_example[1])
174
+
175
+
176
+ def generate_signature_output(pipeline, data, model_config=None, flavor_config=None, params=None):
177
+ # Lazy import to avoid circular dependencies. Ideally we should move _TransformersWrapper
178
+ # out from __init__.py to avoid this.
179
+ from mlflow.transformers import _TransformersWrapper
180
+
181
+ return _TransformersWrapper(
182
+ pipeline=pipeline, model_config=model_config, flavor_config=flavor_config
183
+ ).predict(data, params=params)
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from mlflow.exceptions import MlflowException
6
+ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
7
+
8
+ if TYPE_CHECKING:
9
+ import torch
10
+
11
+ _TORCH_DTYPE_KEY = "torch_dtype"
12
+
13
+
14
+ def _extract_torch_dtype_if_set(pipeline) -> Optional[torch.dtype]:
15
+ """
16
+ Extract the torch datatype argument if set and return as a string encoded value.
17
+ """
18
+ try:
19
+ import torch
20
+ except ImportError:
21
+ # If torch is not installed, safe to assume the model doesn't have a custom torch_dtype
22
+ return None
23
+
24
+ # Check model dtype as pipeline's torch_dtype field doesn't always reflect the model's dtype
25
+ model_dtype = pipeline.model.dtype if hasattr(pipeline.model, "dtype") else None
26
+
27
+ # If the underlying model is PyTorch model, dtype must be a torch.dtype instance
28
+ return model_dtype if isinstance(model_dtype, torch.dtype) else None
29
+
30
+
31
+ def _deserialize_torch_dtype(dtype_str: str) -> torch.dtype:
32
+ """
33
+ Convert the string-encoded `torch_dtype` pipeline argument back to the correct `torch.dtype`
34
+ instance value for applying to a loaded pipeline instance.
35
+ """
36
+ try:
37
+ import torch
38
+ except ImportError as e:
39
+ raise MlflowException(
40
+ "Unable to determine if the value supplied by the argument "
41
+ "torch_dtype is valid since torch is not installed.",
42
+ error_code=INVALID_PARAMETER_VALUE,
43
+ ) from e
44
+
45
+ if dtype_str.startswith("torch."):
46
+ dtype_str = dtype_str[6:]
47
+
48
+ dtype = getattr(torch, dtype_str, None)
49
+ if isinstance(dtype, torch.dtype):
50
+ return dtype
51
+
52
+ raise MlflowException(
53
+ f"The value '{dtype_str}' is not a valid torch.dtype",
54
+ error_code=INVALID_PARAMETER_VALUE,
55
+ )
@@ -0,0 +1,21 @@
1
+ """
2
+ The :py:mod:`mlflow.types` module defines data types and utilities to be used by other mlflow
3
+ components to describe interface independent of other frameworks or languages.
4
+ """
5
+
6
+ from mlflow.version import IS_TRACING_SDK_ONLY
7
+
8
+ if not IS_TRACING_SDK_ONLY:
9
+ import mlflow.types.llm # noqa: F401
10
+
11
+ # Our typing system depends on numpy, which is not included in mlflow-tracing package
12
+ from mlflow.types.schema import ColSpec, DataType, ParamSchema, ParamSpec, Schema, TensorSpec
13
+
14
+ __all__ = [
15
+ "Schema",
16
+ "ColSpec",
17
+ "DataType",
18
+ "TensorSpec",
19
+ "ParamSchema",
20
+ "ParamSpec",
21
+ ]