genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
mlflow/models/utils.py ADDED
@@ -0,0 +1,2054 @@
1
+ import base64
2
+ import datetime as dt
3
+ import decimal
4
+ import importlib
5
+ import json
6
+ import logging
7
+ import os
8
+ import re
9
+ import shutil
10
+ import sys
11
+ import tempfile
12
+ import uuid
13
+ from contextlib import contextmanager
14
+ from copy import deepcopy
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+ import pydantic
21
+
22
+ import mlflow
23
+ from mlflow.entities import LoggedModel
24
+ from mlflow.exceptions import INVALID_PARAMETER_VALUE, MlflowException
25
+ from mlflow.models import Model
26
+ from mlflow.models.model_config import _set_model_config
27
+ from mlflow.store.artifact.utils.models import get_model_name_and_version
28
+ from mlflow.tracking.artifact_utils import _download_artifact_from_uri
29
+ from mlflow.types import DataType, ParamSchema, ParamSpec, Schema, TensorSpec
30
+ from mlflow.types.schema import AnyType, Array, Map, Object, Property
31
+ from mlflow.types.utils import (
32
+ TensorsNotSupportedException,
33
+ _infer_param_schema,
34
+ _is_none_or_nan,
35
+ clean_tensor_type,
36
+ )
37
+ from mlflow.utils import IS_PYDANTIC_V2_OR_NEWER
38
+ from mlflow.utils.annotations import experimental
39
+ from mlflow.utils.databricks_utils import is_in_databricks_runtime
40
+ from mlflow.utils.file_utils import create_tmp_dir, get_local_path_or_none
41
+ from mlflow.utils.mlflow_tags import MLFLOW_MODEL_IS_EXTERNAL
42
+ from mlflow.utils.proto_json_utils import (
43
+ NumpyEncoder,
44
+ dataframe_from_parsed_json,
45
+ parse_inputs_data,
46
+ parse_tf_serving_input,
47
+ )
48
+ from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri
49
+
50
+ try:
51
+ from scipy.sparse import csc_matrix, csr_matrix
52
+
53
+ HAS_SCIPY = True
54
+ except ImportError:
55
+ HAS_SCIPY = False
56
+
57
+ try:
58
+ from pyspark.sql import DataFrame as SparkDataFrame
59
+ from pyspark.sql import Row
60
+ from pyspark.sql.types import (
61
+ ArrayType,
62
+ BinaryType,
63
+ DateType,
64
+ FloatType,
65
+ IntegerType,
66
+ ShortType,
67
+ StructType,
68
+ TimestampType,
69
+ )
70
+
71
+ HAS_PYSPARK = True
72
+ except ImportError:
73
+ SparkDataFrame = None
74
+ HAS_PYSPARK = False
75
+
76
+
77
+ INPUT_EXAMPLE_PATH = "artifact_path"
78
+ EXAMPLE_DATA_KEY = "inputs"
79
+ EXAMPLE_PARAMS_KEY = "params"
80
+ EXAMPLE_FILENAME = "input_example.json"
81
+ SERVING_INPUT_PATH = "serving_input_path"
82
+ SERVING_INPUT_FILENAME = "serving_input_example.json"
83
+
84
+ # TODO: import from scoring_server after refactoring
85
+ DF_SPLIT = "dataframe_split"
86
+ INPUTS = "inputs"
87
+ SERVING_PARAMS_KEY = "params"
88
+
89
+ ModelInputExample = Union[
90
+ pd.DataFrame, np.ndarray, dict, list, "csr_matrix", "csc_matrix", str, bytes, tuple
91
+ ]
92
+
93
+ PyFuncLLMSingleInput = Union[
94
+ dict[str, Any],
95
+ bool,
96
+ bytes,
97
+ float,
98
+ int,
99
+ str,
100
+ ]
101
+
102
+ PyFuncLLMOutputChunk = Union[
103
+ dict[str, Any],
104
+ str,
105
+ ]
106
+
107
+ PyFuncInput = Union[
108
+ pd.DataFrame,
109
+ pd.Series,
110
+ np.ndarray,
111
+ "csc_matrix",
112
+ "csr_matrix",
113
+ List[Any], # noqa: UP006
114
+ Dict[str, Any], # noqa: UP006
115
+ dt.datetime,
116
+ bool,
117
+ bytes,
118
+ float,
119
+ int,
120
+ str,
121
+ ]
122
+ PyFuncOutput = Union[pd.DataFrame, pd.Series, np.ndarray, list, str]
123
+
124
+ if HAS_PYSPARK:
125
+ PyFuncInput = Union[PyFuncInput, SparkDataFrame]
126
+ PyFuncOutput = Union[PyFuncOutput, SparkDataFrame]
127
+
128
+ _logger = logging.getLogger(__name__)
129
+
130
+ _FEATURE_STORE_FLAVOR = "databricks.feature_store.mlflow_model"
131
+
132
+
133
+ def _is_scalar(x):
134
+ return np.isscalar(x) or x is None
135
+
136
+
137
+ def _validate_params(params):
138
+ try:
139
+ _infer_param_schema(params)
140
+ except MlflowException:
141
+ _logger.warning(f"Invalid params found in input example: {params}")
142
+ raise
143
+
144
+
145
+ def _is_ndarray(x):
146
+ return isinstance(x, np.ndarray) or (
147
+ isinstance(x, dict) and all(isinstance(ary, np.ndarray) for ary in x.values())
148
+ )
149
+
150
+
151
+ def _is_sparse_matrix(x):
152
+ if not HAS_SCIPY:
153
+ # we can safely assume that if no scipy is installed,
154
+ # the user won't log scipy sparse matrices
155
+ return False
156
+ return isinstance(x, (csc_matrix, csr_matrix))
157
+
158
+
159
+ def _handle_ndarray_nans(x: np.ndarray):
160
+ if np.issubdtype(x.dtype, np.number):
161
+ return np.where(np.isnan(x), None, x)
162
+ else:
163
+ return x
164
+
165
+
166
+ def _handle_ndarray_input(input_array: Union[np.ndarray, dict[str, Any]]):
167
+ if isinstance(input_array, dict):
168
+ result = {}
169
+ for name in input_array.keys():
170
+ result[name] = _handle_ndarray_nans(input_array[name]).tolist()
171
+ return result
172
+ else:
173
+ return _handle_ndarray_nans(input_array).tolist()
174
+
175
+
176
+ def _handle_sparse_matrix(x: Union["csr_matrix", "csc_matrix"]):
177
+ return {
178
+ "data": _handle_ndarray_nans(x.data).tolist(),
179
+ "indices": x.indices.tolist(),
180
+ "indptr": x.indptr.tolist(),
181
+ "shape": list(x.shape),
182
+ }
183
+
184
+
185
+ def _handle_dataframe_nans(df: pd.DataFrame):
186
+ return df.where(df.notnull(), None)
187
+
188
+
189
+ def _coerce_to_pandas_df(input_ex):
190
+ if isinstance(input_ex, dict):
191
+ # We need to be compatible with infer_schema's behavior, where
192
+ # it infers each value's type directly.
193
+ if all(
194
+ isinstance(x, str) or (isinstance(x, list) and all(_is_scalar(y) for y in x))
195
+ for x in input_ex.values()
196
+ ):
197
+ # e.g.
198
+ # data = {"a": "a", "b": ["a", "b", "c"]}
199
+ # >>> pd.DataFrame([data])
200
+ # a b
201
+ # 0 a [a, b, c]
202
+ _logger.info(
203
+ "We convert input dictionaries to pandas DataFrames such that "
204
+ "each key represents a column, collectively constituting a "
205
+ "single row of data. If you would like to save data as "
206
+ "multiple rows, please convert your data to a pandas "
207
+ "DataFrame before passing to input_example."
208
+ )
209
+ input_ex = pd.DataFrame([input_ex])
210
+ elif np.isscalar(input_ex):
211
+ input_ex = pd.DataFrame([input_ex])
212
+ elif not isinstance(input_ex, pd.DataFrame):
213
+ input_ex = None
214
+ return input_ex
215
+
216
+
217
+ def _convert_dataframe_to_split_dict(df):
218
+ result = _handle_dataframe_nans(df).to_dict(orient="split")
219
+ # Do not include row index
220
+ del result["index"]
221
+ if all(df.columns == range(len(df.columns))):
222
+ # No need to write default column index out
223
+ del result["columns"]
224
+ return result
225
+
226
+
227
+ def _contains_nd_array(data):
228
+ import numpy as np
229
+
230
+ if isinstance(data, np.ndarray):
231
+ return True
232
+ if isinstance(data, list):
233
+ return any(_contains_nd_array(x) for x in data)
234
+ if isinstance(data, dict):
235
+ return any(_contains_nd_array(x) for x in data.values())
236
+ return False
237
+
238
+
239
+ class _Example:
240
+ """
241
+ Represents an input example for MLflow model.
242
+
243
+ Contains jsonable data that can be saved with the model and meta data about the exported format
244
+ that can be saved with :py:class:`Model <mlflow.models.Model>`.
245
+
246
+ The _Example is created from example data provided by user. The example(s) can be provided as
247
+ pandas.DataFrame, numpy.ndarray, python dictionary or python list. The assumption is that the
248
+ example contains jsonable elements (see storage format section below). The input example will
249
+ be saved as a json serializable object if it is a pandas DataFrame or numpy array.
250
+ If the example is a tuple, the first element is considered as the example data and the second
251
+ element is considered as the example params.
252
+
253
+ NOTE: serving input example is not supported for sparse matrices yet.
254
+
255
+ Metadata:
256
+
257
+ The _Example metadata contains the following information:
258
+ - artifact_path: Relative path to the serialized example within the model directory.
259
+ - serving_input_path: Relative path to the serialized example used for model serving
260
+ within the model directory.
261
+ - type: Type of example data provided by the user. Supported types are:
262
+ - ndarray
263
+ - dataframe
264
+ - json_object
265
+ - sparse_matrix_csc
266
+ - sparse_matrix_csr
267
+ If the `type` is `dataframe`, `pandas_orient` is also stored in the metadata. This
268
+ attribute specifies how is the dataframe encoded in json. For example, "split" value
269
+ signals that the data is stored as object with columns and data attributes.
270
+
271
+ Storage Format:
272
+
273
+ The examples are stored as json for portability and readability. Therefore, the contents of the
274
+ example(s) must be jsonable. MLflow will make the following conversions automatically on behalf
275
+ of the user:
276
+
277
+ - binary values: :py:class:`bytes` or :py:class:`bytearray` are converted to base64
278
+ encoded strings.
279
+ - numpy types: Numpy types are converted to the corresponding python types or their closest
280
+ equivalent.
281
+ - csc/csr matrix: similar to 2 dims numpy array, csc/csr matrix are converted to
282
+ corresponding python types or their closest equivalent.
283
+ """
284
+
285
+ def __init__(self, input_example: ModelInputExample):
286
+ try:
287
+ import pyspark.sql
288
+
289
+ if isinstance(input_example, pyspark.sql.DataFrame):
290
+ raise MlflowException(
291
+ "Examples can not be provided as Spark Dataframe. "
292
+ "Please make sure your example is of a small size and "
293
+ "turn it into a pandas DataFrame."
294
+ )
295
+ except ImportError:
296
+ pass
297
+
298
+ self.info = {
299
+ INPUT_EXAMPLE_PATH: EXAMPLE_FILENAME,
300
+ }
301
+
302
+ self._inference_data, self._inference_params = _split_input_data_and_params(
303
+ deepcopy(input_example)
304
+ )
305
+ if self._inference_params:
306
+ self.info[EXAMPLE_PARAMS_KEY] = "true"
307
+ model_input = deepcopy(self._inference_data)
308
+
309
+ if isinstance(model_input, pydantic.BaseModel):
310
+ model_input = (
311
+ model_input.model_dump() if IS_PYDANTIC_V2_OR_NEWER else model_input.dict()
312
+ )
313
+
314
+ is_unified_llm_input = False
315
+ if isinstance(model_input, dict):
316
+ """
317
+ Supported types are:
318
+ - Dict[str, Union[DataType, List, Dict]] --> type: json_object
319
+ - Dict[str, numpy.ndarray] --> type: ndarray
320
+ """
321
+ if any(isinstance(values, np.ndarray) for values in model_input.values()):
322
+ if not all(isinstance(values, np.ndarray) for values in model_input.values()):
323
+ raise MlflowException.invalid_parameter_value(
324
+ "Mixed types in dictionary are not supported as input examples. "
325
+ "Found numpy arrays and other types."
326
+ )
327
+ self.info["type"] = "ndarray"
328
+ model_input = _handle_ndarray_input(model_input)
329
+ self.serving_input = {INPUTS: model_input}
330
+ else:
331
+ from mlflow.pyfunc.utils.serving_data_parser import is_unified_llm_input
332
+
333
+ self.info["type"] = "json_object"
334
+ is_unified_llm_input = is_unified_llm_input(model_input)
335
+ if is_unified_llm_input:
336
+ self.serving_input = model_input
337
+ else:
338
+ self.serving_input = {INPUTS: model_input}
339
+ elif isinstance(model_input, np.ndarray):
340
+ """type: ndarray"""
341
+ model_input = _handle_ndarray_input(model_input)
342
+ self.info["type"] = "ndarray"
343
+ self.serving_input = {INPUTS: model_input}
344
+ elif isinstance(model_input, list):
345
+ """
346
+ Supported types are:
347
+ - List[DataType]
348
+ - List[Dict[str, Union[DataType, List, Dict]]]
349
+ --> type: json_object
350
+ """
351
+ if _contains_nd_array(model_input):
352
+ raise TensorsNotSupportedException(
353
+ "Numpy arrays in list are not supported as input examples."
354
+ )
355
+ self.info["type"] = "json_object"
356
+ self.serving_input = {INPUTS: model_input}
357
+ elif _is_sparse_matrix(model_input):
358
+ """
359
+ Supported types are:
360
+ - scipy.sparse.csr_matrix
361
+ - scipy.sparse.csc_matrix
362
+ Note: This type of input is not supported by the scoring server yet
363
+ """
364
+ if isinstance(model_input, csc_matrix):
365
+ example_type = "sparse_matrix_csc"
366
+ else:
367
+ example_type = "sparse_matrix_csr"
368
+ self.info["type"] = example_type
369
+ self.serving_input = {INPUTS: model_input.toarray()}
370
+ model_input = _handle_sparse_matrix(model_input)
371
+ elif isinstance(model_input, pd.DataFrame):
372
+ model_input = _convert_dataframe_to_split_dict(model_input)
373
+ self.serving_input = {DF_SPLIT: model_input}
374
+ orient = "split" if "columns" in model_input else "values"
375
+ self.info.update(
376
+ {
377
+ "type": "dataframe",
378
+ "pandas_orient": orient,
379
+ }
380
+ )
381
+ elif np.isscalar(model_input) or isinstance(model_input, dt.datetime):
382
+ self.info["type"] = "json_object"
383
+ self.serving_input = {INPUTS: model_input}
384
+ else:
385
+ raise MlflowException.invalid_parameter_value(
386
+ "Expected one of the following types:\n"
387
+ "- pandas.DataFrame\n"
388
+ "- numpy.ndarray\n"
389
+ "- dictionary of (name -> numpy.ndarray)\n"
390
+ "- scipy.sparse.csr_matrix\n"
391
+ "- scipy.sparse.csc_matrix\n"
392
+ "- dict\n"
393
+ "- list\n"
394
+ "- scalars\n"
395
+ "- datetime.datetime\n"
396
+ "- pydantic model instance\n"
397
+ f"but got '{type(model_input)}'",
398
+ )
399
+
400
+ if self._inference_params is not None:
401
+ """
402
+ Save input data and params with their respective keys, so we can load them separately.
403
+ """
404
+ model_input = {
405
+ EXAMPLE_DATA_KEY: model_input,
406
+ EXAMPLE_PARAMS_KEY: self._inference_params,
407
+ }
408
+ if self.serving_input:
409
+ if is_unified_llm_input:
410
+ self.serving_input = {
411
+ **(self.serving_input or {}),
412
+ **self._inference_params,
413
+ }
414
+ else:
415
+ self.serving_input = {
416
+ **(self.serving_input or {}),
417
+ SERVING_PARAMS_KEY: self._inference_params,
418
+ }
419
+
420
+ self.json_input_example = json.dumps(model_input, cls=NumpyEncoder)
421
+ if self.serving_input:
422
+ self.json_serving_input = json.dumps(self.serving_input, cls=NumpyEncoder, indent=2)
423
+ self.info[SERVING_INPUT_PATH] = SERVING_INPUT_FILENAME
424
+ else:
425
+ self.json_serving_input = None
426
+
427
+ def save(self, parent_dir_path: str):
428
+ """
429
+ Save the example as json at ``parent_dir_path``/`self.info['artifact_path']`.
430
+ Save serving input as json at ``parent_dir_path``/`self.info['serving_input_path']`.
431
+ """
432
+ with open(os.path.join(parent_dir_path, self.info[INPUT_EXAMPLE_PATH]), "w") as f:
433
+ f.write(self.json_input_example)
434
+ if self.json_serving_input:
435
+ with open(os.path.join(parent_dir_path, self.info[SERVING_INPUT_PATH]), "w") as f:
436
+ f.write(self.json_serving_input)
437
+
438
+ @property
439
+ def inference_data(self):
440
+ """
441
+ Returns the input example in a form that PyFunc wrapped models can score.
442
+ """
443
+ return self._inference_data
444
+
445
+ @property
446
+ def inference_params(self):
447
+ """
448
+ Returns the params dictionary that PyFunc wrapped models can use for scoring.
449
+ """
450
+ return self._inference_params
451
+
452
+
453
+ def _contains_params(input_example):
454
+ # For tuple input, we assume the first item is input_example data
455
+ # and the second item is params dictionary.
456
+ return (
457
+ isinstance(input_example, tuple)
458
+ and len(input_example) == 2
459
+ and isinstance(input_example[1], dict)
460
+ )
461
+
462
+
463
+ def _split_input_data_and_params(input_example):
464
+ if _contains_params(input_example):
465
+ input_data, inference_params = input_example
466
+ _validate_params(inference_params)
467
+ return input_data, inference_params
468
+ return input_example, None
469
+
470
+
471
+ @experimental(version="2.16.0")
472
+ def convert_input_example_to_serving_input(input_example) -> Optional[str]:
473
+ """
474
+ Helper function to convert a model's input example to a serving input example that
475
+ can be used for model inference in the scoring server.
476
+
477
+ Args:
478
+ input_example: model input example. Supported types are pandas.DataFrame, numpy.ndarray,
479
+ dictionary of (name -> numpy.ndarray), list, scalars and dicts with json serializable
480
+ values.
481
+
482
+ Returns:
483
+ serving input example as a json string
484
+ """
485
+ if input_example is None:
486
+ return None
487
+
488
+ example = _Example(input_example)
489
+ return example.json_serving_input
490
+
491
+
492
+ def _save_example( # noqa: D417
493
+ mlflow_model: Model, input_example: Optional[ModelInputExample], path: str
494
+ ) -> Optional[_Example]:
495
+ """
496
+ Saves example to a file on the given path and updates passed Model with example metadata.
497
+
498
+ The metadata is a dictionary with the following fields:
499
+ - 'artifact_path': example path relative to the model directory.
500
+ - 'type': Type of example. Currently the supported values are 'dataframe' and 'ndarray'
501
+ - One of the following metadata based on the `type`:
502
+ - 'pandas_orient': Used to store dataframes. Determines the json encoding for dataframe
503
+ examples in terms of pandas orient convention. Defaults to 'split'.
504
+ - 'format: Used to store tensors. Determines the standard used to store a tensor input
505
+ example. MLflow uses a JSON-formatted string representation of TF serving
506
+ input.
507
+
508
+ Args:
509
+ mlflow_model: Model metadata that will get updated with the example metadata.
510
+ path: Where to store the example file. Should be model the model directory.
511
+
512
+ Returns:
513
+ _Example object that contains saved input example.
514
+ """
515
+ if input_example is None:
516
+ return None
517
+
518
+ example = _Example(input_example)
519
+ example.save(path)
520
+ mlflow_model.saved_input_example_info = example.info
521
+ return example
522
+
523
+
524
+ def _get_mlflow_model_input_example_dict(
525
+ mlflow_model: Model, uri_or_path: str
526
+ ) -> Optional[dict[str, Any]]:
527
+ """
528
+ Args:
529
+ mlflow_model: Model metadata.
530
+ uri_or_path: Model or run URI, or path to the `model` directory.
531
+ e.g. models://<model_name>/<model_version>, runs:/<run_id>/<artifact_path>
532
+ or /path/to/model
533
+
534
+ Returns:
535
+ Input example or None if the model has no example.
536
+ """
537
+ if mlflow_model.saved_input_example_info is None:
538
+ return None
539
+ example_type = mlflow_model.saved_input_example_info["type"]
540
+ if example_type not in [
541
+ "dataframe",
542
+ "ndarray",
543
+ "sparse_matrix_csc",
544
+ "sparse_matrix_csr",
545
+ "json_object",
546
+ ]:
547
+ raise MlflowException(f"This version of mlflow can not load example of type {example_type}")
548
+ return json.loads(
549
+ _read_file_content(uri_or_path, mlflow_model.saved_input_example_info[INPUT_EXAMPLE_PATH])
550
+ )
551
+
552
+
553
+ def _load_serving_input_example(mlflow_model: Model, path: str) -> Optional[str]:
554
+ """
555
+ Load serving input example from a model directory. Returns None if there is no serving input
556
+ example.
557
+
558
+ Args:
559
+ mlflow_model: Model metadata.
560
+ path: Path to the model directory.
561
+
562
+ Returns:
563
+ Serving input example or None if the model has no serving input example.
564
+ """
565
+ if mlflow_model.saved_input_example_info is None:
566
+ return None
567
+ serving_input_path = mlflow_model.saved_input_example_info.get(SERVING_INPUT_PATH)
568
+ if serving_input_path is None:
569
+ return None
570
+ with open(os.path.join(path, serving_input_path)) as handle:
571
+ return handle.read()
572
+
573
+
574
+ def load_serving_example(model_uri_or_path: str):
575
+ """
576
+ Load serving input example from a model directory or URI.
577
+
578
+ Args:
579
+ model_uri_or_path: Model URI or path to the `model` directory.
580
+ e.g. models://<model_name>/<model_version> or /path/to/model
581
+ """
582
+ return _read_file_content(model_uri_or_path, SERVING_INPUT_FILENAME)
583
+
584
+
585
+ def _read_file_content(uri_or_path: str, file_name: str):
586
+ """
587
+ Read file content from a model directory or URI.
588
+
589
+ Args:
590
+ uri_or_path: Model or run URI, or path to the `model` directory.
591
+ e.g. models://<model_name>/<model_version>, runs:/<run_id>/<artifact_path>
592
+ or /path/to/model
593
+ file_name: Name of the file to read.
594
+ """
595
+ from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository
596
+
597
+ if ModelsArtifactRepository._is_logged_model_uri(uri_or_path):
598
+ uri_or_path = ModelsArtifactRepository.get_underlying_uri(uri_or_path)
599
+
600
+ file_path = str(uri_or_path).rstrip("/") + "/" + file_name
601
+ if os.path.exists(file_path):
602
+ with open(file_path) as handle:
603
+ return handle.read()
604
+ else:
605
+ with tempfile.TemporaryDirectory() as tmpdir:
606
+ local_file_path = _download_artifact_from_uri(file_path, output_path=tmpdir)
607
+ with open(local_file_path) as handle:
608
+ return handle.read()
609
+
610
+
611
+ def _read_example(mlflow_model: Model, uri_or_path: str):
612
+ """
613
+ Read example from a model directory. Returns None if there is no example metadata (i.e. the
614
+ model was saved without example). Raises FileNotFoundError if there is model metadata but the
615
+ example file is missing.
616
+
617
+ Args:
618
+ mlflow_model: Model metadata.
619
+ uri_or_path: Model or run URI, or path to the `model` directory.
620
+ e.g. models://<model_name>/<model_version>, runs:/<run_id>/<artifact_path>
621
+ or /path/to/model
622
+
623
+ Returns:
624
+ Input example data or None if the model has no example.
625
+ """
626
+ input_example = _get_mlflow_model_input_example_dict(mlflow_model, uri_or_path)
627
+ if input_example is None:
628
+ return None
629
+
630
+ example_type = mlflow_model.saved_input_example_info["type"]
631
+ input_schema = mlflow_model.signature.inputs if mlflow_model.signature is not None else None
632
+ if mlflow_model.saved_input_example_info.get(EXAMPLE_PARAMS_KEY, None):
633
+ input_example = input_example[EXAMPLE_DATA_KEY]
634
+ if example_type == "json_object":
635
+ return input_example
636
+ if example_type == "ndarray":
637
+ return parse_inputs_data(input_example, schema=input_schema)
638
+ if example_type in ["sparse_matrix_csc", "sparse_matrix_csr"]:
639
+ return _read_sparse_matrix_from_json(input_example, example_type)
640
+ if example_type == "dataframe":
641
+ return dataframe_from_parsed_json(input_example, pandas_orient="split", schema=input_schema)
642
+ raise MlflowException(
643
+ "Malformed input example metadata. The 'type' field must be one of "
644
+ "'dataframe', 'ndarray', 'sparse_matrix_csc', 'sparse_matrix_csr' or 'json_object'."
645
+ )
646
+
647
+
648
+ def _read_example_params(mlflow_model: Model, path: str):
649
+ """
650
+ Read params of input_example from a model directory. Returns None if there is no params
651
+ in the input_example or the model was saved without example.
652
+ """
653
+ if (
654
+ mlflow_model.saved_input_example_info is None
655
+ or mlflow_model.saved_input_example_info.get(EXAMPLE_PARAMS_KEY, None) is None
656
+ ):
657
+ return None
658
+ input_example_dict = _get_mlflow_model_input_example_dict(mlflow_model, path)
659
+ return input_example_dict[EXAMPLE_PARAMS_KEY]
660
+
661
+
662
+ def _read_tensor_input_from_json(path_or_data, schema=None):
663
+ if isinstance(path_or_data, str) and os.path.exists(path_or_data):
664
+ with open(path_or_data) as handle:
665
+ inp_dict = json.load(handle)
666
+ else:
667
+ inp_dict = path_or_data
668
+ return parse_tf_serving_input(inp_dict, schema)
669
+
670
+
671
+ def _read_sparse_matrix_from_json(path_or_data, example_type):
672
+ if isinstance(path_or_data, str) and os.path.exists(path_or_data):
673
+ with open(path_or_data) as handle:
674
+ matrix_data = json.load(handle)
675
+ else:
676
+ matrix_data = path_or_data
677
+ data = matrix_data["data"]
678
+ indices = matrix_data["indices"]
679
+ indptr = matrix_data["indptr"]
680
+ shape = tuple(matrix_data["shape"])
681
+
682
+ if example_type == "sparse_matrix_csc":
683
+ return csc_matrix((data, indices, indptr), shape=shape)
684
+ else:
685
+ return csr_matrix((data, indices, indptr), shape=shape)
686
+
687
+
688
+ def plot_lines(data_series, xlabel, ylabel, legend_loc=None, line_kwargs=None, title=None):
689
+ import matplotlib.pyplot as plt
690
+
691
+ fig, ax = plt.subplots()
692
+
693
+ if line_kwargs is None:
694
+ line_kwargs = {}
695
+
696
+ for label, data_x, data_y in data_series:
697
+ ax.plot(data_x, data_y, label=label, **line_kwargs)
698
+
699
+ if legend_loc:
700
+ ax.legend(loc=legend_loc)
701
+
702
+ ax.set(xlabel=xlabel, ylabel=ylabel, title=title)
703
+
704
+ return fig, ax
705
+
706
+
707
+ def _enforce_tensor_spec(
708
+ values: Union[np.ndarray, "csc_matrix", "csr_matrix"],
709
+ tensor_spec: TensorSpec,
710
+ ):
711
+ """
712
+ Enforce the input tensor shape and type matches the provided tensor spec.
713
+ """
714
+ expected_shape = tensor_spec.shape
715
+ expected_type = tensor_spec.type
716
+ actual_shape = values.shape
717
+ actual_type = values.dtype if isinstance(values, np.ndarray) else values.data.dtype
718
+
719
+ # This logic is for handling "ragged" arrays. The first check is for a standard numpy shape
720
+ # representation of a ragged array. The second is for handling a more manual specification
721
+ # of shape while support an input which is a ragged array.
722
+ if len(expected_shape) == 1 and expected_shape[0] == -1 and expected_type == np.dtype("O"):
723
+ # Sample spec: Tensor('object', (-1,))
724
+ # Will pass on any provided input
725
+ return values
726
+ if (
727
+ len(expected_shape) > 1
728
+ and -1 in expected_shape[1:]
729
+ and len(actual_shape) == 1
730
+ and actual_type == np.dtype("O")
731
+ ):
732
+ # Sample spec: Tensor('float64', (-1, -1, -1, 3))
733
+ # Will pass on inputs which are ragged arrays: shape==(x,), dtype=='object'
734
+ return values
735
+
736
+ if len(expected_shape) != len(actual_shape):
737
+ raise MlflowException(
738
+ f"Shape of input {actual_shape} does not match expected shape {expected_shape}."
739
+ )
740
+ for expected, actual in zip(expected_shape, actual_shape):
741
+ if expected == -1:
742
+ continue
743
+ if expected != actual:
744
+ raise MlflowException(
745
+ f"Shape of input {actual_shape} does not match expected shape {expected_shape}."
746
+ )
747
+ if clean_tensor_type(actual_type) != expected_type:
748
+ raise MlflowException(
749
+ f"dtype of input {actual_type} does not match expected dtype {expected_type}"
750
+ )
751
+ return values
752
+
753
+
754
+ def _enforce_mlflow_datatype(name, values: pd.Series, t: DataType):
755
+ """
756
+ Enforce the input column type matches the declared in model input schema.
757
+
758
+ The following type conversions are allowed:
759
+
760
+ 1. object -> string
761
+ 2. int -> long (upcast)
762
+ 3. float -> double (upcast)
763
+ 4. int -> double (safe conversion)
764
+ 5. np.datetime64[x] -> datetime (any precision)
765
+ 6. object -> datetime
766
+
767
+ NB: pandas does not have native decimal data type, when user train and infer
768
+ model from pyspark dataframe that contains decimal type, the schema will be
769
+ treated as float64.
770
+ 7. decimal -> double
771
+
772
+ Any other type mismatch will raise error.
773
+ """
774
+
775
+ if values.dtype == object and t not in (DataType.binary, DataType.string):
776
+ values = values.infer_objects()
777
+
778
+ if t == DataType.string and values.dtype == object:
779
+ # NB: the object can contain any type and we currently cannot cast to pandas Strings
780
+ # due to how None is cast
781
+ return values
782
+
783
+ # NB: Comparison of pandas and numpy data type fails when numpy data type is on the left hand
784
+ # side of the comparison operator. It works, however, if pandas type is on the left hand side.
785
+ # That is because pandas is aware of numpy.
786
+ if t.to_pandas() == values.dtype or t.to_numpy() == values.dtype:
787
+ # The types are already compatible => conversion is not necessary.
788
+ return values
789
+
790
+ if t == DataType.binary and values.dtype.kind == t.binary.to_numpy().kind:
791
+ # NB: bytes in numpy have variable itemsize depending on the length of the longest
792
+ # element in the array (column). Since MLflow binary type is length agnostic, we ignore
793
+ # itemsize when matching binary columns.
794
+ return values
795
+
796
+ if t == DataType.datetime and values.dtype.kind == t.to_numpy().kind:
797
+ # NB: datetime values have variable precision denoted by brackets, e.g. datetime64[ns]
798
+ # denotes nanosecond precision. Since MLflow datetime type is precision agnostic, we
799
+ # ignore precision when matching datetime columns.
800
+ try:
801
+ return values.astype(np.dtype("datetime64[ns]"))
802
+ except TypeError as e:
803
+ raise MlflowException(
804
+ "Please ensure that the input data of datetime column only contains timezone-naive "
805
+ f"datetime objects. Error: {e}"
806
+ )
807
+
808
+ if t == DataType.datetime and (values.dtype == object or values.dtype == t.to_python()):
809
+ # NB: Pyspark date columns get converted to object when converted to a pandas
810
+ # DataFrame. To respect the original typing, we convert the column to datetime.
811
+ try:
812
+ return values.astype(np.dtype("datetime64[ns]"), errors="raise")
813
+ except ValueError as e:
814
+ raise MlflowException(
815
+ f"Failed to convert column {name} from type {values.dtype} to {t}."
816
+ ) from e
817
+
818
+ if t == DataType.boolean and values.dtype == object:
819
+ # Should not convert type otherwise it converts None to boolean False
820
+ return values
821
+
822
+ if t == DataType.double and values.dtype == decimal.Decimal:
823
+ # NB: Pyspark Decimal column get converted to decimal.Decimal when converted to pandas
824
+ # DataFrame. In order to support decimal data training from spark data frame, we add this
825
+ # conversion even we might lose the precision.
826
+ try:
827
+ return pd.to_numeric(values, errors="raise")
828
+ except ValueError:
829
+ raise MlflowException(
830
+ f"Failed to convert column {name} from type {values.dtype} to {t}."
831
+ )
832
+
833
+ numpy_type = t.to_numpy()
834
+ if values.dtype.kind == numpy_type.kind:
835
+ is_upcast = values.dtype.itemsize <= numpy_type.itemsize
836
+ elif values.dtype.kind == "u" and numpy_type.kind == "i":
837
+ is_upcast = values.dtype.itemsize < numpy_type.itemsize
838
+ elif values.dtype.kind in ("i", "u") and numpy_type == np.float64:
839
+ # allow (u)int => double conversion
840
+ is_upcast = values.dtype.itemsize <= 6
841
+ else:
842
+ is_upcast = False
843
+
844
+ if is_upcast:
845
+ return values.astype(numpy_type, errors="raise")
846
+ else:
847
+ # support converting long -> float/double for 0 and 1 values
848
+ def all_zero_or_ones(xs):
849
+ return all(pd.isnull(x) or x in [0, 1] for x in xs)
850
+
851
+ if (
852
+ values.dtype == np.int64
853
+ and numpy_type in (np.float32, np.float64)
854
+ and all_zero_or_ones(values)
855
+ ):
856
+ return values.astype(numpy_type, errors="raise")
857
+
858
+ # NB: conversion between incompatible types (e.g. floats -> ints or
859
+ # double -> float) are not allowed. While supported by pandas and numpy,
860
+ # these conversions alter the values significantly.
861
+ def all_ints(xs):
862
+ return all(pd.isnull(x) or int(x) == x for x in xs)
863
+
864
+ hint = ""
865
+ if (
866
+ values.dtype == np.float64
867
+ and numpy_type.kind in ("i", "u")
868
+ and values.hasnans
869
+ and all_ints(values)
870
+ ):
871
+ hint = (
872
+ " Hint: the type mismatch is likely caused by missing values. "
873
+ "Integer columns in python can not represent missing values and are therefore "
874
+ "encoded as floats. The best way to avoid this problem is to infer the model "
875
+ "schema based on a realistic data sample (training dataset) that includes missing "
876
+ "values. Alternatively, you can declare integer columns as doubles (float64) "
877
+ "whenever these columns may have missing values. See `Handling Integers With "
878
+ "Missing Values <https://www.mlflow.org/docs/latest/models.html#"
879
+ "handling-integers-with-missing-values>`_ for more details."
880
+ )
881
+
882
+ raise MlflowException(
883
+ f"Incompatible input types for column {name}. "
884
+ f"Can not safely convert {values.dtype} to {numpy_type}.{hint}"
885
+ )
886
+
887
+
888
+ # dtype -> possible value types mapping
889
+ _ALLOWED_CONVERSIONS_FOR_PARAMS = {
890
+ DataType.long: (DataType.integer,),
891
+ DataType.float: (DataType.integer, DataType.long),
892
+ DataType.double: (DataType.integer, DataType.long, DataType.float),
893
+ }
894
+
895
+
896
+ def _enforce_param_datatype(value: Any, dtype: DataType):
897
+ """
898
+ Enforce the value matches the data type. This is used to enforce params datatype.
899
+ The returned data is of python built-in type or a datetime object.
900
+
901
+ The following type conversions are allowed:
902
+
903
+ 1. int -> long, float, double
904
+ 2. long -> float, double
905
+ 3. float -> double
906
+ 4. any -> datetime (try conversion)
907
+
908
+ Any other type mismatch will raise error.
909
+
910
+ Args:
911
+ value: parameter value
912
+ dtype: expected data type
913
+ """
914
+ if value is None:
915
+ return
916
+
917
+ if dtype == DataType.datetime:
918
+ try:
919
+ datetime_value = np.datetime64(value).item()
920
+ if isinstance(datetime_value, int):
921
+ raise MlflowException.invalid_parameter_value(
922
+ f"Failed to convert value to `{dtype}`. "
923
+ f"It must be convertible to datetime.date/datetime, got `{value}`"
924
+ )
925
+ return datetime_value
926
+ except ValueError as e:
927
+ raise MlflowException.invalid_parameter_value(
928
+ f"Failed to convert value `{value}` from type `{type(value)}` to `{dtype}`"
929
+ ) from e
930
+
931
+ # Note that np.isscalar(datetime.date(...)) is False
932
+ if not np.isscalar(value):
933
+ raise MlflowException.invalid_parameter_value(
934
+ f"Value must be a scalar for type `{dtype}`, got `{value}`"
935
+ )
936
+
937
+ # Always convert to python native type for params
938
+ if DataType.check_type(dtype, value):
939
+ return dtype.to_python()(value)
940
+
941
+ if dtype in _ALLOWED_CONVERSIONS_FOR_PARAMS and any(
942
+ DataType.check_type(t, value) for t in _ALLOWED_CONVERSIONS_FOR_PARAMS[dtype]
943
+ ):
944
+ try:
945
+ return dtype.to_python()(value)
946
+ except ValueError as e:
947
+ raise MlflowException.invalid_parameter_value(
948
+ f"Failed to convert value `{value}` from type `{type(value)}` to `{dtype}`"
949
+ ) from e
950
+
951
+ raise MlflowException.invalid_parameter_value(
952
+ f"Can not safely convert `{type(value)}` to `{dtype}` for value `{value}`"
953
+ )
954
+
955
+
956
+ def _enforce_unnamed_col_schema(pf_input: pd.DataFrame, input_schema: Schema):
957
+ """Enforce the input columns conform to the model's column-based signature."""
958
+ input_names = pf_input.columns[: len(input_schema.inputs)]
959
+ input_types = input_schema.input_types()
960
+ new_pf_input = {}
961
+ for i, x in enumerate(input_names):
962
+ if isinstance(input_types[i], DataType):
963
+ new_pf_input[x] = _enforce_mlflow_datatype(x, pf_input[x], input_types[i])
964
+ # If the input_type is objects/arrays/maps, we assume pf_input must be a pandas DataFrame.
965
+ # Otherwise, the schema is not valid.
966
+ else:
967
+ new_pf_input[x] = pd.Series(
968
+ [_enforce_type(obj, input_types[i]) for obj in pf_input[x]], name=x
969
+ )
970
+ return pd.DataFrame(new_pf_input)
971
+
972
+
973
+ def _enforce_named_col_schema(pf_input: pd.DataFrame, input_schema: Schema):
974
+ """Enforce the input columns conform to the model's column-based signature."""
975
+ input_names = input_schema.input_names()
976
+ input_dict = input_schema.input_dict()
977
+ new_pf_input = {}
978
+ for name in input_names:
979
+ input_type = input_dict[name].type
980
+ required = input_dict[name].required
981
+ if name not in pf_input:
982
+ if required:
983
+ raise MlflowException(
984
+ f"The input column '{name}' is required by the model "
985
+ "signature but missing from the input data."
986
+ )
987
+ else:
988
+ continue
989
+ if isinstance(input_type, DataType):
990
+ new_pf_input[name] = _enforce_mlflow_datatype(name, pf_input[name], input_type)
991
+ # If the input_type is objects/arrays/maps, we assume pf_input must be a pandas DataFrame.
992
+ # Otherwise, the schema is not valid.
993
+ else:
994
+ new_pf_input[name] = pd.Series(
995
+ [_enforce_type(obj, input_type, required) for obj in pf_input[name]], name=name
996
+ )
997
+ return pd.DataFrame(new_pf_input)
998
+
999
+
1000
+ def _reshape_and_cast_pandas_column_values(name, pd_series, tensor_spec):
1001
+ if tensor_spec.shape[0] != -1 or -1 in tensor_spec.shape[1:]:
1002
+ raise MlflowException(
1003
+ "For pandas dataframe input, the first dimension of shape must be a variable "
1004
+ "dimension and other dimensions must be fixed, but in model signature the shape "
1005
+ f"of {'input ' + name if name else 'the unnamed input'} is {tensor_spec.shape}."
1006
+ )
1007
+
1008
+ if np.isscalar(pd_series[0]):
1009
+ for shape in [(-1,), (-1, 1)]:
1010
+ if tensor_spec.shape == shape:
1011
+ return _enforce_tensor_spec(
1012
+ np.array(pd_series, dtype=tensor_spec.type).reshape(shape), tensor_spec
1013
+ )
1014
+ raise MlflowException(
1015
+ f"The input pandas dataframe column '{name}' contains scalar "
1016
+ "values, which requires the shape to be (-1,) or (-1, 1), but got tensor spec "
1017
+ f"shape of {tensor_spec.shape}.",
1018
+ error_code=INVALID_PARAMETER_VALUE,
1019
+ )
1020
+ elif isinstance(pd_series[0], list) and np.isscalar(pd_series[0][0]):
1021
+ # If the pandas column contains list type values,
1022
+ # in this case, the shape and type information is lost,
1023
+ # so do not enforce the shape and type, instead,
1024
+ # reshape the array value list to the required shape, and cast value type to
1025
+ # required type.
1026
+ reshape_err_msg = (
1027
+ f"The value in the Input DataFrame column '{name}' could not be converted to the "
1028
+ f"expected shape of: '{tensor_spec.shape}'. Ensure that each of the input list "
1029
+ "elements are of uniform length and that the data can be coerced to the tensor "
1030
+ f"type '{tensor_spec.type}'"
1031
+ )
1032
+ try:
1033
+ flattened_numpy_arr = np.vstack(pd_series.tolist())
1034
+ reshaped_numpy_arr = flattened_numpy_arr.reshape(tensor_spec.shape).astype(
1035
+ tensor_spec.type
1036
+ )
1037
+ except ValueError:
1038
+ raise MlflowException(reshape_err_msg, error_code=INVALID_PARAMETER_VALUE)
1039
+ if len(reshaped_numpy_arr) != len(pd_series):
1040
+ raise MlflowException(reshape_err_msg, error_code=INVALID_PARAMETER_VALUE)
1041
+ return reshaped_numpy_arr
1042
+ elif isinstance(pd_series[0], np.ndarray):
1043
+ reshape_err_msg = (
1044
+ f"The value in the Input DataFrame column '{name}' could not be converted to the "
1045
+ f"expected shape of: '{tensor_spec.shape}'. Ensure that each of the input numpy "
1046
+ "array elements are of uniform length and can be reshaped to above expected shape."
1047
+ )
1048
+ try:
1049
+ # Because numpy array includes precise type information, so we don't convert type
1050
+ # here, so that in following schema validation we can have strict type check on
1051
+ # numpy array column.
1052
+ reshaped_numpy_arr = np.vstack(pd_series.tolist()).reshape(tensor_spec.shape)
1053
+ except ValueError:
1054
+ raise MlflowException(reshape_err_msg, error_code=INVALID_PARAMETER_VALUE)
1055
+ if len(reshaped_numpy_arr) != len(pd_series):
1056
+ raise MlflowException(reshape_err_msg, error_code=INVALID_PARAMETER_VALUE)
1057
+ return reshaped_numpy_arr
1058
+ else:
1059
+ raise MlflowException(
1060
+ "Because the model signature requires tensor spec input, the input "
1061
+ "pandas dataframe values should be either scalar value, python list "
1062
+ "containing scalar values or numpy array containing scalar values, "
1063
+ "other types are not supported.",
1064
+ error_code=INVALID_PARAMETER_VALUE,
1065
+ )
1066
+
1067
+
1068
+ def _enforce_tensor_schema(pf_input: PyFuncInput, input_schema: Schema):
1069
+ """Enforce the input tensor(s) conforms to the model's tensor-based signature."""
1070
+
1071
+ def _is_sparse_matrix(x):
1072
+ if not HAS_SCIPY:
1073
+ # we can safely assume that it's not a sparse matrix if scipy is not installed
1074
+ return False
1075
+ return isinstance(x, (csr_matrix, csc_matrix))
1076
+
1077
+ if input_schema.has_input_names():
1078
+ if isinstance(pf_input, dict):
1079
+ new_pf_input = {}
1080
+ for col_name, tensor_spec in zip(input_schema.input_names(), input_schema.inputs):
1081
+ if not isinstance(pf_input[col_name], np.ndarray):
1082
+ raise MlflowException(
1083
+ "This model contains a tensor-based model signature with input names,"
1084
+ " which suggests a dictionary input mapping input name to a numpy"
1085
+ f" array, but a dict with value type {type(pf_input[col_name])} was found.",
1086
+ error_code=INVALID_PARAMETER_VALUE,
1087
+ )
1088
+ new_pf_input[col_name] = _enforce_tensor_spec(pf_input[col_name], tensor_spec)
1089
+ elif isinstance(pf_input, pd.DataFrame):
1090
+ new_pf_input = {}
1091
+ for col_name, tensor_spec in zip(input_schema.input_names(), input_schema.inputs):
1092
+ pd_series = pf_input[col_name]
1093
+ new_pf_input[col_name] = _reshape_and_cast_pandas_column_values(
1094
+ col_name, pd_series, tensor_spec
1095
+ )
1096
+ else:
1097
+ raise MlflowException(
1098
+ "This model contains a tensor-based model signature with input names, which"
1099
+ " suggests a dictionary input mapping input name to tensor, or a pandas"
1100
+ " DataFrame input containing columns mapping input name to flattened list value"
1101
+ f" from tensor, but an input of type {type(pf_input)} was found.",
1102
+ error_code=INVALID_PARAMETER_VALUE,
1103
+ )
1104
+ else:
1105
+ tensor_spec = input_schema.inputs[0]
1106
+ if isinstance(pf_input, pd.DataFrame):
1107
+ num_input_columns = len(pf_input.columns)
1108
+ if pf_input.empty:
1109
+ raise MlflowException("Input DataFrame is empty.")
1110
+ elif num_input_columns == 1:
1111
+ new_pf_input = _reshape_and_cast_pandas_column_values(
1112
+ None, pf_input[pf_input.columns[0]], tensor_spec
1113
+ )
1114
+ else:
1115
+ if tensor_spec.shape != (-1, num_input_columns):
1116
+ raise MlflowException(
1117
+ "This model contains a model signature with an unnamed input. Since the "
1118
+ "input data is a pandas DataFrame containing multiple columns, "
1119
+ "the input shape must be of the structure "
1120
+ "(-1, number_of_dataframe_columns). "
1121
+ f"Instead, the input DataFrame passed had {num_input_columns} columns and "
1122
+ f"an input shape of {tensor_spec.shape} with all values within the "
1123
+ "DataFrame of scalar type. Please adjust the passed in DataFrame to "
1124
+ "match the expected structure",
1125
+ error_code=INVALID_PARAMETER_VALUE,
1126
+ )
1127
+ new_pf_input = _enforce_tensor_spec(pf_input.to_numpy(), tensor_spec)
1128
+ elif isinstance(pf_input, np.ndarray) or _is_sparse_matrix(pf_input):
1129
+ new_pf_input = _enforce_tensor_spec(pf_input, tensor_spec)
1130
+ else:
1131
+ raise MlflowException(
1132
+ "This model contains a tensor-based model signature with no input names,"
1133
+ " which suggests a numpy array input or a pandas dataframe input with"
1134
+ f" proper column values, but an input of type {type(pf_input)} was found.",
1135
+ error_code=INVALID_PARAMETER_VALUE,
1136
+ )
1137
+ return new_pf_input
1138
+
1139
+
1140
+ def _enforce_schema(pf_input: PyFuncInput, input_schema: Schema, flavor: Optional[str] = None):
1141
+ """
1142
+ Enforces the provided input matches the model's input schema,
1143
+
1144
+ For signatures with input names, we check there are no missing inputs and reorder the inputs to
1145
+ match the ordering declared in schema if necessary. Any extra columns are ignored.
1146
+
1147
+ For column-based signatures, we make sure the types of the input match the type specified in
1148
+ the schema or if it can be safely converted to match the input schema.
1149
+
1150
+ For Pyspark DataFrame inputs, MLflow casts a sample of the PySpark DataFrame into a Pandas
1151
+ DataFrame. MLflow will only enforce the schema on a subset of the data rows.
1152
+
1153
+ For tensor-based signatures, we make sure the shape and type of the input matches the shape
1154
+ and type specified in model's input schema.
1155
+ """
1156
+
1157
+ def _is_scalar(x):
1158
+ return np.isscalar(x) or x is None
1159
+
1160
+ original_pf_input = pf_input
1161
+ if isinstance(pf_input, pd.Series):
1162
+ pf_input = pd.DataFrame(pf_input)
1163
+ if not input_schema.is_tensor_spec():
1164
+ # convert single DataType to pandas DataFrame
1165
+ if np.isscalar(pf_input):
1166
+ pf_input = pd.DataFrame([pf_input])
1167
+ elif isinstance(pf_input, dict):
1168
+ # keys are column names
1169
+ if any(
1170
+ isinstance(col_spec.type, (Array, Object)) for col_spec in input_schema.inputs
1171
+ ) or all(
1172
+ _is_scalar(value)
1173
+ or (isinstance(value, list) and all(isinstance(item, str) for item in value))
1174
+ for value in pf_input.values()
1175
+ ):
1176
+ pf_input = pd.DataFrame([pf_input])
1177
+ else:
1178
+ try:
1179
+ # This check is specifically to handle the serving structural cast for
1180
+ # certain inputs for the transformers implementation. Due to the fact that
1181
+ # specific Pipeline types in transformers support passing input data
1182
+ # of the form Dict[str, str] in which the value is a scalar string, model
1183
+ # serving will cast this entry as a numpy array with shape () and size 1.
1184
+ # This is seen as a scalar input when attempting to create a Pandas
1185
+ # DataFrame from such a numpy structure and requires the array to be
1186
+ # encapsulated in a list in order to prevent a ValueError exception for
1187
+ # requiring an index if passing in all scalar values thrown by Pandas.
1188
+ if all(
1189
+ isinstance(value, np.ndarray)
1190
+ and value.dtype.type == np.str_
1191
+ and value.size == 1
1192
+ and value.shape == ()
1193
+ for value in pf_input.values()
1194
+ ):
1195
+ pf_input = pd.DataFrame([pf_input])
1196
+ elif any(
1197
+ isinstance(value, np.ndarray) and value.ndim > 1
1198
+ for value in pf_input.values()
1199
+ ):
1200
+ # Pandas DataFrames can't be constructed with embedded multi-dimensional
1201
+ # numpy arrays. Accordingly, we convert any multi-dimensional numpy
1202
+ # arrays to lists before constructing a DataFrame. This is safe because
1203
+ # ColSpec model signatures do not support array columns, so subsequent
1204
+ # validation logic will result in a clear "incompatible input types"
1205
+ # exception. This is preferable to a pandas DataFrame construction error
1206
+ pf_input = pd.DataFrame(
1207
+ {
1208
+ key: (
1209
+ value.tolist()
1210
+ if (isinstance(value, np.ndarray) and value.ndim > 1)
1211
+ else value
1212
+ )
1213
+ for key, value in pf_input.items()
1214
+ }
1215
+ )
1216
+ else:
1217
+ pf_input = pd.DataFrame(pf_input)
1218
+ except Exception as e:
1219
+ raise MlflowException(
1220
+ "This model contains a column-based signature, which suggests a DataFrame"
1221
+ " input. There was an error casting the input data to a DataFrame:"
1222
+ f" {e}"
1223
+ )
1224
+ elif isinstance(pf_input, (list, np.ndarray, pd.Series)):
1225
+ pf_input = pd.DataFrame(pf_input)
1226
+ elif HAS_PYSPARK and isinstance(pf_input, SparkDataFrame):
1227
+ pf_input = pf_input.limit(10).toPandas()
1228
+ for field in original_pf_input.schema.fields:
1229
+ if isinstance(field.dataType, (StructType, ArrayType)):
1230
+ pf_input[field.name] = pf_input[field.name].apply(
1231
+ lambda row: convert_complex_types_pyspark_to_pandas(row, field.dataType)
1232
+ )
1233
+ if not isinstance(pf_input, pd.DataFrame):
1234
+ raise MlflowException(
1235
+ f"Expected input to be DataFrame. Found: {type(pf_input).__name__}"
1236
+ )
1237
+
1238
+ if input_schema.has_input_names():
1239
+ # make sure there are no missing columns
1240
+ input_names = input_schema.required_input_names()
1241
+ optional_names = input_schema.optional_input_names()
1242
+ expected_required_cols = set(input_names)
1243
+ actual_cols = set()
1244
+ optional_cols = set(optional_names)
1245
+ if len(expected_required_cols) == 1 and isinstance(pf_input, np.ndarray):
1246
+ # for schemas with a single column, match input with column
1247
+ pf_input = {input_names[0]: pf_input}
1248
+ actual_cols = expected_required_cols
1249
+ elif isinstance(pf_input, pd.DataFrame):
1250
+ actual_cols = set(pf_input.columns)
1251
+ elif isinstance(pf_input, dict):
1252
+ actual_cols = set(pf_input.keys())
1253
+ missing_cols = expected_required_cols - actual_cols
1254
+ extra_cols = actual_cols - expected_required_cols - optional_cols
1255
+ # Preserve order from the original columns, since missing/extra columns are likely to
1256
+ # be in same order.
1257
+ missing_cols = [c for c in input_names if c in missing_cols]
1258
+ extra_cols = [c for c in actual_cols if c in extra_cols]
1259
+ if missing_cols:
1260
+ message = f"Model is missing inputs {missing_cols}."
1261
+ if extra_cols:
1262
+ message += f" Note that there were extra inputs: {extra_cols}"
1263
+ raise MlflowException(message)
1264
+ if extra_cols:
1265
+ _logger.warning(
1266
+ "Found extra inputs in the model input that are not defined in the model "
1267
+ f"signature: `{extra_cols}`. These inputs will be ignored."
1268
+ )
1269
+ elif not input_schema.is_tensor_spec():
1270
+ # The model signature does not specify column names => we can only verify column count.
1271
+ num_actual_columns = len(pf_input.columns)
1272
+ if num_actual_columns < len(input_schema.inputs):
1273
+ raise MlflowException(
1274
+ "Model inference is missing inputs. The model signature declares "
1275
+ "{} inputs but the provided value only has "
1276
+ "{} inputs. Note: the inputs were not named in the signature so we can "
1277
+ "only verify their count.".format(len(input_schema.inputs), num_actual_columns)
1278
+ )
1279
+ if input_schema.is_tensor_spec():
1280
+ return _enforce_tensor_schema(pf_input, input_schema)
1281
+ elif HAS_PYSPARK and isinstance(original_pf_input, SparkDataFrame):
1282
+ return _enforce_pyspark_dataframe_schema(
1283
+ original_pf_input, pf_input, input_schema, flavor=flavor
1284
+ )
1285
+ else:
1286
+ # pf_input must be a pandas Dataframe at this point
1287
+ return (
1288
+ _enforce_named_col_schema(pf_input, input_schema)
1289
+ if input_schema.has_input_names()
1290
+ else _enforce_unnamed_col_schema(pf_input, input_schema)
1291
+ )
1292
+
1293
+
1294
+ def _enforce_pyspark_dataframe_schema(
1295
+ original_pf_input: SparkDataFrame,
1296
+ pf_input_as_pandas,
1297
+ input_schema: Schema,
1298
+ flavor: Optional[str] = None,
1299
+ ):
1300
+ """
1301
+ Enforce that the input PySpark DataFrame conforms to the model's input schema.
1302
+
1303
+ This function creates a new DataFrame that only includes the columns from the original
1304
+ DataFrame that are declared in the model's input schema. Any extra columns in the original
1305
+ DataFrame are dropped.Note that this function does not modify the original DataFrame.
1306
+
1307
+ Args:
1308
+ original_pf_input: Original input PySpark DataFrame.
1309
+ pf_input_as_pandas: Input DataFrame converted to pandas.
1310
+ input_schema: Expected schema of the input DataFrame.
1311
+ flavor: Optional model flavor. If specified, it is used to handle specific behaviors
1312
+ for different model flavors. Currently, only the '_FEATURE_STORE_FLAVOR' is
1313
+ handled specially.
1314
+
1315
+ Returns:
1316
+ New PySpark DataFrame that conforms to the model's input schema.
1317
+ """
1318
+ if not HAS_PYSPARK:
1319
+ raise MlflowException("PySpark is not installed. Cannot handle a PySpark DataFrame.")
1320
+ new_pf_input = original_pf_input.alias("pf_input_copy")
1321
+ if input_schema.has_input_names():
1322
+ _enforce_named_col_schema(pf_input_as_pandas, input_schema)
1323
+ input_names = input_schema.input_names()
1324
+
1325
+ else:
1326
+ _enforce_unnamed_col_schema(pf_input_as_pandas, input_schema)
1327
+ input_names = pf_input_as_pandas.columns[: len(input_schema.inputs)]
1328
+ columns_to_drop = []
1329
+ columns_not_dropped_for_feature_store_model = []
1330
+ for col, dtype in new_pf_input.dtypes:
1331
+ if col not in input_names:
1332
+ # to support backwards compatibility with feature store models
1333
+ if any(x in dtype for x in ["array", "map", "struct"]):
1334
+ if flavor == _FEATURE_STORE_FLAVOR:
1335
+ columns_not_dropped_for_feature_store_model.append(col)
1336
+ continue
1337
+ columns_to_drop.append(col)
1338
+ if columns_not_dropped_for_feature_store_model:
1339
+ _logger.warning(
1340
+ "The following columns are not in the model signature but "
1341
+ "are not dropped for feature store model: %s",
1342
+ ", ".join(columns_not_dropped_for_feature_store_model),
1343
+ )
1344
+ return new_pf_input.drop(*columns_to_drop)
1345
+
1346
+
1347
+ def _enforce_datatype(data: Any, dtype: DataType, required=True):
1348
+ if not required and _is_none_or_nan(data):
1349
+ return None
1350
+
1351
+ if not isinstance(dtype, DataType):
1352
+ raise MlflowException(f"Expected dtype to be DataType, got {type(dtype).__name__}")
1353
+ if not np.isscalar(data):
1354
+ raise MlflowException(f"Expected data to be scalar, got {type(data).__name__}")
1355
+ # Reuse logic in _enforce_mlflow_datatype for type conversion
1356
+ pd_series = pd.Series(data)
1357
+ try:
1358
+ pd_series = _enforce_mlflow_datatype("", pd_series, dtype)
1359
+ except MlflowException:
1360
+ raise MlflowException(
1361
+ f"Failed to enforce schema of data `{data}` with dtype `{dtype.name}`"
1362
+ )
1363
+ return pd_series[0]
1364
+
1365
+
1366
+ def _enforce_array(data: Any, arr: Array, required: bool = True):
1367
+ """
1368
+ Enforce data against an Array type.
1369
+ If the field is required, then the data must be provided.
1370
+ If Array's internal dtype is AnyType, then None and empty lists are also accepted.
1371
+ """
1372
+ if not required or isinstance(arr.dtype, AnyType):
1373
+ if data is None or (isinstance(data, (list, np.ndarray)) and len(data) == 0):
1374
+ return data
1375
+
1376
+ if not isinstance(data, (list, np.ndarray)):
1377
+ raise MlflowException(f"Expected data to be list or numpy array, got {type(data).__name__}")
1378
+
1379
+ if isinstance(arr.dtype, DataType):
1380
+ # TODO: this is still significantly slower than direct np.asarray dtype conversion
1381
+ # pd.Series conversion can be removed once we support direct validation on the numpy array
1382
+ data_enforced = (
1383
+ _enforce_mlflow_datatype("", pd.Series(data), arr.dtype).to_numpy(
1384
+ dtype=arr.dtype.to_numpy()
1385
+ )
1386
+ if len(data) > 0
1387
+ else data
1388
+ )
1389
+ else:
1390
+ data_enforced = [_enforce_type(x, arr.dtype, required=required) for x in data]
1391
+
1392
+ if isinstance(data, list) and isinstance(data_enforced, np.ndarray):
1393
+ data_enforced = data_enforced.tolist()
1394
+ elif isinstance(data, np.ndarray) and isinstance(data_enforced, list):
1395
+ data_enforced = np.array(data_enforced)
1396
+
1397
+ return data_enforced
1398
+
1399
+
1400
+ def _enforce_property(data: Any, property: Property):
1401
+ return _enforce_type(data, property.dtype, required=property.required)
1402
+
1403
+
1404
+ def _enforce_object(data: dict[str, Any], obj: Object, required: bool = True):
1405
+ if HAS_PYSPARK and isinstance(data, Row):
1406
+ data = None if len(data) == 0 else data.asDict(True)
1407
+ if not required and (data is None or data == {}):
1408
+ return data
1409
+ if not isinstance(data, dict):
1410
+ raise MlflowException(
1411
+ f"Failed to enforce schema of '{data}' with type '{obj}'. "
1412
+ f"Expected data to be dictionary, got {type(data).__name__}"
1413
+ )
1414
+ if not isinstance(obj, Object):
1415
+ raise MlflowException(
1416
+ f"Failed to enforce schema of '{data}' with type '{obj}'. "
1417
+ f"Expected obj to be Object, got {type(obj).__name__}"
1418
+ )
1419
+ properties = {prop.name: prop for prop in obj.properties}
1420
+ required_props = {k for k, prop in properties.items() if prop.required}
1421
+ missing_props = required_props - set(data.keys())
1422
+ if missing_props:
1423
+ raise MlflowException(f"Missing required properties: {missing_props}")
1424
+ if invalid_props := data.keys() - properties.keys():
1425
+ raise MlflowException(
1426
+ f"Invalid properties not defined in the schema found: {invalid_props}"
1427
+ )
1428
+ for k, v in data.items():
1429
+ try:
1430
+ data[k] = _enforce_property(v, properties[k])
1431
+ except MlflowException as e:
1432
+ raise MlflowException(
1433
+ f"Failed to enforce schema for key `{k}`. "
1434
+ f"Expected type {properties[k].to_dict()[k]['type']}, "
1435
+ f"received type {type(v).__name__}"
1436
+ ) from e
1437
+ return data
1438
+
1439
+
1440
+ def _enforce_map(data: Any, map_type: Map, required: bool = True):
1441
+ if (not required or isinstance(map_type.value_type, AnyType)) and (data is None or data == {}):
1442
+ return data
1443
+
1444
+ if not isinstance(data, dict):
1445
+ raise MlflowException(f"Expected data to be a dict, got {type(data).__name__}")
1446
+
1447
+ if not all(isinstance(k, str) for k in data):
1448
+ raise MlflowException("Expected all keys in the map type data are string type.")
1449
+
1450
+ return {k: _enforce_type(v, map_type.value_type, required=required) for k, v in data.items()}
1451
+
1452
+
1453
+ def _enforce_type(data: Any, data_type: Union[DataType, Array, Object, Map], required=True):
1454
+ if isinstance(data_type, DataType):
1455
+ return _enforce_datatype(data, data_type, required=required)
1456
+ if isinstance(data_type, Array):
1457
+ return _enforce_array(data, data_type, required=required)
1458
+ if isinstance(data_type, Object):
1459
+ return _enforce_object(data, data_type, required=required)
1460
+ if isinstance(data_type, Map):
1461
+ return _enforce_map(data, data_type, required=required)
1462
+ if isinstance(data_type, AnyType):
1463
+ return data
1464
+ raise MlflowException(f"Invalid data type: {data_type!r}")
1465
+
1466
+
1467
+ def validate_schema(data: PyFuncInput, expected_schema: Schema) -> None:
1468
+ """
1469
+ Validate that the input data has the expected schema.
1470
+
1471
+ Args:
1472
+ data: Input data to be validated. Supported types are:
1473
+
1474
+ - pandas.DataFrame
1475
+ - pandas.Series
1476
+ - numpy.ndarray
1477
+ - scipy.sparse.csc_matrix
1478
+ - scipy.sparse.csr_matrix
1479
+ - List[Any]
1480
+ - Dict[str, Any]
1481
+ - str
1482
+
1483
+ expected_schema: Expected Schema of the input data.
1484
+
1485
+ Raises:
1486
+ mlflow.exceptions.MlflowException: when the input data does not match the schema.
1487
+
1488
+ .. code-block:: python
1489
+ :caption: Example usage of validate_schema
1490
+
1491
+ import mlflow.models
1492
+
1493
+ # Suppose you've already got a model_uri
1494
+ model_info = mlflow.models.get_model_info(model_uri)
1495
+ # Get model signature directly
1496
+ model_signature = model_info.signature
1497
+ # validate schema
1498
+ mlflow.models.validate_schema(input_data, model_signature.inputs)
1499
+ """
1500
+
1501
+ _enforce_schema(data, expected_schema)
1502
+
1503
+
1504
+ def add_libraries_to_model(model_uri, run_id=None, registered_model_name=None):
1505
+ """
1506
+ Given a registered model_uri (e.g. models:/<model_name>/<model_version>), this utility
1507
+ re-logs the model along with all the required model libraries back to the Model Registry.
1508
+ The required model libraries are stored along with the model as model artifacts. In
1509
+ addition, supporting files to the model (e.g. conda.yaml, requirements.txt) are modified
1510
+ to use the added libraries.
1511
+
1512
+ By default, this utility creates a new model version under the same registered model specified
1513
+ by ``model_uri``. This behavior can be overridden by specifying the ``registered_model_name``
1514
+ argument.
1515
+
1516
+ Args:
1517
+ model_uri: A registered model uri in the Model Registry of the form
1518
+ models:/<model_name>/<model_version/stage/latest>
1519
+ run_id: The ID of the run to which the model with libraries is logged. If None, the model
1520
+ with libraries is logged to the source run corresponding to model version
1521
+ specified by ``model_uri``; if the model version does not have a source run, a
1522
+ new run created.
1523
+ registered_model_name: The new model version (model with its libraries) is
1524
+ registered under the inputted registered_model_name. If None, a
1525
+ new version is logged to the existing model in the Model Registry.
1526
+
1527
+ .. note::
1528
+ This utility only operates on a model that has been registered to the Model Registry.
1529
+
1530
+ .. note::
1531
+ The libraries are only compatible with the platform on which they are added. Cross platform
1532
+ libraries are not supported.
1533
+
1534
+ .. code-block:: python
1535
+ :caption: Example
1536
+
1537
+ # Create and log a model to the Model Registry
1538
+ import pandas as pd
1539
+ from sklearn import datasets
1540
+ from sklearn.ensemble import RandomForestClassifier
1541
+ import mlflow
1542
+ import mlflow.sklearn
1543
+ from mlflow.models import infer_signature
1544
+
1545
+ with mlflow.start_run():
1546
+ iris = datasets.load_iris()
1547
+ iris_train = pd.DataFrame(iris.data, columns=iris.feature_names)
1548
+ clf = RandomForestClassifier(max_depth=7, random_state=0)
1549
+ clf.fit(iris_train, iris.target)
1550
+ signature = infer_signature(iris_train, clf.predict(iris_train))
1551
+ mlflow.sklearn.log_model(
1552
+ clf,
1553
+ name="iris_rf",
1554
+ signature=signature,
1555
+ registered_model_name="model-with-libs",
1556
+ )
1557
+
1558
+ # model uri for the above model
1559
+ model_uri = "models:/model-with-libs/1"
1560
+
1561
+ # Import utility
1562
+ from mlflow.models.utils import add_libraries_to_model
1563
+
1564
+ # Log libraries to the original run of the model
1565
+ add_libraries_to_model(model_uri)
1566
+
1567
+ # Log libraries to some run_id
1568
+ existing_run_id = "21df94e6bdef4631a9d9cb56f211767f"
1569
+ add_libraries_to_model(model_uri, run_id=existing_run_id)
1570
+
1571
+ # Log libraries to a new run
1572
+ with mlflow.start_run():
1573
+ add_libraries_to_model(model_uri)
1574
+
1575
+ # Log libraries to a new registered model named 'new-model'
1576
+ with mlflow.start_run():
1577
+ add_libraries_to_model(model_uri, registered_model_name="new-model")
1578
+ """
1579
+
1580
+ import mlflow
1581
+ from mlflow.models.wheeled_model import WheeledModel
1582
+
1583
+ if mlflow.active_run() is None:
1584
+ if run_id is None:
1585
+ run_id = get_model_version_from_model_uri(model_uri).run_id
1586
+ with mlflow.start_run(run_id):
1587
+ return WheeledModel.log_model(model_uri, registered_model_name)
1588
+ else:
1589
+ return WheeledModel.log_model(model_uri, registered_model_name)
1590
+
1591
+
1592
+ def get_model_version_from_model_uri(model_uri):
1593
+ """
1594
+ Helper function to fetch a model version from a model uri of the form
1595
+ models:/<model_name>/<model_version/stage/latest>.
1596
+ """
1597
+ import mlflow
1598
+ from mlflow import MlflowClient
1599
+
1600
+ databricks_profile_uri = (
1601
+ get_databricks_profile_uri_from_artifact_uri(model_uri) or mlflow.get_registry_uri()
1602
+ )
1603
+ client = MlflowClient(registry_uri=databricks_profile_uri)
1604
+ (name, version) = get_model_name_and_version(client, model_uri)
1605
+ return client.get_model_version(name, version)
1606
+
1607
+
1608
+ def _enforce_params_schema(params: Optional[dict[str, Any]], schema: Optional[ParamSchema]):
1609
+ if schema is None:
1610
+ if params in [None, {}]:
1611
+ return params
1612
+ params_info = (
1613
+ f"Ignoring provided params: {list(params.keys())}"
1614
+ if isinstance(params, dict)
1615
+ else "Ignoring invalid params (not a dictionary)."
1616
+ )
1617
+ _logger.warning(
1618
+ "`params` can only be specified at inference time if the model signature "
1619
+ f"defines a params schema. This model does not define a params schema. {params_info}",
1620
+ )
1621
+ return {}
1622
+ params = {} if params is None else params
1623
+ if not isinstance(params, dict):
1624
+ raise MlflowException.invalid_parameter_value(
1625
+ f"Parameters must be a dictionary. Got type '{type(params).__name__}'.",
1626
+ )
1627
+ if not isinstance(schema, ParamSchema):
1628
+ raise MlflowException.invalid_parameter_value(
1629
+ "Parameters schema must be an instance of ParamSchema. "
1630
+ f"Got type '{type(schema).__name__}'.",
1631
+ )
1632
+ if any(not isinstance(k, str) for k in params.keys()):
1633
+ _logger.warning(
1634
+ "Keys in parameters should be of type `str`, but received non-string keys."
1635
+ "Converting all keys to string..."
1636
+ )
1637
+ params = {str(k): v for k, v in params.items()}
1638
+
1639
+ allowed_keys = {param.name for param in schema.params}
1640
+ ignored_keys = set(params) - allowed_keys
1641
+ if ignored_keys:
1642
+ _logger.warning(
1643
+ f"Unrecognized params {list(ignored_keys)} are ignored for inference. "
1644
+ f"Supported params are: {allowed_keys}. "
1645
+ "To enable them, please add corresponding schema in ModelSignature."
1646
+ )
1647
+
1648
+ params = {k: params[k] for k in params if k in allowed_keys}
1649
+
1650
+ invalid_params = set()
1651
+ for param_spec in schema.params:
1652
+ if param_spec.name in params:
1653
+ try:
1654
+ params[param_spec.name] = ParamSpec.validate_param_spec(
1655
+ params[param_spec.name], param_spec
1656
+ )
1657
+ except MlflowException as e:
1658
+ invalid_params.add((param_spec.name, e.message))
1659
+ else:
1660
+ params[param_spec.name] = param_spec.default
1661
+
1662
+ if invalid_params:
1663
+ raise MlflowException.invalid_parameter_value(
1664
+ f"Invalid parameters found: {invalid_params!r}",
1665
+ )
1666
+
1667
+ return params
1668
+
1669
+
1670
+ def convert_complex_types_pyspark_to_pandas(value, dataType):
1671
+ # This function is needed because the default `asDict` function in PySpark
1672
+ # converts the data to Python types, which is not compatible with the schema enforcement.
1673
+ type_mapping = {
1674
+ IntegerType: lambda v: np.int32(v),
1675
+ ShortType: lambda v: np.int16(v),
1676
+ FloatType: lambda v: np.float32(v),
1677
+ DateType: lambda v: v.strftime("%Y-%m-%d"),
1678
+ TimestampType: lambda v: v.strftime("%Y-%m-%d %H:%M:%S.%f"),
1679
+ BinaryType: lambda v: np.bytes_(v),
1680
+ }
1681
+ if value is None:
1682
+ return None
1683
+ if isinstance(dataType, StructType):
1684
+ return {
1685
+ field.name: convert_complex_types_pyspark_to_pandas(value[field.name], field.dataType)
1686
+ for field in dataType.fields
1687
+ }
1688
+ elif isinstance(dataType, ArrayType):
1689
+ return [
1690
+ convert_complex_types_pyspark_to_pandas(elem, dataType.elementType) for elem in value
1691
+ ]
1692
+ converter = type_mapping.get(type(dataType))
1693
+ if converter:
1694
+ return converter(value)
1695
+ return value
1696
+
1697
+
1698
+ def _is_in_comment(line, start):
1699
+ """
1700
+ Check if the code at the index "start" of the line is in a comment.
1701
+
1702
+ Limitations: This function does not handle multi-line comments, and the # symbol could be in a
1703
+ string, or otherwise not indicate a comment.
1704
+ """
1705
+ return "#" in line[:start]
1706
+
1707
+
1708
+ def _is_in_string_only(line, search_string):
1709
+ """
1710
+ Check is the search_string
1711
+
1712
+ Limitations: This function does not handle multi-line strings.
1713
+ """
1714
+ # Regex for matching double quotes and everything inside
1715
+ double_quotes_regex = r"\"(\\.|[^\"])*\""
1716
+
1717
+ # Regex for matching single quotes and everything inside
1718
+ single_quotes_regex = r"\'(\\.|[^\'])*\'"
1719
+
1720
+ # Regex for matching search_string exactly
1721
+ search_string_regex = rf"({re.escape(search_string)})"
1722
+
1723
+ # Concatenate the patterns using the OR operator '|'
1724
+ # This will matches left to right - on quotes first, search_string last
1725
+ pattern = double_quotes_regex + r"|" + single_quotes_regex + r"|" + search_string_regex
1726
+
1727
+ # Iterate through all matches in the line
1728
+ for match in re.finditer(pattern, line):
1729
+ # If the regex matched on the search_string, we know that it did not match in quotes since
1730
+ # that is the order. So we know that the search_string exists outside of quotes
1731
+ # (at least once).
1732
+ if match.group() == search_string:
1733
+ return False
1734
+ return True
1735
+
1736
+
1737
+ def _validate_model_code_from_notebook(code):
1738
+ """
1739
+ Validate there isn't any code that would work in a notebook but not as exported Python file.
1740
+ For now, this checks for dbutils and magic commands.
1741
+ """
1742
+
1743
+ output_code_list = []
1744
+ for line in code.splitlines():
1745
+ for match in re.finditer(r"\bdbutils\b", line):
1746
+ start = match.start()
1747
+ if not _is_in_comment(line, start) and not _is_in_string_only(line, "dbutils"):
1748
+ _logger.warning(
1749
+ "The model file uses 'dbutils' commands which are not supported. To ensure "
1750
+ "your code functions correctly, make sure that it does not rely on these "
1751
+ "dbutils commands for correctness."
1752
+ )
1753
+ # Prefix any line containing MAGIC commands with a comment. When there is better support
1754
+ # for the Databricks workspace export API, we can get rid of this.
1755
+ if line.startswith("%"):
1756
+ output_code_list.append("# MAGIC " + line)
1757
+ else:
1758
+ output_code_list.append(line)
1759
+ output_code = "\n".join(output_code_list)
1760
+
1761
+ magic_regex = r"^# MAGIC %((?!pip)\S+).*"
1762
+ if re.search(magic_regex, output_code, re.MULTILINE):
1763
+ _logger.warning(
1764
+ "The model file uses magic commands which have been commented out. To ensure your code "
1765
+ "functions correctly, make sure that it does not rely on these magic commands for "
1766
+ "correctness."
1767
+ )
1768
+
1769
+ return output_code.encode("utf-8")
1770
+
1771
+
1772
+ def _convert_llm_ndarray_to_list(data):
1773
+ """
1774
+ Convert numpy array in the input data to list, because numpy array is not json serializable.
1775
+ """
1776
+ if isinstance(data, np.ndarray):
1777
+ return data.tolist()
1778
+ if isinstance(data, list):
1779
+ return [_convert_llm_ndarray_to_list(d) for d in data]
1780
+ if isinstance(data, dict):
1781
+ return {k: _convert_llm_ndarray_to_list(v) for k, v in data.items()}
1782
+ # scalar values are also converted to numpy types, but they are
1783
+ # not acceptable by the model
1784
+ if np.isscalar(data) and isinstance(data, np.generic):
1785
+ return data.item()
1786
+ return data
1787
+
1788
+
1789
+ def _convert_llm_input_data(data: Any) -> Union[list[Any], dict[str, Any]]:
1790
+ """
1791
+ Convert input data to a format that can be passed to the model with GenAI flavors such as
1792
+ LangChain and LLamaIndex.
1793
+
1794
+ Args
1795
+ data: Input data to be converted. We assume it is a single request payload, but it can be
1796
+ in any format such as a single scalar value, a dictionary, list (with one element),
1797
+ Pandas DataFrame, etc.
1798
+ """
1799
+ # This handles pyfunc / spark_udf inputs with model signature. Schema enforcement convert
1800
+ # the input data to pandas DataFrame, so we convert it back.
1801
+ if isinstance(data, pd.DataFrame):
1802
+ # if the data only contains a single key as 0, we assume the input
1803
+ # is either a string or list of strings
1804
+ if list(data.columns) == [0]:
1805
+ data = data.to_dict("list")[0]
1806
+ else:
1807
+ data = data.to_dict(orient="records")
1808
+
1809
+ return _convert_llm_ndarray_to_list(data)
1810
+
1811
+
1812
+ def _databricks_path_exists(path: Path) -> bool:
1813
+ """
1814
+ Check if a path exists in Databricks workspace.
1815
+ """
1816
+ if not is_in_databricks_runtime():
1817
+ return False
1818
+
1819
+ from databricks.sdk import WorkspaceClient
1820
+ from databricks.sdk.errors import ResourceDoesNotExist
1821
+
1822
+ client = WorkspaceClient()
1823
+ try:
1824
+ client.workspace.get_status(str(path))
1825
+ return True
1826
+ except ResourceDoesNotExist:
1827
+ return False
1828
+
1829
+
1830
+ def _validate_and_get_model_code_path(model_code_path: str, temp_dir: str) -> str:
1831
+ """
1832
+ Validate model code path exists. When failing to open the model file on Databricks,
1833
+ creates a temp file in temp_dir and validate its contents if it's a notebook.
1834
+
1835
+ Returns either `model_code_path` or a temp file path with the contents of the notebook.
1836
+ """
1837
+
1838
+ # If the path is not a absolute path then convert it
1839
+ model_code_path = Path(model_code_path).resolve()
1840
+
1841
+ if not (model_code_path.exists() or _databricks_path_exists(model_code_path)):
1842
+ additional_message = (
1843
+ f" Perhaps you meant '{model_code_path}.py'?" if not model_code_path.suffix else ""
1844
+ )
1845
+
1846
+ raise MlflowException.invalid_parameter_value(
1847
+ f"The provided model path '{model_code_path}' does not exist. "
1848
+ f"Ensure the file path is valid and try again.{additional_message}"
1849
+ )
1850
+
1851
+ try:
1852
+ # If `model_code_path` points to a notebook on Databricks, this line throws either
1853
+ # a `FileNotFoundError` or an `OSError`. In this case, try to export the notebook as
1854
+ # a Python file.
1855
+ with open(model_code_path):
1856
+ pass
1857
+
1858
+ return str(model_code_path)
1859
+ except Exception:
1860
+ pass
1861
+
1862
+ try:
1863
+ from databricks.sdk import WorkspaceClient
1864
+ from databricks.sdk.service.workspace import ExportFormat
1865
+
1866
+ w = WorkspaceClient()
1867
+ response = w.workspace.export(path=model_code_path, format=ExportFormat.SOURCE)
1868
+ decoded_content = base64.b64decode(response.content)
1869
+ except Exception:
1870
+ raise MlflowException.invalid_parameter_value(
1871
+ f"The provided model path '{model_code_path}' is not a valid Python file path or a "
1872
+ "Databricks Notebook file path containing the code for defining the chain "
1873
+ "instance. Ensure the file path is valid and try again."
1874
+ )
1875
+
1876
+ _validate_model_code_from_notebook(decoded_content.decode("utf-8"))
1877
+ path = os.path.join(temp_dir, "model.py")
1878
+ with open(path, "wb") as f:
1879
+ f.write(decoded_content)
1880
+ return path
1881
+
1882
+
1883
+ @contextmanager
1884
+ def _config_context(config: Optional[Union[str, dict[str, Any]]] = None):
1885
+ # Check if config_path is None and set it to "" so when loading the model
1886
+ # the config_path is set to "" so the ModelConfig can correctly check if the
1887
+ # config is set or not
1888
+ if config is None:
1889
+ config = ""
1890
+
1891
+ _set_model_config(config)
1892
+ try:
1893
+ yield
1894
+ finally:
1895
+ _set_model_config(None)
1896
+
1897
+
1898
+ class MockDbutils:
1899
+ def __init__(self, real_dbutils=None):
1900
+ self.real_dbutils = real_dbutils
1901
+
1902
+ def __getattr__(self, name):
1903
+ try:
1904
+ if self.real_dbutils:
1905
+ return getattr(self.real_dbutils, name)
1906
+ except AttributeError:
1907
+ pass
1908
+ return MockDbutils()
1909
+
1910
+ def __call__(self, *args, **kwargs):
1911
+ pass
1912
+
1913
+
1914
+ @contextmanager
1915
+ def _mock_dbutils(globals_dict):
1916
+ module_name = "dbutils"
1917
+ original_module = sys.modules.get(module_name)
1918
+ sys.modules[module_name] = MockDbutils(original_module)
1919
+
1920
+ # Inject module directly into the global namespace in case it is referenced without an import
1921
+ original_global = globals_dict.get(module_name)
1922
+ globals_dict[module_name] = MockDbutils(original_module)
1923
+
1924
+ try:
1925
+ yield
1926
+ finally:
1927
+ if original_module is not None:
1928
+ sys.modules[module_name] = original_module
1929
+ else:
1930
+ del sys.modules[module_name]
1931
+
1932
+ if original_global is not None:
1933
+ globals_dict[module_name] = original_global
1934
+ else:
1935
+ del globals_dict[module_name]
1936
+
1937
+
1938
+ # Python's module caching mechanism prevents the re-importation of previously loaded modules by
1939
+ # default. Once a module is imported, it's added to `sys.modules`, and subsequent import attempts
1940
+ # retrieve the cached module rather than re-importing it.
1941
+ # Here, we want to import the `code path` module multiple times during a single runtime session.
1942
+ # This function addresses this by dynamically importing the `code path` module under a unique,
1943
+ # dynamically generated module name. This bypasses the caching mechanism, as each import is
1944
+ # considered a separate module by the Python interpreter.
1945
+ def _load_model_code_path(code_path: str, model_config: Optional[Union[str, dict[str, Any]]]):
1946
+ with _config_context(model_config):
1947
+ try:
1948
+ new_module_name = f"code_model_{uuid.uuid4().hex}"
1949
+ spec = importlib.util.spec_from_file_location(new_module_name, code_path)
1950
+ module = importlib.util.module_from_spec(spec)
1951
+ sys.modules[new_module_name] = module
1952
+ # Since dbutils will only work in databricks environment, we need to mock it
1953
+ with _mock_dbutils(module.__dict__):
1954
+ spec.loader.exec_module(module)
1955
+ except ImportError as e:
1956
+ raise MlflowException(
1957
+ f"Failed to import code model from {code_path}. Error: {e!s}"
1958
+ ) from e
1959
+ except Exception as e:
1960
+ raise MlflowException(
1961
+ f"Failed to run user code from {code_path}. "
1962
+ f"Error: {e!s}. "
1963
+ "Review the stack trace for more information."
1964
+ ) from e
1965
+
1966
+ if mlflow.models.model.__mlflow_model__ is None:
1967
+ raise MlflowException(
1968
+ "If the model is logged as code, ensure the model is set using "
1969
+ "mlflow.models.set_model() within the code file code file."
1970
+ )
1971
+ return mlflow.models.model.__mlflow_model__
1972
+
1973
+
1974
+ def _flatten_nested_params(
1975
+ d: dict[str, Any], parent_key: str = "", sep: str = "/"
1976
+ ) -> dict[str, str]:
1977
+ items: dict[str, Any] = {}
1978
+ for k, v in d.items():
1979
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
1980
+ if isinstance(v, dict):
1981
+ items.update(_flatten_nested_params(v, new_key, sep=sep))
1982
+ else:
1983
+ items[new_key] = v
1984
+ return items
1985
+
1986
+
1987
+ # NB: this function should always be kept in sync with the serving
1988
+ # process in scoring_server invocations.
1989
+ @experimental(version="2.16.0")
1990
+ def validate_serving_input(model_uri: str, serving_input: Union[str, dict[str, Any]]):
1991
+ """
1992
+ Helper function to validate the model can be served and provided input is valid
1993
+ prior to serving the model.
1994
+
1995
+ Args:
1996
+ model_uri: URI of the model to be served.
1997
+ serving_input: Input data to be validated. Should be a dictionary or a JSON string.
1998
+
1999
+ Returns:
2000
+ The prediction result from the model.
2001
+ """
2002
+ from mlflow.pyfunc.scoring_server import _parse_json_data
2003
+ from mlflow.pyfunc.utils.environment import _simulate_serving_environment
2004
+
2005
+ # sklearn model might not have python_function flavor if it
2006
+ # doesn't define a predict function. In such case the model
2007
+ # can not be served anyways
2008
+
2009
+ output_dir = None if get_local_path_or_none(model_uri) else create_tmp_dir()
2010
+
2011
+ try:
2012
+ pyfunc_model = mlflow.pyfunc.load_model(model_uri, dst_path=output_dir)
2013
+ parsed_input = _parse_json_data(
2014
+ serving_input,
2015
+ pyfunc_model.metadata,
2016
+ pyfunc_model.metadata.get_input_schema(),
2017
+ )
2018
+ with _simulate_serving_environment():
2019
+ return pyfunc_model.predict(parsed_input.data, params=parsed_input.params)
2020
+ finally:
2021
+ if output_dir and os.path.exists(output_dir):
2022
+ shutil.rmtree(output_dir)
2023
+
2024
+
2025
+ def get_external_mlflow_model_spec(logged_model: LoggedModel) -> Model:
2026
+ """
2027
+ Create the MLflow Model specification for a given logged model whose artifacts
2028
+ (code, weights, etc.) are stored externally outside of MLflow.
2029
+
2030
+ Args:
2031
+ logged_model: The external logged model for which to create an MLflow Model specification.
2032
+
2033
+ Returns:
2034
+ Model: MLflow Model specification for the given logged model with external artifacts.
2035
+ """
2036
+ from mlflow.models.signature import infer_signature
2037
+
2038
+ return Model(
2039
+ artifact_path=logged_model.artifact_location,
2040
+ model_uuid=logged_model.model_id,
2041
+ model_id=logged_model.model_id,
2042
+ run_id=logged_model.source_run_id,
2043
+ # Include a dummy signature so that the model can be registered to the Databricks Unity
2044
+ # Catalog Model Registry.
2045
+ # TODO: Remove this once the Databricks Unity Catalog Model Registry supports registration
2046
+ # of models without signatures
2047
+ signature=infer_signature(model_input=True, model_output=True),
2048
+ metadata={
2049
+ # Add metadata to the logged model indicating that its artifacts are stored externally.
2050
+ # This helps downstream consumers of the model, such as the Model Registry, easily
2051
+ # and consistently identify that the model's artifacts are external
2052
+ MLFLOW_MODEL_IS_EXTERNAL: True,
2053
+ },
2054
+ )