genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,3217 @@
1
+ # Define all the service endpoint handlers here.
2
+ import bisect
3
+ import io
4
+ import json
5
+ import logging
6
+ import os
7
+ import pathlib
8
+ import posixpath
9
+ import re
10
+ import tempfile
11
+ import time
12
+ import urllib
13
+ from functools import wraps
14
+ from typing import Optional
15
+
16
+ import requests
17
+ from flask import Response, current_app, jsonify, request, send_file
18
+ from google.protobuf import descriptor
19
+ from google.protobuf.json_format import ParseError
20
+
21
+ from mlflow.entities import (
22
+ DatasetInput,
23
+ ExperimentTag,
24
+ FileInfo,
25
+ Metric,
26
+ Param,
27
+ RunTag,
28
+ ViewType,
29
+ )
30
+ from mlflow.entities.logged_model import LoggedModel
31
+ from mlflow.entities.logged_model_input import LoggedModelInput
32
+ from mlflow.entities.logged_model_output import LoggedModelOutput
33
+ from mlflow.entities.logged_model_parameter import LoggedModelParameter
34
+ from mlflow.entities.logged_model_status import LoggedModelStatus
35
+ from mlflow.entities.logged_model_tag import LoggedModelTag
36
+ from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag
37
+ from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
38
+ from mlflow.entities.multipart_upload import MultipartUploadPart
39
+ from mlflow.entities.trace_info import TraceInfo
40
+ from mlflow.entities.trace_info_v2 import TraceInfoV2
41
+ from mlflow.entities.trace_status import TraceStatus
42
+ from mlflow.environment_variables import (
43
+ MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX,
44
+ MLFLOW_DEPLOYMENTS_TARGET,
45
+ )
46
+ from mlflow.exceptions import MlflowException, _UnsupportedMultipartUploadException
47
+ from mlflow.models import Model
48
+ from mlflow.protos import databricks_pb2
49
+ from mlflow.protos.databricks_pb2 import (
50
+ BAD_REQUEST,
51
+ INVALID_PARAMETER_VALUE,
52
+ RESOURCE_DOES_NOT_EXIST,
53
+ )
54
+ from mlflow.protos.mlflow_artifacts_pb2 import (
55
+ AbortMultipartUpload,
56
+ CompleteMultipartUpload,
57
+ CreateMultipartUpload,
58
+ DeleteArtifact,
59
+ DownloadArtifact,
60
+ MlflowArtifactsService,
61
+ UploadArtifact,
62
+ )
63
+ from mlflow.protos.mlflow_artifacts_pb2 import (
64
+ ListArtifacts as ListArtifactsMlflowArtifacts,
65
+ )
66
+ from mlflow.protos.model_registry_pb2 import (
67
+ CreateModelVersion,
68
+ CreateRegisteredModel,
69
+ DeleteModelVersion,
70
+ DeleteModelVersionTag,
71
+ DeleteRegisteredModel,
72
+ DeleteRegisteredModelAlias,
73
+ DeleteRegisteredModelTag,
74
+ GetLatestVersions,
75
+ GetModelVersion,
76
+ GetModelVersionByAlias,
77
+ GetModelVersionDownloadUri,
78
+ GetRegisteredModel,
79
+ ModelRegistryService,
80
+ RenameRegisteredModel,
81
+ SearchModelVersions,
82
+ SearchRegisteredModels,
83
+ SetModelVersionTag,
84
+ SetRegisteredModelAlias,
85
+ SetRegisteredModelTag,
86
+ TransitionModelVersionStage,
87
+ UpdateModelVersion,
88
+ UpdateRegisteredModel,
89
+ )
90
+ from mlflow.protos.service_pb2 import (
91
+ CreateExperiment,
92
+ CreateLoggedModel,
93
+ CreateRun,
94
+ DeleteExperiment,
95
+ DeleteLoggedModel,
96
+ DeleteLoggedModelTag,
97
+ DeleteRun,
98
+ DeleteTag,
99
+ DeleteTraces,
100
+ DeleteTraceTag,
101
+ EndTrace,
102
+ FinalizeLoggedModel,
103
+ GetExperiment,
104
+ GetExperimentByName,
105
+ GetLoggedModel,
106
+ GetMetricHistory,
107
+ GetMetricHistoryBulkInterval,
108
+ GetRun,
109
+ GetTraceInfo,
110
+ GetTraceInfoV3,
111
+ ListArtifacts,
112
+ ListLoggedModelArtifacts,
113
+ LogBatch,
114
+ LogInputs,
115
+ LogLoggedModelParamsRequest,
116
+ LogMetric,
117
+ LogModel,
118
+ LogOutputs,
119
+ LogParam,
120
+ MlflowService,
121
+ RestoreExperiment,
122
+ RestoreRun,
123
+ SearchDatasets,
124
+ SearchExperiments,
125
+ SearchLoggedModels,
126
+ SearchRuns,
127
+ SearchTraces,
128
+ SearchTracesV3,
129
+ SetExperimentTag,
130
+ SetLoggedModelTags,
131
+ SetTag,
132
+ SetTraceTag,
133
+ StartTrace,
134
+ StartTraceV3,
135
+ UpdateExperiment,
136
+ UpdateRun,
137
+ )
138
+ from mlflow.protos.service_pb2 import Trace as ProtoTrace
139
+ from mlflow.server.validation import _validate_content_type
140
+ from mlflow.store.artifact.artifact_repo import MultipartUploadMixin
141
+ from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
142
+ from mlflow.store.db.db_types import DATABASE_ENGINES
143
+ from mlflow.tracing.utils.artifact_utils import (
144
+ TRACE_DATA_FILE_NAME,
145
+ get_artifact_uri_for_trace,
146
+ )
147
+ from mlflow.tracking._model_registry import utils as registry_utils
148
+ from mlflow.tracking._model_registry.registry import ModelRegistryStoreRegistry
149
+ from mlflow.tracking._tracking_service import utils
150
+ from mlflow.tracking._tracking_service.registry import TrackingStoreRegistry
151
+ from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException
152
+ from mlflow.utils.file_utils import local_file_uri_to_path
153
+ from mlflow.utils.mime_type_utils import _guess_mime_type
154
+ from mlflow.utils.promptlab_utils import _create_promptlab_run_impl
155
+ from mlflow.utils.proto_json_utils import message_to_json, parse_dict
156
+ from mlflow.utils.security_validation import InputValidator, SecurityValidationError
157
+ from mlflow.utils.string_utils import is_string_type
158
+ from mlflow.utils.uri import is_local_uri, validate_path_is_safe, validate_query_string
159
+ from mlflow.utils.validation import (
160
+ _validate_batch_log_api_req,
161
+ invalid_value,
162
+ missing_value,
163
+ )
164
+
165
+ _logger = logging.getLogger(__name__)
166
+ _tracking_store = None
167
+ _model_registry_store = None
168
+ _artifact_repo = None
169
+ STATIC_PREFIX_ENV_VAR = "_MLFLOW_STATIC_PREFIX"
170
+ MAX_RUNS_GET_METRIC_HISTORY_BULK = 100
171
+ MAX_RESULTS_PER_RUN = 2500
172
+ MAX_RESULTS_GET_METRIC_HISTORY = 25000
173
+
174
+
175
+ class TrackingStoreRegistryWrapper(TrackingStoreRegistry):
176
+ def __init__(self):
177
+ super().__init__()
178
+ self.register("", self._get_file_store)
179
+ self.register("file", self._get_file_store)
180
+ for scheme in DATABASE_ENGINES:
181
+ self.register(scheme, self._get_sqlalchemy_store)
182
+ self.register_entrypoints()
183
+
184
+ @classmethod
185
+ def _get_file_store(cls, store_uri, artifact_uri):
186
+ from mlflow.store.tracking.file_store import FileStore
187
+
188
+ return FileStore(store_uri, artifact_uri)
189
+
190
+ @classmethod
191
+ def _get_sqlalchemy_store(cls, store_uri, artifact_uri):
192
+ from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
193
+
194
+ return SqlAlchemyStore(store_uri, artifact_uri)
195
+
196
+
197
+
198
+ class ModelRegistryStoreRegistryWrapper(ModelRegistryStoreRegistry):
199
+ def __init__(self):
200
+ super().__init__()
201
+ self.register("", self._get_file_store)
202
+ self.register("file", self._get_file_store)
203
+ for scheme in DATABASE_ENGINES:
204
+ self.register(scheme, self._get_sqlalchemy_store)
205
+ self.register_entrypoints()
206
+
207
+ @classmethod
208
+ def _get_file_store(cls, store_uri):
209
+ from mlflow.store.model_registry.file_store import FileStore
210
+
211
+ return FileStore(store_uri)
212
+
213
+ @classmethod
214
+ def _get_sqlalchemy_store(cls, store_uri):
215
+ from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore
216
+
217
+ return SqlAlchemyStore(store_uri)
218
+
219
+
220
+ _tracking_store_registry = TrackingStoreRegistryWrapper()
221
+ _model_registry_store_registry = ModelRegistryStoreRegistryWrapper()
222
+
223
+
224
+ def _get_artifact_repo_mlflow_artifacts():
225
+ """
226
+ Get an artifact repository specified by ``--artifacts-destination`` option for ``mlflow server``
227
+ command.
228
+ """
229
+ from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR
230
+
231
+ global _artifact_repo
232
+ if _artifact_repo is None:
233
+ _artifact_repo = get_artifact_repository(os.environ[ARTIFACTS_DESTINATION_ENV_VAR])
234
+ return _artifact_repo
235
+
236
+
237
+ def _get_trace_artifact_repo(trace_info: TraceInfo):
238
+ """
239
+ Resolve the artifact repository for fetching data for the given trace.
240
+
241
+ Args:
242
+ trace_info: The trace info object containing metadata about the trace.
243
+ """
244
+ artifact_uri = get_artifact_uri_for_trace(trace_info)
245
+
246
+ if _is_servable_proxied_run_artifact_root(artifact_uri):
247
+ # If the artifact location is a proxied run artifact root (e.g. mlflow-artifacts://...),
248
+ # we need to resolve it to the actual artifact location.
249
+ from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR
250
+
251
+ path = _get_proxied_run_artifact_destination_path(artifact_uri)
252
+ if not path:
253
+ raise MlflowException(
254
+ f"Failed to resolve the proxied run artifact URI: {artifact_uri}. ",
255
+ "Trace artifact URI must contain subpath to the trace data directory.",
256
+ error_code=BAD_REQUEST,
257
+ )
258
+ root = os.environ[ARTIFACTS_DESTINATION_ENV_VAR]
259
+ artifact_uri = posixpath.join(root, path)
260
+
261
+ # We don't set it to global var unlike run artifact, because the artifact repo has
262
+ # to be created with full trace artifact URI including request_id.
263
+ # e.g. s3://<experiment_id>/traces/<request_id>
264
+ artifact_repo = get_artifact_repository(artifact_uri)
265
+ else:
266
+ artifact_repo = get_artifact_repository(artifact_uri)
267
+ return artifact_repo
268
+
269
+
270
+ def _is_serving_proxied_artifacts():
271
+ """
272
+ Returns:
273
+ True if the MLflow server is serving proxied artifacts (i.e. acting as a proxy for
274
+ artifact upload / download / list operations), as would be enabled by specifying the
275
+ --serve-artifacts configuration option. False otherwise.
276
+ """
277
+ from mlflow.server import SERVE_ARTIFACTS_ENV_VAR
278
+
279
+ return os.environ.get(SERVE_ARTIFACTS_ENV_VAR, "false") == "true"
280
+
281
+
282
+ def _is_servable_proxied_run_artifact_root(run_artifact_root):
283
+ """
284
+ Determines whether or not the following are true:
285
+
286
+ - The specified Run artifact root is a proxied artifact root (i.e. an artifact root with scheme
287
+ ``http``, ``https``, or ``mlflow-artifacts``).
288
+
289
+ - The MLflow server is capable of resolving and accessing the underlying storage location
290
+ corresponding to the proxied artifact root, allowing it to fulfill artifact list and
291
+ download requests by using this storage location directly.
292
+
293
+ Args:
294
+ run_artifact_root: The Run artifact root location (URI).
295
+
296
+ Returns:
297
+ True if the specified Run artifact root refers to proxied artifacts that can be
298
+ served by this MLflow server (i.e. the server has access to the destination and
299
+ can respond to list and download requests for the artifact). False otherwise.
300
+ """
301
+ parsed_run_artifact_root = urllib.parse.urlparse(run_artifact_root)
302
+ # NB: If the run artifact root is a proxied artifact root (has scheme `http`, `https`, or
303
+ # `mlflow-artifacts`) *and* the MLflow server is configured to serve artifacts, the MLflow
304
+ # server always assumes that it has access to the underlying storage location for the proxied
305
+ # artifacts. This may not always be accurate. For example:
306
+ #
307
+ # An organization may initially use the MLflow server to serve Tracking API requests and proxy
308
+ # access to artifacts stored in Location A (via `mlflow server --serve-artifacts`). Then, for
309
+ # scalability and / or security purposes, the organization may decide to store artifacts in a
310
+ # new location B and set up a separate server (e.g. `mlflow server --artifacts-only`) to proxy
311
+ # access to artifacts stored in Location B.
312
+ #
313
+ # In this scenario, requests for artifacts stored in Location B that are sent to the original
314
+ # MLflow server will fail if the original MLflow server does not have access to Location B
315
+ # because it will assume that it can serve all proxied artifacts regardless of the underlying
316
+ # location. Such failures can be remediated by granting the original MLflow server access to
317
+ # Location B.
318
+ return (
319
+ parsed_run_artifact_root.scheme in ["http", "https", "mlflow-artifacts"]
320
+ and _is_serving_proxied_artifacts()
321
+ )
322
+
323
+
324
+ def _get_proxied_run_artifact_destination_path(proxied_artifact_root, relative_path=None):
325
+ """
326
+ Resolves the specified proxied artifact location within a Run to a concrete storage location.
327
+
328
+ Args:
329
+ proxied_artifact_root: The Run artifact root location (URI) with scheme ``http``,
330
+ ``https``, or `mlflow-artifacts` that can be resolved by the MLflow server to a
331
+ concrete storage location.
332
+ relative_path: The relative path of the destination within the specified
333
+ ``proxied_artifact_root``. If ``None``, the destination is assumed to be
334
+ the resolved ``proxied_artifact_root``.
335
+
336
+ Returns:
337
+ The storage location of the specified artifact.
338
+ """
339
+ parsed_proxied_artifact_root = urllib.parse.urlparse(proxied_artifact_root)
340
+ assert parsed_proxied_artifact_root.scheme in ["http", "https", "mlflow-artifacts"]
341
+
342
+ if parsed_proxied_artifact_root.scheme == "mlflow-artifacts":
343
+ # If the proxied artifact root is an `mlflow-artifacts` URI, the run artifact root path is
344
+ # simply the path component of the URI, since the fully-qualified format of an
345
+ # `mlflow-artifacts` URI is `mlflow-artifacts://<netloc>/path/to/artifact`
346
+ proxied_run_artifact_root_path = parsed_proxied_artifact_root.path.lstrip("/")
347
+ else:
348
+ # In this case, the proxied artifact root is an HTTP(S) URL referring to an mlflow-artifacts
349
+ # API route that can be used to download the artifact. These routes are always anchored at
350
+ # `/api/2.0/mlflow-artifacts/artifacts`. Accordingly, we split the path on this route anchor
351
+ # and interpret the rest of the path (everything after the route anchor) as the run artifact
352
+ # root path
353
+ mlflow_artifacts_http_route_anchor = "/api/2.0/mlflow-artifacts/artifacts/"
354
+ assert mlflow_artifacts_http_route_anchor in parsed_proxied_artifact_root.path
355
+
356
+ proxied_run_artifact_root_path = parsed_proxied_artifact_root.path.split(
357
+ mlflow_artifacts_http_route_anchor
358
+ )[1].lstrip("/")
359
+
360
+ return (
361
+ posixpath.join(proxied_run_artifact_root_path, relative_path)
362
+ if relative_path is not None
363
+ else proxied_run_artifact_root_path
364
+ )
365
+
366
+
367
+ def _get_tracking_store(backend_store_uri=None, default_artifact_root=None):
368
+ from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR
369
+
370
+ global _tracking_store
371
+ if _tracking_store is None:
372
+ store_uri = backend_store_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
373
+ artifact_root = default_artifact_root or os.environ.get(ARTIFACT_ROOT_ENV_VAR, None)
374
+ _tracking_store = _tracking_store_registry.get_store(store_uri, artifact_root)
375
+ utils.set_tracking_uri(store_uri)
376
+ return _tracking_store
377
+
378
+
379
+ def _get_model_registry_store(registry_store_uri=None):
380
+ from mlflow.server import BACKEND_STORE_URI_ENV_VAR, REGISTRY_STORE_URI_ENV_VAR
381
+
382
+ global _model_registry_store
383
+ if _model_registry_store is None:
384
+ store_uri = (
385
+ registry_store_uri
386
+ or os.environ.get(REGISTRY_STORE_URI_ENV_VAR, None)
387
+ or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
388
+ )
389
+ _model_registry_store = _model_registry_store_registry.get_store(store_uri)
390
+ registry_utils.set_registry_uri(store_uri)
391
+ return _model_registry_store
392
+
393
+
394
+ def initialize_backend_stores(
395
+ backend_store_uri=None, registry_store_uri=None, default_artifact_root=None
396
+ ):
397
+ _get_tracking_store(backend_store_uri, default_artifact_root)
398
+ try:
399
+ _get_model_registry_store(registry_store_uri)
400
+ except UnsupportedModelRegistryStoreURIException:
401
+ pass
402
+
403
+
404
+ def _assert_string(x):
405
+ assert isinstance(x, str)
406
+
407
+
408
+ def _assert_intlike(x):
409
+ try:
410
+ x = int(x)
411
+ except ValueError:
412
+ pass
413
+
414
+ assert isinstance(x, int)
415
+
416
+
417
+ def _assert_bool(x):
418
+ assert isinstance(x, bool)
419
+
420
+
421
+ def _assert_floatlike(x):
422
+ try:
423
+ x = float(x)
424
+ except ValueError:
425
+ pass
426
+
427
+ assert isinstance(x, float)
428
+
429
+
430
+ def _assert_array(x):
431
+ assert isinstance(x, list)
432
+
433
+
434
+ def _assert_map_key_present(x):
435
+ _assert_array(x)
436
+ for entry in x:
437
+ _assert_required(entry.get("key"))
438
+
439
+
440
+ def _assert_required(x, path=None):
441
+ if path is None:
442
+ assert x is not None
443
+ # When parsing JSON payloads via proto, absent string fields
444
+ # are expressed as empty strings
445
+ assert x != ""
446
+ else:
447
+ assert x is not None, missing_value(path)
448
+ assert x != "", missing_value(path)
449
+
450
+
451
+ def _assert_less_than_or_equal(x, max_value, message=None):
452
+ if x > max_value:
453
+ raise AssertionError(message) if message else AssertionError()
454
+
455
+
456
+ def _assert_intlike_within_range(x, min_value, max_value, message=None):
457
+ if not min_value <= x <= max_value:
458
+ raise AssertionError(message) if message else AssertionError()
459
+
460
+
461
+ def _assert_item_type_string(x):
462
+ assert all(isinstance(item, str) for item in x)
463
+
464
+
465
+ _TYPE_VALIDATORS = {
466
+ _assert_intlike,
467
+ _assert_string,
468
+ _assert_bool,
469
+ _assert_floatlike,
470
+ _assert_array,
471
+ _assert_item_type_string,
472
+ }
473
+
474
+
475
+ def _validate_param_against_schema(schema, param, value, proto_parsing_succeeded=False):
476
+ """
477
+ Attempts to validate a single parameter against a specified schema. Examples of the elements of
478
+ the schema are type assertions and checks for required parameters. Returns None on validation
479
+ success. Otherwise, raises an MLFlowException if an assertion fails. This method is intended
480
+ to be called for side effects.
481
+
482
+ Args:
483
+ schema: A list of functions to validate the parameter against.
484
+ param: The string name of the parameter being validated.
485
+ value: The corresponding value of the `param` being validated.
486
+ proto_parsing_succeeded: A boolean value indicating whether proto parsing succeeded.
487
+ If the proto was successfully parsed, we assume all of the types of the parameters in
488
+ the request body were correctly specified, and thus we skip validating types. If proto
489
+ parsing failed, then we validate types in addition to the rest of the schema. For
490
+ details, see https://github.com/mlflow/mlflow/pull/5458#issuecomment-1080880870.
491
+ """
492
+
493
+ for f in schema:
494
+ if f in _TYPE_VALIDATORS and proto_parsing_succeeded:
495
+ continue
496
+
497
+ try:
498
+ f(value)
499
+ except AssertionError as e:
500
+ if e.args:
501
+ message = e.args[0]
502
+ elif f == _assert_required:
503
+ message = f"Missing value for required parameter '{param}'."
504
+ else:
505
+ message = invalid_value(
506
+ param, value, f" Hint: Value was of type '{type(value).__name__}'."
507
+ )
508
+ raise MlflowException(
509
+ message=(
510
+ message + " See the API docs for more information about request parameters."
511
+ ),
512
+ error_code=INVALID_PARAMETER_VALUE,
513
+ )
514
+
515
+ return None
516
+
517
+
518
+ def _get_request_json(flask_request=request):
519
+ _validate_content_type(flask_request, ["application/json"])
520
+ return flask_request.get_json(force=True, silent=True)
521
+
522
+
523
+ def _get_request_message(request_message, flask_request=request, schema=None):
524
+ if flask_request.method == "GET" and flask_request.args:
525
+ # Convert atomic values of repeated fields to lists before calling protobuf deserialization.
526
+ # Context: We parse the parameter string into a dictionary outside of protobuf since
527
+ # protobuf does not know how to read the query parameters directly. The query parser above
528
+ # has no type information and hence any parameter that occurs exactly once is parsed as an
529
+ # atomic value. Since protobuf requires that the values of repeated fields are lists,
530
+ # deserialization will fail unless we do the fix below.
531
+ request_json = {}
532
+ for field in request_message.DESCRIPTOR.fields:
533
+ if field.name not in flask_request.args:
534
+ continue
535
+
536
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
537
+ request_json[field.name] = flask_request.args.getlist(field.name)
538
+ else:
539
+ request_json[field.name] = flask_request.args.get(field.name)
540
+ else:
541
+ request_json = _get_request_json(flask_request)
542
+
543
+ # Older clients may post their JSON double-encoded as strings, so the get_json
544
+ # above actually converts it to a string. Therefore, we check this condition
545
+ # (which we can tell for sure because any proper request should be a dictionary),
546
+ # and decode it a second time.
547
+ if is_string_type(request_json):
548
+ request_json = json.loads(request_json)
549
+
550
+ # If request doesn't have json body then assume it's empty.
551
+ if request_json is None:
552
+ request_json = {}
553
+
554
+ proto_parsing_succeeded = True
555
+ try:
556
+ parse_dict(request_json, request_message)
557
+ except ParseError:
558
+ proto_parsing_succeeded = False
559
+
560
+ schema = schema or {}
561
+ for schema_key, schema_validation_fns in schema.items():
562
+ if schema_key in request_json or _assert_required in schema_validation_fns:
563
+ value = request_json.get(schema_key)
564
+ if schema_key == "run_id" and value is None and "run_uuid" in request_json:
565
+ value = request_json.get("run_uuid")
566
+ _validate_param_against_schema(
567
+ schema=schema_validation_fns,
568
+ param=schema_key,
569
+ value=value,
570
+ proto_parsing_succeeded=proto_parsing_succeeded,
571
+ )
572
+
573
+ return request_message
574
+
575
+
576
+ def _response_with_file_attachment_headers(file_path, response):
577
+ mime_type = _guess_mime_type(file_path)
578
+ filename = pathlib.Path(file_path).name
579
+ response.mimetype = mime_type
580
+ content_disposition_header_name = "Content-Disposition"
581
+ if content_disposition_header_name not in response.headers:
582
+ response.headers[content_disposition_header_name] = f"attachment; filename={filename}"
583
+ response.headers["X-Content-Type-Options"] = "nosniff"
584
+ response.headers["Content-Type"] = mime_type
585
+ return response
586
+
587
+
588
+ def _send_artifact(artifact_repository, path):
589
+ file_path = os.path.abspath(artifact_repository.download_artifacts(path))
590
+ # Always send artifacts as attachments to prevent the browser from displaying them on our web
591
+ # server's domain, which might enable XSS.
592
+ mime_type = _guess_mime_type(file_path)
593
+ file_sender_response = send_file(file_path, mimetype=mime_type, as_attachment=True)
594
+ return _response_with_file_attachment_headers(file_path, file_sender_response)
595
+
596
+
597
+ def catch_mlflow_exception(func):
598
+ @wraps(func)
599
+ def wrapper(*args, **kwargs):
600
+ try:
601
+ return func(*args, **kwargs)
602
+ except MlflowException as e:
603
+ response = Response(mimetype="application/json")
604
+ response.set_data(e.serialize_as_json())
605
+ response.status_code = e.get_http_status_code()
606
+ return response
607
+
608
+ return wrapper
609
+
610
+
611
+ def _disable_unless_serve_artifacts(func):
612
+ @wraps(func)
613
+ def wrapper(*args, **kwargs):
614
+ if not _is_serving_proxied_artifacts():
615
+ return Response(
616
+ (
617
+ f"Endpoint: {request.url_rule} disabled due to the mlflow server running "
618
+ "with `--no-serve-artifacts`. To enable artifacts server functionality, "
619
+ "run `mlflow server` with `--serve-artifacts`"
620
+ ),
621
+ 503,
622
+ )
623
+ return func(*args, **kwargs)
624
+
625
+ return wrapper
626
+
627
+
628
+ def _disable_if_artifacts_only(func):
629
+ @wraps(func)
630
+ def wrapper(*args, **kwargs):
631
+ from mlflow.server import ARTIFACTS_ONLY_ENV_VAR
632
+
633
+ if os.environ.get(ARTIFACTS_ONLY_ENV_VAR):
634
+ return Response(
635
+ (
636
+ f"Endpoint: {request.url_rule} disabled due to the mlflow server running "
637
+ "in `--artifacts-only` mode. To enable tracking server functionality, run "
638
+ "`mlflow server` without `--artifacts-only`"
639
+ ),
640
+ 503,
641
+ )
642
+ return func(*args, **kwargs)
643
+
644
+ return wrapper
645
+
646
+
647
+ @catch_mlflow_exception
648
+ def get_artifact_handler():
649
+ run_id = request.args.get("run_id") or request.args.get("run_uuid")
650
+ path = request.args["path"]
651
+ path = validate_path_is_safe(path)
652
+ run = _get_tracking_store().get_run(run_id)
653
+
654
+ if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
655
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
656
+ artifact_path = _get_proxied_run_artifact_destination_path(
657
+ proxied_artifact_root=run.info.artifact_uri,
658
+ relative_path=path,
659
+ )
660
+ else:
661
+ artifact_repo = _get_artifact_repo(run)
662
+ artifact_path = path
663
+
664
+ return _send_artifact(artifact_repo, artifact_path)
665
+
666
+
667
+ def _not_implemented():
668
+ response = Response()
669
+ response.status_code = 404
670
+ return response
671
+
672
+
673
+ # Tracking Server APIs
674
+
675
+
676
+ @catch_mlflow_exception
677
+ @_disable_if_artifacts_only
678
+ def _create_experiment():
679
+ request_message = _get_request_message(
680
+ CreateExperiment(),
681
+ schema={
682
+ "name": [_assert_required, _assert_string],
683
+ "artifact_location": [_assert_string],
684
+ "tags": [_assert_array],
685
+ },
686
+ )
687
+
688
+ # Security validation for experiment name
689
+ try:
690
+ validated_name = InputValidator.validate_experiment_name(request_message.name)
691
+ except SecurityValidationError as e:
692
+ raise MlflowException(
693
+ f"Invalid experiment name: {e}",
694
+ error_code=INVALID_PARAMETER_VALUE,
695
+ )
696
+
697
+ # Security validation for tags
698
+ validated_tags = []
699
+ for tag in request_message.tags:
700
+ try:
701
+ validated_key = InputValidator.validate_tag_key(tag.key)
702
+ validated_value = InputValidator.validate_tag_value(tag.value)
703
+ validated_tags.append(ExperimentTag(validated_key, validated_value))
704
+ except SecurityValidationError as e:
705
+ raise MlflowException(
706
+ f"Invalid tag: {e}",
707
+ error_code=INVALID_PARAMETER_VALUE,
708
+ )
709
+
710
+ # Security validation for artifact location
711
+ if request_message.artifact_location:
712
+ try:
713
+ validated_artifact_location = InputValidator.validate_uri(request_message.artifact_location)
714
+ except SecurityValidationError as e:
715
+ raise MlflowException(
716
+ f"Invalid artifact location: {e}",
717
+ error_code=INVALID_PARAMETER_VALUE,
718
+ )
719
+ else:
720
+ validated_artifact_location = request_message.artifact_location
721
+
722
+ # Validate query string in artifact location to prevent attacks
723
+ parsed_artifact_location = urllib.parse.urlparse(validated_artifact_location)
724
+ if parsed_artifact_location.fragment or parsed_artifact_location.params:
725
+ raise MlflowException(
726
+ "'artifact_location' URL can't include fragments or params.",
727
+ error_code=INVALID_PARAMETER_VALUE,
728
+ )
729
+ validate_query_string(parsed_artifact_location.query)
730
+ experiment_id = _get_tracking_store().create_experiment(
731
+ validated_name, validated_artifact_location, validated_tags
732
+ )
733
+ response_message = CreateExperiment.Response()
734
+ response_message.experiment_id = experiment_id
735
+ response = Response(mimetype="application/json")
736
+ response.set_data(message_to_json(response_message))
737
+ return response
738
+
739
+
740
+ @catch_mlflow_exception
741
+ @_disable_if_artifacts_only
742
+ def _get_experiment():
743
+ request_message = _get_request_message(
744
+ GetExperiment(), schema={"experiment_id": [_assert_required, _assert_string]}
745
+ )
746
+ response_message = get_experiment_impl(request_message)
747
+ response = Response(mimetype="application/json")
748
+ response.set_data(message_to_json(response_message))
749
+ return response
750
+
751
+
752
+ def get_experiment_impl(request_message):
753
+ response_message = GetExperiment.Response()
754
+ experiment = _get_tracking_store().get_experiment(request_message.experiment_id).to_proto()
755
+ response_message.experiment.MergeFrom(experiment)
756
+ return response_message
757
+
758
+
759
+ @catch_mlflow_exception
760
+ @_disable_if_artifacts_only
761
+ def _get_experiment_by_name():
762
+ request_message = _get_request_message(
763
+ GetExperimentByName(),
764
+ schema={"experiment_name": [_assert_required, _assert_string]},
765
+ )
766
+ response_message = GetExperimentByName.Response()
767
+ store_exp = _get_tracking_store().get_experiment_by_name(request_message.experiment_name)
768
+ if store_exp is None:
769
+ raise MlflowException(
770
+ f"Could not find experiment with name '{request_message.experiment_name}'",
771
+ error_code=RESOURCE_DOES_NOT_EXIST,
772
+ )
773
+ experiment = store_exp.to_proto()
774
+ response_message.experiment.MergeFrom(experiment)
775
+ response = Response(mimetype="application/json")
776
+ response.set_data(message_to_json(response_message))
777
+ return response
778
+
779
+
780
+ @catch_mlflow_exception
781
+ @_disable_if_artifacts_only
782
+ def _delete_experiment():
783
+ request_message = _get_request_message(
784
+ DeleteExperiment(), schema={"experiment_id": [_assert_required, _assert_string]}
785
+ )
786
+ _get_tracking_store().delete_experiment(request_message.experiment_id)
787
+ response_message = DeleteExperiment.Response()
788
+ response = Response(mimetype="application/json")
789
+ response.set_data(message_to_json(response_message))
790
+ return response
791
+
792
+
793
+ @catch_mlflow_exception
794
+ @_disable_if_artifacts_only
795
+ def _restore_experiment():
796
+ request_message = _get_request_message(
797
+ RestoreExperiment(),
798
+ schema={"experiment_id": [_assert_required, _assert_string]},
799
+ )
800
+ _get_tracking_store().restore_experiment(request_message.experiment_id)
801
+ response_message = RestoreExperiment.Response()
802
+ response = Response(mimetype="application/json")
803
+ response.set_data(message_to_json(response_message))
804
+ return response
805
+
806
+
807
+ @catch_mlflow_exception
808
+ @_disable_if_artifacts_only
809
+ def _update_experiment():
810
+ request_message = _get_request_message(
811
+ UpdateExperiment(),
812
+ schema={
813
+ "experiment_id": [_assert_required, _assert_string],
814
+ "new_name": [_assert_string, _assert_required],
815
+ },
816
+ )
817
+ if request_message.new_name:
818
+ _get_tracking_store().rename_experiment(
819
+ request_message.experiment_id, request_message.new_name
820
+ )
821
+ response_message = UpdateExperiment.Response()
822
+ response = Response(mimetype="application/json")
823
+ response.set_data(message_to_json(response_message))
824
+ return response
825
+
826
+
827
+ @catch_mlflow_exception
828
+ @_disable_if_artifacts_only
829
+ def _create_run():
830
+ request_message = _get_request_message(
831
+ CreateRun(),
832
+ schema={
833
+ "experiment_id": [_assert_string],
834
+ "start_time": [_assert_intlike],
835
+ "run_name": [_assert_string],
836
+ },
837
+ )
838
+
839
+ tags = [RunTag(tag.key, tag.value) for tag in request_message.tags]
840
+ run = _get_tracking_store().create_run(
841
+ experiment_id=request_message.experiment_id,
842
+ user_id=request_message.user_id,
843
+ start_time=request_message.start_time,
844
+ tags=tags,
845
+ run_name=request_message.run_name,
846
+ )
847
+
848
+ response_message = CreateRun.Response()
849
+ response_message.run.MergeFrom(run.to_proto())
850
+ response = Response(mimetype="application/json")
851
+ response.set_data(message_to_json(response_message))
852
+ return response
853
+
854
+
855
+ @catch_mlflow_exception
856
+ @_disable_if_artifacts_only
857
+ def _update_run():
858
+ request_message = _get_request_message(
859
+ UpdateRun(),
860
+ schema={
861
+ "run_id": [_assert_required, _assert_string],
862
+ "end_time": [_assert_intlike],
863
+ "status": [_assert_string],
864
+ "run_name": [_assert_string],
865
+ },
866
+ )
867
+ run_id = request_message.run_id or request_message.run_uuid
868
+ run_name = request_message.run_name if request_message.HasField("run_name") else None
869
+ end_time = request_message.end_time if request_message.HasField("end_time") else None
870
+ status = request_message.status if request_message.HasField("status") else None
871
+ updated_info = _get_tracking_store().update_run_info(run_id, status, end_time, run_name)
872
+ response_message = UpdateRun.Response(run_info=updated_info.to_proto())
873
+ response = Response(mimetype="application/json")
874
+ response.set_data(message_to_json(response_message))
875
+ return response
876
+
877
+
878
+ @catch_mlflow_exception
879
+ @_disable_if_artifacts_only
880
+ def _delete_run():
881
+ request_message = _get_request_message(
882
+ DeleteRun(), schema={"run_id": [_assert_required, _assert_string]}
883
+ )
884
+ _get_tracking_store().delete_run(request_message.run_id)
885
+ response_message = DeleteRun.Response()
886
+ response = Response(mimetype="application/json")
887
+ response.set_data(message_to_json(response_message))
888
+ return response
889
+
890
+
891
+ @catch_mlflow_exception
892
+ @_disable_if_artifacts_only
893
+ def _restore_run():
894
+ request_message = _get_request_message(
895
+ RestoreRun(), schema={"run_id": [_assert_required, _assert_string]}
896
+ )
897
+ _get_tracking_store().restore_run(request_message.run_id)
898
+ response_message = RestoreRun.Response()
899
+ response = Response(mimetype="application/json")
900
+ response.set_data(message_to_json(response_message))
901
+ return response
902
+
903
+
904
+ @catch_mlflow_exception
905
+ @_disable_if_artifacts_only
906
+ def _log_metric():
907
+ request_message = _get_request_message(
908
+ LogMetric(),
909
+ schema={
910
+ "run_id": [_assert_required, _assert_string],
911
+ "key": [_assert_required, _assert_string],
912
+ "value": [_assert_required, _assert_floatlike],
913
+ "timestamp": [_assert_intlike, _assert_required],
914
+ "step": [_assert_intlike],
915
+ "model_id": [_assert_string],
916
+ "dataset_name": [_assert_string],
917
+ "dataset_digest": [_assert_string],
918
+ },
919
+ )
920
+
921
+ # Security validation for metric key
922
+ try:
923
+ validated_key = InputValidator.validate_metric_key(request_message.key)
924
+ except SecurityValidationError as e:
925
+ raise MlflowException(
926
+ f"Invalid metric key: {e}",
927
+ error_code=INVALID_PARAMETER_VALUE,
928
+ )
929
+
930
+ metric = Metric(
931
+ validated_key,
932
+ request_message.value,
933
+ request_message.timestamp,
934
+ request_message.step,
935
+ request_message.model_id or None,
936
+ request_message.dataset_name or None,
937
+ request_message.dataset_digest or None,
938
+ request_message.run_id or None,
939
+ )
940
+ run_id = request_message.run_id or request_message.run_uuid
941
+ _get_tracking_store().log_metric(run_id, metric)
942
+ response_message = LogMetric.Response()
943
+ response = Response(mimetype="application/json")
944
+ response.set_data(message_to_json(response_message))
945
+ return response
946
+
947
+
948
+ @catch_mlflow_exception
949
+ @_disable_if_artifacts_only
950
+ def _log_param():
951
+ request_message = _get_request_message(
952
+ LogParam(),
953
+ schema={
954
+ "run_id": [_assert_required, _assert_string],
955
+ "key": [_assert_required, _assert_string],
956
+ "value": [_assert_string],
957
+ },
958
+ )
959
+
960
+ # Security validation for parameter key and value
961
+ try:
962
+ validated_key = InputValidator.validate_param_key(request_message.key)
963
+ validated_value = InputValidator.validate_param_value(request_message.value)
964
+ except SecurityValidationError as e:
965
+ raise MlflowException(
966
+ f"Invalid parameter: {e}",
967
+ error_code=INVALID_PARAMETER_VALUE,
968
+ )
969
+
970
+ param = Param(validated_key, validated_value)
971
+ run_id = request_message.run_id or request_message.run_uuid
972
+ _get_tracking_store().log_param(run_id, param)
973
+ response_message = LogParam.Response()
974
+ response = Response(mimetype="application/json")
975
+ response.set_data(message_to_json(response_message))
976
+ return response
977
+
978
+
979
+ @catch_mlflow_exception
980
+ @_disable_if_artifacts_only
981
+ def _log_inputs():
982
+ request_message = _get_request_message(
983
+ LogInputs(),
984
+ schema={
985
+ "run_id": [_assert_required, _assert_string],
986
+ "datasets": [_assert_array],
987
+ "models": [_assert_array],
988
+ },
989
+ )
990
+ run_id = request_message.run_id
991
+ datasets = [
992
+ DatasetInput.from_proto(proto_dataset_input)
993
+ for proto_dataset_input in request_message.datasets
994
+ ]
995
+ models = (
996
+ [
997
+ LoggedModelInput.from_proto(proto_logged_model_input)
998
+ for proto_logged_model_input in request_message.models
999
+ ]
1000
+ if request_message.models
1001
+ else None
1002
+ )
1003
+
1004
+ _get_tracking_store().log_inputs(run_id, datasets=datasets, models=models)
1005
+ response_message = LogInputs.Response()
1006
+ response = Response(mimetype="application/json")
1007
+ response.set_data(message_to_json(response_message))
1008
+ return response
1009
+
1010
+
1011
+ @catch_mlflow_exception
1012
+ @_disable_if_artifacts_only
1013
+ def _log_outputs():
1014
+ request_message = _get_request_message(
1015
+ LogOutputs(),
1016
+ schema={
1017
+ "run_id": [_assert_required, _assert_string],
1018
+ "models": [_assert_required, _assert_array],
1019
+ },
1020
+ )
1021
+ models = [LoggedModelOutput.from_proto(p) for p in request_message.models]
1022
+ _get_tracking_store().log_outputs(run_id=request_message.run_id, models=models)
1023
+ response_message = LogOutputs.Response()
1024
+ return _wrap_response(response_message)
1025
+
1026
+
1027
+ @catch_mlflow_exception
1028
+ @_disable_if_artifacts_only
1029
+ def _set_experiment_tag():
1030
+ request_message = _get_request_message(
1031
+ SetExperimentTag(),
1032
+ schema={
1033
+ "experiment_id": [_assert_required, _assert_string],
1034
+ "key": [_assert_required, _assert_string],
1035
+ "value": [_assert_string],
1036
+ },
1037
+ )
1038
+ tag = ExperimentTag(request_message.key, request_message.value)
1039
+ _get_tracking_store().set_experiment_tag(request_message.experiment_id, tag)
1040
+ response_message = SetExperimentTag.Response()
1041
+ response = Response(mimetype="application/json")
1042
+ response.set_data(message_to_json(response_message))
1043
+ return response
1044
+
1045
+
1046
+ @catch_mlflow_exception
1047
+ @_disable_if_artifacts_only
1048
+ def _set_tag():
1049
+ request_message = _get_request_message(
1050
+ SetTag(),
1051
+ schema={
1052
+ "run_id": [_assert_required, _assert_string],
1053
+ "key": [_assert_required, _assert_string],
1054
+ "value": [_assert_string],
1055
+ },
1056
+ )
1057
+ tag = RunTag(request_message.key, request_message.value)
1058
+ run_id = request_message.run_id or request_message.run_uuid
1059
+ _get_tracking_store().set_tag(run_id, tag)
1060
+ response_message = SetTag.Response()
1061
+ response = Response(mimetype="application/json")
1062
+ response.set_data(message_to_json(response_message))
1063
+ return response
1064
+
1065
+
1066
+ @catch_mlflow_exception
1067
+ @_disable_if_artifacts_only
1068
+ def _delete_tag():
1069
+ request_message = _get_request_message(
1070
+ DeleteTag(),
1071
+ schema={
1072
+ "run_id": [_assert_required, _assert_string],
1073
+ "key": [_assert_required, _assert_string],
1074
+ },
1075
+ )
1076
+ _get_tracking_store().delete_tag(request_message.run_id, request_message.key)
1077
+ response_message = DeleteTag.Response()
1078
+ response = Response(mimetype="application/json")
1079
+ response.set_data(message_to_json(response_message))
1080
+ return response
1081
+
1082
+
1083
+ @catch_mlflow_exception
1084
+ @_disable_if_artifacts_only
1085
+ def _get_run():
1086
+ request_message = _get_request_message(
1087
+ GetRun(), schema={"run_id": [_assert_required, _assert_string]}
1088
+ )
1089
+ response_message = get_run_impl(request_message)
1090
+ response = Response(mimetype="application/json")
1091
+ response.set_data(message_to_json(response_message))
1092
+ return response
1093
+
1094
+
1095
+ def get_run_impl(request_message):
1096
+ response_message = GetRun.Response()
1097
+ run_id = request_message.run_id or request_message.run_uuid
1098
+ response_message.run.MergeFrom(_get_tracking_store().get_run(run_id).to_proto())
1099
+ return response_message
1100
+
1101
+
1102
+ @catch_mlflow_exception
1103
+ @_disable_if_artifacts_only
1104
+ def _search_runs():
1105
+ request_message = _get_request_message(
1106
+ SearchRuns(),
1107
+ schema={
1108
+ "experiment_ids": [_assert_array],
1109
+ "filter": [_assert_string],
1110
+ "max_results": [
1111
+ _assert_intlike,
1112
+ lambda x: _assert_less_than_or_equal(int(x), 50000),
1113
+ ],
1114
+ "order_by": [_assert_array, _assert_item_type_string],
1115
+ },
1116
+ )
1117
+ response_message = search_runs_impl(request_message)
1118
+ response = Response(mimetype="application/json")
1119
+ response.set_data(message_to_json(response_message))
1120
+ return response
1121
+
1122
+
1123
+ def search_runs_impl(request_message):
1124
+ response_message = SearchRuns.Response()
1125
+ run_view_type = ViewType.ACTIVE_ONLY
1126
+ if request_message.HasField("run_view_type"):
1127
+ run_view_type = ViewType.from_proto(request_message.run_view_type)
1128
+ filter_string = request_message.filter
1129
+ max_results = request_message.max_results
1130
+ experiment_ids = request_message.experiment_ids
1131
+ order_by = request_message.order_by
1132
+ page_token = request_message.page_token
1133
+ run_entities = _get_tracking_store().search_runs(
1134
+ experiment_ids, filter_string, run_view_type, max_results, order_by, page_token
1135
+ )
1136
+ response_message.runs.extend([r.to_proto() for r in run_entities])
1137
+ if run_entities.token:
1138
+ response_message.next_page_token = run_entities.token
1139
+ return response_message
1140
+
1141
+
1142
+ @catch_mlflow_exception
1143
+ @_disable_if_artifacts_only
1144
+ def _list_artifacts():
1145
+ request_message = _get_request_message(
1146
+ ListArtifacts(),
1147
+ schema={
1148
+ "run_id": [_assert_string, _assert_required],
1149
+ "path": [_assert_string],
1150
+ "page_token": [_assert_string],
1151
+ },
1152
+ )
1153
+ response_message = list_artifacts_impl(request_message)
1154
+ response = Response(mimetype="application/json")
1155
+ response.set_data(message_to_json(response_message))
1156
+ return response
1157
+
1158
+
1159
+ def list_artifacts_impl(request_message):
1160
+ response_message = ListArtifacts.Response()
1161
+ if request_message.HasField("path"):
1162
+ path = request_message.path
1163
+ path = validate_path_is_safe(path)
1164
+ else:
1165
+ path = None
1166
+ run_id = request_message.run_id or request_message.run_uuid
1167
+ run = _get_tracking_store().get_run(run_id)
1168
+
1169
+ if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
1170
+ artifact_entities = _list_artifacts_for_proxied_run_artifact_root(
1171
+ proxied_artifact_root=run.info.artifact_uri,
1172
+ relative_path=path,
1173
+ )
1174
+ else:
1175
+ artifact_entities = _get_artifact_repo(run).list_artifacts(path)
1176
+
1177
+ response_message.files.extend([a.to_proto() for a in artifact_entities])
1178
+ response_message.root_uri = run.info.artifact_uri
1179
+ return response_message
1180
+
1181
+
1182
+ @catch_mlflow_exception
1183
+ def _list_artifacts_for_proxied_run_artifact_root(proxied_artifact_root, relative_path=None):
1184
+ """
1185
+ Lists artifacts from the specified ``relative_path`` within the specified proxied Run artifact
1186
+ root (i.e. a Run artifact root with scheme ``http``, ``https``, or ``mlflow-artifacts``).
1187
+
1188
+ Args:
1189
+ proxied_artifact_root: The Run artifact root location (URI) with scheme ``http``,
1190
+ ``https``, or ``mlflow-artifacts`` that can be resolved by the
1191
+ MLflow server to a concrete storage location.
1192
+ relative_path: The relative path within the specified ``proxied_artifact_root`` under
1193
+ which to list artifact contents. If ``None``, artifacts are listed from
1194
+ the ``proxied_artifact_root`` directory.
1195
+ """
1196
+ parsed_proxied_artifact_root = urllib.parse.urlparse(proxied_artifact_root)
1197
+ assert parsed_proxied_artifact_root.scheme in ["http", "https", "mlflow-artifacts"]
1198
+
1199
+ artifact_destination_repo = _get_artifact_repo_mlflow_artifacts()
1200
+ artifact_destination_path = _get_proxied_run_artifact_destination_path(
1201
+ proxied_artifact_root=proxied_artifact_root,
1202
+ relative_path=relative_path,
1203
+ )
1204
+
1205
+ artifact_entities = []
1206
+ for file_info in artifact_destination_repo.list_artifacts(artifact_destination_path):
1207
+ basename = posixpath.basename(file_info.path)
1208
+ run_relative_artifact_path = (
1209
+ posixpath.join(relative_path, basename) if relative_path else basename
1210
+ )
1211
+ artifact_entities.append(
1212
+ FileInfo(run_relative_artifact_path, file_info.is_dir, file_info.file_size)
1213
+ )
1214
+
1215
+ return artifact_entities
1216
+
1217
+
1218
+ @catch_mlflow_exception
1219
+ @_disable_if_artifacts_only
1220
+ def _get_metric_history():
1221
+ request_message = _get_request_message(
1222
+ GetMetricHistory(),
1223
+ schema={
1224
+ "run_id": [_assert_string, _assert_required],
1225
+ "metric_key": [_assert_string, _assert_required],
1226
+ "page_token": [_assert_string],
1227
+ },
1228
+ )
1229
+ response_message = GetMetricHistory.Response()
1230
+ run_id = request_message.run_id or request_message.run_uuid
1231
+
1232
+ max_results = request_message.max_results if request_message.max_results is not None else None
1233
+ page_token = request_message.page_token if request_message.page_token else None
1234
+
1235
+ metric_entities = _get_tracking_store().get_metric_history(
1236
+ run_id, request_message.metric_key, max_results=max_results, page_token=page_token
1237
+ )
1238
+ response_message.metrics.extend([m.to_proto() for m in metric_entities])
1239
+
1240
+ # Set next_page_token if available
1241
+ if next_page_token := metric_entities.token:
1242
+ response_message.next_page_token = next_page_token
1243
+
1244
+ response = Response(mimetype="application/json")
1245
+ response.set_data(message_to_json(response_message))
1246
+ return response
1247
+
1248
+
1249
+ @catch_mlflow_exception
1250
+ @_disable_if_artifacts_only
1251
+ def get_metric_history_bulk_handler():
1252
+ MAX_HISTORY_RESULTS = 25000
1253
+ MAX_RUN_IDS_PER_REQUEST = 100
1254
+ run_ids = request.args.to_dict(flat=False).get("run_id", [])
1255
+ if not run_ids:
1256
+ raise MlflowException(
1257
+ message="GetMetricHistoryBulk request must specify at least one run_id.",
1258
+ error_code=INVALID_PARAMETER_VALUE,
1259
+ )
1260
+ if len(run_ids) > MAX_RUN_IDS_PER_REQUEST:
1261
+ raise MlflowException(
1262
+ message=(
1263
+ f"GetMetricHistoryBulk request cannot specify more than {MAX_RUN_IDS_PER_REQUEST}"
1264
+ f" run_ids. Received {len(run_ids)} run_ids."
1265
+ ),
1266
+ error_code=INVALID_PARAMETER_VALUE,
1267
+ )
1268
+
1269
+ metric_key = request.args.get("metric_key")
1270
+ if metric_key is None:
1271
+ raise MlflowException(
1272
+ message="GetMetricHistoryBulk request must specify a metric_key.",
1273
+ error_code=INVALID_PARAMETER_VALUE,
1274
+ )
1275
+
1276
+ max_results = int(request.args.get("max_results", MAX_HISTORY_RESULTS))
1277
+ max_results = min(max_results, MAX_HISTORY_RESULTS)
1278
+
1279
+ store = _get_tracking_store()
1280
+
1281
+ def _default_history_bulk_impl():
1282
+ metrics_with_run_ids = []
1283
+ for run_id in sorted(run_ids):
1284
+ metrics_for_run = sorted(
1285
+ store.get_metric_history(
1286
+ run_id=run_id,
1287
+ metric_key=metric_key,
1288
+ max_results=max_results,
1289
+ ),
1290
+ key=lambda metric: (metric.timestamp, metric.step, metric.value),
1291
+ )
1292
+ metrics_with_run_ids.extend(
1293
+ [
1294
+ {
1295
+ "key": metric.key,
1296
+ "value": metric.value,
1297
+ "timestamp": metric.timestamp,
1298
+ "step": metric.step,
1299
+ "run_id": run_id,
1300
+ }
1301
+ for metric in metrics_for_run
1302
+ ]
1303
+ )
1304
+ return metrics_with_run_ids
1305
+
1306
+ if hasattr(store, "get_metric_history_bulk"):
1307
+ metrics_with_run_ids = [
1308
+ metric.to_dict()
1309
+ for metric in store.get_metric_history_bulk(
1310
+ run_ids=run_ids,
1311
+ metric_key=metric_key,
1312
+ max_results=max_results,
1313
+ )
1314
+ ]
1315
+ else:
1316
+ metrics_with_run_ids = _default_history_bulk_impl()
1317
+
1318
+ return {
1319
+ "metrics": metrics_with_run_ids[:max_results],
1320
+ }
1321
+
1322
+
1323
+ def _get_sampled_steps_from_steps(
1324
+ start_step: int, end_step: int, max_results: int, all_steps: list[int]
1325
+ ) -> set[int]:
1326
+ # NOTE: all_steps should be sorted before
1327
+ # being passed to this function
1328
+ start_idx = bisect.bisect_left(all_steps, start_step)
1329
+ end_idx = bisect.bisect_right(all_steps, end_step)
1330
+ if end_idx - start_idx <= max_results:
1331
+ return set(all_steps[start_idx:end_idx])
1332
+
1333
+ num_steps = end_idx - start_idx
1334
+ interval = num_steps / max_results
1335
+ sampled_steps = []
1336
+
1337
+ for i in range(0, max_results):
1338
+ idx = start_idx + int(i * interval)
1339
+ if idx < num_steps:
1340
+ sampled_steps.append(all_steps[idx])
1341
+
1342
+ sampled_steps.append(all_steps[end_idx - 1])
1343
+ return set(sampled_steps)
1344
+
1345
+
1346
+ @catch_mlflow_exception
1347
+ @_disable_if_artifacts_only
1348
+ def get_metric_history_bulk_interval_handler():
1349
+ request_message = _get_request_message(
1350
+ GetMetricHistoryBulkInterval(),
1351
+ schema={
1352
+ "run_ids": [
1353
+ _assert_required,
1354
+ _assert_array,
1355
+ _assert_item_type_string,
1356
+ lambda x: _assert_less_than_or_equal(
1357
+ len(x),
1358
+ MAX_RUNS_GET_METRIC_HISTORY_BULK,
1359
+ message=f"GetMetricHistoryBulkInterval request must specify at most "
1360
+ f"{MAX_RUNS_GET_METRIC_HISTORY_BULK} run_ids. Received {len(x)} run_ids.",
1361
+ ),
1362
+ ],
1363
+ "metric_key": [_assert_required, _assert_string],
1364
+ "start_step": [_assert_intlike],
1365
+ "end_step": [_assert_intlike],
1366
+ "max_results": [
1367
+ _assert_intlike,
1368
+ lambda x: _assert_intlike_within_range(
1369
+ int(x),
1370
+ 1,
1371
+ MAX_RESULTS_PER_RUN,
1372
+ message=f"max_results must be between 1 and {MAX_RESULTS_PER_RUN}.",
1373
+ ),
1374
+ ],
1375
+ },
1376
+ )
1377
+ response_message = get_metric_history_bulk_interval_impl(request_message)
1378
+ response = Response(mimetype="application/json")
1379
+ response.set_data(message_to_json(response_message))
1380
+ return response
1381
+
1382
+
1383
+ def get_metric_history_bulk_interval_impl(request_message):
1384
+ args = request.args
1385
+ run_ids = request_message.run_ids
1386
+ metric_key = request_message.metric_key
1387
+ max_results = int(args.get("max_results", MAX_RESULTS_PER_RUN))
1388
+
1389
+ store = _get_tracking_store()
1390
+
1391
+ def _get_sampled_steps(run_ids, metric_key, max_results):
1392
+ # cannot fetch from request_message as the default value is 0
1393
+ start_step = args.get("start_step")
1394
+ end_step = args.get("end_step")
1395
+
1396
+ # perform validation before any data fetching occurs
1397
+ if start_step is not None and end_step is not None:
1398
+ start_step = int(start_step)
1399
+ end_step = int(end_step)
1400
+ if start_step > end_step:
1401
+ raise MlflowException.invalid_parameter_value(
1402
+ "end_step must be greater than start_step. "
1403
+ f"Found start_step={start_step} and end_step={end_step}."
1404
+ )
1405
+ elif start_step is not None or end_step is not None:
1406
+ raise MlflowException.invalid_parameter_value(
1407
+ "If either start step or end step are specified, both must be specified."
1408
+ )
1409
+
1410
+ # get a list of all steps for all runs. this is necessary
1411
+ # because we can't assume that every step was logged, so
1412
+ # sampling needs to be done on the steps that actually exist
1413
+ all_runs = [
1414
+ [m.step for m in store.get_metric_history(run_id, metric_key)] for run_id in run_ids
1415
+ ]
1416
+
1417
+ # save mins and maxes to be added back later
1418
+ all_mins_and_maxes = {step for run in all_runs if run for step in [min(run), max(run)]}
1419
+ all_steps = sorted({step for sublist in all_runs for step in sublist})
1420
+
1421
+ # init start and end step if not provided in args
1422
+ if start_step is None and end_step is None:
1423
+ start_step = 0
1424
+ end_step = all_steps[-1] if all_steps else 0
1425
+
1426
+ # remove any steps outside of the range
1427
+ all_mins_and_maxes = {step for step in all_mins_and_maxes if start_step <= step <= end_step}
1428
+
1429
+ # doing extra iterations here shouldn't badly affect performance,
1430
+ # since the number of steps at this point should be relatively small
1431
+ # (MAX_RESULTS_PER_RUN + len(all_mins_and_maxes))
1432
+ sampled_steps = _get_sampled_steps_from_steps(start_step, end_step, max_results, all_steps)
1433
+ return sorted(sampled_steps.union(all_mins_and_maxes))
1434
+
1435
+ def _default_history_bulk_interval_impl():
1436
+ steps = _get_sampled_steps(run_ids, metric_key, max_results)
1437
+ metrics_with_run_ids = []
1438
+ for run_id in run_ids:
1439
+ metrics_with_run_ids.extend(
1440
+ store.get_metric_history_bulk_interval_from_steps(
1441
+ run_id=run_id,
1442
+ metric_key=metric_key,
1443
+ steps=steps,
1444
+ max_results=MAX_RESULTS_GET_METRIC_HISTORY,
1445
+ )
1446
+ )
1447
+ return metrics_with_run_ids
1448
+
1449
+ metrics_with_run_ids = _default_history_bulk_interval_impl()
1450
+
1451
+ response_message = GetMetricHistoryBulkInterval.Response()
1452
+ response_message.metrics.extend([m.to_proto() for m in metrics_with_run_ids])
1453
+ return response_message
1454
+
1455
+
1456
+ @catch_mlflow_exception
1457
+ @_disable_if_artifacts_only
1458
+ def search_datasets_handler():
1459
+ request_message = _get_request_message(
1460
+ SearchDatasets(),
1461
+ )
1462
+ response_message = search_datasets_impl(request_message)
1463
+ response = Response(mimetype="application/json")
1464
+ response.set_data(message_to_json(response_message))
1465
+ return response
1466
+
1467
+
1468
+ def search_datasets_impl(request_message):
1469
+ MAX_EXPERIMENT_IDS_PER_REQUEST = 20
1470
+ _validate_content_type(request, ["application/json"])
1471
+ experiment_ids = request_message.experiment_ids or []
1472
+ if not experiment_ids:
1473
+ raise MlflowException(
1474
+ message="SearchDatasets request must specify at least one experiment_id.",
1475
+ error_code=INVALID_PARAMETER_VALUE,
1476
+ )
1477
+ if len(experiment_ids) > MAX_EXPERIMENT_IDS_PER_REQUEST:
1478
+ raise MlflowException(
1479
+ message=(
1480
+ f"SearchDatasets request cannot specify more than {MAX_EXPERIMENT_IDS_PER_REQUEST}"
1481
+ f" experiment_ids. Received {len(experiment_ids)} experiment_ids."
1482
+ ),
1483
+ error_code=INVALID_PARAMETER_VALUE,
1484
+ )
1485
+
1486
+ store = _get_tracking_store()
1487
+
1488
+ if hasattr(store, "_search_datasets"):
1489
+ response_message = SearchDatasets.Response()
1490
+ response_message.dataset_summaries.extend(
1491
+ [summary.to_proto() for summary in store._search_datasets(experiment_ids)]
1492
+ )
1493
+ return response_message
1494
+ else:
1495
+ return _not_implemented()
1496
+
1497
+
1498
+ def _validate_gateway_path(method: str, gateway_path: str) -> None:
1499
+ if not gateway_path:
1500
+ raise MlflowException(
1501
+ message="Deployments proxy request must specify a gateway_path.",
1502
+ error_code=INVALID_PARAMETER_VALUE,
1503
+ )
1504
+ elif method == "GET":
1505
+ if gateway_path.strip("/") != "api/2.0/endpoints":
1506
+ raise MlflowException(
1507
+ message=f"Invalid gateway_path: {gateway_path} for method: {method}",
1508
+ error_code=INVALID_PARAMETER_VALUE,
1509
+ )
1510
+ elif method == "POST":
1511
+ # For POST, gateway_path must be in the form of "gateway/{name}/invocations"
1512
+ if not re.fullmatch(r"gateway/[^/]+/invocations", gateway_path.strip("/")):
1513
+ raise MlflowException(
1514
+ message=f"Invalid gateway_path: {gateway_path} for method: {method}",
1515
+ error_code=INVALID_PARAMETER_VALUE,
1516
+ )
1517
+
1518
+
1519
+ @catch_mlflow_exception
1520
+ def gateway_proxy_handler():
1521
+ target_uri = MLFLOW_DEPLOYMENTS_TARGET.get()
1522
+ if not target_uri:
1523
+ # Pretend an empty gateway service is running
1524
+ return {"endpoints": []}
1525
+
1526
+ args = request.args if request.method == "GET" else request.json
1527
+ gateway_path = args.get("gateway_path")
1528
+ _validate_gateway_path(request.method, gateway_path)
1529
+ json_data = args.get("json_data", None)
1530
+ response = requests.request(request.method, f"{target_uri}/{gateway_path}", json=json_data)
1531
+ if response.status_code == 200:
1532
+ return response.json()
1533
+ else:
1534
+ raise MlflowException(
1535
+ message=f"Deployments proxy request failed with error code {response.status_code}. "
1536
+ f"Error message: {response.text}",
1537
+ error_code=response.status_code,
1538
+ )
1539
+
1540
+
1541
+ @catch_mlflow_exception
1542
+ @_disable_if_artifacts_only
1543
+ def create_promptlab_run_handler():
1544
+ def assert_arg_exists(arg_name, arg):
1545
+ if not arg:
1546
+ raise MlflowException(
1547
+ message=f"CreatePromptlabRun request must specify {arg_name}.",
1548
+ error_code=INVALID_PARAMETER_VALUE,
1549
+ )
1550
+
1551
+ _validate_content_type(request, ["application/json"])
1552
+
1553
+ args = request.json
1554
+ experiment_id = args.get("experiment_id")
1555
+ assert_arg_exists("experiment_id", experiment_id)
1556
+ run_name = args.get("run_name", None)
1557
+ tags = args.get("tags", [])
1558
+ prompt_template = args.get("prompt_template")
1559
+ assert_arg_exists("prompt_template", prompt_template)
1560
+ raw_prompt_parameters = args.get("prompt_parameters")
1561
+ assert_arg_exists("prompt_parameters", raw_prompt_parameters)
1562
+ prompt_parameters = [
1563
+ Param(param.get("key"), param.get("value")) for param in args.get("prompt_parameters")
1564
+ ]
1565
+ model_route = args.get("model_route")
1566
+ assert_arg_exists("model_route", model_route)
1567
+ raw_model_parameters = args.get("model_parameters", [])
1568
+ model_parameters = [
1569
+ Param(param.get("key"), param.get("value")) for param in raw_model_parameters
1570
+ ]
1571
+ model_input = args.get("model_input")
1572
+ assert_arg_exists("model_input", model_input)
1573
+ model_output = args.get("model_output", None)
1574
+ raw_model_output_parameters = args.get("model_output_parameters", [])
1575
+ model_output_parameters = [
1576
+ Param(param.get("key"), param.get("value")) for param in raw_model_output_parameters
1577
+ ]
1578
+ mlflow_version = args.get("mlflow_version")
1579
+ assert_arg_exists("mlflow_version", mlflow_version)
1580
+ user_id = args.get("user_id", "unknown")
1581
+
1582
+ # use current time if not provided
1583
+ start_time = args.get("start_time", int(time.time() * 1000))
1584
+
1585
+ store = _get_tracking_store()
1586
+
1587
+ run = _create_promptlab_run_impl(
1588
+ store,
1589
+ experiment_id=experiment_id,
1590
+ run_name=run_name,
1591
+ tags=tags,
1592
+ prompt_template=prompt_template,
1593
+ prompt_parameters=prompt_parameters,
1594
+ model_route=model_route,
1595
+ model_parameters=model_parameters,
1596
+ model_input=model_input,
1597
+ model_output=model_output,
1598
+ model_output_parameters=model_output_parameters,
1599
+ mlflow_version=mlflow_version,
1600
+ user_id=user_id,
1601
+ start_time=start_time,
1602
+ )
1603
+ response_message = CreateRun.Response()
1604
+ response_message.run.MergeFrom(run.to_proto())
1605
+ response = Response(mimetype="application/json")
1606
+ response.set_data(message_to_json(response_message))
1607
+ return response
1608
+
1609
+
1610
+ @catch_mlflow_exception
1611
+ def upload_artifact_handler():
1612
+ args = request.args
1613
+ run_uuid = args.get("run_uuid")
1614
+ if not run_uuid:
1615
+ raise MlflowException(
1616
+ message="Request must specify run_uuid.",
1617
+ error_code=INVALID_PARAMETER_VALUE,
1618
+ )
1619
+ path = args.get("path")
1620
+ if not path:
1621
+ raise MlflowException(
1622
+ message="Request must specify path.",
1623
+ error_code=INVALID_PARAMETER_VALUE,
1624
+ )
1625
+
1626
+ # Security validation for artifact path
1627
+ try:
1628
+ validated_path = InputValidator.validate_artifact_path(path)
1629
+ except SecurityValidationError as e:
1630
+ raise MlflowException(
1631
+ f"Invalid artifact path: {e}",
1632
+ error_code=INVALID_PARAMETER_VALUE,
1633
+ )
1634
+
1635
+ path = validate_path_is_safe(validated_path)
1636
+
1637
+ if request.content_length and request.content_length > 10 * 1024 * 1024:
1638
+ raise MlflowException(
1639
+ message="Artifact size is too large. Max size is 10MB.",
1640
+ error_code=INVALID_PARAMETER_VALUE,
1641
+ )
1642
+
1643
+ data = request.data
1644
+ if not data:
1645
+ raise MlflowException(
1646
+ message="Request must specify data.",
1647
+ error_code=INVALID_PARAMETER_VALUE,
1648
+ )
1649
+
1650
+ run = _get_tracking_store().get_run(run_uuid)
1651
+ artifact_dir = run.info.artifact_uri
1652
+
1653
+ basename = posixpath.basename(path)
1654
+ dirname = posixpath.dirname(path)
1655
+
1656
+ def _log_artifact_to_repo(file, run, dirname, artifact_dir):
1657
+ if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
1658
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
1659
+ path_to_log = (
1660
+ os.path.join(run.info.experiment_id, run.info.run_id, "artifacts", dirname)
1661
+ if dirname
1662
+ else os.path.join(run.info.experiment_id, run.info.run_id, "artifacts")
1663
+ )
1664
+ else:
1665
+ artifact_repo = get_artifact_repository(artifact_dir)
1666
+ path_to_log = dirname
1667
+
1668
+ artifact_repo.log_artifact(file, path_to_log)
1669
+
1670
+ with tempfile.TemporaryDirectory() as tmpdir:
1671
+ dir_path = os.path.join(tmpdir, dirname) if dirname else tmpdir
1672
+ file_path = os.path.join(dir_path, basename)
1673
+
1674
+ os.makedirs(dir_path, exist_ok=True)
1675
+
1676
+ with open(file_path, "wb") as f:
1677
+ f.write(data)
1678
+
1679
+ _log_artifact_to_repo(file_path, run, dirname, artifact_dir)
1680
+
1681
+ return Response(mimetype="application/json")
1682
+
1683
+
1684
+ @catch_mlflow_exception
1685
+ @_disable_if_artifacts_only
1686
+ def _search_experiments():
1687
+ request_message = _get_request_message(
1688
+ SearchExperiments(),
1689
+ schema={
1690
+ "view_type": [_assert_intlike],
1691
+ "max_results": [_assert_intlike],
1692
+ "order_by": [_assert_array],
1693
+ "filter": [_assert_string],
1694
+ "page_token": [_assert_string],
1695
+ },
1696
+ )
1697
+ experiment_entities = _get_tracking_store().search_experiments(
1698
+ view_type=request_message.view_type,
1699
+ max_results=request_message.max_results,
1700
+ order_by=request_message.order_by,
1701
+ filter_string=request_message.filter,
1702
+ page_token=request_message.page_token,
1703
+ )
1704
+ response_message = SearchExperiments.Response()
1705
+ response_message.experiments.extend([e.to_proto() for e in experiment_entities])
1706
+ if experiment_entities.token:
1707
+ response_message.next_page_token = experiment_entities.token
1708
+ response = Response(mimetype="application/json")
1709
+ response.set_data(message_to_json(response_message))
1710
+ return response
1711
+
1712
+
1713
+ @catch_mlflow_exception
1714
+ def _get_artifact_repo(run):
1715
+ return get_artifact_repository(run.info.artifact_uri)
1716
+
1717
+
1718
+ @catch_mlflow_exception
1719
+ @_disable_if_artifacts_only
1720
+ def _log_batch():
1721
+ def _assert_metrics_fields_present(metrics):
1722
+ for idx, m in enumerate(metrics):
1723
+ _assert_required(m.get("key"), path=f"metrics[{idx}].key")
1724
+ _assert_required(m.get("value"), path=f"metrics[{idx}].value")
1725
+ _assert_required(m.get("timestamp"), path=f"metrics[{idx}].timestamp")
1726
+
1727
+ def _assert_params_fields_present(params):
1728
+ for idx, param in enumerate(params):
1729
+ _assert_required(param.get("key"), path=f"params[{idx}].key")
1730
+
1731
+ def _assert_tags_fields_present(tags):
1732
+ for idx, tag in enumerate(tags):
1733
+ _assert_required(tag.get("key"), path=f"tags[{idx}].key")
1734
+
1735
+ _validate_batch_log_api_req(_get_request_json())
1736
+ request_message = _get_request_message(
1737
+ LogBatch(),
1738
+ schema={
1739
+ "run_id": [_assert_string, _assert_required],
1740
+ "metrics": [_assert_array, _assert_metrics_fields_present],
1741
+ "params": [_assert_array, _assert_params_fields_present],
1742
+ "tags": [_assert_array, _assert_tags_fields_present],
1743
+ },
1744
+ )
1745
+ metrics = [Metric.from_proto(proto_metric) for proto_metric in request_message.metrics]
1746
+ params = [Param.from_proto(proto_param) for proto_param in request_message.params]
1747
+ tags = [RunTag.from_proto(proto_tag) for proto_tag in request_message.tags]
1748
+ _get_tracking_store().log_batch(
1749
+ run_id=request_message.run_id, metrics=metrics, params=params, tags=tags
1750
+ )
1751
+ response_message = LogBatch.Response()
1752
+ response = Response(mimetype="application/json")
1753
+ response.set_data(message_to_json(response_message))
1754
+ return response
1755
+
1756
+
1757
+ @catch_mlflow_exception
1758
+ @_disable_if_artifacts_only
1759
+ def _log_model():
1760
+ request_message = _get_request_message(
1761
+ LogModel(),
1762
+ schema={
1763
+ "run_id": [_assert_string, _assert_required],
1764
+ "model_json": [_assert_string, _assert_required],
1765
+ },
1766
+ )
1767
+ try:
1768
+ model = json.loads(request_message.model_json)
1769
+ except Exception:
1770
+ raise MlflowException(
1771
+ f"Malformed model info. \n {request_message.model_json} \n is not a valid JSON.",
1772
+ error_code=INVALID_PARAMETER_VALUE,
1773
+ )
1774
+
1775
+ missing_fields = {"artifact_path", "flavors", "utc_time_created", "run_id"} - set(model.keys())
1776
+
1777
+ if missing_fields:
1778
+ raise MlflowException(
1779
+ f"Model json is missing mandatory fields: {missing_fields}",
1780
+ error_code=INVALID_PARAMETER_VALUE,
1781
+ )
1782
+ _get_tracking_store().record_logged_model(
1783
+ run_id=request_message.run_id, mlflow_model=Model.from_dict(model)
1784
+ )
1785
+ response_message = LogModel.Response()
1786
+ response = Response(mimetype="application/json")
1787
+ response.set_data(message_to_json(response_message))
1788
+ return response
1789
+
1790
+
1791
+ def _wrap_response(response_message):
1792
+ response = Response(mimetype="application/json")
1793
+ response.set_data(message_to_json(response_message))
1794
+ return response
1795
+
1796
+
1797
+ # Model Registry APIs
1798
+
1799
+
1800
+ @catch_mlflow_exception
1801
+ @_disable_if_artifacts_only
1802
+ def _create_registered_model():
1803
+ request_message = _get_request_message(
1804
+ CreateRegisteredModel(),
1805
+ schema={
1806
+ "name": [_assert_string, _assert_required],
1807
+ "tags": [_assert_array],
1808
+ "description": [_assert_string],
1809
+ },
1810
+ )
1811
+ registered_model = _get_model_registry_store().create_registered_model(
1812
+ name=request_message.name,
1813
+ tags=request_message.tags,
1814
+ description=request_message.description,
1815
+ )
1816
+ response_message = CreateRegisteredModel.Response(registered_model=registered_model.to_proto())
1817
+ return _wrap_response(response_message)
1818
+
1819
+
1820
+ @catch_mlflow_exception
1821
+ @_disable_if_artifacts_only
1822
+ def _get_registered_model():
1823
+ request_message = _get_request_message(
1824
+ GetRegisteredModel(), schema={"name": [_assert_string, _assert_required]}
1825
+ )
1826
+ registered_model = _get_model_registry_store().get_registered_model(name=request_message.name)
1827
+ response_message = GetRegisteredModel.Response(registered_model=registered_model.to_proto())
1828
+ return _wrap_response(response_message)
1829
+
1830
+
1831
+ @catch_mlflow_exception
1832
+ @_disable_if_artifacts_only
1833
+ def _update_registered_model():
1834
+ request_message = _get_request_message(
1835
+ UpdateRegisteredModel(),
1836
+ schema={
1837
+ "name": [_assert_string, _assert_required],
1838
+ "description": [_assert_string],
1839
+ },
1840
+ )
1841
+ name = request_message.name
1842
+ new_description = request_message.description
1843
+ registered_model = _get_model_registry_store().update_registered_model(
1844
+ name=name, description=new_description
1845
+ )
1846
+ response_message = UpdateRegisteredModel.Response(registered_model=registered_model.to_proto())
1847
+ return _wrap_response(response_message)
1848
+
1849
+
1850
+ @catch_mlflow_exception
1851
+ @_disable_if_artifacts_only
1852
+ def _rename_registered_model():
1853
+ request_message = _get_request_message(
1854
+ RenameRegisteredModel(),
1855
+ schema={
1856
+ "name": [_assert_string, _assert_required],
1857
+ "new_name": [_assert_string, _assert_required],
1858
+ },
1859
+ )
1860
+ name = request_message.name
1861
+ new_name = request_message.new_name
1862
+ registered_model = _get_model_registry_store().rename_registered_model(
1863
+ name=name, new_name=new_name
1864
+ )
1865
+ response_message = RenameRegisteredModel.Response(registered_model=registered_model.to_proto())
1866
+ return _wrap_response(response_message)
1867
+
1868
+
1869
+ @catch_mlflow_exception
1870
+ @_disable_if_artifacts_only
1871
+ def _delete_registered_model():
1872
+ request_message = _get_request_message(
1873
+ DeleteRegisteredModel(), schema={"name": [_assert_string, _assert_required]}
1874
+ )
1875
+ _get_model_registry_store().delete_registered_model(name=request_message.name)
1876
+ return _wrap_response(DeleteRegisteredModel.Response())
1877
+
1878
+
1879
+ @catch_mlflow_exception
1880
+ @_disable_if_artifacts_only
1881
+ def _search_registered_models():
1882
+ request_message = _get_request_message(
1883
+ SearchRegisteredModels(),
1884
+ schema={
1885
+ "filter": [_assert_string],
1886
+ "max_results": [
1887
+ _assert_intlike,
1888
+ lambda x: _assert_less_than_or_equal(int(x), 1000),
1889
+ ],
1890
+ "order_by": [_assert_array, _assert_item_type_string],
1891
+ "page_token": [_assert_string],
1892
+ },
1893
+ )
1894
+ store = _get_model_registry_store()
1895
+ registered_models = store.search_registered_models(
1896
+ filter_string=request_message.filter,
1897
+ max_results=request_message.max_results,
1898
+ order_by=request_message.order_by,
1899
+ page_token=request_message.page_token,
1900
+ )
1901
+ response_message = SearchRegisteredModels.Response()
1902
+ response_message.registered_models.extend([e.to_proto() for e in registered_models])
1903
+ if registered_models.token:
1904
+ response_message.next_page_token = registered_models.token
1905
+ return _wrap_response(response_message)
1906
+
1907
+
1908
+ @catch_mlflow_exception
1909
+ @_disable_if_artifacts_only
1910
+ def _get_latest_versions():
1911
+ request_message = _get_request_message(
1912
+ GetLatestVersions(),
1913
+ schema={
1914
+ "name": [_assert_string, _assert_required],
1915
+ "stages": [_assert_array, _assert_item_type_string],
1916
+ },
1917
+ )
1918
+ latest_versions = _get_model_registry_store().get_latest_versions(
1919
+ name=request_message.name, stages=request_message.stages
1920
+ )
1921
+ response_message = GetLatestVersions.Response()
1922
+ response_message.model_versions.extend([e.to_proto() for e in latest_versions])
1923
+ return _wrap_response(response_message)
1924
+
1925
+
1926
+ @catch_mlflow_exception
1927
+ @_disable_if_artifacts_only
1928
+ def _set_registered_model_tag():
1929
+ request_message = _get_request_message(
1930
+ SetRegisteredModelTag(),
1931
+ schema={
1932
+ "name": [_assert_string, _assert_required],
1933
+ "key": [_assert_string, _assert_required],
1934
+ "value": [_assert_string],
1935
+ },
1936
+ )
1937
+ tag = RegisteredModelTag(key=request_message.key, value=request_message.value)
1938
+ _get_model_registry_store().set_registered_model_tag(name=request_message.name, tag=tag)
1939
+ return _wrap_response(SetRegisteredModelTag.Response())
1940
+
1941
+
1942
+ @catch_mlflow_exception
1943
+ @_disable_if_artifacts_only
1944
+ def _delete_registered_model_tag():
1945
+ request_message = _get_request_message(
1946
+ DeleteRegisteredModelTag(),
1947
+ schema={
1948
+ "name": [_assert_string, _assert_required],
1949
+ "key": [_assert_string, _assert_required],
1950
+ },
1951
+ )
1952
+ _get_model_registry_store().delete_registered_model_tag(
1953
+ name=request_message.name, key=request_message.key
1954
+ )
1955
+ return _wrap_response(DeleteRegisteredModelTag.Response())
1956
+
1957
+
1958
+ def _validate_non_local_source_contains_relative_paths(source: str):
1959
+ """
1960
+ Validation check to ensure that sources that are provided that conform to the schemes:
1961
+ http, https, or mlflow-artifacts do not contain relative path designations that are intended
1962
+ to access local file system paths on the tracking server.
1963
+
1964
+ Example paths that this validation function is intended to find and raise an Exception if
1965
+ passed:
1966
+ "mlflow-artifacts://host:port/../../../../"
1967
+ "http://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
1968
+ "https://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
1969
+ "/models/artifacts/../../../"
1970
+ "s3:/my_bucket/models/path/../../other/path"
1971
+ "file://path/to/../../../../some/where/you/should/not/be"
1972
+ "mlflow-artifacts://host:port/..%2f..%2f..%2f..%2f"
1973
+ "http://host:port/api/2.0/mlflow-artifacts/artifacts%00"
1974
+ """
1975
+ invalid_source_error_message = (
1976
+ f"Invalid model version source: '{source}'. If supplying a source as an http, https, "
1977
+ "local file path, ftp, objectstore, or mlflow-artifacts uri, an absolute path must be "
1978
+ "provided without relative path references present. "
1979
+ "Please provide an absolute path."
1980
+ )
1981
+
1982
+ while (unquoted := urllib.parse.unquote_plus(source)) != source:
1983
+ source = unquoted
1984
+ source_path = re.sub(r"/+", "/", urllib.parse.urlparse(source).path.rstrip("/"))
1985
+ if "\x00" in source_path or any(p == ".." for p in source.split("/")):
1986
+ raise MlflowException(invalid_source_error_message, INVALID_PARAMETER_VALUE)
1987
+ resolved_source = pathlib.Path(source_path).resolve().as_posix()
1988
+ # NB: drive split is specifically for Windows since WindowsPath.resolve() will append the
1989
+ # drive path of the pwd to a given path. We don't care about the drive here, though.
1990
+ _, resolved_path = os.path.splitdrive(resolved_source)
1991
+
1992
+ if resolved_path != source_path:
1993
+ raise MlflowException(invalid_source_error_message, INVALID_PARAMETER_VALUE)
1994
+
1995
+
1996
+ def _validate_source_run(source: str, run_id: str) -> None:
1997
+ if is_local_uri(source):
1998
+ if run_id:
1999
+ store = _get_tracking_store()
2000
+ run = store.get_run(run_id)
2001
+ source = pathlib.Path(local_file_uri_to_path(source)).resolve()
2002
+ if is_local_uri(run.info.artifact_uri):
2003
+ run_artifact_dir = pathlib.Path(
2004
+ local_file_uri_to_path(run.info.artifact_uri)
2005
+ ).resolve()
2006
+ if run_artifact_dir in [source, *source.parents]:
2007
+ return
2008
+
2009
+ raise MlflowException(
2010
+ f"Invalid model version source: '{source}'. To use a local path as a model version "
2011
+ "source, the run_id request parameter has to be specified and the local path has to be "
2012
+ "contained within the artifact directory of the run specified by the run_id.",
2013
+ INVALID_PARAMETER_VALUE,
2014
+ )
2015
+
2016
+ # Checks if relative paths are present in the source (a security threat). If any are present,
2017
+ # raises an Exception.
2018
+ _validate_non_local_source_contains_relative_paths(source)
2019
+
2020
+
2021
+ def _validate_source_model(source: str, model_id: str) -> None:
2022
+ if is_local_uri(source):
2023
+ if model_id:
2024
+ store = _get_tracking_store()
2025
+ model = store.get_logged_model(model_id)
2026
+ source = pathlib.Path(local_file_uri_to_path(source)).resolve()
2027
+ if is_local_uri(model.artifact_location):
2028
+ run_artifact_dir = pathlib.Path(
2029
+ local_file_uri_to_path(model.artifact_location)
2030
+ ).resolve()
2031
+ if run_artifact_dir in [source, *source.parents]:
2032
+ return
2033
+
2034
+ raise MlflowException(
2035
+ f"Invalid model version source: '{source}'. To use a local path as a model version "
2036
+ "source, the model_id request parameter has to be specified and the local path has to "
2037
+ "be contained within the artifact directory of the run specified by the model_id.",
2038
+ INVALID_PARAMETER_VALUE,
2039
+ )
2040
+
2041
+ # Checks if relative paths are present in the source (a security threat). If any are present,
2042
+ # raises an Exception.
2043
+ _validate_non_local_source_contains_relative_paths(source)
2044
+
2045
+
2046
+ @catch_mlflow_exception
2047
+ @_disable_if_artifacts_only
2048
+ def _create_model_version():
2049
+ request_message = _get_request_message(
2050
+ CreateModelVersion(),
2051
+ schema={
2052
+ "name": [_assert_string, _assert_required],
2053
+ "source": [_assert_string, _assert_required],
2054
+ "run_id": [_assert_string],
2055
+ "tags": [_assert_array],
2056
+ "run_link": [_assert_string],
2057
+ "description": [_assert_string],
2058
+ "model_id": [_assert_string],
2059
+ },
2060
+ )
2061
+
2062
+ if request_message.source and (
2063
+ regex := MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX.get()
2064
+ ):
2065
+ if not re.search(regex, request_message.source):
2066
+ raise MlflowException(
2067
+ f"Invalid model version source: '{request_message.source}'.",
2068
+ error_code=INVALID_PARAMETER_VALUE,
2069
+ )
2070
+
2071
+ # If the model version is a prompt, we don't validate the source
2072
+ if not _is_prompt_request(request_message):
2073
+ if request_message.model_id:
2074
+ _validate_source_model(request_message.source, request_message.model_id)
2075
+ else:
2076
+ _validate_source_run(request_message.source, request_message.run_id)
2077
+
2078
+ model_version = _get_model_registry_store().create_model_version(
2079
+ name=request_message.name,
2080
+ source=request_message.source,
2081
+ run_id=request_message.run_id,
2082
+ run_link=request_message.run_link,
2083
+ tags=request_message.tags,
2084
+ description=request_message.description,
2085
+ model_id=request_message.model_id,
2086
+ )
2087
+ if not _is_prompt_request(request_message) and request_message.model_id:
2088
+ tracking_store = _get_tracking_store()
2089
+ tracking_store.set_model_versions_tags(
2090
+ name=request_message.name,
2091
+ version=model_version.version,
2092
+ model_id=request_message.model_id,
2093
+ )
2094
+ response_message = CreateModelVersion.Response(model_version=model_version.to_proto())
2095
+ return _wrap_response(response_message)
2096
+
2097
+
2098
+ def _is_prompt_request(request_message):
2099
+ return any(tag.key == IS_PROMPT_TAG_KEY for tag in request_message.tags)
2100
+
2101
+
2102
+ @catch_mlflow_exception
2103
+ @_disable_if_artifacts_only
2104
+ def get_model_version_artifact_handler():
2105
+ name = request.args.get("name")
2106
+ version = request.args.get("version")
2107
+ path = request.args["path"]
2108
+ path = validate_path_is_safe(path)
2109
+ artifact_uri = _get_model_registry_store().get_model_version_download_uri(name, version)
2110
+ if _is_servable_proxied_run_artifact_root(artifact_uri):
2111
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
2112
+ artifact_path = _get_proxied_run_artifact_destination_path(
2113
+ proxied_artifact_root=artifact_uri,
2114
+ relative_path=path,
2115
+ )
2116
+ else:
2117
+ artifact_repo = get_artifact_repository(artifact_uri)
2118
+ artifact_path = path
2119
+
2120
+ return _send_artifact(artifact_repo, artifact_path)
2121
+
2122
+
2123
+ @catch_mlflow_exception
2124
+ @_disable_if_artifacts_only
2125
+ def _get_model_version():
2126
+ request_message = _get_request_message(
2127
+ GetModelVersion(),
2128
+ schema={
2129
+ "name": [_assert_string, _assert_required],
2130
+ "version": [_assert_string, _assert_required],
2131
+ },
2132
+ )
2133
+ model_version = _get_model_registry_store().get_model_version(
2134
+ name=request_message.name, version=request_message.version
2135
+ )
2136
+ response_proto = model_version.to_proto()
2137
+ response_message = GetModelVersion.Response(model_version=response_proto)
2138
+ return _wrap_response(response_message)
2139
+
2140
+
2141
+ @catch_mlflow_exception
2142
+ @_disable_if_artifacts_only
2143
+ def _update_model_version():
2144
+ request_message = _get_request_message(
2145
+ UpdateModelVersion(),
2146
+ schema={
2147
+ "name": [_assert_string, _assert_required],
2148
+ "version": [_assert_string, _assert_required],
2149
+ "description": [_assert_string],
2150
+ },
2151
+ )
2152
+ new_description = None
2153
+ if request_message.HasField("description"):
2154
+ new_description = request_message.description
2155
+ model_version = _get_model_registry_store().update_model_version(
2156
+ name=request_message.name,
2157
+ version=request_message.version,
2158
+ description=new_description,
2159
+ )
2160
+ return _wrap_response(UpdateModelVersion.Response(model_version=model_version.to_proto()))
2161
+
2162
+
2163
+ @catch_mlflow_exception
2164
+ @_disable_if_artifacts_only
2165
+ def _transition_stage():
2166
+ request_message = _get_request_message(
2167
+ TransitionModelVersionStage(),
2168
+ schema={
2169
+ "name": [_assert_string, _assert_required],
2170
+ "version": [_assert_string, _assert_required],
2171
+ "stage": [_assert_string, _assert_required],
2172
+ "archive_existing_versions": [_assert_bool],
2173
+ },
2174
+ )
2175
+ model_version = _get_model_registry_store().transition_model_version_stage(
2176
+ name=request_message.name,
2177
+ version=request_message.version,
2178
+ stage=request_message.stage,
2179
+ archive_existing_versions=request_message.archive_existing_versions,
2180
+ )
2181
+ return _wrap_response(
2182
+ TransitionModelVersionStage.Response(model_version=model_version.to_proto())
2183
+ )
2184
+
2185
+
2186
+ @catch_mlflow_exception
2187
+ @_disable_if_artifacts_only
2188
+ def _delete_model_version():
2189
+ request_message = _get_request_message(
2190
+ DeleteModelVersion(),
2191
+ schema={
2192
+ "name": [_assert_string, _assert_required],
2193
+ "version": [_assert_string, _assert_required],
2194
+ },
2195
+ )
2196
+ _get_model_registry_store().delete_model_version(
2197
+ name=request_message.name, version=request_message.version
2198
+ )
2199
+ return _wrap_response(DeleteModelVersion.Response())
2200
+
2201
+
2202
+ @catch_mlflow_exception
2203
+ @_disable_if_artifacts_only
2204
+ def _get_model_version_download_uri():
2205
+ request_message = _get_request_message(GetModelVersionDownloadUri())
2206
+ download_uri = _get_model_registry_store().get_model_version_download_uri(
2207
+ name=request_message.name, version=request_message.version
2208
+ )
2209
+ response_message = GetModelVersionDownloadUri.Response(artifact_uri=download_uri)
2210
+ return _wrap_response(response_message)
2211
+
2212
+
2213
+ @catch_mlflow_exception
2214
+ @_disable_if_artifacts_only
2215
+ def _search_model_versions():
2216
+ request_message = _get_request_message(
2217
+ SearchModelVersions(),
2218
+ schema={
2219
+ "filter": [_assert_string],
2220
+ "max_results": [
2221
+ _assert_intlike,
2222
+ lambda x: _assert_less_than_or_equal(int(x), 200_000),
2223
+ ],
2224
+ "order_by": [_assert_array, _assert_item_type_string],
2225
+ "page_token": [_assert_string],
2226
+ },
2227
+ )
2228
+ response_message = search_model_versions_impl(request_message)
2229
+ return _wrap_response(response_message)
2230
+
2231
+
2232
+ def search_model_versions_impl(request_message):
2233
+ store = _get_model_registry_store()
2234
+ model_versions = store.search_model_versions(
2235
+ filter_string=request_message.filter,
2236
+ max_results=request_message.max_results,
2237
+ order_by=request_message.order_by,
2238
+ page_token=request_message.page_token,
2239
+ )
2240
+ response_message = SearchModelVersions.Response()
2241
+ response_message.model_versions.extend([e.to_proto() for e in model_versions])
2242
+ if model_versions.token:
2243
+ response_message.next_page_token = model_versions.token
2244
+ return response_message
2245
+
2246
+
2247
+ @catch_mlflow_exception
2248
+ @_disable_if_artifacts_only
2249
+ def _set_model_version_tag():
2250
+ request_message = _get_request_message(
2251
+ SetModelVersionTag(),
2252
+ schema={
2253
+ "name": [_assert_string, _assert_required],
2254
+ "version": [_assert_string, _assert_required],
2255
+ "key": [_assert_string, _assert_required],
2256
+ "value": [_assert_string],
2257
+ },
2258
+ )
2259
+ tag = ModelVersionTag(key=request_message.key, value=request_message.value)
2260
+ _get_model_registry_store().set_model_version_tag(
2261
+ name=request_message.name, version=request_message.version, tag=tag
2262
+ )
2263
+ return _wrap_response(SetModelVersionTag.Response())
2264
+
2265
+
2266
+ @catch_mlflow_exception
2267
+ @_disable_if_artifacts_only
2268
+ def _delete_model_version_tag():
2269
+ request_message = _get_request_message(
2270
+ DeleteModelVersionTag(),
2271
+ schema={
2272
+ "name": [_assert_string, _assert_required],
2273
+ "version": [_assert_string, _assert_required],
2274
+ "key": [_assert_string, _assert_required],
2275
+ },
2276
+ )
2277
+ _get_model_registry_store().delete_model_version_tag(
2278
+ name=request_message.name,
2279
+ version=request_message.version,
2280
+ key=request_message.key,
2281
+ )
2282
+ return _wrap_response(DeleteModelVersionTag.Response())
2283
+
2284
+
2285
+ @catch_mlflow_exception
2286
+ @_disable_if_artifacts_only
2287
+ def _set_registered_model_alias():
2288
+ request_message = _get_request_message(
2289
+ SetRegisteredModelAlias(),
2290
+ schema={
2291
+ "name": [_assert_string, _assert_required],
2292
+ "alias": [_assert_string, _assert_required],
2293
+ "version": [_assert_string, _assert_required],
2294
+ },
2295
+ )
2296
+ _get_model_registry_store().set_registered_model_alias(
2297
+ name=request_message.name,
2298
+ alias=request_message.alias,
2299
+ version=request_message.version,
2300
+ )
2301
+ return _wrap_response(SetRegisteredModelAlias.Response())
2302
+
2303
+
2304
+ @catch_mlflow_exception
2305
+ @_disable_if_artifacts_only
2306
+ def _delete_registered_model_alias():
2307
+ request_message = _get_request_message(
2308
+ DeleteRegisteredModelAlias(),
2309
+ schema={
2310
+ "name": [_assert_string, _assert_required],
2311
+ "alias": [_assert_string, _assert_required],
2312
+ },
2313
+ )
2314
+ _get_model_registry_store().delete_registered_model_alias(
2315
+ name=request_message.name, alias=request_message.alias
2316
+ )
2317
+ return _wrap_response(DeleteRegisteredModelAlias.Response())
2318
+
2319
+
2320
+ @catch_mlflow_exception
2321
+ @_disable_if_artifacts_only
2322
+ def _get_model_version_by_alias():
2323
+ request_message = _get_request_message(
2324
+ GetModelVersionByAlias(),
2325
+ schema={
2326
+ "name": [_assert_string, _assert_required],
2327
+ "alias": [_assert_string, _assert_required],
2328
+ },
2329
+ )
2330
+ model_version = _get_model_registry_store().get_model_version_by_alias(
2331
+ name=request_message.name, alias=request_message.alias
2332
+ )
2333
+ response_proto = model_version.to_proto()
2334
+ response_message = GetModelVersionByAlias.Response(model_version=response_proto)
2335
+ return _wrap_response(response_message)
2336
+
2337
+
2338
+ # MLflow Artifacts APIs
2339
+
2340
+
2341
+ @catch_mlflow_exception
2342
+ @_disable_unless_serve_artifacts
2343
+ def _download_artifact(artifact_path):
2344
+ """
2345
+ A request handler for `GET /mlflow-artifacts/artifacts/<artifact_path>` to download an artifact
2346
+ from `artifact_path` (a relative path from the root artifact directory).
2347
+ """
2348
+ artifact_path = validate_path_is_safe(artifact_path)
2349
+ tmp_dir = tempfile.TemporaryDirectory()
2350
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
2351
+ dst = artifact_repo.download_artifacts(artifact_path, tmp_dir.name)
2352
+
2353
+ # Ref: https://stackoverflow.com/a/24613980/6943581
2354
+ file_handle = open(dst, "rb") # noqa: SIM115
2355
+
2356
+ def stream_and_remove_file():
2357
+ yield from file_handle
2358
+ file_handle.close()
2359
+ tmp_dir.cleanup()
2360
+
2361
+ file_sender_response = current_app.response_class(stream_and_remove_file())
2362
+
2363
+ return _response_with_file_attachment_headers(artifact_path, file_sender_response)
2364
+
2365
+
2366
+ @catch_mlflow_exception
2367
+ @_disable_unless_serve_artifacts
2368
+ def _upload_artifact(artifact_path):
2369
+ """
2370
+ A request handler for `PUT /mlflow-artifacts/artifacts/<artifact_path>` to upload an artifact
2371
+ to `artifact_path` (a relative path from the root artifact directory).
2372
+ """
2373
+ artifact_path = validate_path_is_safe(artifact_path)
2374
+ head, tail = posixpath.split(artifact_path)
2375
+ with tempfile.TemporaryDirectory() as tmp_dir:
2376
+ tmp_path = os.path.join(tmp_dir, tail)
2377
+ with open(tmp_path, "wb") as f:
2378
+ chunk_size = 1024 * 1024 # 1 MB
2379
+ while True:
2380
+ chunk = request.stream.read(chunk_size)
2381
+ if len(chunk) == 0:
2382
+ break
2383
+ f.write(chunk)
2384
+
2385
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
2386
+ artifact_repo.log_artifact(tmp_path, artifact_path=head or None)
2387
+
2388
+ return _wrap_response(UploadArtifact.Response())
2389
+
2390
+
2391
+ @catch_mlflow_exception
2392
+ @_disable_unless_serve_artifacts
2393
+ def _list_artifacts_mlflow_artifacts():
2394
+ """
2395
+ A request handler for `GET /mlflow-artifacts/artifacts?path=<value>` to list artifacts in `path`
2396
+ (a relative path from the root artifact directory).
2397
+ """
2398
+ request_message = _get_request_message(ListArtifactsMlflowArtifacts())
2399
+ path = validate_path_is_safe(request_message.path) if request_message.HasField("path") else None
2400
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
2401
+ files = []
2402
+ for file_info in artifact_repo.list_artifacts(path):
2403
+ basename = posixpath.basename(file_info.path)
2404
+ new_file_info = FileInfo(basename, file_info.is_dir, file_info.file_size)
2405
+ files.append(new_file_info.to_proto())
2406
+ response_message = ListArtifacts.Response()
2407
+ response_message.files.extend(files)
2408
+ response = Response(mimetype="application/json")
2409
+ response.set_data(message_to_json(response_message))
2410
+ return response
2411
+
2412
+
2413
+ @catch_mlflow_exception
2414
+ @_disable_unless_serve_artifacts
2415
+ def _delete_artifact_mlflow_artifacts(artifact_path):
2416
+ """
2417
+ A request handler for `DELETE /mlflow-artifacts/artifacts?path=<value>` to delete artifacts in
2418
+ `path` (a relative path from the root artifact directory).
2419
+ """
2420
+ artifact_path = validate_path_is_safe(artifact_path)
2421
+ _get_request_message(DeleteArtifact())
2422
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
2423
+ artifact_repo.delete_artifacts(artifact_path)
2424
+ response_message = DeleteArtifact.Response()
2425
+ response = Response(mimetype="application/json")
2426
+ response.set_data(message_to_json(response_message))
2427
+ return response
2428
+
2429
+
2430
+ @catch_mlflow_exception
2431
+ def _graphql():
2432
+ from graphql import parse
2433
+
2434
+ from mlflow.server.graphql.graphql_no_batching import check_query_safety
2435
+ from mlflow.server.graphql.graphql_schema_extensions import schema
2436
+
2437
+ # Extracting the query, variables, and operationName from the request
2438
+ request_json = _get_request_json()
2439
+ query = request_json.get("query")
2440
+ variables = request_json.get("variables")
2441
+ operation_name = request_json.get("operationName")
2442
+
2443
+ node = parse(query)
2444
+ if check_result := check_query_safety(node):
2445
+ result = check_result
2446
+ else:
2447
+ # Executing the GraphQL query using the Graphene schema
2448
+ result = schema.execute(query, variables=variables, operation_name=operation_name)
2449
+
2450
+ # Convert execution result into json.
2451
+ result_data = {
2452
+ "data": result.data,
2453
+ "errors": [error.message for error in result.errors] if result.errors else None,
2454
+ }
2455
+
2456
+ # Return the response
2457
+ return jsonify(result_data)
2458
+
2459
+
2460
+ def _validate_support_multipart_upload(artifact_repo):
2461
+ if not isinstance(artifact_repo, MultipartUploadMixin):
2462
+ raise _UnsupportedMultipartUploadException()
2463
+
2464
+
2465
+ @catch_mlflow_exception
2466
+ @_disable_unless_serve_artifacts
2467
+ def _create_multipart_upload_artifact(artifact_path):
2468
+ """
2469
+ A request handler for `POST /mlflow-artifacts/mpu/create` to create a multipart upload
2470
+ to `artifact_path` (a relative path from the root artifact directory).
2471
+ """
2472
+ artifact_path = validate_path_is_safe(artifact_path)
2473
+
2474
+ request_message = _get_request_message(
2475
+ CreateMultipartUpload(),
2476
+ schema={
2477
+ "path": [_assert_required, _assert_string],
2478
+ "num_parts": [_assert_intlike],
2479
+ },
2480
+ )
2481
+ path = request_message.path
2482
+ num_parts = request_message.num_parts
2483
+
2484
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
2485
+ _validate_support_multipart_upload(artifact_repo)
2486
+
2487
+ create_response = artifact_repo.create_multipart_upload(
2488
+ path,
2489
+ num_parts,
2490
+ artifact_path,
2491
+ )
2492
+ response_message = create_response.to_proto()
2493
+ response = Response(mimetype="application/json")
2494
+ response.set_data(message_to_json(response_message))
2495
+ return response
2496
+
2497
+
2498
+ @catch_mlflow_exception
2499
+ @_disable_unless_serve_artifacts
2500
+ def _complete_multipart_upload_artifact(artifact_path):
2501
+ """
2502
+ A request handler for `POST /mlflow-artifacts/mpu/complete` to complete a multipart upload
2503
+ to `artifact_path` (a relative path from the root artifact directory).
2504
+ """
2505
+ artifact_path = validate_path_is_safe(artifact_path)
2506
+
2507
+ request_message = _get_request_message(
2508
+ CompleteMultipartUpload(),
2509
+ schema={
2510
+ "path": [_assert_required, _assert_string],
2511
+ "upload_id": [_assert_string],
2512
+ "parts": [_assert_required],
2513
+ },
2514
+ )
2515
+ path = request_message.path
2516
+ upload_id = request_message.upload_id
2517
+ parts = [MultipartUploadPart.from_proto(part) for part in request_message.parts]
2518
+
2519
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
2520
+ _validate_support_multipart_upload(artifact_repo)
2521
+
2522
+ artifact_repo.complete_multipart_upload(
2523
+ path,
2524
+ upload_id,
2525
+ parts,
2526
+ artifact_path,
2527
+ )
2528
+ return _wrap_response(CompleteMultipartUpload.Response())
2529
+
2530
+
2531
+ @catch_mlflow_exception
2532
+ @_disable_unless_serve_artifacts
2533
+ def _abort_multipart_upload_artifact(artifact_path):
2534
+ """
2535
+ A request handler for `POST /mlflow-artifacts/mpu/abort` to abort a multipart upload
2536
+ to `artifact_path` (a relative path from the root artifact directory).
2537
+ """
2538
+ artifact_path = validate_path_is_safe(artifact_path)
2539
+
2540
+ request_message = _get_request_message(
2541
+ AbortMultipartUpload(),
2542
+ schema={
2543
+ "path": [_assert_required, _assert_string],
2544
+ "upload_id": [_assert_string],
2545
+ },
2546
+ )
2547
+ path = request_message.path
2548
+ upload_id = request_message.upload_id
2549
+
2550
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
2551
+ _validate_support_multipart_upload(artifact_repo)
2552
+
2553
+ artifact_repo.abort_multipart_upload(
2554
+ path,
2555
+ upload_id,
2556
+ artifact_path,
2557
+ )
2558
+ return _wrap_response(AbortMultipartUpload.Response())
2559
+
2560
+
2561
+ # MLflow Tracing APIs
2562
+
2563
+
2564
+ @catch_mlflow_exception
2565
+ @_disable_if_artifacts_only
2566
+ def _start_trace_v3():
2567
+ """
2568
+ A request handler for `POST /mlflow/traces` to create a new TraceInfo record in tracking store.
2569
+ """
2570
+ request_message = _get_request_message(
2571
+ StartTraceV3(),
2572
+ schema={"trace": [_assert_required]},
2573
+ )
2574
+ trace_info = TraceInfo.from_proto(request_message.trace.trace_info)
2575
+ trace_info = _get_tracking_store().start_trace(trace_info)
2576
+ response_message = StartTraceV3.Response(trace=ProtoTrace(trace_info=trace_info.to_proto()))
2577
+ return _wrap_response(response_message)
2578
+
2579
+
2580
+ @catch_mlflow_exception
2581
+ @_disable_if_artifacts_only
2582
+ def _get_trace_info_v3(trace_id):
2583
+ """
2584
+ A request handler for `GET /mlflow/traces/{trace_id}/info` to retrieve
2585
+ an existing TraceInfo record from tracking store.
2586
+ """
2587
+ trace_info = _get_tracking_store().get_trace_info(trace_id)
2588
+ response_message = GetTraceInfoV3.Response(trace=ProtoTrace(trace_info=trace_info.to_proto()))
2589
+ return _wrap_response(response_message)
2590
+
2591
+
2592
+ @catch_mlflow_exception
2593
+ @_disable_if_artifacts_only
2594
+ def _search_traces_v3():
2595
+ """
2596
+ A request handler for `GET /mlflow/traces` to search for TraceInfo records in tracking store.
2597
+ """
2598
+ request_message = _get_request_message(
2599
+ SearchTracesV3(),
2600
+ schema={
2601
+ "locations": [_assert_array, _assert_required],
2602
+ "filter": [_assert_string],
2603
+ "max_results": [
2604
+ _assert_intlike,
2605
+ lambda x: _assert_less_than_or_equal(int(x), 500),
2606
+ ],
2607
+ "order_by": [_assert_array, _assert_item_type_string],
2608
+ "page_token": [_assert_string],
2609
+ },
2610
+ )
2611
+ experiment_ids = []
2612
+ for location in request_message.locations:
2613
+ if location.HasField("mlflow_experiment"):
2614
+ experiment_ids.append(location.mlflow_experiment.experiment_id)
2615
+
2616
+ traces, token = _get_tracking_store().search_traces(
2617
+ experiment_ids=experiment_ids,
2618
+ filter_string=request_message.filter,
2619
+ max_results=request_message.max_results,
2620
+ order_by=request_message.order_by,
2621
+ page_token=request_message.page_token,
2622
+ )
2623
+ response_message = SearchTracesV3.Response()
2624
+ response_message.traces.extend([e.to_proto() for e in traces])
2625
+ if token:
2626
+ response_message.next_page_token = token
2627
+ return _wrap_response(response_message)
2628
+
2629
+
2630
+ @catch_mlflow_exception
2631
+ @_disable_if_artifacts_only
2632
+ def _delete_traces():
2633
+ """
2634
+ A request handler for `POST /mlflow/traces/delete-traces` to delete TraceInfo records
2635
+ from tracking store.
2636
+ """
2637
+ request_message = _get_request_message(
2638
+ DeleteTraces(),
2639
+ schema={
2640
+ "experiment_id": [_assert_string, _assert_required],
2641
+ "max_timestamp_millis": [_assert_intlike],
2642
+ "max_traces": [_assert_intlike],
2643
+ "request_ids": [_assert_array, _assert_item_type_string],
2644
+ },
2645
+ )
2646
+
2647
+ # NB: Interestingly, the field accessor for the message object returns the default
2648
+ # value for optional field if it's not set. For example, `request_message.max_traces`
2649
+ # returns 0 if max_traces is not specified in the request. This is not desirable,
2650
+ # because null and 0 means completely opposite i.e. the former is 'delete nothing'
2651
+ # while the latter is 'delete all'. To handle this, we need to explicitly check
2652
+ # if the field is set or not using `HasField` method and return None if not.
2653
+ def _get_nullable_field(field):
2654
+ if request_message.HasField(field):
2655
+ return getattr(request_message, field)
2656
+ return None
2657
+
2658
+ traces_deleted = _get_tracking_store().delete_traces(
2659
+ experiment_id=request_message.experiment_id,
2660
+ max_timestamp_millis=_get_nullable_field("max_timestamp_millis"),
2661
+ max_traces=_get_nullable_field("max_traces"),
2662
+ trace_ids=request_message.request_ids,
2663
+ )
2664
+ return _wrap_response(DeleteTraces.Response(traces_deleted=traces_deleted))
2665
+
2666
+
2667
+ @catch_mlflow_exception
2668
+ @_disable_if_artifacts_only
2669
+ def _set_trace_tag(request_id):
2670
+ """
2671
+ A request handler for `PATCH /mlflow/traces/{request_id}/tags` to set tags on a TraceInfo record
2672
+ """
2673
+ request_message = _get_request_message(
2674
+ SetTraceTag(),
2675
+ schema={
2676
+ "key": [_assert_string, _assert_required],
2677
+ "value": [_assert_string],
2678
+ },
2679
+ )
2680
+ _get_tracking_store().set_trace_tag(request_id, request_message.key, request_message.value)
2681
+ return _wrap_response(SetTraceTag.Response())
2682
+
2683
+
2684
+ @catch_mlflow_exception
2685
+ @_disable_if_artifacts_only
2686
+ def _delete_trace_tag(request_id):
2687
+ """
2688
+ A request handler for `DELETE /mlflow/traces/{request_id}/tags` to delete tags from a TraceInfo
2689
+ record.
2690
+ """
2691
+ request_message = _get_request_message(
2692
+ DeleteTraceTag(),
2693
+ schema={
2694
+ "key": [_assert_string, _assert_required],
2695
+ },
2696
+ )
2697
+ _get_tracking_store().delete_trace_tag(request_id, request_message.key)
2698
+ return _wrap_response(DeleteTraceTag.Response())
2699
+
2700
+
2701
+ @catch_mlflow_exception
2702
+ @_disable_if_artifacts_only
2703
+ def get_trace_artifact_handler():
2704
+ request_id = request.args.get("request_id")
2705
+
2706
+ if not request_id:
2707
+ raise MlflowException(
2708
+ 'Request must include the "request_id" query parameter.',
2709
+ error_code=BAD_REQUEST,
2710
+ )
2711
+
2712
+ trace_info = _get_tracking_store().get_trace_info(request_id)
2713
+ trace_data = _get_trace_artifact_repo(trace_info).download_trace_data()
2714
+
2715
+ # Write data to a BytesIO buffer instead of needing to save a temp file
2716
+ buf = io.BytesIO()
2717
+ buf.write(json.dumps(trace_data).encode())
2718
+ buf.seek(0)
2719
+
2720
+ file_sender_response = send_file(
2721
+ buf,
2722
+ mimetype="application/octet-stream",
2723
+ as_attachment=True,
2724
+ download_name=TRACE_DATA_FILE_NAME,
2725
+ )
2726
+ return _response_with_file_attachment_headers(TRACE_DATA_FILE_NAME, file_sender_response)
2727
+
2728
+
2729
+ # Deprecated MLflow Tracing APIs. Kept for backward compatibility but do not use.
2730
+
2731
+
2732
+ @catch_mlflow_exception
2733
+ @_disable_if_artifacts_only
2734
+ def _deprecated_start_trace_v2():
2735
+ """
2736
+ A request handler for `POST /mlflow/traces` to create a new TraceInfo record in tracking store.
2737
+ """
2738
+ request_message = _get_request_message(
2739
+ StartTrace(),
2740
+ schema={
2741
+ "experiment_id": [_assert_string],
2742
+ "timestamp_ms": [_assert_intlike],
2743
+ "request_metadata": [_assert_map_key_present],
2744
+ "tags": [_assert_map_key_present],
2745
+ },
2746
+ )
2747
+ request_metadata = {e.key: e.value for e in request_message.request_metadata}
2748
+ tags = {e.key: e.value for e in request_message.tags}
2749
+
2750
+ trace_info = _get_tracking_store().deprecated_start_trace_v2(
2751
+ experiment_id=request_message.experiment_id,
2752
+ timestamp_ms=request_message.timestamp_ms,
2753
+ request_metadata=request_metadata,
2754
+ tags=tags,
2755
+ )
2756
+ response_message = StartTrace.Response(trace_info=trace_info.to_proto())
2757
+ return _wrap_response(response_message)
2758
+
2759
+
2760
+ @catch_mlflow_exception
2761
+ @_disable_if_artifacts_only
2762
+ def _deprecated_end_trace_v2(request_id):
2763
+ """
2764
+ A request handler for `PATCH /mlflow/traces/{request_id}` to mark an existing TraceInfo
2765
+ record completed in tracking store.
2766
+ """
2767
+ request_message = _get_request_message(
2768
+ EndTrace(),
2769
+ schema={
2770
+ "timestamp_ms": [_assert_intlike],
2771
+ "status": [_assert_string],
2772
+ "request_metadata": [_assert_map_key_present],
2773
+ "tags": [_assert_map_key_present],
2774
+ },
2775
+ )
2776
+ request_metadata = {e.key: e.value for e in request_message.request_metadata}
2777
+ tags = {e.key: e.value for e in request_message.tags}
2778
+
2779
+ trace_info = _get_tracking_store().deprecated_end_trace_v2(
2780
+ request_id=request_id,
2781
+ timestamp_ms=request_message.timestamp_ms,
2782
+ status=TraceStatus.from_proto(request_message.status),
2783
+ request_metadata=request_metadata,
2784
+ tags=tags,
2785
+ )
2786
+
2787
+ if isinstance(trace_info, TraceInfo):
2788
+ trace_info = TraceInfoV2.from_v3(trace_info)
2789
+
2790
+ response_message = EndTrace.Response(trace_info=trace_info.to_proto())
2791
+ return _wrap_response(response_message)
2792
+
2793
+
2794
+ @catch_mlflow_exception
2795
+ @_disable_if_artifacts_only
2796
+ def _deprecated_get_trace_info_v2(request_id):
2797
+ """
2798
+ A request handler for `GET /mlflow/traces/{request_id}/info` to retrieve
2799
+ an existing TraceInfo record from tracking store.
2800
+ """
2801
+ trace_info = _get_tracking_store().get_trace_info(request_id)
2802
+ trace_info = TraceInfoV2.from_v3(trace_info)
2803
+ response_message = GetTraceInfo.Response(trace_info=trace_info.to_proto())
2804
+ return _wrap_response(response_message)
2805
+
2806
+
2807
+ @catch_mlflow_exception
2808
+ @_disable_if_artifacts_only
2809
+ def _deprecated_search_traces_v2():
2810
+ """
2811
+ A request handler for `GET /mlflow/traces` to search for TraceInfo records in tracking store.
2812
+ """
2813
+ request_message = _get_request_message(
2814
+ SearchTraces(),
2815
+ schema={
2816
+ "experiment_ids": [
2817
+ _assert_array,
2818
+ _assert_item_type_string,
2819
+ _assert_required,
2820
+ ],
2821
+ "filter": [_assert_string],
2822
+ "max_results": [
2823
+ _assert_intlike,
2824
+ lambda x: _assert_less_than_or_equal(int(x), 500),
2825
+ ],
2826
+ "order_by": [_assert_array, _assert_item_type_string],
2827
+ "page_token": [_assert_string],
2828
+ },
2829
+ )
2830
+ traces, token = _get_tracking_store().search_traces(
2831
+ experiment_ids=request_message.experiment_ids,
2832
+ filter_string=request_message.filter,
2833
+ max_results=request_message.max_results,
2834
+ order_by=request_message.order_by,
2835
+ page_token=request_message.page_token,
2836
+ )
2837
+ traces = [TraceInfoV2.from_v3(t) for t in traces]
2838
+ response_message = SearchTraces.Response()
2839
+ response_message.traces.extend([e.to_proto() for e in traces])
2840
+ if token:
2841
+ response_message.next_page_token = token
2842
+ return _wrap_response(response_message)
2843
+
2844
+
2845
+ # Logged Models APIs
2846
+
2847
+
2848
+ @catch_mlflow_exception
2849
+ @_disable_if_artifacts_only
2850
+ def get_logged_model_artifact_handler(model_id: str):
2851
+ artifact_file_path = request.args.get("artifact_file_path")
2852
+ if not artifact_file_path:
2853
+ raise MlflowException(
2854
+ 'Request must include the "artifact_file_path" query parameter.',
2855
+ error_code=BAD_REQUEST,
2856
+ )
2857
+ validate_path_is_safe(artifact_file_path)
2858
+
2859
+ logged_model: LoggedModel = _get_tracking_store().get_logged_model(model_id)
2860
+ if _is_servable_proxied_run_artifact_root(logged_model.artifact_location):
2861
+ artifact_repo = _get_artifact_repo_mlflow_artifacts()
2862
+ artifact_path = _get_proxied_run_artifact_destination_path(
2863
+ proxied_artifact_root=logged_model.artifact_location,
2864
+ relative_path=artifact_file_path,
2865
+ )
2866
+ else:
2867
+ artifact_repo = get_artifact_repository(logged_model.artifact_location)
2868
+ artifact_path = artifact_file_path
2869
+
2870
+ return _send_artifact(artifact_repo, artifact_path)
2871
+
2872
+
2873
+ @catch_mlflow_exception
2874
+ @_disable_if_artifacts_only
2875
+ def _create_logged_model():
2876
+ request_message = _get_request_message(
2877
+ CreateLoggedModel(),
2878
+ schema={
2879
+ "experiment_id": [_assert_string, _assert_required],
2880
+ "name": [_assert_string],
2881
+ "model_type": [_assert_string],
2882
+ "source_run_id": [_assert_string],
2883
+ "params": [_assert_array],
2884
+ "tags": [_assert_array],
2885
+ },
2886
+ )
2887
+
2888
+ model = _get_tracking_store().create_logged_model(
2889
+ experiment_id=request_message.experiment_id,
2890
+ name=request_message.name or None,
2891
+ model_type=request_message.model_type,
2892
+ source_run_id=request_message.source_run_id,
2893
+ params=(
2894
+ [LoggedModelParameter.from_proto(param) for param in request_message.params]
2895
+ if request_message.params
2896
+ else None
2897
+ ),
2898
+ tags=(
2899
+ [LoggedModelTag(key=tag.key, value=tag.value) for tag in request_message.tags]
2900
+ if request_message.tags
2901
+ else None
2902
+ ),
2903
+ )
2904
+ response_message = CreateLoggedModel.Response(model=model.to_proto())
2905
+ return _wrap_response(response_message)
2906
+
2907
+
2908
+ @catch_mlflow_exception
2909
+ @_disable_if_artifacts_only
2910
+ def _log_logged_model_params(model_id: str):
2911
+ request_message = _get_request_message(
2912
+ LogLoggedModelParamsRequest(),
2913
+ schema={
2914
+ "model_id": [_assert_string, _assert_required],
2915
+ "params": [_assert_array],
2916
+ },
2917
+ )
2918
+ params = (
2919
+ [LoggedModelParameter.from_proto(param) for param in request_message.params]
2920
+ if request_message.params
2921
+ else []
2922
+ )
2923
+ _get_tracking_store().log_logged_model_params(model_id, params)
2924
+ return _wrap_response(LogLoggedModelParamsRequest.Response())
2925
+
2926
+
2927
+ @catch_mlflow_exception
2928
+ @_disable_if_artifacts_only
2929
+ def _get_logged_model(model_id: str):
2930
+ model = _get_tracking_store().get_logged_model(model_id)
2931
+ response_message = GetLoggedModel.Response(model=model.to_proto())
2932
+ return _wrap_response(response_message)
2933
+
2934
+
2935
+ @catch_mlflow_exception
2936
+ @_disable_if_artifacts_only
2937
+ def _finalize_logged_model(model_id: str):
2938
+ request_message = _get_request_message(
2939
+ FinalizeLoggedModel(),
2940
+ schema={
2941
+ "model_id": [_assert_string, _assert_required],
2942
+ "status": [_assert_intlike, _assert_required],
2943
+ },
2944
+ )
2945
+ model = _get_tracking_store().finalize_logged_model(
2946
+ request_message.model_id, LoggedModelStatus.from_int(request_message.status)
2947
+ )
2948
+ response_message = FinalizeLoggedModel.Response(model=model.to_proto())
2949
+ return _wrap_response(response_message)
2950
+
2951
+
2952
+ @catch_mlflow_exception
2953
+ @_disable_if_artifacts_only
2954
+ def _delete_logged_model(model_id: str):
2955
+ _get_tracking_store().delete_logged_model(model_id)
2956
+ return _wrap_response(DeleteLoggedModel.Response())
2957
+
2958
+
2959
+ @catch_mlflow_exception
2960
+ @_disable_if_artifacts_only
2961
+ def _set_logged_model_tags(model_id: str):
2962
+ request_message = _get_request_message(
2963
+ SetLoggedModelTags(),
2964
+ schema={"tags": [_assert_array]},
2965
+ )
2966
+ tags = [LoggedModelTag(key=tag.key, value=tag.value) for tag in request_message.tags]
2967
+ _get_tracking_store().set_logged_model_tags(model_id, tags)
2968
+ return _wrap_response(SetLoggedModelTags.Response())
2969
+
2970
+
2971
+ @catch_mlflow_exception
2972
+ @_disable_if_artifacts_only
2973
+ def _delete_logged_model_tag(model_id: str, tag_key: str):
2974
+ _get_tracking_store().delete_logged_model_tag(model_id, tag_key)
2975
+ return _wrap_response(DeleteLoggedModelTag.Response())
2976
+
2977
+
2978
+ @catch_mlflow_exception
2979
+ @_disable_if_artifacts_only
2980
+ def _search_logged_models():
2981
+ request_message = _get_request_message(
2982
+ SearchLoggedModels(),
2983
+ schema={
2984
+ "experiment_ids": [
2985
+ _assert_array,
2986
+ _assert_item_type_string,
2987
+ _assert_required,
2988
+ ],
2989
+ "filter": [_assert_string],
2990
+ "datasets": [_assert_array],
2991
+ "max_results": [_assert_intlike],
2992
+ "order_by": [_assert_array],
2993
+ "page_token": [_assert_string],
2994
+ },
2995
+ )
2996
+ models = _get_tracking_store().search_logged_models(
2997
+ # Convert `RepeatedScalarContainer` objects (experiment_ids and order_by) to `list`
2998
+ # to avoid serialization issues
2999
+ experiment_ids=list(request_message.experiment_ids),
3000
+ filter_string=request_message.filter or None,
3001
+ datasets=(
3002
+ [
3003
+ {
3004
+ "dataset_name": d.dataset_name,
3005
+ "dataset_digest": d.dataset_digest or None,
3006
+ }
3007
+ for d in request_message.datasets
3008
+ ]
3009
+ if request_message.datasets
3010
+ else None
3011
+ ),
3012
+ max_results=request_message.max_results or None,
3013
+ order_by=(
3014
+ [
3015
+ {
3016
+ "field_name": ob.field_name,
3017
+ "ascending": ob.ascending,
3018
+ "dataset_name": ob.dataset_name or None,
3019
+ "dataset_digest": ob.dataset_digest or None,
3020
+ }
3021
+ for ob in request_message.order_by
3022
+ ]
3023
+ if request_message.order_by
3024
+ else None
3025
+ ),
3026
+ page_token=request_message.page_token or None,
3027
+ )
3028
+ response_message = SearchLoggedModels.Response()
3029
+ response_message.models.extend([e.to_proto() for e in models])
3030
+ if models.token:
3031
+ response_message.next_page_token = models.token
3032
+ return _wrap_response(response_message)
3033
+
3034
+
3035
+ @catch_mlflow_exception
3036
+ @_disable_if_artifacts_only
3037
+ def _list_logged_model_artifacts(model_id: str):
3038
+ request_message = _get_request_message(
3039
+ ListLoggedModelArtifacts(),
3040
+ schema={"artifact_directory_path": [_assert_string]},
3041
+ )
3042
+ if request_message.HasField("artifact_directory_path"):
3043
+ artifact_path = validate_path_is_safe(request_message.artifact_directory_path)
3044
+ else:
3045
+ artifact_path = None
3046
+
3047
+ return _list_logged_model_artifacts_impl(model_id, artifact_path)
3048
+
3049
+
3050
+ def _list_logged_model_artifacts_impl(
3051
+ model_id: str, artifact_directory_path: Optional[str]
3052
+ ) -> Response:
3053
+ response = ListLoggedModelArtifacts.Response()
3054
+ logged_model: LoggedModel = _get_tracking_store().get_logged_model(model_id)
3055
+ if _is_servable_proxied_run_artifact_root(logged_model.artifact_location):
3056
+ artifacts = _list_artifacts_for_proxied_run_artifact_root(
3057
+ proxied_artifact_root=logged_model.artifact_location,
3058
+ relative_path=artifact_directory_path,
3059
+ )
3060
+ else:
3061
+ artifacts = get_artifact_repository(logged_model.artifact_location).list_artifacts(
3062
+ artifact_directory_path
3063
+ )
3064
+
3065
+ response.files.extend([a.to_proto() for a in artifacts])
3066
+ response.root_uri = logged_model.artifact_location
3067
+ return _wrap_response(response)
3068
+
3069
+
3070
+ def _get_rest_path(base_path, version=2):
3071
+ return f"/api/{version}.0{base_path}"
3072
+
3073
+
3074
+ def _get_ajax_path(base_path, version=2):
3075
+ return _add_static_prefix(f"/ajax-api/{version}.0{base_path}")
3076
+
3077
+
3078
+ def _add_static_prefix(route: str) -> str:
3079
+ if prefix := os.environ.get(STATIC_PREFIX_ENV_VAR):
3080
+ return prefix.rstrip("/") + route
3081
+ return route
3082
+
3083
+
3084
+ def _get_paths(base_path, version=2):
3085
+ """
3086
+ A service endpoints base path is typically something like /mlflow/experiment.
3087
+ We should register paths like /api/2.0/mlflow/experiment and
3088
+ /ajax-api/2.0/mlflow/experiment in the Flask router.
3089
+ """
3090
+ base_path = _convert_path_parameter_to_flask_format(base_path)
3091
+ return [_get_rest_path(base_path, version), _get_ajax_path(base_path, version)]
3092
+
3093
+
3094
+ def _convert_path_parameter_to_flask_format(path):
3095
+ """
3096
+ Converts path parameter format to Flask compatible format.
3097
+
3098
+ Some protobuf endpoint paths contain parameters like /mlflow/trace/{request_id}.
3099
+ This can be interpreted correctly by gRPC framework like Armeria, but Flask does
3100
+ not understand it. Instead, we need to specify it with a different format,
3101
+ like /mlflow/trace/<request_id>.
3102
+ """
3103
+ return re.sub(r"{(\w+)}", r"<\1>", path)
3104
+
3105
+
3106
+ def get_handler(request_class):
3107
+ """
3108
+ Args:
3109
+ request_class: The type of protobuf message
3110
+ """
3111
+ return HANDLERS.get(request_class, _not_implemented)
3112
+
3113
+
3114
+ def get_service_endpoints(service, get_handler):
3115
+ ret = []
3116
+ for service_method in service.DESCRIPTOR.methods:
3117
+ endpoints = service_method.GetOptions().Extensions[databricks_pb2.rpc].endpoints
3118
+ for endpoint in endpoints:
3119
+ for http_path in _get_paths(endpoint.path, version=endpoint.since.major):
3120
+ handler = get_handler(service().GetRequestClass(service_method))
3121
+ ret.append((http_path, handler, [endpoint.method]))
3122
+ return ret
3123
+
3124
+
3125
+ def get_endpoints(get_handler=get_handler):
3126
+ """
3127
+ Returns:
3128
+ List of tuples (path, handler, methods)
3129
+ """
3130
+ return (
3131
+ get_service_endpoints(MlflowService, get_handler)
3132
+ + get_service_endpoints(ModelRegistryService, get_handler)
3133
+ + get_service_endpoints(MlflowArtifactsService, get_handler)
3134
+ + [(_add_static_prefix("/graphql"), _graphql, ["GET", "POST"])]
3135
+ )
3136
+
3137
+
3138
+ HANDLERS = {
3139
+ # Tracking Server APIs
3140
+ CreateExperiment: _create_experiment,
3141
+ GetExperiment: _get_experiment,
3142
+ GetExperimentByName: _get_experiment_by_name,
3143
+ DeleteExperiment: _delete_experiment,
3144
+ RestoreExperiment: _restore_experiment,
3145
+ UpdateExperiment: _update_experiment,
3146
+ CreateRun: _create_run,
3147
+ UpdateRun: _update_run,
3148
+ DeleteRun: _delete_run,
3149
+ RestoreRun: _restore_run,
3150
+ LogParam: _log_param,
3151
+ LogMetric: _log_metric,
3152
+ SetExperimentTag: _set_experiment_tag,
3153
+ SetTag: _set_tag,
3154
+ DeleteTag: _delete_tag,
3155
+ LogBatch: _log_batch,
3156
+ LogModel: _log_model,
3157
+ GetRun: _get_run,
3158
+ SearchRuns: _search_runs,
3159
+ ListArtifacts: _list_artifacts,
3160
+ GetMetricHistory: _get_metric_history,
3161
+ GetMetricHistoryBulkInterval: get_metric_history_bulk_interval_handler,
3162
+ SearchExperiments: _search_experiments,
3163
+ LogInputs: _log_inputs,
3164
+ LogOutputs: _log_outputs,
3165
+ # Model Registry APIs
3166
+ CreateRegisteredModel: _create_registered_model,
3167
+ GetRegisteredModel: _get_registered_model,
3168
+ DeleteRegisteredModel: _delete_registered_model,
3169
+ UpdateRegisteredModel: _update_registered_model,
3170
+ RenameRegisteredModel: _rename_registered_model,
3171
+ SearchRegisteredModels: _search_registered_models,
3172
+ GetLatestVersions: _get_latest_versions,
3173
+ CreateModelVersion: _create_model_version,
3174
+ GetModelVersion: _get_model_version,
3175
+ DeleteModelVersion: _delete_model_version,
3176
+ UpdateModelVersion: _update_model_version,
3177
+ TransitionModelVersionStage: _transition_stage,
3178
+ GetModelVersionDownloadUri: _get_model_version_download_uri,
3179
+ SearchModelVersions: _search_model_versions,
3180
+ SetRegisteredModelTag: _set_registered_model_tag,
3181
+ DeleteRegisteredModelTag: _delete_registered_model_tag,
3182
+ SetModelVersionTag: _set_model_version_tag,
3183
+ DeleteModelVersionTag: _delete_model_version_tag,
3184
+ SetRegisteredModelAlias: _set_registered_model_alias,
3185
+ DeleteRegisteredModelAlias: _delete_registered_model_alias,
3186
+ GetModelVersionByAlias: _get_model_version_by_alias,
3187
+ # MLflow Artifacts APIs
3188
+ DownloadArtifact: _download_artifact,
3189
+ UploadArtifact: _upload_artifact,
3190
+ ListArtifactsMlflowArtifacts: _list_artifacts_mlflow_artifacts,
3191
+ DeleteArtifact: _delete_artifact_mlflow_artifacts,
3192
+ CreateMultipartUpload: _create_multipart_upload_artifact,
3193
+ CompleteMultipartUpload: _complete_multipart_upload_artifact,
3194
+ AbortMultipartUpload: _abort_multipart_upload_artifact,
3195
+ # MLflow Tracing APIs (V3)
3196
+ StartTraceV3: _start_trace_v3,
3197
+ GetTraceInfoV3: _get_trace_info_v3,
3198
+ SearchTracesV3: _search_traces_v3,
3199
+ DeleteTraces: _delete_traces,
3200
+ SetTraceTag: _set_trace_tag,
3201
+ DeleteTraceTag: _delete_trace_tag,
3202
+ # Legacy MLflow Tracing V2 APIs. Kept for backward compatibility but do not use.
3203
+ StartTrace: _deprecated_start_trace_v2,
3204
+ EndTrace: _deprecated_end_trace_v2,
3205
+ GetTraceInfo: _deprecated_get_trace_info_v2,
3206
+ SearchTraces: _deprecated_search_traces_v2,
3207
+ # Logged Models APIs
3208
+ CreateLoggedModel: _create_logged_model,
3209
+ GetLoggedModel: _get_logged_model,
3210
+ FinalizeLoggedModel: _finalize_logged_model,
3211
+ DeleteLoggedModel: _delete_logged_model,
3212
+ SetLoggedModelTags: _set_logged_model_tags,
3213
+ DeleteLoggedModelTag: _delete_logged_model_tag,
3214
+ SearchLoggedModels: _search_logged_models,
3215
+ ListLoggedModelArtifacts: _list_logged_model_artifacts,
3216
+ LogLoggedModelParamsRequest: _log_logged_model_params,
3217
+ }