genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,1436 @@
1
+ import functools
2
+ import getpass
3
+ import json
4
+ import logging
5
+ import os
6
+ import platform
7
+ import subprocess
8
+ import time
9
+ from dataclasses import dataclass
10
+ from typing import TYPE_CHECKING, NamedTuple, Optional, TypeVar
11
+
12
+ from mlflow.utils.logging_utils import eprint
13
+ from mlflow.utils.request_utils import augmented_raise_for_status
14
+
15
+ if TYPE_CHECKING:
16
+ from pyspark.sql.connect.session import SparkSession as SparkConnectSession
17
+
18
+
19
+ import mlflow.utils
20
+ from mlflow.environment_variables import (
21
+ MLFLOW_ENABLE_DB_SDK,
22
+ MLFLOW_TRACKING_URI,
23
+ )
24
+ from mlflow.exceptions import MlflowException
25
+ from mlflow.legacy_databricks_cli.configure.provider import (
26
+ DatabricksConfig,
27
+ DatabricksConfigProvider,
28
+ DatabricksModelServingConfigProvider,
29
+ EnvironmentVariableConfigProvider,
30
+ ProfileConfigProvider,
31
+ SparkTaskContextConfigProvider,
32
+ )
33
+ from mlflow.utils._spark_utils import _get_active_spark_session
34
+ from mlflow.utils.rest_utils import MlflowHostCreds, http_request
35
+ from mlflow.utils.uri import (
36
+ _DATABRICKS_UNITY_CATALOG_SCHEME,
37
+ get_db_info_from_uri,
38
+ is_databricks_uri,
39
+ )
40
+
41
+ _logger = logging.getLogger(__name__)
42
+
43
+
44
+ _MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
45
+
46
+
47
+ def _use_repl_context_if_available(
48
+ name: str,
49
+ *,
50
+ ignore_none: bool = False,
51
+ ):
52
+ """Creates a decorator to insert a short circuit that returns the specified REPL context
53
+ attribute if it's available.
54
+
55
+ Args:
56
+ name: Attribute name (e.g. "apiUrl").
57
+ ignore_none: If True, use the original function if the REPL context attribute exists but
58
+ is None.
59
+
60
+ Returns:
61
+ Decorator to insert the short circuit.
62
+ """
63
+
64
+ def decorator(f):
65
+ @functools.wraps(f)
66
+ def wrapper(*args, **kwargs):
67
+ try:
68
+ from dbruntime.databricks_repl_context import get_context
69
+
70
+ context = get_context()
71
+ if context is not None and hasattr(context, name):
72
+ attr = getattr(context, name)
73
+ if attr is None and ignore_none:
74
+ # do nothing and continue to the original function
75
+ pass
76
+ else:
77
+ return attr
78
+ except Exception:
79
+ pass
80
+ return f(*args, **kwargs)
81
+
82
+ return wrapper
83
+
84
+ return decorator
85
+
86
+
87
+ def get_mlflow_credential_context_by_run_id(run_id):
88
+ from mlflow.tracking.artifact_utils import get_artifact_uri
89
+ from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri
90
+
91
+ run_root_artifact_uri = get_artifact_uri(run_id=run_id)
92
+ profile = get_databricks_profile_uri_from_artifact_uri(run_root_artifact_uri)
93
+ return MlflowCredentialContext(profile)
94
+
95
+
96
+ class MlflowCredentialContext:
97
+ """Sets and clears credentials on a context using the provided profile URL."""
98
+
99
+ def __init__(self, databricks_profile_url):
100
+ self.databricks_profile_url = databricks_profile_url or "databricks"
101
+ self.db_utils = _get_dbutils()
102
+
103
+ def __enter__(self):
104
+ db_creds = _get_databricks_creds_config(self.databricks_profile_url)
105
+ self.db_utils.notebook.entry_point.putMlflowProperties(
106
+ db_creds.host,
107
+ db_creds.insecure,
108
+ db_creds.token,
109
+ db_creds.username,
110
+ db_creds.password,
111
+ )
112
+
113
+ def __exit__(self, exc_type, exc_value, exc_traceback):
114
+ self.db_utils.notebook.entry_point.clearMlflowProperties()
115
+
116
+
117
+ def _get_dbutils():
118
+ try:
119
+ import IPython
120
+
121
+ ip_shell = IPython.get_ipython()
122
+ if ip_shell is None:
123
+ raise _NoDbutilsError
124
+ return ip_shell.ns_table["user_global"]["dbutils"]
125
+ except ImportError:
126
+ raise _NoDbutilsError
127
+ except KeyError:
128
+ raise _NoDbutilsError
129
+
130
+
131
+ class _NoDbutilsError(Exception):
132
+ pass
133
+
134
+
135
+ def _get_java_dbutils():
136
+ dbutils = _get_dbutils()
137
+ return dbutils.notebook.entry_point.getDbutils()
138
+
139
+
140
+ def _get_command_context():
141
+ return _get_java_dbutils().notebook().getContext()
142
+
143
+
144
+ def _get_extra_context(context_key):
145
+ opt = _get_command_context().extraContext().get(context_key)
146
+ return opt.get() if opt.isDefined() else None
147
+
148
+
149
+ def _get_context_tag(context_tag_key):
150
+ try:
151
+ tag_opt = _get_command_context().tags().get(context_tag_key)
152
+ if tag_opt.isDefined():
153
+ return tag_opt.get()
154
+ except Exception:
155
+ pass
156
+
157
+ return None
158
+
159
+
160
+ @_use_repl_context_if_available("aclPathOfAclRoot")
161
+ def acl_path_of_acl_root():
162
+ try:
163
+ return _get_command_context().aclPathOfAclRoot().get()
164
+ except Exception:
165
+ return _get_extra_context("aclPathOfAclRoot")
166
+
167
+
168
+ def _get_property_from_spark_context(key):
169
+ try:
170
+ from pyspark import TaskContext
171
+
172
+ task_context = TaskContext.get()
173
+ if task_context:
174
+ return task_context.getLocalProperty(key)
175
+ except Exception:
176
+ return None
177
+
178
+
179
+ def is_databricks_default_tracking_uri(tracking_uri):
180
+ return tracking_uri.lower().strip() == "databricks"
181
+
182
+
183
+ @_use_repl_context_if_available("isInNotebook")
184
+ def is_in_databricks_notebook():
185
+ if _get_property_from_spark_context("spark.databricks.notebook.id") is not None:
186
+ return True
187
+ try:
188
+ return path.startswith("/workspace") if (path := acl_path_of_acl_root()) else False
189
+ except Exception:
190
+ return False
191
+
192
+
193
+ @_use_repl_context_if_available("isInJob")
194
+ def is_in_databricks_job():
195
+ try:
196
+ return get_job_id() is not None and get_job_run_id() is not None
197
+ except Exception:
198
+ return False
199
+
200
+
201
+ def is_in_databricks_model_serving_environment():
202
+ """
203
+ Check if the code is running in Databricks Model Serving environment.
204
+ The environment variable set by Databricks when starting the serving container.
205
+ """
206
+ val = (
207
+ os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
208
+ # Checking the old env var name for backward compatibility. The env var was renamed once
209
+ # to fix a model loading issue, but we still need to support it for a while.
210
+ # TODO: Remove this once the new env var is fully rolled out.
211
+ or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV")
212
+ or "false"
213
+ )
214
+ return val.lower() == "true"
215
+
216
+
217
+ def is_mlflow_tracing_enabled_in_model_serving() -> bool:
218
+ """
219
+ This environment variable guards tracing behaviors for models in databricks
220
+ model serving. Tracing in serving is only enabled when this env var is true.
221
+ """
222
+ return os.environ.get("ENABLE_MLFLOW_TRACING", "false").lower() == "true"
223
+
224
+
225
+ # this should only be the case when we are in model serving environment
226
+ # and OAuth token file exists in specified path
227
+ def should_fetch_model_serving_environment_oauth():
228
+ return (
229
+ is_in_databricks_model_serving_environment()
230
+ and os.path.exists(_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH)
231
+ and os.path.isfile(_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH)
232
+ )
233
+
234
+
235
+ def is_in_databricks_repo():
236
+ try:
237
+ return get_git_repo_relative_path() is not None
238
+ except Exception:
239
+ return False
240
+
241
+
242
+ def is_in_databricks_repo_notebook():
243
+ try:
244
+ path = get_notebook_path()
245
+ return path is not None and path.startswith("/Repos")
246
+ except Exception:
247
+ return False
248
+
249
+
250
+ _DATABRICKS_VERSION_FILE_PATH = "/databricks/DBR_VERSION"
251
+
252
+
253
+ def get_databricks_runtime_version():
254
+ if ver := os.environ.get("DATABRICKS_RUNTIME_VERSION"):
255
+ return ver
256
+ if os.path.exists(_DATABRICKS_VERSION_FILE_PATH):
257
+ # In Databricks DCS cluster, it doesn't have DATABRICKS_RUNTIME_VERSION
258
+ # environment variable, we have to read version from the version file.
259
+ with open(_DATABRICKS_VERSION_FILE_PATH) as f:
260
+ return f.read().strip()
261
+ return None
262
+
263
+
264
+ def is_in_databricks_runtime():
265
+ return get_databricks_runtime_version() is not None
266
+
267
+
268
+ def is_in_databricks_serverless_runtime():
269
+ dbr_version = get_databricks_runtime_version()
270
+ return dbr_version and dbr_version.startswith("client.")
271
+
272
+
273
+ def is_in_databricks_shared_cluster_runtime():
274
+ from mlflow.utils.spark_utils import is_spark_connect_mode
275
+
276
+ return (
277
+ is_in_databricks_runtime()
278
+ and is_spark_connect_mode()
279
+ and not is_in_databricks_serverless_runtime()
280
+ )
281
+
282
+
283
+ def is_databricks_connect(spark=None):
284
+ """
285
+ Return True if current Spark-connect client connects to Databricks cluster.
286
+ """
287
+ from mlflow.utils.spark_utils import is_spark_connect_mode
288
+
289
+ if is_in_databricks_serverless_runtime() or is_in_databricks_shared_cluster_runtime():
290
+ return True
291
+
292
+ spark = spark or _get_active_spark_session()
293
+ if spark is None:
294
+ return False
295
+
296
+ if not is_spark_connect_mode():
297
+ return False
298
+
299
+ if hasattr(spark.client, "metadata"):
300
+ metadata = spark.client.metadata()
301
+ else:
302
+ metadata = spark.client._builder.metadata()
303
+
304
+ return any(k in ["x-databricks-session-id", "x-databricks-cluster-id"] for k, v in metadata)
305
+
306
+
307
+ @dataclass
308
+ class DBConnectUDFSandboxInfo:
309
+ spark: "SparkConnectSession"
310
+ image_version: str
311
+ runtime_version: str
312
+ platform_machine: str
313
+ mlflow_version: str
314
+
315
+
316
+ _dbconnect_udf_sandbox_info_cache: Optional[DBConnectUDFSandboxInfo] = None
317
+
318
+
319
+ def get_dbconnect_udf_sandbox_info(spark):
320
+ """
321
+ Get Databricks UDF sandbox info which includes the following fields:
322
+ - image_version like
323
+ '{major_version}.{minor_version}' or 'client.{major_version}.{minor_version}'
324
+ - runtime_version like '{major_version}.{minor_version}'
325
+ - platform_machine like 'x86_64' or 'aarch64'
326
+ - mlflow_version
327
+ """
328
+ global _dbconnect_udf_sandbox_info_cache
329
+ from pyspark.sql.functions import pandas_udf
330
+
331
+ if (
332
+ _dbconnect_udf_sandbox_info_cache is not None
333
+ and spark is _dbconnect_udf_sandbox_info_cache.spark
334
+ ):
335
+ return _dbconnect_udf_sandbox_info_cache
336
+
337
+ # version is like '15.4.x-scala2.12'
338
+ version = spark.sql("SELECT current_version().dbr_version").collect()[0][0]
339
+ major, minor, *_rest = version.split(".")
340
+ runtime_version = f"{major}.{minor}"
341
+
342
+ # For Databricks Serverless python REPL,
343
+ # the UDF sandbox runs on client image, which has version like 'client.1.1'
344
+ # in other cases, UDF sandbox runs on databricks runtime image with version like '15.4'
345
+ if is_in_databricks_runtime():
346
+ _dbconnect_udf_sandbox_info_cache = DBConnectUDFSandboxInfo(
347
+ spark=_get_active_spark_session(),
348
+ runtime_version=runtime_version,
349
+ image_version=get_databricks_runtime_version(),
350
+ platform_machine=platform.machine(),
351
+ # In databricks runtime, driver and executor should have the
352
+ # same version.
353
+ mlflow_version=mlflow.__version__,
354
+ )
355
+ else:
356
+ image_version = runtime_version
357
+
358
+ @pandas_udf("string")
359
+ def f(_):
360
+ import pandas as pd
361
+
362
+ platform_machine = platform.machine()
363
+
364
+ try:
365
+ import mlflow
366
+
367
+ mlflow_version = mlflow.__version__
368
+ except ImportError:
369
+ mlflow_version = ""
370
+
371
+ return pd.Series([f"{platform_machine}\n{mlflow_version}"])
372
+
373
+ platform_machine, mlflow_version = (
374
+ spark.range(1).select(f("id")).collect()[0][0].split("\n")
375
+ )
376
+ if mlflow_version == "":
377
+ mlflow_version = None
378
+ _dbconnect_udf_sandbox_info_cache = DBConnectUDFSandboxInfo(
379
+ spark=spark,
380
+ image_version=image_version,
381
+ runtime_version=runtime_version,
382
+ platform_machine=platform_machine,
383
+ mlflow_version=mlflow_version,
384
+ )
385
+
386
+ return _dbconnect_udf_sandbox_info_cache
387
+
388
+
389
+ def is_databricks_serverless(spark):
390
+ """
391
+ Return True if running on Databricks Serverless notebook or
392
+ on Databricks Connect client that connects to Databricks Serverless.
393
+ """
394
+ from mlflow.utils.spark_utils import is_spark_connect_mode
395
+
396
+ if not is_spark_connect_mode():
397
+ return False
398
+
399
+ if hasattr(spark.client, "metadata"):
400
+ metadata = spark.client.metadata()
401
+ else:
402
+ metadata = spark.client._builder.metadata()
403
+
404
+ return any(k == "x-databricks-session-id" for k, v in metadata)
405
+
406
+
407
+ def is_dbfs_fuse_available():
408
+ if not is_in_databricks_runtime():
409
+ return False
410
+
411
+ try:
412
+ return (
413
+ subprocess.call(
414
+ ["mountpoint", "/dbfs"],
415
+ stderr=subprocess.DEVNULL,
416
+ stdout=subprocess.DEVNULL,
417
+ )
418
+ == 0
419
+ )
420
+ except Exception:
421
+ return False
422
+
423
+
424
+ def is_uc_volume_fuse_available():
425
+ try:
426
+ return (
427
+ subprocess.call(
428
+ ["mountpoint", "/Volumes"],
429
+ stderr=subprocess.DEVNULL,
430
+ stdout=subprocess.DEVNULL,
431
+ )
432
+ == 0
433
+ )
434
+ except Exception:
435
+ return False
436
+
437
+
438
+ @_use_repl_context_if_available("isInCluster")
439
+ def is_in_cluster():
440
+ try:
441
+ spark_session = _get_active_spark_session()
442
+ return (
443
+ spark_session is not None
444
+ and spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId", None)
445
+ is not None
446
+ )
447
+ except Exception:
448
+ return False
449
+
450
+
451
+ @_use_repl_context_if_available("notebookId")
452
+ def get_notebook_id():
453
+ """Should only be called if is_in_databricks_notebook is true"""
454
+ if notebook_id := _get_property_from_spark_context("spark.databricks.notebook.id"):
455
+ return notebook_id
456
+ if (path := acl_path_of_acl_root()) and path.startswith("/workspace"):
457
+ return path.split("/")[-1]
458
+ return None
459
+
460
+
461
+ @_use_repl_context_if_available("notebookPath")
462
+ def get_notebook_path():
463
+ """Should only be called if is_in_databricks_notebook is true"""
464
+ path = _get_property_from_spark_context("spark.databricks.notebook.path")
465
+ if path is not None:
466
+ return path
467
+ try:
468
+ return _get_command_context().notebookPath().get()
469
+ except Exception:
470
+ return _get_extra_context("notebook_path")
471
+
472
+
473
+ @_use_repl_context_if_available("clusterId")
474
+ def get_cluster_id():
475
+ spark_session = _get_active_spark_session()
476
+ if spark_session is None:
477
+ return None
478
+ return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId", None)
479
+
480
+
481
+ @_use_repl_context_if_available("jobGroupId")
482
+ def get_job_group_id():
483
+ try:
484
+ dbutils = _get_dbutils()
485
+ job_group_id = dbutils.entry_point.getJobGroupId()
486
+ if job_group_id is not None:
487
+ return job_group_id
488
+ except Exception:
489
+ return None
490
+
491
+
492
+ @_use_repl_context_if_available("replId")
493
+ def get_repl_id():
494
+ """
495
+ Returns:
496
+ The ID of the current Databricks Python REPL.
497
+ """
498
+ # Attempt to fetch the REPL ID from the Python REPL's entrypoint object. This REPL ID
499
+ # is guaranteed to be set upon REPL startup in DBR / MLR 9.0
500
+ try:
501
+ dbutils = _get_dbutils()
502
+ repl_id = dbutils.entry_point.getReplId()
503
+ if repl_id is not None:
504
+ return repl_id
505
+ except Exception:
506
+ pass
507
+
508
+ # If the REPL ID entrypoint property is unavailable due to an older runtime version (< 9.0),
509
+ # attempt to fetch the REPL ID from the Spark Context. This property may not be available
510
+ # until several seconds after REPL startup
511
+ try:
512
+ from pyspark import SparkContext
513
+
514
+ repl_id = SparkContext.getOrCreate().getLocalProperty("spark.databricks.replId")
515
+ if repl_id is not None:
516
+ return repl_id
517
+ except Exception:
518
+ pass
519
+
520
+
521
+ @_use_repl_context_if_available("jobId")
522
+ def get_job_id():
523
+ try:
524
+ return _get_command_context().jobId().get()
525
+ except Exception:
526
+ return _get_context_tag("jobId")
527
+
528
+
529
+ @_use_repl_context_if_available("idInJob")
530
+ def get_job_run_id():
531
+ try:
532
+ return _get_command_context().idInJob().get()
533
+ except Exception:
534
+ return _get_context_tag("idInJob")
535
+
536
+
537
+ @_use_repl_context_if_available("jobTaskType")
538
+ def get_job_type():
539
+ """Should only be called if is_in_databricks_job is true"""
540
+ try:
541
+ return _get_command_context().jobTaskType().get()
542
+ except Exception:
543
+ return _get_context_tag("jobTaskType")
544
+
545
+
546
+ @_use_repl_context_if_available("jobType")
547
+ def get_job_type_info():
548
+ try:
549
+ return _get_context_tag("jobType")
550
+ except Exception:
551
+ return None
552
+
553
+
554
+ @_use_repl_context_if_available("commandRunId")
555
+ def get_command_run_id():
556
+ try:
557
+ return _get_command_context().commandRunId().get()
558
+ except Exception:
559
+ # Older runtimes may not have the commandRunId available
560
+ return None
561
+
562
+
563
+ @_use_repl_context_if_available("workloadId")
564
+ def get_workload_id():
565
+ try:
566
+ return _get_command_context().workloadId().get()
567
+ except Exception:
568
+ return _get_context_tag("workloadId")
569
+
570
+
571
+ @_use_repl_context_if_available("workloadClass")
572
+ def get_workload_class():
573
+ try:
574
+ return _get_command_context().workloadClass().get()
575
+ except Exception:
576
+ return _get_context_tag("workloadClass")
577
+
578
+
579
+ @_use_repl_context_if_available("apiUrl")
580
+ def get_webapp_url():
581
+ """Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true"""
582
+ url = _get_property_from_spark_context("spark.databricks.api.url")
583
+ if url is not None:
584
+ return url
585
+ try:
586
+ return _get_command_context().apiUrl().get()
587
+ except Exception:
588
+ return _get_extra_context("api_url")
589
+
590
+
591
+ @_use_repl_context_if_available("workspaceId")
592
+ def get_workspace_id():
593
+ try:
594
+ return _get_command_context().workspaceId().get()
595
+ except Exception:
596
+ return _get_context_tag("orgId")
597
+
598
+
599
+ @_use_repl_context_if_available("browserHostName")
600
+ def get_browser_hostname():
601
+ try:
602
+ return _get_command_context().browserHostName().get()
603
+ except Exception:
604
+ return _get_context_tag("browserHostName")
605
+
606
+
607
+ def get_workspace_info_from_dbutils():
608
+ try:
609
+ dbutils = _get_dbutils()
610
+ if dbutils:
611
+ browser_hostname = get_browser_hostname()
612
+ workspace_host = "https://" + browser_hostname if browser_hostname else get_webapp_url()
613
+ workspace_id = get_workspace_id()
614
+ return workspace_host, workspace_id
615
+ except Exception:
616
+ pass
617
+ return None, None
618
+
619
+
620
+ @_use_repl_context_if_available("workspaceUrl", ignore_none=True)
621
+ def _get_workspace_url():
622
+ try:
623
+ if spark_session := _get_active_spark_session():
624
+ if workspace_url := spark_session.conf.get("spark.databricks.workspaceUrl", None):
625
+ return workspace_url
626
+ except Exception:
627
+ return None
628
+
629
+
630
+ def get_workspace_url():
631
+ if url := _get_workspace_url():
632
+ return f"https://{url}" if not url.startswith("https://") else url
633
+ return None
634
+
635
+
636
+ def warn_on_deprecated_cross_workspace_registry_uri(registry_uri):
637
+ workspace_host, workspace_id = get_workspace_info_from_databricks_secrets(
638
+ tracking_uri=registry_uri
639
+ )
640
+ if workspace_host is not None or workspace_id is not None:
641
+ _logger.warning(
642
+ "Accessing remote workspace model registries using registry URIs of the form "
643
+ "'databricks://scope:prefix', or by loading models via URIs of the form "
644
+ "'models://scope:prefix@databricks/model-name/stage-or-version', is deprecated. "
645
+ "Use Models in Unity Catalog instead for easy cross-workspace model access, with "
646
+ "granular per-user audit logging and no extra setup required. See "
647
+ "https://docs.databricks.com/machine-learning/manage-model-lifecycle/index.html "
648
+ "for more details."
649
+ )
650
+
651
+
652
+ def get_workspace_info_from_databricks_secrets(tracking_uri):
653
+ profile, key_prefix = get_db_info_from_uri(tracking_uri)
654
+ if key_prefix:
655
+ dbutils = _get_dbutils()
656
+ if dbutils:
657
+ workspace_id = dbutils.secrets.get(scope=profile, key=key_prefix + "-workspace-id")
658
+ workspace_host = dbutils.secrets.get(scope=profile, key=key_prefix + "-host")
659
+ return workspace_host, workspace_id
660
+ return None, None
661
+
662
+
663
+ def _fail_malformed_databricks_auth(uri):
664
+ if uri and uri.startswith(_DATABRICKS_UNITY_CATALOG_SCHEME):
665
+ uri_name = "registry URI"
666
+ uri_scheme = _DATABRICKS_UNITY_CATALOG_SCHEME
667
+ else:
668
+ uri_name = "tracking URI"
669
+ uri_scheme = "databricks"
670
+ if is_in_databricks_model_serving_environment():
671
+ raise MlflowException(
672
+ f"Reading Databricks credential configuration in model serving failed. "
673
+ f"Most commonly, this happens because the model currently "
674
+ f"being served was logged without Databricks resource dependencies "
675
+ f"properly specified. Re-log your model, specifying resource dependencies as "
676
+ f"described in "
677
+ f"https://docs.databricks.com/en/generative-ai/agent-framework/log-agent.html"
678
+ f"#specify-resources-for-pyfunc-or-langchain-agent "
679
+ f"and then register and attempt to serve it again. Alternatively, you can explicitly "
680
+ f"configure authentication by setting environment variables as described in "
681
+ f"https://docs.databricks.com/en/generative-ai/agent-framework/deploy-agent.html"
682
+ f"#manual-authentication. "
683
+ f"Additional debug info: the MLflow {uri_name} was set to '{uri}'"
684
+ )
685
+ raise MlflowException(
686
+ f"Reading Databricks credential configuration failed with MLflow {uri_name} '{uri}'. "
687
+ "Please ensure that the 'databricks-sdk' PyPI library is installed, the tracking "
688
+ "URI is set correctly, and Databricks authentication is properly configured. "
689
+ f"The {uri_name} can be either '{uri_scheme}' "
690
+ f"(using profile name specified by 'DATABRICKS_CONFIG_PROFILE' environment variable "
691
+ f"or using 'DEFAULT' authentication profile if 'DATABRICKS_CONFIG_PROFILE' environment "
692
+ f"variable does not exist) or '{uri_scheme}://{{profile}}'. "
693
+ "You can configure Databricks authentication in several ways, for example by "
694
+ "specifying environment variables (e.g. DATABRICKS_HOST + DATABRICKS_TOKEN) or "
695
+ "logging in using 'databricks auth login'. \n"
696
+ "For details on configuring Databricks authentication, please refer to "
697
+ "'https://docs.databricks.com/en/dev-tools/auth/index.html#unified-auth'."
698
+ )
699
+
700
+
701
+ # Helper function to attempt to read OAuth Token from
702
+ # mounted file in Databricks Model Serving environment
703
+ def get_model_dependency_oauth_token(should_retry=True):
704
+ try:
705
+ with open(_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
706
+ oauth_dict = json.load(f)
707
+ return oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
708
+ except Exception as e:
709
+ # sleep and retry in case of any race conditions with OAuth refreshing
710
+ if should_retry:
711
+ time.sleep(0.5)
712
+ return get_model_dependency_oauth_token(should_retry=False)
713
+ else:
714
+ raise MlflowException(
715
+ "Unable to read Oauth credentials from file mount for Databricks "
716
+ "Model Serving dependency failed"
717
+ ) from e
718
+
719
+
720
+ class TrackingURIConfigProvider(DatabricksConfigProvider):
721
+ """
722
+ TrackingURIConfigProvider extracts `scope` and `key_prefix` from tracking URI
723
+ of format like `databricks://scope:key_prefix`,
724
+ then read host and token value from dbutils secrets by key
725
+ "{key_prefix}-host" and "{key_prefix}-token"
726
+
727
+ This provider only works in Databricks runtime and it is deprecated,
728
+ in Databricks runtime you can simply use 'databricks'
729
+ as the tracking URI and MLflow can automatically read dynamic token in
730
+ Databricks runtime.
731
+ """
732
+
733
+ def __init__(self, tracking_uri):
734
+ self.tracking_uri = tracking_uri
735
+
736
+ def get_config(self):
737
+ scope, key_prefix = get_db_info_from_uri(self.tracking_uri)
738
+
739
+ if scope and key_prefix:
740
+ dbutils = _get_dbutils()
741
+ if dbutils:
742
+ # Prefix differentiates users and is provided as path information in the URI
743
+ host = dbutils.secrets.get(scope=scope, key=key_prefix + "-host")
744
+ token = dbutils.secrets.get(scope=scope, key=key_prefix + "-token")
745
+ return DatabricksConfig.from_token(host=host, token=token, insecure=False)
746
+
747
+ return None
748
+
749
+
750
+ def get_databricks_host_creds(server_uri=None):
751
+ """
752
+ Reads in configuration necessary to make HTTP requests to a Databricks server. This
753
+ uses Databricks SDK workspace client API,
754
+ If no available credential configuration is found to the server URI, this function
755
+ will attempt to retrieve these credentials from the Databricks Secret Manager. For that to work,
756
+ the server URI will need to be of the following format: "databricks://scope:prefix". In the
757
+ Databricks Secret Manager, we will query for a secret in the scope "<scope>" for secrets with
758
+ keys of the form "<prefix>-host" and "<prefix>-token". Note that this prefix *cannot* be empty
759
+ if trying to authenticate with this method. If found, those host credentials will be used. This
760
+ method will throw an exception if sufficient auth cannot be found.
761
+
762
+ Args:
763
+ server_uri: A URI that specifies the Databricks profile you want to use for making
764
+ requests.
765
+
766
+ Returns:
767
+ MlflowHostCreds which includes the hostname if databricks sdk authentication is available,
768
+ otherwise includes the hostname and authentication information necessary to
769
+ talk to the Databricks server.
770
+
771
+ .. Warning:: This API is deprecated. In the future it might be removed.
772
+ """
773
+
774
+ if MLFLOW_ENABLE_DB_SDK.get():
775
+ from databricks.sdk import WorkspaceClient
776
+
777
+ profile, key_prefix = get_db_info_from_uri(server_uri)
778
+ profile = profile or os.environ.get("DATABRICKS_CONFIG_PROFILE")
779
+ if key_prefix is not None:
780
+ try:
781
+ config = TrackingURIConfigProvider(server_uri).get_config()
782
+ WorkspaceClient(host=config.host, token=config.token)
783
+ return MlflowHostCreds(
784
+ config.host,
785
+ token=config.token,
786
+ use_databricks_sdk=True,
787
+ use_secret_scope_token=True,
788
+ )
789
+ except Exception as e:
790
+ raise MlflowException(
791
+ f"The hostname and credentials configured by {server_uri} is invalid. "
792
+ "Please create valid hostname secret by command "
793
+ f"'databricks secrets put-secret {profile} {key_prefix}-host' and "
794
+ "create valid token secret by command "
795
+ f"'databricks secrets put-secret {profile} {key_prefix}-token'."
796
+ ) from e
797
+ try:
798
+ # Using databricks-sdk to create Databricks WorkspaceClient instance,
799
+ # If authentication is failed, MLflow falls back to legacy authentication methods,
800
+ # see `SparkTaskContextConfigProvider`, `DatabricksModelServingConfigProvider`,
801
+ # and `TrackingURIConfigProvider`.
802
+ # databricks-sdk supports many kinds of authentication ways,
803
+ # it will try to read authentication information by the following ways:
804
+ # 1. Read dynamic generated token via databricks `dbutils`.
805
+ # 2. parse relevant environment variables (such as DATABRICKS_HOST + DATABRICKS_TOKEN
806
+ # or DATABRICKS_HOST + DATABRICKS_CLIENT_ID + DATABRICKS_CLIENT_SECRET)
807
+ # to get authentication information
808
+ # 3. parse ~/.databrickscfg file (generated by databricks-CLI command-line tool)
809
+ # to get authentication information.
810
+ # databricks-sdk is designed to hide authentication details and
811
+ # support various authentication ways, so that it does not provide API
812
+ # to get credential values. Instead, we can use ``WorkspaceClient``
813
+ # API to invoke databricks shard restful APIs.
814
+ WorkspaceClient(profile=profile)
815
+ use_databricks_sdk = True
816
+ databricks_auth_profile = profile
817
+ except Exception as e:
818
+ _logger.debug(f"Failed to create databricks SDK workspace client, error: {e!r}")
819
+ use_databricks_sdk = False
820
+ databricks_auth_profile = None
821
+ else:
822
+ use_databricks_sdk = False
823
+ databricks_auth_profile = None
824
+
825
+ config = _get_databricks_creds_config(server_uri)
826
+
827
+ if not config:
828
+ _fail_malformed_databricks_auth(profile)
829
+
830
+ return MlflowHostCreds(
831
+ config.host,
832
+ username=config.username,
833
+ password=config.password,
834
+ ignore_tls_verification=config.insecure,
835
+ token=config.token,
836
+ client_id=config.client_id,
837
+ client_secret=config.client_secret,
838
+ use_databricks_sdk=use_databricks_sdk,
839
+ databricks_auth_profile=databricks_auth_profile,
840
+ )
841
+
842
+
843
+ @_use_repl_context_if_available("mlflowGitRepoUrl")
844
+ def get_git_repo_url():
845
+ try:
846
+ return _get_command_context().mlflowGitRepoUrl().get()
847
+ except Exception:
848
+ return _get_extra_context("mlflowGitUrl")
849
+
850
+
851
+ @_use_repl_context_if_available("mlflowGitRepoProvider")
852
+ def get_git_repo_provider():
853
+ try:
854
+ return _get_command_context().mlflowGitRepoProvider().get()
855
+ except Exception:
856
+ return _get_extra_context("mlflowGitProvider")
857
+
858
+
859
+ @_use_repl_context_if_available("mlflowGitRepoCommit")
860
+ def get_git_repo_commit():
861
+ try:
862
+ return _get_command_context().mlflowGitRepoCommit().get()
863
+ except Exception:
864
+ return _get_extra_context("mlflowGitCommit")
865
+
866
+
867
+ @_use_repl_context_if_available("mlflowGitRelativePath")
868
+ def get_git_repo_relative_path():
869
+ try:
870
+ return _get_command_context().mlflowGitRelativePath().get()
871
+ except Exception:
872
+ return _get_extra_context("mlflowGitRelativePath")
873
+
874
+
875
+ @_use_repl_context_if_available("mlflowGitRepoReference")
876
+ def get_git_repo_reference():
877
+ try:
878
+ return _get_command_context().mlflowGitRepoReference().get()
879
+ except Exception:
880
+ return _get_extra_context("mlflowGitReference")
881
+
882
+
883
+ @_use_repl_context_if_available("mlflowGitRepoReferenceType")
884
+ def get_git_repo_reference_type():
885
+ try:
886
+ return _get_command_context().mlflowGitRepoReferenceType().get()
887
+ except Exception:
888
+ return _get_extra_context("mlflowGitReferenceType")
889
+
890
+
891
+ @_use_repl_context_if_available("mlflowGitRepoStatus")
892
+ def get_git_repo_status():
893
+ try:
894
+ return _get_command_context().mlflowGitRepoStatus().get()
895
+ except Exception:
896
+ return _get_extra_context("mlflowGitStatus")
897
+
898
+
899
+ def is_running_in_ipython_environment():
900
+ try:
901
+ from IPython import get_ipython
902
+
903
+ return get_ipython() is not None
904
+ except (ImportError, ModuleNotFoundError):
905
+ return False
906
+
907
+
908
+ def get_databricks_run_url(tracking_uri: str, run_id: str, artifact_path=None) -> Optional[str]:
909
+ """
910
+ Obtains a Databricks URL corresponding to the specified MLflow Run, optionally referring
911
+ to an artifact within the run.
912
+
913
+ Args:
914
+ tracking_uri: The URI of the MLflow Tracking server containing the Run.
915
+ run_id: The ID of the MLflow Run for which to obtain a Databricks URL.
916
+ artifact_path: An optional relative artifact path within the Run to which the URL
917
+ should refer.
918
+
919
+ Returns:
920
+ A Databricks URL corresponding to the specified MLflow Run
921
+ (and artifact path, if specified), or None if the MLflow Run does not belong to a
922
+ Databricks Workspace.
923
+ """
924
+ from mlflow.tracking.client import MlflowClient
925
+
926
+ try:
927
+ workspace_info = (
928
+ DatabricksWorkspaceInfo.from_environment()
929
+ or get_databricks_workspace_info_from_uri(tracking_uri)
930
+ )
931
+ if workspace_info is not None:
932
+ experiment_id = MlflowClient(tracking_uri).get_run(run_id).info.experiment_id
933
+ return _construct_databricks_run_url(
934
+ host=workspace_info.host,
935
+ experiment_id=experiment_id,
936
+ run_id=run_id,
937
+ workspace_id=workspace_info.workspace_id,
938
+ artifact_path=artifact_path,
939
+ )
940
+ except Exception:
941
+ return None
942
+
943
+
944
+ def get_databricks_model_version_url(registry_uri: str, name: str, version: str) -> Optional[str]:
945
+ """Obtains a Databricks URL corresponding to the specified Model Version.
946
+
947
+ Args:
948
+ registry_uri: The URI of the Model Registry server containing the Model Version.
949
+ name: The name of the registered model containing the Model Version.
950
+ version: Version number of the Model Version.
951
+
952
+ Returns:
953
+ A Databricks URL corresponding to the specified Model Version, or None if the
954
+ Model Version does not belong to a Databricks Workspace.
955
+
956
+ """
957
+ try:
958
+ workspace_info = (
959
+ DatabricksWorkspaceInfo.from_environment()
960
+ or get_databricks_workspace_info_from_uri(registry_uri)
961
+ )
962
+ if workspace_info is not None:
963
+ return _construct_databricks_model_version_url(
964
+ host=workspace_info.host,
965
+ name=name,
966
+ version=version,
967
+ workspace_id=workspace_info.workspace_id,
968
+ )
969
+ except Exception:
970
+ return None
971
+
972
+
973
+ DatabricksWorkspaceInfoType = TypeVar("DatabricksWorkspaceInfo", bound="DatabricksWorkspaceInfo")
974
+
975
+
976
+ class DatabricksWorkspaceInfo:
977
+ WORKSPACE_HOST_ENV_VAR = "_DATABRICKS_WORKSPACE_HOST"
978
+ WORKSPACE_ID_ENV_VAR = "_DATABRICKS_WORKSPACE_ID"
979
+
980
+ def __init__(self, host: str, workspace_id: Optional[str] = None):
981
+ self.host = host
982
+ self.workspace_id = workspace_id
983
+
984
+ @classmethod
985
+ def from_environment(cls) -> Optional[DatabricksWorkspaceInfoType]:
986
+ if DatabricksWorkspaceInfo.WORKSPACE_HOST_ENV_VAR in os.environ:
987
+ return DatabricksWorkspaceInfo(
988
+ host=os.environ[DatabricksWorkspaceInfo.WORKSPACE_HOST_ENV_VAR],
989
+ workspace_id=os.environ.get(DatabricksWorkspaceInfo.WORKSPACE_ID_ENV_VAR),
990
+ )
991
+ else:
992
+ return None
993
+
994
+ def to_environment(self):
995
+ env = {
996
+ DatabricksWorkspaceInfo.WORKSPACE_HOST_ENV_VAR: self.host,
997
+ }
998
+ if self.workspace_id is not None:
999
+ env[DatabricksWorkspaceInfo.WORKSPACE_ID_ENV_VAR] = self.workspace_id
1000
+
1001
+ return env
1002
+
1003
+
1004
+ def get_databricks_workspace_info_from_uri(tracking_uri: str) -> Optional[DatabricksWorkspaceInfo]:
1005
+ if not is_databricks_uri(tracking_uri):
1006
+ return None
1007
+
1008
+ if is_databricks_default_tracking_uri(tracking_uri) and (
1009
+ is_in_databricks_notebook() or is_in_databricks_job()
1010
+ ):
1011
+ workspace_host, workspace_id = get_workspace_info_from_dbutils()
1012
+ else:
1013
+ workspace_host, workspace_id = get_workspace_info_from_databricks_secrets(tracking_uri)
1014
+ if not workspace_id:
1015
+ _logger.info(
1016
+ "No workspace ID specified; if your Databricks workspaces share the same"
1017
+ " host URL, you may want to specify the workspace ID (along with the host"
1018
+ " information in the secret manager) for run lineage tracking. For more"
1019
+ " details on how to specify this information in the secret manager,"
1020
+ " please refer to the Databricks MLflow documentation."
1021
+ )
1022
+
1023
+ if workspace_host:
1024
+ return DatabricksWorkspaceInfo(host=workspace_host, workspace_id=workspace_id)
1025
+ else:
1026
+ return None
1027
+
1028
+
1029
+ def check_databricks_secret_scope_access(scope_name):
1030
+ dbutils = _get_dbutils()
1031
+ if dbutils:
1032
+ try:
1033
+ dbutils.secrets.list(scope_name)
1034
+ except Exception as e:
1035
+ _logger.warning(
1036
+ f"Unable to access Databricks secret scope '{scope_name}' for OpenAI credentials "
1037
+ "that will be used to deploy the model to Databricks Model Serving. "
1038
+ "Please verify that the current Databricks user has 'READ' permission for "
1039
+ "this scope. For more information, see "
1040
+ "https://mlflow.org/docs/latest/python_api/openai/index.html#credential-management-for-openai-on-databricks. " # noqa: E501
1041
+ f"Error: {e}"
1042
+ )
1043
+
1044
+
1045
+ def _construct_databricks_run_url(
1046
+ host: str,
1047
+ experiment_id: str,
1048
+ run_id: str,
1049
+ workspace_id: Optional[str] = None,
1050
+ artifact_path: Optional[str] = None,
1051
+ ) -> str:
1052
+ run_url = host
1053
+ if workspace_id and workspace_id != "0":
1054
+ run_url += "?o=" + str(workspace_id)
1055
+
1056
+ run_url += f"#mlflow/experiments/{experiment_id}/runs/{run_id}"
1057
+
1058
+ if artifact_path is not None:
1059
+ run_url += f"/artifactPath/{artifact_path.lstrip('/')}"
1060
+
1061
+ return run_url
1062
+
1063
+
1064
+ def _construct_databricks_model_version_url(
1065
+ host: str, name: str, version: str, workspace_id: Optional[str] = None
1066
+ ) -> str:
1067
+ model_version_url = host
1068
+ if workspace_id and workspace_id != "0":
1069
+ model_version_url += "?o=" + str(workspace_id)
1070
+
1071
+ model_version_url += f"#mlflow/models/{name}/versions/{version}"
1072
+
1073
+ return model_version_url
1074
+
1075
+
1076
+ def _construct_databricks_logged_model_url(
1077
+ workspace_url: str, experiment_id: str, model_id: str, workspace_id: Optional[str] = None
1078
+ ) -> str:
1079
+ """
1080
+ Get a Databricks URL for a given registered model version in Unity Catalog.
1081
+
1082
+ Args:
1083
+ workspace_url: The URL of the workspace the registered model is in.
1084
+ experiment_id: The ID of the experiment the model is logged to.
1085
+ model_id: The ID of the logged model to create the URL for.
1086
+ workspace_id: The ID of the workspace to include as a query parameter (if provided).
1087
+
1088
+ Returns:
1089
+ The Databricks URL for a registered model in Unity Catalog.
1090
+ """
1091
+ query = f"?o={workspace_id}" if (workspace_id and workspace_id != "0") else ""
1092
+ return f"{workspace_url}/ml/experiments/{experiment_id}/models/{model_id}{query}"
1093
+
1094
+
1095
+ def _construct_databricks_uc_registered_model_url(
1096
+ workspace_url: str, registered_model_name: str, version: str, workspace_id: Optional[str] = None
1097
+ ) -> str:
1098
+ """
1099
+ Get a Databricks URL for a given registered model version in Unity Catalog.
1100
+
1101
+ Args:
1102
+ workspace_url: The URL of the workspace the registered model is in.
1103
+ registered_model_name: The full name of the registered model containing the version.
1104
+ version: The version of the registered model to create the URL for.
1105
+ workspace_id: The ID of the workspace to include as a query parameter (if provided).
1106
+
1107
+ Returns:
1108
+ The Databricks URL for a registered model in Unity Catalog.
1109
+ """
1110
+ path = registered_model_name.replace(".", "/")
1111
+ query = f"?o={workspace_id}" if (workspace_id and workspace_id != "0") else ""
1112
+ return f"{workspace_url}/explore/data/models/{path}/version/{version}{query}"
1113
+
1114
+
1115
+ def _print_databricks_deployment_job_url(
1116
+ model_name: str,
1117
+ job_id: str,
1118
+ workspace_url: Optional[str] = None,
1119
+ workspace_id: Optional[str] = None,
1120
+ ) -> str:
1121
+ if not workspace_url:
1122
+ workspace_url = get_workspace_url()
1123
+ if not workspace_id:
1124
+ workspace_id = get_workspace_id()
1125
+ # If there is no workspace_url, we cannot print the job URL
1126
+ if not workspace_url:
1127
+ return None
1128
+
1129
+ query = f"?o={workspace_id}" if (workspace_id and workspace_id != "0") else ""
1130
+ job_url = f"{workspace_url}/jobs/{job_id}{query}"
1131
+ eprint(f"🔗 Linked deployment job to '{model_name}': {job_url}")
1132
+ return job_url
1133
+
1134
+
1135
+ def _get_databricks_creds_config(tracking_uri):
1136
+ # Note:
1137
+ # `_get_databricks_creds_config` reads credential token values or password and
1138
+ # returns a `DatabricksConfig` object
1139
+ # Databricks-SDK API doesn't support reading credential token values,
1140
+ # so that in this function we still have to use
1141
+ # configuration providers defined in legacy Databricks CLI python library to
1142
+ # read token values.
1143
+ profile, key_prefix = get_db_info_from_uri(tracking_uri)
1144
+
1145
+ config = None
1146
+
1147
+ if profile and key_prefix:
1148
+ # legacy way to read credentials by setting `tracking_uri` to 'databricks://scope:prefix'
1149
+ providers = [TrackingURIConfigProvider(tracking_uri)]
1150
+ elif profile:
1151
+ # If `tracking_uri` is 'databricks://<profile>'
1152
+ # MLflow should only read credentials from this profile
1153
+ providers = [ProfileConfigProvider(profile)]
1154
+ else:
1155
+ providers = [
1156
+ # `EnvironmentVariableConfigProvider` should be prioritized at the highest level,
1157
+ # to align with Databricks-SDK behavior.
1158
+ EnvironmentVariableConfigProvider(),
1159
+ _dynamic_token_config_provider,
1160
+ ProfileConfigProvider(None),
1161
+ SparkTaskContextConfigProvider(),
1162
+ DatabricksModelServingConfigProvider(),
1163
+ ]
1164
+
1165
+ for provider in providers:
1166
+ if provider:
1167
+ _config = provider.get_config()
1168
+ if _config is not None and _config.is_valid:
1169
+ config = _config
1170
+ break
1171
+
1172
+ if not config or not config.host:
1173
+ _fail_malformed_databricks_auth(tracking_uri)
1174
+
1175
+ return config
1176
+
1177
+
1178
+ def get_databricks_env_vars(tracking_uri):
1179
+ if not mlflow.utils.uri.is_databricks_uri(tracking_uri):
1180
+ return {}
1181
+
1182
+ config = _get_databricks_creds_config(tracking_uri)
1183
+
1184
+ if config.auth_type == "databricks-cli":
1185
+ raise MlflowException(
1186
+ "You configured authentication type to 'databricks-cli', in this case, MLflow cannot "
1187
+ "read credential values, so that MLflow cannot construct the databricks environment "
1188
+ "variables for child process authentication."
1189
+ )
1190
+
1191
+ # We set these via environment variables so that only the current profile is exposed, rather
1192
+ # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary
1193
+ # part of ~/.databrickscfg into the container
1194
+ env_vars = {}
1195
+ env_vars[MLFLOW_TRACKING_URI.name] = "databricks"
1196
+ env_vars["DATABRICKS_HOST"] = config.host
1197
+ if config.username:
1198
+ env_vars["DATABRICKS_USERNAME"] = config.username
1199
+ if config.password:
1200
+ env_vars["DATABRICKS_PASSWORD"] = config.password
1201
+ if config.token:
1202
+ env_vars["DATABRICKS_TOKEN"] = config.token
1203
+ if config.insecure:
1204
+ env_vars["DATABRICKS_INSECURE"] = str(config.insecure)
1205
+ if config.client_id:
1206
+ env_vars["DATABRICKS_CLIENT_ID"] = config.client_id
1207
+ if config.client_secret:
1208
+ env_vars["DATABRICKS_CLIENT_SECRET"] = config.client_secret
1209
+
1210
+ workspace_info = get_databricks_workspace_info_from_uri(tracking_uri)
1211
+ if workspace_info is not None:
1212
+ env_vars.update(workspace_info.to_environment())
1213
+
1214
+ return env_vars
1215
+
1216
+
1217
+ def _get_databricks_serverless_env_vars() -> dict[str, str]:
1218
+ """
1219
+ Returns the environment variables required to to initialize WorkspaceClient in a subprocess
1220
+ with serverless compute.
1221
+
1222
+ Note:
1223
+ Databricks authentication related environment variables such as DATABRICKS_HOST are
1224
+ set in the are set in the _capture_imported_modules function.
1225
+ """
1226
+ envs = {}
1227
+ if "SPARK_REMOTE" in os.environ:
1228
+ envs["SPARK_LOCAL_REMOTE"] = os.environ["SPARK_REMOTE"]
1229
+ else:
1230
+ _logger.warning(
1231
+ "Missing required environment variable `SPARK_LOCAL_REMOTE` or `SPARK_REMOTE`. "
1232
+ "These are necessary to initialize the WorkspaceClient with serverless compute in "
1233
+ "a subprocess in Databricks for UC function execution. Setting the value to 'true'."
1234
+ )
1235
+ envs["SPARK_LOCAL_REMOTE"] = "true"
1236
+ return envs
1237
+
1238
+
1239
+ class DatabricksRuntimeVersion(NamedTuple):
1240
+ is_client_image: bool
1241
+ major: int
1242
+ minor: int
1243
+
1244
+ @classmethod
1245
+ def parse(cls, databricks_runtime: Optional[str] = None):
1246
+ dbr_version = databricks_runtime or get_databricks_runtime_version()
1247
+ try:
1248
+ dbr_version_splits = dbr_version.split(".", maxsplit=2)
1249
+ if dbr_version_splits[0] == "client":
1250
+ is_client_image = True
1251
+ major = int(dbr_version_splits[1])
1252
+ minor = int(dbr_version_splits[2]) if len(dbr_version_splits) > 2 else 0
1253
+ else:
1254
+ is_client_image = False
1255
+ major = int(dbr_version_splits[0])
1256
+ minor = int(dbr_version_splits[1])
1257
+ return cls(is_client_image, major, minor)
1258
+ except Exception:
1259
+ raise MlflowException(f"Failed to parse databricks runtime version '{dbr_version}'.")
1260
+
1261
+
1262
+ def get_databricks_runtime_major_minor_version():
1263
+ return DatabricksRuntimeVersion.parse()
1264
+
1265
+
1266
+ _dynamic_token_config_provider = None
1267
+
1268
+
1269
+ def _init_databricks_dynamic_token_config_provider(entry_point):
1270
+ """
1271
+ set a custom DatabricksConfigProvider with the hostname and token of the
1272
+ user running the current command (achieved by looking at
1273
+ PythonAccessibleThreadLocals.commandContext, via the already-exposed
1274
+ NotebookUtils.getContext API)
1275
+ """
1276
+ global _dynamic_token_config_provider
1277
+
1278
+ notebook_utils = entry_point.getDbutils().notebook()
1279
+
1280
+ dbr_version = get_databricks_runtime_major_minor_version()
1281
+ dbr_major_minor_version = (dbr_version.major, dbr_version.minor)
1282
+
1283
+ # the CLI code in client-branch-1.0 is the same as in the 15.0 runtime branch
1284
+ if dbr_version.is_client_image or dbr_major_minor_version >= (13, 2):
1285
+
1286
+ class DynamicConfigProvider(DatabricksConfigProvider):
1287
+ def get_config(self):
1288
+ logger = entry_point.getLogger()
1289
+ try:
1290
+ from dbruntime.databricks_repl_context import get_context
1291
+
1292
+ ctx = get_context()
1293
+ if ctx and ctx.apiUrl and ctx.apiToken:
1294
+ return DatabricksConfig.from_token(
1295
+ host=ctx.apiUrl, token=ctx.apiToken, insecure=ctx.sslTrustAll
1296
+ )
1297
+ except Exception as e:
1298
+ _logger.debug(
1299
+ "Unexpected internal error while constructing `DatabricksConfig` "
1300
+ f"from REPL context: {e}",
1301
+ )
1302
+ # Invoking getContext() will attempt to find the credentials related to the
1303
+ # current command execution, so it's critical that we execute it on every
1304
+ # get_config().
1305
+ api_url_option = notebook_utils.getContext().apiUrl()
1306
+ api_url = api_url_option.get() if api_url_option.isDefined() else None
1307
+ # Invoking getNonUcApiToken() will attempt to find the current credentials related
1308
+ # to the current command execution and refresh it if its expired automatically,
1309
+ # so it's critical that we execute it on every get_config().
1310
+ api_token = None
1311
+ try:
1312
+ api_token = entry_point.getNonUcApiToken()
1313
+ except Exception:
1314
+ # Using apiToken from command context would return back the token which is not
1315
+ # refreshed.
1316
+ fallback_api_token_option = notebook_utils.getContext().apiToken()
1317
+ logger.logUsage(
1318
+ "refreshableTokenNotFound",
1319
+ {"api_url": api_url},
1320
+ None,
1321
+ )
1322
+ if fallback_api_token_option.isDefined():
1323
+ api_token = fallback_api_token_option.get()
1324
+
1325
+ ssl_trust_all = entry_point.getDriverConf().workflowSslTrustAll()
1326
+
1327
+ if api_token is None or api_url is None:
1328
+ return None
1329
+
1330
+ return DatabricksConfig.from_token(
1331
+ host=api_url, token=api_token, insecure=ssl_trust_all
1332
+ )
1333
+ elif dbr_major_minor_version >= (10, 3):
1334
+
1335
+ class DynamicConfigProvider(DatabricksConfigProvider):
1336
+ def get_config(self):
1337
+ try:
1338
+ from dbruntime.databricks_repl_context import get_context
1339
+
1340
+ ctx = get_context()
1341
+ if ctx and ctx.apiUrl and ctx.apiToken:
1342
+ return DatabricksConfig.from_token(
1343
+ host=ctx.apiUrl, token=ctx.apiToken, insecure=ctx.sslTrustAll
1344
+ )
1345
+ except Exception as e:
1346
+ _logger.debug(
1347
+ "Unexpected internal error while constructing `DatabricksConfig` "
1348
+ f"from REPL context: {e}",
1349
+ )
1350
+ # Invoking getContext() will attempt to find the credentials related to the
1351
+ # current command execution, so it's critical that we execute it on every
1352
+ # get_config().
1353
+ api_token_option = notebook_utils.getContext().apiToken()
1354
+ api_url_option = notebook_utils.getContext().apiUrl()
1355
+ ssl_trust_all = entry_point.getDriverConf().workflowSslTrustAll()
1356
+
1357
+ if not api_token_option.isDefined() or not api_url_option.isDefined():
1358
+ return None
1359
+
1360
+ return DatabricksConfig.from_token(
1361
+ host=api_url_option.get(), token=api_token_option.get(), insecure=ssl_trust_all
1362
+ )
1363
+ else:
1364
+
1365
+ class DynamicConfigProvider(DatabricksConfigProvider):
1366
+ def get_config(self):
1367
+ # Invoking getContext() will attempt to find the credentials related to the
1368
+ # current command execution, so it's critical that we execute it on every
1369
+ # get_config().
1370
+ api_token_option = notebook_utils.getContext().apiToken()
1371
+ api_url_option = notebook_utils.getContext().apiUrl()
1372
+ ssl_trust_all = entry_point.getDriverConf().workflowSslTrustAll()
1373
+
1374
+ if not api_token_option.isDefined() or not api_url_option.isDefined():
1375
+ return None
1376
+
1377
+ return DatabricksConfig.from_token(
1378
+ host=api_url_option.get(), token=api_token_option.get(), insecure=ssl_trust_all
1379
+ )
1380
+
1381
+ _dynamic_token_config_provider = DynamicConfigProvider()
1382
+
1383
+
1384
+ if is_in_databricks_runtime():
1385
+ try:
1386
+ dbutils = _get_dbutils()
1387
+ _init_databricks_dynamic_token_config_provider(dbutils.entry_point)
1388
+ except _NoDbutilsError:
1389
+ # If there is no dbutils available, it means it is run in databricks driver local suite,
1390
+ # in this case, we don't need to initialize databricks token because
1391
+ # there is no backend mlflow service available.
1392
+ pass
1393
+
1394
+
1395
+ def get_databricks_nfs_temp_dir():
1396
+ entry_point = _get_dbutils().entry_point
1397
+ if getpass.getuser().lower() == "root":
1398
+ return entry_point.getReplNFSTempDir()
1399
+ else:
1400
+ try:
1401
+ # If it is not ROOT user, it means the code is running in Safe-spark.
1402
+ # In this case, we should get temporary directory of current user.
1403
+ # and `getReplNFSTempDir` will be deprecated for this case.
1404
+ return entry_point.getUserNFSTempDir()
1405
+ except Exception:
1406
+ # fallback
1407
+ return entry_point.getReplNFSTempDir()
1408
+
1409
+
1410
+ def get_databricks_local_temp_dir():
1411
+ entry_point = _get_dbutils().entry_point
1412
+ if getpass.getuser().lower() == "root":
1413
+ return entry_point.getReplLocalTempDir()
1414
+ else:
1415
+ try:
1416
+ # If it is not ROOT user, it means the code is running in Safe-spark.
1417
+ # In this case, we should get temporary directory of current user.
1418
+ # and `getReplLocalTempDir` will be deprecated for this case.
1419
+ return entry_point.getUserLocalTempDir()
1420
+ except Exception:
1421
+ # fallback
1422
+ return entry_point.getReplLocalTempDir()
1423
+
1424
+
1425
+ def stage_model_for_databricks_model_serving(model_name: str, model_version: str):
1426
+ response = http_request(
1427
+ host_creds=get_databricks_host_creds(),
1428
+ endpoint="/api/2.0/serving-endpoints:stageDeployment",
1429
+ method="POST",
1430
+ raise_on_status=False,
1431
+ json={
1432
+ "model_name": model_name,
1433
+ "model_version": model_version,
1434
+ },
1435
+ )
1436
+ augmented_raise_for_status(response)