genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,3844 @@
1
+ """
2
+ The ``python_function`` model flavor serves as a default model interface for MLflow Python models.
3
+ Any MLflow Python model is expected to be loadable as a ``python_function`` model.
4
+
5
+ In addition, the ``mlflow.pyfunc`` module defines a generic :ref:`filesystem format
6
+ <pyfunc-filesystem-format>` for Python models and provides utilities for saving to and loading from
7
+ this format. The format is self contained in the sense that it includes all necessary information
8
+ for anyone to load it and use it. Dependencies are either stored directly with the model or
9
+ referenced via a Conda environment.
10
+
11
+ The ``mlflow.pyfunc`` module also defines utilities for creating custom ``pyfunc`` models
12
+ using frameworks and inference logic that may not be natively included in MLflow. See
13
+ :ref:`pyfunc-create-custom`.
14
+
15
+ .. _pyfunc-inference-api:
16
+
17
+ *************
18
+ Inference API
19
+ *************
20
+
21
+ Python function models are loaded as an instance of :py:class:`PyFuncModel
22
+ <mlflow.pyfunc.PyFuncModel>`, which is an MLflow wrapper around the model implementation and model
23
+ metadata (MLmodel file). You can score the model by calling the :py:func:`predict()
24
+ <mlflow.pyfunc.PyFuncModel.predict>` method, which has the following signature::
25
+
26
+ predict(
27
+ model_input: [pandas.DataFrame, numpy.ndarray, scipy.sparse.(csc_matrix | csr_matrix),
28
+ List[Any], Dict[str, Any], pyspark.sql.DataFrame]
29
+ ) -> [numpy.ndarray | pandas.(Series | DataFrame) | List | Dict | pyspark.sql.DataFrame]
30
+
31
+ All PyFunc models will support `pandas.DataFrame` as input and PyFunc deep learning models will
32
+ also support tensor inputs in the form of Dict[str, numpy.ndarray] (named tensors) and
33
+ `numpy.ndarrays` (unnamed tensors).
34
+
35
+ Here are some examples of supported inference types, assuming we have the correct ``model`` object
36
+ loaded.
37
+
38
+ .. list-table::
39
+ :widths: 30 70
40
+ :header-rows: 1
41
+ :class: wrap-table
42
+
43
+ * - Input Type
44
+ - Example
45
+ * - ``pandas.DataFrame``
46
+ -
47
+ .. code-block:: python
48
+
49
+ import pandas as pd
50
+
51
+ x_new = pd.DataFrame(dict(x1=[1, 2, 3], x2=[4, 5, 6]))
52
+ model.predict(x_new)
53
+
54
+ * - ``numpy.ndarray``
55
+ -
56
+ .. code-block:: python
57
+
58
+ import numpy as np
59
+
60
+ x_new = np.array([[1, 4][2, 5], [3, 6]])
61
+ model.predict(x_new)
62
+
63
+ * - ``scipy.sparse.csc_matrix`` or ``scipy.sparse.csr_matrix``
64
+ -
65
+ .. code-block:: python
66
+
67
+ import scipy
68
+
69
+ x_new = scipy.sparse.csc_matrix([[1, 2, 3], [4, 5, 6]])
70
+ model.predict(x_new)
71
+
72
+ x_new = scipy.sparse.csr_matrix([[1, 2, 3], [4, 5, 6]])
73
+ model.predict(x_new)
74
+
75
+ * - python ``List``
76
+ -
77
+ .. code-block:: python
78
+
79
+ x_new = [[1, 4], [2, 5], [3, 6]]
80
+ model.predict(x_new)
81
+
82
+ * - python ``Dict``
83
+ -
84
+ .. code-block:: python
85
+
86
+ x_new = dict(x1=[1, 2, 3], x2=[4, 5, 6])
87
+ model.predict(x_new)
88
+
89
+ * - ``pyspark.sql.DataFrame``
90
+ -
91
+ .. code-block:: python
92
+
93
+ from pyspark.sql import SparkSession
94
+
95
+ spark = SparkSession.builder.getOrCreate()
96
+
97
+ data = [(1, 4), (2, 5), (3, 6)] # List of tuples
98
+ x_new = spark.createDataFrame(data, ["x1", "x2"]) # Specify column name
99
+ model.predict(x_new)
100
+
101
+ .. _pyfunc-filesystem-format:
102
+
103
+ *****************
104
+ Filesystem format
105
+ *****************
106
+
107
+ The Pyfunc format is defined as a directory structure containing all required data, code, and
108
+ configuration::
109
+
110
+ ./dst-path/
111
+ ./MLmodel: configuration
112
+ <code>: code packaged with the model (specified in the MLmodel file)
113
+ <data>: data packaged with the model (specified in the MLmodel file)
114
+ <env>: Conda environment definition (specified in the MLmodel file)
115
+
116
+ The directory structure may contain additional contents that can be referenced by the ``MLmodel``
117
+ configuration.
118
+
119
+ .. _pyfunc-model-config:
120
+
121
+ MLModel configuration
122
+ #####################
123
+
124
+ A Python model contains an ``MLmodel`` file in **python_function** format in its root with the
125
+ following parameters:
126
+
127
+ - loader_module [required]:
128
+ Python module that can load the model. Expected as module identifier
129
+ e.g. ``mlflow.sklearn``, it will be imported using ``importlib.import_module``.
130
+ The imported module must contain a function with the following signature::
131
+
132
+ _load_pyfunc(path: string) -> <pyfunc model implementation>
133
+
134
+ The path argument is specified by the ``data`` parameter and may refer to a file or
135
+ directory. The model implementation is expected to be an object with a
136
+ ``predict`` method with the following signature::
137
+
138
+ predict(
139
+ model_input: [pandas.DataFrame, numpy.ndarray,
140
+ scipy.sparse.(csc_matrix | csr_matrix), List[Any], Dict[str, Any]],
141
+ pyspark.sql.DataFrame
142
+ ) -> [numpy.ndarray | pandas.(Series | DataFrame) | List | Dict | pyspark.sql.DataFrame]
143
+
144
+ - code [optional]:
145
+ Relative path to a directory containing the code packaged with this model.
146
+ All files and directories inside this directory are added to the Python path
147
+ prior to importing the model loader.
148
+
149
+ - data [optional]:
150
+ Relative path to a file or directory containing model data.
151
+ The path is passed to the model loader.
152
+
153
+ - env [optional]:
154
+ Relative path to an exported Conda environment. If present this environment
155
+ should be activated prior to running the model.
156
+
157
+ - Optionally, any additional parameters necessary for interpreting the serialized model in
158
+ ``pyfunc`` format.
159
+
160
+ .. rubric:: Example
161
+
162
+ ::
163
+
164
+ tree example/sklearn_iris/mlruns/run1/outputs/linear-lr
165
+
166
+ ::
167
+
168
+ ├── MLmodel
169
+ ├── code
170
+ │ ├── sklearn_iris.py
171
+
172
+ ├── data
173
+ │ └── model.pkl
174
+ └── mlflow_env.yml
175
+
176
+ ::
177
+
178
+ cat example/sklearn_iris/mlruns/run1/outputs/linear-lr/MLmodel
179
+
180
+ ::
181
+
182
+ python_function:
183
+ code: code
184
+ data: data/model.pkl
185
+ loader_module: mlflow.sklearn
186
+ env: mlflow_env.yml
187
+ main: sklearn_iris
188
+
189
+ .. _pyfunc-create-custom:
190
+
191
+ **********************************
192
+ Models From Code for Custom Models
193
+ **********************************
194
+
195
+ .. tip::
196
+
197
+ MLflow 2.12.2 introduced the feature "models from code", which greatly simplifies the process
198
+ of serializing and deploying custom models through the use of script serialization. It is
199
+ strongly recommended to migrate custom model implementations to this new paradigm to avoid the
200
+ limitations and complexity of serializing with cloudpickle.
201
+ You can learn more about models from code within the
202
+ `Models From Code Guide <../model/models-from-code.html>`_.
203
+
204
+ The section below illustrates the process of using the legacy serializer for custom Pyfunc models.
205
+ Models from code will provide a far simpler experience for logging of your models.
206
+
207
+ ******************************
208
+ Creating custom Pyfunc models
209
+ ******************************
210
+
211
+ MLflow's persistence modules provide convenience functions for creating models with the
212
+ ``pyfunc`` flavor in a variety of machine learning frameworks (scikit-learn, Keras, Pytorch, and
213
+ more); however, they do not cover every use case. For example, you may want to create an MLflow
214
+ model with the ``pyfunc`` flavor using a framework that MLflow does not natively support.
215
+ Alternatively, you may want to build an MLflow model that executes custom logic when evaluating
216
+ queries, such as preprocessing and postprocessing routines. Therefore, ``mlflow.pyfunc``
217
+ provides utilities for creating ``pyfunc`` models from arbitrary code and model data.
218
+
219
+ The :meth:`save_model()` and :meth:`log_model()` methods are designed to support multiple workflows
220
+ for creating custom ``pyfunc`` models that incorporate custom inference logic and artifacts
221
+ that the logic may require.
222
+
223
+ An `artifact` is a file or directory, such as a serialized model or a CSV. For example, a
224
+ serialized TensorFlow graph is an artifact. An MLflow model directory is also an artifact.
225
+
226
+ .. _pyfunc-create-custom-workflows:
227
+
228
+ Workflows
229
+ #########
230
+
231
+ :meth:`save_model()` and :meth:`log_model()` support the following workflows:
232
+
233
+ 1. Programmatically defining a new MLflow model, including its attributes and artifacts.
234
+
235
+ Given a set of artifact URIs, :meth:`save_model()` and :meth:`log_model()` can
236
+ automatically download artifacts from their URIs and create an MLflow model directory.
237
+
238
+ In this case, you must define a Python class which inherits from :class:`~PythonModel`,
239
+ defining ``predict()`` and, optionally, ``load_context()``. An instance of this class is
240
+ specified via the ``python_model`` parameter; it is automatically serialized and deserialized
241
+ as a Python class, including all of its attributes.
242
+
243
+ 2. Interpreting pre-existing data as an MLflow model.
244
+
245
+ If you already have a directory containing model data, :meth:`save_model()` and
246
+ :meth:`log_model()` can import the data as an MLflow model. The ``data_path`` parameter
247
+ specifies the local filesystem path to the directory containing model data.
248
+
249
+ In this case, you must provide a Python module, called a `loader module`. The
250
+ loader module defines a ``_load_pyfunc()`` method that performs the following tasks:
251
+
252
+ - Load data from the specified ``data_path``. For example, this process may include
253
+ deserializing pickled Python objects or models or parsing CSV files.
254
+
255
+ - Construct and return a pyfunc-compatible model wrapper. As in the first
256
+ use case, this wrapper must define a ``predict()`` method that is used to evaluate
257
+ queries. ``predict()`` must adhere to the :ref:`pyfunc-inference-api`.
258
+
259
+ The ``loader_module`` parameter specifies the name of your loader module.
260
+
261
+ For an example loader module implementation, refer to the `loader module
262
+ implementation in mlflow.sklearn <https://github.com/mlflow/mlflow/blob/
263
+ 74d75109aaf2975f5026104d6125bb30f4e3f744/mlflow/sklearn.py#L200-L205>`_.
264
+
265
+ .. _pyfunc-create-custom-selecting-workflow:
266
+
267
+ Which workflow is right for my use case?
268
+ ########################################
269
+
270
+ We consider the first workflow to be more user-friendly and generally recommend it for the
271
+ following reasons:
272
+
273
+ - It automatically resolves and collects specified model artifacts.
274
+
275
+ - It automatically serializes and deserializes the ``python_model`` instance and all of
276
+ its attributes, reducing the amount of user logic that is required to load the model
277
+
278
+ - You can create Models using logic that is defined in the ``__main__`` scope. This allows
279
+ custom models to be constructed in interactive environments, such as notebooks and the Python
280
+ REPL.
281
+
282
+ You may prefer the second, lower-level workflow for the following reasons:
283
+
284
+ - Inference logic is always persisted as code, rather than a Python object. This makes logic
285
+ easier to inspect and modify later.
286
+
287
+ - If you have already collected all of your model data in a single location, the second
288
+ workflow allows it to be saved in MLflow format directly, without enumerating constituent
289
+ artifacts.
290
+
291
+ ******************************************
292
+ Function-based Model vs Class-based Model
293
+ ******************************************
294
+
295
+ When creating custom PyFunc models, you can choose between two different interfaces:
296
+ a function-based model and a class-based model. In short, a function-based model is simply a
297
+ python function that does not take additional params. The class-based model, on the other hand,
298
+ is subclass of ``PythonModel`` that supports several required and optional
299
+ methods. If your use case is simple and fits within a single predict function, a function-based
300
+ approach is recommended. If you need more power, such as custom serialization, custom data
301
+ processing, or to override additional methods, you should use the class-based implementation.
302
+
303
+ Before looking at code examples, it's important to note that both methods are serialized via
304
+ `cloudpickle <https://github.com/cloudpipe/cloudpickle>`_. cloudpickle can serialize Python
305
+ functions, lambda functions, and locally defined classes and functions inside other functions. This
306
+ makes cloudpickle especially useful for parallel and distributed computing where code objects need
307
+ to be sent over network to execute on remote workers, which is a common deployment paradigm for
308
+ MLflow.
309
+
310
+ That said, cloudpickle has some limitations.
311
+
312
+ - **Environment Dependency**: cloudpickle does not capture the full execution environment, so in
313
+ MLflow we must pass ``pip_requirements``, ``extra_pip_requirements``, or an ``input_example``,
314
+ the latter of which is used to infer environment dependencies. For more, refer to
315
+ `the model dependency docs <https://mlflow.org/docs/latest/model/dependencies.html>`_.
316
+
317
+ - **Object Support**: cloudpickle does not serialize objects outside of the Python data model.
318
+ Some relevant examples include raw files and database connections. If your program depends on
319
+ these, be sure to log ways to reference these objects along with your model.
320
+
321
+ Function-based Model
322
+ ####################
323
+ If you're looking to serialize a simple python function without additional dependent methods, you
324
+ can simply log a predict method via the keyword argument ``python_model``.
325
+
326
+ .. note::
327
+
328
+ Function-based model only supports a function with a single input argument. If you would like
329
+ to pass more arguments or additional inference parameters, please use the class-based model
330
+ below.
331
+
332
+ .. code-block:: python
333
+
334
+ import mlflow
335
+ import pandas as pd
336
+
337
+
338
+ # Define a simple function to log
339
+ def predict(model_input):
340
+ return model_input.apply(lambda x: x * 2)
341
+
342
+
343
+ # Save the function as a model
344
+ with mlflow.start_run():
345
+ mlflow.pyfunc.log_model(
346
+ name="model", python_model=predict, pip_requirements=["pandas"]
347
+ )
348
+ run_id = mlflow.active_run().info.run_id
349
+
350
+ # Load the model from the tracking server and perform inference
351
+ model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
352
+ x_new = pd.Series([1, 2, 3])
353
+
354
+ prediction = model.predict(x_new)
355
+ print(prediction)
356
+
357
+
358
+ Class-based Model
359
+ #################
360
+ If you're looking to serialize a more complex object, for instance a class that handles
361
+ preprocessing, complex prediction logic, or custom serialization, you should subclass the
362
+ ``PythonModel`` class. MLflow has tutorials on building custom PyFunc models, as shown
363
+ `here <https://mlflow.org/docs/latest/traditional-ml/creating-custom-pyfunc/index.html>`_,
364
+ so instead of duplicating that information, in this example we'll recreate the above functionality
365
+ to highlight the differences. Note that this PythonModel implementation is overly complex and
366
+ we would recommend using the functional-based Model instead for this simple case.
367
+
368
+ .. code-block:: python
369
+
370
+ import mlflow
371
+ import pandas as pd
372
+
373
+
374
+ class MyModel(mlflow.pyfunc.PythonModel):
375
+ def predict(self, context, model_input, params=None):
376
+ return [x * 2 for x in model_input]
377
+
378
+
379
+ # Save the function as a model
380
+ with mlflow.start_run():
381
+ mlflow.pyfunc.log_model(
382
+ name="model", python_model=MyModel(), pip_requirements=["pandas"]
383
+ )
384
+ run_id = mlflow.active_run().info.run_id
385
+
386
+ # Load the model from the tracking server and perform inference
387
+ model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
388
+ x_new = pd.Series([1, 2, 3])
389
+
390
+ print(f"Prediction:\n\t{model.predict(x_new)}")
391
+
392
+ The primary difference between the this implementation and the function-based implementation above
393
+ is that the predict method is wrapped with a class, has the ``self`` parameter,
394
+ and has the ``params`` parameter that defaults to None. Note that function-based models don't
395
+ support additional params.
396
+
397
+ In summary, use the function-based Model when you have a simple function to serialize.
398
+ If you need more power, use the class-based model.
399
+ """
400
+
401
+ import collections
402
+ import functools
403
+ import hashlib
404
+ import importlib
405
+ import inspect
406
+ import json
407
+ import logging
408
+ import os
409
+ import shutil
410
+ import signal
411
+ import subprocess
412
+ import sys
413
+ import tempfile
414
+ import threading
415
+ import uuid
416
+ from copy import deepcopy
417
+ from pathlib import Path
418
+ from typing import Any, Iterator, Optional, Tuple, Union
419
+ from urllib.parse import urlparse
420
+
421
+ import numpy as np
422
+ import pandas
423
+ import pydantic
424
+ import yaml
425
+ from packaging.version import Version
426
+
427
+ import mlflow
428
+ import mlflow.models.signature
429
+ import mlflow.pyfunc.loaders
430
+ import mlflow.pyfunc.model
431
+ from mlflow.entities.model_registry.prompt import Prompt
432
+ from mlflow.environment_variables import (
433
+ _MLFLOW_IN_CAPTURE_MODULE_PROCESS,
434
+ _MLFLOW_TESTING,
435
+ MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR,
436
+ MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT,
437
+ )
438
+ from mlflow.exceptions import MlflowException
439
+ from mlflow.models import Model, ModelInputExample, ModelSignature
440
+ from mlflow.models.auth_policy import AuthPolicy
441
+ from mlflow.models.dependencies_schemas import (
442
+ _clear_dependencies_schemas,
443
+ _get_dependencies_schema_from_model,
444
+ _get_dependencies_schemas,
445
+ )
446
+ from mlflow.models.flavor_backend_registry import get_flavor_backend
447
+ from mlflow.models.model import (
448
+ _DATABRICKS_FS_LOADER_MODULE,
449
+ MLMODEL_FILE_NAME,
450
+ MODEL_CODE_PATH,
451
+ MODEL_CONFIG,
452
+ )
453
+ from mlflow.models.resources import Resource, _ResourceBuilder
454
+ from mlflow.models.signature import (
455
+ _extract_type_hints,
456
+ _infer_signature_from_input_example,
457
+ _infer_signature_from_type_hints,
458
+ )
459
+ from mlflow.models.utils import (
460
+ PyFuncInput,
461
+ PyFuncLLMOutputChunk,
462
+ PyFuncLLMSingleInput,
463
+ PyFuncOutput,
464
+ _convert_llm_input_data,
465
+ _enforce_params_schema,
466
+ _enforce_schema,
467
+ _load_model_code_path,
468
+ _save_example,
469
+ _split_input_data_and_params,
470
+ _validate_and_get_model_code_path,
471
+ )
472
+ from mlflow.protos.databricks_pb2 import (
473
+ BAD_REQUEST,
474
+ INTERNAL_ERROR,
475
+ INVALID_PARAMETER_VALUE,
476
+ RESOURCE_DOES_NOT_EXIST,
477
+ )
478
+ from mlflow.protos.databricks_uc_registry_messages_pb2 import (
479
+ Entity,
480
+ Job,
481
+ LineageHeaderInfo,
482
+ Notebook,
483
+ )
484
+ from mlflow.pyfunc.context import Context, set_prediction_context
485
+ from mlflow.pyfunc.dbconnect_artifact_cache import (
486
+ DBConnectArtifactCache,
487
+ archive_directory,
488
+ extract_archive_to_dir,
489
+ )
490
+ from mlflow.pyfunc.model import (
491
+ _DEFAULT_CHAT_AGENT_METADATA_TASK,
492
+ _DEFAULT_CHAT_MODEL_METADATA_TASK,
493
+ _DEFAULT_RESPONSES_AGENT_METADATA_TASK,
494
+ ChatAgent,
495
+ ChatModel,
496
+ PythonModel,
497
+ PythonModelContext,
498
+ _FunctionPythonModel,
499
+ _log_warning_if_params_not_in_predict_signature,
500
+ _PythonModelPyfuncWrapper,
501
+ get_default_conda_env, # noqa: F401
502
+ get_default_pip_requirements,
503
+ )
504
+
505
+ try:
506
+ from mlflow.pyfunc.model import ResponsesAgent
507
+
508
+ IS_RESPONSES_AGENT_AVAILABLE = True
509
+ except ImportError:
510
+ IS_RESPONSES_AGENT_AVAILABLE = False
511
+ from mlflow.tracing.provider import trace_disabled
512
+ from mlflow.tracing.utils import _try_get_prediction_context
513
+ from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
514
+ from mlflow.tracking.artifact_utils import _download_artifact_from_uri
515
+ from mlflow.types.agent import (
516
+ CHAT_AGENT_INPUT_EXAMPLE,
517
+ CHAT_AGENT_INPUT_SCHEMA,
518
+ CHAT_AGENT_OUTPUT_SCHEMA,
519
+ ChatAgentRequest,
520
+ ChatAgentResponse,
521
+ )
522
+ from mlflow.types.llm import (
523
+ CHAT_MODEL_INPUT_EXAMPLE,
524
+ CHAT_MODEL_INPUT_SCHEMA,
525
+ CHAT_MODEL_OUTPUT_SCHEMA,
526
+ ChatCompletionResponse,
527
+ ChatMessage,
528
+ ChatParams,
529
+ )
530
+ from mlflow.types.type_hints import (
531
+ _convert_dataframe_to_example_format,
532
+ _is_example_valid_for_type_from_example,
533
+ _is_type_hint_from_example,
534
+ _signature_cannot_be_inferred_from_type_hint,
535
+ model_validate,
536
+ )
537
+ from mlflow.utils import (
538
+ PYTHON_VERSION,
539
+ _is_in_ipython_notebook,
540
+ check_port_connectivity,
541
+ databricks_utils,
542
+ find_free_port,
543
+ get_major_minor_py_version,
544
+ )
545
+ from mlflow.utils import env_manager as _EnvManager
546
+ from mlflow.utils._spark_utils import modified_environ
547
+ from mlflow.utils.annotations import deprecated, developer_stable, experimental
548
+ from mlflow.utils.databricks_utils import (
549
+ _get_databricks_serverless_env_vars,
550
+ get_dbconnect_udf_sandbox_info,
551
+ is_databricks_connect,
552
+ is_in_databricks_runtime,
553
+ is_in_databricks_serverless_runtime,
554
+ is_in_databricks_shared_cluster_runtime,
555
+ )
556
+ from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
557
+ from mlflow.utils.environment import (
558
+ _CONDA_ENV_FILE_NAME,
559
+ _CONSTRAINTS_FILE_NAME,
560
+ _PYTHON_ENV_FILE_NAME,
561
+ _REQUIREMENTS_FILE_NAME,
562
+ _process_conda_env,
563
+ _process_pip_requirements,
564
+ _PythonEnv,
565
+ _validate_env_arguments,
566
+ )
567
+ from mlflow.utils.file_utils import (
568
+ _copy_file_or_tree,
569
+ get_or_create_nfs_tmp_dir,
570
+ get_or_create_tmp_dir,
571
+ get_total_file_size,
572
+ write_to,
573
+ )
574
+ from mlflow.utils.mlflow_tags import MLFLOW_MODEL_IS_EXTERNAL
575
+ from mlflow.utils.model_utils import (
576
+ _add_code_from_conf_to_system_path,
577
+ _get_flavor_configuration,
578
+ _get_flavor_configuration_from_ml_model_file,
579
+ _get_overridden_pyfunc_model_config,
580
+ _validate_and_copy_file_to_directory,
581
+ _validate_and_get_model_config_from_file,
582
+ _validate_and_prepare_target_save_path,
583
+ _validate_infer_and_copy_code_paths,
584
+ _validate_pyfunc_model_config,
585
+ )
586
+ from mlflow.utils.nfs_on_spark import get_nfs_cache_root_dir
587
+ from mlflow.utils.pydantic_utils import model_dump_compat
588
+ from mlflow.utils.requirements_utils import (
589
+ _parse_requirements,
590
+ warn_dependency_requirement_mismatches,
591
+ )
592
+ from mlflow.utils.spark_utils import is_spark_connect_mode
593
+ from mlflow.utils.virtualenv import _get_python_env, _get_virtualenv_name
594
+ from mlflow.utils.warnings_utils import color_warning
595
+
596
+ try:
597
+ from pyspark.sql import DataFrame as SparkDataFrame
598
+
599
+ HAS_PYSPARK = True
600
+ except ImportError:
601
+ HAS_PYSPARK = False
602
+ FLAVOR_NAME = "python_function"
603
+ MAIN = "loader_module"
604
+ CODE = "code"
605
+ DATA = "data"
606
+ ENV = "env"
607
+ TASK = "task"
608
+
609
+ _MODEL_DATA_SUBPATH = "data"
610
+ _CHAT_PARAMS_WARNING_MESSAGE = (
611
+ "Default values for temperature, n and stream in ChatParams will be removed in the "
612
+ "next release. Specify them in the input example explicitly if needed."
613
+ )
614
+ _TYPE_FROM_EXAMPLE_ERROR_MESSAGE = (
615
+ "Input example must be provided when using TypeFromExample as type hint. "
616
+ "Fix this by passing `input_example` when logging your model. Check "
617
+ "https://mlflow.org/docs/latest/model/python_model.html#typefromexample-type-hint-usage "
618
+ "for more details."
619
+ )
620
+
621
+
622
+ class EnvType:
623
+ CONDA = "conda"
624
+ VIRTUALENV = "virtualenv"
625
+
626
+ def __init__(self):
627
+ raise NotImplementedError("This class is not meant to be instantiated.")
628
+
629
+
630
+ PY_VERSION = "python_version"
631
+
632
+
633
+ _logger = logging.getLogger(__name__)
634
+
635
+
636
+ def add_to_model(
637
+ model,
638
+ loader_module,
639
+ data=None,
640
+ code=None,
641
+ conda_env=None,
642
+ python_env=None,
643
+ model_config=None,
644
+ model_code_path=None,
645
+ **kwargs,
646
+ ):
647
+ """
648
+ Add a ``pyfunc`` spec to the model configuration.
649
+
650
+ Defines ``pyfunc`` configuration schema. Caller can use this to create a valid ``pyfunc`` model
651
+ flavor out of an existing directory structure. For example, other model flavors can use this to
652
+ specify how to use their output as a ``pyfunc``.
653
+
654
+ NOTE:
655
+
656
+ All paths are relative to the exported model root directory.
657
+
658
+ Args:
659
+ model: Existing model.
660
+ loader_module: The module to be used to load the model.
661
+ data: Path to the model data.
662
+ code: Path to the code dependencies.
663
+ conda_env: Conda environment.
664
+ python_env: Python environment.
665
+ model_config: The model configuration to apply to the model. This configuration
666
+ is available during model loading.
667
+
668
+ .. Note:: Experimental: This parameter may change or be removed in a future
669
+ release without warning.
670
+
671
+ model_code_path: Path to the model code.
672
+ kwargs: Additional key-value pairs to include in the ``pyfunc`` flavor specification.
673
+ Values must be YAML-serializable.
674
+
675
+ Returns:
676
+ Updated model configuration.
677
+ """
678
+ params = deepcopy(kwargs)
679
+ params[MAIN] = loader_module
680
+ params[PY_VERSION] = PYTHON_VERSION
681
+ if code:
682
+ params[CODE] = code
683
+ if data:
684
+ params[DATA] = data
685
+ if conda_env or python_env:
686
+ params[ENV] = {}
687
+ if conda_env:
688
+ params[ENV][EnvType.CONDA] = conda_env
689
+ if python_env:
690
+ params[ENV][EnvType.VIRTUALENV] = python_env
691
+ if model_config:
692
+ params[MODEL_CONFIG] = model_config
693
+ if model_code_path:
694
+ params[MODEL_CODE_PATH] = model_code_path
695
+ return model.add_flavor(FLAVOR_NAME, **params)
696
+
697
+
698
+ def _extract_conda_env(env):
699
+ # In MLflow < 2.0.0, the 'env' field in a pyfunc configuration is a string containing the path
700
+ # to a conda.yaml file.
701
+ return env if isinstance(env, str) else env[EnvType.CONDA]
702
+
703
+
704
+ def _load_model_env(path):
705
+ """
706
+ Get ENV file string from a model configuration stored in Python Function format.
707
+ Returned value is a model-relative path to a Conda Environment file,
708
+ or None if none was specified at model save time
709
+ """
710
+ return _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME).get(ENV, None)
711
+
712
+
713
+ def _validate_params(params, model_metadata):
714
+ if hasattr(model_metadata, "get_params_schema"):
715
+ params_schema = model_metadata.get_params_schema()
716
+ return _enforce_params_schema(params, params_schema)
717
+ if params:
718
+ raise MlflowException.invalid_parameter_value(
719
+ "This model was not logged with a params schema and does not support "
720
+ "providing the params argument."
721
+ "Please log the model with mlflow >= 2.6.0 and specify a params schema.",
722
+ )
723
+ return
724
+
725
+
726
+ def _validate_prediction_input(data: PyFuncInput, params, input_schema, params_schema, flavor=None):
727
+ """
728
+ Internal helper function to transform and validate input data and params for prediction.
729
+ Any additional transformation logics related to input data and params should be added here.
730
+ """
731
+ if input_schema is not None:
732
+ try:
733
+ data = _enforce_schema(data, input_schema, flavor)
734
+ except Exception as e:
735
+ # Include error in message for backwards compatibility
736
+ raise MlflowException.invalid_parameter_value(
737
+ f"Failed to enforce schema of data '{data}' "
738
+ f"with schema '{input_schema}'. "
739
+ f"Error: {e}",
740
+ )
741
+ params = _enforce_params_schema(params, params_schema)
742
+ if HAS_PYSPARK and isinstance(data, SparkDataFrame):
743
+ _logger.warning(
744
+ "Input data is a Spark DataFrame. Note that behaviour for "
745
+ "Spark DataFrames is model dependent."
746
+ )
747
+ return data, params
748
+
749
+
750
+ class PyFuncModel:
751
+ """
752
+ MLflow 'python function' model.
753
+
754
+ Wrapper around model implementation and metadata. This class is not meant to be constructed
755
+ directly. Instead, instances of this class are constructed and returned from
756
+ :py:func:`load_model() <mlflow.pyfunc.load_model>`.
757
+
758
+ ``model_impl`` can be any Python object that implements the `Pyfunc interface
759
+ <https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#pyfunc-inference-api>`_, and is
760
+ returned by invoking the model's ``loader_module``.
761
+
762
+ ``model_meta`` contains model metadata loaded from the MLmodel file.
763
+ """
764
+
765
+ def __init__(
766
+ self,
767
+ model_meta: Model,
768
+ model_impl: Any,
769
+ predict_fn: str = "predict",
770
+ predict_stream_fn: Optional[str] = None,
771
+ model_id: Optional[str] = None,
772
+ ):
773
+ if not hasattr(model_impl, predict_fn):
774
+ raise MlflowException(f"Model implementation is missing required {predict_fn} method.")
775
+ if not model_meta:
776
+ raise MlflowException("Model is missing metadata.")
777
+ self._model_meta = model_meta
778
+ self.__model_impl = model_impl
779
+ self._predict_fn = getattr(model_impl, predict_fn)
780
+ if predict_stream_fn:
781
+ if not hasattr(model_impl, predict_stream_fn):
782
+ raise MlflowException(
783
+ f"Model implementation is missing required {predict_stream_fn} method."
784
+ )
785
+ self._predict_stream_fn = getattr(model_impl, predict_stream_fn)
786
+ else:
787
+ self._predict_stream_fn = None
788
+ self._model_id = model_id
789
+ self._input_example = None
790
+
791
+ @property
792
+ @developer_stable
793
+ def _model_impl(self) -> Any:
794
+ """
795
+ The underlying model implementation object.
796
+
797
+ NOTE: This is a stable developer API.
798
+ """
799
+ return self.__model_impl
800
+
801
+ @property
802
+ def model_id(self) -> Optional[str]:
803
+ """
804
+ The model ID of the model.
805
+
806
+ Returns:
807
+ The model ID of the model.
808
+ """
809
+ return self._model_id
810
+
811
+ def _update_dependencies_schemas_in_prediction_context(self, context: Context):
812
+ if self._model_meta and self._model_meta.metadata:
813
+ dependencies_schemas = self._model_meta.metadata.get("dependencies_schemas", {})
814
+ context.update(
815
+ dependencies_schemas={
816
+ dependency: json.dumps(schema)
817
+ for dependency, schema in dependencies_schemas.items()
818
+ }
819
+ )
820
+
821
+ @property
822
+ def input_example(self) -> Optional[Any]:
823
+ """
824
+ The input example provided when the model was saved.
825
+ """
826
+ return self._input_example
827
+
828
+ @input_example.setter
829
+ def input_example(self, value: Any) -> None:
830
+ self._input_example = value
831
+
832
+ def predict(self, data: PyFuncInput, params: Optional[dict[str, Any]] = None) -> PyFuncOutput:
833
+ context = _try_get_prediction_context() or Context()
834
+ with set_prediction_context(context):
835
+ if schema := _get_dependencies_schema_from_model(self._model_meta):
836
+ context.update(**schema)
837
+
838
+ if self.model_id:
839
+ context.update(model_id=self.model_id)
840
+ return self._predict(data, params)
841
+
842
+ def _predict(self, data: PyFuncInput, params: Optional[dict[str, Any]] = None) -> PyFuncOutput:
843
+ """
844
+ Generates model predictions.
845
+
846
+ If the model contains signature, enforce the input schema first before calling the model
847
+ implementation with the sanitized input. If the pyfunc model does not include model schema,
848
+ the input is passed to the model implementation as is. See `Model Signature Enforcement
849
+ <https://www.mlflow.org/docs/latest/models.html#signature-enforcement>`_ for more details.
850
+
851
+ Args:
852
+ data: LLM Model single input as one of pandas.DataFrame, numpy.ndarray,
853
+ scipy.sparse.(csc_matrix | csr_matrix), List[Any], or
854
+ Dict[str, numpy.ndarray].
855
+ For model signatures with tensor spec inputs
856
+ (e.g. the Tensorflow core / Keras model), the input data type must be one of
857
+ `numpy.ndarray`, `List[numpy.ndarray]`, `Dict[str, numpy.ndarray]` or
858
+ `pandas.DataFrame`. If data is of `pandas.DataFrame` type and the model
859
+ contains a signature with tensor spec inputs, the corresponding column values
860
+ in the pandas DataFrame will be reshaped to the required shape with 'C' order
861
+ (i.e. read / write the elements using C-like index order), and DataFrame
862
+ column values will be cast as the required tensor spec type. For Pyspark
863
+ DataFrame inputs, MLflow will only enforce the schema on a subset
864
+ of the data rows.
865
+ params: Additional parameters to pass to the model for inference.
866
+
867
+ Returns:
868
+ Model predictions as one of pandas.DataFrame, pandas.Series, numpy.ndarray or list.
869
+ """
870
+ # fetch the schema from metadata to avoid signature change after model is loaded
871
+ self.input_schema = self.metadata.get_input_schema()
872
+ self.params_schema = self.metadata.get_params_schema()
873
+ # signature can only be inferred from type hints if the model is PythonModel
874
+ if self.metadata._is_signature_from_type_hint():
875
+ # we don't need to validate on data as data validation
876
+ # will be done during PythonModel's predict call
877
+ params = _enforce_params_schema(params, self.params_schema)
878
+ else:
879
+ data, params = _validate_prediction_input(
880
+ data, params, self.input_schema, self.params_schema, self.loader_module
881
+ )
882
+ if (
883
+ isinstance(data, pandas.DataFrame)
884
+ and self.metadata._is_type_hint_from_example()
885
+ and self.input_example is not None
886
+ ):
887
+ data = _convert_dataframe_to_example_format(data, self.input_example)
888
+ params_arg = inspect.signature(self._predict_fn).parameters.get("params")
889
+ if params_arg and params_arg.kind != inspect.Parameter.VAR_KEYWORD:
890
+ return self._predict_fn(data, params=params)
891
+
892
+ _log_warning_if_params_not_in_predict_signature(_logger, params)
893
+ return self._predict_fn(data)
894
+
895
+ def predict_stream(
896
+ self, data: PyFuncLLMSingleInput, params: Optional[dict[str, Any]] = None
897
+ ) -> Iterator[PyFuncLLMOutputChunk]:
898
+ context = _try_get_prediction_context() or Context()
899
+
900
+ if schema := _get_dependencies_schema_from_model(self._model_meta):
901
+ context.update(**schema)
902
+
903
+ if self.model_id:
904
+ context.update(model_id=self.model_id)
905
+
906
+ # NB: The prediction context must be applied during iterating over the stream,
907
+ # hence, simply wrapping the self._predict_stream call with the context manager
908
+ # is not sufficient.
909
+ def _gen_with_context(*args, **kwargs):
910
+ with set_prediction_context(context):
911
+ yield from self._predict_stream(*args, **kwargs)
912
+
913
+ return _gen_with_context(data, params)
914
+
915
+ def _predict_stream(
916
+ self, data: PyFuncLLMSingleInput, params: Optional[dict[str, Any]] = None
917
+ ) -> Iterator[PyFuncLLMOutputChunk]:
918
+ """
919
+ Generates streaming model predictions. Only LLM supports this method.
920
+
921
+ If the model contains signature, enforce the input schema first before calling the model
922
+ implementation with the sanitized input. If the pyfunc model does not include model schema,
923
+ the input is passed to the model implementation as is. See `Model Signature Enforcement
924
+ <https://www.mlflow.org/docs/latest/models.html#signature-enforcement>`_ for more details.
925
+
926
+ Args:
927
+ data: LLM Model single input as one of dict, str, bool, bytes, float, int, str type.
928
+ params: Additional parameters to pass to the model for inference.
929
+
930
+ Returns:
931
+ Model predictions as an iterator of chunks. The chunks in the iterator must be type of
932
+ dict or string. Chunk dict fields are determined by the model implementation.
933
+ """
934
+
935
+ if self._predict_stream_fn is None:
936
+ raise MlflowException("This model does not support predict_stream method.")
937
+
938
+ self.input_schema = self.metadata.get_input_schema()
939
+ self.params_schema = self.metadata.get_params_schema()
940
+ data, params = _validate_prediction_input(
941
+ data, params, self.input_schema, self.params_schema, self.loader_module
942
+ )
943
+ data = _convert_llm_input_data(data)
944
+ if isinstance(data, list):
945
+ # `predict_stream` only accepts single input.
946
+ # but `enforce_schema` might convert single input into a list like `[single_input]`
947
+ # so extract the first element in the list.
948
+ if len(data) != 1:
949
+ raise MlflowException(
950
+ f"'predict_stream' requires single input, but it got input data {data}"
951
+ )
952
+ data = data[0]
953
+
954
+ if "params" in inspect.signature(self._predict_stream_fn).parameters:
955
+ return self._predict_stream_fn(data, params=params)
956
+
957
+ _log_warning_if_params_not_in_predict_signature(_logger, params)
958
+ return self._predict_stream_fn(data)
959
+
960
+ def unwrap_python_model(self):
961
+ """
962
+ Unwrap the underlying Python model object.
963
+
964
+ This method is useful for accessing custom model functions, while still being able to
965
+ leverage the MLflow designed workflow through the `predict()` method.
966
+
967
+ Returns:
968
+ The underlying wrapped model object
969
+
970
+ .. code-block:: python
971
+ :test:
972
+ :caption: Example
973
+
974
+ import mlflow
975
+
976
+
977
+ # define a custom model
978
+ class MyModel(mlflow.pyfunc.PythonModel):
979
+ def predict(self, context, model_input, params=None):
980
+ return self.my_custom_function(model_input, params)
981
+
982
+ def my_custom_function(self, model_input, params=None):
983
+ # do something with the model input
984
+ return 0
985
+
986
+
987
+ some_input = 1
988
+ # save the model
989
+ with mlflow.start_run():
990
+ model_info = mlflow.pyfunc.log_model(name="model", python_model=MyModel())
991
+
992
+ # load the model
993
+ loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
994
+ print(type(loaded_model)) # <class 'mlflow.pyfunc.model.PyFuncModel'>
995
+ unwrapped_model = loaded_model.unwrap_python_model()
996
+ print(type(unwrapped_model)) # <class '__main__.MyModel'>
997
+
998
+ # does not work, only predict() is exposed
999
+ # print(loaded_model.my_custom_function(some_input))
1000
+ print(unwrapped_model.my_custom_function(some_input)) # works
1001
+ print(loaded_model.predict(some_input)) # works
1002
+
1003
+ # works, but None is needed for context arg
1004
+ print(unwrapped_model.predict(None, some_input))
1005
+ """
1006
+ try:
1007
+ python_model = self._model_impl.python_model
1008
+ if python_model is None:
1009
+ raise AttributeError("Expected python_model attribute not to be None.")
1010
+ except AttributeError as e:
1011
+ raise MlflowException("Unable to retrieve base model object from pyfunc.") from e
1012
+ return python_model
1013
+
1014
+ def __eq__(self, other):
1015
+ if not isinstance(other, PyFuncModel):
1016
+ return False
1017
+ return self._model_meta == other._model_meta
1018
+
1019
+ @property
1020
+ def metadata(self) -> Model:
1021
+ """Model metadata."""
1022
+ if self._model_meta is None:
1023
+ raise MlflowException("Model is missing metadata.")
1024
+ return self._model_meta
1025
+
1026
+ @property
1027
+ def model_config(self):
1028
+ """Model's flavor configuration"""
1029
+ return self._model_meta.flavors[FLAVOR_NAME].get(MODEL_CONFIG, {})
1030
+
1031
+ @property
1032
+ def loader_module(self):
1033
+ """Model's flavor configuration"""
1034
+ if self._model_meta.flavors.get(FLAVOR_NAME) is None:
1035
+ return None
1036
+ return self._model_meta.flavors[FLAVOR_NAME].get(MAIN)
1037
+
1038
+ def __repr__(self):
1039
+ info = {}
1040
+ if self._model_meta is not None:
1041
+ if hasattr(self._model_meta, "run_id") and self._model_meta.run_id is not None:
1042
+ info["run_id"] = self._model_meta.run_id
1043
+ if (
1044
+ hasattr(self._model_meta, "artifact_path")
1045
+ and self._model_meta.artifact_path is not None
1046
+ ):
1047
+ info["artifact_path"] = self._model_meta.artifact_path
1048
+ info["flavor"] = self._model_meta.flavors[FLAVOR_NAME]["loader_module"]
1049
+ return yaml.safe_dump({"mlflow.pyfunc.loaded_model": info}, default_flow_style=False)
1050
+
1051
+ @experimental(version="2.16.0")
1052
+ def get_raw_model(self):
1053
+ """
1054
+ Get the underlying raw model if the model wrapper implemented `get_raw_model` function.
1055
+ """
1056
+ if hasattr(self._model_impl, "get_raw_model"):
1057
+ return self._model_impl.get_raw_model()
1058
+ raise NotImplementedError("`get_raw_model` is not implemented by the underlying model")
1059
+
1060
+
1061
+ def _get_pip_requirements_from_model_path(model_path: str):
1062
+ req_file_path = os.path.join(model_path, _REQUIREMENTS_FILE_NAME)
1063
+ if not os.path.exists(req_file_path):
1064
+ return []
1065
+
1066
+ return [req.req_str for req in _parse_requirements(req_file_path, is_constraint=False)]
1067
+
1068
+
1069
+ @trace_disabled # Suppress traces while loading model
1070
+ def load_model(
1071
+ model_uri: str,
1072
+ suppress_warnings: bool = False,
1073
+ dst_path: Optional[str] = None,
1074
+ model_config: Optional[Union[str, Path, dict[str, Any]]] = None,
1075
+ ) -> PyFuncModel:
1076
+ """
1077
+ Load a model stored in Python function format.
1078
+
1079
+ Args:
1080
+ model_uri: The location, in URI format, of the MLflow model. For example:
1081
+
1082
+ - ``/Users/me/path/to/local/model``
1083
+ - ``relative/path/to/local/model``
1084
+ - ``s3://my_bucket/path/to/model``
1085
+ - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
1086
+ - ``models:/<model_name>/<model_version>``
1087
+ - ``models:/<model_name>/<stage>``
1088
+ - ``mlflow-artifacts:/path/to/model``
1089
+
1090
+ For more information about supported URI schemes, see
1091
+ `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
1092
+ artifact-locations>`_.
1093
+ suppress_warnings: If ``True``, non-fatal warning messages associated with the model
1094
+ loading process will be suppressed. If ``False``, these warning messages will be
1095
+ emitted.
1096
+ dst_path: The local filesystem path to which to download the model artifact.
1097
+ This directory must already exist. If unspecified, a local output
1098
+ path will be created.
1099
+ model_config: The model configuration to apply to the model. The configuration will
1100
+ be available as the ``model_config`` property of the ``context`` parameter
1101
+ in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
1102
+ and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
1103
+ The configuration can be passed as a file path, or a dict with string keys.
1104
+
1105
+ .. Note:: Experimental: This parameter may change or be removed in a future
1106
+ release without warning.
1107
+ """
1108
+
1109
+ lineage_header_info = None
1110
+ if (
1111
+ not _MLFLOW_IN_CAPTURE_MODULE_PROCESS.get()
1112
+ ) and databricks_utils.is_in_databricks_runtime():
1113
+ entity_list = []
1114
+ # Get notebook id and job id, pack them into lineage_header_info
1115
+ if databricks_utils.is_in_databricks_notebook() and (
1116
+ notebook_id := databricks_utils.get_notebook_id()
1117
+ ):
1118
+ notebook_entity = Notebook(id=notebook_id)
1119
+ entity_list.append(Entity(notebook=notebook_entity))
1120
+
1121
+ if databricks_utils.is_in_databricks_job() and (job_id := databricks_utils.get_job_id()):
1122
+ job_entity = Job(id=job_id)
1123
+ entity_list.append(Entity(job=job_entity))
1124
+
1125
+ lineage_header_info = LineageHeaderInfo(entities=entity_list) if entity_list else None
1126
+
1127
+ local_path = _download_artifact_from_uri(
1128
+ artifact_uri=model_uri, output_path=dst_path, lineage_header_info=lineage_header_info
1129
+ )
1130
+
1131
+ if not suppress_warnings:
1132
+ model_requirements = _get_pip_requirements_from_model_path(local_path)
1133
+ warn_dependency_requirement_mismatches(model_requirements)
1134
+
1135
+ model_meta = Model.load(os.path.join(local_path, MLMODEL_FILE_NAME))
1136
+
1137
+ if model_meta.metadata and model_meta.metadata.get(MLFLOW_MODEL_IS_EXTERNAL, False) is True:
1138
+ raise MlflowException(
1139
+ "This model's artifacts are external and are not stored in the model directory."
1140
+ " This model cannot be loaded with MLflow.",
1141
+ BAD_REQUEST,
1142
+ )
1143
+
1144
+ conf = model_meta.flavors.get(FLAVOR_NAME)
1145
+ if conf is None:
1146
+ raise MlflowException(
1147
+ f'Model does not have the "{FLAVOR_NAME}" flavor',
1148
+ RESOURCE_DOES_NOT_EXIST,
1149
+ )
1150
+ model_py_version = conf.get(PY_VERSION)
1151
+ if not suppress_warnings:
1152
+ _warn_potentially_incompatible_py_version_if_necessary(model_py_version=model_py_version)
1153
+
1154
+ _add_code_from_conf_to_system_path(local_path, conf, code_key=CODE)
1155
+ data_path = os.path.join(local_path, conf[DATA]) if (DATA in conf) else local_path
1156
+
1157
+ if isinstance(model_config, str):
1158
+ model_config = _validate_and_get_model_config_from_file(model_config)
1159
+
1160
+ model_config = _get_overridden_pyfunc_model_config(
1161
+ conf.get(MODEL_CONFIG, None), model_config, _logger
1162
+ )
1163
+
1164
+ try:
1165
+ if model_config:
1166
+ model_impl = importlib.import_module(conf[MAIN])._load_pyfunc(data_path, model_config)
1167
+ else:
1168
+ model_impl = importlib.import_module(conf[MAIN])._load_pyfunc(data_path)
1169
+ except ModuleNotFoundError as e:
1170
+ # This error message is particularly for the case when the error is caused by module
1171
+ # "databricks.feature_store.mlflow_model". But depending on the environment, the offending
1172
+ # module might be "databricks", "databricks.feature_store" or full package. So we will
1173
+ # raise the error with the following note if "databricks" presents in the error. All non-
1174
+ # databricks module errors will just be re-raised.
1175
+ if conf[MAIN] == _DATABRICKS_FS_LOADER_MODULE and e.name.startswith("databricks"):
1176
+ raise MlflowException(
1177
+ f"{e.msg}; "
1178
+ "Note: mlflow.pyfunc.load_model is not supported for Feature Store models. "
1179
+ "spark_udf() and predict() will not work as expected. Use "
1180
+ "score_batch for offline predictions.",
1181
+ BAD_REQUEST,
1182
+ ) from None
1183
+ raise e
1184
+ finally:
1185
+ # clean up the dependencies schema which is set to global state after loading the model.
1186
+ # This avoids the schema being used by other models loaded in the same process.
1187
+ _clear_dependencies_schemas()
1188
+ predict_fn = conf.get("predict_fn", "predict")
1189
+ streamable = conf.get("streamable", False)
1190
+ predict_stream_fn = conf.get("predict_stream_fn", "predict_stream") if streamable else None
1191
+
1192
+ pyfunc_model = PyFuncModel(
1193
+ model_meta=model_meta,
1194
+ model_impl=model_impl,
1195
+ predict_fn=predict_fn,
1196
+ predict_stream_fn=predict_stream_fn,
1197
+ model_id=model_meta.model_id,
1198
+ )
1199
+
1200
+ try:
1201
+ model_input_example = model_meta.load_input_example(path=local_path)
1202
+ pyfunc_model.input_example = model_input_example
1203
+ except Exception as e:
1204
+ _logger.debug(f"Failed to load input example from model metadata: {e}.")
1205
+
1206
+ return pyfunc_model
1207
+
1208
+
1209
+ class _ServedPyFuncModel(PyFuncModel):
1210
+ def __init__(self, model_meta: Model, client: Any, server_pid: int, env_manager="local"):
1211
+ super().__init__(model_meta=model_meta, model_impl=client, predict_fn="invoke")
1212
+ self._client = client
1213
+ self._server_pid = server_pid
1214
+ # We need to set `env_manager` attribute because it is used by Databricks runtime
1215
+ # evaluate usage logging to log 'env_manager' tag in `_evaluate` function patching.
1216
+ self._env_manager = env_manager
1217
+
1218
+ def predict(self, data, params=None):
1219
+ """
1220
+ Args:
1221
+ data: Model input data.
1222
+ params: Additional parameters to pass to the model for inference.
1223
+
1224
+ Returns:
1225
+ Model predictions.
1226
+ """
1227
+ if "params" in inspect.signature(self._client.invoke).parameters:
1228
+ result = self._client.invoke(data, params=params).get_predictions()
1229
+ else:
1230
+ _log_warning_if_params_not_in_predict_signature(_logger, params)
1231
+ result = self._client.invoke(data).get_predictions()
1232
+ if isinstance(result, pandas.DataFrame):
1233
+ result = result[result.columns[0]]
1234
+ return result
1235
+
1236
+ @property
1237
+ def pid(self):
1238
+ if self._server_pid is None:
1239
+ raise MlflowException("Served PyFunc Model is missing server process ID.")
1240
+ return self._server_pid
1241
+
1242
+ @property
1243
+ def env_manager(self):
1244
+ return self._env_manager
1245
+
1246
+ @env_manager.setter
1247
+ def env_manager(self, value):
1248
+ self._env_manager = value
1249
+
1250
+
1251
+ def _load_model_or_server(
1252
+ model_uri: str, env_manager: str, model_config: Optional[dict[str, Any]] = None
1253
+ ):
1254
+ """
1255
+ Load a model with env restoration. If a non-local ``env_manager`` is specified, prepare an
1256
+ independent Python environment with the training time dependencies of the specified model
1257
+ installed and start a MLflow Model Scoring Server process with that model in that environment.
1258
+ Return a _ServedPyFuncModel that invokes the scoring server for prediction. Otherwise, load and
1259
+ return the model locally as a PyFuncModel using :py:func:`mlflow.pyfunc.load_model`.
1260
+
1261
+ Args:
1262
+ model_uri: The uri of the model.
1263
+ env_manager: The environment manager to load the model.
1264
+ model_config: The model configuration to use by the model, only if the model
1265
+ accepts it.
1266
+
1267
+ Returns:
1268
+ A _ServedPyFuncModel for non-local ``env_manager``s or a PyFuncModel otherwise.
1269
+ """
1270
+ from mlflow.pyfunc.scoring_server.client import (
1271
+ ScoringServerClient,
1272
+ StdinScoringServerClient,
1273
+ )
1274
+
1275
+ if env_manager == _EnvManager.LOCAL:
1276
+ return load_model(model_uri, model_config=model_config)
1277
+
1278
+ _logger.info("Starting model server for model environment restoration.")
1279
+
1280
+ local_path = _download_artifact_from_uri(artifact_uri=model_uri)
1281
+ model_meta = Model.load(os.path.join(local_path, MLMODEL_FILE_NAME))
1282
+
1283
+ is_port_connectable = check_port_connectivity()
1284
+ pyfunc_backend = get_flavor_backend(
1285
+ local_path,
1286
+ env_manager=env_manager,
1287
+ install_mlflow=os.environ.get("MLFLOW_HOME") is not None,
1288
+ create_env_root_dir=not is_port_connectable,
1289
+ )
1290
+ _logger.info("Restoring model environment. This can take a few minutes.")
1291
+ # Set capture_output to True in Databricks so that when environment preparation fails, the
1292
+ # exception message of the notebook cell output will include child process command execution
1293
+ # stdout/stderr output.
1294
+ pyfunc_backend.prepare_env(model_uri=local_path, capture_output=is_in_databricks_runtime())
1295
+ if is_port_connectable:
1296
+ server_port = find_free_port()
1297
+ scoring_server_proc = pyfunc_backend.serve(
1298
+ model_uri=local_path,
1299
+ port=server_port,
1300
+ host="127.0.0.1",
1301
+ timeout=MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT.get(),
1302
+ enable_mlserver=False,
1303
+ synchronous=False,
1304
+ stdout=subprocess.PIPE,
1305
+ stderr=subprocess.STDOUT,
1306
+ model_config=model_config,
1307
+ )
1308
+ client = ScoringServerClient("127.0.0.1", server_port)
1309
+ else:
1310
+ scoring_server_proc = pyfunc_backend.serve_stdin(local_path, model_config=model_config)
1311
+ client = StdinScoringServerClient(scoring_server_proc)
1312
+
1313
+ _logger.info(f"Scoring server process started at PID: {scoring_server_proc.pid}")
1314
+ try:
1315
+ client.wait_server_ready(timeout=90, scoring_server_proc=scoring_server_proc)
1316
+ except Exception as e:
1317
+ if scoring_server_proc.poll() is None:
1318
+ # the scoring server is still running but client can't connect to it.
1319
+ # kill the server.
1320
+ scoring_server_proc.kill()
1321
+ server_output, _ = scoring_server_proc.communicate(timeout=15)
1322
+ if isinstance(server_output, bytes):
1323
+ server_output = server_output.decode("UTF-8")
1324
+ raise MlflowException(
1325
+ "MLflow model server failed to launch, server process stdout and stderr are:\n"
1326
+ + server_output
1327
+ ) from e
1328
+
1329
+ return _ServedPyFuncModel(
1330
+ model_meta=model_meta,
1331
+ client=client,
1332
+ server_pid=scoring_server_proc.pid,
1333
+ env_manager=env_manager,
1334
+ )
1335
+
1336
+
1337
+ def _get_model_dependencies(model_uri, format="pip"):
1338
+ model_dir = _download_artifact_from_uri(model_uri)
1339
+
1340
+ def get_conda_yaml_path():
1341
+ model_config = _get_flavor_configuration_from_ml_model_file(
1342
+ os.path.join(model_dir, MLMODEL_FILE_NAME), flavor_name=FLAVOR_NAME
1343
+ )
1344
+ return os.path.join(model_dir, _extract_conda_env(model_config[ENV]))
1345
+
1346
+ if format == "pip":
1347
+ requirements_file = os.path.join(model_dir, _REQUIREMENTS_FILE_NAME)
1348
+ if os.path.exists(requirements_file):
1349
+ return requirements_file
1350
+
1351
+ _logger.info(
1352
+ f"{_REQUIREMENTS_FILE_NAME} is not found in the model directory. Falling back to"
1353
+ f" extracting pip requirements from the model's 'conda.yaml' file. Conda"
1354
+ " dependencies will be ignored."
1355
+ )
1356
+
1357
+ with open(get_conda_yaml_path()) as yf:
1358
+ conda_yaml = yaml.safe_load(yf)
1359
+
1360
+ conda_deps = conda_yaml.get("dependencies", [])
1361
+ for index, dep in enumerate(conda_deps):
1362
+ if isinstance(dep, dict) and "pip" in dep:
1363
+ pip_deps_index = index
1364
+ break
1365
+ else:
1366
+ raise MlflowException(
1367
+ "No pip section found in conda.yaml file in the model directory.",
1368
+ error_code=RESOURCE_DOES_NOT_EXIST,
1369
+ )
1370
+
1371
+ pip_deps = conda_deps.pop(pip_deps_index)["pip"]
1372
+ tmp_dir = tempfile.mkdtemp()
1373
+ pip_file_path = os.path.join(tmp_dir, _REQUIREMENTS_FILE_NAME)
1374
+ with open(pip_file_path, "w") as f:
1375
+ f.write("\n".join(pip_deps) + "\n")
1376
+
1377
+ if len(conda_deps) > 0:
1378
+ _logger.warning(
1379
+ "The following conda dependencies have been excluded from the environment file:"
1380
+ f" {', '.join(conda_deps)}."
1381
+ )
1382
+
1383
+ return pip_file_path
1384
+
1385
+ elif format == "conda":
1386
+ return get_conda_yaml_path()
1387
+ else:
1388
+ raise MlflowException(
1389
+ f"Illegal format argument '{format}'.", error_code=INVALID_PARAMETER_VALUE
1390
+ )
1391
+
1392
+
1393
+ def get_model_dependencies(model_uri, format="pip"):
1394
+ """
1395
+ Downloads the model dependencies and returns the path to requirements.txt or conda.yaml file.
1396
+
1397
+ .. warning::
1398
+ This API downloads all the model artifacts to the local filesystem. This may take
1399
+ a long time for large models. To avoid this overhead, use
1400
+ ``mlflow.artifacts.download_artifacts("<model_uri>/requirements.txt")`` or
1401
+ ``mlflow.artifacts.download_artifacts("<model_uri>/conda.yaml")`` instead.
1402
+
1403
+ Args:
1404
+ model_uri: The uri of the model to get dependencies from.
1405
+ format: The format of the returned dependency file. If the ``"pip"`` format is
1406
+ specified, the path to a pip ``requirements.txt`` file is returned.
1407
+ If the ``"conda"`` format is specified, the path to a ``"conda.yaml"``
1408
+ file is returned . If the ``"pip"`` format is specified but the model
1409
+ was not saved with a ``requirements.txt`` file, the ``pip`` section
1410
+ of the model's ``conda.yaml`` file is extracted instead, and any
1411
+ additional conda dependencies are ignored. Default value is ``"pip"``.
1412
+
1413
+ Returns:
1414
+ The local filesystem path to either a pip ``requirements.txt`` file
1415
+ (if ``format="pip"``) or a ``conda.yaml`` file (if ``format="conda"``)
1416
+ specifying the model's dependencies.
1417
+ """
1418
+ dep_file = _get_model_dependencies(model_uri, format)
1419
+
1420
+ if format == "pip":
1421
+ prefix = "%" if _is_in_ipython_notebook() else ""
1422
+ _logger.info(
1423
+ "To install the dependencies that were used to train the model, run the "
1424
+ f"following command: '{prefix}pip install -r {dep_file}'."
1425
+ )
1426
+ return dep_file
1427
+
1428
+
1429
+ @deprecated("mlflow.pyfunc.load_model", 1.0)
1430
+ def load_pyfunc(model_uri, suppress_warnings=False):
1431
+ """
1432
+ Load a model stored in Python function format.
1433
+
1434
+ Args:
1435
+ model_uri: The location, in URI format, of the MLflow model. For example:
1436
+
1437
+ - ``/Users/me/path/to/local/model``
1438
+ - ``relative/path/to/local/model``
1439
+ - ``s3://my_bucket/path/to/model``
1440
+ - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
1441
+ - ``models:/<model_name>/<model_version>``
1442
+ - ``models:/<model_name>/<stage>``
1443
+ - ``mlflow-artifacts:/path/to/model``
1444
+
1445
+ For more information about supported URI schemes, see
1446
+ `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
1447
+ artifact-locations>`_.
1448
+
1449
+ suppress_warnings: If ``True``, non-fatal warning messages associated with the model
1450
+ loading process will be suppressed. If ``False``, these warning messages will be
1451
+ emitted.
1452
+ """
1453
+ return load_model(model_uri, suppress_warnings)
1454
+
1455
+
1456
+ def _warn_potentially_incompatible_py_version_if_necessary(model_py_version=None):
1457
+ """
1458
+ Compares the version of Python that was used to save a given model with the version
1459
+ of Python that is currently running. If a major or minor version difference is detected,
1460
+ logs an appropriate warning.
1461
+ """
1462
+ if model_py_version is None:
1463
+ _logger.warning(
1464
+ "The specified model does not have a specified Python version. It may be"
1465
+ " incompatible with the version of Python that is currently running: Python %s",
1466
+ PYTHON_VERSION,
1467
+ )
1468
+ elif get_major_minor_py_version(model_py_version) != get_major_minor_py_version(PYTHON_VERSION):
1469
+ _logger.warning(
1470
+ "The version of Python that the model was saved in, `Python %s`, differs"
1471
+ " from the version of Python that is currently running, `Python %s`,"
1472
+ " and may be incompatible",
1473
+ model_py_version,
1474
+ PYTHON_VERSION,
1475
+ )
1476
+
1477
+
1478
+ def _create_model_downloading_tmp_dir(should_use_nfs):
1479
+ root_tmp_dir = get_or_create_nfs_tmp_dir() if should_use_nfs else get_or_create_tmp_dir()
1480
+
1481
+ root_model_cache_dir = os.path.join(root_tmp_dir, "models")
1482
+ os.makedirs(root_model_cache_dir, exist_ok=True)
1483
+
1484
+ tmp_model_dir = tempfile.mkdtemp(dir=root_model_cache_dir)
1485
+ # mkdtemp creates a directory with permission 0o700
1486
+ # change it to be 0o770 to ensure it can be seen in spark UDF
1487
+ os.chmod(tmp_model_dir, 0o770)
1488
+ return tmp_model_dir
1489
+
1490
+
1491
+ _MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP = 200
1492
+
1493
+
1494
+ def _is_variant_type(spark_type):
1495
+ try:
1496
+ from pyspark.sql.types import VariantType
1497
+
1498
+ return isinstance(spark_type, VariantType)
1499
+ except ImportError:
1500
+ return False
1501
+
1502
+
1503
+ def _convert_spec_type_to_spark_type(spec_type):
1504
+ from pyspark.sql.types import ArrayType, MapType, StringType, StructField, StructType
1505
+
1506
+ from mlflow.types.schema import AnyType, Array, DataType, Map, Object
1507
+
1508
+ if isinstance(spec_type, DataType):
1509
+ return spec_type.to_spark()
1510
+
1511
+ if isinstance(spec_type, AnyType):
1512
+ try:
1513
+ from pyspark.sql.types import VariantType
1514
+
1515
+ return VariantType()
1516
+ except ImportError:
1517
+ raise MlflowException.invalid_parameter_value(
1518
+ "`AnyType` is not supported in PySpark versions older than 4.0.0. "
1519
+ "Upgrade your PySpark version to use this feature.",
1520
+ )
1521
+
1522
+ if isinstance(spec_type, Array):
1523
+ return ArrayType(_convert_spec_type_to_spark_type(spec_type.dtype))
1524
+
1525
+ if isinstance(spec_type, Object):
1526
+ return StructType(
1527
+ [
1528
+ StructField(
1529
+ property.name,
1530
+ _convert_spec_type_to_spark_type(property.dtype),
1531
+ # we set nullable to True for all properties
1532
+ # to avoid some errors like java.lang.NullPointerException
1533
+ # when the signature is not inferred based on correct data.
1534
+ )
1535
+ for property in spec_type.properties
1536
+ ]
1537
+ )
1538
+
1539
+ # Map only supports string as key
1540
+ if isinstance(spec_type, Map):
1541
+ return MapType(
1542
+ keyType=StringType(), valueType=_convert_spec_type_to_spark_type(spec_type.value_type)
1543
+ )
1544
+
1545
+ raise MlflowException(f"Failed to convert schema type `{spec_type}` to spark type.")
1546
+
1547
+
1548
+ def _cast_output_spec_to_spark_type(spec):
1549
+ from pyspark.sql.types import ArrayType
1550
+
1551
+ from mlflow.types.schema import ColSpec, DataType, TensorSpec
1552
+
1553
+ # TODO: handle optional output columns.
1554
+ if isinstance(spec, ColSpec):
1555
+ return _convert_spec_type_to_spark_type(spec.type)
1556
+ elif isinstance(spec, TensorSpec):
1557
+ data_type = DataType.from_numpy_type(spec.type)
1558
+ if data_type is None:
1559
+ raise MlflowException(
1560
+ f"Model output tensor spec type {spec.type} is not supported in spark_udf.",
1561
+ error_code=INVALID_PARAMETER_VALUE,
1562
+ )
1563
+
1564
+ if len(spec.shape) == 1:
1565
+ return ArrayType(data_type.to_spark())
1566
+ elif len(spec.shape) == 2:
1567
+ return ArrayType(ArrayType(data_type.to_spark()))
1568
+ else:
1569
+ raise MlflowException(
1570
+ "Only 1D or 2D tensors are supported as spark_udf "
1571
+ f"return value, but model output '{spec.name}' has shape {spec.shape}.",
1572
+ error_code=INVALID_PARAMETER_VALUE,
1573
+ )
1574
+ else:
1575
+ raise MlflowException(
1576
+ f"Unknown schema output spec {spec}.", error_code=INVALID_PARAMETER_VALUE
1577
+ )
1578
+
1579
+
1580
+ def _infer_spark_udf_return_type(model_output_schema):
1581
+ from pyspark.sql.types import StructField, StructType
1582
+
1583
+ if len(model_output_schema.inputs) == 1:
1584
+ return _cast_output_spec_to_spark_type(model_output_schema.inputs[0])
1585
+
1586
+ return StructType(
1587
+ [
1588
+ StructField(name=spec.name or str(i), dataType=_cast_output_spec_to_spark_type(spec))
1589
+ for i, spec in enumerate(model_output_schema.inputs)
1590
+ ]
1591
+ )
1592
+
1593
+
1594
+ def _parse_spark_datatype(datatype: str):
1595
+ from pyspark.sql.functions import udf
1596
+ from pyspark.sql.session import SparkSession
1597
+
1598
+ return_type = "boolean" if datatype == "bool" else datatype
1599
+ parsed_datatype = udf(lambda x: x, returnType=return_type).returnType
1600
+
1601
+ if parsed_datatype.typeName() == "unparseddata":
1602
+ # For spark 3.5.x, `udf(lambda x: x, returnType=return_type).returnType`
1603
+ # returns UnparsedDataType, which is not compatible with signature inference.
1604
+ # Note: SparkSession.active only exists for spark >= 3.5.0
1605
+ schema = (
1606
+ SparkSession.active()
1607
+ .range(0)
1608
+ .select(udf(lambda x: x, returnType=return_type)("id"))
1609
+ .schema
1610
+ )
1611
+ return schema[0].dataType
1612
+
1613
+ return parsed_datatype
1614
+
1615
+
1616
+ def _is_none_or_nan(value):
1617
+ # The condition `isinstance(value, float)` is needed to avoid error
1618
+ # from `np.isnan(value)` if value is a non-numeric type.
1619
+ return value is None or isinstance(value, float) and np.isnan(value)
1620
+
1621
+
1622
+ def _convert_array_values(values, result_type):
1623
+ """
1624
+ Convert list or numpy array values to spark dataframe column values.
1625
+ """
1626
+ from pyspark.sql.types import ArrayType, StructType
1627
+
1628
+ if not isinstance(result_type, ArrayType):
1629
+ raise MlflowException.invalid_parameter_value(
1630
+ f"result_type must be ArrayType, got {result_type.simpleString()}",
1631
+ )
1632
+
1633
+ spark_primitive_type_to_np_type = _get_spark_primitive_type_to_np_type()
1634
+
1635
+ if type(result_type.elementType) in spark_primitive_type_to_np_type:
1636
+ np_type = spark_primitive_type_to_np_type[type(result_type.elementType)]
1637
+ # For array type result values, if provided value is None or NaN, regard it as a null array.
1638
+ # see https://github.com/mlflow/mlflow/issues/8986
1639
+ return None if _is_none_or_nan(values) else np.array(values, dtype=np_type)
1640
+ if isinstance(result_type.elementType, ArrayType):
1641
+ return [_convert_array_values(v, result_type.elementType) for v in values]
1642
+ if isinstance(result_type.elementType, StructType):
1643
+ return [_convert_struct_values(v, result_type.elementType) for v in values]
1644
+ if _is_variant_type(result_type.elementType):
1645
+ return values
1646
+
1647
+ raise MlflowException.invalid_parameter_value(
1648
+ "Unsupported array type field with element type "
1649
+ f"{result_type.elementType.simpleString()} in Array type.",
1650
+ )
1651
+
1652
+
1653
+ def _get_spark_primitive_types():
1654
+ from pyspark.sql import types
1655
+
1656
+ return (
1657
+ types.IntegerType,
1658
+ types.LongType,
1659
+ types.FloatType,
1660
+ types.DoubleType,
1661
+ types.StringType,
1662
+ types.BooleanType,
1663
+ )
1664
+
1665
+
1666
+ def _get_spark_primitive_type_to_np_type():
1667
+ from pyspark.sql import types
1668
+
1669
+ return {
1670
+ types.IntegerType: np.int32,
1671
+ types.LongType: np.int64,
1672
+ types.FloatType: np.float32,
1673
+ types.DoubleType: np.float64,
1674
+ types.BooleanType: np.bool_,
1675
+ types.StringType: np.str_,
1676
+ }
1677
+
1678
+
1679
+ def _get_spark_primitive_type_to_python_type():
1680
+ from pyspark.sql import types
1681
+
1682
+ return {
1683
+ types.IntegerType: int,
1684
+ types.LongType: int,
1685
+ types.FloatType: float,
1686
+ types.DoubleType: float,
1687
+ types.BooleanType: bool,
1688
+ types.StringType: str,
1689
+ }
1690
+
1691
+
1692
+ def _check_udf_return_type(data_type):
1693
+ from pyspark.sql.types import ArrayType, MapType, StringType, StructType
1694
+
1695
+ primitive_types = _get_spark_primitive_types()
1696
+ if isinstance(data_type, primitive_types):
1697
+ return True
1698
+
1699
+ if isinstance(data_type, ArrayType):
1700
+ return _check_udf_return_type(data_type.elementType)
1701
+
1702
+ if isinstance(data_type, StructType):
1703
+ return all(_check_udf_return_type(field.dataType) for field in data_type.fields)
1704
+
1705
+ if isinstance(data_type, MapType):
1706
+ return isinstance(data_type.keyType, StringType) and _check_udf_return_type(
1707
+ data_type.valueType
1708
+ )
1709
+
1710
+ return False
1711
+
1712
+
1713
+ def _convert_struct_values(
1714
+ result: Union[pandas.DataFrame, dict[str, Any]],
1715
+ result_type,
1716
+ ):
1717
+ """
1718
+ Convert spark StructType values to spark dataframe column values.
1719
+ """
1720
+
1721
+ from pyspark.sql.types import ArrayType, MapType, StructType
1722
+
1723
+ if not isinstance(result_type, StructType):
1724
+ raise MlflowException.invalid_parameter_value(
1725
+ f"result_type must be StructType, got {result_type.simpleString()}",
1726
+ )
1727
+
1728
+ if not isinstance(result, (dict, pandas.DataFrame)):
1729
+ raise MlflowException.invalid_parameter_value(
1730
+ f"Unsupported result type {type(result)}, expected dict or pandas DataFrame",
1731
+ )
1732
+
1733
+ spark_primitive_type_to_np_type = _get_spark_primitive_type_to_np_type()
1734
+ is_pandas_df = isinstance(result, pandas.DataFrame)
1735
+ result_dict = {}
1736
+ for field_name in result_type.fieldNames():
1737
+ field_type = result_type[field_name].dataType
1738
+ field_values = result[field_name]
1739
+
1740
+ if type(field_type) in spark_primitive_type_to_np_type:
1741
+ np_type = spark_primitive_type_to_np_type[type(field_type)]
1742
+ if is_pandas_df:
1743
+ # it's possible that field_values contain only Nones
1744
+ # in this case, we don't need to cast the type
1745
+ if not all(_is_none_or_nan(field_value) for field_value in field_values):
1746
+ field_values = field_values.astype(np_type)
1747
+ else:
1748
+ field_values = (
1749
+ None
1750
+ if _is_none_or_nan(field_values)
1751
+ else np.array(field_values, dtype=np_type).item()
1752
+ )
1753
+ elif isinstance(field_type, ArrayType):
1754
+ if is_pandas_df:
1755
+ field_values = pandas.Series(
1756
+ _convert_array_values(field_value, field_type) for field_value in field_values
1757
+ )
1758
+ else:
1759
+ field_values = _convert_array_values(field_values, field_type)
1760
+ elif isinstance(field_type, StructType):
1761
+ if is_pandas_df:
1762
+ field_values = pandas.Series(
1763
+ [
1764
+ _convert_struct_values(field_value, field_type)
1765
+ for field_value in field_values
1766
+ ]
1767
+ )
1768
+ else:
1769
+ if isinstance(field_type, pydantic.BaseModel):
1770
+ field_values = model_dump_compat(field_values)
1771
+ field_values = _convert_struct_values(field_values, field_type)
1772
+ elif isinstance(field_type, MapType):
1773
+ if is_pandas_df:
1774
+ field_values = pandas.Series(
1775
+ [
1776
+ {
1777
+ key: _convert_value_based_on_spark_type(value, field_type.valueType)
1778
+ for key, value in field_value.items()
1779
+ }
1780
+ for field_value in field_values
1781
+ ]
1782
+ ).astype(object)
1783
+ else:
1784
+ field_values = {
1785
+ key: _convert_value_based_on_spark_type(value, field_type.valueType)
1786
+ for key, value in field_values.items()
1787
+ }
1788
+ elif _is_variant_type(field_type):
1789
+ return field_values
1790
+ else:
1791
+ raise MlflowException.invalid_parameter_value(
1792
+ f"Unsupported field type {field_type.simpleString()} in struct type.",
1793
+ )
1794
+ result_dict[field_name] = field_values
1795
+
1796
+ if is_pandas_df:
1797
+ return pandas.DataFrame(result_dict)
1798
+ return result_dict
1799
+
1800
+
1801
+ def _convert_value_based_on_spark_type(value, spark_type):
1802
+ """
1803
+ Convert value to python types based on the given spark type.
1804
+ """
1805
+
1806
+ from pyspark.sql.types import ArrayType, MapType, StructType
1807
+
1808
+ spark_primitive_type_to_python_type = _get_spark_primitive_type_to_python_type()
1809
+
1810
+ if type(spark_type) in spark_primitive_type_to_python_type:
1811
+ python_type = spark_primitive_type_to_python_type[type(spark_type)]
1812
+ return None if _is_none_or_nan(value) else python_type(value)
1813
+ if isinstance(spark_type, StructType):
1814
+ return _convert_struct_values(value, spark_type)
1815
+ if isinstance(spark_type, ArrayType):
1816
+ return [_convert_value_based_on_spark_type(v, spark_type.elementType) for v in value]
1817
+ if isinstance(spark_type, MapType):
1818
+ return {
1819
+ key: _convert_value_based_on_spark_type(value[key], spark_type.valueType)
1820
+ for key in value
1821
+ }
1822
+ if _is_variant_type(spark_type):
1823
+ return value
1824
+ raise MlflowException.invalid_parameter_value(
1825
+ f"Unsupported type {spark_type} for value {value}"
1826
+ )
1827
+
1828
+
1829
+ # This location is used to prebuild python environment in Databricks runtime.
1830
+ # The location for prebuilding env should be located under /local_disk0
1831
+ # because the python env will be uploaded to NFS and mounted to Serverless UDF sandbox,
1832
+ # for serverless client image case, it doesn't have "/local_disk0" directory
1833
+ _PREBUILD_ENV_ROOT_LOCATION = "/tmp"
1834
+
1835
+
1836
+ def _gen_prebuilt_env_archive_name(spark, local_model_path):
1837
+ """
1838
+ Generate prebuilt env archive file name.
1839
+ The format is:
1840
+ 'mlflow-{sha of python env config and dependencies}-{runtime version}-{platform machine}'
1841
+ Note: The runtime version and platform machine information are included in the
1842
+ archive name because the prebuilt env might not be compatible across different
1843
+ runtime versions or platform machines.
1844
+ """
1845
+ python_env = _get_python_env(Path(local_model_path))
1846
+ env_name = _get_virtualenv_name(python_env, local_model_path)
1847
+ dbconnect_udf_sandbox_info = get_dbconnect_udf_sandbox_info(spark)
1848
+ return (
1849
+ f"{env_name}-{dbconnect_udf_sandbox_info.image_version}-"
1850
+ f"{dbconnect_udf_sandbox_info.platform_machine}"
1851
+ )
1852
+
1853
+
1854
+ def _verify_prebuilt_env(spark, local_model_path, env_archive_path):
1855
+ # Use `[:-7]` to truncate ".tar.gz" in the end
1856
+ archive_name = os.path.basename(env_archive_path)[:-7]
1857
+ prebuilt_env_sha, prebuilt_runtime_version, prebuilt_platform_machine = archive_name.split("-")[
1858
+ -3:
1859
+ ]
1860
+
1861
+ python_env = _get_python_env(Path(local_model_path))
1862
+ env_sha = _get_virtualenv_name(python_env, local_model_path).split("-")[-1]
1863
+ dbconnect_udf_sandbox_info = get_dbconnect_udf_sandbox_info(spark)
1864
+ runtime_version = dbconnect_udf_sandbox_info.image_version
1865
+ platform_machine = dbconnect_udf_sandbox_info.platform_machine
1866
+
1867
+ if prebuilt_env_sha != env_sha:
1868
+ raise MlflowException(
1869
+ f"The prebuilt env '{env_archive_path}' does not match the model required environment."
1870
+ )
1871
+ if prebuilt_runtime_version != runtime_version:
1872
+ raise MlflowException(
1873
+ f"The prebuilt env '{env_archive_path}' runtime version '{prebuilt_runtime_version}' "
1874
+ f"does not match UDF sandbox runtime version {runtime_version}."
1875
+ )
1876
+ if prebuilt_platform_machine != platform_machine:
1877
+ raise MlflowException(
1878
+ f"The prebuilt env '{env_archive_path}' platform machine '{prebuilt_platform_machine}' "
1879
+ f"does not match UDF sandbox platform machine {platform_machine}."
1880
+ )
1881
+
1882
+
1883
+ def _prebuild_env_internal(local_model_path, archive_name, save_path, env_manager):
1884
+ env_root_dir = os.path.join(_PREBUILD_ENV_ROOT_LOCATION, archive_name)
1885
+ archive_path = os.path.join(save_path, archive_name + ".tar.gz")
1886
+ if os.path.exists(env_root_dir):
1887
+ shutil.rmtree(env_root_dir)
1888
+ if os.path.exists(archive_path):
1889
+ os.remove(archive_path)
1890
+
1891
+ try:
1892
+ pyfunc_backend = get_flavor_backend(
1893
+ local_model_path,
1894
+ env_manager=env_manager,
1895
+ install_mlflow=False,
1896
+ create_env_root_dir=False,
1897
+ env_root_dir=env_root_dir,
1898
+ )
1899
+
1900
+ pyfunc_backend.prepare_env(model_uri=local_model_path, capture_output=False)
1901
+ # exclude pip cache from the archive file.
1902
+ cache_path = os.path.join(env_root_dir, "pip_cache_pkgs")
1903
+ if os.path.exists(cache_path):
1904
+ shutil.rmtree(cache_path)
1905
+
1906
+ return archive_directory(env_root_dir, archive_path)
1907
+ finally:
1908
+ shutil.rmtree(env_root_dir, ignore_errors=True)
1909
+
1910
+
1911
+ def _download_prebuilt_env_if_needed(prebuilt_env_uri):
1912
+ from mlflow.utils.file_utils import get_or_create_tmp_dir
1913
+
1914
+ parsed_url = urlparse(prebuilt_env_uri)
1915
+ if parsed_url.scheme == "" or parsed_url.scheme == "file":
1916
+ # local path
1917
+ return parsed_url.path
1918
+ if parsed_url.scheme == "dbfs":
1919
+ tmp_dir = MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR.get() or get_or_create_tmp_dir()
1920
+ model_env_uc_path = parsed_url.path
1921
+
1922
+ # download file from DBFS.
1923
+ local_model_env_path = os.path.join(tmp_dir, os.path.basename(model_env_uc_path))
1924
+ if os.path.exists(local_model_env_path):
1925
+ # file is already downloaded.
1926
+ return local_model_env_path
1927
+
1928
+ try:
1929
+ from databricks.sdk import WorkspaceClient
1930
+
1931
+ ws = WorkspaceClient()
1932
+ # Download model env file from UC volume.
1933
+ with (
1934
+ ws.files.download(model_env_uc_path).contents as rf,
1935
+ open(local_model_env_path, "wb") as wf,
1936
+ ):
1937
+ while chunk := rf.read(4096 * 1024):
1938
+ wf.write(chunk)
1939
+ return local_model_env_path
1940
+ except (Exception, KeyboardInterrupt):
1941
+ if os.path.exists(local_model_env_path):
1942
+ # clean the partially saved file if downloading fails.
1943
+ os.remove(local_model_env_path)
1944
+ raise
1945
+
1946
+ raise MlflowException(
1947
+ f"Unsupported prebuilt env file path '{prebuilt_env_uri}', "
1948
+ f"invalid scheme: '{parsed_url.scheme}'."
1949
+ )
1950
+
1951
+
1952
+ def build_model_env(model_uri, save_path, env_manager=_EnvManager.VIRTUALENV):
1953
+ """
1954
+ Prebuild model python environment and generate an archive file saved to provided
1955
+ `save_path`.
1956
+
1957
+ Typical usages:
1958
+ - Pre-build a model's environment in Databricks Runtime and then download the prebuilt
1959
+ python environment archive file. This pre-built environment archive can then be used
1960
+ in `mlflow.pyfunc.spark_udf` for remote inference execution when using Databricks Connect
1961
+ to remotely connect to a Databricks environment for code execution.
1962
+
1963
+ .. note::
1964
+ The `build_model_env` API is intended to only work when executed within Databricks runtime,
1965
+ serving the purpose of capturing the required execution environment that is needed for
1966
+ remote code execution when using DBConnect. The environment archive is designed to be used
1967
+ when performing remote execution using `mlflow.pyfunc.spark_udf` in
1968
+ Databricks runtime or Databricks Connect client and has no other purpose.
1969
+ The prebuilt env archive file cannot be used across different Databricks runtime
1970
+ versions or different platform machines. As such, if you connect to a different cluster
1971
+ that is running a different runtime version on Databricks, you will need to execute this
1972
+ API in a notebook and retrieve the generated archive to your local machine. Each
1973
+ environment snapshot is unique to the the model, the runtime version of your remote
1974
+ Databricks cluster, and the specification of the udf execution environment.
1975
+ When using the prebuilt env in `mlflow.pyfunc.spark_udf`, MLflow will verify
1976
+ whether the spark UDF sandbox environment matches the prebuilt env requirements and will
1977
+ raise Exceptions if there are compatibility issues. If these occur, simply re-run this API
1978
+ in the cluster that you are attempting to attach to.
1979
+
1980
+ .. code-block:: python
1981
+ :caption: Example
1982
+
1983
+ from mlflow.pyfunc import build_model_env
1984
+
1985
+ # Create a python environment archive file at the path `prebuilt_env_uri`
1986
+ prebuilt_env_uri = build_model_env(f"runs:/{run_id}/model", "/path/to/save_directory")
1987
+
1988
+ Args:
1989
+ model_uri: URI to the model that is used to build the python environment.
1990
+ save_path: The directory path that is used to save the prebuilt model environment
1991
+ archive file path.
1992
+ The path can be either local directory path or
1993
+ mounted DBFS path such as '/dbfs/...' or
1994
+ mounted UC volume path such as '/Volumes/...'.
1995
+ env_manager: The environment manager to use in order to create the python environment
1996
+ for model inference, the value can be either 'virtualenv' or 'uv', the default
1997
+ value is 'virtualenv'.
1998
+
1999
+ Returns:
2000
+ Return the path of an archive file containing the python environment data.
2001
+ """
2002
+ from mlflow.utils._spark_utils import _get_active_spark_session
2003
+
2004
+ if not is_in_databricks_runtime():
2005
+ raise RuntimeError("'build_model_env' only support running in Databricks runtime.")
2006
+
2007
+ if os.path.isfile(save_path):
2008
+ raise RuntimeError(f"The saving path '{save_path}' must be a directory.")
2009
+ os.makedirs(save_path, exist_ok=True)
2010
+
2011
+ local_model_path = _download_artifact_from_uri(
2012
+ artifact_uri=model_uri, output_path=_create_model_downloading_tmp_dir(should_use_nfs=False)
2013
+ )
2014
+ archive_name = _gen_prebuilt_env_archive_name(_get_active_spark_session(), local_model_path)
2015
+ dest_path = os.path.join(save_path, archive_name + ".tar.gz")
2016
+ if os.path.exists(dest_path):
2017
+ raise RuntimeError(
2018
+ "A pre-built model python environment already exists "
2019
+ f"in '{dest_path}'. To rebuild it, please remove "
2020
+ "the existing one first."
2021
+ )
2022
+
2023
+ # Archive the environment directory as a `tar.gz` format archive file,
2024
+ # and then move the archive file to the destination directory.
2025
+ # Note:
2026
+ # - all symlink files in the input directory are kept as it is in the
2027
+ # archive file.
2028
+ # - the destination directory could be UC-volume fuse mounted directory
2029
+ # which only supports limited filesystem operations, so to ensure it works,
2030
+ # we generate the archive file under /tmp and then move it into the
2031
+ # destination directory.
2032
+ tmp_archive_path = None
2033
+ try:
2034
+ tmp_archive_path = _prebuild_env_internal(
2035
+ local_model_path, archive_name, _PREBUILD_ENV_ROOT_LOCATION, env_manager
2036
+ )
2037
+ shutil.move(tmp_archive_path, save_path)
2038
+ return dest_path
2039
+ finally:
2040
+ shutil.rmtree(local_model_path, ignore_errors=True)
2041
+ if tmp_archive_path and os.path.exists(tmp_archive_path):
2042
+ os.remove(tmp_archive_path)
2043
+
2044
+
2045
+ def spark_udf(
2046
+ spark,
2047
+ model_uri,
2048
+ result_type=None,
2049
+ env_manager=None,
2050
+ params: Optional[dict[str, Any]] = None,
2051
+ extra_env: Optional[dict[str, str]] = None,
2052
+ prebuilt_env_uri: Optional[str] = None,
2053
+ model_config: Optional[Union[str, Path, dict[str, Any]]] = None,
2054
+ ):
2055
+ """
2056
+ A Spark UDF that can be used to invoke the Python function formatted model.
2057
+
2058
+ Parameters passed to the UDF are forwarded to the model as a DataFrame where the column names
2059
+ are ordinals (0, 1, ...). On some versions of Spark (3.0 and above), it is also possible to
2060
+ wrap the input in a struct. In that case, the data will be passed as a DataFrame with column
2061
+ names given by the struct definition (e.g. when invoked as my_udf(struct('x', 'y')), the model
2062
+ will get the data as a pandas DataFrame with 2 columns 'x' and 'y').
2063
+
2064
+ If a model contains a signature with tensor spec inputs, you will need to pass a column of
2065
+ array type as a corresponding UDF argument. The column values of which must be one dimensional
2066
+ arrays. The UDF will reshape the column values to the required shape with 'C' order
2067
+ (i.e. read / write the elements using C-like index order) and cast the values as the required
2068
+ tensor spec type.
2069
+
2070
+ If a model contains a signature, the UDF can be called without specifying column name
2071
+ arguments. In this case, the UDF will be called with column names from signature, so the
2072
+ evaluation dataframe's column names must match the model signature's column names.
2073
+
2074
+ The predictions are filtered to contain only the columns that can be represented as the
2075
+ ``result_type``. If the ``result_type`` is string or array of strings, all predictions are
2076
+ converted to string. If the result type is not an array type, the left most column with
2077
+ matching type is returned.
2078
+
2079
+ .. note::
2080
+ Inputs of type ``pyspark.sql.types.DateType`` are not supported on earlier versions of
2081
+ Spark (2.4 and below).
2082
+
2083
+ .. note::
2084
+ When using Databricks Connect to connect to a remote Databricks cluster,
2085
+ the Databricks cluster must use runtime version >= 16, and if the 'prebuilt_env_uri'
2086
+ parameter is set, 'env_manager' parameter should not be set.
2087
+ the Databricks cluster must use runtime version >= 15.4,and if the 'prebuilt_env_uri'
2088
+ parameter is set, 'env_manager' parameter should not be set,
2089
+ if the runtime version is 15.4 and the cluster is
2090
+ standard access mode, the cluster need to configure
2091
+ "spark.databricks.safespark.archive.artifact.unpack.disabled" to "false".
2092
+
2093
+ .. note::
2094
+ Please be aware that when operating in Databricks Serverless,
2095
+ spark tasks run within the confines of the Databricks Serverless UDF sandbox.
2096
+ This environment has a total capacity limit of 1GB, combining both available
2097
+ memory and local disk capacity. Furthermore, there are no GPU devices available
2098
+ in this setup. Therefore, any deep-learning models that contain large weights
2099
+ or require a GPU are not suitable for deployment on Databricks Serverless.
2100
+
2101
+ .. code-block:: python
2102
+ :caption: Example
2103
+
2104
+ from pyspark.sql.functions import struct
2105
+
2106
+ predict = mlflow.pyfunc.spark_udf(spark, "/my/local/model")
2107
+ df.withColumn("prediction", predict(struct("name", "age"))).show()
2108
+
2109
+ Args:
2110
+ spark: A SparkSession object.
2111
+ model_uri: The location, in URI format, of the MLflow model with the
2112
+ :py:mod:`mlflow.pyfunc` flavor. For example:
2113
+
2114
+ - ``/Users/me/path/to/local/model``
2115
+ - ``relative/path/to/local/model``
2116
+ - ``s3://my_bucket/path/to/model``
2117
+ - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
2118
+ - ``models:/<model_name>/<model_version>``
2119
+ - ``models:/<model_name>/<stage>``
2120
+ - ``mlflow-artifacts:/path/to/model``
2121
+
2122
+ For more information about supported URI schemes, see
2123
+ `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
2124
+ artifact-locations>`_.
2125
+
2126
+ result_type: the return type of the user-defined function. The value can be either a
2127
+ ``pyspark.sql.types.DataType`` object or a DDL-formatted type string. Only a primitive
2128
+ type, an array ``pyspark.sql.types.ArrayType`` of primitive type, or a struct type
2129
+ containing fields of above 2 kinds of types are allowed.
2130
+ If unspecified, it tries to infer result type from model signature
2131
+ output schema, if model output schema is not available, it fallbacks to use ``double``
2132
+ type.
2133
+
2134
+ The following classes of result type are supported:
2135
+
2136
+ - "int" or ``pyspark.sql.types.IntegerType``: The leftmost integer that can fit in an
2137
+ ``int32`` or an exception if there is none.
2138
+
2139
+ - "long" or ``pyspark.sql.types.LongType``: The leftmost long integer that can fit in an
2140
+ ``int64`` or an exception if there is none.
2141
+
2142
+ - ``ArrayType(IntegerType|LongType)``: All integer columns that can fit into the
2143
+ requested size.
2144
+
2145
+ - "float" or ``pyspark.sql.types.FloatType``: The leftmost numeric result cast to
2146
+ ``float32`` or an exception if there is none.
2147
+
2148
+ - "double" or ``pyspark.sql.types.DoubleType``: The leftmost numeric result cast to
2149
+ ``double`` or an exception if there is none.
2150
+
2151
+ - ``ArrayType(FloatType|DoubleType)``: All numeric columns cast to the requested type or
2152
+ an exception if there are no numeric columns.
2153
+
2154
+ - "string" or ``pyspark.sql.types.StringType``: The leftmost column converted to
2155
+ ``string``.
2156
+
2157
+ - "boolean" or "bool" or ``pyspark.sql.types.BooleanType``: The leftmost column
2158
+ converted to ``bool`` or an exception if there is none.
2159
+
2160
+ - ``ArrayType(StringType)``: All columns converted to ``string``.
2161
+
2162
+ - "field1 FIELD1_TYPE, field2 FIELD2_TYPE, ...": A struct type containing multiple
2163
+ fields separated by comma, each field type must be one of types listed above.
2164
+
2165
+ env_manager: The environment manager to use in order to create the python environment
2166
+ for model inference. Note that environment is only restored in the context
2167
+ of the PySpark UDF; the software environment outside of the UDF is
2168
+ unaffected. If `prebuilt_env_uri` parameter is not set, the default value
2169
+ is ``local``, and the following values are supported:
2170
+
2171
+ - ``virtualenv``: Use virtualenv to restore the python environment that
2172
+ was used to train the model. This is the default option if ``env_manager``
2173
+ is not set.
2174
+ - ``uv`` : Use uv to restore the python environment that
2175
+ was used to train the model.
2176
+ - ``conda``: Use Conda to restore the software environment
2177
+ that was used to train the model.
2178
+ - ``local``: Use the current Python environment for model inference, which
2179
+ may differ from the environment used to train the model and may lead to
2180
+ errors or invalid predictions.
2181
+
2182
+ If the `prebuilt_env_uri` parameter is set, `env_manager` parameter should not
2183
+ be set.
2184
+
2185
+ params: Additional parameters to pass to the model for inference.
2186
+
2187
+ extra_env: Extra environment variables to pass to the UDF executors.
2188
+ For overrides that need to propagate to the Spark workers (i.e.,
2189
+ overriding the scoring server timeout via `MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT`).
2190
+
2191
+ prebuilt_env_uri: The path of the prebuilt env archive file created by
2192
+ `mlflow.pyfunc.build_model_env` API.
2193
+ This parameter can only be used in Databricks Serverless notebook REPL,
2194
+ Databricks Shared cluster notebook REPL, and Databricks Connect client
2195
+ environment.
2196
+ The path can be either local file path or DBFS path such as
2197
+ 'dbfs:/Volumes/...', in this case, MLflow automatically downloads it
2198
+ to local temporary directory, "MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR"
2199
+ environmental variable can be set to specify the temporary directory
2200
+ to use.
2201
+
2202
+ If this parameter is set, `env_manger` parameter must not be set.
2203
+
2204
+ model_config: The model configuration to set when loading the model.
2205
+ See 'model_config' argument in `mlflow.pyfunc.load_model` API for details.
2206
+
2207
+ Returns:
2208
+ Spark UDF that applies the model's ``predict`` method to the data and returns a
2209
+ type specified by ``result_type``, which by default is a double.
2210
+ """
2211
+
2212
+ # Scope Spark import to this method so users don't need pyspark to use non-Spark-related
2213
+ # functionality.
2214
+ from pyspark.sql.functions import pandas_udf
2215
+ from pyspark.sql.types import (
2216
+ ArrayType,
2217
+ BooleanType,
2218
+ DoubleType,
2219
+ FloatType,
2220
+ IntegerType,
2221
+ LongType,
2222
+ MapType,
2223
+ StringType,
2224
+ )
2225
+ from pyspark.sql.types import StructType as SparkStructType
2226
+
2227
+ from mlflow.pyfunc.spark_model_cache import SparkModelCache
2228
+ from mlflow.utils._spark_utils import _SparkDirectoryDistributor
2229
+
2230
+ is_spark_connect = is_spark_connect_mode()
2231
+ # Used in test to force install local version of mlflow when starting a model server
2232
+ mlflow_home = os.environ.get("MLFLOW_HOME")
2233
+ openai_env_vars = mlflow.openai.model._OpenAIEnvVar.read_environ()
2234
+ mlflow_testing = _MLFLOW_TESTING.get_raw()
2235
+
2236
+ if prebuilt_env_uri:
2237
+ if env_manager not in (None, _EnvManager.VIRTUALENV, _EnvManager.UV):
2238
+ raise MlflowException(
2239
+ "If 'prebuilt_env_uri' parameter is set, 'env_manager' parameter must "
2240
+ "be either None, 'virtualenv', or 'uv'."
2241
+ )
2242
+ env_manager = _EnvManager.VIRTUALENV
2243
+ else:
2244
+ env_manager = env_manager or _EnvManager.LOCAL
2245
+
2246
+ _EnvManager.validate(env_manager)
2247
+
2248
+ if is_spark_connect:
2249
+ is_spark_in_local_mode = False
2250
+ else:
2251
+ # Check whether spark is in local or local-cluster mode
2252
+ # this case all executors and driver share the same filesystem
2253
+ is_spark_in_local_mode = spark.conf.get("spark.master").startswith("local")
2254
+
2255
+ is_dbconnect_mode = is_databricks_connect(spark)
2256
+ if prebuilt_env_uri is not None and not is_dbconnect_mode:
2257
+ raise RuntimeError(
2258
+ "'prebuilt_env' parameter can only be used in Databricks Serverless "
2259
+ "notebook REPL, atabricks Shared cluster notebook REPL, and Databricks Connect client "
2260
+ "environment."
2261
+ )
2262
+
2263
+ if prebuilt_env_uri is None and is_dbconnect_mode and not is_in_databricks_runtime():
2264
+ raise RuntimeError(
2265
+ "'prebuilt_env_uri' param is required if using Databricks Connect to connect "
2266
+ "to Databricks cluster from your own machine."
2267
+ )
2268
+
2269
+ # Databricks connect can use `spark.addArtifact` to upload artifact to NFS.
2270
+ # But for Databricks shared cluster runtime, it can directly write to NFS, so exclude it
2271
+ # Note for Databricks Serverless runtime (notebook REPL), it runs on Servereless VM that
2272
+ # can't access NFS, so it needs to use `spark.addArtifact`.
2273
+ use_dbconnect_artifact = is_dbconnect_mode and not is_in_databricks_shared_cluster_runtime()
2274
+
2275
+ if use_dbconnect_artifact:
2276
+ udf_sandbox_info = get_dbconnect_udf_sandbox_info(spark)
2277
+ if Version(udf_sandbox_info.mlflow_version) < Version("2.18.0"):
2278
+ raise MlflowException(
2279
+ "Using 'mlflow.pyfunc.spark_udf' in Databricks Serverless or in remote "
2280
+ "Databricks Connect requires UDF sandbox image installed with MLflow "
2281
+ "of version >= 2.18.0"
2282
+ )
2283
+ # `udf_sandbox_info.runtime_version` format is like '<major_version>.<minor_version>'.
2284
+ # It's safe to apply `Version`.
2285
+ dbr_runtime_version = Version(udf_sandbox_info.runtime_version)
2286
+ if dbr_runtime_version < Version("15.4"):
2287
+ raise MlflowException(
2288
+ "Using 'mlflow.pyfunc.spark_udf' in Databricks Serverless or in remote "
2289
+ "Databricks Connect requires Databricks runtime version >= 15.4."
2290
+ )
2291
+ if dbr_runtime_version == Version("15.4"):
2292
+ if spark.conf.get("spark.databricks.pyspark.udf.isolation.enabled").lower() == "true":
2293
+ # The connected cluster is standard (shared) mode.
2294
+ if (
2295
+ spark.conf.get(
2296
+ "spark.databricks.safespark.archive.artifact.unpack.disabled"
2297
+ ).lower()
2298
+ != "false"
2299
+ ):
2300
+ raise MlflowException(
2301
+ "Using 'mlflow.pyfunc.spark_udf' in remote Databricks Connect requires "
2302
+ "Databricks cluster setting "
2303
+ "'spark.databricks.safespark.archive.artifact.unpack.disabled' to 'false' "
2304
+ "if Databricks runtime version is 15.4"
2305
+ )
2306
+
2307
+ nfs_root_dir = get_nfs_cache_root_dir()
2308
+ should_use_nfs = nfs_root_dir is not None
2309
+
2310
+ should_use_spark_to_broadcast_file = not (
2311
+ is_spark_in_local_mode or should_use_nfs or is_spark_connect or use_dbconnect_artifact
2312
+ )
2313
+
2314
+ # For spark connect mode,
2315
+ # If client code is executed in databricks runtime and NFS is available,
2316
+ # we save model to NFS temp directory in the driver
2317
+ # and load the model in the executor.
2318
+ should_spark_connect_use_nfs = is_in_databricks_runtime() and should_use_nfs
2319
+
2320
+ if (
2321
+ is_spark_connect
2322
+ and not is_dbconnect_mode
2323
+ and env_manager in (_EnvManager.VIRTUALENV, _EnvManager.CONDA, _EnvManager.UV)
2324
+ ):
2325
+ raise MlflowException.invalid_parameter_value(
2326
+ f"Environment manager {env_manager!r} is not supported in Spark Connect "
2327
+ "client environment if it connects to non-Databricks Spark cluster.",
2328
+ )
2329
+
2330
+ local_model_path = _download_artifact_from_uri(
2331
+ artifact_uri=model_uri,
2332
+ output_path=_create_model_downloading_tmp_dir(should_use_nfs),
2333
+ )
2334
+
2335
+ if prebuilt_env_uri:
2336
+ prebuilt_env_uri = _download_prebuilt_env_if_needed(prebuilt_env_uri)
2337
+ _verify_prebuilt_env(spark, local_model_path, prebuilt_env_uri)
2338
+ if use_dbconnect_artifact and env_manager == _EnvManager.CONDA:
2339
+ raise MlflowException(
2340
+ "Databricks connect mode or Databricks Serverless python REPL doesn't "
2341
+ "support env_manager 'conda'."
2342
+ )
2343
+
2344
+ if env_manager == _EnvManager.LOCAL:
2345
+ # Assume spark executor python environment is the same with spark driver side.
2346
+ model_requirements = _get_pip_requirements_from_model_path(local_model_path)
2347
+ warn_dependency_requirement_mismatches(model_requirements)
2348
+ _logger.warning(
2349
+ 'Calling `spark_udf()` with `env_manager="local"` does not recreate the same '
2350
+ "environment that was used during training, which may lead to errors or inaccurate "
2351
+ 'predictions. We recommend specifying `env_manager="conda"`, which automatically '
2352
+ "recreates the environment that was used to train the model and performs inference "
2353
+ "in the recreated environment."
2354
+ )
2355
+ else:
2356
+ _logger.info(
2357
+ f"This UDF will use {env_manager} to recreate the model's software environment for "
2358
+ "inference. This may take extra time during execution."
2359
+ )
2360
+ if not sys.platform.startswith("linux"):
2361
+ # TODO: support killing mlflow server launched in UDF task when spark job canceled
2362
+ # for non-linux system.
2363
+ # https://stackoverflow.com/questions/53208/how-do-i-automatically-destroy-child-processes-in-windows
2364
+ _logger.warning(
2365
+ "In order to run inference code in restored python environment, PySpark UDF "
2366
+ "processes spawn MLflow Model servers as child processes. Due to system "
2367
+ "limitations with handling SIGKILL signals, these MLflow Model server child "
2368
+ "processes cannot be cleaned up if the Spark Job is canceled."
2369
+ )
2370
+
2371
+ if prebuilt_env_uri:
2372
+ env_cache_key = os.path.basename(prebuilt_env_uri)[:-7]
2373
+ elif use_dbconnect_artifact:
2374
+ env_cache_key = _gen_prebuilt_env_archive_name(spark, local_model_path)
2375
+ else:
2376
+ env_cache_key = None
2377
+
2378
+ if use_dbconnect_artifact or prebuilt_env_uri is not None:
2379
+ prebuilt_env_root_dir = os.path.join(_PREBUILD_ENV_ROOT_LOCATION, env_cache_key)
2380
+ pyfunc_backend_env_root_config = {
2381
+ "create_env_root_dir": False,
2382
+ "env_root_dir": prebuilt_env_root_dir,
2383
+ }
2384
+ else:
2385
+ pyfunc_backend_env_root_config = {"create_env_root_dir": True}
2386
+ pyfunc_backend = get_flavor_backend(
2387
+ local_model_path,
2388
+ env_manager=env_manager,
2389
+ install_mlflow=os.environ.get("MLFLOW_HOME") is not None,
2390
+ **pyfunc_backend_env_root_config,
2391
+ )
2392
+ dbconnect_artifact_cache = DBConnectArtifactCache.get_or_create(spark)
2393
+
2394
+ if use_dbconnect_artifact:
2395
+ # Upload model artifacts and python environment to NFS as DBConncet artifacts.
2396
+ if env_manager in (_EnvManager.VIRTUALENV, _EnvManager.UV):
2397
+ if not dbconnect_artifact_cache.has_cache_key(env_cache_key):
2398
+ if prebuilt_env_uri:
2399
+ env_archive_path = prebuilt_env_uri
2400
+ else:
2401
+ env_archive_path = _prebuild_env_internal(
2402
+ local_model_path, env_cache_key, get_or_create_tmp_dir(), env_manager
2403
+ )
2404
+ dbconnect_artifact_cache.add_artifact_archive(env_cache_key, env_archive_path)
2405
+
2406
+ if not dbconnect_artifact_cache.has_cache_key(model_uri):
2407
+ model_archive_path = os.path.join(
2408
+ os.path.dirname(local_model_path), f"model-{uuid.uuid4()}.tar.gz"
2409
+ )
2410
+ archive_directory(local_model_path, model_archive_path)
2411
+ dbconnect_artifact_cache.add_artifact_archive(model_uri, model_archive_path)
2412
+
2413
+ elif not should_use_spark_to_broadcast_file:
2414
+ if prebuilt_env_uri:
2415
+ # Extract prebuilt env archive file to NFS directory.
2416
+ prebuilt_env_nfs_dir = os.path.join(
2417
+ get_or_create_nfs_tmp_dir(), "prebuilt_env", env_cache_key
2418
+ )
2419
+ if not os.path.exists(prebuilt_env_nfs_dir):
2420
+ extract_archive_to_dir(prebuilt_env_uri, prebuilt_env_nfs_dir)
2421
+ else:
2422
+ # Prepare restored environment in driver side if possible.
2423
+ # Note: In databricks runtime, because databricks notebook cell output cannot capture
2424
+ # child process output, so that set capture_output to be True so that when `conda
2425
+ # prepare env` command failed, the exception message will include command stdout/stderr
2426
+ # output. Otherwise user have to check cluster driver log to find command stdout/stderr
2427
+ # output.
2428
+ # In non-databricks runtime, set capture_output to be False, because the benefit of
2429
+ # "capture_output=False" is the output will be printed immediately, otherwise you have
2430
+ # to wait conda command fail and suddenly get all output printed (included in error
2431
+ # message).
2432
+ if env_manager != _EnvManager.LOCAL:
2433
+ pyfunc_backend.prepare_env(
2434
+ model_uri=local_model_path, capture_output=is_in_databricks_runtime()
2435
+ )
2436
+ else:
2437
+ # Broadcast local model directory to remote worker if needed.
2438
+ archive_path = SparkModelCache.add_local_model(spark, local_model_path)
2439
+
2440
+ model_metadata = Model.load(os.path.join(local_model_path, MLMODEL_FILE_NAME))
2441
+
2442
+ if result_type is None:
2443
+ if model_output_schema := model_metadata.get_output_schema():
2444
+ result_type = _infer_spark_udf_return_type(model_output_schema)
2445
+ else:
2446
+ _logger.warning(
2447
+ "No 'result_type' provided for spark_udf and the model does not "
2448
+ "have an output schema. 'result_type' is set to 'double' type."
2449
+ )
2450
+ result_type = DoubleType()
2451
+ else:
2452
+ if isinstance(result_type, str):
2453
+ result_type = _parse_spark_datatype(result_type)
2454
+
2455
+ # if result type is inferred by MLflow, we don't need to check it
2456
+ if not _check_udf_return_type(result_type):
2457
+ raise MlflowException.invalid_parameter_value(
2458
+ f"""Invalid 'spark_udf' result type: {result_type}.
2459
+ It must be one of the following types:
2460
+ Primitive types:
2461
+ - int
2462
+ - long
2463
+ - float
2464
+ - double
2465
+ - string
2466
+ - boolean
2467
+ Compound types:
2468
+ - ND array of primitives / structs.
2469
+ - struct<field: primitive | array<primitive> | array<array<primitive>>, ...>:
2470
+ A struct with primitive, ND array<primitive/structs>,
2471
+ e.g., struct<a:int, b:array<int>>.
2472
+ """
2473
+ )
2474
+ params = _validate_params(params, model_metadata)
2475
+
2476
+ def _predict_row_batch(predict_fn, args):
2477
+ input_schema = model_metadata.get_input_schema()
2478
+ args = list(args)
2479
+ if len(args) == 1 and isinstance(args[0], pandas.DataFrame):
2480
+ pdf = args[0]
2481
+ else:
2482
+ if input_schema is None:
2483
+ names = [str(i) for i in range(len(args))]
2484
+ else:
2485
+ names = input_schema.input_names()
2486
+ required_names = input_schema.required_input_names()
2487
+ if len(args) > len(names):
2488
+ args = args[: len(names)]
2489
+ if len(args) < len(required_names):
2490
+ raise MlflowException(
2491
+ f"Model input is missing required columns. Expected {len(names)} required"
2492
+ f" input columns {names}, but the model received only {len(args)} "
2493
+ "unnamed input columns (Since the columns were passed unnamed they are"
2494
+ " expected to be in the order specified by the schema)."
2495
+ )
2496
+ pdf = pandas.DataFrame(
2497
+ data={
2498
+ names[i]: arg
2499
+ if isinstance(arg, pandas.Series)
2500
+ # pandas_udf receives a StructType column as a pandas DataFrame.
2501
+ # We need to convert it back to a dict of pandas Series.
2502
+ else arg.apply(lambda row: row.to_dict(), axis=1)
2503
+ for i, arg in enumerate(args)
2504
+ },
2505
+ columns=names,
2506
+ )
2507
+
2508
+ result = predict_fn(pdf, params)
2509
+
2510
+ if isinstance(result, dict):
2511
+ result = {k: list(v) for k, v in result.items()}
2512
+
2513
+ if isinstance(result_type, ArrayType) and isinstance(result_type.elementType, ArrayType):
2514
+ result_values = _convert_array_values(result, result_type)
2515
+ return pandas.Series(result_values)
2516
+
2517
+ if isinstance(result_type, SparkStructType):
2518
+ if (
2519
+ isinstance(result, list)
2520
+ and len(result) > 0
2521
+ and isinstance(result[0], pydantic.BaseModel)
2522
+ ):
2523
+ result = pandas.DataFrame([model_dump_compat(r) for r in result])
2524
+ else:
2525
+ result = pandas.DataFrame(result)
2526
+ return _convert_struct_values(result, result_type)
2527
+
2528
+ if not isinstance(result, pandas.DataFrame):
2529
+ if isinstance(result_type, MapType):
2530
+ # list of dicts should be converted into a single column
2531
+ result = pandas.DataFrame([result])
2532
+ else:
2533
+ result = (
2534
+ pandas.DataFrame([result]) if np.isscalar(result) else pandas.DataFrame(result)
2535
+ )
2536
+
2537
+ elem_type = result_type.elementType if isinstance(result_type, ArrayType) else result_type
2538
+ if type(elem_type) == IntegerType:
2539
+ result = result.select_dtypes(
2540
+ [np.byte, np.ubyte, np.short, np.ushort, np.int32]
2541
+ ).astype(np.int32)
2542
+
2543
+ elif type(elem_type) == LongType:
2544
+ result = result.select_dtypes([np.byte, np.ubyte, np.short, np.ushort, int]).astype(
2545
+ np.int64
2546
+ )
2547
+
2548
+ elif type(elem_type) == FloatType:
2549
+ result = result.select_dtypes(include=(np.number,)).astype(np.float32)
2550
+
2551
+ elif type(elem_type) == DoubleType:
2552
+ result = result.select_dtypes(include=(np.number,)).astype(np.float64)
2553
+
2554
+ elif type(elem_type) == BooleanType:
2555
+ result = result.select_dtypes([bool, np.bool_]).astype(bool)
2556
+
2557
+ if len(result.columns) == 0:
2558
+ raise MlflowException(
2559
+ message="The model did not produce any values compatible with the requested "
2560
+ f"type '{elem_type}'. Consider requesting udf with StringType or "
2561
+ "Arraytype(StringType).",
2562
+ error_code=INVALID_PARAMETER_VALUE,
2563
+ )
2564
+
2565
+ if type(elem_type) == StringType:
2566
+ if Version(pandas.__version__) >= Version("2.1.0"):
2567
+ result = result.map(str)
2568
+ else:
2569
+ result = result.applymap(str)
2570
+
2571
+ if type(result_type) == ArrayType:
2572
+ return pandas.Series(result.to_numpy().tolist())
2573
+ else:
2574
+ return result[result.columns[0]]
2575
+
2576
+ result_type_hint = (
2577
+ pandas.DataFrame if isinstance(result_type, SparkStructType) else pandas.Series
2578
+ )
2579
+
2580
+ tracking_uri = mlflow.get_tracking_uri()
2581
+
2582
+ @pandas_udf(result_type)
2583
+ def udf(
2584
+ iterator: Iterator[Tuple[Union[pandas.Series, pandas.DataFrame], ...]], # noqa: UP006
2585
+ ) -> Iterator[result_type_hint]:
2586
+ # importing here to prevent circular import
2587
+ from mlflow.pyfunc.scoring_server.client import (
2588
+ ScoringServerClient,
2589
+ StdinScoringServerClient,
2590
+ )
2591
+
2592
+ # Note: this is a pandas udf function in iteration style, which takes an iterator of
2593
+ # tuple of pandas.Series and outputs an iterator of pandas.Series.
2594
+ update_envs = {}
2595
+ if mlflow_home is not None:
2596
+ update_envs["MLFLOW_HOME"] = mlflow_home
2597
+ if openai_env_vars:
2598
+ update_envs.update(openai_env_vars)
2599
+ if mlflow_testing:
2600
+ update_envs[_MLFLOW_TESTING.name] = mlflow_testing
2601
+ if extra_env:
2602
+ update_envs.update(extra_env)
2603
+
2604
+ # use `modified_environ` to temporarily set the envs and restore them finally
2605
+ with modified_environ(update=update_envs):
2606
+ scoring_server_proc = None
2607
+ # set tracking_uri inside udf so that with spark_connect
2608
+ # we can load the model from correct path
2609
+ mlflow.set_tracking_uri(tracking_uri)
2610
+
2611
+ if env_manager != _EnvManager.LOCAL:
2612
+ if use_dbconnect_artifact:
2613
+ local_model_path_on_executor = (
2614
+ dbconnect_artifact_cache.get_unpacked_artifact_dir(model_uri)
2615
+ )
2616
+ env_src_dir = dbconnect_artifact_cache.get_unpacked_artifact_dir(env_cache_key)
2617
+
2618
+ # Create symlink if it does not exist
2619
+ if not os.path.exists(prebuilt_env_root_dir):
2620
+ os.symlink(env_src_dir, prebuilt_env_root_dir)
2621
+ elif prebuilt_env_uri is not None:
2622
+ # prebuilt env is extracted to `prebuilt_env_nfs_dir` directory,
2623
+ # and model is downloaded to `local_model_path` which points to an NFS
2624
+ # directory too.
2625
+ local_model_path_on_executor = None
2626
+
2627
+ # Create symlink if it does not exist
2628
+ if not os.path.exists(prebuilt_env_root_dir):
2629
+ os.symlink(prebuilt_env_nfs_dir, prebuilt_env_root_dir)
2630
+ elif should_use_spark_to_broadcast_file:
2631
+ local_model_path_on_executor = _SparkDirectoryDistributor.get_or_extract(
2632
+ archive_path
2633
+ )
2634
+ # Call "prepare_env" in advance in order to reduce scoring server launch time.
2635
+ # So that we can use a shorter timeout when call `client.wait_server_ready`,
2636
+ # otherwise we have to set a long timeout for `client.wait_server_ready` time,
2637
+ # this prevents spark UDF task failing fast if other exception raised
2638
+ # when scoring server launching.
2639
+ # Set "capture_output" so that if "conda env create" command failed, the command
2640
+ # stdout/stderr output will be attached to the exception message and included in
2641
+ # driver side exception.
2642
+ pyfunc_backend.prepare_env(
2643
+ model_uri=local_model_path_on_executor, capture_output=True
2644
+ )
2645
+ else:
2646
+ local_model_path_on_executor = None
2647
+
2648
+ if check_port_connectivity():
2649
+ # launch scoring server
2650
+ server_port = find_free_port()
2651
+ host = "127.0.0.1"
2652
+ scoring_server_proc = pyfunc_backend.serve(
2653
+ model_uri=local_model_path_on_executor or local_model_path,
2654
+ port=server_port,
2655
+ host=host,
2656
+ timeout=MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT.get(),
2657
+ enable_mlserver=False,
2658
+ synchronous=False,
2659
+ stdout=subprocess.PIPE,
2660
+ stderr=subprocess.STDOUT,
2661
+ model_config=model_config,
2662
+ )
2663
+
2664
+ client = ScoringServerClient(host, server_port)
2665
+ else:
2666
+ scoring_server_proc = pyfunc_backend.serve_stdin(
2667
+ model_uri=local_model_path_on_executor or local_model_path,
2668
+ stdout=subprocess.PIPE,
2669
+ stderr=subprocess.STDOUT,
2670
+ model_config=model_config,
2671
+ )
2672
+ client = StdinScoringServerClient(scoring_server_proc)
2673
+
2674
+ _logger.info("Using %s", client.__class__.__name__)
2675
+
2676
+ server_tail_logs = collections.deque(
2677
+ maxlen=_MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP
2678
+ )
2679
+
2680
+ def server_redirect_log_thread_func(child_stdout):
2681
+ for line in child_stdout:
2682
+ decoded = line.decode() if isinstance(line, bytes) else line
2683
+ server_tail_logs.append(decoded)
2684
+ sys.stdout.write("[model server] " + decoded)
2685
+
2686
+ server_redirect_log_thread = threading.Thread(
2687
+ target=server_redirect_log_thread_func,
2688
+ args=(scoring_server_proc.stdout,),
2689
+ daemon=True,
2690
+ name=f"mlflow_pyfunc_model_server_log_redirector_{uuid.uuid4().hex[:8]}",
2691
+ )
2692
+ server_redirect_log_thread.start()
2693
+
2694
+ try:
2695
+ client.wait_server_ready(timeout=90, scoring_server_proc=scoring_server_proc)
2696
+ except Exception as e:
2697
+ err_msg = (
2698
+ "During spark UDF task execution, mlflow model server failed to launch. "
2699
+ )
2700
+ if len(server_tail_logs) == _MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP:
2701
+ err_msg += (
2702
+ f"Last {_MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP} "
2703
+ "lines of MLflow model server output:\n"
2704
+ )
2705
+ else:
2706
+ err_msg += "MLflow model server output:\n"
2707
+ err_msg += "".join(server_tail_logs)
2708
+ raise MlflowException(err_msg) from e
2709
+
2710
+ def batch_predict_fn(pdf, params=None):
2711
+ if "params" in inspect.signature(client.invoke).parameters:
2712
+ return client.invoke(pdf, params=params).get_predictions()
2713
+ _log_warning_if_params_not_in_predict_signature(_logger, params)
2714
+ return client.invoke(pdf).get_predictions()
2715
+
2716
+ elif env_manager == _EnvManager.LOCAL:
2717
+ if use_dbconnect_artifact:
2718
+ model_path = dbconnect_artifact_cache.get_unpacked_artifact_dir(model_uri)
2719
+ loaded_model = mlflow.pyfunc.load_model(model_path, model_config=model_config)
2720
+ elif is_spark_connect and not should_spark_connect_use_nfs:
2721
+ model_path = os.path.join(
2722
+ tempfile.gettempdir(),
2723
+ "mlflow",
2724
+ hashlib.sha1(model_uri.encode(), usedforsecurity=False).hexdigest(),
2725
+ # Use pid to avoid conflict when multiple spark UDF tasks
2726
+ str(os.getpid()),
2727
+ )
2728
+ try:
2729
+ loaded_model = mlflow.pyfunc.load_model(
2730
+ model_path, model_config=model_config
2731
+ )
2732
+ except Exception:
2733
+ os.makedirs(model_path, exist_ok=True)
2734
+ loaded_model = mlflow.pyfunc.load_model(
2735
+ model_uri, dst_path=model_path, model_config=model_config
2736
+ )
2737
+ elif should_use_spark_to_broadcast_file:
2738
+ loaded_model, _ = SparkModelCache.get_or_load(archive_path)
2739
+ else:
2740
+ loaded_model = mlflow.pyfunc.load_model(
2741
+ local_model_path, model_config=model_config
2742
+ )
2743
+
2744
+ def batch_predict_fn(pdf, params=None):
2745
+ if "params" in inspect.signature(loaded_model.predict).parameters:
2746
+ return loaded_model.predict(pdf, params=params)
2747
+ _log_warning_if_params_not_in_predict_signature(_logger, params)
2748
+ return loaded_model.predict(pdf)
2749
+
2750
+ try:
2751
+ for input_batch in iterator:
2752
+ # If the UDF is called with only multiple arguments,
2753
+ # the `input_batch` is a tuple which composes of several pd.Series/pd.DataFrame
2754
+ # objects.
2755
+ # If the UDF is called with only one argument,
2756
+ # the `input_batch` instance will be an instance of `pd.Series`/`pd.DataFrame`,
2757
+ if isinstance(input_batch, (pandas.Series, pandas.DataFrame)):
2758
+ # UDF is called with only one argument
2759
+ row_batch_args = (input_batch,)
2760
+ else:
2761
+ row_batch_args = input_batch
2762
+
2763
+ if len(row_batch_args[0]) > 0:
2764
+ yield _predict_row_batch(batch_predict_fn, row_batch_args)
2765
+ except SystemError as e:
2766
+ if "error return without exception set" in str(e):
2767
+ raise MlflowException(
2768
+ "A system error related to the Python C extension has occurred. "
2769
+ "This is usually caused by an incompatible Python library that uses the "
2770
+ "C extension. To address this, we recommend you to log the model "
2771
+ "with fixed version python libraries that use the C extension "
2772
+ "(such as 'numpy' library), and set spark_udf `env_manager` argument "
2773
+ "to 'virtualenv' or 'uv' so that spark_udf can restore the original "
2774
+ "python library version before running model inference."
2775
+ ) from e
2776
+ finally:
2777
+ if scoring_server_proc is not None:
2778
+ os.kill(scoring_server_proc.pid, signal.SIGTERM)
2779
+
2780
+ udf.metadata = model_metadata
2781
+
2782
+ @functools.wraps(udf)
2783
+ def udf_with_default_cols(*args):
2784
+ if len(args) == 0:
2785
+ input_schema = model_metadata.get_input_schema()
2786
+ if input_schema and len(input_schema.optional_input_names()) > 0:
2787
+ raise MlflowException(
2788
+ message="Cannot apply UDF without column names specified when"
2789
+ " model signature contains optional columns.",
2790
+ error_code=INVALID_PARAMETER_VALUE,
2791
+ )
2792
+ if input_schema and len(input_schema.inputs) > 0:
2793
+ if input_schema.has_input_names():
2794
+ input_names = input_schema.input_names()
2795
+ return udf(*input_names)
2796
+ else:
2797
+ raise MlflowException(
2798
+ message="Cannot apply udf because no column names specified. The udf "
2799
+ f"expects {len(input_schema.inputs)} columns with types: "
2800
+ "{input_schema.inputs}. Input column names could not be inferred from the"
2801
+ " model signature (column names not found).",
2802
+ error_code=INVALID_PARAMETER_VALUE,
2803
+ )
2804
+ else:
2805
+ raise MlflowException(
2806
+ "Attempting to apply udf on zero columns because no column names were "
2807
+ "specified as arguments or inferred from the model signature.",
2808
+ error_code=INVALID_PARAMETER_VALUE,
2809
+ )
2810
+ else:
2811
+ return udf(*args)
2812
+
2813
+ return udf_with_default_cols
2814
+
2815
+
2816
+ def _validate_function_python_model(python_model):
2817
+ if not (isinstance(python_model, PythonModel) or callable(python_model)):
2818
+ raise MlflowException(
2819
+ "`python_model` must be a PythonModel instance, callable object, or path to a script "
2820
+ "that uses set_model() to set a PythonModel instance or callable object.",
2821
+ error_code=INVALID_PARAMETER_VALUE,
2822
+ )
2823
+
2824
+ if callable(python_model):
2825
+ num_args = len(inspect.signature(python_model).parameters)
2826
+ if num_args != 1:
2827
+ raise MlflowException(
2828
+ "When `python_model` is a callable object, it must accept exactly one argument. "
2829
+ f"Found {num_args} arguments.",
2830
+ error_code=INVALID_PARAMETER_VALUE,
2831
+ )
2832
+
2833
+
2834
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
2835
+ @trace_disabled # Suppress traces for internal predict calls while saving model
2836
+ def save_model(
2837
+ path,
2838
+ loader_module=None,
2839
+ data_path=None,
2840
+ code_paths=None,
2841
+ infer_code_paths=False,
2842
+ conda_env=None,
2843
+ mlflow_model=None,
2844
+ python_model=None,
2845
+ artifacts=None,
2846
+ signature: ModelSignature = None,
2847
+ input_example: ModelInputExample = None,
2848
+ pip_requirements=None,
2849
+ extra_pip_requirements=None,
2850
+ metadata=None,
2851
+ model_config=None,
2852
+ streamable=None,
2853
+ resources: Optional[Union[str, list[Resource]]] = None,
2854
+ auth_policy: Optional[AuthPolicy] = None,
2855
+ **kwargs,
2856
+ ):
2857
+ """
2858
+ Save a Pyfunc model with custom inference logic and optional data dependencies to a path on the
2859
+ local filesystem.
2860
+
2861
+ For information about the workflows that this method supports, please see :ref:`"workflows for
2862
+ creating custom pyfunc models" <pyfunc-create-custom-workflows>` and
2863
+ :ref:`"which workflow is right for my use case?" <pyfunc-create-custom-selecting-workflow>`.
2864
+ Note that the parameters for the second workflow: ``loader_module``, ``data_path`` and the
2865
+ parameters for the first workflow: ``python_model``, ``artifacts``, cannot be
2866
+ specified together.
2867
+
2868
+ Args:
2869
+ path: The path to which to save the Python model.
2870
+ loader_module: The name of the Python module that is used to load the model
2871
+ from ``data_path``. This module must define a method with the prototype
2872
+ ``_load_pyfunc(data_path)``. If not ``None``, this module and its
2873
+ dependencies must be included in one of the following locations:
2874
+
2875
+ - The MLflow library.
2876
+ - Package(s) listed in the model's Conda environment, specified by
2877
+ the ``conda_env`` parameter.
2878
+ - One or more of the files specified by the ``code_paths`` parameter.
2879
+
2880
+ data_path: Path to a file or directory containing model data.
2881
+ code_paths: {{ code_paths_pyfunc }}
2882
+ infer_code_paths: {{ infer_code_paths }}
2883
+ conda_env: {{ conda_env }}
2884
+ mlflow_model: :py:mod:`mlflow.models.Model` configuration to which to add the
2885
+ **python_function** flavor.
2886
+ python_model:
2887
+ An instance of a subclass of :class:`~PythonModel` or a callable object with a single
2888
+ argument (see the examples below). The passed-in object is serialized using the
2889
+ CloudPickle library. The python_model can also be a file path to the PythonModel
2890
+ which defines the model from code artifact rather than serializing the model object.
2891
+ Any dependencies of the class should be included in one of the
2892
+ following locations:
2893
+
2894
+ - The MLflow library.
2895
+ - Package(s) listed in the model's Conda environment, specified by the ``conda_env``
2896
+ parameter.
2897
+ - One or more of the files specified by the ``code_paths`` parameter.
2898
+
2899
+ Note: If the class is imported from another module, as opposed to being defined in the
2900
+ ``__main__`` scope, the defining module should also be included in one of the listed
2901
+ locations.
2902
+
2903
+ **Examples**
2904
+
2905
+ Class model
2906
+
2907
+ .. code-block:: python
2908
+
2909
+ from typing import List, Dict
2910
+ import mlflow
2911
+
2912
+
2913
+ class MyModel(mlflow.pyfunc.PythonModel):
2914
+ def predict(self, context, model_input: List[str], params=None) -> List[str]:
2915
+ return [i.upper() for i in model_input]
2916
+
2917
+
2918
+ mlflow.pyfunc.save_model("model", python_model=MyModel(), input_example=["a"])
2919
+ model = mlflow.pyfunc.load_model("model")
2920
+ print(model.predict(["a", "b", "c"])) # -> ["A", "B", "C"]
2921
+
2922
+ Functional model
2923
+
2924
+ .. note::
2925
+ Experimental: Functional model support is experimental and may change or be removed
2926
+ in a future release without warning.
2927
+
2928
+ .. code-block:: python
2929
+
2930
+ from typing import List
2931
+ import mlflow
2932
+
2933
+
2934
+ def predict(model_input: List[str]) -> List[str]:
2935
+ return [i.upper() for i in model_input]
2936
+
2937
+
2938
+ mlflow.pyfunc.save_model("model", python_model=predict, input_example=["a"])
2939
+ model = mlflow.pyfunc.load_model("model")
2940
+ print(model.predict(["a", "b", "c"])) # -> ["A", "B", "C"]
2941
+
2942
+ Model from code
2943
+
2944
+ .. note::
2945
+ Experimental: Model from code model support is experimental and may change or
2946
+ be removed in a future release without warning.
2947
+
2948
+ .. code-block:: python
2949
+
2950
+ # code.py
2951
+ from typing import List
2952
+ import mlflow
2953
+
2954
+
2955
+ class MyModel(mlflow.pyfunc.PythonModel):
2956
+ def predict(self, context, model_input: List[str], params=None) -> List[str]:
2957
+ return [i.upper() for i in model_input]
2958
+
2959
+
2960
+ mlflow.models.set_model(MyModel())
2961
+
2962
+ # log_model.py
2963
+ import mlflow
2964
+
2965
+ with mlflow.start_run():
2966
+ model_info = mlflow.pyfunc.log_model(
2967
+ name="model",
2968
+ python_model="code.py",
2969
+ )
2970
+
2971
+ If the `predict` method or function has type annotations, MLflow automatically
2972
+ constructs a model signature based on the type annotations (unless the ``signature``
2973
+ argument is explicitly specified), and converts the input value to the specified type
2974
+ before passing it to the function. Currently, the following type annotations are
2975
+ supported:
2976
+
2977
+ - ``List[str]``
2978
+ - ``List[Dict[str, str]]``
2979
+
2980
+ artifacts: A dictionary containing ``<name, artifact_uri>`` entries. Remote artifact URIs
2981
+ are resolved to absolute filesystem paths, producing a dictionary of
2982
+ ``<name, absolute_path>`` entries. ``python_model`` can reference these
2983
+ resolved entries as the ``artifacts`` property of the ``context`` parameter
2984
+ in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
2985
+ and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
2986
+ For example, consider the following ``artifacts`` dictionary::
2987
+
2988
+ {"my_file": "s3://my-bucket/path/to/my/file"}
2989
+
2990
+ In this case, the ``"my_file"`` artifact is downloaded from S3. The
2991
+ ``python_model`` can then refer to ``"my_file"`` as an absolute filesystem
2992
+ path via ``context.artifacts["my_file"]``.
2993
+
2994
+ If ``None``, no artifacts are added to the model.
2995
+
2996
+ signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
2997
+ describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
2998
+ The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
2999
+ from datasets with valid model input (e.g. the training dataset with target
3000
+ column omitted) and valid model output (e.g. model predictions generated on
3001
+ the training dataset), for example:
3002
+
3003
+ .. code-block:: python
3004
+
3005
+ from mlflow.models import infer_signature
3006
+
3007
+ train = df.drop_column("target_label")
3008
+ predictions = ... # compute model predictions
3009
+ signature = infer_signature(train, predictions)
3010
+ input_example: {{ input_example }}
3011
+ pip_requirements: {{ pip_requirements }}
3012
+ extra_pip_requirements: {{ extra_pip_requirements }}
3013
+ metadata: {{ metadata }}
3014
+ model_config: The model configuration to apply to the model. The configuration will
3015
+ be available as the ``model_config`` property of the ``context`` parameter
3016
+ in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
3017
+ and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
3018
+ The configuration can be passed as a file path, or a dict with string keys.
3019
+
3020
+ .. Note:: Experimental: This parameter may change or be removed in a future
3021
+ release without warning.
3022
+ streamable: A boolean value indicating if the model supports streaming prediction,
3023
+ If None, MLflow will try to inspect if the model supports streaming
3024
+ by checking if `predict_stream` method exists. Default None.
3025
+ resources: A list of model resources or a resources.yaml file containing a list of
3026
+ resources required to serve the model.
3027
+
3028
+ .. Note:: Experimental: This parameter may change or be removed in a future
3029
+ release without warning.
3030
+ auth_policy: {{ auth_policy }}
3031
+ kwargs: Extra keyword arguments.
3032
+ """
3033
+ _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
3034
+ _validate_pyfunc_model_config(model_config)
3035
+ _validate_and_prepare_target_save_path(path)
3036
+
3037
+ with tempfile.TemporaryDirectory() as temp_dir:
3038
+ model_code_path = None
3039
+ if python_model:
3040
+ if isinstance(model_config, Path):
3041
+ model_config = os.fspath(model_config)
3042
+
3043
+ if isinstance(model_config, str):
3044
+ model_config = _validate_and_get_model_config_from_file(model_config)
3045
+
3046
+ if isinstance(python_model, Path):
3047
+ python_model = os.fspath(python_model)
3048
+
3049
+ if isinstance(python_model, str):
3050
+ model_code_path = _validate_and_get_model_code_path(python_model, temp_dir)
3051
+ _validate_and_copy_file_to_directory(model_code_path, path, "code")
3052
+ python_model = _load_model_code_path(model_code_path, model_config)
3053
+
3054
+ _validate_function_python_model(python_model)
3055
+ if callable(python_model) and all(
3056
+ a is None for a in (input_example, pip_requirements, extra_pip_requirements)
3057
+ ):
3058
+ raise MlflowException(
3059
+ "If `python_model` is a callable object, at least one of `input_example`, "
3060
+ "`pip_requirements`, or `extra_pip_requirements` must be specified."
3061
+ )
3062
+
3063
+ mlflow_model = kwargs.pop("model", mlflow_model)
3064
+ if len(kwargs) > 0:
3065
+ raise TypeError(f"save_model() got unexpected keyword arguments: {kwargs}")
3066
+
3067
+ if code_paths is not None:
3068
+ if not isinstance(code_paths, list):
3069
+ raise TypeError(f"Argument code_paths should be a list, not {type(code_paths)}")
3070
+
3071
+ first_argument_set = {
3072
+ "loader_module": loader_module,
3073
+ "data_path": data_path,
3074
+ }
3075
+ second_argument_set = {
3076
+ "artifacts": artifacts,
3077
+ "python_model": python_model,
3078
+ }
3079
+ first_argument_set_specified = any(item is not None for item in first_argument_set.values())
3080
+ second_argument_set_specified = any(item is not None for item in second_argument_set.values())
3081
+ if first_argument_set_specified and second_argument_set_specified:
3082
+ raise MlflowException(
3083
+ message=(
3084
+ f"The following sets of parameters cannot be specified together:"
3085
+ f" {first_argument_set.keys()} and {second_argument_set.keys()}."
3086
+ " All parameters in one set must be `None`. Instead, found"
3087
+ f" the following values: {first_argument_set} and {second_argument_set}"
3088
+ ),
3089
+ error_code=INVALID_PARAMETER_VALUE,
3090
+ )
3091
+ elif (loader_module is None) and (python_model is None):
3092
+ msg = (
3093
+ "Either `loader_module` or `python_model` must be specified. A `loader_module` "
3094
+ "should be a python module. A `python_model` should be a subclass of PythonModel"
3095
+ )
3096
+ raise MlflowException(message=msg, error_code=INVALID_PARAMETER_VALUE)
3097
+ if mlflow_model is None:
3098
+ mlflow_model = Model()
3099
+ saved_example = None
3100
+ signature_from_type_hints = None
3101
+ type_hint_from_example = None
3102
+ if isinstance(python_model, ChatModel):
3103
+ if signature is not None:
3104
+ raise MlflowException(
3105
+ "ChatModel subclasses have a standard signature that is set "
3106
+ "automatically. Please remove the `signature` parameter from "
3107
+ "the call to log_model() or save_model().",
3108
+ error_code=INVALID_PARAMETER_VALUE,
3109
+ )
3110
+ mlflow_model.signature = ModelSignature(
3111
+ CHAT_MODEL_INPUT_SCHEMA,
3112
+ CHAT_MODEL_OUTPUT_SCHEMA,
3113
+ )
3114
+ # For ChatModel we set default metadata to indicate its task
3115
+ default_metadata = {TASK: _DEFAULT_CHAT_MODEL_METADATA_TASK}
3116
+ mlflow_model.metadata = default_metadata | (mlflow_model.metadata or {})
3117
+
3118
+ if input_example:
3119
+ input_example, input_params = _split_input_data_and_params(input_example)
3120
+ valid_params = {}
3121
+ if isinstance(input_example, list):
3122
+ messages = [
3123
+ message if isinstance(message, ChatMessage) else ChatMessage.from_dict(message)
3124
+ for message in input_example
3125
+ ]
3126
+ else:
3127
+ # If the input example is a dictionary, convert it to ChatMessage format
3128
+ messages = [
3129
+ ChatMessage.from_dict(m) if isinstance(m, dict) else m
3130
+ for m in input_example["messages"]
3131
+ ]
3132
+ valid_params = {
3133
+ k: v
3134
+ for k, v in input_example.items()
3135
+ if k != "messages" and k in ChatParams.keys()
3136
+ }
3137
+ if valid_params or input_params:
3138
+ _logger.warning(_CHAT_PARAMS_WARNING_MESSAGE)
3139
+ input_example = {
3140
+ "messages": [m.to_dict() for m in messages],
3141
+ **valid_params,
3142
+ **(input_params or {}),
3143
+ }
3144
+ else:
3145
+ input_example = CHAT_MODEL_INPUT_EXAMPLE
3146
+ _logger.warning(_CHAT_PARAMS_WARNING_MESSAGE)
3147
+ messages = [ChatMessage.from_dict(m) for m in input_example["messages"]]
3148
+ # extra params introduced by ChatParams will not be included in the
3149
+ # logged input example file to avoid confusion
3150
+ _save_example(mlflow_model, input_example, path)
3151
+ params = ChatParams.from_dict(input_example)
3152
+
3153
+ # call load_context() first, as predict may depend on it
3154
+ _logger.info("Predicting on input example to validate output")
3155
+ context = PythonModelContext(artifacts, model_config)
3156
+ python_model.load_context(context)
3157
+ if "context" in inspect.signature(python_model.predict).parameters:
3158
+ output = python_model.predict(context, messages, params)
3159
+ else:
3160
+ output = python_model.predict(messages, params)
3161
+ if not isinstance(output, ChatCompletionResponse):
3162
+ raise MlflowException(
3163
+ "Failed to save ChatModel. Please ensure that the model's predict() method "
3164
+ "returns a ChatCompletionResponse object. If your predict() method currently "
3165
+ "returns a dict, you can instantiate a ChatCompletionResponse using "
3166
+ "`from_dict()`, e.g. `ChatCompletionResponse.from_dict(output)`",
3167
+ )
3168
+ elif isinstance(python_model, ChatAgent):
3169
+ input_example = _save_model_chat_agent_helper(
3170
+ python_model, mlflow_model, signature, input_example
3171
+ )
3172
+ elif IS_RESPONSES_AGENT_AVAILABLE and isinstance(python_model, ResponsesAgent):
3173
+ input_example = _save_model_responses_agent_helper(
3174
+ python_model, mlflow_model, signature, input_example
3175
+ )
3176
+ elif callable(python_model) or isinstance(python_model, PythonModel):
3177
+ model_for_signature_inference = None
3178
+ predict_func = None
3179
+ if callable(python_model):
3180
+ # first argument is the model input
3181
+ type_hints = _extract_type_hints(python_model, input_arg_index=0)
3182
+ pyfunc_decorator_used = getattr(python_model, "_is_pyfunc", False)
3183
+ # only show the warning here if @pyfunc is not applied on the function
3184
+ # since @pyfunc will trigger the warning instead
3185
+ if type_hints.input is None and not pyfunc_decorator_used:
3186
+ color_warning(
3187
+ "Add type hints to the `predict` method to enable "
3188
+ "data validation and automatic signature inference. Check "
3189
+ "https://mlflow.org/docs/latest/model/python_model.html#type-hint-usage-in-pythonmodel"
3190
+ " for more details.",
3191
+ stacklevel=1,
3192
+ color="yellow",
3193
+ )
3194
+ model_for_signature_inference = _FunctionPythonModel(python_model)
3195
+ predict_func = python_model
3196
+ elif isinstance(python_model, PythonModel):
3197
+ type_hints = python_model.predict_type_hints
3198
+ model_for_signature_inference = python_model
3199
+ predict_func = python_model.predict
3200
+ # Load context before calling predict to ensure necessary artifacts are available
3201
+ context = PythonModelContext(artifacts, model_config)
3202
+ model_for_signature_inference.load_context(context)
3203
+ type_hint_from_example = _is_type_hint_from_example(type_hints.input)
3204
+ if type_hint_from_example:
3205
+ should_infer_signature_from_type_hints = False
3206
+ else:
3207
+ should_infer_signature_from_type_hints = (
3208
+ not _signature_cannot_be_inferred_from_type_hint(type_hints.input)
3209
+ )
3210
+ if should_infer_signature_from_type_hints:
3211
+ signature_from_type_hints = _infer_signature_from_type_hints(
3212
+ func=predict_func,
3213
+ type_hints=type_hints,
3214
+ input_example=input_example,
3215
+ )
3216
+ # only infer signature based on input example when signature
3217
+ # and type hints are not provided
3218
+ if signature is None and signature_from_type_hints is None:
3219
+ saved_example = _save_example(mlflow_model, input_example, path)
3220
+ if saved_example is not None:
3221
+ _logger.info("Inferring model signature from input example")
3222
+ try:
3223
+ mlflow_model.signature = _infer_signature_from_input_example(
3224
+ saved_example,
3225
+ _PythonModelPyfuncWrapper(model_for_signature_inference, None, None),
3226
+ )
3227
+ except Exception as e:
3228
+ _logger.warning(
3229
+ f"Failed to infer model signature from input example, error: {e}",
3230
+ )
3231
+ else:
3232
+ if type_hint_from_example and mlflow_model.signature:
3233
+ update_signature_for_type_hint_from_example(
3234
+ input_example, mlflow_model.signature
3235
+ )
3236
+ else:
3237
+ if type_hint_from_example:
3238
+ _logger.warning(
3239
+ _TYPE_FROM_EXAMPLE_ERROR_MESSAGE,
3240
+ extra={"color": "red"},
3241
+ )
3242
+ # if signature is inferred from type hints, warnings are emitted
3243
+ # in _infer_signature_from_type_hints
3244
+ elif not should_infer_signature_from_type_hints:
3245
+ _logger.warning(
3246
+ "Failed to infer model signature: "
3247
+ f"Type hint {type_hints} cannot be used to infer model signature and "
3248
+ "input example is not provided, model signature cannot be inferred."
3249
+ )
3250
+
3251
+ if metadata is not None:
3252
+ mlflow_model.metadata = metadata
3253
+ if saved_example is None:
3254
+ saved_example = _save_example(mlflow_model, input_example, path)
3255
+
3256
+ if signature_from_type_hints:
3257
+ if signature and signature_from_type_hints != signature:
3258
+ # TODO: drop this support and raise exception in the next minor release since this
3259
+ # is a behavior change
3260
+ _logger.warning(
3261
+ "Provided signature does not match the signature inferred from the Python model's "
3262
+ "`predict` function type hint. Signature inferred from type hint will be used:\n"
3263
+ f"{signature_from_type_hints}\nRemove the `signature` parameter or ensure it "
3264
+ "matches the inferred signature. In a future release, this warning will become an "
3265
+ "exception, and the signature must align with the type hint.",
3266
+ extra={"color": "red"},
3267
+ )
3268
+ mlflow_model.signature = signature_from_type_hints
3269
+ elif signature:
3270
+ mlflow_model.signature = signature
3271
+ if type_hint_from_example:
3272
+ if saved_example is None:
3273
+ _logger.warning(
3274
+ _TYPE_FROM_EXAMPLE_ERROR_MESSAGE,
3275
+ extra={"color": "red"},
3276
+ )
3277
+ else:
3278
+ # TODO: validate input example against signature
3279
+ update_signature_for_type_hint_from_example(input_example, mlflow_model.signature)
3280
+ else:
3281
+ if saved_example is None:
3282
+ color_warning(
3283
+ message="An input example was not provided when logging the model. To ensure "
3284
+ "the model signature functions correctly, specify the `input_example` "
3285
+ "parameter. See "
3286
+ "https://mlflow.org/docs/latest/model/signatures.html#model-input-example "
3287
+ "for more details about the benefits of using input_example.",
3288
+ stacklevel=1,
3289
+ color="yellow_bold",
3290
+ )
3291
+ else:
3292
+ _logger.info("Validating input example against model signature")
3293
+ try:
3294
+ _validate_prediction_input(
3295
+ data=saved_example.inference_data,
3296
+ params=saved_example.inference_params,
3297
+ input_schema=signature.inputs,
3298
+ params_schema=signature.params,
3299
+ )
3300
+ except Exception as e:
3301
+ raise MlflowException.invalid_parameter_value(
3302
+ f"Input example does not match the model signature. {e}"
3303
+ )
3304
+
3305
+ with _get_dependencies_schemas() as dependencies_schemas:
3306
+ schema = dependencies_schemas.to_dict()
3307
+ if schema is not None:
3308
+ if mlflow_model.metadata is None:
3309
+ mlflow_model.metadata = {}
3310
+ mlflow_model.metadata.update(schema)
3311
+
3312
+ if resources is not None:
3313
+ if isinstance(resources, (Path, str)):
3314
+ serialized_resource = _ResourceBuilder.from_yaml_file(resources)
3315
+ else:
3316
+ serialized_resource = _ResourceBuilder.from_resources(resources)
3317
+
3318
+ mlflow_model.resources = serialized_resource
3319
+
3320
+ if auth_policy is not None:
3321
+ mlflow_model.auth_policy = auth_policy
3322
+
3323
+ if first_argument_set_specified:
3324
+ return _save_model_with_loader_module_and_data_path(
3325
+ path=path,
3326
+ loader_module=loader_module,
3327
+ data_path=data_path,
3328
+ code_paths=code_paths,
3329
+ conda_env=conda_env,
3330
+ mlflow_model=mlflow_model,
3331
+ pip_requirements=pip_requirements,
3332
+ extra_pip_requirements=extra_pip_requirements,
3333
+ model_config=model_config,
3334
+ streamable=streamable,
3335
+ infer_code_paths=infer_code_paths,
3336
+ )
3337
+ elif second_argument_set_specified:
3338
+ return mlflow.pyfunc.model._save_model_with_class_artifacts_params(
3339
+ path=path,
3340
+ signature=signature,
3341
+ python_model=python_model,
3342
+ artifacts=artifacts,
3343
+ conda_env=conda_env,
3344
+ code_paths=code_paths,
3345
+ mlflow_model=mlflow_model,
3346
+ pip_requirements=pip_requirements,
3347
+ extra_pip_requirements=extra_pip_requirements,
3348
+ model_config=model_config,
3349
+ streamable=streamable,
3350
+ model_code_path=model_code_path,
3351
+ infer_code_paths=infer_code_paths,
3352
+ )
3353
+
3354
+
3355
+ def update_signature_for_type_hint_from_example(input_example: Any, signature: ModelSignature):
3356
+ if _is_example_valid_for_type_from_example(input_example):
3357
+ signature._is_type_hint_from_example = True
3358
+ else:
3359
+ _logger.warning(
3360
+ "Input example must be one of pandas.DataFrame, pandas.Series "
3361
+ f"or list when using TypeFromExample as type hint, got {type(input_example)}. "
3362
+ "Check https://mlflow.org/docs/latest/model/python_model.html#typefromexample-type-hint-usage"
3363
+ " for more details.",
3364
+ )
3365
+
3366
+
3367
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
3368
+ @trace_disabled # Suppress traces for internal predict calls while logging model
3369
+ def log_model(
3370
+ artifact_path=None,
3371
+ loader_module=None,
3372
+ data_path=None,
3373
+ code_paths=None,
3374
+ infer_code_paths=False,
3375
+ conda_env=None,
3376
+ python_model=None,
3377
+ artifacts=None,
3378
+ registered_model_name=None,
3379
+ signature: ModelSignature = None,
3380
+ input_example: ModelInputExample = None,
3381
+ await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
3382
+ pip_requirements=None,
3383
+ extra_pip_requirements=None,
3384
+ metadata=None,
3385
+ model_config=None,
3386
+ streamable=None,
3387
+ resources: Optional[Union[str, list[Resource]]] = None,
3388
+ auth_policy: Optional[AuthPolicy] = None,
3389
+ prompts: Optional[list[Union[str, Prompt]]] = None,
3390
+ name=None,
3391
+ params: Optional[dict[str, Any]] = None,
3392
+ tags: Optional[dict[str, Any]] = None,
3393
+ model_type: Optional[str] = None,
3394
+ step: int = 0,
3395
+ model_id: Optional[str] = None,
3396
+ ):
3397
+ """
3398
+ Log a Pyfunc model with custom inference logic and optional data dependencies as an MLflow
3399
+ artifact for the current run.
3400
+
3401
+ For information about the workflows that this method supports, see :ref:`Workflows for
3402
+ creating custom pyfunc models <pyfunc-create-custom-workflows>` and
3403
+ :ref:`Which workflow is right for my use case? <pyfunc-create-custom-selecting-workflow>`.
3404
+ You cannot specify the parameters for the second workflow: ``loader_module``, ``data_path``
3405
+ and the parameters for the first workflow: ``python_model``, ``artifacts`` together.
3406
+
3407
+ Args:
3408
+ artifact_path: Deprecated. Use `name` instead.
3409
+ loader_module: The name of the Python module that is used to load the model
3410
+ from ``data_path``. This module must define a method with the prototype
3411
+ ``_load_pyfunc(data_path)``. If not ``None``, this module and its
3412
+ dependencies must be included in one of the following locations:
3413
+
3414
+ - The MLflow library.
3415
+ - Package(s) listed in the model's Conda environment, specified by
3416
+ the ``conda_env`` parameter.
3417
+ - One or more of the files specified by the ``code_paths`` parameter.
3418
+
3419
+ data_path: Path to a file or directory containing model data.
3420
+ code_paths: {{ code_paths_pyfunc }}
3421
+ infer_code_paths: {{ infer_code_paths }}
3422
+ conda_env: {{ conda_env }}
3423
+ python_model:
3424
+ An instance of a subclass of :class:`~PythonModel` or a callable object with a single
3425
+ argument (see the examples below). The passed-in object is serialized using the
3426
+ CloudPickle library. The python_model can also be a file path to the PythonModel
3427
+ which defines the model from code artifact rather than serializing the model object.
3428
+ Any dependencies of the class should be included in one of the
3429
+ following locations:
3430
+
3431
+ - The MLflow library.
3432
+ - Package(s) listed in the model's Conda environment, specified by the ``conda_env``
3433
+ parameter.
3434
+ - One or more of the files specified by the ``code_paths`` parameter.
3435
+
3436
+ Note: If the class is imported from another module, as opposed to being defined in the
3437
+ ``__main__`` scope, the defining module should also be included in one of the listed
3438
+ locations.
3439
+
3440
+ **Examples**
3441
+
3442
+ Class model
3443
+
3444
+ .. code-block:: python
3445
+
3446
+ from typing import List
3447
+ import mlflow
3448
+
3449
+
3450
+ class MyModel(mlflow.pyfunc.PythonModel):
3451
+ def predict(self, context, model_input: List[str], params=None) -> List[str]:
3452
+ return [i.upper() for i in model_input]
3453
+
3454
+
3455
+ with mlflow.start_run():
3456
+ model_info = mlflow.pyfunc.log_model(
3457
+ name="model",
3458
+ python_model=MyModel(),
3459
+ )
3460
+
3461
+
3462
+ loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
3463
+ print(loaded_model.predict(["a", "b", "c"])) # -> ["A", "B", "C"]
3464
+
3465
+ Functional model
3466
+
3467
+ .. note::
3468
+ Experimental: Functional model support is experimental and may change or be removed
3469
+ in a future release without warning.
3470
+
3471
+ .. code-block:: python
3472
+
3473
+ from typing import List
3474
+ import mlflow
3475
+
3476
+
3477
+ def predict(model_input: List[str]) -> List[str]:
3478
+ return [i.upper() for i in model_input]
3479
+
3480
+
3481
+ with mlflow.start_run():
3482
+ model_info = mlflow.pyfunc.log_model(
3483
+ name="model", python_model=predict, input_example=["a"]
3484
+ )
3485
+
3486
+
3487
+ loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
3488
+ print(loaded_model.predict(["a", "b", "c"])) # -> ["A", "B", "C"]
3489
+
3490
+ Model from code
3491
+
3492
+ .. note::
3493
+ Experimental: Model from code model support is experimental and may change or
3494
+ be removed in a future release without warning.
3495
+
3496
+ .. code-block:: python
3497
+
3498
+ # code.py
3499
+ from typing import List
3500
+ import mlflow
3501
+
3502
+
3503
+ class MyModel(mlflow.pyfunc.PythonModel):
3504
+ def predict(self, context, model_input: List[str], params=None) -> List[str]:
3505
+ return [i.upper() for i in model_input]
3506
+
3507
+
3508
+ mlflow.models.set_model(MyModel())
3509
+
3510
+ # log_model.py
3511
+ import mlflow
3512
+
3513
+ with mlflow.start_run():
3514
+ model_info = mlflow.pyfunc.log_model(
3515
+ name="model",
3516
+ python_model="code.py",
3517
+ )
3518
+
3519
+ If the `predict` method or function has type annotations, MLflow automatically
3520
+ constructs a model signature based on the type annotations (unless the ``signature``
3521
+ argument is explicitly specified), and converts the input value to the specified type
3522
+ before passing it to the function. Currently, the following type annotations are
3523
+ supported:
3524
+
3525
+ - ``List[str]``
3526
+ - ``List[Dict[str, str]]``
3527
+
3528
+ artifacts: A dictionary containing ``<name, artifact_uri>`` entries. Remote artifact URIs
3529
+ are resolved to absolute filesystem paths, producing a dictionary of
3530
+ ``<name, absolute_path>`` entries. ``python_model`` can reference these
3531
+ resolved entries as the ``artifacts`` property of the ``context`` parameter
3532
+ in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
3533
+ and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
3534
+ For example, consider the following ``artifacts`` dictionary::
3535
+
3536
+ {"my_file": "s3://my-bucket/path/to/my/file"}
3537
+
3538
+ In this case, the ``"my_file"`` artifact is downloaded from S3. The
3539
+ ``python_model`` can then refer to ``"my_file"`` as an absolute filesystem
3540
+ path via ``context.artifacts["my_file"]``.
3541
+
3542
+ If ``None``, no artifacts are added to the model.
3543
+ registered_model_name: If given, create a model
3544
+ version under ``registered_model_name``, also creating a
3545
+ registered model if one with the given name does not exist.
3546
+
3547
+ signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
3548
+ describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
3549
+ The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
3550
+ from datasets with valid model input (e.g. the training dataset with target
3551
+ column omitted) and valid model output (e.g. model predictions generated on
3552
+ the training dataset), for example:
3553
+
3554
+ .. code-block:: python
3555
+
3556
+ from mlflow.models import infer_signature
3557
+
3558
+ train = df.drop_column("target_label")
3559
+ predictions = ... # compute model predictions
3560
+ signature = infer_signature(train, predictions)
3561
+
3562
+ input_example: {{ input_example }}
3563
+ await_registration_for: Number of seconds to wait for the model version to finish
3564
+ being created and is in ``READY`` status. By default, the function
3565
+ waits for five minutes. Specify 0 or None to skip waiting.
3566
+ pip_requirements: {{ pip_requirements }}
3567
+ extra_pip_requirements: {{ extra_pip_requirements }}
3568
+ metadata: {{ metadata }}
3569
+ model_config: The model configuration to apply to the model. The configuration will
3570
+ be available as the ``model_config`` property of the ``context`` parameter
3571
+ in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
3572
+ and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
3573
+ The configuration can be passed as a file path, or a dict with string keys.
3574
+
3575
+ .. Note:: Experimental: This parameter may change or be removed in a future
3576
+ release without warning.
3577
+ streamable: A boolean value indicating if the model supports streaming prediction,
3578
+ If None, MLflow will try to inspect if the model supports streaming
3579
+ by checking if `predict_stream` method exists. Default None.
3580
+ resources: A list of model resources or a resources.yaml file containing a list of
3581
+ resources required to serve the model.
3582
+
3583
+ .. Note:: Experimental: This parameter may change or be removed in a future
3584
+ release without warning.
3585
+ auth_policy: {{ auth_policy }}
3586
+ prompts: {{ prompts }}
3587
+ name: {{ name }}
3588
+ params: {{ params }}
3589
+ tags: {{ tags }}
3590
+ model_type: {{ model_type }}
3591
+ step: {{ step }}
3592
+ model_id: {{ model_id }}
3593
+
3594
+ Returns:
3595
+ A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
3596
+ metadata of the logged model.
3597
+ """
3598
+ return Model.log(
3599
+ artifact_path=artifact_path,
3600
+ name=name,
3601
+ flavor=mlflow.pyfunc,
3602
+ loader_module=loader_module,
3603
+ data_path=data_path,
3604
+ code_paths=code_paths,
3605
+ python_model=python_model,
3606
+ artifacts=artifacts,
3607
+ conda_env=conda_env,
3608
+ registered_model_name=registered_model_name,
3609
+ signature=signature,
3610
+ input_example=input_example,
3611
+ await_registration_for=await_registration_for,
3612
+ pip_requirements=pip_requirements,
3613
+ extra_pip_requirements=extra_pip_requirements,
3614
+ metadata=metadata,
3615
+ prompts=prompts,
3616
+ model_config=model_config,
3617
+ streamable=streamable,
3618
+ resources=resources,
3619
+ infer_code_paths=infer_code_paths,
3620
+ auth_policy=auth_policy,
3621
+ params=params,
3622
+ tags=tags,
3623
+ model_type=model_type,
3624
+ step=step,
3625
+ model_id=model_id,
3626
+ )
3627
+
3628
+
3629
+ def _save_model_with_loader_module_and_data_path( # noqa: D417
3630
+ path,
3631
+ loader_module,
3632
+ data_path=None,
3633
+ code_paths=None,
3634
+ conda_env=None,
3635
+ mlflow_model=None,
3636
+ pip_requirements=None,
3637
+ extra_pip_requirements=None,
3638
+ model_config=None,
3639
+ streamable=None,
3640
+ infer_code_paths=False,
3641
+ ):
3642
+ """
3643
+ Export model as a generic Python function model.
3644
+
3645
+ Args:
3646
+ path: The path to which to save the Python model.
3647
+ loader_module: The name of the Python module that is used to load the model
3648
+ from ``data_path``. This module must define a method with the prototype
3649
+ ``_load_pyfunc(data_path)``.
3650
+ data_path: Path to a file or directory containing model data.
3651
+ code_paths: A list of local filesystem paths to Python file dependencies (or directories
3652
+ containing file dependencies). These files are *prepended* to the system
3653
+ path before the model is loaded.
3654
+ conda_env: Either a dictionary representation of a Conda environment or the path to a
3655
+ Conda environment yaml file. If provided, this describes the environment
3656
+ this model should be run in.
3657
+ streamable: A boolean value indicating if the model supports streaming prediction,
3658
+ None value also means not streamable.
3659
+
3660
+ Returns:
3661
+ Model configuration containing model info.
3662
+ """
3663
+
3664
+ data = None
3665
+
3666
+ if data_path is not None:
3667
+ model_file = _copy_file_or_tree(src=data_path, dst=path, dst_dir="data")
3668
+ data = model_file
3669
+
3670
+ if mlflow_model is None:
3671
+ mlflow_model = Model()
3672
+
3673
+ streamable = streamable or False
3674
+ mlflow.pyfunc.add_to_model(
3675
+ mlflow_model,
3676
+ loader_module=loader_module,
3677
+ code=None,
3678
+ data=data,
3679
+ conda_env=_CONDA_ENV_FILE_NAME,
3680
+ python_env=_PYTHON_ENV_FILE_NAME,
3681
+ model_config=model_config,
3682
+ streamable=streamable,
3683
+ )
3684
+ if size := get_total_file_size(path):
3685
+ mlflow_model.model_size_bytes = size
3686
+ mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
3687
+
3688
+ code_dir_subpath = _validate_infer_and_copy_code_paths(
3689
+ code_paths, path, infer_code_paths, FLAVOR_NAME
3690
+ )
3691
+ mlflow_model.flavors[FLAVOR_NAME][CODE] = code_dir_subpath
3692
+
3693
+ # `mlflow_model.code` is updated, re-generate `MLmodel` file.
3694
+ mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
3695
+
3696
+ if conda_env is None:
3697
+ if pip_requirements is None:
3698
+ default_reqs = get_default_pip_requirements()
3699
+ extra_env_vars = (
3700
+ _get_databricks_serverless_env_vars()
3701
+ if is_in_databricks_serverless_runtime()
3702
+ else None
3703
+ )
3704
+ # To ensure `_load_pyfunc` can successfully load the model during the dependency
3705
+ # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
3706
+ inferred_reqs = mlflow.models.infer_pip_requirements(
3707
+ path,
3708
+ FLAVOR_NAME,
3709
+ fallback=default_reqs,
3710
+ extra_env_vars=extra_env_vars,
3711
+ )
3712
+ default_reqs = sorted(set(inferred_reqs).union(default_reqs))
3713
+ else:
3714
+ default_reqs = None
3715
+ conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
3716
+ default_reqs,
3717
+ pip_requirements,
3718
+ extra_pip_requirements,
3719
+ )
3720
+ else:
3721
+ conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
3722
+
3723
+ with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
3724
+ yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
3725
+
3726
+ # Save `constraints.txt` if necessary
3727
+ if pip_constraints:
3728
+ write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
3729
+
3730
+ # Save `requirements.txt`
3731
+ write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
3732
+
3733
+ _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
3734
+ return mlflow_model
3735
+
3736
+
3737
+ def _save_model_chat_agent_helper(python_model, mlflow_model, signature, input_example):
3738
+ """Helper method for save_model for ChatAgent models
3739
+
3740
+ Returns: a dict input_example
3741
+ """
3742
+ if signature is not None:
3743
+ raise MlflowException(
3744
+ "ChatAgent subclasses have a standard signature that is set "
3745
+ "automatically. Please remove the `signature` parameter from "
3746
+ "the call to log_model() or save_model().",
3747
+ error_code=INVALID_PARAMETER_VALUE,
3748
+ )
3749
+ mlflow_model.signature = ModelSignature(
3750
+ inputs=CHAT_AGENT_INPUT_SCHEMA,
3751
+ outputs=CHAT_AGENT_OUTPUT_SCHEMA,
3752
+ )
3753
+ # For ChatAgent we set default metadata to indicate its task
3754
+ default_metadata = {TASK: _DEFAULT_CHAT_AGENT_METADATA_TASK}
3755
+ mlflow_model.metadata = default_metadata | (mlflow_model.metadata or {})
3756
+
3757
+ # We accept a dict with ChatAgentRequest schema
3758
+ if input_example:
3759
+ try:
3760
+ model_validate(ChatAgentRequest, input_example)
3761
+ except pydantic.ValidationError as e:
3762
+ raise MlflowException(
3763
+ message=(
3764
+ f"Invalid input example. Expected a ChatAgentRequest object or dictionary with"
3765
+ f" its schema. Pydantic validation error: {e}"
3766
+ ),
3767
+ error_code=INTERNAL_ERROR,
3768
+ ) from e
3769
+ if isinstance(input_example, ChatAgentRequest):
3770
+ input_example = input_example.model_dump_compat(exclude_none=True)
3771
+ else:
3772
+ input_example = CHAT_AGENT_INPUT_EXAMPLE
3773
+
3774
+ _logger.info("Predicting on input example to validate output")
3775
+ request = ChatAgentRequest(**input_example)
3776
+ output = python_model.predict(request.messages, request.context, request.custom_inputs)
3777
+ try:
3778
+ model_validate(ChatAgentResponse, output)
3779
+ except Exception as e:
3780
+ raise MlflowException(
3781
+ "Failed to save ChatAgent. Ensure your model's predict() method returns a "
3782
+ "ChatAgentResponse object or a dict with the same schema."
3783
+ f"Pydantic validation error: {e}"
3784
+ ) from e
3785
+ return input_example
3786
+
3787
+
3788
+ def _save_model_responses_agent_helper(python_model, mlflow_model, signature, input_example):
3789
+ """Helper method for save_model for ResponsesAgent models
3790
+
3791
+ Returns: a dictionary input example
3792
+ """
3793
+ from mlflow.types.responses import (
3794
+ RESPONSES_AGENT_INPUT_EXAMPLE,
3795
+ RESPONSES_AGENT_INPUT_SCHEMA,
3796
+ RESPONSES_AGENT_OUTPUT_SCHEMA,
3797
+ ResponsesAgentRequest,
3798
+ ResponsesAgentResponse,
3799
+ )
3800
+
3801
+ if signature is not None:
3802
+ raise MlflowException(
3803
+ "ResponsesAgent subclasses have a standard signature that is set "
3804
+ "automatically. Please remove the `signature` parameter from "
3805
+ "the call to log_model() or save_model().",
3806
+ error_code=INVALID_PARAMETER_VALUE,
3807
+ )
3808
+ mlflow_model.signature = ModelSignature(
3809
+ inputs=RESPONSES_AGENT_INPUT_SCHEMA,
3810
+ outputs=RESPONSES_AGENT_OUTPUT_SCHEMA,
3811
+ )
3812
+
3813
+ # For ResponsesAgent we set default metadata to indicate its task
3814
+ default_metadata = {TASK: _DEFAULT_RESPONSES_AGENT_METADATA_TASK}
3815
+ mlflow_model.metadata = default_metadata | (mlflow_model.metadata or {})
3816
+
3817
+ # We accept either a dict or a ResponsesRequest object as input
3818
+ if input_example:
3819
+ try:
3820
+ model_validate(ResponsesAgentRequest, input_example)
3821
+ except pydantic.ValidationError as e:
3822
+ raise MlflowException(
3823
+ message=(
3824
+ f"Invalid input example. Expected a ResponsesRequest object or dictionary with"
3825
+ f" its schema. Pydantic validation error: {e}"
3826
+ ),
3827
+ error_code=INTERNAL_ERROR,
3828
+ ) from e
3829
+ if isinstance(input_example, ResponsesAgentRequest):
3830
+ input_example = input_example.model_dump_compat(exclude_none=True)
3831
+ else:
3832
+ input_example = RESPONSES_AGENT_INPUT_EXAMPLE
3833
+ _logger.info("Predicting on input example to validate output")
3834
+ request = ResponsesAgentRequest(**input_example)
3835
+ output = python_model.predict(request)
3836
+ try:
3837
+ model_validate(ResponsesAgentResponse, output)
3838
+ except Exception as e:
3839
+ raise MlflowException(
3840
+ "Failed to save ResponsesAgent. Ensure your model's predict() method returns a "
3841
+ "ResponsesResponse object or a dict with the same schema."
3842
+ f"Pydantic validation error: {e}"
3843
+ ) from e
3844
+ return input_example