genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,2785 @@
1
+ import json
2
+ import logging
3
+ import math
4
+ import random
5
+ import threading
6
+ import time
7
+ import uuid
8
+ from collections import defaultdict
9
+ from functools import reduce
10
+ from typing import Any, Optional, TypedDict
11
+
12
+ import sqlalchemy
13
+ import sqlalchemy.orm
14
+ import sqlalchemy.sql.expression as sql
15
+ from sqlalchemy import and_, func, sql, text
16
+ from sqlalchemy.future import select
17
+
18
+ import mlflow.store.db.utils
19
+ from mlflow.entities import (
20
+ DatasetInput,
21
+ Experiment,
22
+ Run,
23
+ RunInputs,
24
+ RunOutputs,
25
+ RunStatus,
26
+ RunTag,
27
+ SourceType,
28
+ TraceInfo,
29
+ ViewType,
30
+ _DatasetSummary,
31
+ )
32
+ from mlflow.entities.lifecycle_stage import LifecycleStage
33
+ from mlflow.entities.logged_model import LoggedModel
34
+ from mlflow.entities.logged_model_input import LoggedModelInput
35
+ from mlflow.entities.logged_model_output import LoggedModelOutput
36
+ from mlflow.entities.logged_model_parameter import LoggedModelParameter
37
+ from mlflow.entities.logged_model_status import LoggedModelStatus
38
+ from mlflow.entities.logged_model_tag import LoggedModelTag
39
+ from mlflow.entities.metric import Metric, MetricWithRunId
40
+ from mlflow.entities.trace_info_v2 import TraceInfoV2
41
+ from mlflow.entities.trace_status import TraceStatus
42
+ from mlflow.exceptions import MlflowException
43
+ from mlflow.protos.databricks_pb2 import (
44
+ INTERNAL_ERROR,
45
+ INVALID_PARAMETER_VALUE,
46
+ INVALID_STATE,
47
+ RESOURCE_ALREADY_EXISTS,
48
+ RESOURCE_DOES_NOT_EXIST,
49
+ )
50
+ from mlflow.store.db.db_types import MSSQL, MYSQL
51
+ from mlflow.store.entities.paged_list import PagedList
52
+ from mlflow.store.tracking import (
53
+ SEARCH_LOGGED_MODEL_MAX_RESULTS_DEFAULT,
54
+ SEARCH_MAX_RESULTS_DEFAULT,
55
+ SEARCH_MAX_RESULTS_THRESHOLD,
56
+ SEARCH_TRACES_DEFAULT_MAX_RESULTS,
57
+ )
58
+ from mlflow.store.tracking.abstract_store import AbstractStore
59
+ from mlflow.store.tracking.dbmodels.models import (
60
+ SqlDataset,
61
+ SqlExperiment,
62
+ SqlExperimentTag,
63
+ SqlInput,
64
+ SqlInputTag,
65
+ SqlLatestMetric,
66
+ SqlLoggedModel,
67
+ SqlLoggedModelMetric,
68
+ SqlLoggedModelParam,
69
+ SqlLoggedModelTag,
70
+ SqlMetric,
71
+ SqlParam,
72
+ SqlRun,
73
+ SqlTag,
74
+ SqlTraceInfo,
75
+ SqlTraceMetadata,
76
+ SqlTraceTag,
77
+ )
78
+ from mlflow.tracing.utils import generate_request_id_v2
79
+ from mlflow.tracking.fluent import _get_experiment_id
80
+ from mlflow.utils.file_utils import local_file_uri_to_path, mkdir
81
+ from mlflow.utils.mlflow_tags import (
82
+ MLFLOW_ARTIFACT_LOCATION,
83
+ MLFLOW_DATASET_CONTEXT,
84
+ MLFLOW_LOGGED_MODELS,
85
+ MLFLOW_RUN_NAME,
86
+ _get_run_name_from_tags,
87
+ )
88
+ from mlflow.utils.name_utils import _generate_random_name
89
+ from mlflow.utils.search_utils import (
90
+ SearchExperimentsUtils,
91
+ SearchLoggedModelsPaginationToken,
92
+ SearchTraceUtils,
93
+ SearchUtils,
94
+ )
95
+ from mlflow.utils.string_utils import is_string_type
96
+ from mlflow.utils.time import get_current_time_millis
97
+ from mlflow.utils.uri import (
98
+ append_to_uri_path,
99
+ extract_db_type_from_uri,
100
+ is_local_uri,
101
+ resolve_uri_if_local,
102
+ )
103
+ from mlflow.utils.validation import (
104
+ _validate_batch_log_data,
105
+ _validate_batch_log_limits,
106
+ _validate_dataset_inputs,
107
+ _validate_experiment_artifact_location_length,
108
+ _validate_experiment_name,
109
+ _validate_experiment_tag,
110
+ _validate_logged_model_name,
111
+ _validate_metric,
112
+ _validate_param,
113
+ _validate_param_keys_unique,
114
+ _validate_run_id,
115
+ _validate_tag,
116
+ _validate_trace_tag,
117
+ )
118
+
119
+ _logger = logging.getLogger(__name__)
120
+
121
+ # For each database table, fetch its columns and define an appropriate attribute for each column
122
+ # on the table's associated object representation (Mapper). This is necessary to ensure that
123
+ # columns defined via backreference are available as Mapper instance attributes (e.g.,
124
+ # ``SqlExperiment.tags`` and ``SqlRun.params``). For more information, see
125
+ # https://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.configure_mappers
126
+ # and https://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.mapper.Mapper
127
+ sqlalchemy.orm.configure_mappers()
128
+
129
+
130
+ class DatasetFilter(TypedDict, total=False):
131
+ """
132
+ Dataset filter used for search_logged_models.
133
+ """
134
+
135
+ dataset_name: str
136
+ dataset_digest: str
137
+
138
+
139
+ class SqlAlchemyStore(AbstractStore):
140
+ """
141
+ SQLAlchemy compliant backend store for tracking meta data for MLflow entities. MLflow
142
+ supports the database dialects ``mysql``, ``mssql``, ``sqlite``, and ``postgresql``.
143
+ As specified in the
144
+ `SQLAlchemy docs <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_ ,
145
+ the database URI is expected in the format
146
+ ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``. If you do not
147
+ specify a driver, SQLAlchemy uses a dialect's default driver.
148
+
149
+ This store interacts with SQL store using SQLAlchemy abstractions defined for MLflow entities.
150
+ :py:class:`mlflow.store.dbmodels.models.SqlExperiment`,
151
+ :py:class:`mlflow.store.dbmodels.models.SqlRun`,
152
+ :py:class:`mlflow.store.dbmodels.models.SqlTag`,
153
+ :py:class:`mlflow.store.dbmodels.models.SqlMetric`, and
154
+ :py:class:`mlflow.store.dbmodels.models.SqlParam`.
155
+
156
+ Run artifacts are stored in a separate location using artifact stores conforming to
157
+ :py:class:`mlflow.store.artifact_repo.ArtifactRepository`. Default artifact locations for
158
+ user experiments are stored in the database along with metadata. Each run artifact location
159
+ is recorded in :py:class:`mlflow.store.dbmodels.models.SqlRun` and stored in the backend DB.
160
+ """
161
+
162
+ ARTIFACTS_FOLDER_NAME = "artifacts"
163
+ MODELS_FOLDER_NAME = "models"
164
+ TRACE_FOLDER_NAME = "traces"
165
+ DEFAULT_EXPERIMENT_ID = "0"
166
+ _db_uri_sql_alchemy_engine_map = {}
167
+ _db_uri_sql_alchemy_engine_map_lock = threading.Lock()
168
+
169
+ def __init__(self, db_uri, default_artifact_root):
170
+ """
171
+ Create a database backed store.
172
+
173
+ Args:
174
+ db_uri: The SQLAlchemy database URI string to connect to the database. See
175
+ the `SQLAlchemy docs
176
+ <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
177
+ for format specifications. MLflow supports the dialects ``mysql``,
178
+ ``mssql``, ``sqlite``, and ``postgresql``.
179
+ default_artifact_root: Path/URI to location suitable for large data (such as a blob
180
+ store object, DBFS path, or shared NFS file system).
181
+ """
182
+ super().__init__()
183
+ self.db_uri = db_uri
184
+ self.db_type = extract_db_type_from_uri(db_uri)
185
+ self.artifact_root_uri = resolve_uri_if_local(default_artifact_root)
186
+ # Quick check to see if the respective SQLAlchemy database engine has already been created.
187
+ if db_uri not in SqlAlchemyStore._db_uri_sql_alchemy_engine_map:
188
+ with SqlAlchemyStore._db_uri_sql_alchemy_engine_map_lock:
189
+ # Repeat check to prevent race conditions where one thread checks for an existing
190
+ # engine while another is creating the respective one, resulting in multiple
191
+ # engines being created. It isn't combined with the above check to prevent
192
+ # inefficiency from multiple threads waiting for the lock to check for engine
193
+ # existence if it has already been created.
194
+ if db_uri not in SqlAlchemyStore._db_uri_sql_alchemy_engine_map:
195
+ SqlAlchemyStore._db_uri_sql_alchemy_engine_map[db_uri] = (
196
+ mlflow.store.db.utils.create_sqlalchemy_engine_with_retry(db_uri)
197
+ )
198
+ self.engine = SqlAlchemyStore._db_uri_sql_alchemy_engine_map[db_uri]
199
+ # On a completely fresh MLflow installation against an empty database (verify database
200
+ # emptiness by checking that 'experiments' etc aren't in the list of table names), run all
201
+ # DB migrations
202
+ if not mlflow.store.db.utils._all_tables_exist(self.engine):
203
+ mlflow.store.db.utils._initialize_tables(self.engine)
204
+ SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
205
+ self.ManagedSessionMaker = mlflow.store.db.utils._get_managed_session_maker(
206
+ SessionMaker, self.db_type
207
+ )
208
+ mlflow.store.db.utils._verify_schema(self.engine)
209
+
210
+ if is_local_uri(default_artifact_root):
211
+ mkdir(local_file_uri_to_path(default_artifact_root))
212
+
213
+ if len(self.search_experiments(view_type=ViewType.ALL)) == 0:
214
+ with self.ManagedSessionMaker() as session:
215
+ self._create_default_experiment(session)
216
+
217
+ def _get_dialect(self):
218
+ return self.engine.dialect.name
219
+
220
+ def _dispose_engine(self):
221
+ self.engine.dispose()
222
+
223
+ def _set_zero_value_insertion_for_autoincrement_column(self, session):
224
+ if self.db_type == MYSQL:
225
+ # config letting MySQL override default
226
+ # to allow 0 value for experiment ID (auto increment column)
227
+ session.execute(sql.text("SET @@SESSION.sql_mode='NO_AUTO_VALUE_ON_ZERO';"))
228
+ if self.db_type == MSSQL:
229
+ # config letting MSSQL override default
230
+ # to allow any manual value inserted into IDENTITY column
231
+ session.execute(sql.text("SET IDENTITY_INSERT experiments ON;"))
232
+
233
+ # DB helper methods to allow zero values for columns with auto increments
234
+ def _unset_zero_value_insertion_for_autoincrement_column(self, session):
235
+ if self.db_type == MYSQL:
236
+ session.execute(sql.text("SET @@SESSION.sql_mode='';"))
237
+ if self.db_type == MSSQL:
238
+ session.execute(sql.text("SET IDENTITY_INSERT experiments OFF;"))
239
+
240
+ def _create_default_experiment(self, session):
241
+ """
242
+ MLflow UI and client code expects a default experiment with ID 0.
243
+ This method uses SQL insert statement to create the default experiment as a hack, since
244
+ experiment table uses 'experiment_id' column is a PK and is also set to auto increment.
245
+ MySQL and other implementation do not allow value '0' for such cases.
246
+
247
+ ToDo: Identify a less hacky mechanism to create default experiment 0
248
+ """
249
+ table = SqlExperiment.__tablename__
250
+ creation_time = get_current_time_millis()
251
+ default_experiment = {
252
+ SqlExperiment.experiment_id.name: int(SqlAlchemyStore.DEFAULT_EXPERIMENT_ID),
253
+ SqlExperiment.name.name: Experiment.DEFAULT_EXPERIMENT_NAME,
254
+ SqlExperiment.artifact_location.name: str(self._get_artifact_location(0)),
255
+ SqlExperiment.lifecycle_stage.name: LifecycleStage.ACTIVE,
256
+ SqlExperiment.creation_time.name: creation_time,
257
+ SqlExperiment.last_update_time.name: creation_time,
258
+ }
259
+
260
+ def decorate(s):
261
+ if is_string_type(s):
262
+ return repr(s)
263
+ else:
264
+ return str(s)
265
+
266
+ # Get a list of keys to ensure we have a deterministic ordering
267
+ columns = list(default_experiment.keys())
268
+ values = ", ".join([decorate(default_experiment.get(c)) for c in columns])
269
+
270
+ try:
271
+ self._set_zero_value_insertion_for_autoincrement_column(session)
272
+ session.execute(
273
+ sql.text(f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({values});")
274
+ )
275
+ finally:
276
+ self._unset_zero_value_insertion_for_autoincrement_column(session)
277
+
278
+ def _get_or_create(self, session, model, **kwargs):
279
+ instance = session.query(model).filter_by(**kwargs).first()
280
+ created = False
281
+
282
+ if instance:
283
+ return instance, created
284
+ else:
285
+ instance = model(**kwargs)
286
+ session.add(instance)
287
+ created = True
288
+
289
+ return instance, created
290
+
291
+ def _get_artifact_location(self, experiment_id):
292
+ return append_to_uri_path(self.artifact_root_uri, str(experiment_id))
293
+
294
+ def create_experiment(self, name, artifact_location=None, tags=None):
295
+ _validate_experiment_name(name)
296
+
297
+ # Genesis-Flow: Use MLFLOW_ARTIFACT_LOCATION if no artifact location is provided
298
+ if not artifact_location:
299
+ from mlflow.environment_variables import MLFLOW_ARTIFACT_LOCATION
300
+ if MLFLOW_ARTIFACT_LOCATION.defined:
301
+ artifact_location = MLFLOW_ARTIFACT_LOCATION.get()
302
+
303
+ if artifact_location:
304
+ artifact_location = resolve_uri_if_local(artifact_location)
305
+ _validate_experiment_artifact_location_length(artifact_location)
306
+ with self.ManagedSessionMaker() as session:
307
+ try:
308
+ creation_time = get_current_time_millis()
309
+ experiment = SqlExperiment(
310
+ name=name,
311
+ lifecycle_stage=LifecycleStage.ACTIVE,
312
+ artifact_location=artifact_location,
313
+ creation_time=creation_time,
314
+ last_update_time=creation_time,
315
+ )
316
+ experiment.tags = (
317
+ [SqlExperimentTag(key=tag.key, value=tag.value) for tag in tags] if tags else []
318
+ )
319
+ session.add(experiment)
320
+ if not artifact_location:
321
+ # this requires a double write. The first one to generate an autoincrement-ed ID
322
+ eid = session.query(SqlExperiment).filter_by(name=name).first().experiment_id
323
+ experiment.artifact_location = self._get_artifact_location(eid)
324
+ except sqlalchemy.exc.IntegrityError as e:
325
+ raise MlflowException(
326
+ f"Experiment(name={name}) already exists. Error: {e}",
327
+ RESOURCE_ALREADY_EXISTS,
328
+ )
329
+
330
+ session.flush()
331
+ return str(experiment.experiment_id)
332
+
333
+ def _search_experiments(
334
+ self,
335
+ view_type,
336
+ max_results,
337
+ filter_string,
338
+ order_by,
339
+ page_token,
340
+ ):
341
+ def compute_next_token(current_size):
342
+ next_token = None
343
+ if max_results + 1 == current_size:
344
+ final_offset = offset + max_results
345
+ next_token = SearchExperimentsUtils.create_page_token(final_offset)
346
+
347
+ return next_token
348
+
349
+ self._validate_max_results_param(max_results)
350
+ with self.ManagedSessionMaker() as session:
351
+ parsed_filters = SearchExperimentsUtils.parse_search_filter(filter_string)
352
+ attribute_filters, non_attribute_filters = _get_search_experiments_filter_clauses(
353
+ parsed_filters, self._get_dialect()
354
+ )
355
+
356
+ order_by_clauses = _get_search_experiments_order_by_clauses(order_by)
357
+ offset = SearchUtils.parse_start_offset_from_page_token(page_token)
358
+ lifecycle_stags = set(LifecycleStage.view_type_to_stages(view_type))
359
+
360
+ stmt = (
361
+ reduce(lambda s, f: s.join(f), non_attribute_filters, select(SqlExperiment))
362
+ .options(*self._get_eager_experiment_query_options())
363
+ .filter(
364
+ *attribute_filters,
365
+ SqlExperiment.lifecycle_stage.in_(lifecycle_stags),
366
+ )
367
+ .order_by(*order_by_clauses)
368
+ .offset(offset)
369
+ .limit(max_results + 1)
370
+ )
371
+ queried_experiments = session.execute(stmt).scalars(SqlExperiment).all()
372
+ experiments = [e.to_mlflow_entity() for e in queried_experiments]
373
+ next_page_token = compute_next_token(len(experiments))
374
+
375
+ return experiments[:max_results], next_page_token
376
+
377
+ def search_experiments(
378
+ self,
379
+ view_type=ViewType.ACTIVE_ONLY,
380
+ max_results=SEARCH_MAX_RESULTS_DEFAULT,
381
+ filter_string=None,
382
+ order_by=None,
383
+ page_token=None,
384
+ ):
385
+ experiments, next_page_token = self._search_experiments(
386
+ view_type, max_results, filter_string, order_by, page_token
387
+ )
388
+ return PagedList(experiments, next_page_token)
389
+
390
+ def _get_experiment(self, session, experiment_id, view_type, eager=False): # noqa: D417
391
+ """
392
+ Args:
393
+ eager: If ``True``, eagerly loads the experiments's tags. If ``False``, these tags
394
+ are not eagerly loaded and will be loaded if/when their corresponding
395
+ object properties are accessed from the resulting ``SqlExperiment`` object.
396
+ """
397
+ experiment_id = experiment_id or SqlAlchemyStore.DEFAULT_EXPERIMENT_ID
398
+ stages = LifecycleStage.view_type_to_stages(view_type)
399
+ query_options = self._get_eager_experiment_query_options() if eager else []
400
+
401
+ experiment = (
402
+ session.query(SqlExperiment)
403
+ .options(*query_options)
404
+ .filter(
405
+ SqlExperiment.experiment_id == experiment_id,
406
+ SqlExperiment.lifecycle_stage.in_(stages),
407
+ )
408
+ .one_or_none()
409
+ )
410
+
411
+ if experiment is None:
412
+ raise MlflowException(
413
+ f"No Experiment with id={experiment_id} exists", RESOURCE_DOES_NOT_EXIST
414
+ )
415
+
416
+ return experiment
417
+
418
+ @staticmethod
419
+ def _get_eager_experiment_query_options():
420
+ """
421
+ A list of SQLAlchemy query options that can be used to eagerly load the following
422
+ experiment attributes when fetching an experiment: ``tags``.
423
+ """
424
+ return [
425
+ # Use a subquery load rather than a joined load in order to minimize the memory overhead
426
+ # of the eager loading procedure. For more information about relationship loading
427
+ # techniques, see https://docs.sqlalchemy.org/en/13/orm/
428
+ # loading_relationships.html#relationship-loading-techniques
429
+ sqlalchemy.orm.subqueryload(SqlExperiment.tags),
430
+ ]
431
+
432
+ def get_experiment(self, experiment_id):
433
+ with self.ManagedSessionMaker() as session:
434
+ return self._get_experiment(
435
+ session, experiment_id, ViewType.ALL, eager=True
436
+ ).to_mlflow_entity()
437
+
438
+ def get_experiment_by_name(self, experiment_name):
439
+ """
440
+ Specialized implementation for SQL backed store.
441
+ """
442
+ with self.ManagedSessionMaker() as session:
443
+ stages = LifecycleStage.view_type_to_stages(ViewType.ALL)
444
+ experiment = (
445
+ session.query(SqlExperiment)
446
+ .options(*self._get_eager_experiment_query_options())
447
+ .filter(
448
+ SqlExperiment.name == experiment_name,
449
+ SqlExperiment.lifecycle_stage.in_(stages),
450
+ )
451
+ .one_or_none()
452
+ )
453
+ return experiment.to_mlflow_entity() if experiment is not None else None
454
+
455
+ def delete_experiment(self, experiment_id):
456
+ with self.ManagedSessionMaker() as session:
457
+ experiment = self._get_experiment(session, experiment_id, ViewType.ACTIVE_ONLY)
458
+ experiment.lifecycle_stage = LifecycleStage.DELETED
459
+ experiment.last_update_time = get_current_time_millis()
460
+ runs = self._list_run_infos(session, experiment_id)
461
+ for run in runs:
462
+ self._mark_run_deleted(session, run)
463
+ session.add(experiment)
464
+
465
+ def _hard_delete_experiment(self, experiment_id):
466
+ """
467
+ Permanently delete a experiment (metadata and metrics, tags, parameters).
468
+ This is used by the ``mlflow gc`` command line and is not intended to be used elsewhere.
469
+ """
470
+ with self.ManagedSessionMaker() as session:
471
+ experiment = self._get_experiment(
472
+ experiment_id=experiment_id,
473
+ session=session,
474
+ view_type=ViewType.DELETED_ONLY,
475
+ )
476
+ session.delete(experiment)
477
+
478
+ def _mark_run_deleted(self, session, run):
479
+ run.lifecycle_stage = LifecycleStage.DELETED
480
+ run.deleted_time = get_current_time_millis()
481
+ session.add(run)
482
+
483
+ def _mark_run_active(self, session, run):
484
+ run.lifecycle_stage = LifecycleStage.ACTIVE
485
+ run.deleted_time = None
486
+ session.add(run)
487
+
488
+ def _list_run_infos(self, session, experiment_id):
489
+ return session.query(SqlRun).filter(SqlRun.experiment_id == experiment_id).all()
490
+
491
+ def restore_experiment(self, experiment_id):
492
+ with self.ManagedSessionMaker() as session:
493
+ experiment = self._get_experiment(session, experiment_id, ViewType.DELETED_ONLY)
494
+ experiment.lifecycle_stage = LifecycleStage.ACTIVE
495
+ experiment.last_update_time = get_current_time_millis()
496
+ runs = self._list_run_infos(session, experiment_id)
497
+ for run in runs:
498
+ self._mark_run_active(session, run)
499
+ session.add(experiment)
500
+
501
+ def rename_experiment(self, experiment_id, new_name):
502
+ with self.ManagedSessionMaker() as session:
503
+ experiment = self._get_experiment(session, experiment_id, ViewType.ALL)
504
+ if experiment.lifecycle_stage != LifecycleStage.ACTIVE:
505
+ raise MlflowException("Cannot rename a non-active experiment.", INVALID_STATE)
506
+
507
+ experiment.name = new_name
508
+ experiment.last_update_time = get_current_time_millis()
509
+ session.add(experiment)
510
+
511
+ def create_run(self, experiment_id, user_id, start_time, tags, run_name):
512
+ with self.ManagedSessionMaker() as session:
513
+ experiment = self.get_experiment(experiment_id)
514
+ self._check_experiment_is_active(experiment)
515
+
516
+ # Note: we need to ensure the generated "run_id" only contains digits and lower
517
+ # case letters, because some query filters contain "IN" clause, and in MYSQL the
518
+ # "IN" clause is case-insensitive, we use a trick that filters out comparison values
519
+ # containing upper case letters when parsing "IN" clause inside query filter.
520
+ run_id = uuid.uuid4().hex
521
+ artifact_location = append_to_uri_path(
522
+ experiment.artifact_location,
523
+ run_id,
524
+ SqlAlchemyStore.ARTIFACTS_FOLDER_NAME,
525
+ )
526
+ tags = tags.copy() if tags else []
527
+ run_name_tag = _get_run_name_from_tags(tags)
528
+ if run_name and run_name_tag and (run_name != run_name_tag):
529
+ raise MlflowException(
530
+ "Both 'run_name' argument and 'mlflow.runName' tag are specified, but with "
531
+ f"different values (run_name='{run_name}', mlflow.runName='{run_name_tag}').",
532
+ INVALID_PARAMETER_VALUE,
533
+ )
534
+ run_name = run_name or run_name_tag or _generate_random_name()
535
+ if not run_name_tag:
536
+ tags.append(RunTag(key=MLFLOW_RUN_NAME, value=run_name))
537
+ run = SqlRun(
538
+ name=run_name,
539
+ artifact_uri=artifact_location,
540
+ run_uuid=run_id,
541
+ experiment_id=experiment_id,
542
+ source_type=SourceType.to_string(SourceType.UNKNOWN),
543
+ source_name="",
544
+ entry_point_name="",
545
+ user_id=user_id,
546
+ status=RunStatus.to_string(RunStatus.RUNNING),
547
+ start_time=start_time,
548
+ end_time=None,
549
+ deleted_time=None,
550
+ source_version="",
551
+ lifecycle_stage=LifecycleStage.ACTIVE,
552
+ )
553
+
554
+ run.tags = [SqlTag(key=tag.key, value=tag.value) for tag in tags]
555
+ session.add(run)
556
+
557
+ run = run.to_mlflow_entity()
558
+ inputs_list = self._get_run_inputs(session, [run_id])
559
+ dataset_inputs = inputs_list[0] if inputs_list else []
560
+ return Run(run.info, run.data, RunInputs(dataset_inputs=dataset_inputs))
561
+
562
+ def _get_run(self, session, run_uuid, eager=False): # noqa: D417
563
+ """
564
+ Args:
565
+ eager: If ``True``, eagerly loads the run's summary metrics (``latest_metrics``),
566
+ params, and tags when fetching the run. If ``False``, these attributes
567
+ are not eagerly loaded and will be loaded when their corresponding
568
+ object properties are accessed from the resulting ``SqlRun`` object.
569
+ """
570
+ query_options = self._get_eager_run_query_options() if eager else []
571
+ runs = (
572
+ session.query(SqlRun).options(*query_options).filter(SqlRun.run_uuid == run_uuid).all()
573
+ )
574
+
575
+ if len(runs) == 0:
576
+ raise MlflowException(f"Run with id={run_uuid} not found", RESOURCE_DOES_NOT_EXIST)
577
+ if len(runs) > 1:
578
+ raise MlflowException(
579
+ f"Expected only 1 run with id={run_uuid}. Found {len(runs)}.",
580
+ INVALID_STATE,
581
+ )
582
+
583
+ return runs[0]
584
+
585
+ def _get_run_inputs(self, session, run_uuids):
586
+ datasets_with_tags = (
587
+ session.query(
588
+ SqlInput.input_uuid,
589
+ SqlInput.destination_id.label("run_uuid"),
590
+ SqlDataset,
591
+ SqlInputTag,
592
+ )
593
+ .select_from(SqlInput)
594
+ .join(SqlDataset, SqlInput.source_id == SqlDataset.dataset_uuid)
595
+ .outerjoin(SqlInputTag, SqlInputTag.input_uuid == SqlInput.input_uuid)
596
+ .filter(SqlInput.destination_type == "RUN", SqlInput.destination_id.in_(run_uuids))
597
+ .order_by("run_uuid")
598
+ ).all()
599
+
600
+ dataset_inputs_per_run = defaultdict(dict)
601
+ for input_uuid, run_uuid, dataset_sql, tag_sql in datasets_with_tags:
602
+ dataset_inputs = dataset_inputs_per_run[run_uuid]
603
+ dataset_uuid = dataset_sql.dataset_uuid
604
+ dataset_input = dataset_inputs.get(dataset_uuid)
605
+ if dataset_input is None:
606
+ dataset_entity = dataset_sql.to_mlflow_entity()
607
+ dataset_input = DatasetInput(dataset=dataset_entity, tags=[])
608
+ dataset_inputs[dataset_uuid] = dataset_input
609
+ if tag_sql is not None:
610
+ dataset_input.tags.append(tag_sql.to_mlflow_entity())
611
+ return [list(dataset_inputs_per_run[run_uuid].values()) for run_uuid in run_uuids]
612
+
613
+ @staticmethod
614
+ def _get_eager_run_query_options():
615
+ """
616
+ A list of SQLAlchemy query options that can be used to eagerly load the following
617
+ run attributes when fetching a run: ``latest_metrics``, ``params``, and ``tags``.
618
+ """
619
+ return [
620
+ # Use a select in load rather than a joined load in order to minimize the memory
621
+ # overhead of the eager loading procedure. For more information about relationship
622
+ # loading techniques, see https://docs.sqlalchemy.org/en/13/orm/
623
+ # loading_relationships.html#relationship-loading-techniques
624
+ sqlalchemy.orm.selectinload(SqlRun.latest_metrics),
625
+ sqlalchemy.orm.selectinload(SqlRun.params),
626
+ sqlalchemy.orm.selectinload(SqlRun.tags),
627
+ ]
628
+
629
+ def _check_run_is_active(self, run):
630
+ if run.lifecycle_stage != LifecycleStage.ACTIVE:
631
+ raise MlflowException(
632
+ (
633
+ f"The run {run.run_uuid} must be in the 'active' state. "
634
+ f"Current state is {run.lifecycle_stage}."
635
+ ),
636
+ INVALID_PARAMETER_VALUE,
637
+ )
638
+
639
+ def _check_experiment_is_active(self, experiment):
640
+ if experiment.lifecycle_stage != LifecycleStage.ACTIVE:
641
+ raise MlflowException(
642
+ (
643
+ f"The experiment {experiment.experiment_id} must be in the 'active' state. "
644
+ f"Current state is {experiment.lifecycle_stage}."
645
+ ),
646
+ INVALID_PARAMETER_VALUE,
647
+ )
648
+
649
+ def update_run_info(self, run_id, run_status, end_time, run_name):
650
+ with self.ManagedSessionMaker() as session:
651
+ run = self._get_run(run_uuid=run_id, session=session)
652
+ self._check_run_is_active(run)
653
+ if run_status is not None:
654
+ run.status = RunStatus.to_string(run_status)
655
+ if end_time is not None:
656
+ run.end_time = end_time
657
+ if run_name:
658
+ run.name = run_name
659
+ run_name_tag = self._try_get_run_tag(session, run_id, MLFLOW_RUN_NAME)
660
+ if run_name_tag is None:
661
+ run.tags.append(SqlTag(key=MLFLOW_RUN_NAME, value=run_name))
662
+ else:
663
+ run_name_tag.value = run_name
664
+
665
+ session.add(run)
666
+ run = run.to_mlflow_entity()
667
+
668
+ return run.info
669
+
670
+ def _try_get_run_tag(self, session, run_id, tagKey, eager=False):
671
+ query_options = self._get_eager_run_query_options() if eager else []
672
+ return (
673
+ session.query(SqlTag)
674
+ .options(*query_options)
675
+ .filter(SqlTag.run_uuid == run_id, SqlTag.key == tagKey)
676
+ .one_or_none()
677
+ )
678
+
679
+ def get_run(self, run_id):
680
+ with self.ManagedSessionMaker() as session:
681
+ # Load the run with the specified id and eagerly load its summary metrics, params, and
682
+ # tags. These attributes are referenced during the invocation of
683
+ # ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries
684
+ # that are otherwise executed at attribute access time under a lazy loading model.
685
+ run = self._get_run(run_uuid=run_id, session=session, eager=True)
686
+ mlflow_run = run.to_mlflow_entity()
687
+ # Get the run inputs and add to the run
688
+ inputs = self._get_run_inputs(run_uuids=[run_id], session=session)[0]
689
+ model_inputs = self._get_model_inputs(run_id, session)
690
+ model_outputs = self._get_model_outputs(run_id, session)
691
+ return Run(
692
+ mlflow_run.info,
693
+ mlflow_run.data,
694
+ RunInputs(dataset_inputs=inputs, model_inputs=model_inputs),
695
+ RunOutputs(model_outputs),
696
+ )
697
+
698
+ def restore_run(self, run_id):
699
+ with self.ManagedSessionMaker() as session:
700
+ run = self._get_run(run_uuid=run_id, session=session)
701
+ run.lifecycle_stage = LifecycleStage.ACTIVE
702
+ run.deleted_time = None
703
+ session.add(run)
704
+
705
+ def delete_run(self, run_id):
706
+ with self.ManagedSessionMaker() as session:
707
+ run = self._get_run(run_uuid=run_id, session=session)
708
+ run.lifecycle_stage = LifecycleStage.DELETED
709
+ run.deleted_time = get_current_time_millis()
710
+ session.add(run)
711
+
712
+ def _hard_delete_run(self, run_id):
713
+ """
714
+ Permanently delete a run (metadata and metrics, tags, parameters).
715
+ This is used by the ``mlflow gc`` command line and is not intended to be used elsewhere.
716
+ """
717
+ with self.ManagedSessionMaker() as session:
718
+ run = self._get_run(run_uuid=run_id, session=session)
719
+ session.delete(run)
720
+
721
+ def _get_deleted_runs(self, older_than=0):
722
+ """
723
+ Get all deleted run ids.
724
+
725
+ Args:
726
+ older_than: get runs that is older than this variable in number of milliseconds.
727
+ defaults to 0 ms to get all deleted runs.
728
+ """
729
+ current_time = get_current_time_millis()
730
+ with self.ManagedSessionMaker() as session:
731
+ runs = (
732
+ session.query(SqlRun)
733
+ .filter(
734
+ SqlRun.lifecycle_stage == LifecycleStage.DELETED,
735
+ SqlRun.deleted_time <= (current_time - older_than),
736
+ )
737
+ .all()
738
+ )
739
+ return [run.run_uuid for run in runs]
740
+
741
+ def log_metric(self, run_id, metric):
742
+ # simply call _log_metrics and let it handle the rest
743
+ self._log_metrics(run_id, [metric])
744
+ self._log_model_metrics(run_id, [metric])
745
+
746
+ def sanitize_metric_value(self, metric_value: float) -> tuple[bool, float]:
747
+ """
748
+ Returns a tuple of two values:
749
+ - A boolean indicating whether the metric is NaN.
750
+ - The metric value, which is set to 0 if the metric is NaN.
751
+ """
752
+ is_nan = math.isnan(metric_value)
753
+ if is_nan:
754
+ value = 0
755
+ elif math.isinf(metric_value):
756
+ # NB: Sql can not represent Infs = > We replace +/- Inf with max/min 64b float
757
+ # value
758
+ value = 1.7976931348623157e308 if metric_value > 0 else -1.7976931348623157e308
759
+ else:
760
+ value = metric_value
761
+ return is_nan, value
762
+
763
+ def _log_metrics(self, run_id, metrics):
764
+ # Duplicate metric values are eliminated here to maintain
765
+ # the same behavior in log_metric
766
+ metric_instances = []
767
+ seen = set()
768
+ is_single_metric = len(metrics) == 1
769
+ for idx, metric in enumerate(metrics):
770
+ _validate_metric(
771
+ metric.key,
772
+ metric.value,
773
+ metric.timestamp,
774
+ metric.step,
775
+ path="" if is_single_metric else f"metrics[{idx}]",
776
+ )
777
+ if metric not in seen:
778
+ is_nan, value = self.sanitize_metric_value(metric.value)
779
+ metric_instances.append(
780
+ SqlMetric(
781
+ run_uuid=run_id,
782
+ key=metric.key,
783
+ value=value,
784
+ timestamp=metric.timestamp,
785
+ step=metric.step,
786
+ is_nan=is_nan,
787
+ )
788
+ )
789
+ seen.add(metric)
790
+
791
+ with self.ManagedSessionMaker() as session:
792
+ run = self._get_run(run_uuid=run_id, session=session)
793
+ self._check_run_is_active(run)
794
+
795
+ def _insert_metrics(metric_instances):
796
+ session.add_all(metric_instances)
797
+ self._update_latest_metrics_if_necessary(metric_instances, session)
798
+ session.commit()
799
+
800
+ try:
801
+ _insert_metrics(metric_instances)
802
+ except sqlalchemy.exc.IntegrityError:
803
+ # Primary key can be violated if it is tried to log a metric with same value,
804
+ # timestamp, step, and key within the same run.
805
+ # Roll back the current session to make it usable for further transactions. In
806
+ # the event of an error during "commit", a rollback is required in order to
807
+ # continue using the session. In this case, we re-use the session to query
808
+ # SqlMetric
809
+ session.rollback()
810
+ # Divide metric keys into batches of 100 to avoid loading too much metric
811
+ # history data into memory at once
812
+ metric_keys = [m.key for m in metric_instances]
813
+ metric_key_batches = [
814
+ metric_keys[i : i + 100] for i in range(0, len(metric_keys), 100)
815
+ ]
816
+ for metric_key_batch in metric_key_batches:
817
+ # obtain the metric history corresponding to the given metrics
818
+ metric_history = (
819
+ session.query(SqlMetric)
820
+ .filter(
821
+ SqlMetric.run_uuid == run_id,
822
+ SqlMetric.key.in_(metric_key_batch),
823
+ )
824
+ .all()
825
+ )
826
+ # convert to a set of Metric instance to take advantage of its hashable
827
+ # and then obtain the metrics that were not logged earlier within this
828
+ # run_id
829
+ metric_history = {m.to_mlflow_entity() for m in metric_history}
830
+ non_existing_metrics = [
831
+ m for m in metric_instances if m.to_mlflow_entity() not in metric_history
832
+ ]
833
+ # if there exist metrics that were tried to be logged & rolled back even
834
+ # though they were not violating the PK, log them
835
+ _insert_metrics(non_existing_metrics)
836
+
837
+ def _log_model_metrics(
838
+ self,
839
+ run_id: str,
840
+ metrics: list[Metric],
841
+ dataset_uuid: Optional[str] = None,
842
+ experiment_id: Optional[str] = None,
843
+ ) -> None:
844
+ if not metrics:
845
+ return
846
+
847
+ metric_instances: list[SqlLoggedModelMetric] = []
848
+ is_single_metric = len(metrics) == 1
849
+ seen: set[Metric] = set()
850
+ for idx, metric in enumerate(metrics):
851
+ if metric.model_id is None:
852
+ continue
853
+
854
+ if metric in seen:
855
+ continue
856
+ seen.add(metric)
857
+
858
+ _validate_metric(
859
+ metric.key,
860
+ metric.value,
861
+ metric.timestamp,
862
+ metric.step,
863
+ path="" if is_single_metric else f"metrics[{idx}]",
864
+ )
865
+ is_nan, value = self.sanitize_metric_value(metric.value)
866
+ metric_instances.append(
867
+ SqlLoggedModelMetric(
868
+ model_id=metric.model_id,
869
+ metric_name=metric.key,
870
+ metric_timestamp_ms=metric.timestamp,
871
+ metric_step=metric.step,
872
+ metric_value=value,
873
+ experiment_id=experiment_id or _get_experiment_id(),
874
+ run_id=run_id,
875
+ dataset_uuid=dataset_uuid,
876
+ dataset_name=metric.dataset_name,
877
+ dataset_digest=metric.dataset_digest,
878
+ )
879
+ )
880
+
881
+ with self.ManagedSessionMaker() as session:
882
+ try:
883
+ session.add_all(metric_instances)
884
+ session.commit()
885
+ except sqlalchemy.exc.IntegrityError:
886
+ # Primary key can be violated if it is tried to log a metric with same value,
887
+ # timestamp, step, and key within the same run.
888
+ session.rollback()
889
+ metric_keys = [m.metric_name for m in metric_instances]
890
+ metric_key_batches = (
891
+ metric_keys[i : i + 100] for i in range(0, len(metric_keys), 100)
892
+ )
893
+ for batch in metric_key_batches:
894
+ existing_metrics = (
895
+ session.query(SqlLoggedModelMetric)
896
+ .filter(
897
+ SqlLoggedModelMetric.run_id == run_id,
898
+ SqlLoggedModelMetric.metric_name.in_(batch),
899
+ )
900
+ .all()
901
+ )
902
+ existing_metrics = {m.to_mlflow_entity() for m in existing_metrics}
903
+ non_existing_metrics = [
904
+ m for m in metric_instances if m.to_mlflow_entity() not in existing_metrics
905
+ ]
906
+ session.add_all(non_existing_metrics)
907
+
908
+ def _update_latest_metrics_if_necessary(self, logged_metrics, session):
909
+ def _compare_metrics(metric_a, metric_b):
910
+ """
911
+ Returns:
912
+ True if ``metric_a`` is strictly more recent than ``metric_b``, as determined
913
+ by ``step``, ``timestamp``, and ``value``. False otherwise.
914
+ """
915
+ return (metric_a.step, metric_a.timestamp, metric_a.value) > (
916
+ metric_b.step,
917
+ metric_b.timestamp,
918
+ metric_b.value,
919
+ )
920
+
921
+ def _overwrite_metric(new_metric, old_metric):
922
+ """
923
+ Writes content of new_metric over old_metric. The content are `value`, `step`,
924
+ `timestamp`, and `is_nan`.
925
+
926
+ Returns:
927
+ old_metric with its content updated.
928
+ """
929
+ old_metric.value = new_metric.value
930
+ old_metric.step = new_metric.step
931
+ old_metric.timestamp = new_metric.timestamp
932
+ old_metric.is_nan = new_metric.is_nan
933
+ return old_metric
934
+
935
+ if not logged_metrics:
936
+ return
937
+
938
+ # Fetch the latest metric value corresponding to the specified run_id and metric keys and
939
+ # lock their associated rows for the remainder of the transaction in order to ensure
940
+ # isolation
941
+ latest_metrics = {}
942
+ metric_keys = [m.key for m in logged_metrics]
943
+ # Divide metric keys into batches of 500 to avoid binding too many parameters to the SQL
944
+ # query, which may produce limit exceeded errors or poor performance on certain database
945
+ # platforms
946
+ metric_key_batches = [metric_keys[i : i + 500] for i in range(0, len(metric_keys), 500)]
947
+ for metric_key_batch in metric_key_batches:
948
+ # First, determine which metric keys are present in the database
949
+ latest_metrics_key_records_from_db = (
950
+ session.query(SqlLatestMetric.key)
951
+ .filter(
952
+ SqlLatestMetric.run_uuid == logged_metrics[0].run_uuid,
953
+ SqlLatestMetric.key.in_(metric_key_batch),
954
+ )
955
+ .all()
956
+ )
957
+ # Then, take a write lock on the rows corresponding to metric keys that are present,
958
+ # ensuring that they aren't modified by another transaction until they can be
959
+ # compared to the metric values logged by this transaction while avoiding gap locking
960
+ # and next-key locking which may otherwise occur when issuing a `SELECT FOR UPDATE`
961
+ # against nonexistent rows
962
+ if len(latest_metrics_key_records_from_db) > 0:
963
+ latest_metric_keys_from_db = [
964
+ record[0] for record in latest_metrics_key_records_from_db
965
+ ]
966
+ latest_metrics_batch = (
967
+ session.query(SqlLatestMetric)
968
+ .filter(
969
+ SqlLatestMetric.run_uuid == logged_metrics[0].run_uuid,
970
+ SqlLatestMetric.key.in_(latest_metric_keys_from_db),
971
+ )
972
+ # Order by the metric run ID and key to ensure a consistent locking order
973
+ # across transactions, reducing deadlock likelihood
974
+ .order_by(SqlLatestMetric.run_uuid, SqlLatestMetric.key)
975
+ .with_for_update()
976
+ .all()
977
+ )
978
+ latest_metrics.update({m.key: m for m in latest_metrics_batch})
979
+
980
+ # iterate over all logged metrics and compare them with corresponding
981
+ # SqlLatestMetric entries
982
+ # if there's no SqlLatestMetric entry for the current metric key,
983
+ # create a new SqlLatestMetric instance and put it in
984
+ # new_latest_metric_dict so that they can be saved later.
985
+ new_latest_metric_dict = {}
986
+ for logged_metric in logged_metrics:
987
+ latest_metric = latest_metrics.get(logged_metric.key)
988
+ # a metric key can be passed more then once within logged metrics
989
+ # with different step/timestamp/value. However SqlLatestMetric
990
+ # entries are inserted after this loop is completed.
991
+ # so, retrieve the instances they were just created and use them
992
+ # for comparison.
993
+ new_latest_metric = new_latest_metric_dict.get(logged_metric.key)
994
+
995
+ # just create a new SqlLatestMetric instance since both
996
+ # latest_metric row or recently created instance does not exist
997
+ if not latest_metric and not new_latest_metric:
998
+ new_latest_metric = SqlLatestMetric(
999
+ run_uuid=logged_metric.run_uuid,
1000
+ key=logged_metric.key,
1001
+ value=logged_metric.value,
1002
+ timestamp=logged_metric.timestamp,
1003
+ step=logged_metric.step,
1004
+ is_nan=logged_metric.is_nan,
1005
+ )
1006
+ new_latest_metric_dict[logged_metric.key] = new_latest_metric
1007
+
1008
+ # there's no row but a new instance is recently created.
1009
+ # so, update the recent instance in new_latest_metric_dict if
1010
+ # metric comparison is successful.
1011
+ elif not latest_metric and new_latest_metric:
1012
+ if _compare_metrics(logged_metric, new_latest_metric):
1013
+ new_latest_metric = _overwrite_metric(logged_metric, new_latest_metric)
1014
+ new_latest_metric_dict[logged_metric.key] = new_latest_metric
1015
+
1016
+ # compare with the row
1017
+ elif _compare_metrics(logged_metric, latest_metric):
1018
+ # editing the attributes of latest_metric, which is a
1019
+ # SqlLatestMetric instance will result in UPDATE in DB side.
1020
+ latest_metric = _overwrite_metric(logged_metric, latest_metric)
1021
+
1022
+ if new_latest_metric_dict:
1023
+ session.add_all(new_latest_metric_dict.values())
1024
+
1025
+ def get_metric_history(self, run_id, metric_key, max_results=None, page_token=None):
1026
+ """
1027
+ Return all logged values for a given metric.
1028
+
1029
+ Args:
1030
+ run_id: Unique identifier for run.
1031
+ metric_key: Metric name within the run.
1032
+ max_results: An indicator for paginated results.
1033
+ page_token: Token indicating the page of metric history to fetch.
1034
+
1035
+ Returns:
1036
+ A :py:class:`mlflow.store.entities.paged_list.PagedList` of
1037
+ :py:class:`mlflow.entities.Metric` entities if ``metric_key`` values
1038
+ have been logged to the ``run_id``, else an empty list.
1039
+
1040
+ """
1041
+ with self.ManagedSessionMaker() as session:
1042
+ query = session.query(SqlMetric).filter_by(run_uuid=run_id, key=metric_key)
1043
+
1044
+ # Parse offset from page_token for pagination
1045
+ offset = SearchUtils.parse_start_offset_from_page_token(page_token)
1046
+
1047
+ # Add ORDER BY clause to satisfy MSSQL requirement for OFFSET
1048
+ query = query.order_by(SqlMetric.timestamp, SqlMetric.step, SqlMetric.value)
1049
+ query = query.offset(offset)
1050
+
1051
+ if max_results is not None:
1052
+ query = query.limit(max_results + 1)
1053
+
1054
+ metrics = query.all()
1055
+
1056
+ # Compute next token if more results are available
1057
+ next_token = None
1058
+ if max_results is not None and len(metrics) == max_results + 1:
1059
+ final_offset = offset + max_results
1060
+ next_token = SearchUtils.create_page_token(final_offset)
1061
+ metrics = metrics[:max_results]
1062
+
1063
+ return PagedList([metric.to_mlflow_entity() for metric in metrics], next_token)
1064
+
1065
+ def get_metric_history_bulk(self, run_ids, metric_key, max_results):
1066
+ """
1067
+ Return all logged values for a given metric.
1068
+
1069
+ Args:
1070
+ run_ids: Unique identifiers of the runs from which to fetch the metric histories for
1071
+ the specified key.
1072
+ metric_key: Metric name within the runs.
1073
+ max_results: The maximum number of results to return.
1074
+
1075
+ Returns:
1076
+ A List of SqlAlchemyStore.MetricWithRunId objects if metric_key values have been logged
1077
+ to one or more of the specified run_ids, else an empty list. Results are sorted by run
1078
+ ID in lexicographically ascending order, followed by timestamp, step, and value in
1079
+ numerically ascending order.
1080
+ """
1081
+ # NB: The SQLAlchemyStore does not currently support pagination for this API.
1082
+ # Raise if `page_token` is specified, as the functionality to support paged queries
1083
+ # is not implemented.
1084
+ with self.ManagedSessionMaker() as session:
1085
+ metrics = (
1086
+ session.query(SqlMetric)
1087
+ .filter(
1088
+ SqlMetric.key == metric_key,
1089
+ SqlMetric.run_uuid.in_(run_ids),
1090
+ )
1091
+ .order_by(
1092
+ SqlMetric.run_uuid,
1093
+ SqlMetric.timestamp,
1094
+ SqlMetric.step,
1095
+ SqlMetric.value,
1096
+ )
1097
+ .limit(max_results)
1098
+ .all()
1099
+ )
1100
+ return [
1101
+ MetricWithRunId(
1102
+ run_id=metric.run_uuid,
1103
+ metric=metric.to_mlflow_entity(),
1104
+ )
1105
+ for metric in metrics
1106
+ ]
1107
+
1108
+ def get_max_step_for_metric(self, run_id, metric_key):
1109
+ with self.ManagedSessionMaker() as session:
1110
+ max_step = (
1111
+ session.query(func.max(SqlMetric.step))
1112
+ .filter(SqlMetric.run_uuid == run_id, SqlMetric.key == metric_key)
1113
+ .scalar()
1114
+ )
1115
+ return max_step or 0
1116
+
1117
+ def get_metric_history_bulk_interval_from_steps(self, run_id, metric_key, steps, max_results):
1118
+ with self.ManagedSessionMaker() as session:
1119
+ metrics = (
1120
+ session.query(SqlMetric)
1121
+ .filter(
1122
+ SqlMetric.key == metric_key,
1123
+ SqlMetric.run_uuid == run_id,
1124
+ SqlMetric.step.in_(steps),
1125
+ )
1126
+ .order_by(
1127
+ SqlMetric.run_uuid,
1128
+ SqlMetric.step,
1129
+ SqlMetric.timestamp,
1130
+ SqlMetric.value,
1131
+ )
1132
+ .limit(max_results)
1133
+ .all()
1134
+ )
1135
+ return [
1136
+ MetricWithRunId(
1137
+ run_id=metric.run_uuid,
1138
+ metric=metric.to_mlflow_entity(),
1139
+ )
1140
+ for metric in metrics
1141
+ ]
1142
+
1143
+ def _search_datasets(self, experiment_ids):
1144
+ """
1145
+ Return all dataset summaries associated to the given experiments.
1146
+
1147
+ Args:
1148
+ experiment_ids: List of experiment ids to scope the search
1149
+
1150
+ Returns:
1151
+ A List of :py:class:`SqlAlchemyStore.DatasetSummary` entities.
1152
+ """
1153
+
1154
+ MAX_DATASET_SUMMARIES_RESULTS = 1000
1155
+ with self.ManagedSessionMaker() as session:
1156
+ # Note that the join with the input tag table is a left join. This is required so if an
1157
+ # input does not have the MLFLOW_DATASET_CONTEXT tag, we still return that entry as part
1158
+ # of the final result with the context set to None.
1159
+ summaries = (
1160
+ session.query(
1161
+ SqlDataset.experiment_id,
1162
+ SqlDataset.name,
1163
+ SqlDataset.digest,
1164
+ SqlInputTag.value,
1165
+ )
1166
+ .select_from(SqlDataset)
1167
+ .distinct()
1168
+ .join(SqlInput, SqlInput.source_id == SqlDataset.dataset_uuid)
1169
+ .join(
1170
+ SqlInputTag,
1171
+ and_(
1172
+ SqlInput.input_uuid == SqlInputTag.input_uuid,
1173
+ SqlInputTag.name == MLFLOW_DATASET_CONTEXT,
1174
+ ),
1175
+ isouter=True,
1176
+ )
1177
+ .filter(SqlDataset.experiment_id.in_(experiment_ids))
1178
+ .limit(MAX_DATASET_SUMMARIES_RESULTS)
1179
+ .all()
1180
+ )
1181
+
1182
+ return [
1183
+ _DatasetSummary(
1184
+ experiment_id=str(summary.experiment_id),
1185
+ name=summary.name,
1186
+ digest=summary.digest,
1187
+ context=summary.value,
1188
+ )
1189
+ for summary in summaries
1190
+ ]
1191
+
1192
+ def log_param(self, run_id, param):
1193
+ param = _validate_param(param.key, param.value)
1194
+ with self.ManagedSessionMaker() as session:
1195
+ run = self._get_run(run_uuid=run_id, session=session)
1196
+ self._check_run_is_active(run)
1197
+ # if we try to update the value of an existing param this will fail
1198
+ # because it will try to create it with same run_uuid, param key
1199
+ try:
1200
+ # This will check for various integrity checks for params table.
1201
+ # ToDo: Consider prior checks for null, type, param name validations, ... etc.
1202
+ self._get_or_create(
1203
+ model=SqlParam,
1204
+ session=session,
1205
+ run_uuid=run_id,
1206
+ key=param.key,
1207
+ value=param.value,
1208
+ )
1209
+ # Explicitly commit the session in order to catch potential integrity errors
1210
+ # while maintaining the current managed session scope ("commit" checks that
1211
+ # a transaction satisfies uniqueness constraints and throws integrity errors
1212
+ # when they are violated; "get_or_create()" does not perform these checks). It is
1213
+ # important that we maintain the same session scope because, in the case of
1214
+ # an integrity error, we want to examine the uniqueness of parameter values using
1215
+ # the same database state that the session uses during "commit". Creating a new
1216
+ # session synchronizes the state with the database. As a result, if the conflicting
1217
+ # parameter value were to be removed prior to the creation of a new session,
1218
+ # we would be unable to determine the cause of failure for the first session's
1219
+ # "commit" operation.
1220
+ session.commit()
1221
+ except sqlalchemy.exc.IntegrityError:
1222
+ # Roll back the current session to make it usable for further transactions. In the
1223
+ # event of an error during "commit", a rollback is required in order to continue
1224
+ # using the session. In this case, we re-use the session because the SqlRun, `run`,
1225
+ # is lazily evaluated during the invocation of `run.params`.
1226
+ session.rollback()
1227
+ existing_params = [p.value for p in run.params if p.key == param.key]
1228
+ if len(existing_params) > 0:
1229
+ old_value = existing_params[0]
1230
+ if old_value != param.value:
1231
+ raise MlflowException(
1232
+ "Changing param values is not allowed. Param with key='{}' was already"
1233
+ " logged with value='{}' for run ID='{}'. Attempted logging new value"
1234
+ " '{}'.".format(param.key, old_value, run_id, param.value),
1235
+ INVALID_PARAMETER_VALUE,
1236
+ )
1237
+ else:
1238
+ raise
1239
+
1240
+ def _log_params(self, run_id, params):
1241
+ if not params:
1242
+ return
1243
+
1244
+ with self.ManagedSessionMaker() as session:
1245
+ run = self._get_run(run_uuid=run_id, session=session)
1246
+ self._check_run_is_active(run)
1247
+ existing_params = {p.key: p.value for p in run.params}
1248
+ new_params = []
1249
+ non_matching_params = []
1250
+ for param in params:
1251
+ if param.key in existing_params:
1252
+ if param.value != existing_params[param.key]:
1253
+ non_matching_params.append(
1254
+ {
1255
+ "key": param.key,
1256
+ "old_value": existing_params[param.key],
1257
+ "new_value": param.value,
1258
+ }
1259
+ )
1260
+ continue
1261
+ new_params.append(SqlParam(run_uuid=run_id, key=param.key, value=param.value))
1262
+
1263
+ if non_matching_params:
1264
+ raise MlflowException(
1265
+ "Changing param values is not allowed. Params were already"
1266
+ f" logged='{non_matching_params}' for run ID='{run_id}'.",
1267
+ INVALID_PARAMETER_VALUE,
1268
+ )
1269
+
1270
+ if not new_params:
1271
+ return
1272
+
1273
+ session.add_all(new_params)
1274
+
1275
+ def set_experiment_tag(self, experiment_id, tag):
1276
+ """
1277
+ Set a tag for the specified experiment
1278
+
1279
+ Args:
1280
+ experiment_id: String ID of the experiment
1281
+ tag: ExperimentRunTag instance to log
1282
+ """
1283
+ _validate_experiment_tag(tag.key, tag.value)
1284
+ with self.ManagedSessionMaker() as session:
1285
+ tag = _validate_tag(tag.key, tag.value)
1286
+ experiment = self._get_experiment(
1287
+ session, experiment_id, ViewType.ALL
1288
+ ).to_mlflow_entity()
1289
+ self._check_experiment_is_active(experiment)
1290
+ session.merge(
1291
+ SqlExperimentTag(experiment_id=experiment_id, key=tag.key, value=tag.value)
1292
+ )
1293
+
1294
+ def set_tag(self, run_id, tag):
1295
+ """
1296
+ Set a tag on a run.
1297
+
1298
+ Args:
1299
+ run_id: String ID of the run.
1300
+ tag: RunTag instance to log.
1301
+ """
1302
+ with self.ManagedSessionMaker() as session:
1303
+ tag = _validate_tag(tag.key, tag.value)
1304
+ run = self._get_run(run_uuid=run_id, session=session)
1305
+ self._check_run_is_active(run)
1306
+ if tag.key == MLFLOW_RUN_NAME:
1307
+ run_status = RunStatus.from_string(run.status)
1308
+ self.update_run_info(run_id, run_status, run.end_time, tag.value)
1309
+ else:
1310
+ # NB: Updating the run_info will set the tag. No need to do it twice.
1311
+ session.merge(SqlTag(run_uuid=run_id, key=tag.key, value=tag.value))
1312
+
1313
+ def _set_tags(self, run_id, tags):
1314
+ """
1315
+ Set multiple tags on a run
1316
+
1317
+ Args:
1318
+ run_id: String ID of the run
1319
+ tags: List of RunTag instances to log
1320
+ path: current json path for error messages
1321
+ """
1322
+ if not tags:
1323
+ return
1324
+
1325
+ tags = [_validate_tag(t.key, t.value, path=f"tags[{idx}]") for (idx, t) in enumerate(tags)]
1326
+
1327
+ with self.ManagedSessionMaker() as session:
1328
+ run = self._get_run(run_uuid=run_id, session=session)
1329
+ self._check_run_is_active(run)
1330
+
1331
+ def _try_insert_tags(attempt_number, max_retries):
1332
+ try:
1333
+ current_tags = (
1334
+ session.query(SqlTag)
1335
+ .filter(
1336
+ SqlTag.run_uuid == run_id,
1337
+ SqlTag.key.in_([t.key for t in tags]),
1338
+ )
1339
+ .all()
1340
+ )
1341
+ current_tags = {t.key: t for t in current_tags}
1342
+
1343
+ new_tag_dict = {}
1344
+ for tag in tags:
1345
+ # NB: If the run name tag is explicitly set, update the run info attribute
1346
+ # and do not resubmit the tag for overwrite as the tag will be set within
1347
+ # `set_tag()` with a call to `update_run_info()`
1348
+ if tag.key == MLFLOW_RUN_NAME:
1349
+ self.set_tag(run_id, tag)
1350
+ else:
1351
+ current_tag = current_tags.get(tag.key)
1352
+ new_tag = new_tag_dict.get(tag.key)
1353
+
1354
+ # update the SqlTag if it is already present in DB
1355
+ if current_tag:
1356
+ current_tag.value = tag.value
1357
+ continue
1358
+
1359
+ # if a SqlTag instance is already present in `new_tag_dict`,
1360
+ # this means that multiple tags with the same key were passed to
1361
+ # `set_tags`.
1362
+ # In this case, we resolve potential conflicts by updating the value
1363
+ # of the existing instance to the value of `tag`
1364
+ if new_tag:
1365
+ new_tag.value = tag.value
1366
+ # otherwise, put it into the dict
1367
+ else:
1368
+ new_tag = SqlTag(run_uuid=run_id, key=tag.key, value=tag.value)
1369
+
1370
+ new_tag_dict[tag.key] = new_tag
1371
+
1372
+ # finally, save new entries to DB.
1373
+ session.add_all(new_tag_dict.values())
1374
+ session.commit()
1375
+ except sqlalchemy.exc.IntegrityError:
1376
+ session.rollback()
1377
+ # two concurrent operations may try to attempt to insert tags.
1378
+ # apply retry here.
1379
+ if attempt_number > max_retries:
1380
+ raise MlflowException(
1381
+ "Failed to set tags with given within {} retries. Keys: {}".format(
1382
+ max_retries, [t.key for t in tags]
1383
+ )
1384
+ )
1385
+ sleep_duration = (2**attempt_number) - 1
1386
+ sleep_duration += random.uniform(0, 1)
1387
+ time.sleep(sleep_duration)
1388
+ _try_insert_tags(attempt_number + 1, max_retries=max_retries)
1389
+
1390
+ _try_insert_tags(attempt_number=0, max_retries=3)
1391
+
1392
+ def delete_tag(self, run_id, key):
1393
+ """
1394
+ Delete a tag from a run. This is irreversible.
1395
+
1396
+ Args:
1397
+ run_id: String ID of the run
1398
+ key: Name of the tag
1399
+ """
1400
+ with self.ManagedSessionMaker() as session:
1401
+ run = self._get_run(run_uuid=run_id, session=session)
1402
+ self._check_run_is_active(run)
1403
+ filtered_tags = session.query(SqlTag).filter_by(run_uuid=run_id, key=key).all()
1404
+ if len(filtered_tags) == 0:
1405
+ raise MlflowException(
1406
+ f"No tag with name: {key} in run with id {run_id}",
1407
+ error_code=RESOURCE_DOES_NOT_EXIST,
1408
+ )
1409
+ elif len(filtered_tags) > 1:
1410
+ raise MlflowException(
1411
+ "Bad data in database - tags for a specific run must have "
1412
+ "a single unique value. "
1413
+ "See https://mlflow.org/docs/latest/tracking.html#adding-tags-to-runs",
1414
+ error_code=INVALID_STATE,
1415
+ )
1416
+ session.delete(filtered_tags[0])
1417
+
1418
+ def _search_runs(
1419
+ self,
1420
+ experiment_ids,
1421
+ filter_string,
1422
+ run_view_type,
1423
+ max_results,
1424
+ order_by,
1425
+ page_token,
1426
+ ):
1427
+ def compute_next_token(current_size):
1428
+ next_token = None
1429
+ if max_results == current_size:
1430
+ final_offset = offset + max_results
1431
+ next_token = SearchUtils.create_page_token(final_offset)
1432
+
1433
+ return next_token
1434
+
1435
+ self._validate_max_results_param(max_results, allow_null=True)
1436
+
1437
+ stages = set(LifecycleStage.view_type_to_stages(run_view_type))
1438
+
1439
+ with self.ManagedSessionMaker() as session:
1440
+ # Fetch the appropriate runs and eagerly load their summary metrics, params, and
1441
+ # tags. These run attributes are referenced during the invocation of
1442
+ # ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries
1443
+ # that are otherwise executed at attribute access time under a lazy loading model.
1444
+ parsed_filters = SearchUtils.parse_search_filter(filter_string)
1445
+ cases_orderby, parsed_orderby, sorting_joins = _get_orderby_clauses(order_by, session)
1446
+
1447
+ stmt = select(SqlRun, *cases_orderby)
1448
+ (
1449
+ attribute_filters,
1450
+ non_attribute_filters,
1451
+ dataset_filters,
1452
+ ) = _get_sqlalchemy_filter_clauses(parsed_filters, session, self._get_dialect())
1453
+ for non_attr_filter in non_attribute_filters:
1454
+ stmt = stmt.join(non_attr_filter)
1455
+ for idx, dataset_filter in enumerate(dataset_filters):
1456
+ # need to reference the anon table in the join condition
1457
+ anon_table_name = f"anon_{idx + 1}"
1458
+ stmt = stmt.join(
1459
+ dataset_filter,
1460
+ text(f"runs.run_uuid = {anon_table_name}.destination_id"),
1461
+ )
1462
+ # using an outer join is necessary here because we want to be able to sort
1463
+ # on a column (tag, metric or param) without removing the lines that
1464
+ # do not have a value for this column (which is what inner join would do)
1465
+ for j in sorting_joins:
1466
+ stmt = stmt.outerjoin(j)
1467
+
1468
+ offset = SearchUtils.parse_start_offset_from_page_token(page_token)
1469
+ stmt = (
1470
+ stmt.distinct()
1471
+ .options(*self._get_eager_run_query_options())
1472
+ .filter(
1473
+ SqlRun.experiment_id.in_(experiment_ids),
1474
+ SqlRun.lifecycle_stage.in_(stages),
1475
+ *attribute_filters,
1476
+ )
1477
+ .order_by(*parsed_orderby)
1478
+ .offset(offset)
1479
+ .limit(max_results)
1480
+ )
1481
+ queried_runs = session.execute(stmt).scalars(SqlRun).all()
1482
+
1483
+ runs = [run.to_mlflow_entity() for run in queried_runs]
1484
+ run_ids = [run.info.run_id for run in runs]
1485
+
1486
+ # add inputs to runs
1487
+ inputs = self._get_run_inputs(run_uuids=run_ids, session=session)
1488
+ runs_with_inputs = []
1489
+ for i, run in enumerate(runs):
1490
+ runs_with_inputs.append(
1491
+ Run(run.info, run.data, RunInputs(dataset_inputs=inputs[i]))
1492
+ )
1493
+
1494
+ next_page_token = compute_next_token(len(runs_with_inputs))
1495
+
1496
+ return runs_with_inputs, next_page_token
1497
+
1498
+ def log_batch(self, run_id, metrics, params, tags):
1499
+ _validate_run_id(run_id)
1500
+ metrics, params, tags = _validate_batch_log_data(metrics, params, tags)
1501
+ _validate_batch_log_limits(metrics, params, tags)
1502
+ _validate_param_keys_unique(params)
1503
+
1504
+ with self.ManagedSessionMaker() as session:
1505
+ run = self._get_run(run_uuid=run_id, session=session)
1506
+ self._check_run_is_active(run)
1507
+ try:
1508
+ self._log_params(run_id, params)
1509
+ self._log_metrics(run_id, metrics)
1510
+ self._log_model_metrics(run_id, metrics)
1511
+ self._set_tags(run_id, tags)
1512
+ except MlflowException as e:
1513
+ raise e
1514
+ except Exception as e:
1515
+ raise MlflowException(e, INTERNAL_ERROR)
1516
+
1517
+ def record_logged_model(self, run_id, mlflow_model):
1518
+ from mlflow.models import Model
1519
+
1520
+ if not isinstance(mlflow_model, Model):
1521
+ raise TypeError(
1522
+ f"Argument 'mlflow_model' should be mlflow.models.Model, got '{type(mlflow_model)}'"
1523
+ )
1524
+ model_dict = mlflow_model.get_tags_dict()
1525
+ with self.ManagedSessionMaker() as session:
1526
+ run = self._get_run(run_uuid=run_id, session=session)
1527
+ self._check_run_is_active(run)
1528
+ previous_tag = [t for t in run.tags if t.key == MLFLOW_LOGGED_MODELS]
1529
+ if previous_tag:
1530
+ value = json.dumps(json.loads(previous_tag[0].value) + [model_dict])
1531
+ else:
1532
+ value = json.dumps([model_dict])
1533
+ _validate_tag(MLFLOW_LOGGED_MODELS, value)
1534
+ session.merge(SqlTag(key=MLFLOW_LOGGED_MODELS, value=value, run_uuid=run_id))
1535
+
1536
+ def log_inputs(
1537
+ self,
1538
+ run_id: str,
1539
+ datasets: Optional[list[DatasetInput]] = None,
1540
+ models: Optional[list[LoggedModelInput]] = None,
1541
+ ):
1542
+ """
1543
+ Log inputs, such as datasets, to the specified run.
1544
+
1545
+ Args:
1546
+ run_id: String id for the run
1547
+ datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log
1548
+ as inputs to the run.
1549
+ models: List of :py:class:`mlflow.entities.LoggedModelInput` instances to log
1550
+ as inputs to the run.
1551
+
1552
+ Returns:
1553
+ None.
1554
+ """
1555
+ _validate_run_id(run_id)
1556
+ if datasets is not None:
1557
+ if not isinstance(datasets, list):
1558
+ raise TypeError(f"Argument 'datasets' should be a list, got '{type(datasets)}'")
1559
+ _validate_dataset_inputs(datasets)
1560
+
1561
+ with self.ManagedSessionMaker() as session:
1562
+ run = self._get_run(run_uuid=run_id, session=session)
1563
+ experiment_id = run.experiment_id
1564
+ self._check_run_is_active(run)
1565
+ try:
1566
+ self._log_inputs_impl(experiment_id, run_id, datasets, models)
1567
+ except MlflowException as e:
1568
+ raise e
1569
+ except Exception as e:
1570
+ raise MlflowException(e, INTERNAL_ERROR)
1571
+
1572
+ def _log_inputs_impl(
1573
+ self,
1574
+ experiment_id,
1575
+ run_id,
1576
+ dataset_inputs: Optional[list[DatasetInput]] = None,
1577
+ models: Optional[list[LoggedModelInput]] = None,
1578
+ ):
1579
+ dataset_inputs = dataset_inputs or []
1580
+ for dataset_input in dataset_inputs:
1581
+ if dataset_input.dataset is None:
1582
+ raise MlflowException(
1583
+ "Dataset input must have a dataset associated with it.",
1584
+ INTERNAL_ERROR,
1585
+ )
1586
+
1587
+ # dedup dataset_inputs list if two dataset inputs have the same name and digest
1588
+ # keeping the first occurrence
1589
+ name_digest_keys = {}
1590
+ for dataset_input in dataset_inputs:
1591
+ key = (dataset_input.dataset.name, dataset_input.dataset.digest)
1592
+ if key not in name_digest_keys:
1593
+ name_digest_keys[key] = dataset_input
1594
+ dataset_inputs = list(name_digest_keys.values())
1595
+
1596
+ with self.ManagedSessionMaker() as session:
1597
+ dataset_names_to_check = [
1598
+ dataset_input.dataset.name for dataset_input in dataset_inputs
1599
+ ]
1600
+ dataset_digests_to_check = [
1601
+ dataset_input.dataset.digest for dataset_input in dataset_inputs
1602
+ ]
1603
+ # find all datasets with the same name and digest
1604
+ # if the dataset already exists, use the existing dataset uuid
1605
+ existing_datasets = (
1606
+ session.query(SqlDataset)
1607
+ .filter(SqlDataset.name.in_(dataset_names_to_check))
1608
+ .filter(SqlDataset.digest.in_(dataset_digests_to_check))
1609
+ .all()
1610
+ )
1611
+ dataset_uuids = {}
1612
+ for existing_dataset in existing_datasets:
1613
+ dataset_uuids[(existing_dataset.name, existing_dataset.digest)] = (
1614
+ existing_dataset.dataset_uuid
1615
+ )
1616
+
1617
+ # collect all objects to write to DB in a single list
1618
+ objs_to_write = []
1619
+
1620
+ # add datasets to objs_to_write
1621
+ for dataset_input in dataset_inputs:
1622
+ if (
1623
+ dataset_input.dataset.name,
1624
+ dataset_input.dataset.digest,
1625
+ ) not in dataset_uuids:
1626
+ new_dataset_uuid = uuid.uuid4().hex
1627
+ dataset_uuids[(dataset_input.dataset.name, dataset_input.dataset.digest)] = (
1628
+ new_dataset_uuid
1629
+ )
1630
+ objs_to_write.append(
1631
+ SqlDataset(
1632
+ dataset_uuid=new_dataset_uuid,
1633
+ experiment_id=experiment_id,
1634
+ name=dataset_input.dataset.name,
1635
+ digest=dataset_input.dataset.digest,
1636
+ dataset_source_type=dataset_input.dataset.source_type,
1637
+ dataset_source=dataset_input.dataset.source,
1638
+ dataset_schema=dataset_input.dataset.schema,
1639
+ dataset_profile=dataset_input.dataset.profile,
1640
+ )
1641
+ )
1642
+
1643
+ # find all inputs with the same source_id and destination_id
1644
+ # if the input already exists, use the existing input uuid
1645
+ existing_inputs = (
1646
+ session.query(SqlInput)
1647
+ .filter(SqlInput.source_type == "DATASET")
1648
+ .filter(SqlInput.source_id.in_(dataset_uuids.values()))
1649
+ .filter(SqlInput.destination_type == "RUN")
1650
+ .filter(SqlInput.destination_id == run_id)
1651
+ .all()
1652
+ )
1653
+ input_uuids = {}
1654
+ for existing_input in existing_inputs:
1655
+ input_uuids[(existing_input.source_id, existing_input.destination_id)] = (
1656
+ existing_input.input_uuid
1657
+ )
1658
+
1659
+ # add input edges to objs_to_write
1660
+ for dataset_input in dataset_inputs:
1661
+ dataset_uuid = dataset_uuids[
1662
+ (dataset_input.dataset.name, dataset_input.dataset.digest)
1663
+ ]
1664
+ if (dataset_uuid, run_id) not in input_uuids:
1665
+ new_input_uuid = uuid.uuid4().hex
1666
+ input_uuids[(dataset_input.dataset.name, dataset_input.dataset.digest)] = (
1667
+ new_input_uuid
1668
+ )
1669
+ objs_to_write.append(
1670
+ SqlInput(
1671
+ input_uuid=new_input_uuid,
1672
+ source_type="DATASET",
1673
+ source_id=dataset_uuid,
1674
+ destination_type="RUN",
1675
+ destination_id=run_id,
1676
+ )
1677
+ )
1678
+ # add input tags to objs_to_write
1679
+ for input_tag in dataset_input.tags:
1680
+ objs_to_write.append(
1681
+ SqlInputTag(
1682
+ input_uuid=new_input_uuid,
1683
+ name=input_tag.key,
1684
+ value=input_tag.value,
1685
+ )
1686
+ )
1687
+
1688
+ if models:
1689
+ for model in models:
1690
+ session.merge(
1691
+ SqlInput(
1692
+ input_uuid=uuid.uuid4().hex,
1693
+ source_type="RUN_INPUT",
1694
+ source_id=run_id,
1695
+ destination_type="MODEL_INPUT",
1696
+ destination_id=model.model_id,
1697
+ )
1698
+ )
1699
+
1700
+ session.add_all(objs_to_write)
1701
+
1702
+ def log_outputs(self, run_id: str, models: list[LoggedModelOutput]):
1703
+ with self.ManagedSessionMaker() as session:
1704
+ run = self._get_run(run_uuid=run_id, session=session)
1705
+ self._check_run_is_active(run)
1706
+ session.add_all(
1707
+ SqlInput(
1708
+ input_uuid=uuid.uuid4().hex,
1709
+ source_type="RUN_OUTPUT",
1710
+ source_id=run_id,
1711
+ destination_type="MODEL_OUTPUT",
1712
+ destination_id=model.model_id,
1713
+ step=model.step,
1714
+ )
1715
+ for model in models
1716
+ )
1717
+
1718
+ def _get_model_inputs(
1719
+ self,
1720
+ run_id: str,
1721
+ session: Optional[sqlalchemy.orm.Session] = None,
1722
+ ) -> list[LoggedModelInput]:
1723
+ return [
1724
+ LoggedModelInput(model_id=input.destination_id)
1725
+ for input in (
1726
+ session.query(SqlInput)
1727
+ .filter(
1728
+ SqlInput.source_type == "RUN_INPUT",
1729
+ SqlInput.source_id == run_id,
1730
+ SqlInput.destination_type == "MODEL_INPUT",
1731
+ )
1732
+ .all()
1733
+ )
1734
+ ]
1735
+
1736
+ def _get_model_outputs(
1737
+ self,
1738
+ run_id: str,
1739
+ session: sqlalchemy.orm.Session,
1740
+ ) -> list[LoggedModelOutput]:
1741
+ return [
1742
+ LoggedModelOutput(model_id=output.destination_id, step=output.step)
1743
+ for output in session.query(SqlInput)
1744
+ .filter(
1745
+ SqlInput.source_type == "RUN_OUTPUT",
1746
+ SqlInput.source_id == run_id,
1747
+ SqlInput.destination_type == "MODEL_OUTPUT",
1748
+ )
1749
+ .all()
1750
+ ]
1751
+
1752
+ #######################################################################################
1753
+ # Logged models
1754
+ #######################################################################################
1755
+ def create_logged_model(
1756
+ self,
1757
+ experiment_id: str,
1758
+ name: Optional[str] = None,
1759
+ source_run_id: Optional[str] = None,
1760
+ tags: Optional[list[LoggedModelTag]] = None,
1761
+ params: Optional[list[LoggedModelParameter]] = None,
1762
+ model_type: Optional[str] = None,
1763
+ ) -> LoggedModel:
1764
+ _validate_logged_model_name(name)
1765
+ with self.ManagedSessionMaker() as session:
1766
+ experiment = self.get_experiment(experiment_id)
1767
+ self._check_experiment_is_active(experiment)
1768
+ model_id = f"m-{str(uuid.uuid4()).replace('-', '')}"
1769
+ artifact_location = append_to_uri_path(
1770
+ experiment.artifact_location,
1771
+ SqlAlchemyStore.MODELS_FOLDER_NAME,
1772
+ model_id,
1773
+ SqlAlchemyStore.ARTIFACTS_FOLDER_NAME,
1774
+ )
1775
+ name = name or _generate_random_name()
1776
+ creation_timestamp = get_current_time_millis()
1777
+ logged_model = SqlLoggedModel(
1778
+ model_id=model_id,
1779
+ experiment_id=experiment_id,
1780
+ name=name,
1781
+ artifact_location=artifact_location,
1782
+ creation_timestamp_ms=creation_timestamp,
1783
+ last_updated_timestamp_ms=creation_timestamp,
1784
+ model_type=model_type,
1785
+ status=LoggedModelStatus.PENDING.to_int(),
1786
+ lifecycle_stage=LifecycleStage.ACTIVE,
1787
+ source_run_id=source_run_id,
1788
+ )
1789
+ session.add(logged_model)
1790
+
1791
+ if params:
1792
+ session.add_all(
1793
+ SqlLoggedModelParam(
1794
+ model_id=logged_model.model_id,
1795
+ experiment_id=experiment_id,
1796
+ param_key=param.key,
1797
+ param_value=param.value,
1798
+ )
1799
+ for param in params
1800
+ )
1801
+
1802
+ if tags:
1803
+ session.add_all(
1804
+ SqlLoggedModelTag(
1805
+ model_id=logged_model.model_id,
1806
+ experiment_id=experiment_id,
1807
+ tag_key=tag.key,
1808
+ tag_value=tag.value,
1809
+ )
1810
+ for tag in tags
1811
+ )
1812
+
1813
+ session.commit()
1814
+ return logged_model.to_mlflow_entity()
1815
+
1816
+ def log_logged_model_params(self, model_id: str, params: list[LoggedModelParameter]):
1817
+ with self.ManagedSessionMaker() as session:
1818
+ logged_model = session.query(SqlLoggedModel).get(model_id)
1819
+ if not logged_model:
1820
+ self._raise_model_not_found(model_id)
1821
+
1822
+ session.add_all(
1823
+ SqlLoggedModelParam(
1824
+ model_id=model_id,
1825
+ experiment_id=logged_model.experiment_id,
1826
+ param_key=param.key,
1827
+ param_value=param.value,
1828
+ )
1829
+ for param in params
1830
+ )
1831
+
1832
+ def _raise_model_not_found(self, model_id: str):
1833
+ raise MlflowException(
1834
+ f"Logged model with ID '{model_id}' not found.",
1835
+ RESOURCE_DOES_NOT_EXIST,
1836
+ )
1837
+
1838
+ def get_logged_model(self, model_id: str) -> LoggedModel:
1839
+ with self.ManagedSessionMaker() as session:
1840
+ logged_model = (
1841
+ session.query(SqlLoggedModel)
1842
+ .filter(
1843
+ SqlLoggedModel.model_id == model_id,
1844
+ SqlLoggedModel.lifecycle_stage != LifecycleStage.DELETED,
1845
+ )
1846
+ .first()
1847
+ )
1848
+ if not logged_model:
1849
+ self._raise_model_not_found(model_id)
1850
+
1851
+ return logged_model.to_mlflow_entity()
1852
+
1853
+ def delete_logged_model(self, model_id):
1854
+ with self.ManagedSessionMaker() as session:
1855
+ logged_model = session.query(SqlLoggedModel).get(model_id)
1856
+ if not logged_model:
1857
+ self._raise_model_not_found(model_id)
1858
+
1859
+ logged_model.lifecycle_stage = LifecycleStage.DELETED
1860
+ logged_model.last_updated_timestamp_ms = get_current_time_millis()
1861
+ session.commit()
1862
+
1863
+ def finalize_logged_model(self, model_id: str, status: LoggedModelStatus) -> LoggedModel:
1864
+ with self.ManagedSessionMaker() as session:
1865
+ logged_model = session.query(SqlLoggedModel).get(model_id)
1866
+ if not logged_model:
1867
+ self._raise_model_not_found(model_id)
1868
+
1869
+ logged_model.status = status.to_int()
1870
+ logged_model.last_updated_timestamp_ms = get_current_time_millis()
1871
+ session.commit()
1872
+ return logged_model.to_mlflow_entity()
1873
+
1874
+ def set_logged_model_tags(self, model_id: str, tags: list[LoggedModelTag]) -> None:
1875
+ with self.ManagedSessionMaker() as session:
1876
+ logged_model = session.query(SqlLoggedModel).get(model_id)
1877
+ if not logged_model:
1878
+ self._raise_model_not_found(model_id)
1879
+
1880
+ # TODO: Consider upserting tags in a single transaction for performance
1881
+ for tag in tags:
1882
+ session.merge(
1883
+ SqlLoggedModelTag(
1884
+ model_id=model_id,
1885
+ experiment_id=logged_model.experiment_id,
1886
+ tag_key=tag.key,
1887
+ tag_value=tag.value,
1888
+ )
1889
+ )
1890
+
1891
+ def delete_logged_model_tag(self, model_id: str, key: str) -> None:
1892
+ with self.ManagedSessionMaker() as session:
1893
+ logged_model = session.query(SqlLoggedModel).get(model_id)
1894
+ if not logged_model:
1895
+ self._raise_model_not_found(model_id)
1896
+
1897
+ count = (
1898
+ session.query(SqlLoggedModelTag)
1899
+ .filter(
1900
+ SqlLoggedModelTag.model_id == model_id,
1901
+ SqlLoggedModelTag.tag_key == key,
1902
+ )
1903
+ .delete()
1904
+ )
1905
+ if count == 0:
1906
+ raise MlflowException(
1907
+ f"No tag with key {key!r} found for model with ID {model_id!r}.",
1908
+ RESOURCE_DOES_NOT_EXIST,
1909
+ )
1910
+
1911
+ def _apply_order_by_search_logged_models(
1912
+ self,
1913
+ models: sqlalchemy.orm.Query,
1914
+ session: sqlalchemy.orm.Session,
1915
+ order_by: Optional[list[dict[str, Any]]] = None,
1916
+ ) -> sqlalchemy.orm.Query:
1917
+ order_by_clauses = []
1918
+ has_creation_timestamp = False
1919
+ for ob in order_by or []:
1920
+ field_name = ob.get("field_name")
1921
+ ascending = ob.get("ascending", True)
1922
+ if "." not in field_name:
1923
+ name = SqlLoggedModel.ALIASES.get(field_name, field_name)
1924
+ if name == "creation_timestamp_ms":
1925
+ has_creation_timestamp = True
1926
+ try:
1927
+ col = getattr(SqlLoggedModel, name)
1928
+ except AttributeError:
1929
+ raise MlflowException.invalid_parameter_value(
1930
+ f"Invalid order by field name: {field_name}"
1931
+ )
1932
+ # Why not use `nulls_last`? Because it's not supported by all dialects (e.g., MySQL)
1933
+ order_by_clauses.extend(
1934
+ [
1935
+ # Sort nulls last
1936
+ sqlalchemy.case((col.is_(None), 1), else_=0).asc(),
1937
+ col.asc() if ascending else col.desc(),
1938
+ ]
1939
+ )
1940
+ continue
1941
+
1942
+ entity, name = field_name.split(".", 1)
1943
+ # TODO: Support filtering by other entities such as params if needed
1944
+ if entity != "metrics":
1945
+ raise MlflowException.invalid_parameter_value(
1946
+ f"Invalid order by field name: {field_name}. Only metrics are supported."
1947
+ )
1948
+
1949
+ # Sub query to get the latest metrics value for each (model_id, metric_name) pair
1950
+ dataset_filter = []
1951
+ if dataset_name := ob.get("dataset_name"):
1952
+ dataset_filter.append(SqlLoggedModelMetric.dataset_name == dataset_name)
1953
+ if dataset_digest := ob.get("dataset_digest"):
1954
+ dataset_filter.append(SqlLoggedModelMetric.dataset_digest == dataset_digest)
1955
+
1956
+ subquery = (
1957
+ session.query(
1958
+ SqlLoggedModelMetric.model_id,
1959
+ SqlLoggedModelMetric.metric_value,
1960
+ func.rank()
1961
+ .over(
1962
+ partition_by=[
1963
+ SqlLoggedModelMetric.model_id,
1964
+ SqlLoggedModelMetric.metric_name,
1965
+ ],
1966
+ order_by=[
1967
+ SqlLoggedModelMetric.metric_timestamp_ms.desc(),
1968
+ SqlLoggedModelMetric.metric_step.desc(),
1969
+ ],
1970
+ )
1971
+ .label("rank"),
1972
+ )
1973
+ .filter(
1974
+ SqlLoggedModelMetric.metric_name == name,
1975
+ *dataset_filter,
1976
+ )
1977
+ .subquery()
1978
+ )
1979
+ subquery = select(subquery.c).where(subquery.c.rank == 1).subquery()
1980
+
1981
+ models = models.outerjoin(subquery)
1982
+ # Why not use `nulls_last`? Because it's not supported by all dialects (e.g., MySQL)
1983
+ order_by_clauses.extend(
1984
+ [
1985
+ # Sort nulls last
1986
+ sqlalchemy.case((subquery.c.metric_value.is_(None), 1), else_=0).asc(),
1987
+ subquery.c.metric_value.asc() if ascending else subquery.c.metric_value.desc(),
1988
+ ]
1989
+ )
1990
+
1991
+ if not has_creation_timestamp:
1992
+ order_by_clauses.append(SqlLoggedModel.creation_timestamp_ms.desc())
1993
+
1994
+ return models.order_by(*order_by_clauses)
1995
+
1996
+ def _apply_filter_string_datasets_search_logged_models(
1997
+ self,
1998
+ models: sqlalchemy.orm.Query,
1999
+ session: sqlalchemy.orm.Session,
2000
+ experiment_ids: list[str],
2001
+ filter_string: Optional[str],
2002
+ datasets: Optional[list[dict[str, Any]]],
2003
+ ):
2004
+ from mlflow.utils.search_logged_model_utils import EntityType, parse_filter_string
2005
+
2006
+ comparisons = parse_filter_string(filter_string)
2007
+ dialect = self._get_dialect()
2008
+ attr_filters: list[sqlalchemy.BinaryExpression] = []
2009
+ non_attr_filters: list[sqlalchemy.BinaryExpression] = []
2010
+
2011
+ dataset_filters = []
2012
+ if datasets:
2013
+ for dataset in datasets:
2014
+ dataset_filter = SqlLoggedModelMetric.dataset_name == dataset["dataset_name"]
2015
+ if "dataset_digest" in dataset:
2016
+ dataset_filter = dataset_filter & (
2017
+ SqlLoggedModelMetric.dataset_digest == dataset["dataset_digest"]
2018
+ )
2019
+ dataset_filters.append(dataset_filter)
2020
+
2021
+ has_metric_filters = False
2022
+ for comp in comparisons:
2023
+ comp_func = SearchUtils.get_sql_comparison_func(comp.op, dialect)
2024
+ if comp.entity.type == EntityType.ATTRIBUTE:
2025
+ attr_filters.append(comp_func(getattr(SqlLoggedModel, comp.entity.key), comp.value))
2026
+ elif comp.entity.type == EntityType.METRIC:
2027
+ has_metric_filters = True
2028
+ metric_filters = [
2029
+ SqlLoggedModelMetric.metric_name == comp.entity.key,
2030
+ comp_func(SqlLoggedModelMetric.metric_value, comp.value),
2031
+ ]
2032
+ if dataset_filters:
2033
+ metric_filters.append(sqlalchemy.or_(*dataset_filters))
2034
+ non_attr_filters.append(
2035
+ session.query(SqlLoggedModelMetric).filter(*metric_filters).subquery()
2036
+ )
2037
+ elif comp.entity.type == EntityType.PARAM:
2038
+ non_attr_filters.append(
2039
+ session.query(SqlLoggedModelParam)
2040
+ .filter(
2041
+ SqlLoggedModelParam.param_key == comp.entity.key,
2042
+ comp_func(SqlLoggedModelParam.param_value, comp.value),
2043
+ )
2044
+ .subquery()
2045
+ )
2046
+ elif comp.entity.type == EntityType.TAG:
2047
+ non_attr_filters.append(
2048
+ session.query(SqlLoggedModelTag)
2049
+ .filter(
2050
+ SqlLoggedModelTag.tag_key == comp.entity.key,
2051
+ comp_func(SqlLoggedModelTag.tag_value, comp.value),
2052
+ )
2053
+ .subquery()
2054
+ )
2055
+
2056
+ for f in non_attr_filters:
2057
+ models = models.join(f)
2058
+
2059
+ # If there are dataset filters but no metric filters,
2060
+ # filter for models that have any metrics on the datasets
2061
+ if dataset_filters and not has_metric_filters:
2062
+ subquery = (
2063
+ session.query(SqlLoggedModelMetric.model_id)
2064
+ .filter(sqlalchemy.or_(*dataset_filters))
2065
+ .distinct()
2066
+ .subquery()
2067
+ )
2068
+ models = models.join(subquery)
2069
+
2070
+ return models.filter(
2071
+ SqlLoggedModel.lifecycle_stage != LifecycleStage.DELETED,
2072
+ SqlLoggedModel.experiment_id.in_(experiment_ids),
2073
+ *attr_filters,
2074
+ )
2075
+
2076
+ def search_logged_models(
2077
+ self,
2078
+ experiment_ids: list[str],
2079
+ filter_string: Optional[str] = None,
2080
+ datasets: Optional[list[DatasetFilter]] = None,
2081
+ max_results: Optional[int] = None,
2082
+ order_by: Optional[list[dict[str, Any]]] = None,
2083
+ page_token: Optional[str] = None,
2084
+ ) -> PagedList[LoggedModel]:
2085
+ if datasets and not all(d.get("dataset_name") for d in datasets):
2086
+ raise MlflowException(
2087
+ "`dataset_name` in the `datasets` clause must be specified.",
2088
+ INVALID_PARAMETER_VALUE,
2089
+ )
2090
+ if page_token:
2091
+ token = SearchLoggedModelsPaginationToken.decode(page_token)
2092
+ token.validate(experiment_ids, filter_string, order_by)
2093
+ offset = token.offset
2094
+ else:
2095
+ offset = 0
2096
+
2097
+ max_results = max_results or SEARCH_LOGGED_MODEL_MAX_RESULTS_DEFAULT
2098
+ with self.ManagedSessionMaker() as session:
2099
+ models = session.query(SqlLoggedModel)
2100
+ models = self._apply_filter_string_datasets_search_logged_models(
2101
+ models, session, experiment_ids, filter_string, datasets
2102
+ )
2103
+ models = self._apply_order_by_search_logged_models(models, session, order_by)
2104
+ models = models.offset(offset).limit(max_results + 1).all()
2105
+
2106
+ if len(models) > max_results:
2107
+ token = SearchLoggedModelsPaginationToken(
2108
+ offset=offset + max_results,
2109
+ experiment_ids=experiment_ids,
2110
+ filter_string=filter_string,
2111
+ order_by=order_by,
2112
+ ).encode()
2113
+ else:
2114
+ token = None
2115
+
2116
+ return PagedList([lm.to_mlflow_entity() for lm in models[:max_results]], token=token)
2117
+
2118
+ #######################################################################################
2119
+ # Below are Tracing APIs. We may refactor them to be in a separate class in the future.
2120
+ #######################################################################################
2121
+ def _get_trace_artifact_location_tag(self, experiment, trace_id: str) -> SqlTraceTag:
2122
+ # Trace data is stored as file artifacts regardless of the tracking backend choice.
2123
+ # We use subdirectory "/traces" under the experiment's artifact location to isolate
2124
+ # them from run artifacts.
2125
+ artifact_uri = append_to_uri_path(
2126
+ experiment.artifact_location,
2127
+ SqlAlchemyStore.TRACE_FOLDER_NAME,
2128
+ trace_id,
2129
+ SqlAlchemyStore.ARTIFACTS_FOLDER_NAME,
2130
+ )
2131
+ return SqlTraceTag(request_id=trace_id, key=MLFLOW_ARTIFACT_LOCATION, value=artifact_uri)
2132
+
2133
+ def start_trace(self, trace_info: "TraceInfo") -> TraceInfo:
2134
+ """
2135
+ Create a trace using the V3 API format with a complete Trace object.
2136
+
2137
+ Args:
2138
+ trace_info: The TraceInfo object to create in the backend.
2139
+
2140
+ Returns:
2141
+ The created TraceInfo object from the backend.
2142
+ """
2143
+ with self.ManagedSessionMaker() as session:
2144
+ experiment = self.get_experiment(trace_info.experiment_id)
2145
+ self._check_experiment_is_active(experiment)
2146
+
2147
+ # Use the provided trace_id
2148
+ trace_id = trace_info.trace_id
2149
+
2150
+ # Create SqlTraceInfo with V3 fields directly
2151
+ sql_trace_info = SqlTraceInfo(
2152
+ request_id=trace_id,
2153
+ experiment_id=trace_info.experiment_id,
2154
+ timestamp_ms=trace_info.request_time,
2155
+ execution_time_ms=trace_info.execution_duration,
2156
+ status=trace_info.state.value,
2157
+ client_request_id=trace_info.client_request_id,
2158
+ request_preview=trace_info.request_preview,
2159
+ response_preview=trace_info.response_preview,
2160
+ )
2161
+
2162
+ sql_trace_info.tags = [
2163
+ SqlTraceTag(request_id=trace_id, key=k, value=v) for k, v in trace_info.tags.items()
2164
+ ]
2165
+ sql_trace_info.tags.append(self._get_trace_artifact_location_tag(experiment, trace_id))
2166
+
2167
+ sql_trace_info.request_metadata = [
2168
+ SqlTraceMetadata(request_id=trace_id, key=k, value=v)
2169
+ for k, v in trace_info.trace_metadata.items()
2170
+ ]
2171
+ session.add(sql_trace_info)
2172
+ return sql_trace_info.to_mlflow_entity()
2173
+
2174
+ def get_trace_info(self, trace_id: str) -> TraceInfo:
2175
+ """
2176
+ Fetch the trace info for the given trace id.
2177
+
2178
+ Args:
2179
+ trace_id: Unique string identifier of the trace.
2180
+
2181
+ Returns:
2182
+ The TraceInfo object.
2183
+ """
2184
+ with self.ManagedSessionMaker() as session:
2185
+ sql_trace_info = self._get_sql_trace_info(session, trace_id)
2186
+ return sql_trace_info.to_mlflow_entity()
2187
+
2188
+ def _get_sql_trace_info(self, session, trace_id) -> SqlTraceInfo:
2189
+ sql_trace_info = (
2190
+ session.query(SqlTraceInfo).filter(SqlTraceInfo.request_id == trace_id).one_or_none()
2191
+ )
2192
+ if sql_trace_info is None:
2193
+ raise MlflowException(
2194
+ f"Trace with ID '{trace_id}' not found.",
2195
+ RESOURCE_DOES_NOT_EXIST,
2196
+ )
2197
+ return sql_trace_info
2198
+
2199
+ def search_traces(
2200
+ self,
2201
+ experiment_ids: list[str],
2202
+ filter_string: Optional[str] = None,
2203
+ max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
2204
+ order_by: Optional[list[str]] = None,
2205
+ page_token: Optional[str] = None,
2206
+ model_id: Optional[str] = None,
2207
+ sql_warehouse_id: Optional[str] = None,
2208
+ ) -> tuple[list[TraceInfo], Optional[str]]:
2209
+ """
2210
+ Return traces that match the given list of search expressions within the experiments.
2211
+
2212
+ Args:
2213
+ experiment_ids: List of experiment ids to scope the search.
2214
+ filter_string: A search filter string.
2215
+ max_results: Maximum number of traces desired.
2216
+ order_by: List of order_by clauses.
2217
+ page_token: Token specifying the next page of results. It should be obtained from
2218
+ a ``search_traces`` call.
2219
+ model_id: If specified, search traces associated with the given model ID.
2220
+ sql_warehouse_id: Only used in Databricks. The ID of the SQL warehouse to use for
2221
+ searching traces in inference tables.
2222
+
2223
+ Returns:
2224
+ A tuple of a list of :py:class:`TraceInfo <mlflow.entities.TraceInfo>` objects that
2225
+ satisfy the search expressions and a pagination token for the next page of results.
2226
+ """
2227
+ self._validate_max_results_param(max_results)
2228
+
2229
+ with self.ManagedSessionMaker() as session:
2230
+ cases_orderby, parsed_orderby, sorting_joins = _get_orderby_clauses_for_search_traces(
2231
+ order_by or [], session
2232
+ )
2233
+ stmt = select(SqlTraceInfo, *cases_orderby)
2234
+
2235
+ attribute_filters, non_attribute_filters = _get_filter_clauses_for_search_traces(
2236
+ filter_string, session, self._get_dialect()
2237
+ )
2238
+ for non_attr_filter in non_attribute_filters:
2239
+ stmt = stmt.join(non_attr_filter)
2240
+
2241
+ # using an outer join is necessary here because we want to be able to sort
2242
+ # on a column (tag, metric or param) without removing the lines that
2243
+ # do not have a value for this column (which is what inner join would do)
2244
+ for j in sorting_joins:
2245
+ stmt = stmt.outerjoin(j)
2246
+
2247
+ offset = SearchTraceUtils.parse_start_offset_from_page_token(page_token)
2248
+ stmt = (
2249
+ # NB: We don't need to distinct the results of joins because of the fact that
2250
+ # the right tables of the joins are unique on the join key, trace_id.
2251
+ # This is because the subquery that is joined on the right side is conditioned
2252
+ # by a key and value pair of tags/metadata, and the combination of key and
2253
+ # trace_id is unique in those tables.
2254
+ # Be careful when changing the query building logic, as it may break this
2255
+ # uniqueness property and require deduplication, which can be expensive.
2256
+ stmt.filter(
2257
+ SqlTraceInfo.experiment_id.in_(experiment_ids),
2258
+ *attribute_filters,
2259
+ )
2260
+ .order_by(*parsed_orderby)
2261
+ .offset(offset)
2262
+ .limit(max_results)
2263
+ )
2264
+ queried_traces = session.execute(stmt).scalars(SqlTraceInfo).all()
2265
+ trace_infos = [t.to_mlflow_entity() for t in queried_traces]
2266
+
2267
+ # Compute next search token
2268
+ if max_results == len(trace_infos):
2269
+ final_offset = offset + max_results
2270
+ next_token = SearchTraceUtils.create_page_token(final_offset)
2271
+ else:
2272
+ next_token = None
2273
+
2274
+ return trace_infos, next_token
2275
+
2276
+ def _validate_max_results_param(self, max_results: int, allow_null=False):
2277
+ if (not allow_null and max_results is None) or max_results < 1:
2278
+ raise MlflowException(
2279
+ f"Invalid value {max_results} for parameter 'max_results' supplied. It must be "
2280
+ f"a positive integer",
2281
+ INVALID_PARAMETER_VALUE,
2282
+ )
2283
+
2284
+ if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
2285
+ raise MlflowException(
2286
+ f"Invalid value {max_results} for parameter 'max_results' supplied. It must be at "
2287
+ f"most {SEARCH_MAX_RESULTS_THRESHOLD}",
2288
+ INVALID_PARAMETER_VALUE,
2289
+ )
2290
+
2291
+ def set_trace_tag(self, trace_id: str, key: str, value: str):
2292
+ """
2293
+ Set a tag on the trace with the given trace_id.
2294
+
2295
+ Args:
2296
+ trace_id: The ID of the trace.
2297
+ key: The string key of the tag.
2298
+ value: The string value of the tag.
2299
+ """
2300
+ with self.ManagedSessionMaker() as session:
2301
+ key, value = _validate_trace_tag(key, value)
2302
+ session.merge(SqlTraceTag(request_id=trace_id, key=key, value=value))
2303
+
2304
+ def delete_trace_tag(self, trace_id: str, key: str):
2305
+ """
2306
+ Delete a tag on the trace with the given trace_id.
2307
+
2308
+ Args:
2309
+ trace_id: The ID of the trace.
2310
+ key: The string key of the tag.
2311
+ """
2312
+ with self.ManagedSessionMaker() as session:
2313
+ tags = session.query(SqlTraceTag).filter_by(request_id=trace_id, key=key)
2314
+ if tags.count() == 0:
2315
+ raise MlflowException(
2316
+ f"No trace tag with key '{key}' for trace with ID '{trace_id}'",
2317
+ RESOURCE_DOES_NOT_EXIST,
2318
+ )
2319
+ tags.delete()
2320
+
2321
+ def _delete_traces(
2322
+ self,
2323
+ experiment_id: str,
2324
+ max_timestamp_millis: Optional[int] = None,
2325
+ max_traces: Optional[int] = None,
2326
+ trace_ids: Optional[list[str]] = None,
2327
+ ) -> int:
2328
+ """
2329
+ Delete traces based on the specified criteria.
2330
+
2331
+ Args:
2332
+ experiment_id: ID of the associated experiment.
2333
+ max_timestamp_millis: The maximum timestamp in milliseconds since the UNIX epoch for
2334
+ deleting traces. Traces older than or equal to this timestamp will be deleted.
2335
+ max_traces: The maximum number of traces to delete.
2336
+ trace_ids: A set of request IDs to delete.
2337
+
2338
+ Returns:
2339
+ The number of traces deleted.
2340
+ """
2341
+ with self.ManagedSessionMaker() as session:
2342
+ filters = [SqlTraceInfo.experiment_id == experiment_id]
2343
+ if max_timestamp_millis:
2344
+ filters.append(SqlTraceInfo.timestamp_ms <= max_timestamp_millis)
2345
+ if trace_ids:
2346
+ filters.append(SqlTraceInfo.request_id.in_(trace_ids))
2347
+ if max_traces:
2348
+ filters.append(
2349
+ SqlTraceInfo.request_id.in_(
2350
+ session.query(SqlTraceInfo.request_id)
2351
+ .filter(*filters)
2352
+ # Delete the oldest traces first
2353
+ .order_by(SqlTraceInfo.timestamp_ms)
2354
+ .limit(max_traces)
2355
+ .subquery()
2356
+ )
2357
+ )
2358
+
2359
+ return (
2360
+ session.query(SqlTraceInfo)
2361
+ .filter(and_(*filters))
2362
+ .delete(synchronize_session="fetch")
2363
+ )
2364
+
2365
+ #######################################################################################
2366
+ # Below are legacy V2 Tracing APIs. DO NOT USE. Use the V3 APIs instead.
2367
+ #######################################################################################
2368
+ def deprecated_start_trace_v2(
2369
+ self,
2370
+ experiment_id: str,
2371
+ timestamp_ms: int,
2372
+ request_metadata: dict[str, str],
2373
+ tags: dict[str, str],
2374
+ ) -> TraceInfoV2:
2375
+ """
2376
+ DEPRECATED. DO NOT USE.
2377
+
2378
+ Create an initial TraceInfo object in the database.
2379
+
2380
+ Args:
2381
+ experiment_id: String id of the experiment for this run.
2382
+ timestamp_ms: Start time of the trace, in milliseconds since the UNIX epoch.
2383
+ request_metadata: Metadata of the trace.
2384
+ tags: Tags of the trace.
2385
+
2386
+ Returns:
2387
+ The created TraceInfo object.
2388
+ """
2389
+ with self.ManagedSessionMaker() as session:
2390
+ experiment = self.get_experiment(experiment_id)
2391
+ self._check_experiment_is_active(experiment)
2392
+
2393
+ request_id = generate_request_id_v2()
2394
+ trace_info = SqlTraceInfo(
2395
+ request_id=request_id,
2396
+ experiment_id=experiment_id,
2397
+ timestamp_ms=timestamp_ms,
2398
+ execution_time_ms=None,
2399
+ status=TraceStatus.IN_PROGRESS,
2400
+ )
2401
+
2402
+ trace_info.tags = [SqlTraceTag(key=k, value=v) for k, v in tags.items()]
2403
+ trace_info.tags.append(self._get_trace_artifact_location_tag(experiment, request_id))
2404
+
2405
+ trace_info.request_metadata = [
2406
+ SqlTraceMetadata(key=k, value=v) for k, v in request_metadata.items()
2407
+ ]
2408
+ session.add(trace_info)
2409
+
2410
+ return TraceInfoV2.from_v3(trace_info.to_mlflow_entity())
2411
+
2412
+ def deprecated_end_trace_v2(
2413
+ self,
2414
+ request_id: str,
2415
+ timestamp_ms: int,
2416
+ status: TraceStatus,
2417
+ request_metadata: dict[str, str],
2418
+ tags: dict[str, str],
2419
+ ) -> TraceInfoV2:
2420
+ """
2421
+ DEPRECATED. DO NOT USE.
2422
+
2423
+ Update the TraceInfo object in the database with the completed trace info.
2424
+
2425
+ Args:
2426
+ request_id: Unique string identifier of the trace.
2427
+ timestamp_ms: End time of the trace, in milliseconds. The execution time field
2428
+ in the TraceInfo will be calculated by subtracting the start time from this.
2429
+ status: Status of the trace.
2430
+ request_metadata: Metadata of the trace. This will be merged with the existing
2431
+ metadata logged during the start_trace call.
2432
+ tags: Tags of the trace. This will be merged with the existing tags logged
2433
+ during the start_trace or set_trace_tag calls.
2434
+
2435
+ Returns:
2436
+ The updated TraceInfo object.
2437
+ """
2438
+ with self.ManagedSessionMaker() as session:
2439
+ sql_trace_info = self._get_sql_trace_info(session, request_id)
2440
+ trace_start_time_ms = sql_trace_info.timestamp_ms
2441
+ execution_time_ms = timestamp_ms - trace_start_time_ms
2442
+ sql_trace_info.execution_time_ms = execution_time_ms
2443
+ sql_trace_info.status = status
2444
+ session.merge(sql_trace_info)
2445
+ for k, v in request_metadata.items():
2446
+ session.merge(SqlTraceMetadata(request_id=request_id, key=k, value=v))
2447
+ for k, v in tags.items():
2448
+ session.merge(SqlTraceTag(request_id=request_id, key=k, value=v))
2449
+ return TraceInfoV2.from_v3(sql_trace_info.to_mlflow_entity())
2450
+
2451
+
2452
+ def _get_sqlalchemy_filter_clauses(parsed, session, dialect):
2453
+ """
2454
+ Creates run attribute filters and subqueries that will be inner-joined to SqlRun to act as
2455
+ multi-clause filters and return them as a tuple.
2456
+ """
2457
+ attribute_filters = []
2458
+ non_attribute_filters = []
2459
+ dataset_filters = []
2460
+
2461
+ for sql_statement in parsed:
2462
+ key_type = sql_statement.get("type")
2463
+ key_name = sql_statement.get("key")
2464
+ value = sql_statement.get("value")
2465
+ comparator = sql_statement.get("comparator").upper()
2466
+
2467
+ key_name = SearchUtils.translate_key_alias(key_name)
2468
+
2469
+ if SearchUtils.is_string_attribute(
2470
+ key_type, key_name, comparator
2471
+ ) or SearchUtils.is_numeric_attribute(key_type, key_name, comparator):
2472
+ if key_name == "run_name":
2473
+ # Treat "attributes.run_name == <value>" as "tags.`mlflow.runName` == <value>".
2474
+ # The name column in the runs table is empty for runs logged in MLflow <= 1.29.0.
2475
+ key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(
2476
+ SqlTag.key, MLFLOW_RUN_NAME
2477
+ )
2478
+ val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
2479
+ SqlTag.value, value
2480
+ )
2481
+ non_attribute_filters.append(
2482
+ session.query(SqlTag).filter(key_filter, val_filter).subquery()
2483
+ )
2484
+ else:
2485
+ attribute = getattr(SqlRun, SqlRun.get_attribute_name(key_name))
2486
+ attr_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
2487
+ attribute, value
2488
+ )
2489
+ attribute_filters.append(attr_filter)
2490
+ else:
2491
+ if SearchUtils.is_metric(key_type, comparator):
2492
+ entity = SqlLatestMetric
2493
+ value = float(value)
2494
+ elif SearchUtils.is_param(key_type, comparator):
2495
+ entity = SqlParam
2496
+ elif SearchUtils.is_tag(key_type, comparator):
2497
+ entity = SqlTag
2498
+ elif SearchUtils.is_dataset(key_type, comparator):
2499
+ entity = SqlDataset
2500
+ else:
2501
+ raise MlflowException(
2502
+ f"Invalid search expression type '{key_type}'",
2503
+ error_code=INVALID_PARAMETER_VALUE,
2504
+ )
2505
+
2506
+ if entity == SqlDataset:
2507
+ if key_name == "context":
2508
+ dataset_filters.append(
2509
+ session.query(entity, SqlInput, SqlInputTag)
2510
+ .join(SqlInput, SqlInput.source_id == SqlDataset.dataset_uuid)
2511
+ .join(
2512
+ SqlInputTag,
2513
+ and_(
2514
+ SqlInputTag.input_uuid == SqlInput.input_uuid,
2515
+ SqlInputTag.name == MLFLOW_DATASET_CONTEXT,
2516
+ SearchUtils.get_sql_comparison_func(comparator, dialect)(
2517
+ getattr(SqlInputTag, "value"), value
2518
+ ),
2519
+ ),
2520
+ )
2521
+ .subquery()
2522
+ )
2523
+ else:
2524
+ dataset_attr_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
2525
+ getattr(SqlDataset, key_name), value
2526
+ )
2527
+ dataset_filters.append(
2528
+ session.query(entity, SqlInput)
2529
+ .join(SqlInput, SqlInput.source_id == SqlDataset.dataset_uuid)
2530
+ .filter(dataset_attr_filter)
2531
+ .subquery()
2532
+ )
2533
+ else:
2534
+ key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(entity.key, key_name)
2535
+ val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
2536
+ entity.value, value
2537
+ )
2538
+ non_attribute_filters.append(
2539
+ session.query(entity).filter(key_filter, val_filter).subquery()
2540
+ )
2541
+
2542
+ return attribute_filters, non_attribute_filters, dataset_filters
2543
+
2544
+
2545
+ def _get_orderby_clauses(order_by_list, session):
2546
+ """Sorts a set of runs based on their natural ordering and an overriding set of order_bys.
2547
+ Runs are naturally ordered first by start time descending, then by run id for tie-breaking.
2548
+ """
2549
+
2550
+ clauses = []
2551
+ ordering_joins = []
2552
+ clause_id = 0
2553
+ observed_order_by_clauses = set()
2554
+ select_clauses = []
2555
+ # contrary to filters, it is not easily feasible to separately handle sorting
2556
+ # on attributes and on joined tables as we must keep all clauses in the same order
2557
+ if order_by_list:
2558
+ for order_by_clause in order_by_list:
2559
+ clause_id += 1
2560
+ (key_type, key, ascending) = SearchUtils.parse_order_by_for_search_runs(order_by_clause)
2561
+ key = SearchUtils.translate_key_alias(key)
2562
+ if SearchUtils.is_string_attribute(
2563
+ key_type, key, "="
2564
+ ) or SearchUtils.is_numeric_attribute(key_type, key, "="):
2565
+ order_value = getattr(SqlRun, SqlRun.get_attribute_name(key))
2566
+ else:
2567
+ if SearchUtils.is_metric(key_type, "="): # any valid comparator
2568
+ entity = SqlLatestMetric
2569
+ elif SearchUtils.is_tag(key_type, "="):
2570
+ entity = SqlTag
2571
+ elif SearchUtils.is_param(key_type, "="):
2572
+ entity = SqlParam
2573
+ else:
2574
+ raise MlflowException(
2575
+ f"Invalid identifier type '{key_type}'",
2576
+ error_code=INVALID_PARAMETER_VALUE,
2577
+ )
2578
+
2579
+ # build a subquery first because we will join it in the main request so that the
2580
+ # metric we want to sort on is available when we apply the sorting clause
2581
+ subquery = session.query(entity).filter(entity.key == key).subquery()
2582
+
2583
+ ordering_joins.append(subquery)
2584
+ order_value = subquery.c.value
2585
+
2586
+ # MySQL does not support NULLS LAST expression, so we sort first by
2587
+ # presence of the field (and is_nan for metrics), then by actual value
2588
+ # As the subqueries are created independently and used later in the
2589
+ # same main query, the CASE WHEN columns need to have unique names to
2590
+ # avoid ambiguity
2591
+ if SearchUtils.is_metric(key_type, "="):
2592
+ case = sql.case(
2593
+ # Ideally the use of "IS" is preferred here but owing to sqlalchemy
2594
+ # translation in MSSQL we are forced to use "=" instead.
2595
+ # These 2 options are functionally identical / unchanged because
2596
+ # the column (is_nan) is not nullable. However it could become an issue
2597
+ # if this precondition changes in the future.
2598
+ (subquery.c.is_nan == sqlalchemy.true(), 1),
2599
+ (order_value.is_(None), 2),
2600
+ else_=0,
2601
+ ).label(f"clause_{clause_id}")
2602
+
2603
+ else: # other entities do not have an 'is_nan' field
2604
+ case = sql.case((order_value.is_(None), 1), else_=0).label(f"clause_{clause_id}")
2605
+ clauses.append(case.name)
2606
+ select_clauses.append(case)
2607
+ select_clauses.append(order_value)
2608
+
2609
+ if (key_type, key) in observed_order_by_clauses:
2610
+ raise MlflowException(f"`order_by` contains duplicate fields: {order_by_list}")
2611
+ observed_order_by_clauses.add((key_type, key))
2612
+
2613
+ if ascending:
2614
+ clauses.append(order_value)
2615
+ else:
2616
+ clauses.append(order_value.desc())
2617
+
2618
+ if (
2619
+ SearchUtils._ATTRIBUTE_IDENTIFIER,
2620
+ SqlRun.start_time.key,
2621
+ ) not in observed_order_by_clauses:
2622
+ clauses.append(SqlRun.start_time.desc())
2623
+ clauses.append(SqlRun.run_uuid)
2624
+ return select_clauses, clauses, ordering_joins
2625
+
2626
+
2627
+ def _get_search_experiments_filter_clauses(parsed_filters, dialect):
2628
+ attribute_filters = []
2629
+ non_attribute_filters = []
2630
+ for f in parsed_filters:
2631
+ type_ = f["type"]
2632
+ key = f["key"]
2633
+ comparator = f["comparator"]
2634
+ value = f["value"]
2635
+ if type_ == "attribute":
2636
+ if SearchExperimentsUtils.is_string_attribute(
2637
+ type_, key, comparator
2638
+ ) and comparator not in ("=", "!=", "LIKE", "ILIKE"):
2639
+ raise MlflowException.invalid_parameter_value(
2640
+ f"Invalid comparator for string attribute: {comparator}"
2641
+ )
2642
+ if SearchExperimentsUtils.is_numeric_attribute(
2643
+ type_, key, comparator
2644
+ ) and comparator not in ("=", "!=", "<", "<=", ">", ">="):
2645
+ raise MlflowException.invalid_parameter_value(
2646
+ f"Invalid comparator for numeric attribute: {comparator}"
2647
+ )
2648
+ attr = getattr(SqlExperiment, key)
2649
+ attr_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(attr, value)
2650
+ attribute_filters.append(attr_filter)
2651
+ elif type_ == "tag":
2652
+ if comparator not in ("=", "!=", "LIKE", "ILIKE"):
2653
+ raise MlflowException.invalid_parameter_value(
2654
+ f"Invalid comparator for tag: {comparator}"
2655
+ )
2656
+ val_filter = SearchUtils.get_sql_comparison_func(comparator, dialect)(
2657
+ SqlExperimentTag.value, value
2658
+ )
2659
+ key_filter = SearchUtils.get_sql_comparison_func("=", dialect)(
2660
+ SqlExperimentTag.key, key
2661
+ )
2662
+ non_attribute_filters.append(
2663
+ select(SqlExperimentTag).filter(key_filter, val_filter).subquery()
2664
+ )
2665
+ else:
2666
+ raise MlflowException.invalid_parameter_value(f"Invalid token type: {type_}")
2667
+
2668
+ return attribute_filters, non_attribute_filters
2669
+
2670
+
2671
+ def _get_search_experiments_order_by_clauses(order_by):
2672
+ order_by_clauses = []
2673
+ for type_, key, ascending in map(
2674
+ SearchExperimentsUtils.parse_order_by_for_search_experiments,
2675
+ order_by or ["creation_time DESC", "experiment_id ASC"],
2676
+ ):
2677
+ if type_ == "attribute":
2678
+ order_by_clauses.append((getattr(SqlExperiment, key), ascending))
2679
+ else:
2680
+ raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}")
2681
+
2682
+ # Add a tie-breaker
2683
+ if not any(col == SqlExperiment.experiment_id for col, _ in order_by_clauses):
2684
+ order_by_clauses.append((SqlExperiment.experiment_id, False))
2685
+
2686
+ return [col.asc() if ascending else col.desc() for col, ascending in order_by_clauses]
2687
+
2688
+
2689
+ def _get_orderby_clauses_for_search_traces(order_by_list: list[str], session):
2690
+ """Sorts a set of traces based on their natural ordering and an overriding set of order_bys.
2691
+ Traces are ordered first by timestamp_ms descending, then by trace_id for tie-breaking.
2692
+ """
2693
+ clauses = []
2694
+ ordering_joins = []
2695
+ observed_order_by_clauses = set()
2696
+ select_clauses = []
2697
+
2698
+ for clause_id, order_by_clause in enumerate(order_by_list):
2699
+ (key_type, key, ascending) = SearchTraceUtils.parse_order_by_for_search_traces(
2700
+ order_by_clause
2701
+ )
2702
+
2703
+ if SearchTraceUtils.is_attribute(key_type, key, "="):
2704
+ order_value = getattr(SqlTraceInfo, key)
2705
+ else:
2706
+ if SearchTraceUtils.is_tag(key_type, "="):
2707
+ entity = SqlTraceTag
2708
+ elif SearchTraceUtils.is_request_metadata(key_type, "="):
2709
+ entity = SqlTraceMetadata
2710
+ else:
2711
+ raise MlflowException(
2712
+ f"Invalid identifier type '{key_type}'",
2713
+ error_code=INVALID_PARAMETER_VALUE,
2714
+ )
2715
+ # Tags and request metadata requires a join to the main table (trace_info)
2716
+ subquery = session.query(entity).filter(entity.key == key).subquery()
2717
+ ordering_joins.append(subquery)
2718
+ order_value = subquery.c.value
2719
+
2720
+ case = sql.case((order_value.is_(None), 1), else_=0).label(f"clause_{clause_id}")
2721
+ clauses.append(case.name)
2722
+ select_clauses.append(case)
2723
+ select_clauses.append(order_value)
2724
+
2725
+ if (key_type, key) in observed_order_by_clauses:
2726
+ raise MlflowException(f"`order_by` contains duplicate fields: {order_by_list}")
2727
+ observed_order_by_clauses.add((key_type, key))
2728
+ clauses.append(order_value if ascending else order_value.desc())
2729
+
2730
+ # Add descending trace start time as default ordering and a tie-breaker
2731
+ for attr, ascending in [
2732
+ (SqlTraceInfo.timestamp_ms, False),
2733
+ (SqlTraceInfo.request_id, True),
2734
+ ]:
2735
+ if (
2736
+ SearchTraceUtils._ATTRIBUTE_IDENTIFIER,
2737
+ attr.key,
2738
+ ) not in observed_order_by_clauses:
2739
+ clauses.append(attr if ascending else attr.desc())
2740
+ return select_clauses, clauses, ordering_joins
2741
+
2742
+
2743
+ def _get_filter_clauses_for_search_traces(filter_string, session, dialect):
2744
+ """
2745
+ Creates trace attribute filters and subqueries that will be inner-joined
2746
+ to SqlTraceInfo to act as multi-clause filters and return them as a tuple.
2747
+ """
2748
+ attribute_filters = []
2749
+ non_attribute_filters = []
2750
+
2751
+ parsed_filters = SearchTraceUtils.parse_search_filter_for_search_traces(filter_string)
2752
+ for sql_statement in parsed_filters:
2753
+ key_type = sql_statement.get("type")
2754
+ key_name = sql_statement.get("key")
2755
+ value = sql_statement.get("value")
2756
+ comparator = sql_statement.get("comparator").upper()
2757
+
2758
+ if SearchTraceUtils.is_attribute(key_type, key_name, comparator):
2759
+ attribute = getattr(SqlTraceInfo, key_name)
2760
+ attr_filter = SearchTraceUtils.get_sql_comparison_func(comparator, dialect)(
2761
+ attribute, value
2762
+ )
2763
+ attribute_filters.append(attr_filter)
2764
+ else:
2765
+ if SearchTraceUtils.is_tag(key_type, comparator):
2766
+ entity = SqlTraceTag
2767
+ elif SearchTraceUtils.is_request_metadata(key_type, comparator):
2768
+ entity = SqlTraceMetadata
2769
+ else:
2770
+ raise MlflowException(
2771
+ f"Invalid search expression type '{key_type}'",
2772
+ error_code=INVALID_PARAMETER_VALUE,
2773
+ )
2774
+
2775
+ key_filter = SearchTraceUtils.get_sql_comparison_func("=", dialect)(
2776
+ entity.key, key_name
2777
+ )
2778
+ val_filter = SearchTraceUtils.get_sql_comparison_func(comparator, dialect)(
2779
+ entity.value, value
2780
+ )
2781
+ non_attribute_filters.append(
2782
+ session.query(entity).filter(key_filter, val_filter).subquery()
2783
+ )
2784
+
2785
+ return attribute_filters, non_attribute_filters