genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,1286 @@
1
+ import logging
2
+ import urllib
3
+ from typing import Any, Optional, Union
4
+
5
+ import sqlalchemy
6
+ from sqlalchemy.future import select
7
+
8
+ import mlflow.store.db.utils
9
+ from mlflow.entities.model_registry.model_version_stages import (
10
+ ALL_STAGES,
11
+ DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS,
12
+ STAGE_ARCHIVED,
13
+ STAGE_DELETED_INTERNAL,
14
+ get_canonical_stage,
15
+ )
16
+ from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
17
+ from mlflow.exceptions import MlflowException
18
+ from mlflow.prompt.registry_utils import handle_resource_already_exist_error, has_prompt_tag
19
+ from mlflow.protos.databricks_pb2 import (
20
+ INVALID_PARAMETER_VALUE,
21
+ INVALID_STATE,
22
+ RESOURCE_ALREADY_EXISTS,
23
+ RESOURCE_DOES_NOT_EXIST,
24
+ )
25
+ from mlflow.store.artifact.utils.models import _parse_model_uri
26
+ from mlflow.store.entities.paged_list import PagedList
27
+ from mlflow.store.model_registry import (
28
+ SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT,
29
+ SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
30
+ SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
31
+ SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
32
+ )
33
+ from mlflow.store.model_registry.abstract_store import AbstractStore
34
+ from mlflow.store.model_registry.dbmodels.models import (
35
+ SqlModelVersion,
36
+ SqlModelVersionTag,
37
+ SqlRegisteredModel,
38
+ SqlRegisteredModelAlias,
39
+ SqlRegisteredModelTag,
40
+ )
41
+ from mlflow.tracking.client import MlflowClient
42
+ from mlflow.utils.search_utils import SearchModelUtils, SearchModelVersionUtils, SearchUtils
43
+ from mlflow.utils.time import get_current_time_millis
44
+ from mlflow.utils.uri import extract_db_type_from_uri
45
+ from mlflow.utils.validation import (
46
+ _validate_model_alias_name,
47
+ _validate_model_name,
48
+ _validate_model_renaming,
49
+ _validate_model_version,
50
+ _validate_model_version_tag,
51
+ _validate_registered_model_tag,
52
+ _validate_tag_name,
53
+ )
54
+
55
+ _logger = logging.getLogger(__name__)
56
+
57
+ # For each database table, fetch its columns and define an appropriate attribute for each column
58
+ # on the table's associated object representation (Mapper). This is necessary to ensure that
59
+ # columns defined via backreference are available as Mapper instance attributes (e.g.,
60
+ # ``SqlRegisteredModel.model_versions``). For more information, see
61
+ # https://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.configure_mappers
62
+ # and https://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.mapper.Mapper
63
+ sqlalchemy.orm.configure_mappers()
64
+
65
+
66
+ class SqlAlchemyStore(AbstractStore):
67
+ """
68
+ This entity may change or be removed in a future release without warning.
69
+ SQLAlchemy compliant backend store for tracking meta data for MLflow entities. MLflow
70
+ supports the database dialects ``mysql``, ``mssql``, ``sqlite``, and ``postgresql``.
71
+ As specified in the
72
+ `SQLAlchemy docs <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_ ,
73
+ the database URI is expected in the format
74
+ ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``. If you do not
75
+ specify a driver, SQLAlchemy uses a dialect's default driver.
76
+
77
+ This store interacts with SQL store using SQLAlchemy abstractions defined for MLflow entities.
78
+ :py:class:`mlflow.store.model_registry.models.RegisteredModel` and
79
+ :py:class:`mlflow.store.model_registry.models.ModelVersion`
80
+ """
81
+
82
+ CREATE_MODEL_VERSION_RETRIES = 3
83
+
84
+ def __init__(self, db_uri):
85
+ """
86
+ Create a database backed store.
87
+
88
+ Args:
89
+ db_uri: The SQLAlchemy database URI string to connect to the database. See
90
+ the `SQLAlchemy docs
91
+ <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
92
+ for format specifications. MLflow supports the dialects ``mysql``,
93
+ ``mssql``, ``sqlite``, and ``postgresql``.
94
+ default_artifact_root: Path/URI to location suitable for large data (such as a blob
95
+ store object, DBFS path, or shared NFS file system).
96
+ """
97
+ super().__init__()
98
+ self.db_uri = db_uri
99
+ self.db_type = extract_db_type_from_uri(db_uri)
100
+ self.engine = mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(db_uri)
101
+ if not mlflow.store.db.utils._all_tables_exist(self.engine):
102
+ mlflow.store.db.utils._initialize_tables(self.engine)
103
+ # Verify that all model registry tables exist.
104
+ SqlAlchemyStore._verify_registry_tables_exist(self.engine)
105
+ SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
106
+ self.ManagedSessionMaker = mlflow.store.db.utils._get_managed_session_maker(
107
+ SessionMaker, self.db_type
108
+ )
109
+ # TODO: verify schema here once we add logic to initialize the registry tables if they
110
+ # don't exist (schema verification will fail in tests otherwise)
111
+ # mlflow.store.db.utils._verify_schema(self.engine)
112
+
113
+ def _get_dialect(self):
114
+ return self.engine.dialect.name
115
+
116
+ def _dispose_engine(self):
117
+ self.engine.dispose()
118
+
119
+ @staticmethod
120
+ def _verify_registry_tables_exist(engine):
121
+ # Verify that all tables have been created.
122
+ inspected_tables = set(sqlalchemy.inspect(engine).get_table_names())
123
+ expected_tables = [
124
+ SqlRegisteredModel.__tablename__,
125
+ SqlModelVersion.__tablename__,
126
+ ]
127
+ if any(table not in inspected_tables for table in expected_tables):
128
+ # TODO: Replace the MlflowException with the following line once it's possible to run
129
+ # the registry against a different DB than the tracking server:
130
+ # mlflow.store.db.utils._initialize_tables(self.engine)
131
+ raise MlflowException("Database migration in unexpected state. Run manual upgrade.")
132
+
133
+ @staticmethod
134
+ def _get_eager_registered_model_query_options():
135
+ """
136
+ A list of SQLAlchemy query options that can be used to eagerly
137
+ load the following registered model attributes
138
+ when fetching a registered model: ``registered_model_tags``.
139
+ """
140
+ # Use a subquery load rather than a joined load in order to minimize the memory overhead
141
+ # of the eager loading procedure. For more information about relationship loading
142
+ # techniques, see https://docs.sqlalchemy.org/en/13/orm/
143
+ # loading_relationships.html#relationship-loading-techniques
144
+ return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
145
+
146
+ @staticmethod
147
+ def _get_eager_model_version_query_options():
148
+ """
149
+ A list of SQLAlchemy query options that can be used to eagerly
150
+ load the following model version attributes
151
+ when fetching a model version: ``model_version_tags``.
152
+ """
153
+ # Use a subquery load rather than a joined load in order to minimize the memory overhead
154
+ # of the eager loading procedure. For more information about relationship loading
155
+ # techniques, see https://docs.sqlalchemy.org/en/13/orm/
156
+ # loading_relationships.html#relationship-loading-techniques
157
+ return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_version_tags)]
158
+
159
+ def create_registered_model(self, name, tags=None, description=None, deployment_job_id=None):
160
+ """
161
+ Create a new registered model in backend store.
162
+
163
+ Args:
164
+ name: Name of the new model. This is expected to be unique in the backend store.
165
+ tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag`
166
+ instances associated with this registered model.
167
+ description: Description of the version.
168
+ deployment_job_id: Optional deployment job ID.
169
+
170
+ Returns:
171
+ A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel`
172
+ created in the backend.
173
+ """
174
+ _validate_model_name(name)
175
+ for tag in tags or []:
176
+ _validate_registered_model_tag(tag.key, tag.value)
177
+ with self.ManagedSessionMaker() as session:
178
+ try:
179
+ creation_time = get_current_time_millis()
180
+ registered_model = SqlRegisteredModel(
181
+ name=name,
182
+ creation_time=creation_time,
183
+ last_updated_time=creation_time,
184
+ description=description,
185
+ )
186
+ tags_dict = {}
187
+ for tag in tags or []:
188
+ tags_dict[tag.key] = tag.value
189
+ registered_model.registered_model_tags = [
190
+ SqlRegisteredModelTag(key=key, value=value) for key, value in tags_dict.items()
191
+ ]
192
+ session.add(registered_model)
193
+ session.flush()
194
+ return registered_model.to_mlflow_entity()
195
+ except sqlalchemy.exc.IntegrityError:
196
+ existing_model = self.get_registered_model(name)
197
+ handle_resource_already_exist_error(
198
+ name, has_prompt_tag(existing_model._tags), has_prompt_tag(tags)
199
+ )
200
+
201
+ @classmethod
202
+ def _get_registered_model(cls, session, name, eager=False): # noqa: D417
203
+ """
204
+ Args:
205
+ eager: If ``True``, eagerly loads the registered model's tags. If ``False``, these
206
+ attributes are not eagerly loaded and will be loaded when their corresponding object
207
+ properties are accessed from the resulting ``SqlRegisteredModel`` object.
208
+ """
209
+ _validate_model_name(name)
210
+ query_options = cls._get_eager_registered_model_query_options() if eager else []
211
+ rms = (
212
+ session.query(SqlRegisteredModel)
213
+ .options(*query_options)
214
+ .filter(SqlRegisteredModel.name == name)
215
+ .all()
216
+ )
217
+
218
+ if len(rms) == 0:
219
+ raise MlflowException(
220
+ f"Registered Model with name={name} not found", RESOURCE_DOES_NOT_EXIST
221
+ )
222
+ if len(rms) > 1:
223
+ raise MlflowException(
224
+ f"Expected only 1 registered model with name={name}. Found {len(rms)}.",
225
+ INVALID_STATE,
226
+ )
227
+ return rms[0]
228
+
229
+ def update_registered_model(self, name, description, deployment_job_id=None):
230
+ """
231
+ Update description of the registered model.
232
+
233
+ Args:
234
+ name: Registered model name.
235
+ description: New description.
236
+ deployment_job_id: Optional deployment job ID.
237
+
238
+ Returns:
239
+ A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
240
+
241
+ """
242
+ with self.ManagedSessionMaker() as session:
243
+ sql_registered_model = self._get_registered_model(session, name)
244
+ updated_time = get_current_time_millis()
245
+ sql_registered_model.description = description
246
+ sql_registered_model.last_updated_time = updated_time
247
+ session.add(sql_registered_model)
248
+ session.flush()
249
+ return sql_registered_model.to_mlflow_entity()
250
+
251
+ def rename_registered_model(self, name, new_name):
252
+ """
253
+ Rename the registered model.
254
+
255
+ Args:
256
+ name: Registered model name.
257
+ new_name: New proposed name.
258
+
259
+ Returns:
260
+ A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
261
+
262
+ """
263
+ _validate_model_renaming(new_name)
264
+ with self.ManagedSessionMaker() as session:
265
+ sql_registered_model = self._get_registered_model(session, name)
266
+ try:
267
+ updated_time = get_current_time_millis()
268
+ sql_registered_model.name = new_name
269
+ for sql_model_version in sql_registered_model.model_versions:
270
+ sql_model_version.name = new_name
271
+ sql_model_version.last_updated_time = updated_time
272
+ sql_registered_model.last_updated_time = updated_time
273
+ session.add_all([sql_registered_model] + sql_registered_model.model_versions)
274
+ session.flush()
275
+ return sql_registered_model.to_mlflow_entity()
276
+ except sqlalchemy.exc.IntegrityError as e:
277
+ raise MlflowException(
278
+ f"Registered Model (name={new_name}) already exists. Error: {e}",
279
+ RESOURCE_ALREADY_EXISTS,
280
+ )
281
+
282
+ def delete_registered_model(self, name):
283
+ """
284
+ Delete the registered model.
285
+ Backend raises exception if a registered model with given name does not exist.
286
+
287
+ Args:
288
+ name: Registered model name.
289
+
290
+ Returns:
291
+ None
292
+ """
293
+ with self.ManagedSessionMaker() as session:
294
+ sql_registered_model = self._get_registered_model(session, name)
295
+ session.delete(sql_registered_model)
296
+
297
+ def _compute_next_token(self, max_results_for_query, current_size, offset, max_results):
298
+ next_token = None
299
+ if max_results_for_query == current_size:
300
+ final_offset = offset + max_results
301
+ next_token = SearchUtils.create_page_token(final_offset)
302
+ return next_token
303
+
304
+ def search_registered_models(
305
+ self,
306
+ filter_string=None,
307
+ max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
308
+ order_by=None,
309
+ page_token=None,
310
+ ):
311
+ """
312
+ Search for registered models in backend that satisfy the filter criteria.
313
+
314
+ Args:
315
+ filter_string: Filter query string, defaults to searching all registered models.
316
+ max_results: Maximum number of registered models desired.
317
+ order_by: List of column names with ASC|DESC annotation, to be used for ordering
318
+ matching search results.
319
+ page_token: Token specifying the next page of results. It should be obtained from
320
+ a ``search_registered_models`` call.
321
+
322
+ Returns:
323
+ A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
324
+ that satisfy the search expressions. The pagination token for the next page can be
325
+ obtained via the ``token`` attribute of the object.
326
+ """
327
+ if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD:
328
+ raise MlflowException(
329
+ "Invalid value for request parameter max_results. It must be at most "
330
+ f"{SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
331
+ INVALID_PARAMETER_VALUE,
332
+ )
333
+
334
+ parsed_filters = SearchModelUtils.parse_search_filter(filter_string)
335
+
336
+ filter_query = self._get_search_registered_model_filter_query(
337
+ parsed_filters, self.engine.dialect.name
338
+ )
339
+
340
+ parsed_orderby = self._parse_search_registered_models_order_by(order_by)
341
+ offset = SearchUtils.parse_start_offset_from_page_token(page_token)
342
+ # we query for max_results + 1 items to check whether there is another page to return.
343
+ # this remediates having to make another query which returns no items.
344
+ max_results_for_query = max_results + 1
345
+
346
+ with self.ManagedSessionMaker() as session:
347
+ query = (
348
+ filter_query.options(*self._get_eager_registered_model_query_options())
349
+ .order_by(*parsed_orderby)
350
+ .limit(max_results_for_query)
351
+ )
352
+ if page_token:
353
+ query = query.offset(offset)
354
+ sql_registered_models = session.execute(query).scalars(SqlRegisteredModel).all()
355
+ next_page_token = self._compute_next_token(
356
+ max_results_for_query, len(sql_registered_models), offset, max_results
357
+ )
358
+ rm_entities = [rm.to_mlflow_entity() for rm in sql_registered_models][:max_results]
359
+ return PagedList(rm_entities, next_page_token)
360
+
361
+ @classmethod
362
+ def _get_search_registered_model_filter_query(cls, parsed_filters, dialect):
363
+ attribute_filters = []
364
+ tag_filters = {}
365
+ for f in parsed_filters:
366
+ type_ = f["type"]
367
+ key = f["key"]
368
+ comparator = f["comparator"]
369
+ value = f["value"]
370
+ if type_ == "attribute":
371
+ if key != "name":
372
+ raise MlflowException(
373
+ f"Invalid attribute name: {key}", error_code=INVALID_PARAMETER_VALUE
374
+ )
375
+ if comparator not in ("=", "!=", "LIKE", "ILIKE"):
376
+ raise MlflowException(
377
+ f"Invalid comparator for attribute: {comparator}",
378
+ error_code=INVALID_PARAMETER_VALUE,
379
+ )
380
+ attr = getattr(SqlRegisteredModel, key)
381
+ attr_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(attr, value)
382
+ attribute_filters.append(attr_filter)
383
+ elif type_ == "tag":
384
+ if comparator not in ("=", "!=", "LIKE", "ILIKE"):
385
+ raise MlflowException.invalid_parameter_value(
386
+ f"Invalid comparator for tag: {comparator}"
387
+ )
388
+ if key not in tag_filters:
389
+ key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(
390
+ SqlRegisteredModelTag.key, key
391
+ )
392
+ tag_filters[key] = [key_filter]
393
+
394
+ val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
395
+ SqlRegisteredModelTag.value, value
396
+ )
397
+ tag_filters[key].append(val_filter)
398
+ else:
399
+ raise MlflowException(
400
+ f"Invalid token type: {type_}", error_code=INVALID_PARAMETER_VALUE
401
+ )
402
+
403
+ rm_query = select(SqlRegisteredModel).filter(*attribute_filters)
404
+
405
+ if not cls._is_querying_prompt(parsed_filters):
406
+ rm_query = cls._update_query_to_exclude_prompts(
407
+ rm_query, tag_filters, dialect, SqlRegisteredModel, SqlRegisteredModelTag
408
+ )
409
+
410
+ if tag_filters:
411
+ sql_tag_filters = (sqlalchemy.and_(*x) for x in tag_filters.values())
412
+ tag_filter_query = (
413
+ select(SqlRegisteredModelTag.name)
414
+ .filter(sqlalchemy.or_(*sql_tag_filters))
415
+ .group_by(SqlRegisteredModelTag.name)
416
+ .having(sqlalchemy.func.count(sqlalchemy.literal(1)) == len(tag_filters))
417
+ .subquery()
418
+ )
419
+
420
+ return rm_query.join(
421
+ tag_filter_query, SqlRegisteredModel.name == tag_filter_query.c.name
422
+ )
423
+ else:
424
+ return rm_query
425
+
426
+ @classmethod
427
+ def _get_search_model_versions_filter_clauses(cls, parsed_filters, dialect):
428
+ attribute_filters = []
429
+ tag_filters = {}
430
+ for f in parsed_filters:
431
+ type_ = f["type"]
432
+ key = f["key"]
433
+ comparator = f["comparator"]
434
+ value = f["value"]
435
+ if type_ == "attribute":
436
+ if key not in SearchModelVersionUtils.VALID_SEARCH_ATTRIBUTE_KEYS:
437
+ raise MlflowException(
438
+ f"Invalid attribute name: {key}", error_code=INVALID_PARAMETER_VALUE
439
+ )
440
+ if key in SearchModelVersionUtils.NUMERIC_ATTRIBUTES:
441
+ if (
442
+ comparator
443
+ not in SearchModelVersionUtils.VALID_NUMERIC_ATTRIBUTE_COMPARATORS
444
+ ):
445
+ raise MlflowException(
446
+ f"Invalid comparator for attribute {key}: {comparator}",
447
+ error_code=INVALID_PARAMETER_VALUE,
448
+ )
449
+ elif (
450
+ comparator not in SearchModelVersionUtils.VALID_STRING_ATTRIBUTE_COMPARATORS
451
+ or (comparator == "IN" and key != "run_id")
452
+ ):
453
+ raise MlflowException(
454
+ f"Invalid comparator for attribute: {comparator}",
455
+ error_code=INVALID_PARAMETER_VALUE,
456
+ )
457
+ if key == "source_path":
458
+ key_name = "source"
459
+ elif key == "version_number":
460
+ key_name = "version"
461
+ else:
462
+ key_name = key
463
+ attr = getattr(SqlModelVersion, key_name)
464
+ if comparator == "IN":
465
+ # Note: Here the run_id values in databases contain only lower case letters,
466
+ # so we already filter out comparison values containing upper case letters
467
+ # in `SearchModelUtils._get_value`. This addresses MySQL IN clause case
468
+ # in-sensitive issue.
469
+ val_filter = attr.in_(value)
470
+ else:
471
+ val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
472
+ attr, value
473
+ )
474
+ attribute_filters.append(val_filter)
475
+ elif type_ == "tag":
476
+ if comparator not in ("=", "!=", "LIKE", "ILIKE"):
477
+ raise MlflowException.invalid_parameter_value(
478
+ f"Invalid comparator for tag: {comparator}",
479
+ )
480
+ if key not in tag_filters:
481
+ key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(
482
+ SqlModelVersionTag.key, key
483
+ )
484
+ tag_filters[key] = [key_filter]
485
+
486
+ val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
487
+ SqlModelVersionTag.value, value
488
+ )
489
+ tag_filters[key].append(val_filter)
490
+ else:
491
+ raise MlflowException(
492
+ f"Invalid token type: {type_}", error_code=INVALID_PARAMETER_VALUE
493
+ )
494
+
495
+ mv_query = select(SqlModelVersion).filter(*attribute_filters)
496
+
497
+ if not cls._is_querying_prompt(parsed_filters):
498
+ mv_query = cls._update_query_to_exclude_prompts(
499
+ mv_query, tag_filters, dialect, SqlModelVersion, SqlModelVersionTag
500
+ )
501
+
502
+ if tag_filters:
503
+ sql_tag_filters = (sqlalchemy.and_(*x) for x in tag_filters.values())
504
+ tag_filter_query = (
505
+ select(SqlModelVersionTag.name, SqlModelVersionTag.version)
506
+ .filter(sqlalchemy.or_(*sql_tag_filters))
507
+ .group_by(SqlModelVersionTag.name, SqlModelVersionTag.version)
508
+ .having(sqlalchemy.func.count(sqlalchemy.literal(1)) == len(tag_filters))
509
+ .subquery()
510
+ )
511
+ return mv_query.join(
512
+ tag_filter_query,
513
+ sqlalchemy.and_(
514
+ SqlModelVersion.name == tag_filter_query.c.name,
515
+ SqlModelVersion.version == tag_filter_query.c.version,
516
+ ),
517
+ )
518
+ else:
519
+ return mv_query
520
+
521
+ @classmethod
522
+ def _update_query_to_exclude_prompts(
523
+ cls,
524
+ query: Any,
525
+ tag_filters: dict[str, list[Any]],
526
+ dialect: str,
527
+ main_db_model: Union[SqlModelVersion, SqlRegisteredModel],
528
+ tag_db_model: Union[SqlModelVersionTag, SqlRegisteredModelTag],
529
+ ):
530
+ """
531
+ Update query to exclude all prompt rows and return only normal model or model versions.
532
+
533
+ Prompts and normal models are distinguished by the `mlflow.prompt.is_prompt` tag.
534
+ The search API should only return normal models by default. However, simply filtering
535
+ rows using the tag like this does not work because models do not have the prompt tag.
536
+
537
+ tags.`mlflow.prompt.is_prompt` != 'true'
538
+ tags.`mlflow.prompt.is_prompt` = 'false'
539
+
540
+ To workaround this, we need to use a subquery to get all prompt rows and then use an
541
+ anti-join for excluding prompts.
542
+ """
543
+ # If the tag filter contains the prompt tag, remove it
544
+ tag_filters.pop(IS_PROMPT_TAG_KEY, [])
545
+
546
+ # Filter to get all prompt rows
547
+ equal = SearchUtils.get_sql_comparison_func("=", dialect)
548
+ prompts_subquery = (
549
+ select(tag_db_model.name)
550
+ .filter(
551
+ equal(tag_db_model.key, IS_PROMPT_TAG_KEY),
552
+ equal(tag_db_model.value, "true"),
553
+ )
554
+ .group_by(tag_db_model.name)
555
+ .subquery()
556
+ )
557
+ return query.join(
558
+ prompts_subquery, main_db_model.name == prompts_subquery.c.name, isouter=True
559
+ ).filter(prompts_subquery.c.name.is_(None))
560
+
561
+ @classmethod
562
+ def _is_querying_prompt(cls, parsed_filters: list[dict[str, Any]]) -> bool:
563
+ for f in parsed_filters:
564
+ if f["type"] != "tag" or f["key"] != IS_PROMPT_TAG_KEY:
565
+ continue
566
+
567
+ return (f["comparator"] == "=" and f["value"].lower() == "true") or (
568
+ f["comparator"] == "!=" and f["value"].lower() == "false"
569
+ )
570
+
571
+ # Query should return only normal models by default
572
+ return False
573
+
574
+ @classmethod
575
+ def _parse_search_registered_models_order_by(cls, order_by_list):
576
+ """Sorts a set of registered models based on their natural ordering and an overriding set
577
+ of order_bys. Registered models are naturally ordered first by name ascending.
578
+ """
579
+ clauses = []
580
+ observed_order_by_clauses = set()
581
+ if order_by_list:
582
+ for order_by_clause in order_by_list:
583
+ (
584
+ attribute_token,
585
+ ascending,
586
+ ) = SearchUtils.parse_order_by_for_search_registered_models(order_by_clause)
587
+ if attribute_token == SqlRegisteredModel.name.key:
588
+ field = SqlRegisteredModel.name
589
+ elif attribute_token in SearchUtils.VALID_TIMESTAMP_ORDER_BY_KEYS:
590
+ field = SqlRegisteredModel.last_updated_time
591
+ else:
592
+ raise MlflowException(
593
+ f"Invalid order by key '{attribute_token}' specified."
594
+ + "Valid keys are "
595
+ + f"'{SearchUtils.RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS}'",
596
+ error_code=INVALID_PARAMETER_VALUE,
597
+ )
598
+ if field.key in observed_order_by_clauses:
599
+ raise MlflowException(f"`order_by` contains duplicate fields: {order_by_list}")
600
+ observed_order_by_clauses.add(field.key)
601
+ if ascending:
602
+ clauses.append(field.asc())
603
+ else:
604
+ clauses.append(field.desc())
605
+
606
+ if SqlRegisteredModel.name.key not in observed_order_by_clauses:
607
+ clauses.append(SqlRegisteredModel.name.asc())
608
+ return clauses
609
+
610
+ def get_registered_model(self, name):
611
+ """
612
+ Get registered model instance by name.
613
+
614
+ Args:
615
+ name: Registered model name.
616
+
617
+ Returns:
618
+ A single :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
619
+ """
620
+ with self.ManagedSessionMaker() as session:
621
+ return self._get_registered_model(session, name, eager=True).to_mlflow_entity()
622
+
623
+ def get_latest_versions(self, name, stages=None):
624
+ """
625
+ Latest version models for each requested stage. If no ``stages`` argument is provided,
626
+ returns the latest version for each stage.
627
+
628
+ Args:
629
+ name: Registered model name.
630
+ stages: List of desired stages. If input list is None, return latest versions for
631
+ each stage.
632
+
633
+ Returns:
634
+ List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
635
+
636
+ """
637
+ with self.ManagedSessionMaker() as session:
638
+ sql_registered_model = self._get_registered_model(session, name)
639
+ # Convert to RegisteredModel entity first and then extract latest_versions
640
+ latest_versions = sql_registered_model.to_mlflow_entity().latest_versions
641
+ if stages is None or len(stages) == 0:
642
+ expected_stages = {get_canonical_stage(stage) for stage in ALL_STAGES}
643
+ else:
644
+ expected_stages = {get_canonical_stage(stage) for stage in stages}
645
+ mvs = [mv for mv in latest_versions if mv.current_stage in expected_stages]
646
+
647
+ # Populate aliases for each model version
648
+ for mv in mvs:
649
+ model_aliases = sql_registered_model.registered_model_aliases
650
+ mv.aliases = [alias.alias for alias in model_aliases if alias.version == mv.version]
651
+
652
+ return mvs
653
+
654
+ @classmethod
655
+ def _get_registered_model_tag(cls, session, name, key):
656
+ tags = (
657
+ session.query(SqlRegisteredModelTag)
658
+ .filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.key == key)
659
+ .all()
660
+ )
661
+ if len(tags) == 0:
662
+ return None
663
+ if len(tags) > 1:
664
+ raise MlflowException(
665
+ f"Expected only 1 registered model tag with name={name}, key={key}. "
666
+ f"Found {len(tags)}.",
667
+ INVALID_STATE,
668
+ )
669
+ return tags[0]
670
+
671
+ def set_registered_model_tag(self, name, tag):
672
+ """
673
+ Set a tag for the registered model.
674
+
675
+ Args:
676
+ name: Registered model name.
677
+ tag: :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log.
678
+
679
+ Returns:
680
+ None
681
+ """
682
+ _validate_model_name(name)
683
+ _validate_registered_model_tag(tag.key, tag.value)
684
+ with self.ManagedSessionMaker() as session:
685
+ # check if registered model exists
686
+ self._get_registered_model(session, name)
687
+ session.merge(SqlRegisteredModelTag(name=name, key=tag.key, value=tag.value))
688
+
689
+ def delete_registered_model_tag(self, name, key):
690
+ """
691
+ Delete a tag associated with the registered model.
692
+
693
+ Args:
694
+ name: Registered model name.
695
+ key: Registered model tag key.
696
+
697
+ Returns:
698
+ None
699
+ """
700
+ _validate_model_name(name)
701
+ _validate_tag_name(key)
702
+ with self.ManagedSessionMaker() as session:
703
+ # check if registered model exists
704
+ self._get_registered_model(session, name)
705
+ existing_tag = self._get_registered_model_tag(session, name, key)
706
+ if existing_tag is not None:
707
+ session.delete(existing_tag)
708
+
709
+ # CRUD API for ModelVersion objects
710
+
711
+ def create_model_version(
712
+ self,
713
+ name,
714
+ source,
715
+ run_id=None,
716
+ tags=None,
717
+ run_link=None,
718
+ description=None,
719
+ local_model_path=None,
720
+ model_id: Optional[str] = None,
721
+ ):
722
+ """
723
+ Create a new model version from given source and run ID.
724
+
725
+ Args:
726
+ name: Registered model name.
727
+ source: URI indicating the location of the model artifacts.
728
+ run_id: Run ID from MLflow tracking server that generated the model.
729
+ tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
730
+ instances associated with this model version.
731
+ run_link: Link to the run from an MLflow tracking server that generated this model.
732
+ description: Description of the version.
733
+ local_model_path: Unused.
734
+ model_id: The ID of the model (from an Experiment) that is being promoted to a
735
+ registered model version, if applicable.
736
+
737
+ Returns:
738
+ A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
739
+ created in the backend.
740
+
741
+ """
742
+
743
+ def next_version(sql_registered_model):
744
+ if sql_registered_model.model_versions:
745
+ return max([mv.version for mv in sql_registered_model.model_versions]) + 1
746
+ else:
747
+ return 1
748
+
749
+ _validate_model_name(name)
750
+ for tag in tags or []:
751
+ _validate_model_version_tag(tag.key, tag.value)
752
+ storage_location = source
753
+ if urllib.parse.urlparse(source).scheme == "models":
754
+ parsed_model_uri = _parse_model_uri(source)
755
+ try:
756
+ if parsed_model_uri.model_id is not None:
757
+ # TODO: Propagate tracking URI to file sqlalchemy directly, rather than relying
758
+ # on global URI (individual MlflowClient instances may have different tracking
759
+ # URIs)
760
+ model = MlflowClient().get_logged_model(parsed_model_uri.model_id)
761
+ storage_location = model.artifact_location
762
+ run_id = run_id or model.source_run_id
763
+ else:
764
+ storage_location = self.get_model_version_download_uri(
765
+ parsed_model_uri.name, parsed_model_uri.version
766
+ )
767
+ except Exception as e:
768
+ raise MlflowException(
769
+ f"Unable to fetch model from model URI source artifact location '{source}'."
770
+ f"Error: {e}"
771
+ ) from e
772
+ with self.ManagedSessionMaker() as session:
773
+ creation_time = get_current_time_millis()
774
+ for attempt in range(self.CREATE_MODEL_VERSION_RETRIES):
775
+ try:
776
+ sql_registered_model = self._get_registered_model(session, name)
777
+ sql_registered_model.last_updated_time = creation_time
778
+ version = next_version(sql_registered_model)
779
+ model_version = SqlModelVersion(
780
+ name=name,
781
+ version=version,
782
+ creation_time=creation_time,
783
+ last_updated_time=creation_time,
784
+ source=source,
785
+ storage_location=storage_location,
786
+ run_id=run_id,
787
+ run_link=run_link,
788
+ description=description,
789
+ )
790
+ tags_dict = {}
791
+ for tag in tags or []:
792
+ tags_dict[tag.key] = tag.value
793
+ model_version.model_version_tags = [
794
+ SqlModelVersionTag(key=key, value=value) for key, value in tags_dict.items()
795
+ ]
796
+ session.add_all([sql_registered_model, model_version])
797
+ session.flush()
798
+ return self._populate_model_version_aliases(
799
+ session, name, model_version.to_mlflow_entity()
800
+ )
801
+ except sqlalchemy.exc.IntegrityError:
802
+ more_retries = self.CREATE_MODEL_VERSION_RETRIES - attempt - 1
803
+ _logger.info(
804
+ "Model Version creation error (name=%s) Retrying %s more time%s.",
805
+ name,
806
+ str(more_retries),
807
+ "s" if more_retries > 1 else "",
808
+ )
809
+ raise MlflowException(
810
+ f"Model Version creation error (name={name}). Giving up after "
811
+ f"{self.CREATE_MODEL_VERSION_RETRIES} attempts."
812
+ )
813
+
814
+ @classmethod
815
+ def _populate_model_version_aliases(cls, session, name, version):
816
+ model_aliases = cls._get_registered_model(session, name).registered_model_aliases
817
+ version.aliases = [
818
+ alias.alias for alias in model_aliases if alias.version == version.version
819
+ ]
820
+ return version
821
+
822
+ @classmethod
823
+ def _get_model_version_from_db(cls, session, name, version, conditions, query_options=None):
824
+ if query_options is None:
825
+ query_options = []
826
+ versions = session.query(SqlModelVersion).options(*query_options).filter(*conditions).all()
827
+
828
+ if len(versions) == 0:
829
+ raise MlflowException(
830
+ f"Model Version (name={name}, version={version}) not found",
831
+ RESOURCE_DOES_NOT_EXIST,
832
+ )
833
+ if len(versions) > 1:
834
+ raise MlflowException(
835
+ f"Expected only 1 model version with (name={name}, version={version}). "
836
+ f"Found {len(versions)}.",
837
+ INVALID_STATE,
838
+ )
839
+ return versions[0]
840
+
841
+ @classmethod
842
+ def _get_sql_model_version(cls, session, name, version, eager=False): # noqa: D417
843
+ """
844
+ Args:
845
+ eager: If ``True``, eagerly loads the model version's tags.
846
+ If ``False``, these attributes are not eagerly loaded and
847
+ will be loaded when their corresponding object properties
848
+ are accessed from the resulting ``SqlModelVersion`` object.
849
+ """
850
+ _validate_model_name(name)
851
+ _validate_model_version(version)
852
+ query_options = cls._get_eager_model_version_query_options() if eager else []
853
+ conditions = [
854
+ SqlModelVersion.name == name,
855
+ SqlModelVersion.version == version,
856
+ SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL,
857
+ ]
858
+ return cls._get_model_version_from_db(session, name, version, conditions, query_options)
859
+
860
+ def _get_sql_model_version_including_deleted(self, name, version):
861
+ """
862
+ Private method to retrieve model versions including those that are internally deleted.
863
+ Used in tests to verify redaction behavior on deletion.
864
+
865
+ Args:
866
+ name: Registered model name.
867
+ version: Registered model version.
868
+
869
+ Returns:
870
+ A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
871
+ """
872
+ with self.ManagedSessionMaker() as session:
873
+ conditions = [
874
+ SqlModelVersion.name == name,
875
+ SqlModelVersion.version == version,
876
+ ]
877
+ sql_model_version = self._get_model_version_from_db(session, name, version, conditions)
878
+ return self._populate_model_version_aliases(
879
+ session, name, sql_model_version.to_mlflow_entity()
880
+ )
881
+
882
+ def update_model_version(self, name, version, description=None):
883
+ """
884
+ Update metadata associated with a model version in backend.
885
+
886
+ Args:
887
+ name: Registered model name.
888
+ version: Registered model version.
889
+ description: New model description.
890
+
891
+ Returns:
892
+ A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
893
+
894
+ """
895
+ with self.ManagedSessionMaker() as session:
896
+ updated_time = get_current_time_millis()
897
+ sql_model_version = self._get_sql_model_version(session, name=name, version=version)
898
+ sql_model_version.description = description
899
+ sql_model_version.last_updated_time = updated_time
900
+ session.add(sql_model_version)
901
+ return self._populate_model_version_aliases(
902
+ session, name, sql_model_version.to_mlflow_entity()
903
+ )
904
+
905
+ def transition_model_version_stage(self, name, version, stage, archive_existing_versions):
906
+ """
907
+ Update model version stage.
908
+
909
+ Args:
910
+ name: Registered model name.
911
+ version: Registered model version.
912
+ stage: New desired stage for this model version.
913
+ archive_existing_versions: If this flag is set to ``True``, all existing model
914
+ versions in the stage will be automatically moved to the "archived" stage. Only
915
+ valid when ``stage`` is ``"staging"`` or ``"production"`` otherwise an error will
916
+ be raised.
917
+
918
+ Returns:
919
+ A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
920
+
921
+ """
922
+ is_active_stage = get_canonical_stage(stage) in DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS
923
+ if archive_existing_versions and not is_active_stage:
924
+ msg_tpl = (
925
+ "Model version transition cannot archive existing model versions "
926
+ "because '{}' is not an Active stage. Valid stages are {}"
927
+ )
928
+ raise MlflowException(msg_tpl.format(stage, DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS))
929
+
930
+ with self.ManagedSessionMaker() as session:
931
+ last_updated_time = get_current_time_millis()
932
+
933
+ model_versions = []
934
+ if archive_existing_versions:
935
+ conditions = [
936
+ SqlModelVersion.name == name,
937
+ SqlModelVersion.version != version,
938
+ SqlModelVersion.current_stage == get_canonical_stage(stage),
939
+ ]
940
+ model_versions = session.query(SqlModelVersion).filter(*conditions).all()
941
+ for mv in model_versions:
942
+ mv.current_stage = STAGE_ARCHIVED
943
+ mv.last_updated_time = last_updated_time
944
+
945
+ sql_model_version = self._get_sql_model_version(
946
+ session=session, name=name, version=version
947
+ )
948
+ sql_model_version.current_stage = get_canonical_stage(stage)
949
+ sql_model_version.last_updated_time = last_updated_time
950
+ sql_registered_model = sql_model_version.registered_model
951
+ sql_registered_model.last_updated_time = last_updated_time
952
+ session.add_all([*model_versions, sql_model_version, sql_registered_model])
953
+ return self._populate_model_version_aliases(
954
+ session, name, sql_model_version.to_mlflow_entity()
955
+ )
956
+
957
+ def delete_model_version(self, name, version):
958
+ """
959
+ Delete model version in backend.
960
+
961
+ Args:
962
+ name: Registered model name.
963
+ version: Registered model version.
964
+
965
+ Returns:
966
+ None
967
+ """
968
+ # currently delete model version still keeps the tags associated with the version
969
+ with self.ManagedSessionMaker() as session:
970
+ updated_time = get_current_time_millis()
971
+ sql_model_version = self._get_sql_model_version(session, name, version)
972
+ sql_registered_model = sql_model_version.registered_model
973
+ sql_registered_model.last_updated_time = updated_time
974
+ aliases = sql_registered_model.registered_model_aliases
975
+ for alias in aliases:
976
+ if alias.version == version:
977
+ session.delete(alias)
978
+ sql_model_version.current_stage = STAGE_DELETED_INTERNAL
979
+ sql_model_version.last_updated_time = updated_time
980
+ sql_model_version.description = None
981
+ sql_model_version.user_id = None
982
+ sql_model_version.source = "REDACTED-SOURCE-PATH"
983
+ sql_model_version.run_id = "REDACTED-RUN-ID"
984
+ sql_model_version.run_link = "REDACTED-RUN-LINK"
985
+ sql_model_version.status_message = None
986
+ session.add_all([sql_registered_model, sql_model_version])
987
+
988
+ def get_model_version(self, name, version):
989
+ """
990
+ Get the model version instance by name and version.
991
+
992
+ Args:
993
+ name: Registered model name.
994
+ version: Registered model version.
995
+
996
+ Returns:
997
+ A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
998
+ """
999
+ with self.ManagedSessionMaker() as session:
1000
+ sql_model_version = self._get_sql_model_version(session, name, version, eager=True)
1001
+ return self._populate_model_version_aliases(
1002
+ session, name, sql_model_version.to_mlflow_entity()
1003
+ )
1004
+
1005
+ def get_model_version_download_uri(self, name, version):
1006
+ """
1007
+ Get the download location in Model Registry for this model version.
1008
+ NOTE: For first version of Model Registry, since the models are not copied over to another
1009
+ location, download URI points to input source path.
1010
+
1011
+ Args:
1012
+ name: Registered model name.
1013
+ version: Registered model version.
1014
+
1015
+ Returns:
1016
+ A single URI location that allows reads for downloading.
1017
+ """
1018
+ with self.ManagedSessionMaker() as session:
1019
+ sql_model_version = self._get_sql_model_version(session, name, version)
1020
+ return sql_model_version.storage_location or sql_model_version.source
1021
+
1022
+ def search_model_versions(
1023
+ self,
1024
+ filter_string=None,
1025
+ max_results=SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT,
1026
+ order_by=None,
1027
+ page_token=None,
1028
+ ):
1029
+ """
1030
+ Search for model versions in backend that satisfy the filter criteria.
1031
+
1032
+ Args:
1033
+ filter_string: A filter string expression. Currently supports a single filter
1034
+ condition either name of model like ``name = 'model_name'`` or
1035
+ ``run_id = '...'``.
1036
+ max_results: Maximum number of model versions desired.
1037
+ order_by: List of column names with ASC|DESC annotation, to be used for ordering
1038
+ matching search results.
1039
+ page_token: Token specifying the next page of results. It should be obtained from
1040
+ a ``search_model_versions`` call.
1041
+
1042
+ Returns:
1043
+ A PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
1044
+ objects that satisfy the search expressions. The pagination token for the next
1045
+ page can be obtained via the ``token`` attribute of the object.
1046
+
1047
+ """
1048
+ if not isinstance(max_results, int) or max_results < 1:
1049
+ raise MlflowException(
1050
+ "Invalid value for max_results. It must be a positive integer,"
1051
+ f" but got {max_results}",
1052
+ INVALID_PARAMETER_VALUE,
1053
+ )
1054
+
1055
+ if max_results > SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD:
1056
+ raise MlflowException(
1057
+ "Invalid value for request parameter max_results. It must be at most "
1058
+ f"{SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
1059
+ INVALID_PARAMETER_VALUE,
1060
+ )
1061
+
1062
+ parsed_filters = SearchModelVersionUtils.parse_search_filter(filter_string)
1063
+
1064
+ filter_query = self._get_search_model_versions_filter_clauses(
1065
+ parsed_filters, self.engine.dialect.name
1066
+ )
1067
+
1068
+ parsed_orderby = self._parse_search_model_versions_order_by(
1069
+ order_by or ["last_updated_timestamp DESC", "name ASC", "version_number DESC"]
1070
+ )
1071
+ offset = SearchUtils.parse_start_offset_from_page_token(page_token)
1072
+ # we query for max_results + 1 items to check whether there is another page to return.
1073
+ # this remediates having to make another query which returns no items.
1074
+ max_results_for_query = max_results + 1
1075
+
1076
+ with self.ManagedSessionMaker() as session:
1077
+ query = (
1078
+ filter_query.options(*self._get_eager_model_version_query_options())
1079
+ .filter(SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL)
1080
+ .order_by(*parsed_orderby)
1081
+ .limit(max_results_for_query)
1082
+ )
1083
+ if page_token:
1084
+ query = query.offset(offset)
1085
+ sql_model_versions = session.execute(query).scalars(SqlModelVersion).all()
1086
+ next_page_token = self._compute_next_token(
1087
+ max_results_for_query, len(sql_model_versions), offset, max_results
1088
+ )
1089
+ model_versions = [mv.to_mlflow_entity() for mv in sql_model_versions][:max_results]
1090
+ return PagedList(model_versions, next_page_token)
1091
+
1092
+ @classmethod
1093
+ def _parse_search_model_versions_order_by(cls, order_by_list):
1094
+ """Sorts a set of model versions based on their natural ordering and an overriding set
1095
+ of order_bys. Model versions are naturally ordered first by name ascending, then by
1096
+ version ascending.
1097
+ """
1098
+ clauses = []
1099
+ observed_order_by_clauses = set()
1100
+ if order_by_list:
1101
+ for order_by_clause in order_by_list:
1102
+ (
1103
+ _,
1104
+ key,
1105
+ ascending,
1106
+ ) = SearchModelVersionUtils.parse_order_by_for_search_model_versions(
1107
+ order_by_clause
1108
+ )
1109
+ if key not in SearchModelVersionUtils.VALID_ORDER_BY_ATTRIBUTE_KEYS:
1110
+ raise MlflowException(
1111
+ f"Invalid order by key '{key}' specified. "
1112
+ "Valid keys are "
1113
+ f"{SearchModelVersionUtils.VALID_ORDER_BY_ATTRIBUTE_KEYS}",
1114
+ error_code=INVALID_PARAMETER_VALUE,
1115
+ )
1116
+ else:
1117
+ if key == "version_number":
1118
+ field = SqlModelVersion.version
1119
+ elif key == "creation_timestamp":
1120
+ field = SqlModelVersion.creation_time
1121
+ elif key == "last_updated_timestamp":
1122
+ field = SqlModelVersion.last_updated_time
1123
+ else:
1124
+ field = getattr(SqlModelVersion, key)
1125
+ if field.key in observed_order_by_clauses:
1126
+ raise MlflowException(f"`order_by` contains duplicate fields: {order_by_list}")
1127
+ observed_order_by_clauses.add(field.key)
1128
+ if ascending:
1129
+ clauses.append(field.asc())
1130
+ else:
1131
+ clauses.append(field.desc())
1132
+
1133
+ if SqlModelVersion.name.key not in observed_order_by_clauses:
1134
+ clauses.append(SqlModelVersion.name.asc())
1135
+ if SqlModelVersion.version.key not in observed_order_by_clauses:
1136
+ clauses.append(SqlModelVersion.version.desc())
1137
+ return clauses
1138
+
1139
+ @classmethod
1140
+ def _get_model_version_tag(cls, session, name, version, key):
1141
+ tags = (
1142
+ session.query(SqlModelVersionTag)
1143
+ .filter(
1144
+ SqlModelVersionTag.name == name,
1145
+ SqlModelVersionTag.version == version,
1146
+ SqlModelVersionTag.key == key,
1147
+ )
1148
+ .all()
1149
+ )
1150
+ if len(tags) == 0:
1151
+ return None
1152
+ if len(tags) > 1:
1153
+ raise MlflowException(
1154
+ f"Expected only 1 model version tag with name={name}, version={version}, "
1155
+ f"key={key}. Found {len(tags)}.",
1156
+ INVALID_STATE,
1157
+ )
1158
+ return tags[0]
1159
+
1160
+ def set_model_version_tag(self, name, version, tag):
1161
+ """
1162
+ Set a tag for the model version.
1163
+
1164
+ Args:
1165
+ name: Registered model name.
1166
+ version: Registered model version.
1167
+ tag: :py:class:`mlflow.entities.model_registry.ModelVersionTag` instance to log.
1168
+
1169
+ Returns:
1170
+ None
1171
+ """
1172
+ _validate_model_name(name)
1173
+ _validate_model_version(version)
1174
+ _validate_model_version_tag(tag.key, tag.value)
1175
+ with self.ManagedSessionMaker() as session:
1176
+ # check if model version exists
1177
+ self._get_sql_model_version(session, name, version)
1178
+ session.merge(
1179
+ SqlModelVersionTag(name=name, version=version, key=tag.key, value=tag.value)
1180
+ )
1181
+
1182
+ def delete_model_version_tag(self, name, version, key):
1183
+ """
1184
+ Delete a tag associated with the model version.
1185
+
1186
+ Args:
1187
+ name: Registered model name.
1188
+ version: Registered model version.
1189
+ key: Tag key.
1190
+
1191
+ Returns:
1192
+ None
1193
+ """
1194
+ _validate_model_name(name)
1195
+ _validate_model_version(version)
1196
+ _validate_tag_name(key)
1197
+ with self.ManagedSessionMaker() as session:
1198
+ # check if model version exists
1199
+ self._get_sql_model_version(session, name, version)
1200
+ existing_tag = self._get_model_version_tag(session, name, version, key)
1201
+ if existing_tag is not None:
1202
+ session.delete(existing_tag)
1203
+
1204
+ @classmethod
1205
+ def _get_registered_model_alias(cls, session, name, alias):
1206
+ return (
1207
+ session.query(SqlRegisteredModelAlias)
1208
+ .filter(
1209
+ SqlRegisteredModelAlias.name == name,
1210
+ SqlRegisteredModelAlias.alias == alias,
1211
+ )
1212
+ .first()
1213
+ )
1214
+
1215
+ def set_registered_model_alias(self, name, alias, version):
1216
+ """
1217
+ Set a registered model alias pointing to a model version.
1218
+
1219
+ Args:
1220
+ name: Registered model name.
1221
+ alias: Name of the alias.
1222
+ version: Registered model version number.
1223
+
1224
+ Returns:
1225
+ None
1226
+ """
1227
+ _validate_model_name(name)
1228
+ _validate_model_alias_name(alias)
1229
+ _validate_model_version(version)
1230
+ with self.ManagedSessionMaker() as session:
1231
+ # check if model version exists
1232
+ self._get_sql_model_version(session, name, version)
1233
+ session.merge(SqlRegisteredModelAlias(name=name, alias=alias, version=version))
1234
+
1235
+ def delete_registered_model_alias(self, name, alias):
1236
+ """
1237
+ Delete an alias associated with a registered model.
1238
+
1239
+ Args:
1240
+ name: Registered model name.
1241
+ alias: Name of the alias.
1242
+
1243
+ Returns:
1244
+ None
1245
+ """
1246
+ _validate_model_name(name)
1247
+ _validate_model_alias_name(alias)
1248
+ with self.ManagedSessionMaker() as session:
1249
+ # check if registered model exists
1250
+ self._get_registered_model(session, name)
1251
+ existing_alias = self._get_registered_model_alias(session, name, alias)
1252
+ if existing_alias is not None:
1253
+ session.delete(existing_alias)
1254
+
1255
+ def get_model_version_by_alias(self, name, alias):
1256
+ """
1257
+ Get the model version instance by name and alias.
1258
+
1259
+ Args:
1260
+ name: Registered model name.
1261
+ alias: Name of the alias.
1262
+
1263
+ Returns:
1264
+ A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
1265
+ """
1266
+ _validate_model_name(name)
1267
+ _validate_model_alias_name(alias)
1268
+ with self.ManagedSessionMaker() as session:
1269
+ existing_alias = self._get_registered_model_alias(session, name, alias)
1270
+ if existing_alias is not None:
1271
+ sql_model_version = self._get_sql_model_version(
1272
+ session, existing_alias.name, existing_alias.version
1273
+ )
1274
+ return self._populate_model_version_aliases(
1275
+ session, name, sql_model_version.to_mlflow_entity()
1276
+ )
1277
+ else:
1278
+ raise MlflowException(
1279
+ f"Registered model alias {alias} not found.", INVALID_PARAMETER_VALUE
1280
+ )
1281
+
1282
+ def _await_model_version_creation(self, mv, await_creation_for):
1283
+ """
1284
+ Does not wait for the model version to become READY as a successful creation will
1285
+ immediately place the model version in a READY state.
1286
+ """