genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,1131 @@
1
+ import json
2
+ import logging
3
+ from typing import Any, Optional
4
+
5
+ from mlflow.entities import (
6
+ DatasetInput,
7
+ Experiment,
8
+ LoggedModel,
9
+ LoggedModelInput,
10
+ LoggedModelOutput,
11
+ LoggedModelParameter,
12
+ LoggedModelStatus,
13
+ LoggedModelTag,
14
+ Metric,
15
+ Run,
16
+ RunInfo,
17
+ ViewType,
18
+ )
19
+ from mlflow.entities.assessment import Assessment, Expectation, Feedback
20
+ from mlflow.entities.trace import Trace
21
+ from mlflow.entities.trace_data import TraceData
22
+ from mlflow.entities.trace_info import TraceInfo
23
+ from mlflow.entities.trace_info_v2 import TraceInfoV2
24
+ from mlflow.entities.trace_location import TraceLocation
25
+ from mlflow.entities.trace_status import TraceStatus
26
+ from mlflow.environment_variables import (
27
+ _MLFLOW_CREATE_LOGGED_MODEL_PARAMS_BATCH_SIZE,
28
+ _MLFLOW_LOG_LOGGED_MODEL_PARAMS_BATCH_SIZE,
29
+ MLFLOW_ASYNC_TRACE_LOGGING_RETRY_TIMEOUT,
30
+ )
31
+ from mlflow.exceptions import MlflowException
32
+ from mlflow.protos import databricks_pb2
33
+ from mlflow.protos.service_pb2 import (
34
+ CreateAssessment,
35
+ CreateExperiment,
36
+ CreateLoggedModel,
37
+ CreateRun,
38
+ DeleteAssessment,
39
+ DeleteExperiment,
40
+ DeleteLoggedModel,
41
+ DeleteLoggedModelTag,
42
+ DeleteRun,
43
+ DeleteTag,
44
+ DeleteTraces,
45
+ DeleteTraceTag,
46
+ EndTrace,
47
+ FinalizeLoggedModel,
48
+ GetAssessmentRequest,
49
+ GetExperiment,
50
+ GetExperimentByName,
51
+ GetLoggedModel,
52
+ GetMetricHistory,
53
+ GetOnlineTraceDetails,
54
+ GetRun,
55
+ GetTraceInfo,
56
+ GetTraceInfoV3,
57
+ LogBatch,
58
+ LogInputs,
59
+ LogLoggedModelParamsRequest,
60
+ LogMetric,
61
+ LogModel,
62
+ LogOutputs,
63
+ LogParam,
64
+ MlflowService,
65
+ RestoreExperiment,
66
+ RestoreRun,
67
+ SearchExperiments,
68
+ SearchLoggedModels,
69
+ SearchRuns,
70
+ SearchTraces,
71
+ SearchTracesV3,
72
+ SearchUnifiedTraces,
73
+ SetExperimentTag,
74
+ SetLoggedModelTags,
75
+ SetTag,
76
+ SetTraceTag,
77
+ StartTrace,
78
+ StartTraceV3,
79
+ TraceRequestMetadata,
80
+ TraceTag,
81
+ UpdateAssessment,
82
+ UpdateExperiment,
83
+ UpdateRun,
84
+ )
85
+ from mlflow.store.entities.paged_list import PagedList
86
+ from mlflow.store.tracking import SEARCH_TRACES_DEFAULT_MAX_RESULTS
87
+ from mlflow.store.tracking.abstract_store import AbstractStore
88
+ from mlflow.utils.proto_json_utils import message_to_json
89
+ from mlflow.utils.rest_utils import (
90
+ _REST_API_PATH_PREFIX,
91
+ _V3_TRACE_REST_API_PATH_PREFIX,
92
+ call_endpoint,
93
+ extract_api_info_for_service,
94
+ get_logged_model_endpoint,
95
+ get_single_assessment_endpoint,
96
+ get_single_trace_endpoint,
97
+ get_trace_tag_endpoint,
98
+ )
99
+
100
+ _METHOD_TO_INFO = extract_api_info_for_service(MlflowService, _REST_API_PATH_PREFIX)
101
+ _logger = logging.getLogger(__name__)
102
+
103
+
104
+ class RestStore(AbstractStore):
105
+ """
106
+ Client for a remote tracking server accessed via REST API calls
107
+
108
+ Args
109
+ get_host_creds: Method to be invoked prior to every REST request to get the
110
+ :py:class:`mlflow.rest_utils.MlflowHostCreds` for the request. Note that this
111
+ is a function so that we can obtain fresh credentials in the case of expiry.
112
+ """
113
+
114
+ def __init__(self, get_host_creds):
115
+ super().__init__()
116
+ self.get_host_creds = get_host_creds
117
+
118
+ def _call_endpoint(
119
+ self,
120
+ api,
121
+ json_body=None,
122
+ endpoint=None,
123
+ retry_timeout_seconds=None,
124
+ ):
125
+ if endpoint:
126
+ # Allow customizing the endpoint for compatibility with dynamic endpoints, such as
127
+ # /mlflow/traces/{trace_id}/info.
128
+ _, method = _METHOD_TO_INFO[api]
129
+ else:
130
+ endpoint, method = _METHOD_TO_INFO[api]
131
+ response_proto = api.Response()
132
+ return call_endpoint(
133
+ self.get_host_creds(),
134
+ endpoint,
135
+ method,
136
+ json_body,
137
+ response_proto,
138
+ retry_timeout_seconds=retry_timeout_seconds,
139
+ )
140
+
141
+ def search_experiments(
142
+ self,
143
+ view_type=ViewType.ACTIVE_ONLY,
144
+ max_results=None,
145
+ filter_string=None,
146
+ order_by=None,
147
+ page_token=None,
148
+ ):
149
+ req_body = message_to_json(
150
+ SearchExperiments(
151
+ view_type=view_type,
152
+ max_results=max_results,
153
+ page_token=page_token,
154
+ order_by=order_by,
155
+ filter=filter_string,
156
+ )
157
+ )
158
+ response_proto = self._call_endpoint(SearchExperiments, req_body)
159
+ experiments = [Experiment.from_proto(x) for x in response_proto.experiments]
160
+ token = (
161
+ response_proto.next_page_token if response_proto.HasField("next_page_token") else None
162
+ )
163
+ return PagedList(experiments, token)
164
+
165
+ def create_experiment(self, name, artifact_location=None, tags=None):
166
+ """
167
+ Create a new experiment.
168
+ If an experiment with the given name already exists, throws exception.
169
+
170
+ Args:
171
+ name: Desired name for an experiment.
172
+ artifact_location: Location to store run artifacts.
173
+ tags: A list of :py:class:`mlflow.entities.ExperimentTag` instances to set for the
174
+ experiment.
175
+
176
+ Returns:
177
+ experiment_id for the newly created experiment if successful, else None
178
+ """
179
+ tag_protos = [tag.to_proto() for tag in tags] if tags else []
180
+ req_body = message_to_json(
181
+ CreateExperiment(name=name, artifact_location=artifact_location, tags=tag_protos)
182
+ )
183
+ response_proto = self._call_endpoint(CreateExperiment, req_body)
184
+ return response_proto.experiment_id
185
+
186
+ def get_experiment(self, experiment_id):
187
+ """
188
+ Fetch the experiment from the backend store.
189
+
190
+ Args:
191
+ experiment_id: String id for the experiment
192
+
193
+ Returns:
194
+ A single :py:class:`mlflow.entities.Experiment` object if it exists,
195
+ otherwise raises an Exception.
196
+ """
197
+ req_body = message_to_json(GetExperiment(experiment_id=str(experiment_id)))
198
+ response_proto = self._call_endpoint(GetExperiment, req_body)
199
+ return Experiment.from_proto(response_proto.experiment)
200
+
201
+ def delete_experiment(self, experiment_id):
202
+ req_body = message_to_json(DeleteExperiment(experiment_id=str(experiment_id)))
203
+ self._call_endpoint(DeleteExperiment, req_body)
204
+
205
+ def restore_experiment(self, experiment_id):
206
+ req_body = message_to_json(RestoreExperiment(experiment_id=str(experiment_id)))
207
+ self._call_endpoint(RestoreExperiment, req_body)
208
+
209
+ def rename_experiment(self, experiment_id, new_name):
210
+ req_body = message_to_json(
211
+ UpdateExperiment(experiment_id=str(experiment_id), new_name=new_name)
212
+ )
213
+ self._call_endpoint(UpdateExperiment, req_body)
214
+
215
+ def get_run(self, run_id):
216
+ """
217
+ Fetch the run from backend store
218
+
219
+ Args:
220
+ run_id: Unique identifier for the run
221
+
222
+ Returns:
223
+ A single Run object if it exists, otherwise raises an Exception
224
+ """
225
+ req_body = message_to_json(GetRun(run_uuid=run_id, run_id=run_id))
226
+ response_proto = self._call_endpoint(GetRun, req_body)
227
+ return Run.from_proto(response_proto.run)
228
+
229
+ def update_run_info(self, run_id, run_status, end_time, run_name):
230
+ """Updates the metadata of the specified run."""
231
+ req_body = message_to_json(
232
+ UpdateRun(
233
+ run_uuid=run_id,
234
+ run_id=run_id,
235
+ status=run_status,
236
+ end_time=end_time,
237
+ run_name=run_name,
238
+ )
239
+ )
240
+ response_proto = self._call_endpoint(UpdateRun, req_body)
241
+ return RunInfo.from_proto(response_proto.run_info)
242
+
243
+ def create_run(self, experiment_id, user_id, start_time, tags, run_name):
244
+ """
245
+ Create a run under the specified experiment ID, setting the run's status to "RUNNING"
246
+ and the start time to the current time.
247
+
248
+ Args:
249
+ experiment_id: ID of the experiment for this run.
250
+ user_id: ID of the user launching this run.
251
+ start_time: timestamp of the initialization of the run.
252
+ tags: tags to apply to this run at initialization.
253
+ run_name: Name of this run.
254
+
255
+ Returns:
256
+ The created Run object.
257
+ """
258
+
259
+ tag_protos = [tag.to_proto() for tag in tags]
260
+ req_body = message_to_json(
261
+ CreateRun(
262
+ experiment_id=str(experiment_id),
263
+ user_id=user_id,
264
+ start_time=start_time,
265
+ tags=tag_protos,
266
+ run_name=run_name,
267
+ )
268
+ )
269
+ response_proto = self._call_endpoint(CreateRun, req_body)
270
+ return Run.from_proto(response_proto.run)
271
+
272
+ def start_trace(self, trace_info: TraceInfo) -> TraceInfo:
273
+ """
274
+ Create a new trace using the V3 API format.
275
+
276
+ NB: The backend API is named "StartTraceV3" for some internal reason, but actually
277
+ it is supposed to be called at the end of the trace.
278
+
279
+ Args:
280
+ trace_info: The TraceInfo object to create in the backend.
281
+
282
+ Returns:
283
+ The returned TraceInfo object from the backend.
284
+ """
285
+ # NB: The Databricks backend expects a Trace object, not a TraceInfo object, although
286
+ # it doesn't use the data field at all. Trace data increases the payload size significantly,
287
+ # so we create a Trace object with an empty data field here.
288
+ trace = Trace(info=trace_info, data=TraceData(spans=[]))
289
+ req_body = message_to_json(StartTraceV3(trace=trace.to_proto()))
290
+
291
+ try:
292
+ response_proto = self._call_endpoint(
293
+ # NB: _call_endpoint doesn't handle versioning between v2 and v3 endpoint
294
+ # yet, so manually passing the v3 endpoint here.
295
+ StartTraceV3,
296
+ req_body,
297
+ endpoint=_V3_TRACE_REST_API_PATH_PREFIX,
298
+ retry_timeout_seconds=MLFLOW_ASYNC_TRACE_LOGGING_RETRY_TIMEOUT.get(),
299
+ )
300
+ return TraceInfo.from_proto(response_proto.trace.trace_info)
301
+ except MlflowException as e:
302
+ if e.error_code == databricks_pb2.ErrorCode.Name(databricks_pb2.ENDPOINT_NOT_FOUND):
303
+ _logger.debug(
304
+ "Server does not support StartTraceV3 API yet. Falling back to V2 API."
305
+ )
306
+ return self._create_trace_v2_fallback(trace_info)
307
+ raise
308
+
309
+ def _create_trace_v2_fallback(self, trace_info: TraceInfo) -> TraceInfo:
310
+ """
311
+ Create a new trace using the V2 API format. This is a fallback for the case where the
312
+ client is v3 but the tracking server does not support v3 yet(<= 3.2.0).
313
+ """
314
+ trace_info_v2 = self.deprecated_start_trace_v2(
315
+ experiment_id=trace_info.experiment_id,
316
+ timestamp_ms=trace_info.request_time,
317
+ request_metadata=trace_info.trace_metadata,
318
+ tags=trace_info.tags,
319
+ )
320
+ self.deprecated_end_trace_v2(
321
+ request_id=trace_info_v2.request_id,
322
+ timestamp_ms=trace_info.request_time + trace_info.execution_duration,
323
+ status=trace_info.status,
324
+ request_metadata=trace_info.trace_metadata,
325
+ tags=trace_info.tags,
326
+ )
327
+ return trace_info_v2.to_v3()
328
+
329
+ def _delete_traces(
330
+ self,
331
+ experiment_id: str,
332
+ max_timestamp_millis: Optional[int] = None,
333
+ max_traces: Optional[int] = None,
334
+ trace_ids: Optional[list[str]] = None,
335
+ ) -> int:
336
+ req_body = message_to_json(
337
+ DeleteTraces(
338
+ experiment_id=experiment_id,
339
+ max_timestamp_millis=max_timestamp_millis,
340
+ max_traces=max_traces,
341
+ request_ids=trace_ids,
342
+ )
343
+ )
344
+ res = self._call_endpoint(DeleteTraces, req_body)
345
+ return res.traces_deleted
346
+
347
+ def get_trace_info(self, trace_id: str) -> TraceInfo:
348
+ """
349
+ Get the trace matching the `trace_id`.
350
+
351
+ Args:
352
+ trace_id: String id of the trace to fetch.
353
+
354
+ Returns:
355
+ The fetched Trace object, of type ``mlflow.entities.TraceInfo``.
356
+ """
357
+ trace_v3_req_body = message_to_json(GetTraceInfoV3(trace_id=trace_id))
358
+ trace_v3_endpoint = get_single_trace_endpoint(trace_id)
359
+ try:
360
+ trace_v3_response_proto = self._call_endpoint(
361
+ GetTraceInfoV3, trace_v3_req_body, endpoint=trace_v3_endpoint
362
+ )
363
+ return TraceInfo.from_proto(trace_v3_response_proto.trace.trace_info)
364
+ except MlflowException as e:
365
+ # If the tracking server does not support V3 trace API yet, fallback to V2 API.
366
+ if e.error_code != databricks_pb2.ErrorCode.Name(databricks_pb2.ENDPOINT_NOT_FOUND):
367
+ raise
368
+ _logger.debug("Server does not support GetTraceInfoV3 API yet. Falling back to V2 API.")
369
+
370
+ req_body = message_to_json(GetTraceInfo(request_id=trace_id))
371
+ endpoint = get_single_trace_endpoint(trace_id, use_v3=False)
372
+ response_proto = self._call_endpoint(GetTraceInfo, req_body, endpoint=endpoint)
373
+ return TraceInfoV2.from_proto(response_proto.trace_info).to_v3()
374
+
375
+ def get_online_trace_details(
376
+ self,
377
+ trace_id: str,
378
+ sql_warehouse_id: str,
379
+ source_inference_table: str,
380
+ source_databricks_request_id: str,
381
+ ):
382
+ req = GetOnlineTraceDetails(
383
+ trace_id=trace_id,
384
+ sql_warehouse_id=sql_warehouse_id,
385
+ source_inference_table=source_inference_table,
386
+ source_databricks_request_id=source_databricks_request_id,
387
+ )
388
+ req_body = message_to_json(req)
389
+ response_proto = self._call_endpoint(GetOnlineTraceDetails, req_body)
390
+ return response_proto.trace_data
391
+
392
+ def search_traces(
393
+ self,
394
+ experiment_ids: list[str],
395
+ filter_string: Optional[str] = None,
396
+ max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
397
+ order_by: Optional[list[str]] = None,
398
+ page_token: Optional[str] = None,
399
+ model_id: Optional[str] = None,
400
+ sql_warehouse_id: Optional[str] = None,
401
+ ):
402
+ if sql_warehouse_id is None:
403
+ # Create trace_locations from experiment_ids for the V3 API
404
+ trace_locations = []
405
+ for exp_id in experiment_ids:
406
+ try:
407
+ location = TraceLocation.from_experiment_id(exp_id)
408
+ proto_location = location.to_proto()
409
+ trace_locations.append(proto_location)
410
+ except Exception as e:
411
+ raise MlflowException(
412
+ f"Invalid experiment ID format: {exp_id}. Error: {e!s}"
413
+ ) from e
414
+
415
+ # Create V3 request message using protobuf
416
+ request = SearchTracesV3(
417
+ locations=trace_locations,
418
+ filter=filter_string,
419
+ max_results=max_results,
420
+ order_by=order_by,
421
+ page_token=page_token,
422
+ )
423
+
424
+ req_body = message_to_json(request)
425
+ v3_endpoint = f"{_V3_TRACE_REST_API_PATH_PREFIX}/search"
426
+
427
+ try:
428
+ response_proto = self._call_endpoint(SearchTracesV3, req_body, v3_endpoint)
429
+ except MlflowException as e:
430
+ if e.error_code == databricks_pb2.ErrorCode.Name(databricks_pb2.ENDPOINT_NOT_FOUND):
431
+ _logger.debug(
432
+ "Server does not support SearchTracesV3 API yet. Falling back to V2 API."
433
+ )
434
+ response_proto = self._call_endpoint(SearchTraces, req_body)
435
+ else:
436
+ raise
437
+
438
+ trace_infos = [TraceInfo.from_proto(t) for t in response_proto.traces]
439
+ else:
440
+ response_proto = self._search_unified_traces(
441
+ model_id=model_id,
442
+ sql_warehouse_id=sql_warehouse_id,
443
+ experiment_ids=experiment_ids,
444
+ filter_string=filter_string,
445
+ max_results=max_results,
446
+ order_by=order_by,
447
+ page_token=page_token,
448
+ )
449
+ # Convert TraceInfo (v2) objects to TraceInfoV3 objects for consistency
450
+ trace_infos = [TraceInfo.from_proto(t) for t in response_proto.traces]
451
+ return trace_infos, response_proto.next_page_token or None
452
+
453
+ def _search_unified_traces(
454
+ self,
455
+ model_id: str,
456
+ sql_warehouse_id: str,
457
+ experiment_ids: list[str],
458
+ filter_string: Optional[str] = None,
459
+ max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
460
+ order_by: Optional[list[str]] = None,
461
+ page_token: Optional[str] = None,
462
+ ):
463
+ request = SearchUnifiedTraces(
464
+ model_id=model_id,
465
+ sql_warehouse_id=sql_warehouse_id,
466
+ experiment_ids=experiment_ids,
467
+ filter=filter_string,
468
+ max_results=max_results,
469
+ order_by=order_by,
470
+ page_token=page_token,
471
+ )
472
+ req_body = message_to_json(request)
473
+ return self._call_endpoint(SearchUnifiedTraces, req_body)
474
+
475
+ def set_trace_tag(self, trace_id: str, key: str, value: str):
476
+ """
477
+ Set a tag on the trace with the given trace_id.
478
+
479
+ Args:
480
+ trace_id: The ID of the trace.
481
+ key: The string key of the tag.
482
+ value: The string value of the tag.
483
+ """
484
+ # Always use v2 endpoint
485
+ req_body = message_to_json(SetTraceTag(key=key, value=value))
486
+ self._call_endpoint(SetTraceTag, req_body, endpoint=get_trace_tag_endpoint(trace_id))
487
+
488
+ def delete_trace_tag(self, trace_id: str, key: str):
489
+ """
490
+ Delete a tag on the trace with the given trace_id.
491
+
492
+ Args:
493
+ trace_id: The ID of the trace.
494
+ key: The string key of the tag.
495
+ """
496
+ # Always use v2 endpoint
497
+ req_body = message_to_json(DeleteTraceTag(key=key))
498
+ self._call_endpoint(DeleteTraceTag, req_body, endpoint=get_trace_tag_endpoint(trace_id))
499
+
500
+ def get_assessment(self, trace_id: str, assessment_id: str) -> Assessment:
501
+ """
502
+ Get an assessment entity from the backend store.
503
+ """
504
+ req_body = message_to_json(
505
+ GetAssessmentRequest(trace_id=trace_id, assessment_id=assessment_id)
506
+ )
507
+ response_proto = self._call_endpoint(
508
+ GetAssessmentRequest,
509
+ req_body,
510
+ endpoint=get_single_assessment_endpoint(trace_id, assessment_id),
511
+ )
512
+ return Assessment.from_proto(response_proto.assessment)
513
+
514
+ def create_assessment(self, assessment: Assessment) -> Assessment:
515
+ """
516
+ Create an assessment entity in the backend store.
517
+
518
+ Args:
519
+ assessment: The assessment to log (without an assessment_id).
520
+
521
+ Returns:
522
+ The created Assessment object.
523
+ """
524
+ req_body = message_to_json(CreateAssessment(assessment=assessment.to_proto()))
525
+ response_proto = self._call_endpoint(
526
+ CreateAssessment,
527
+ req_body,
528
+ endpoint=f"{_V3_TRACE_REST_API_PATH_PREFIX}/{assessment.trace_id}/assessments",
529
+ )
530
+ return Assessment.from_proto(response_proto.assessment)
531
+
532
+ def update_assessment(
533
+ self,
534
+ trace_id: str,
535
+ assessment_id: str,
536
+ name: Optional[str] = None,
537
+ expectation: Optional[Expectation] = None,
538
+ feedback: Optional[Feedback] = None,
539
+ rationale: Optional[str] = None,
540
+ metadata: Optional[dict[str, str]] = None,
541
+ ) -> Assessment:
542
+ """
543
+ Update an existing assessment entity in the backend store.
544
+
545
+ Args:
546
+ trace_id: The ID of the trace.
547
+ assessment_id: The ID of the assessment to update.
548
+ name: The updated name of the assessment.
549
+ expectation: The updated expectation value of the assessment.
550
+ feedback: The updated feedback value of the assessment.
551
+ rationale: The updated rationale of the feedback. Not applicable for expectations.
552
+ metadata: Additional metadata for the assessment.
553
+ """
554
+ if expectation is not None and feedback is not None:
555
+ raise MlflowException.invalid_parameter_value(
556
+ "Exactly one of `expectation` or `feedback` should be specified."
557
+ )
558
+
559
+ update = UpdateAssessment()
560
+
561
+ # The assessment object to be sent to the backend (only contains fields to update and IDs)
562
+ assessment = update.assessment
563
+ # Field mask specifies which fields to update.
564
+ mask = update.update_mask
565
+
566
+ assessment.assessment_id = assessment_id
567
+ assessment.trace_id = trace_id
568
+
569
+ if name is not None:
570
+ assessment.assessment_name = name
571
+ mask.paths.append("assessment_name")
572
+ if expectation is not None:
573
+ assessment.expectation.CopyFrom(expectation.to_proto())
574
+ mask.paths.append("expectation")
575
+ if feedback is not None:
576
+ assessment.feedback.CopyFrom(feedback.to_proto())
577
+ mask.paths.append("feedback")
578
+ if rationale is not None:
579
+ assessment.rationale = rationale
580
+ mask.paths.append("rationale")
581
+ if metadata is not None:
582
+ assessment.metadata.update(metadata)
583
+ mask.paths.append("metadata")
584
+
585
+ req_body = message_to_json(update)
586
+ response_proto = self._call_endpoint(
587
+ UpdateAssessment,
588
+ req_body,
589
+ endpoint=get_single_assessment_endpoint(trace_id, assessment_id),
590
+ )
591
+ return Assessment.from_proto(response_proto.assessment)
592
+
593
+ def delete_assessment(self, trace_id: str, assessment_id: str):
594
+ """
595
+ Delete an assessment associated with a trace.
596
+
597
+ Args:
598
+ trace_id: String ID of the trace.
599
+ assessment_id: String ID of the assessment to delete.
600
+ """
601
+ req_body = message_to_json(DeleteAssessment(trace_id=trace_id, assessment_id=assessment_id))
602
+ self._call_endpoint(
603
+ DeleteAssessment,
604
+ req_body,
605
+ endpoint=get_single_assessment_endpoint(trace_id, assessment_id),
606
+ )
607
+
608
+ def log_metric(self, run_id: str, metric: Metric):
609
+ """
610
+ Log a metric for the specified run
611
+
612
+ Args:
613
+ run_id: String id for the run
614
+ metric: Metric instance to log
615
+ """
616
+ req_body = message_to_json(
617
+ LogMetric(
618
+ run_uuid=run_id,
619
+ run_id=run_id,
620
+ key=metric.key,
621
+ value=metric.value,
622
+ timestamp=metric.timestamp,
623
+ step=metric.step,
624
+ model_id=metric.model_id,
625
+ dataset_name=metric.dataset_name,
626
+ dataset_digest=metric.dataset_digest,
627
+ )
628
+ )
629
+ self._call_endpoint(LogMetric, req_body)
630
+
631
+ def log_param(self, run_id, param):
632
+ """
633
+ Log a param for the specified run
634
+
635
+ Args:
636
+ run_id: String id for the run
637
+ param: Param instance to log
638
+ """
639
+ req_body = message_to_json(
640
+ LogParam(run_uuid=run_id, run_id=run_id, key=param.key, value=param.value)
641
+ )
642
+ self._call_endpoint(LogParam, req_body)
643
+
644
+ def set_experiment_tag(self, experiment_id, tag):
645
+ """
646
+ Set a tag for the specified experiment
647
+
648
+ Args:
649
+ experiment_id: String ID of the experiment
650
+ tag: ExperimentRunTag instance to log
651
+ """
652
+ req_body = message_to_json(
653
+ SetExperimentTag(experiment_id=experiment_id, key=tag.key, value=tag.value)
654
+ )
655
+ self._call_endpoint(SetExperimentTag, req_body)
656
+
657
+ def set_tag(self, run_id, tag):
658
+ """
659
+ Set a tag for the specified run
660
+
661
+ Args:
662
+ run_id: String ID of the run
663
+ tag: RunTag instance to log
664
+ """
665
+ req_body = message_to_json(
666
+ SetTag(run_uuid=run_id, run_id=run_id, key=tag.key, value=tag.value)
667
+ )
668
+ self._call_endpoint(SetTag, req_body)
669
+
670
+ def delete_tag(self, run_id, key):
671
+ """
672
+ Delete a tag from a run. This is irreversible.
673
+
674
+ Args:
675
+ run_id: String ID of the run.
676
+ key: Name of the tag.
677
+ """
678
+ req_body = message_to_json(DeleteTag(run_id=run_id, key=key))
679
+ self._call_endpoint(DeleteTag, req_body)
680
+
681
+ def get_metric_history(self, run_id, metric_key, max_results=None, page_token=None):
682
+ """
683
+ Return all logged values for a given metric.
684
+
685
+ Args:
686
+ run_id: Unique identifier for run.
687
+ metric_key: Metric name within the run.
688
+ max_results: Maximum number of metric history events (steps) to return per paged
689
+ query. Only supported in 'databricks' backend.
690
+ page_token: A Token specifying the next paginated set of results of metric history.
691
+
692
+ Returns:
693
+ A PagedList of :py:class:`mlflow.entities.Metric` entities if a paginated request
694
+ is made by setting ``max_results`` to a value other than ``None``, a List of
695
+ :py:class:`mlflow.entities.Metric` entities if ``max_results`` is None, else, if no
696
+ metrics of the ``metric_key`` have been logged to the ``run_id``, an empty list.
697
+ """
698
+ req_body = message_to_json(
699
+ GetMetricHistory(
700
+ run_uuid=run_id,
701
+ run_id=run_id,
702
+ metric_key=metric_key,
703
+ max_results=max_results,
704
+ page_token=page_token,
705
+ )
706
+ )
707
+ response_proto = self._call_endpoint(GetMetricHistory, req_body)
708
+
709
+ metric_history = [Metric.from_proto(metric) for metric in response_proto.metrics]
710
+ return PagedList(metric_history, response_proto.next_page_token or None)
711
+
712
+ def _search_runs(
713
+ self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token
714
+ ):
715
+ experiment_ids = [str(experiment_id) for experiment_id in experiment_ids]
716
+ sr = SearchRuns(
717
+ experiment_ids=experiment_ids,
718
+ filter=filter_string,
719
+ run_view_type=ViewType.to_proto(run_view_type),
720
+ max_results=max_results,
721
+ order_by=order_by,
722
+ page_token=page_token,
723
+ )
724
+ req_body = message_to_json(sr)
725
+ response_proto = self._call_endpoint(SearchRuns, req_body)
726
+ runs = [Run.from_proto(proto_run) for proto_run in response_proto.runs]
727
+ # If next_page_token is not set, we will see it as "". We need to convert this to None.
728
+ next_page_token = None
729
+ if response_proto.next_page_token:
730
+ next_page_token = response_proto.next_page_token
731
+ return runs, next_page_token
732
+
733
+ def delete_run(self, run_id):
734
+ req_body = message_to_json(DeleteRun(run_id=run_id))
735
+ self._call_endpoint(DeleteRun, req_body)
736
+
737
+ def restore_run(self, run_id):
738
+ req_body = message_to_json(RestoreRun(run_id=run_id))
739
+ self._call_endpoint(RestoreRun, req_body)
740
+
741
+ def get_experiment_by_name(self, experiment_name):
742
+ try:
743
+ req_body = message_to_json(GetExperimentByName(experiment_name=experiment_name))
744
+ response_proto = self._call_endpoint(GetExperimentByName, req_body)
745
+ return Experiment.from_proto(response_proto.experiment)
746
+ except MlflowException as e:
747
+ if e.error_code == databricks_pb2.ErrorCode.Name(
748
+ databricks_pb2.RESOURCE_DOES_NOT_EXIST
749
+ ):
750
+ return None
751
+ else:
752
+ raise
753
+
754
+ def log_batch(self, run_id, metrics, params, tags):
755
+ metric_protos = [metric.to_proto() for metric in metrics]
756
+ param_protos = [param.to_proto() for param in params]
757
+ tag_protos = [tag.to_proto() for tag in tags]
758
+ req_body = message_to_json(
759
+ LogBatch(metrics=metric_protos, params=param_protos, tags=tag_protos, run_id=run_id)
760
+ )
761
+ self._call_endpoint(LogBatch, req_body)
762
+
763
+ def record_logged_model(self, run_id, mlflow_model):
764
+ req_body = message_to_json(
765
+ LogModel(run_id=run_id, model_json=json.dumps(mlflow_model.get_tags_dict()))
766
+ )
767
+ self._call_endpoint(LogModel, req_body)
768
+
769
+ def create_logged_model(
770
+ self,
771
+ experiment_id: str,
772
+ name: Optional[str] = None,
773
+ source_run_id: Optional[str] = None,
774
+ tags: Optional[list[LoggedModelTag]] = None,
775
+ params: Optional[list[LoggedModelParameter]] = None,
776
+ model_type: Optional[str] = None,
777
+ ) -> LoggedModel:
778
+ """
779
+ Create a new logged model.
780
+
781
+ Args:
782
+ experiment_id: ID of the experiment to which the model belongs.
783
+ name: Name of the model. If not specified, a random name will be generated.
784
+ source_run_id: ID of the run that produced the model.
785
+ tags: Tags to set on the model.
786
+ params: Parameters to set on the model.
787
+ model_type: Type of the model.
788
+
789
+ Returns:
790
+ The created model.
791
+ """
792
+ # Include the first 100 params in the initial request
793
+ initial_params = []
794
+ remaining_params = []
795
+ if params:
796
+ initial_batch_size = _MLFLOW_CREATE_LOGGED_MODEL_PARAMS_BATCH_SIZE.get()
797
+ initial_params = params[:initial_batch_size]
798
+ remaining_params = params[initial_batch_size:]
799
+
800
+ req_body = message_to_json(
801
+ CreateLoggedModel(
802
+ experiment_id=experiment_id,
803
+ name=name,
804
+ model_type=model_type,
805
+ source_run_id=source_run_id,
806
+ params=[p.to_proto() for p in initial_params],
807
+ tags=[t.to_proto() for t in tags or []],
808
+ )
809
+ )
810
+
811
+ response_proto = self._call_endpoint(CreateLoggedModel, req_body)
812
+ model = LoggedModel.from_proto(response_proto.model)
813
+
814
+ # Log remaining params if there are any
815
+ if remaining_params:
816
+ self.log_logged_model_params(model_id=model.model_id, params=remaining_params)
817
+ model = self.get_logged_model(model_id=model.model_id)
818
+
819
+ return model
820
+
821
+ def log_logged_model_params(self, model_id: str, params: list[LoggedModelParameter]) -> None:
822
+ """
823
+ Log parameters for a logged model in batches of 100.
824
+
825
+ Args:
826
+ model_id: ID of the model to log parameters for.
827
+ params: List of parameters to log.
828
+
829
+ Returns:
830
+ None
831
+ """
832
+ # Process params in batches to avoid exceeding per-request backend limits
833
+ batch_size = _MLFLOW_LOG_LOGGED_MODEL_PARAMS_BATCH_SIZE.get()
834
+ endpoint = get_logged_model_endpoint(model_id)
835
+ for i in range(0, len(params), batch_size):
836
+ batch = params[i : i + batch_size]
837
+ req_body = message_to_json(
838
+ LogLoggedModelParamsRequest(
839
+ model_id=model_id,
840
+ params=[p.to_proto() for p in batch],
841
+ )
842
+ )
843
+ self._call_endpoint(
844
+ LogLoggedModelParamsRequest, json_body=req_body, endpoint=f"{endpoint}/params"
845
+ )
846
+
847
+ def get_logged_model(self, model_id: str) -> LoggedModel:
848
+ """
849
+ Fetch the logged model with the specified ID.
850
+
851
+ Args:
852
+ model_id: ID of the model to fetch.
853
+
854
+ Returns:
855
+ The fetched model.
856
+ """
857
+ endpoint = get_logged_model_endpoint(model_id)
858
+ response_proto = self._call_endpoint(GetLoggedModel, endpoint=endpoint)
859
+ return LoggedModel.from_proto(response_proto.model)
860
+
861
+ def delete_logged_model(self, model_id) -> None:
862
+ request = DeleteLoggedModel(model_id=model_id)
863
+ endpoint = get_logged_model_endpoint(model_id)
864
+ self._call_endpoint(
865
+ DeleteLoggedModel, endpoint=endpoint, json_body=message_to_json(request)
866
+ )
867
+
868
+ def search_logged_models(
869
+ self,
870
+ experiment_ids: list[str],
871
+ filter_string: Optional[str] = None,
872
+ datasets: Optional[list[dict[str, Any]]] = None,
873
+ max_results: Optional[int] = None,
874
+ order_by: Optional[list[dict[str, Any]]] = None,
875
+ page_token: Optional[str] = None,
876
+ ) -> PagedList[LoggedModel]:
877
+ """
878
+ Search for logged models that match the specified search criteria.
879
+
880
+ Args:
881
+ experiment_ids: List of experiment ids to scope the search.
882
+ filter_string: A search filter string.
883
+ datasets: List of dictionaries to specify datasets on which to apply metrics filters.
884
+ The following fields are supported:
885
+
886
+ dataset_name (str): Required. Name of the dataset.
887
+ dataset_digest (str): Optional. Digest of the dataset.
888
+ max_results: Maximum number of logged models desired.
889
+ order_by: List of dictionaries to specify the ordering of the search results.
890
+ The following fields are supported:
891
+
892
+ field_name (str): Required. Name of the field to order by, e.g. "metrics.accuracy".
893
+ ascending: (bool): Optional. Whether the order is ascending or not.
894
+ dataset_name: (str): Optional. If ``field_name`` refers to a metric, this field
895
+ specifies the name of the dataset associated with the metric. Only metrics
896
+ associated with the specified dataset name will be considered for ordering.
897
+ This field may only be set if ``field_name`` refers to a metric.
898
+ dataset_digest (str): Optional. If ``field_name`` refers to a metric, this field
899
+ specifies the digest of the dataset associated with the metric. Only metrics
900
+ associated with the specified dataset name and digest will be considered for
901
+ ordering. This field may only be set if ``dataset_name`` is also set.
902
+ page_token: Token specifying the next page of results.
903
+
904
+ Returns:
905
+ A :py:class:`PagedList <mlflow.store.entities.PagedList>` of
906
+ :py:class:`LoggedModel <mlflow.entities.LoggedModel>` objects.
907
+ """
908
+ req_body = message_to_json(
909
+ SearchLoggedModels(
910
+ experiment_ids=experiment_ids,
911
+ filter=filter_string,
912
+ datasets=[
913
+ SearchLoggedModels.Dataset(
914
+ dataset_name=d["dataset_name"],
915
+ dataset_digest=d.get("dataset_digest"),
916
+ )
917
+ for d in datasets or []
918
+ ],
919
+ max_results=max_results,
920
+ order_by=[
921
+ SearchLoggedModels.OrderBy(
922
+ field_name=d["field_name"],
923
+ ascending=d.get("ascending", True),
924
+ dataset_name=d.get("dataset_name"),
925
+ dataset_digest=d.get("dataset_digest"),
926
+ )
927
+ for d in order_by or []
928
+ ],
929
+ page_token=page_token,
930
+ )
931
+ )
932
+ response_proto = self._call_endpoint(SearchLoggedModels, req_body)
933
+ models = [LoggedModel.from_proto(x) for x in response_proto.models]
934
+ return PagedList(models, response_proto.next_page_token or None)
935
+
936
+ def finalize_logged_model(self, model_id: str, status: LoggedModelStatus) -> LoggedModel:
937
+ """
938
+ Finalize a model by updating its status.
939
+
940
+ Args:
941
+ model_id: ID of the model to finalize.
942
+ status: Final status to set on the model.
943
+
944
+ Returns:
945
+ The updated model.
946
+ """
947
+ endpoint = get_logged_model_endpoint(model_id)
948
+ json_body = message_to_json(
949
+ FinalizeLoggedModel(model_id=model_id, status=status.to_proto())
950
+ )
951
+ response_proto = self._call_endpoint(
952
+ FinalizeLoggedModel, json_body=json_body, endpoint=endpoint
953
+ )
954
+ return LoggedModel.from_proto(response_proto.model)
955
+
956
+ def set_logged_model_tags(self, model_id: str, tags: list[LoggedModelTag]) -> None:
957
+ """
958
+ Set tags on the specified logged model.
959
+
960
+ Args:
961
+ model_id: ID of the model.
962
+ tags: Tags to set on the model.
963
+
964
+ Returns:
965
+ None
966
+ """
967
+ endpoint = get_logged_model_endpoint(model_id)
968
+ json_body = message_to_json(SetLoggedModelTags(tags=[tag.to_proto() for tag in tags]))
969
+ self._call_endpoint(SetLoggedModelTags, json_body=json_body, endpoint=f"{endpoint}/tags")
970
+
971
+ def delete_logged_model_tag(self, model_id: str, key: str) -> None:
972
+ """
973
+ Delete a tag from the specified logged model.
974
+
975
+ Args:
976
+ model_id: ID of the model.
977
+ key: Key of the tag to delete.
978
+
979
+ Returns:
980
+ The model with the specified tag removed.
981
+ """
982
+ endpoint = get_logged_model_endpoint(model_id)
983
+ self._call_endpoint(DeleteLoggedModelTag, endpoint=f"{endpoint}/tags/{key}")
984
+
985
+ def log_inputs(
986
+ self,
987
+ run_id: str,
988
+ datasets: Optional[list[DatasetInput]] = None,
989
+ models: Optional[list[LoggedModelInput]] = None,
990
+ ):
991
+ """
992
+ Log inputs, such as datasets, to the specified run.
993
+
994
+ Args:
995
+ run_id: String id for the run
996
+ datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log
997
+ as inputs to the run.
998
+ models: List of :py:class:`mlflow.entities.LoggedModelInput` instances to log.
999
+
1000
+ Returns:
1001
+ None.
1002
+ """
1003
+ datasets_protos = [dataset.to_proto() for dataset in datasets or []]
1004
+ models_protos = [model.to_proto() for model in models or []]
1005
+ req_body = message_to_json(
1006
+ LogInputs(
1007
+ run_id=run_id,
1008
+ datasets=datasets_protos,
1009
+ models=models_protos,
1010
+ )
1011
+ )
1012
+ self._call_endpoint(LogInputs, req_body)
1013
+
1014
+ def log_outputs(self, run_id: str, models: list[LoggedModelOutput]):
1015
+ """
1016
+ Log outputs, such as models, to the specified run.
1017
+
1018
+ Args:
1019
+ run_id: String id for the run
1020
+ models: List of :py:class:`mlflow.entities.LoggedModelOutput` instances to log
1021
+ as outputs of the run.
1022
+
1023
+ Returns:
1024
+ None.
1025
+ """
1026
+ req_body = message_to_json(LogOutputs(run_id=run_id, models=[m.to_proto() for m in models]))
1027
+ self._call_endpoint(LogOutputs, req_body)
1028
+
1029
+ ############################################################################################
1030
+ # Deprecated MLflow Tracing APIs. Kept for backward compatibility but do not use.
1031
+ ############################################################################################
1032
+ def deprecated_start_trace_v2(
1033
+ self,
1034
+ experiment_id: str,
1035
+ timestamp_ms: int,
1036
+ request_metadata: dict[str, str],
1037
+ tags: dict[str, str],
1038
+ ) -> TraceInfoV2:
1039
+ """
1040
+ DEPRECATED. DO NOT USE.
1041
+
1042
+ Start an initial TraceInfo object in the backend store.
1043
+
1044
+ Args:
1045
+ experiment_id: String id of the experiment for this run.
1046
+ timestamp_ms: Start time of the trace, in milliseconds since the UNIX epoch.
1047
+ request_metadata: Metadata of the trace.
1048
+ tags: Tags of the trace.
1049
+
1050
+ Returns:
1051
+ The created TraceInfo object.
1052
+ """
1053
+ request_metadata_proto = []
1054
+ for key, value in request_metadata.items():
1055
+ attr = TraceRequestMetadata()
1056
+ attr.key = key
1057
+ attr.value = str(value)
1058
+ request_metadata_proto.append(attr)
1059
+
1060
+ tags_proto = []
1061
+ for key, value in tags.items():
1062
+ tag = TraceTag()
1063
+ tag.key = key
1064
+ tag.value = str(value)
1065
+ tags_proto.append(tag)
1066
+
1067
+ req_body = message_to_json(
1068
+ StartTrace(
1069
+ experiment_id=str(experiment_id),
1070
+ timestamp_ms=timestamp_ms,
1071
+ request_metadata=request_metadata_proto,
1072
+ tags=tags_proto,
1073
+ )
1074
+ )
1075
+ response_proto = self._call_endpoint(StartTrace, req_body)
1076
+ return TraceInfoV2.from_proto(response_proto.trace_info)
1077
+
1078
+ def deprecated_end_trace_v2(
1079
+ self,
1080
+ request_id: str,
1081
+ timestamp_ms: int,
1082
+ status: TraceStatus,
1083
+ request_metadata: dict[str, str],
1084
+ tags: dict[str, str],
1085
+ ) -> TraceInfoV2:
1086
+ """
1087
+ DEPRECATED. DO NOT USE.
1088
+
1089
+ Update the TraceInfo object in the backend store with the completed trace info.
1090
+
1091
+ Args:
1092
+ request_id: Unique string identifier of the trace.
1093
+ timestamp_ms: End time of the trace, in milliseconds. The execution time field
1094
+ in the TraceInfo will be calculated by subtracting the start time from this.
1095
+ status: Status of the trace.
1096
+ request_metadata: Metadata of the trace. This will be merged with the existing
1097
+ metadata logged during the start_trace call.
1098
+ tags: Tags of the trace. This will be merged with the existing tags logged
1099
+ during the start_trace or set_trace_tag calls.
1100
+
1101
+ Returns:
1102
+ The updated TraceInfo object.
1103
+ """
1104
+ request_metadata_proto = []
1105
+ for key, value in request_metadata.items():
1106
+ attr = TraceRequestMetadata()
1107
+ attr.key = key
1108
+ attr.value = str(value)
1109
+ request_metadata_proto.append(attr)
1110
+
1111
+ tags_proto = []
1112
+ for key, value in tags.items():
1113
+ tag = TraceTag()
1114
+ tag.key = key
1115
+ tag.value = str(value)
1116
+ tags_proto.append(tag)
1117
+
1118
+ req_body = message_to_json(
1119
+ EndTrace(
1120
+ request_id=request_id,
1121
+ timestamp_ms=timestamp_ms,
1122
+ status=status.to_proto(),
1123
+ request_metadata=request_metadata_proto,
1124
+ tags=tags_proto,
1125
+ )
1126
+ )
1127
+ # EndTrace endpoint is a dynamic path built with the request_id
1128
+ # Always use v2 endpoint (not v3) for this endpoint to maintain compatibility
1129
+ endpoint = f"{_REST_API_PATH_PREFIX}/mlflow/traces/{request_id}"
1130
+ response_proto = self._call_endpoint(EndTrace, req_body, endpoint=endpoint)
1131
+ return TraceInfoV2.from_proto(response_proto.trace_info)