genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,1740 @@
1
+ import base64
2
+ import functools
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ import shutil
8
+ from contextlib import contextmanager
9
+ from dataclasses import dataclass
10
+ from typing import Optional, Union
11
+
12
+ import google.protobuf.empty_pb2
13
+
14
+ import mlflow
15
+ from mlflow.entities import Run
16
+ from mlflow.entities.logged_model import LoggedModel
17
+ from mlflow.entities.model_registry.prompt import Prompt
18
+ from mlflow.entities.model_registry.prompt_version import PromptVersion
19
+ from mlflow.exceptions import MlflowException, RestException
20
+ from mlflow.protos.databricks_pb2 import (
21
+ INTERNAL_ERROR,
22
+ INVALID_PARAMETER_VALUE,
23
+ RESOURCE_DOES_NOT_EXIST,
24
+ ErrorCode,
25
+ )
26
+ from mlflow.protos.databricks_uc_registry_messages_pb2 import (
27
+ MODEL_VERSION_OPERATION_READ_WRITE,
28
+ CreateModelVersionRequest,
29
+ CreateModelVersionResponse,
30
+ CreateRegisteredModelRequest,
31
+ CreateRegisteredModelResponse,
32
+ DeleteModelVersionRequest,
33
+ DeleteModelVersionResponse,
34
+ DeleteModelVersionTagRequest,
35
+ DeleteModelVersionTagResponse,
36
+ DeleteRegisteredModelAliasRequest,
37
+ DeleteRegisteredModelAliasResponse,
38
+ DeleteRegisteredModelRequest,
39
+ DeleteRegisteredModelResponse,
40
+ DeleteRegisteredModelTagRequest,
41
+ DeleteRegisteredModelTagResponse,
42
+ Entity,
43
+ FinalizeModelVersionRequest,
44
+ FinalizeModelVersionResponse,
45
+ GenerateTemporaryModelVersionCredentialsRequest,
46
+ GenerateTemporaryModelVersionCredentialsResponse,
47
+ GetModelVersionByAliasRequest,
48
+ GetModelVersionByAliasResponse,
49
+ GetModelVersionDownloadUriRequest,
50
+ GetModelVersionDownloadUriResponse,
51
+ GetModelVersionRequest,
52
+ GetModelVersionResponse,
53
+ GetRegisteredModelRequest,
54
+ GetRegisteredModelResponse,
55
+ Job,
56
+ Lineage,
57
+ LineageHeaderInfo,
58
+ Notebook,
59
+ SearchModelVersionsRequest,
60
+ SearchModelVersionsResponse,
61
+ SearchRegisteredModelsRequest,
62
+ SearchRegisteredModelsResponse,
63
+ Securable,
64
+ SetModelVersionTagRequest,
65
+ SetModelVersionTagResponse,
66
+ SetRegisteredModelAliasRequest,
67
+ SetRegisteredModelAliasResponse,
68
+ SetRegisteredModelTagRequest,
69
+ SetRegisteredModelTagResponse,
70
+ StorageMode,
71
+ Table,
72
+ TemporaryCredentials,
73
+ UpdateModelVersionRequest,
74
+ UpdateModelVersionResponse,
75
+ UpdateRegisteredModelRequest,
76
+ UpdateRegisteredModelResponse,
77
+ )
78
+ from mlflow.protos.databricks_uc_registry_service_pb2 import UcModelRegistryService
79
+ from mlflow.protos.service_pb2 import GetRun, MlflowService
80
+ from mlflow.protos.unity_catalog_prompt_messages_pb2 import (
81
+ CreatePromptRequest,
82
+ CreatePromptVersionRequest,
83
+ DeletePromptAliasRequest,
84
+ DeletePromptRequest,
85
+ DeletePromptTagRequest,
86
+ DeletePromptVersionRequest,
87
+ DeletePromptVersionTagRequest,
88
+ GetPromptRequest,
89
+ GetPromptVersionByAliasRequest,
90
+ GetPromptVersionRequest,
91
+ LinkPromptsToTracesRequest,
92
+ LinkPromptVersionsToModelsRequest,
93
+ LinkPromptVersionsToRunsRequest,
94
+ PromptVersionLinkEntry,
95
+ SearchPromptsRequest,
96
+ SearchPromptsResponse,
97
+ SearchPromptVersionsRequest,
98
+ SearchPromptVersionsResponse,
99
+ SetPromptAliasRequest,
100
+ SetPromptTagRequest,
101
+ SetPromptVersionTagRequest,
102
+ UnityCatalogSchema,
103
+ UpdatePromptRequest,
104
+ UpdatePromptVersionRequest,
105
+ )
106
+ from mlflow.protos.unity_catalog_prompt_messages_pb2 import (
107
+ Prompt as ProtoPrompt,
108
+ )
109
+ from mlflow.protos.unity_catalog_prompt_messages_pb2 import (
110
+ PromptVersion as ProtoPromptVersion,
111
+ )
112
+ from mlflow.protos.unity_catalog_prompt_service_pb2 import UnityCatalogPromptService
113
+ from mlflow.store._unity_catalog.lineage.constants import (
114
+ _DATABRICKS_LINEAGE_ID_HEADER,
115
+ _DATABRICKS_ORG_ID_HEADER,
116
+ )
117
+ from mlflow.store._unity_catalog.registry.utils import (
118
+ mlflow_tags_to_proto,
119
+ mlflow_tags_to_proto_version_tags,
120
+ proto_info_to_mlflow_prompt_info,
121
+ proto_to_mlflow_prompt,
122
+ )
123
+ from mlflow.store.artifact.databricks_sdk_models_artifact_repo import (
124
+ DatabricksSDKModelsArtifactRepository,
125
+ )
126
+ from mlflow.store.artifact.presigned_url_artifact_repo import (
127
+ PresignedUrlArtifactRepository,
128
+ )
129
+ from mlflow.store.entities.paged_list import PagedList
130
+ from mlflow.store.model_registry.rest_store import BaseRestStore
131
+ from mlflow.utils._spark_utils import _get_active_spark_session
132
+ from mlflow.utils._unity_catalog_utils import (
133
+ get_artifact_repo_from_storage_info,
134
+ get_full_name_from_sc,
135
+ is_databricks_sdk_models_artifact_repository_enabled,
136
+ model_version_from_uc_proto,
137
+ model_version_search_from_uc_proto,
138
+ registered_model_from_uc_proto,
139
+ registered_model_search_from_uc_proto,
140
+ uc_model_version_tag_from_mlflow_tags,
141
+ uc_registered_model_tag_from_mlflow_tags,
142
+ )
143
+ from mlflow.utils.databricks_utils import (
144
+ _print_databricks_deployment_job_url,
145
+ get_databricks_host_creds,
146
+ is_databricks_uri,
147
+ )
148
+ from mlflow.utils.mlflow_tags import (
149
+ MLFLOW_DATABRICKS_JOB_ID,
150
+ MLFLOW_DATABRICKS_JOB_RUN_ID,
151
+ MLFLOW_DATABRICKS_NOTEBOOK_ID,
152
+ )
153
+ from mlflow.utils.proto_json_utils import message_to_json, parse_dict
154
+ from mlflow.utils.rest_utils import (
155
+ _REST_API_PATH_PREFIX,
156
+ call_endpoint,
157
+ extract_all_api_info_for_service,
158
+ extract_api_info_for_service,
159
+ http_request,
160
+ verify_rest_response,
161
+ )
162
+ from mlflow.utils.uri import is_fuse_or_uc_volumes_uri
163
+
164
+ _TRACKING_METHOD_TO_INFO = extract_api_info_for_service(MlflowService, _REST_API_PATH_PREFIX)
165
+ _METHOD_TO_INFO = {
166
+ **extract_api_info_for_service(UcModelRegistryService, _REST_API_PATH_PREFIX),
167
+ **extract_api_info_for_service(UnityCatalogPromptService, _REST_API_PATH_PREFIX),
168
+ }
169
+ _METHOD_TO_ALL_INFO = {
170
+ **extract_all_api_info_for_service(UcModelRegistryService, _REST_API_PATH_PREFIX),
171
+ **extract_all_api_info_for_service(UnityCatalogPromptService, _REST_API_PATH_PREFIX),
172
+ }
173
+
174
+ _logger = logging.getLogger(__name__)
175
+ _DELTA_TABLE = "delta_table"
176
+ _MAX_LINEAGE_DATA_SOURCES = 10
177
+
178
+ # Pre-compiled regex patterns for better performance in search operations
179
+ _CATALOG_PATTERN = re.compile(r"catalog\s*=\s*['\"]([^'\"]+)['\"]", re.IGNORECASE)
180
+ _SCHEMA_PATTERN = re.compile(r"schema\s*=\s*['\"]([^'\"]+)['\"]", re.IGNORECASE)
181
+
182
+
183
+ @dataclass
184
+ class _CatalogSchemaFilter:
185
+ """Internal class to hold parsed catalog, schema, and remaining filter."""
186
+
187
+ catalog_name: str
188
+ schema_name: str
189
+ remaining_filter: Optional[str]
190
+
191
+
192
+ def _require_arg_unspecified(arg_name, arg_value, default_values=None, message=None):
193
+ default_values = [None] if default_values is None else default_values
194
+ if arg_value not in default_values:
195
+ _raise_unsupported_arg(arg_name, message)
196
+
197
+
198
+ def _raise_unsupported_arg(arg_name, message=None):
199
+ messages = [
200
+ f"Argument '{arg_name}' is unsupported for models in the Unity Catalog.",
201
+ ]
202
+ if message is not None:
203
+ messages.append(message)
204
+ raise MlflowException(" ".join(messages))
205
+
206
+
207
+ def _raise_unsupported_method(method, message=None):
208
+ messages = [
209
+ f"Method '{method}' is unsupported for models in the Unity Catalog.",
210
+ ]
211
+ if message is not None:
212
+ messages.append(message)
213
+ raise MlflowException(" ".join(messages))
214
+
215
+
216
+ def _load_model(local_model_dir):
217
+ # Import Model here instead of in the top level, to avoid circular import; the
218
+ # mlflow.models.model module imports from MLflow tracking, which triggers an import of
219
+ # this file during store registry initialization
220
+ from mlflow.models.model import Model
221
+
222
+ try:
223
+ return Model.load(local_model_dir)
224
+ except Exception as e:
225
+ raise MlflowException(
226
+ "Unable to load model metadata. Ensure the source path of the model "
227
+ "being registered points to a valid MLflow model directory "
228
+ "(see https://mlflow.org/docs/latest/models.html#storage-format) containing a "
229
+ "model signature (https://mlflow.org/docs/latest/models.html#model-signature) "
230
+ "specifying both input and output type specifications."
231
+ ) from e
232
+
233
+
234
+ def get_feature_dependencies(model_dir):
235
+ """
236
+ Gets the features which a model depends on. This functionality is only implemented on
237
+ Databricks. In OSS mlflow, the dependencies are always empty ("").
238
+ """
239
+ model = _load_model(model_dir)
240
+ model_info = model.get_model_info()
241
+ if (
242
+ model_info.flavors.get("python_function", {}).get("loader_module")
243
+ == mlflow.models.model._DATABRICKS_FS_LOADER_MODULE
244
+ ):
245
+ raise MlflowException(
246
+ "This model was packaged by Databricks Feature Store and can only be registered on a "
247
+ "Databricks cluster."
248
+ )
249
+ return ""
250
+
251
+
252
+ def get_model_version_dependencies(model_dir):
253
+ """
254
+ Gets the specified dependencies for a particular model version and formats them
255
+ to be passed into CreateModelVersion.
256
+ """
257
+ from mlflow.models.resources import ResourceType
258
+
259
+ model = _load_model(model_dir)
260
+ model_info = model.get_model_info()
261
+ dependencies = []
262
+
263
+ # Try to get model.auth_policy.system_auth_policy.resources. If that is not found or empty,
264
+ # then use model.resources.
265
+ if model.auth_policy:
266
+ databricks_resources = model.auth_policy.get("system_auth_policy", {}).get("resources", {})
267
+ else:
268
+ databricks_resources = model.resources
269
+
270
+ if databricks_resources:
271
+ databricks_dependencies = databricks_resources.get("databricks", {})
272
+ dependencies.extend(
273
+ _fetch_langchain_dependency_from_model_resources(
274
+ databricks_dependencies,
275
+ ResourceType.VECTOR_SEARCH_INDEX.value,
276
+ "DATABRICKS_VECTOR_INDEX",
277
+ )
278
+ )
279
+ dependencies.extend(
280
+ _fetch_langchain_dependency_from_model_resources(
281
+ databricks_dependencies,
282
+ ResourceType.SERVING_ENDPOINT.value,
283
+ "DATABRICKS_MODEL_ENDPOINT",
284
+ )
285
+ )
286
+ dependencies.extend(
287
+ _fetch_langchain_dependency_from_model_resources(
288
+ databricks_dependencies,
289
+ ResourceType.FUNCTION.value,
290
+ "DATABRICKS_UC_FUNCTION",
291
+ )
292
+ )
293
+ dependencies.extend(
294
+ _fetch_langchain_dependency_from_model_resources(
295
+ databricks_dependencies,
296
+ ResourceType.UC_CONNECTION.value,
297
+ "DATABRICKS_UC_CONNECTION",
298
+ )
299
+ )
300
+ dependencies.extend(
301
+ _fetch_langchain_dependency_from_model_resources(
302
+ databricks_dependencies,
303
+ ResourceType.TABLE.value,
304
+ "DATABRICKS_TABLE",
305
+ )
306
+ )
307
+ else:
308
+ # These types of dependencies are required for old models that didn't use
309
+ # resources so they can be registered correctly to UC
310
+ _DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY = "databricks_vector_search_index_name"
311
+ _DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY = "databricks_embeddings_endpoint_name"
312
+ _DATABRICKS_LLM_ENDPOINT_NAME_KEY = "databricks_llm_endpoint_name"
313
+ _DATABRICKS_CHAT_ENDPOINT_NAME_KEY = "databricks_chat_endpoint_name"
314
+ _DB_DEPENDENCY_KEY = "databricks_dependency"
315
+
316
+ databricks_dependencies = model_info.flavors.get("langchain", {}).get(
317
+ _DB_DEPENDENCY_KEY, {}
318
+ )
319
+
320
+ index_names = _fetch_langchain_dependency_from_model_info(
321
+ databricks_dependencies, _DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY
322
+ )
323
+ for index_name in index_names:
324
+ dependencies.append({"type": "DATABRICKS_VECTOR_INDEX", "name": index_name})
325
+ for key in (
326
+ _DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY,
327
+ _DATABRICKS_LLM_ENDPOINT_NAME_KEY,
328
+ _DATABRICKS_CHAT_ENDPOINT_NAME_KEY,
329
+ ):
330
+ endpoint_names = _fetch_langchain_dependency_from_model_info(
331
+ databricks_dependencies, key
332
+ )
333
+ for endpoint_name in endpoint_names:
334
+ dependencies.append({"type": "DATABRICKS_MODEL_ENDPOINT", "name": endpoint_name})
335
+ return dependencies
336
+
337
+
338
+ def _fetch_langchain_dependency_from_model_resources(databricks_dependencies, key, resource_type):
339
+ dependencies = databricks_dependencies.get(key, [])
340
+ deps = []
341
+ for dependency in dependencies:
342
+ if dependency.get("on_behalf_of_user", False):
343
+ continue
344
+ deps.append({"type": resource_type, "name": dependency["name"]})
345
+ return deps
346
+
347
+
348
+ def _fetch_langchain_dependency_from_model_info(databricks_dependencies, key):
349
+ return databricks_dependencies.get(key, [])
350
+
351
+
352
+ class UcModelRegistryStore(BaseRestStore):
353
+ """
354
+ Client for a remote model registry server accessed via REST API calls
355
+
356
+ Args:
357
+ store_uri: URI with scheme 'databricks-uc'
358
+ tracking_uri: URI of the Databricks MLflow tracking server from which to fetch
359
+ run info and download run artifacts, when creating new model
360
+ versions from source artifacts logged to an MLflow run.
361
+ """
362
+
363
+ def __init__(self, store_uri, tracking_uri):
364
+ super().__init__(get_host_creds=functools.partial(get_databricks_host_creds, store_uri))
365
+ self.tracking_uri = tracking_uri
366
+ self.get_tracking_host_creds = functools.partial(get_databricks_host_creds, tracking_uri)
367
+ try:
368
+ self.spark = _get_active_spark_session()
369
+ except Exception:
370
+ pass
371
+
372
+ def _get_response_from_method(self, method):
373
+ method_to_response = {
374
+ CreateRegisteredModelRequest: CreateRegisteredModelResponse,
375
+ UpdateRegisteredModelRequest: UpdateRegisteredModelResponse,
376
+ DeleteRegisteredModelRequest: DeleteRegisteredModelResponse,
377
+ CreateModelVersionRequest: CreateModelVersionResponse,
378
+ FinalizeModelVersionRequest: FinalizeModelVersionResponse,
379
+ UpdateModelVersionRequest: UpdateModelVersionResponse,
380
+ DeleteModelVersionRequest: DeleteModelVersionResponse,
381
+ GetModelVersionDownloadUriRequest: GetModelVersionDownloadUriResponse,
382
+ SearchModelVersionsRequest: SearchModelVersionsResponse,
383
+ GetRegisteredModelRequest: GetRegisteredModelResponse,
384
+ GetModelVersionRequest: GetModelVersionResponse,
385
+ SearchRegisteredModelsRequest: SearchRegisteredModelsResponse,
386
+ GenerateTemporaryModelVersionCredentialsRequest: (
387
+ GenerateTemporaryModelVersionCredentialsResponse
388
+ ),
389
+ GetRun: GetRun.Response,
390
+ SetRegisteredModelAliasRequest: SetRegisteredModelAliasResponse,
391
+ DeleteRegisteredModelAliasRequest: DeleteRegisteredModelAliasResponse,
392
+ SetRegisteredModelTagRequest: SetRegisteredModelTagResponse,
393
+ DeleteRegisteredModelTagRequest: DeleteRegisteredModelTagResponse,
394
+ SetModelVersionTagRequest: SetModelVersionTagResponse,
395
+ DeleteModelVersionTagRequest: DeleteModelVersionTagResponse,
396
+ GetModelVersionByAliasRequest: GetModelVersionByAliasResponse,
397
+ CreatePromptRequest: ProtoPrompt,
398
+ SearchPromptsRequest: SearchPromptsResponse,
399
+ DeletePromptRequest: google.protobuf.empty_pb2.Empty,
400
+ SetPromptTagRequest: google.protobuf.empty_pb2.Empty,
401
+ DeletePromptTagRequest: google.protobuf.empty_pb2.Empty,
402
+ CreatePromptVersionRequest: ProtoPromptVersion,
403
+ GetPromptVersionRequest: ProtoPromptVersion,
404
+ DeletePromptVersionRequest: google.protobuf.empty_pb2.Empty,
405
+ GetPromptVersionByAliasRequest: ProtoPromptVersion,
406
+ UpdatePromptRequest: ProtoPrompt,
407
+ GetPromptRequest: ProtoPrompt,
408
+ SearchPromptVersionsRequest: SearchPromptVersionsResponse,
409
+ SetPromptAliasRequest: google.protobuf.empty_pb2.Empty,
410
+ DeletePromptAliasRequest: google.protobuf.empty_pb2.Empty,
411
+ SetPromptVersionTagRequest: google.protobuf.empty_pb2.Empty,
412
+ DeletePromptVersionTagRequest: google.protobuf.empty_pb2.Empty,
413
+ UpdatePromptVersionRequest: ProtoPromptVersion,
414
+ LinkPromptVersionsToModelsRequest: google.protobuf.empty_pb2.Empty,
415
+ LinkPromptsToTracesRequest: google.protobuf.empty_pb2.Empty,
416
+ LinkPromptVersionsToRunsRequest: google.protobuf.empty_pb2.Empty,
417
+ }
418
+ return method_to_response[method]()
419
+
420
+ def _get_endpoint_from_method(self, method):
421
+ return _METHOD_TO_INFO[method]
422
+
423
+ def _get_all_endpoints_from_method(self, method):
424
+ return _METHOD_TO_ALL_INFO[method]
425
+
426
+ # CRUD API for RegisteredModel objects
427
+
428
+ def create_registered_model(self, name, tags=None, description=None, deployment_job_id=None):
429
+ """
430
+ Create a new registered model in backend store.
431
+
432
+ Args:
433
+ name: Name of the new model. This is expected to be unique in the backend store.
434
+ tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag`
435
+ instances associated with this registered model.
436
+ description: Description of the model.
437
+ deployment_job_id: Optional deployment job id.
438
+
439
+ Returns:
440
+ A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel`
441
+ created in the backend.
442
+
443
+ """
444
+ full_name = get_full_name_from_sc(name, self.spark)
445
+ req_body = message_to_json(
446
+ CreateRegisteredModelRequest(
447
+ name=full_name,
448
+ description=description,
449
+ tags=uc_registered_model_tag_from_mlflow_tags(tags),
450
+ deployment_job_id=str(deployment_job_id) if deployment_job_id else None,
451
+ )
452
+ )
453
+ try:
454
+ response_proto = self._call_endpoint(CreateRegisteredModelRequest, req_body)
455
+ except RestException as e:
456
+
457
+ def reraise_with_legacy_hint(exception, legacy_hint):
458
+ new_message = exception.message.rstrip(".") + f". {legacy_hint}"
459
+ raise MlflowException(
460
+ message=new_message,
461
+ error_code=exception.error_code,
462
+ )
463
+
464
+ if "specify all three levels" in e.message:
465
+ # The exception is likely due to the user trying to create a registered model
466
+ # in Unity Catalog without specifying a 3-level name (catalog.schema.model).
467
+ # The user may not be intending to use the Unity Catalog Model Registry at all,
468
+ # but rather the legacy Workspace Model Registry. Accordingly, we re-raise with
469
+ # a hint
470
+ legacy_hint = (
471
+ "If you are trying to use the legacy Workspace Model Registry, instead of the"
472
+ " recommended Unity Catalog Model Registry, set the Model Registry URI to"
473
+ " 'databricks' (legacy) instead of 'databricks-uc' (recommended)."
474
+ )
475
+ reraise_with_legacy_hint(exception=e, legacy_hint=legacy_hint)
476
+ elif "METASTORE_DOES_NOT_EXIST" in e.message:
477
+ legacy_hint = (
478
+ "If you are trying to use the Model Registry in a Databricks workspace that"
479
+ " does not have Unity Catalog enabled, either enable Unity Catalog in the"
480
+ " workspace (recommended) or set the Model Registry URI to 'databricks' to"
481
+ " use the legacy Workspace Model Registry."
482
+ )
483
+ reraise_with_legacy_hint(exception=e, legacy_hint=legacy_hint)
484
+ else:
485
+ raise
486
+
487
+ if deployment_job_id:
488
+ _print_databricks_deployment_job_url(
489
+ model_name=full_name,
490
+ job_id=str(deployment_job_id),
491
+ )
492
+ return registered_model_from_uc_proto(response_proto.registered_model)
493
+
494
+ def update_registered_model(self, name, description=None, deployment_job_id=None):
495
+ """
496
+ Update description of the registered model.
497
+
498
+ Args:
499
+ name: Registered model name.
500
+ description: New description.
501
+ deployment_job_id: Optional deployment job id.
502
+
503
+ Returns:
504
+ A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
505
+ """
506
+ full_name = get_full_name_from_sc(name, self.spark)
507
+ req_body = message_to_json(
508
+ UpdateRegisteredModelRequest(
509
+ name=full_name,
510
+ description=description,
511
+ deployment_job_id=str(deployment_job_id) if deployment_job_id else None,
512
+ )
513
+ )
514
+ response_proto = self._call_endpoint(UpdateRegisteredModelRequest, req_body)
515
+ if deployment_job_id:
516
+ _print_databricks_deployment_job_url(
517
+ model_name=full_name,
518
+ job_id=str(deployment_job_id),
519
+ )
520
+ return registered_model_from_uc_proto(response_proto.registered_model)
521
+
522
+ def rename_registered_model(self, name, new_name):
523
+ """
524
+ Rename the registered model.
525
+
526
+ Args:
527
+ name: Registered model name.
528
+ new_name: New proposed name.
529
+
530
+ Returns:
531
+ A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
532
+ """
533
+ full_name = get_full_name_from_sc(name, self.spark)
534
+ req_body = message_to_json(UpdateRegisteredModelRequest(name=full_name, new_name=new_name))
535
+ response_proto = self._call_endpoint(UpdateRegisteredModelRequest, req_body)
536
+ return registered_model_from_uc_proto(response_proto.registered_model)
537
+
538
+ def delete_registered_model(self, name):
539
+ """
540
+ Delete the registered model.
541
+ Backend raises exception if a registered model with given name does not exist.
542
+
543
+ Args:
544
+ name: Registered model name.
545
+
546
+ Returns:
547
+ None
548
+ """
549
+ full_name = get_full_name_from_sc(name, self.spark)
550
+ req_body = message_to_json(DeleteRegisteredModelRequest(name=full_name))
551
+ self._call_endpoint(DeleteRegisteredModelRequest, req_body)
552
+
553
+ def search_registered_models(
554
+ self, filter_string=None, max_results=None, order_by=None, page_token=None
555
+ ):
556
+ """
557
+ Search for registered models in backend that satisfy the filter criteria.
558
+
559
+ Args:
560
+ filter_string: Filter query string, defaults to searching all registered models.
561
+ max_results: Maximum number of registered models desired.
562
+ order_by: List of column names with ASC|DESC annotation, to be used for ordering
563
+ matching search results.
564
+ page_token: Token specifying the next page of results. It should be obtained from
565
+ a ``search_registered_models`` call.
566
+
567
+ Returns:
568
+ A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
569
+ that satisfy the search expressions. The pagination token for the next page can be
570
+ obtained via the ``token`` attribute of the object.
571
+
572
+ """
573
+ _require_arg_unspecified("filter_string", filter_string)
574
+ _require_arg_unspecified("order_by", order_by)
575
+ req_body = message_to_json(
576
+ SearchRegisteredModelsRequest(
577
+ max_results=max_results,
578
+ page_token=page_token,
579
+ )
580
+ )
581
+ response_proto = self._call_endpoint(SearchRegisteredModelsRequest, req_body)
582
+ registered_models = [
583
+ registered_model_search_from_uc_proto(registered_model)
584
+ for registered_model in response_proto.registered_models
585
+ ]
586
+ return PagedList(registered_models, response_proto.next_page_token)
587
+
588
+ def get_registered_model(self, name):
589
+ """
590
+ Get registered model instance by name.
591
+
592
+ Args:
593
+ name: Registered model name.
594
+
595
+ Returns:
596
+ A single :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
597
+ """
598
+ full_name = get_full_name_from_sc(name, self.spark)
599
+ req_body = message_to_json(GetRegisteredModelRequest(name=full_name))
600
+ response_proto = self._call_endpoint(GetRegisteredModelRequest, req_body)
601
+ return registered_model_from_uc_proto(response_proto.registered_model)
602
+
603
+ def get_latest_versions(self, name, stages=None):
604
+ """
605
+ Latest version models for each requested stage. If no ``stages`` argument is provided,
606
+ returns the latest version for each stage.
607
+
608
+ Args:
609
+ name: Registered model name.
610
+ stages: List of desired stages. If input list is None, return latest versions for
611
+ each stage.
612
+
613
+ Returns:
614
+ List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
615
+ """
616
+ alias_doc_url = "https://mlflow.org/docs/latest/model-registry.html#deploy-and-organize-models-with-aliases-and-tags"
617
+ if stages is None:
618
+ message = (
619
+ "To load the latest version of a model in Unity Catalog, you can "
620
+ "set an alias on the model version and load it by alias. See "
621
+ f"{alias_doc_url} for details."
622
+ )
623
+ else:
624
+ message = (
625
+ f"Detected attempt to load latest model version in stages {stages}. "
626
+ "You may see this error because:\n"
627
+ "1) You're attempting to load a model version by stage. Setting stages "
628
+ "and loading model versions by stage is unsupported in Unity Catalog. Instead, "
629
+ "use aliases for flexible model deployment. See "
630
+ f"{alias_doc_url} for details.\n"
631
+ "2) You're attempting to load a model version by alias. Use "
632
+ "syntax 'models:/your_model_name@your_alias_name'\n"
633
+ "3) You're attempting load a model version by version number. Verify "
634
+ "that the version number is a valid integer"
635
+ )
636
+
637
+ _raise_unsupported_method(
638
+ method="get_latest_versions",
639
+ message=message,
640
+ )
641
+
642
+ def set_registered_model_tag(self, name, tag):
643
+ """
644
+ Set a tag for the registered model.
645
+
646
+ Args:
647
+ name: Registered model name.
648
+ tag: :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log.
649
+
650
+ Returns:
651
+ None
652
+ """
653
+ full_name = get_full_name_from_sc(name, self.spark)
654
+ req_body = message_to_json(
655
+ SetRegisteredModelTagRequest(name=full_name, key=tag.key, value=tag.value)
656
+ )
657
+ self._call_endpoint(SetRegisteredModelTagRequest, req_body)
658
+
659
+ def delete_registered_model_tag(self, name, key):
660
+ """
661
+ Delete a tag associated with the registered model.
662
+
663
+ Args:
664
+ name: Registered model name.
665
+ key: Registered model tag key.
666
+
667
+ Returns:
668
+ None
669
+ """
670
+ full_name = get_full_name_from_sc(name, self.spark)
671
+ req_body = message_to_json(DeleteRegisteredModelTagRequest(name=full_name, key=key))
672
+ self._call_endpoint(DeleteRegisteredModelTagRequest, req_body)
673
+
674
+ # CRUD API for ModelVersion objects
675
+ def _finalize_model_version(self, name, version):
676
+ """
677
+ Finalize a UC model version after its files have been written to managed storage,
678
+ updating its status from PENDING_REGISTRATION to READY
679
+
680
+ Args:
681
+ name: Registered model name
682
+ version: Model version number
683
+
684
+ Returns:
685
+ Protobuf ModelVersion describing the finalized model version
686
+ """
687
+ req_body = message_to_json(FinalizeModelVersionRequest(name=name, version=version))
688
+ return self._call_endpoint(FinalizeModelVersionRequest, req_body).model_version
689
+
690
+ def _get_temporary_model_version_write_credentials(self, name, version) -> TemporaryCredentials:
691
+ """
692
+ Get temporary credentials for uploading model version files
693
+
694
+ Args:
695
+ name: Registered model name.
696
+ version: Model version number.
697
+
698
+ Returns:
699
+ mlflow.protos.databricks_uc_registry_messages_pb2.TemporaryCredentials containing
700
+ temporary model version credentials.
701
+ """
702
+ req_body = message_to_json(
703
+ GenerateTemporaryModelVersionCredentialsRequest(
704
+ name=name, version=version, operation=MODEL_VERSION_OPERATION_READ_WRITE
705
+ )
706
+ )
707
+ return self._call_endpoint(
708
+ GenerateTemporaryModelVersionCredentialsRequest, req_body
709
+ ).credentials
710
+
711
+ def _get_run_and_headers(self, run_id):
712
+ if run_id is None or not is_databricks_uri(self.tracking_uri):
713
+ return None, None
714
+ host_creds = self.get_tracking_host_creds()
715
+ endpoint, method = _TRACKING_METHOD_TO_INFO[GetRun]
716
+ response = http_request(
717
+ host_creds=host_creds,
718
+ endpoint=endpoint,
719
+ method=method,
720
+ params={"run_id": run_id},
721
+ )
722
+ try:
723
+ verify_rest_response(response, endpoint)
724
+ except MlflowException:
725
+ _logger.warning(
726
+ f"Unable to fetch model version's source run (with ID {run_id}) "
727
+ "from tracking server. The source run may be deleted or inaccessible to the "
728
+ "current user. No run link will be recorded for the model version."
729
+ )
730
+ return None, None
731
+ headers = response.headers
732
+ js_dict = response.json()
733
+ parsed_response = GetRun.Response()
734
+ parse_dict(js_dict=js_dict, message=parsed_response)
735
+ run = Run.from_proto(parsed_response.run)
736
+ return headers, run
737
+
738
+ def _get_workspace_id(self, headers):
739
+ if headers is None or _DATABRICKS_ORG_ID_HEADER not in headers:
740
+ _logger.warning(
741
+ "Unable to get model version source run's workspace ID from request headers. "
742
+ "No run link will be recorded for the model version"
743
+ )
744
+ return None
745
+ return headers[_DATABRICKS_ORG_ID_HEADER]
746
+
747
+ def _get_notebook_id(self, run):
748
+ if run is None:
749
+ return None
750
+ return run.data.tags.get(MLFLOW_DATABRICKS_NOTEBOOK_ID, None)
751
+
752
+ def _get_job_id(self, run):
753
+ if run is None:
754
+ return None
755
+ return run.data.tags.get(MLFLOW_DATABRICKS_JOB_ID, None)
756
+
757
+ def _get_job_run_id(self, run):
758
+ if run is None:
759
+ return None
760
+ return run.data.tags.get(MLFLOW_DATABRICKS_JOB_RUN_ID, None)
761
+
762
+ def _get_lineage_input_sources(self, run):
763
+ from mlflow.data.delta_dataset_source import DeltaDatasetSource
764
+
765
+ if run is None:
766
+ return None
767
+ securable_list = []
768
+ if run.inputs is not None:
769
+ for dataset in run.inputs.dataset_inputs:
770
+ dataset_source = mlflow.data.get_source(dataset)
771
+ if (
772
+ isinstance(dataset_source, DeltaDatasetSource)
773
+ and dataset_source._get_source_type() == _DELTA_TABLE
774
+ ):
775
+ # check if dataset is a uc table and then append
776
+ if dataset_source.delta_table_name and dataset_source.delta_table_id:
777
+ table_entity = Table(
778
+ name=dataset_source.delta_table_name,
779
+ table_id=dataset_source.delta_table_id,
780
+ )
781
+ securable_list.append(Securable(table=table_entity))
782
+ if len(securable_list) > _MAX_LINEAGE_DATA_SOURCES:
783
+ _logger.warning(
784
+ f"Model version has {len(securable_list)!s} upstream datasets, which "
785
+ f"exceeds the max of 10 upstream datasets for lineage tracking. Only "
786
+ f"the first 10 datasets will be propagated to Unity Catalog lineage"
787
+ )
788
+ return securable_list[0:_MAX_LINEAGE_DATA_SOURCES]
789
+ else:
790
+ return None
791
+
792
+ def _validate_model_signature(self, local_model_path):
793
+ # Import Model here instead of in the top level, to avoid circular import; the
794
+ # mlflow.models.model module imports from MLflow tracking, which triggers an import of
795
+ # this file during store registry initialization
796
+ model = _load_model(local_model_path)
797
+ signature_required_explanation = (
798
+ "All models in the Unity Catalog must be logged with a "
799
+ "model signature containing both input and output "
800
+ "type specifications. See "
801
+ "https://mlflow.org/docs/latest/model/signatures.html#how-to-log-models-with-signatures"
802
+ " for details on how to log a model with a signature"
803
+ )
804
+ if model.signature is None:
805
+ raise MlflowException(
806
+ "Model passed for registration did not contain any signature metadata. "
807
+ f"{signature_required_explanation}"
808
+ )
809
+ if model.signature.outputs is None:
810
+ raise MlflowException(
811
+ "Model passed for registration contained a signature that includes only inputs. "
812
+ f"{signature_required_explanation}"
813
+ )
814
+
815
+ def _download_model_weights_if_not_saved(self, local_model_path):
816
+ """
817
+ Transformers models can be saved without the base model weights by setting
818
+ `save_pretrained=False` when saving or logging the model. Such 'weight-less'
819
+ model cannot be directly deployed to model serving, so here we download the
820
+ weights proactively from the HuggingFace hub and save them to the model directory.
821
+ """
822
+ model = _load_model(local_model_path)
823
+ flavor_conf = model.flavors.get("transformers")
824
+
825
+ if not flavor_conf:
826
+ return
827
+
828
+ from mlflow.transformers.flavor_config import FlavorKey
829
+ from mlflow.transformers.model_io import _MODEL_BINARY_FILE_NAME
830
+
831
+ if (
832
+ FlavorKey.MODEL_BINARY in flavor_conf
833
+ and os.path.exists(os.path.join(local_model_path, _MODEL_BINARY_FILE_NAME))
834
+ and FlavorKey.MODEL_REVISION not in flavor_conf
835
+ ):
836
+ # Model weights are already saved
837
+ return
838
+
839
+ _logger.info(
840
+ "You are attempting to register a transformers model that does not have persisted "
841
+ "model weights. Attempting to fetch the weights so that the model can be registered "
842
+ "within Unity Catalog."
843
+ )
844
+ try:
845
+ mlflow.transformers.persist_pretrained_model(local_model_path)
846
+ except Exception as e:
847
+ raise MlflowException(
848
+ "Failed to download the model weights from the HuggingFace hub and cannot register "
849
+ "the model in the Unity Catalog. Please ensure that the model was saved with the "
850
+ "correct reference to the HuggingFace hub repository and that you have access to "
851
+ "fetch model weights from the defined repository.",
852
+ error_code=INTERNAL_ERROR,
853
+ ) from e
854
+
855
+ @contextmanager
856
+ def _local_model_dir(self, source, local_model_path):
857
+ if local_model_path is not None:
858
+ yield local_model_path
859
+ else:
860
+ try:
861
+ local_model_dir = mlflow.artifacts.download_artifacts(
862
+ artifact_uri=source, tracking_uri=self.tracking_uri
863
+ )
864
+ except Exception as e:
865
+ raise MlflowException(
866
+ f"Unable to download model artifacts from source artifact location "
867
+ f"'{source}' in order to upload them to Unity Catalog. Please ensure "
868
+ f"the source artifact location exists and that you can download from "
869
+ f"it via mlflow.artifacts.download_artifacts()"
870
+ ) from e
871
+ try:
872
+ yield local_model_dir
873
+ finally:
874
+ # Clean up temporary model directory at end of block. We assume a temporary
875
+ # model directory was created if the `source` is not a local path
876
+ # (must be downloaded from remote to a temporary directory) and
877
+ # `local_model_dir` is not a FUSE-mounted path. The check for FUSE-mounted
878
+ # paths is important as mlflow.artifacts.download_artifacts() can return
879
+ # a FUSE mounted path equivalent to the (remote) source path in some cases,
880
+ # e.g. return /dbfs/some/path for source dbfs:/some/path.
881
+ if not os.path.exists(source) and not is_fuse_or_uc_volumes_uri(local_model_dir):
882
+ shutil.rmtree(local_model_dir)
883
+
884
+ def _get_logged_model_from_model_id(self, model_id) -> Optional[LoggedModel]:
885
+ # load the MLflow LoggedModel by model_id and
886
+ if model_id is None:
887
+ return None
888
+ return mlflow.get_logged_model(model_id)
889
+
890
+ def create_model_version(
891
+ self,
892
+ name,
893
+ source,
894
+ run_id=None,
895
+ tags=None,
896
+ run_link=None,
897
+ description=None,
898
+ local_model_path=None,
899
+ model_id: Optional[str] = None,
900
+ ):
901
+ """
902
+ Create a new model version from given source and run ID.
903
+
904
+ Args:
905
+ name: Registered model name.
906
+ source: URI indicating the location of the model artifacts.
907
+ run_id: Run ID from MLflow tracking server that generated the model.
908
+ tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
909
+ instances associated with this model version.
910
+ run_link: Link to the run from an MLflow tracking server that generated this model.
911
+ description: Description of the version.
912
+ local_model_path: Local path to the MLflow model, if it's already accessible on the
913
+ local filesystem. Can be used by AbstractStores that upload model version files
914
+ to the model registry to avoid a redundant download from the source location when
915
+ logging and registering a model via a single
916
+ mlflow.<flavor>.log_model(..., registered_model_name) call.
917
+ model_id: The ID of the model (from an Experiment) that is being promoted to a
918
+ registered model version, if applicable.
919
+
920
+ Returns:
921
+ A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
922
+ created in the backend.
923
+ """
924
+ _require_arg_unspecified(arg_name="run_link", arg_value=run_link)
925
+ logged_model = self._get_logged_model_from_model_id(model_id)
926
+ if logged_model:
927
+ run_id = logged_model.source_run_id
928
+ headers, run = self._get_run_and_headers(run_id)
929
+ source_workspace_id = self._get_workspace_id(headers)
930
+ notebook_id = self._get_notebook_id(run)
931
+ lineage_securable_list = self._get_lineage_input_sources(run)
932
+ job_id = self._get_job_id(run)
933
+ job_run_id = self._get_job_run_id(run)
934
+ extra_headers = None
935
+ if notebook_id is not None or job_id is not None:
936
+ entity_list = []
937
+ lineage_list = None
938
+ if notebook_id is not None:
939
+ notebook_entity = Notebook(id=str(notebook_id))
940
+ entity_list.append(Entity(notebook=notebook_entity))
941
+ if job_id is not None:
942
+ job_entity = Job(id=job_id, job_run_id=job_run_id)
943
+ entity_list.append(Entity(job=job_entity))
944
+ if lineage_securable_list is not None:
945
+ lineage_list = [Lineage(source_securables=lineage_securable_list)]
946
+ lineage_header_info = LineageHeaderInfo(entities=entity_list, lineages=lineage_list)
947
+ # Base64-encode the header value to ensure it's valid ASCII,
948
+ # similar to JWT (see https://stackoverflow.com/a/40347926)
949
+ header_json = message_to_json(lineage_header_info)
950
+ header_base64 = base64.b64encode(header_json.encode())
951
+ extra_headers = {_DATABRICKS_LINEAGE_ID_HEADER: header_base64}
952
+ full_name = get_full_name_from_sc(name, self.spark)
953
+ with self._local_model_dir(source, local_model_path) as local_model_dir:
954
+ self._validate_model_signature(local_model_dir)
955
+ self._download_model_weights_if_not_saved(local_model_dir)
956
+ feature_deps = get_feature_dependencies(local_model_dir)
957
+ other_model_deps = get_model_version_dependencies(local_model_dir)
958
+ req_body = message_to_json(
959
+ CreateModelVersionRequest(
960
+ name=full_name,
961
+ source=source,
962
+ run_id=run_id,
963
+ description=description,
964
+ tags=uc_model_version_tag_from_mlflow_tags(tags),
965
+ run_tracking_server_id=source_workspace_id,
966
+ feature_deps=feature_deps,
967
+ model_version_dependencies=other_model_deps,
968
+ model_id=model_id,
969
+ )
970
+ )
971
+ model_version = self._call_endpoint(
972
+ CreateModelVersionRequest, req_body, extra_headers=extra_headers
973
+ ).model_version
974
+
975
+ store = self._get_artifact_repo(model_version, full_name)
976
+ store.log_artifacts(local_dir=local_model_dir, artifact_path="")
977
+ finalized_mv = self._finalize_model_version(
978
+ name=full_name, version=model_version.version
979
+ )
980
+ return model_version_from_uc_proto(finalized_mv)
981
+
982
+ def _get_artifact_repo(self, model_version, model_name=None):
983
+ def base_credential_refresh_def():
984
+ return self._get_temporary_model_version_write_credentials(
985
+ name=model_version.name, version=model_version.version
986
+ )
987
+
988
+ if is_databricks_sdk_models_artifact_repository_enabled(self.get_host_creds()):
989
+ return DatabricksSDKModelsArtifactRepository(model_name, model_version.version)
990
+
991
+ scoped_token = base_credential_refresh_def()
992
+ if scoped_token.storage_mode == StorageMode.DEFAULT_STORAGE:
993
+ return PresignedUrlArtifactRepository(
994
+ self.get_host_creds(), model_version.name, model_version.version
995
+ )
996
+
997
+ return get_artifact_repo_from_storage_info(
998
+ storage_location=model_version.storage_location,
999
+ scoped_token=scoped_token,
1000
+ base_credential_refresh_def=base_credential_refresh_def,
1001
+ )
1002
+
1003
+ def transition_model_version_stage(self, name, version, stage, archive_existing_versions):
1004
+ """
1005
+ Update model version stage.
1006
+
1007
+ Args:
1008
+ name: Registered model name.
1009
+ version: Registered model version.
1010
+ stage: New desired stage for this model version.
1011
+ archive_existing_versions: If this flag is set to ``True``, all existing model
1012
+ versions in the stage will be automatically moved to the "archived" stage. Only
1013
+ valid when ``stage`` is ``"staging"`` or ``"production"`` otherwise an error will be
1014
+ raised.
1015
+ """
1016
+ _raise_unsupported_method(
1017
+ method="transition_model_version_stage",
1018
+ message="We recommend using aliases instead of stages for more flexible model "
1019
+ "deployment management. You can set an alias on a registered model using "
1020
+ "`MlflowClient().set_registered_model_alias(name, alias, version)` and load a model "
1021
+ "version by alias using the URI 'models:/your_model_name@your_alias', e.g. "
1022
+ "`mlflow.pyfunc.load_model('models:/your_model_name@your_alias')`.",
1023
+ )
1024
+
1025
+ def update_model_version(self, name, version, description):
1026
+ """
1027
+ Update metadata associated with a model version in backend.
1028
+
1029
+ Args:
1030
+ name: Registered model name.
1031
+ version: Registered model version.
1032
+ description: New model description.
1033
+
1034
+ Returns:
1035
+ A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
1036
+
1037
+ """
1038
+ full_name = get_full_name_from_sc(name, self.spark)
1039
+ req_body = message_to_json(
1040
+ UpdateModelVersionRequest(name=full_name, version=str(version), description=description)
1041
+ )
1042
+ response_proto = self._call_endpoint(UpdateModelVersionRequest, req_body)
1043
+ return model_version_from_uc_proto(response_proto.model_version)
1044
+
1045
+ def delete_model_version(self, name, version):
1046
+ """
1047
+ Delete model version in backend.
1048
+
1049
+ Args:
1050
+ name: Registered model name.
1051
+ version: Registered model version.
1052
+
1053
+ Returns:
1054
+ None
1055
+ """
1056
+ full_name = get_full_name_from_sc(name, self.spark)
1057
+ req_body = message_to_json(DeleteModelVersionRequest(name=full_name, version=str(version)))
1058
+ self._call_endpoint(DeleteModelVersionRequest, req_body)
1059
+
1060
+ def get_model_version(self, name, version):
1061
+ """
1062
+ Get the model version instance by name and version.
1063
+
1064
+ Args:
1065
+ name: Registered model name.
1066
+ version: Registered model version.
1067
+
1068
+ Returns:
1069
+ A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
1070
+ """
1071
+ full_name = get_full_name_from_sc(name, self.spark)
1072
+ req_body = message_to_json(GetModelVersionRequest(name=full_name, version=str(version)))
1073
+ response_proto = self._call_endpoint(GetModelVersionRequest, req_body)
1074
+ return model_version_from_uc_proto(response_proto.model_version)
1075
+
1076
+ def get_model_version_download_uri(self, name, version):
1077
+ """
1078
+ Get the download location in Model Registry for this model version.
1079
+ NOTE: For first version of Model Registry, since the models are not copied over to another
1080
+ location, download URI points to input source path.
1081
+
1082
+ Args:
1083
+ name: Registered model name.
1084
+ version: Registered model version.
1085
+
1086
+ Returns:
1087
+ A single URI location that allows reads for downloading.
1088
+ """
1089
+ full_name = get_full_name_from_sc(name, self.spark)
1090
+ req_body = message_to_json(
1091
+ GetModelVersionDownloadUriRequest(name=full_name, version=str(version))
1092
+ )
1093
+ response_proto = self._call_endpoint(GetModelVersionDownloadUriRequest, req_body)
1094
+ return response_proto.artifact_uri
1095
+
1096
+ def search_model_versions(
1097
+ self, filter_string=None, max_results=None, order_by=None, page_token=None
1098
+ ):
1099
+ """
1100
+ Search for model versions in backend that satisfy the filter criteria.
1101
+
1102
+ Args:
1103
+ filter_string: A filter string expression. Currently supports a single filter
1104
+ condition either name of model like ``name = 'model_name'`` or
1105
+ ``run_id = '...'``.
1106
+ max_results: Maximum number of model versions desired.
1107
+ order_by: List of column names with ASC|DESC annotation, to be used for ordering
1108
+ matching search results.
1109
+ page_token: Token specifying the next page of results. It should be obtained from
1110
+ a ``search_model_versions`` call.
1111
+
1112
+ Returns:
1113
+ A PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
1114
+ objects that satisfy the search expressions. The pagination token for the next
1115
+ page can be obtained via the ``token`` attribute of the object.
1116
+
1117
+ """
1118
+ _require_arg_unspecified(arg_name="order_by", arg_value=order_by)
1119
+ req_body = message_to_json(
1120
+ SearchModelVersionsRequest(
1121
+ filter=filter_string, page_token=page_token, max_results=max_results
1122
+ )
1123
+ )
1124
+ response_proto = self._call_endpoint(SearchModelVersionsRequest, req_body)
1125
+ model_versions = [
1126
+ model_version_search_from_uc_proto(mvd) for mvd in response_proto.model_versions
1127
+ ]
1128
+ return PagedList(model_versions, response_proto.next_page_token)
1129
+
1130
+ def set_model_version_tag(self, name, version, tag):
1131
+ """
1132
+ Set a tag for the model version.
1133
+
1134
+ Args:
1135
+ name: Registered model name.
1136
+ version: Registered model version.
1137
+ tag: :py:class:`mlflow.entities.model_registry.ModelVersionTag` instance to log.
1138
+ """
1139
+ full_name = get_full_name_from_sc(name, self.spark)
1140
+ req_body = message_to_json(
1141
+ SetModelVersionTagRequest(
1142
+ name=full_name, version=str(version), key=tag.key, value=tag.value
1143
+ )
1144
+ )
1145
+ self._call_endpoint(SetModelVersionTagRequest, req_body)
1146
+
1147
+ def delete_model_version_tag(self, name, version, key):
1148
+ """
1149
+ Delete a tag associated with the model version.
1150
+
1151
+ Args:
1152
+ name: Registered model name.
1153
+ version: Registered model version.
1154
+ key: Tag key.
1155
+ """
1156
+ full_name = get_full_name_from_sc(name, self.spark)
1157
+ req_body = message_to_json(
1158
+ DeleteModelVersionTagRequest(name=full_name, version=version, key=key)
1159
+ )
1160
+ self._call_endpoint(DeleteModelVersionTagRequest, req_body)
1161
+
1162
+ def set_registered_model_alias(self, name, alias, version):
1163
+ """
1164
+ Set a registered model alias pointing to a model version.
1165
+
1166
+ Args:
1167
+ name: Registered model name.
1168
+ alias: Name of the alias.
1169
+ version: Registered model version number.
1170
+
1171
+ Returns:
1172
+ None
1173
+ """
1174
+ full_name = get_full_name_from_sc(name, self.spark)
1175
+ req_body = message_to_json(
1176
+ SetRegisteredModelAliasRequest(name=full_name, alias=alias, version=str(version))
1177
+ )
1178
+ self._call_endpoint(SetRegisteredModelAliasRequest, req_body)
1179
+
1180
+ def delete_registered_model_alias(self, name, alias):
1181
+ """
1182
+ Delete an alias associated with a registered model.
1183
+
1184
+ Args:
1185
+ name: Registered model name.
1186
+ alias: Name of the alias.
1187
+
1188
+ Returns:
1189
+ None
1190
+ """
1191
+ full_name = get_full_name_from_sc(name, self.spark)
1192
+ req_body = message_to_json(DeleteRegisteredModelAliasRequest(name=full_name, alias=alias))
1193
+ self._call_endpoint(DeleteRegisteredModelAliasRequest, req_body)
1194
+
1195
+ def get_model_version_by_alias(self, name, alias):
1196
+ """
1197
+ Get the model version instance by name and alias.
1198
+
1199
+ Args:
1200
+ name: Registered model name.
1201
+ alias: Name of the alias.
1202
+
1203
+ Returns:
1204
+ A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
1205
+ """
1206
+ full_name = get_full_name_from_sc(name, self.spark)
1207
+ req_body = message_to_json(GetModelVersionByAliasRequest(name=full_name, alias=alias))
1208
+ response_proto = self._call_endpoint(GetModelVersionByAliasRequest, req_body)
1209
+ return model_version_from_uc_proto(response_proto.model_version)
1210
+
1211
+ def _await_model_version_creation(self, mv, await_creation_for):
1212
+ """
1213
+ Does not wait for the model version to become READY as a successful creation will
1214
+ immediately place the model version in a READY state.
1215
+ """
1216
+
1217
+ # Prompt-related method overrides for UC
1218
+
1219
+ def create_prompt(
1220
+ self,
1221
+ name: str,
1222
+ description: Optional[str] = None,
1223
+ tags: Optional[dict[str, str]] = None,
1224
+ ) -> Prompt:
1225
+ """
1226
+ Create a new prompt in Unity Catalog (metadata only, no initial version).
1227
+ """
1228
+ # Create a Prompt object with the provided fields
1229
+ prompt_proto = ProtoPrompt()
1230
+ prompt_proto.name = name
1231
+ if description:
1232
+ prompt_proto.description = description
1233
+ if tags:
1234
+ prompt_proto.tags.extend(mlflow_tags_to_proto(tags))
1235
+
1236
+ req_body = message_to_json(
1237
+ CreatePromptRequest(
1238
+ name=name,
1239
+ prompt=prompt_proto,
1240
+ )
1241
+ )
1242
+ response_proto = self._call_endpoint(CreatePromptRequest, req_body)
1243
+ return proto_info_to_mlflow_prompt_info(response_proto, tags or {})
1244
+
1245
+ def search_prompts(
1246
+ self,
1247
+ filter_string: Optional[str] = None,
1248
+ max_results: Optional[int] = None,
1249
+ order_by: Optional[list[str]] = None,
1250
+ page_token: Optional[str] = None,
1251
+ ) -> PagedList[Prompt]:
1252
+ """
1253
+ Search for prompts in Unity Catalog.
1254
+
1255
+ Args:
1256
+ filter_string: Filter string that must include catalog and schema in the format:
1257
+ "catalog = 'catalog_name' AND schema = 'schema_name'"
1258
+ max_results: Maximum number of results to return
1259
+ order_by: List of fields to order by (not used in current implementation)
1260
+ page_token: Token for pagination
1261
+ """
1262
+ # Parse catalog and schema from filter string
1263
+ if filter_string:
1264
+ parsed_filter = self._parse_catalog_schema_from_filter(filter_string)
1265
+ else:
1266
+ raise MlflowException(
1267
+ "For Unity Catalog prompt registries, you must specify catalog and schema "
1268
+ "in the filter string: \"catalog = 'catalog_name' AND schema = 'schema_name'\"",
1269
+ INVALID_PARAMETER_VALUE,
1270
+ )
1271
+
1272
+ # Build the request with Unity Catalog schema
1273
+ unity_catalog_schema = UnityCatalogSchema(
1274
+ catalog_name=parsed_filter.catalog_name, schema_name=parsed_filter.schema_name
1275
+ )
1276
+ req_body = message_to_json(
1277
+ SearchPromptsRequest(
1278
+ catalog_schema=unity_catalog_schema,
1279
+ filter=parsed_filter.remaining_filter,
1280
+ max_results=max_results,
1281
+ page_token=page_token,
1282
+ )
1283
+ )
1284
+
1285
+ response_proto = self._call_endpoint(SearchPromptsRequest, req_body)
1286
+ prompts = []
1287
+ for prompt_info in response_proto.prompts:
1288
+ # For UC, only use the basic prompt info without extra tag fetching
1289
+ prompts.append(proto_info_to_mlflow_prompt_info(prompt_info, {}))
1290
+
1291
+ return PagedList(prompts, response_proto.next_page_token)
1292
+
1293
+ def _parse_catalog_schema_from_filter(
1294
+ self, filter_string: Optional[str]
1295
+ ) -> _CatalogSchemaFilter:
1296
+ """
1297
+ Parse catalog and schema from filter string for Unity Catalog using regex.
1298
+
1299
+ Expects filter format: "catalog = 'catalog_name' AND schema = 'schema_name'"
1300
+
1301
+ Args:
1302
+ filter_string: Filter string containing catalog and schema
1303
+
1304
+ Returns:
1305
+ _CatalogSchemaFilter object with catalog_name, schema_name, and remaining_filter
1306
+
1307
+ Raises:
1308
+ MlflowException: If filter format is invalid for Unity Catalog
1309
+ """
1310
+ if not filter_string:
1311
+ raise MlflowException(
1312
+ "For Unity Catalog prompt registries, you must specify catalog and schema "
1313
+ "in the filter string: \"catalog = 'catalog_name' AND schema = 'schema_name'\"",
1314
+ INVALID_PARAMETER_VALUE,
1315
+ )
1316
+
1317
+ # Use pre-compiled regex patterns for better performance
1318
+ catalog_match = _CATALOG_PATTERN.search(filter_string)
1319
+ schema_match = _SCHEMA_PATTERN.search(filter_string)
1320
+
1321
+ if not catalog_match or not schema_match:
1322
+ raise MlflowException(
1323
+ "For Unity Catalog prompt registries, filter string must include both "
1324
+ "catalog and schema in the format: "
1325
+ "\"catalog = 'catalog_name' AND schema = 'schema_name'\". "
1326
+ f"Got: {filter_string}",
1327
+ INVALID_PARAMETER_VALUE,
1328
+ )
1329
+
1330
+ catalog_name = catalog_match.group(1)
1331
+ schema_name = schema_match.group(1)
1332
+
1333
+ # Remove catalog and schema from filter string to get remaining filters
1334
+ # First, normalize the filter by splitting on AND and rebuilding
1335
+ # without catalog/schema parts
1336
+ parts = re.split(r"\s+AND\s+", filter_string, flags=re.IGNORECASE)
1337
+ remaining_parts = []
1338
+
1339
+ for part in parts:
1340
+ part = part.strip()
1341
+ # Skip parts that match catalog or schema patterns
1342
+ if not (_CATALOG_PATTERN.match(part) or _SCHEMA_PATTERN.match(part)):
1343
+ remaining_parts.append(part)
1344
+
1345
+ # Rejoin the remaining parts
1346
+ remaining_filter = " AND ".join(remaining_parts) if remaining_parts else None
1347
+
1348
+ return _CatalogSchemaFilter(catalog_name, schema_name, remaining_filter)
1349
+
1350
+ def delete_prompt(self, name: str) -> None:
1351
+ """
1352
+ Delete a prompt from Unity Catalog.
1353
+ """
1354
+ req_body = message_to_json(DeletePromptRequest(name=name))
1355
+ endpoint, method = self._get_endpoint_from_method(DeletePromptRequest)
1356
+ self._edit_endpoint_and_call(
1357
+ endpoint=endpoint,
1358
+ method=method,
1359
+ req_body=req_body,
1360
+ name=name,
1361
+ proto_name=DeletePromptRequest,
1362
+ )
1363
+
1364
+ def set_prompt_tag(self, name: str, key: str, value: str) -> None:
1365
+ """
1366
+ Set a tag on a prompt in Unity Catalog.
1367
+ """
1368
+ req_body = message_to_json(SetPromptTagRequest(name=name, key=key, value=value))
1369
+ endpoint, method = self._get_endpoint_from_method(SetPromptTagRequest)
1370
+ self._edit_endpoint_and_call(
1371
+ endpoint=endpoint,
1372
+ method=method,
1373
+ req_body=req_body,
1374
+ name=name,
1375
+ key=key,
1376
+ proto_name=SetPromptTagRequest,
1377
+ )
1378
+
1379
+ def delete_prompt_tag(self, name: str, key: str) -> None:
1380
+ """
1381
+ Delete a tag from a prompt in Unity Catalog.
1382
+ """
1383
+ req_body = message_to_json(DeletePromptTagRequest(name=name, key=key))
1384
+ endpoint, method = self._get_endpoint_from_method(DeletePromptTagRequest)
1385
+ self._edit_endpoint_and_call(
1386
+ endpoint=endpoint,
1387
+ method=method,
1388
+ req_body=req_body,
1389
+ name=name,
1390
+ key=key,
1391
+ proto_name=DeletePromptTagRequest,
1392
+ )
1393
+
1394
+ def get_prompt(self, name: str) -> Optional[Prompt]:
1395
+ """
1396
+ Get prompt by name from Unity Catalog.
1397
+ """
1398
+ try:
1399
+ req_body = message_to_json(GetPromptRequest(name=name))
1400
+ endpoint, method = self._get_endpoint_from_method(GetPromptRequest)
1401
+ response_proto = self._edit_endpoint_and_call(
1402
+ endpoint=endpoint,
1403
+ method=method,
1404
+ req_body=req_body,
1405
+ name=name,
1406
+ proto_name=GetPromptRequest,
1407
+ )
1408
+ return proto_info_to_mlflow_prompt_info(response_proto, {})
1409
+ except Exception as e:
1410
+ if isinstance(e, MlflowException) and e.error_code == ErrorCode.Name(
1411
+ RESOURCE_DOES_NOT_EXIST
1412
+ ):
1413
+ return None
1414
+ raise
1415
+
1416
+ def create_prompt_version(
1417
+ self,
1418
+ name: str,
1419
+ template: str,
1420
+ description: Optional[str] = None,
1421
+ tags: Optional[dict[str, str]] = None,
1422
+ ) -> PromptVersion:
1423
+ """
1424
+ Create a new prompt version in Unity Catalog.
1425
+ """
1426
+ # Create a PromptVersion object with the provided fields
1427
+ prompt_version_proto = ProtoPromptVersion()
1428
+ prompt_version_proto.name = name
1429
+ # JSON-encode the template for Unity Catalog server
1430
+ prompt_version_proto.template = json.dumps(template)
1431
+
1432
+ # Note: version will be set by the backend when creating a new version
1433
+ # We don't set it here as it's generated server-side
1434
+ if description:
1435
+ prompt_version_proto.description = description
1436
+ if tags:
1437
+ prompt_version_proto.tags.extend(mlflow_tags_to_proto_version_tags(tags))
1438
+
1439
+ req_body = message_to_json(
1440
+ CreatePromptVersionRequest(
1441
+ name=name,
1442
+ prompt_version=prompt_version_proto,
1443
+ )
1444
+ )
1445
+ endpoint, method = self._get_endpoint_from_method(CreatePromptVersionRequest)
1446
+ response_proto = self._edit_endpoint_and_call(
1447
+ endpoint=endpoint,
1448
+ method=method,
1449
+ req_body=req_body,
1450
+ name=name,
1451
+ proto_name=CreatePromptVersionRequest,
1452
+ )
1453
+ return proto_to_mlflow_prompt(response_proto)
1454
+
1455
+ def get_prompt_version(self, name: str, version: Union[str, int]) -> Optional[PromptVersion]:
1456
+ """
1457
+ Get a specific prompt version from Unity Catalog.
1458
+ """
1459
+ try:
1460
+ req_body = message_to_json(GetPromptVersionRequest(name=name, version=str(version)))
1461
+ endpoint, method = self._get_endpoint_from_method(GetPromptVersionRequest)
1462
+ response_proto = self._edit_endpoint_and_call(
1463
+ endpoint=endpoint,
1464
+ method=method,
1465
+ req_body=req_body,
1466
+ name=name,
1467
+ version=version,
1468
+ proto_name=GetPromptVersionRequest,
1469
+ )
1470
+
1471
+ # No longer fetch prompt-level tags - keep them completely separate
1472
+ return proto_to_mlflow_prompt(response_proto)
1473
+ except Exception as e:
1474
+ if isinstance(e, MlflowException) and e.error_code == ErrorCode.Name(
1475
+ RESOURCE_DOES_NOT_EXIST
1476
+ ):
1477
+ return None
1478
+ raise
1479
+
1480
+ def delete_prompt_version(self, name: str, version: Union[str, int]) -> None:
1481
+ """
1482
+ Delete a prompt version from Unity Catalog.
1483
+ """
1484
+ # Delete the specific version only
1485
+ req_body = message_to_json(DeletePromptVersionRequest(name=name, version=str(version)))
1486
+ endpoint, method = self._get_endpoint_from_method(DeletePromptVersionRequest)
1487
+ self._edit_endpoint_and_call(
1488
+ endpoint=endpoint,
1489
+ method=method,
1490
+ req_body=req_body,
1491
+ name=name,
1492
+ version=version,
1493
+ proto_name=DeletePromptVersionRequest,
1494
+ )
1495
+
1496
+ def search_prompt_versions(
1497
+ self, name: str, max_results: Optional[int] = None, page_token: Optional[str] = None
1498
+ ) -> SearchPromptVersionsResponse:
1499
+ """
1500
+ Search prompt versions for a given prompt name in Unity Catalog.
1501
+
1502
+ Note: Unity Catalog server uses a non-standard endpoint pattern for this operation.
1503
+
1504
+ Args:
1505
+ name: Name of the prompt to search versions for
1506
+ max_results: Maximum number of versions to return
1507
+ page_token: Token for pagination
1508
+
1509
+ Returns:
1510
+ SearchPromptVersionsResponse containing the list of versions
1511
+ """
1512
+ req_body = message_to_json(
1513
+ SearchPromptVersionsRequest(name=name, max_results=max_results, page_token=page_token)
1514
+ )
1515
+ endpoint, method = self._get_endpoint_from_method(SearchPromptVersionsRequest)
1516
+ return self._edit_endpoint_and_call(
1517
+ endpoint=endpoint,
1518
+ method=method,
1519
+ req_body=req_body,
1520
+ name=name,
1521
+ proto_name=SearchPromptVersionsRequest,
1522
+ )
1523
+
1524
+ def set_prompt_version_tag(
1525
+ self, name: str, version: Union[str, int], key: str, value: str
1526
+ ) -> None:
1527
+ """
1528
+ Set a tag on a prompt version in Unity Catalog.
1529
+ """
1530
+ req_body = message_to_json(
1531
+ SetPromptVersionTagRequest(name=name, version=str(version), key=key, value=value)
1532
+ )
1533
+ endpoint, method = self._get_endpoint_from_method(SetPromptVersionTagRequest)
1534
+ self._edit_endpoint_and_call(
1535
+ endpoint=endpoint,
1536
+ method=method,
1537
+ req_body=req_body,
1538
+ name=name,
1539
+ version=version,
1540
+ key=key,
1541
+ proto_name=SetPromptVersionTagRequest,
1542
+ )
1543
+
1544
+ def delete_prompt_version_tag(self, name: str, version: Union[str, int], key: str) -> None:
1545
+ """
1546
+ Delete a tag from a prompt version in Unity Catalog.
1547
+ """
1548
+ req_body = message_to_json(
1549
+ DeletePromptVersionTagRequest(name=name, version=str(version), key=key)
1550
+ )
1551
+ endpoint, method = self._get_endpoint_from_method(DeletePromptVersionTagRequest)
1552
+ self._edit_endpoint_and_call(
1553
+ endpoint=endpoint,
1554
+ method=method,
1555
+ req_body=req_body,
1556
+ name=name,
1557
+ version=version,
1558
+ key=key,
1559
+ proto_name=DeletePromptVersionTagRequest,
1560
+ )
1561
+
1562
+ def get_prompt_version_by_alias(self, name: str, alias: str) -> Optional[PromptVersion]:
1563
+ """
1564
+ Get a prompt version by alias from Unity Catalog.
1565
+ """
1566
+ try:
1567
+ req_body = message_to_json(GetPromptVersionByAliasRequest(name=name, alias=alias))
1568
+ endpoint, method = self._get_endpoint_from_method(GetPromptVersionByAliasRequest)
1569
+ response_proto = self._edit_endpoint_and_call(
1570
+ endpoint=endpoint,
1571
+ method=method,
1572
+ req_body=req_body,
1573
+ name=name,
1574
+ alias=alias,
1575
+ proto_name=GetPromptVersionByAliasRequest,
1576
+ )
1577
+
1578
+ # No longer fetch prompt-level tags - keep them completely separate
1579
+ return proto_to_mlflow_prompt(response_proto)
1580
+ except Exception as e:
1581
+ if isinstance(e, MlflowException) and e.error_code == ErrorCode.Name(
1582
+ RESOURCE_DOES_NOT_EXIST
1583
+ ):
1584
+ return None
1585
+ raise
1586
+
1587
+ def set_prompt_alias(self, name: str, alias: str, version: Union[str, int]) -> None:
1588
+ """
1589
+ Set an alias for a prompt version in Unity Catalog.
1590
+ """
1591
+ req_body = message_to_json(
1592
+ SetPromptAliasRequest(name=name, alias=alias, version=str(version))
1593
+ )
1594
+ endpoint, method = self._get_endpoint_from_method(SetPromptAliasRequest)
1595
+ self._edit_endpoint_and_call(
1596
+ endpoint=endpoint,
1597
+ method=method,
1598
+ req_body=req_body,
1599
+ name=name,
1600
+ alias=alias,
1601
+ version=version,
1602
+ proto_name=SetPromptAliasRequest,
1603
+ )
1604
+
1605
+ def delete_prompt_alias(self, name: str, alias: str) -> None:
1606
+ """
1607
+ Delete an alias from a prompt in Unity Catalog.
1608
+ """
1609
+ req_body = message_to_json(DeletePromptAliasRequest(name=name, alias=alias))
1610
+ endpoint, method = self._get_endpoint_from_method(DeletePromptAliasRequest)
1611
+ self._edit_endpoint_and_call(
1612
+ endpoint=endpoint,
1613
+ method=method,
1614
+ req_body=req_body,
1615
+ name=name,
1616
+ alias=alias,
1617
+ proto_name=DeletePromptAliasRequest,
1618
+ )
1619
+
1620
+ def link_prompt_version_to_model(self, name: str, version: str, model_id: str) -> None:
1621
+ """
1622
+ Link a prompt version to a model in Unity Catalog.
1623
+
1624
+ Args:
1625
+ name: Name of the prompt.
1626
+ version: Version of the prompt to link.
1627
+ model_id: ID of the model to link to.
1628
+ """
1629
+ # Call the default implementation, since the LinkPromptVersionsToModels API
1630
+ # will initially be a no-op until the Databricks backend supports it
1631
+ super().link_prompt_version_to_model(name=name, version=version, model_id=model_id)
1632
+
1633
+ prompt_version_entry = PromptVersionLinkEntry(name=name, version=version)
1634
+ req_body = message_to_json(
1635
+ LinkPromptVersionsToModelsRequest(
1636
+ prompt_versions=[prompt_version_entry], model_ids=[model_id]
1637
+ )
1638
+ )
1639
+ endpoint, method = self._get_endpoint_from_method(LinkPromptVersionsToModelsRequest)
1640
+ try:
1641
+ # NB: This will not raise an exception if the backend does not support linking.
1642
+ # We do this to prioritize reduction in errors and log spam while the prompt
1643
+ # registry remains experimental
1644
+ self._edit_endpoint_and_call(
1645
+ endpoint=endpoint,
1646
+ method=method,
1647
+ req_body=req_body,
1648
+ name=name,
1649
+ version=version,
1650
+ model_id=model_id,
1651
+ proto_name=LinkPromptVersionsToModelsRequest,
1652
+ )
1653
+ except Exception:
1654
+ _logger.debug("Failed to link prompt version to model in unity catalog", exc_info=True)
1655
+
1656
+ def link_prompts_to_trace(self, prompt_versions: list[PromptVersion], trace_id: str) -> None:
1657
+ """
1658
+ Link multiple prompt versions to a trace in Unity Catalog.
1659
+
1660
+ Args:
1661
+ prompt_versions: List of PromptVersion objects to link.
1662
+ trace_id: Trace ID to link to each prompt version.
1663
+ """
1664
+ super().link_prompts_to_trace(prompt_versions=prompt_versions, trace_id=trace_id)
1665
+
1666
+ prompt_version_entries = [
1667
+ PromptVersionLinkEntry(name=pv.name, version=str(pv.version)) for pv in prompt_versions
1668
+ ]
1669
+
1670
+ batch_size = 25
1671
+ endpoint, method = self._get_endpoint_from_method(LinkPromptsToTracesRequest)
1672
+
1673
+ for i in range(0, len(prompt_version_entries), batch_size):
1674
+ batch = prompt_version_entries[i : i + batch_size]
1675
+ req_body = message_to_json(
1676
+ LinkPromptsToTracesRequest(prompt_versions=batch, trace_ids=[trace_id])
1677
+ )
1678
+ try:
1679
+ self._edit_endpoint_and_call(
1680
+ endpoint=endpoint,
1681
+ method=method,
1682
+ req_body=req_body,
1683
+ proto_name=LinkPromptsToTracesRequest,
1684
+ )
1685
+ except Exception:
1686
+ _logger.debug("Failed to link prompts to traces in unity catalog", exc_info=True)
1687
+
1688
+ def link_prompt_version_to_run(self, name: str, version: str, run_id: str) -> None:
1689
+ """
1690
+ Link a prompt version to a run in Unity Catalog.
1691
+
1692
+ Args:
1693
+ name: Name of the prompt.
1694
+ version: Version of the prompt to link.
1695
+ run_id: ID of the run to link to.
1696
+ """
1697
+ super().link_prompt_version_to_run(name=name, version=version, run_id=run_id)
1698
+
1699
+ prompt_version_entry = PromptVersionLinkEntry(name=name, version=version)
1700
+ endpoint, method = self._get_endpoint_from_method(LinkPromptVersionsToRunsRequest)
1701
+
1702
+ req_body = message_to_json(
1703
+ LinkPromptVersionsToRunsRequest(
1704
+ prompt_versions=[prompt_version_entry], run_ids=[run_id]
1705
+ )
1706
+ )
1707
+ try:
1708
+ self._edit_endpoint_and_call(
1709
+ endpoint=endpoint,
1710
+ method=method,
1711
+ req_body=req_body,
1712
+ proto_name=LinkPromptVersionsToRunsRequest,
1713
+ )
1714
+ except Exception:
1715
+ _logger.debug("Failed to link prompt version to run in unity catalog", exc_info=True)
1716
+
1717
+ def _edit_endpoint_and_call(self, endpoint, method, req_body, proto_name, **kwargs):
1718
+ """
1719
+ Edit endpoint URL with parameters and make the call.
1720
+
1721
+ Args:
1722
+ endpoint: URL template with placeholders like {name}, {key}
1723
+ method: HTTP method
1724
+ req_body: Request body
1725
+ proto_name: Protobuf message class for response
1726
+ **kwargs: Parameters to substitute in the endpoint template
1727
+ """
1728
+ # Replace placeholders in endpoint with actual values
1729
+ for key, value in kwargs.items():
1730
+ if value is not None:
1731
+ endpoint = endpoint.replace(f"{{{key}}}", str(value))
1732
+
1733
+ # Make the API call
1734
+ return call_endpoint(
1735
+ self.get_host_creds(),
1736
+ endpoint=endpoint,
1737
+ method=method,
1738
+ json_body=req_body,
1739
+ response_proto=self._get_response_from_method(proto_name),
1740
+ )