genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,646 @@
1
+ import copy
2
+ import datetime
3
+ import json
4
+ import threading
5
+ import time
6
+ import uuid
7
+ from collections.abc import Container, Sequence
8
+ from typing import Any, Optional
9
+
10
+ from mlflow import MlflowClient
11
+ from mlflow.entities import Metric, Param, RunTag
12
+ from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID
13
+
14
+ try:
15
+ from optuna._typing import JSONSerializable
16
+ from optuna.distributions import (
17
+ BaseDistribution,
18
+ check_distribution_compatibility,
19
+ distribution_to_json,
20
+ json_to_distribution,
21
+ )
22
+ from optuna.storages import BaseStorage
23
+ from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
24
+ from optuna.study import StudyDirection
25
+ from optuna.study._frozen import FrozenStudy
26
+ from optuna.trial import FrozenTrial, TrialState
27
+ except ImportError as e:
28
+ raise ImportError("Install optuna to use `mlflow.optuna` module") from e
29
+
30
+ optuna_mlflow_status_map = {
31
+ TrialState.RUNNING: "RUNNING",
32
+ TrialState.COMPLETE: "FINISHED",
33
+ TrialState.PRUNED: "KILLED",
34
+ TrialState.FAIL: "FAILED",
35
+ TrialState.WAITING: "SCHEDULED",
36
+ }
37
+
38
+ mlflow_optuna_status_map = {
39
+ "RUNNING": TrialState.RUNNING,
40
+ "FINISHED": TrialState.COMPLETE,
41
+ "KILLED": TrialState.PRUNED,
42
+ "FAILED": TrialState.FAIL,
43
+ "SCHEDULED": TrialState.WAITING,
44
+ }
45
+
46
+
47
+ class MlflowStorage(BaseStorage):
48
+ """
49
+ MLflow based storage class with batch processing to avoid REST API throttling.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ experiment_id: str,
55
+ name: Optional[str] = None,
56
+ batch_flush_interval: float = 1.0,
57
+ batch_size_threshold: int = 100,
58
+ ):
59
+ """
60
+ Initialize MLFlowStorage with batching capabilities.
61
+
62
+ Parameters
63
+ ----------
64
+ experiment_id : str
65
+ MLflow experiment ID
66
+ name : Optional[str]
67
+ Optional name for the storage
68
+ batch_flush_interval : float
69
+ Time in seconds between automatic batch flushes (default: 1.0)
70
+ batch_size_threshold : int
71
+ Maximum number of items in batch before triggering a flush (default: 100)
72
+ """
73
+ if not experiment_id:
74
+ raise Exception("No experiment_id provided. MLFlowStorage cannot create experiments.")
75
+
76
+ self._experiment_id = experiment_id
77
+ self._mlflow_client = MlflowClient()
78
+ self._name = name
79
+
80
+ # Batching configuration
81
+ self._batch_flush_interval = batch_flush_interval
82
+ self._batch_size_threshold = batch_size_threshold
83
+
84
+ # Batching queues for metrics, parameters, and tags
85
+ self._batch_queue = {} # Dictionary of run_id -> {'metrics': [], 'params': [], 'tags': []}
86
+ self._batch_lock = threading.RLock()
87
+ self._last_flush_time = time.time()
88
+
89
+ # Flag to indicate if the worker should stop - must be defined BEFORE starting the thread
90
+ self._stop_worker = False
91
+
92
+ # Start a background thread for periodic flushing
93
+ self._flush_thread = threading.Thread(
94
+ target=self._periodic_flush_worker,
95
+ daemon=True,
96
+ name=f"mlflow_optuna_batch_flush_worker_{uuid.uuid4().hex[:8]}",
97
+ )
98
+ self._flush_thread.start()
99
+
100
+ def __getstate__(self):
101
+ """
102
+ Prepare the object for serialization by removing non-picklable components.
103
+ This is called when the object is being pickled.
104
+ """
105
+ state = self.__dict__.copy()
106
+
107
+ # Remove thread-related attributes that can't be pickled
108
+ state.pop("_batch_lock", None)
109
+ state.pop("_flush_thread", None)
110
+
111
+ # Store the configuration but not the actual lock/thread
112
+ state["_thread_running"] = hasattr(self, "_flush_thread") and self._flush_thread.is_alive()
113
+
114
+ return state
115
+
116
+ def __setstate__(self, state):
117
+ """
118
+ Restore the object after deserialization by recreating non-picklable components.
119
+ This is called when the object is being unpickled.
120
+ """
121
+ # First, update the instance with the pickled state
122
+ self.__dict__.update(state)
123
+
124
+ # Recreate the lock
125
+ self._batch_lock = threading.RLock()
126
+
127
+ # Don't automatically restart the thread on workers - this would create too many threads
128
+ # Instead, we'll use a manual flush approach in distributed contexts
129
+ self._flush_thread = None
130
+
131
+ # If we're on a worker node, we should disable automatic background flushing
132
+ # because it could cause issues with multiple threads trying to write to MLflow
133
+ self._stop_worker = True
134
+
135
+ def __del__(self):
136
+ """Ensure all queued data is flushed before destroying the object."""
137
+ # Set the stop flag
138
+ if hasattr(self, "_stop_worker"):
139
+ self._stop_worker = True
140
+
141
+ # Join the thread if it exists and is alive
142
+ if hasattr(self, "_flush_thread") and self._flush_thread.is_alive():
143
+ try:
144
+ self._flush_thread.join(timeout=5.0)
145
+ except Exception:
146
+ pass # Ignore errors during cleanup
147
+
148
+ # Flush any remaining data
149
+ if hasattr(self, "_batch_queue"):
150
+ try:
151
+ self.flush_all_batches()
152
+ except Exception:
153
+ pass # Ignore errors during cleanup
154
+
155
+ def _periodic_flush_worker(self):
156
+ """Background worker that periodically flushes batched data."""
157
+ while not self._stop_worker:
158
+ try:
159
+ time.sleep(min(0.1, self._batch_flush_interval / 10)) # Sleep in small increments
160
+
161
+ # Check if it's time to flush
162
+ current_time = time.time()
163
+ if current_time - self._last_flush_time >= self._batch_flush_interval:
164
+ self.flush_all_batches()
165
+ self._last_flush_time = current_time
166
+ except Exception:
167
+ # Catch any exceptions to prevent thread crashes
168
+ time.sleep(1.0) # Sleep a bit longer if there was an error
169
+
170
+ def _queue_batch_operation(
171
+ self,
172
+ run_id: str,
173
+ metrics: Optional[list[Metric]] = None,
174
+ params: Optional[list[Param]] = None,
175
+ tags: Optional[list[RunTag]] = None,
176
+ ):
177
+ """Queue metrics, parameters, or tags for batched processing."""
178
+ with self._batch_lock:
179
+ if run_id not in self._batch_queue:
180
+ self._batch_queue[run_id] = {"metrics": [], "params": [], "tags": []}
181
+
182
+ batch = self._batch_queue[run_id]
183
+
184
+ if metrics:
185
+ batch["metrics"].extend(metrics)
186
+ if params:
187
+ batch["params"].extend(params)
188
+ if tags:
189
+ batch["tags"].extend(tags)
190
+
191
+ # Check if we've reached the batch size threshold for this run
192
+ batch_size = len(batch["metrics"]) + len(batch["params"]) + len(batch["tags"])
193
+ if batch_size >= self._batch_size_threshold:
194
+ self._flush_batch(run_id)
195
+
196
+ def _flush_batch(self, run_id: str):
197
+ """Flush the batch for a specific run_id to MLflow."""
198
+ with self._batch_lock:
199
+ if run_id not in self._batch_queue:
200
+ return
201
+
202
+ batch = self._batch_queue[run_id]
203
+
204
+ # Only make the API call if there's something to flush
205
+ if batch["metrics"] or batch["params"] or batch["tags"]:
206
+ try:
207
+ self._mlflow_client.log_batch(
208
+ run_id, metrics=batch["metrics"], params=batch["params"], tags=batch["tags"]
209
+ )
210
+ except Exception as e:
211
+ # If the run doesn't exist, propagate the error
212
+ if "Run with id=" in str(e) and "not found" in str(e):
213
+ raise
214
+ # Otherwise, handle or log the error as needed
215
+
216
+ # Clear the batch
217
+ batch["metrics"] = []
218
+ batch["params"] = []
219
+ batch["tags"] = []
220
+
221
+ def flush_all_batches(self):
222
+ """Flush all pending batches to MLflow."""
223
+ with self._batch_lock:
224
+ run_ids = list(self._batch_queue.keys())
225
+
226
+ # Flush each run's batch
227
+ for run_id in run_ids:
228
+ self._flush_batch(run_id)
229
+
230
+ def _search_run_by_name(self, run_name: str):
231
+ filter_string = f"tags.mlflow.runName = '{run_name}'"
232
+ return self._mlflow_client.search_runs(
233
+ experiment_ids=[self._experiment_id], filter_string=filter_string
234
+ )
235
+
236
+ def create_new_study(
237
+ self, directions: Sequence[StudyDirection], study_name: Optional[str] = None
238
+ ) -> int:
239
+ """Create a new study as a mlflow run."""
240
+ study_name = study_name or DEFAULT_STUDY_NAME_PREFIX + str(uuid.uuid4())
241
+ tags = {
242
+ "mlflow.runName": study_name,
243
+ "optuna.study_direction": ",".join(direction.name for direction in directions),
244
+ }
245
+ study_run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=tags)
246
+ return study_run.info.run_id
247
+
248
+ def delete_study(self, study_id) -> None:
249
+ """Delete a study."""
250
+ # Ensure any pending changes are saved before deletion
251
+ self._flush_batch(study_id)
252
+ self._mlflow_client.delete_run(study_id)
253
+
254
+ def set_study_user_attr(self, study_id, key: str, value: JSONSerializable) -> None:
255
+ """Register a user-defined attribute as mlflow run tags to a study run."""
256
+ # Verify the run exists first to fail fast if it doesn't
257
+ self._mlflow_client.get_run(study_id)
258
+
259
+ # Queue the tag if the run exists
260
+ self._queue_batch_operation(study_id, tags=[RunTag(f"user_{key}", json.dumps(value))])
261
+
262
+ def set_study_system_attr(self, study_id, key: str, value: JSONSerializable) -> None:
263
+ """Register a optuna-internal attribute as mlflow run tags to a study run."""
264
+ # Verify the run exists first to fail fast if it doesn't
265
+ self._mlflow_client.get_run(study_id)
266
+
267
+ # Queue the tag if the run exists
268
+ self._queue_batch_operation(study_id, tags=[RunTag(f"sys_{key}", json.dumps(value))])
269
+
270
+ def get_study_id_from_name(self, study_name: str) -> int:
271
+ # Flush all batches to ensure we have the latest data
272
+ self.flush_all_batches()
273
+
274
+ runs = self._search_run_by_name(study_name)
275
+ if len(runs):
276
+ return runs[0].info.run_id
277
+ else:
278
+ raise Exception(f"Study {study_name} not found")
279
+
280
+ def get_study_name_from_id(self, study_id) -> str:
281
+ # Flush the batch for this study to ensure we have the latest data
282
+ self._flush_batch(study_id)
283
+
284
+ run = self._mlflow_client.get_run(study_id)
285
+ return run.data.tags["mlflow.runName"]
286
+
287
+ def get_study_directions(self, study_id) -> list[StudyDirection]:
288
+ # Flush the batch for this study to ensure we have the latest data
289
+ self._flush_batch(study_id)
290
+
291
+ run = self._mlflow_client.get_run(study_id)
292
+ directions_str = run.data.tags["optuna.study_direction"]
293
+ return [StudyDirection[name] for name in directions_str.split(",")]
294
+
295
+ def get_study_user_attrs(self, study_id) -> dict[str, Any]:
296
+ # Flush the batch for this study to ensure we have the latest data
297
+ self._flush_batch(study_id)
298
+
299
+ run = self._mlflow_client.get_run(study_id)
300
+ user_attrs = {}
301
+ for key, value in run.data.tags.items():
302
+ if key.startswith("user_"):
303
+ user_attrs[key[5:]] = json.loads(value)
304
+ return user_attrs
305
+
306
+ def get_study_system_attrs(self, study_id) -> dict[str, Any]:
307
+ # Flush the batch for this study to ensure we have the latest data
308
+ self._flush_batch(study_id)
309
+
310
+ run = self._mlflow_client.get_run(study_id)
311
+ system_attrs = {}
312
+ for key, value in run.data.tags.items():
313
+ if key.startswith("sys_"):
314
+ system_attrs[key[4:]] = json.loads(value)
315
+ return system_attrs
316
+
317
+ def get_all_studies(self) -> list[FrozenStudy]:
318
+ # Flush all batches to ensure we have the latest data
319
+ self.flush_all_batches()
320
+
321
+ runs = self._mlflow_client.search_runs(experiment_ids=[self._experiment_id])
322
+ studies = []
323
+ for run in runs:
324
+ study_id = run.info.run_id
325
+ study_name = run.data.tags["mlflow.runName"]
326
+ directions_str = run.data.tags["optuna.study_direction"]
327
+ directions = [StudyDirection[name] for name in directions_str.split(",")]
328
+ studies.append(
329
+ FrozenStudy(
330
+ study_name=study_name,
331
+ direction=None,
332
+ directions=directions,
333
+ user_attrs=self.get_study_user_attrs(study_id),
334
+ system_attrs=self.get_study_system_attrs(study_id),
335
+ study_id=study_id,
336
+ )
337
+ )
338
+ return studies
339
+
340
+ def create_new_trial(self, study_id, template_trial: Optional[FrozenTrial] = None) -> int:
341
+ # Ensure study batch is flushed before creating a new trial
342
+ self._flush_batch(study_id)
343
+
344
+ if template_trial:
345
+ frozen = copy.deepcopy(template_trial)
346
+ else:
347
+ frozen = FrozenTrial(
348
+ trial_id=-1, # dummy value.
349
+ number=-1, # dummy value.
350
+ state=TrialState.RUNNING,
351
+ params={},
352
+ distributions={},
353
+ user_attrs={},
354
+ system_attrs={},
355
+ value=None,
356
+ intermediate_values={},
357
+ datetime_start=datetime.datetime.now(),
358
+ datetime_complete=None,
359
+ )
360
+
361
+ distribution_json = {
362
+ k: distribution_to_json(dist) for k, dist in frozen.distributions.items()
363
+ }
364
+ distribution_str = json.dumps(distribution_json)
365
+ tags = {"param_directions": distribution_str}
366
+
367
+ trial_run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=tags)
368
+ trial_id = trial_run.info.run_id
369
+
370
+ # Add parent run ID tag
371
+ self._queue_batch_operation(trial_id, tags=[RunTag(MLFLOW_PARENT_RUN_ID, study_id)])
372
+
373
+ # Log trial_id metric to study
374
+ hash_id = float(hash(trial_id))
375
+ self._queue_batch_operation(
376
+ study_id, metrics=[Metric("trial_id", hash_id, int(time.time() * 1000), 1)]
377
+ )
378
+
379
+ # Ensure study batch is flushed to get accurate metric history
380
+ self._flush_batch(study_id)
381
+
382
+ trial_ids = self._mlflow_client.get_metric_history(study_id, "trial_id")
383
+ index = next((i for i, obj in enumerate(trial_ids) if obj.value == hash_id), -1)
384
+
385
+ self._queue_batch_operation(trial_id, tags=[RunTag("numbers", str(index))])
386
+
387
+ # Set trial state
388
+ state = frozen.state
389
+ if state.is_finished():
390
+ self._mlflow_client.set_terminated(trial_id, status=optuna_mlflow_status_map[state])
391
+ else:
392
+ self._mlflow_client.update_run(trial_id, status=optuna_mlflow_status_map[state])
393
+
394
+ timestamp = int(time.time() * 1000)
395
+ metrics = []
396
+ params = []
397
+ tags = []
398
+
399
+ # Add metrics
400
+ if frozen.values is not None:
401
+ if len(frozen.values) > 1:
402
+ metrics.extend(
403
+ [
404
+ Metric(f"value_{idx}", val, timestamp, 1)
405
+ for idx, val in enumerate(frozen.values)
406
+ ]
407
+ )
408
+ else:
409
+ metrics.append(Metric("value", frozen.values[0], timestamp, 1))
410
+ elif frozen.value is not None:
411
+ metrics.append(Metric("value", frozen.value, timestamp, 1))
412
+
413
+ # Add intermediate values
414
+ metrics.extend(
415
+ [
416
+ Metric("intermediate_value", val, timestamp, int(k))
417
+ for k, val in frozen.intermediate_values.items()
418
+ ]
419
+ )
420
+
421
+ # Add params
422
+ params.extend([Param(k, param) for k, param in frozen.params.items()])
423
+
424
+ # Add tags
425
+ tags.extend(
426
+ [RunTag(f"user_{key}", json.dumps(value)) for key, value in frozen.user_attrs.items()]
427
+ )
428
+ tags.extend(
429
+ [RunTag(f"sys_{key}", json.dumps(value)) for key, value in frozen.system_attrs.items()]
430
+ )
431
+ tags.extend(
432
+ [
433
+ RunTag(
434
+ f"param_internal_val_{k}",
435
+ json.dumps(frozen.distributions[k].to_internal_repr(param)),
436
+ )
437
+ for k, param in frozen.params.items()
438
+ ]
439
+ )
440
+
441
+ # Queue all the data to be sent in batches
442
+ self._queue_batch_operation(trial_id, metrics=metrics, params=params, tags=tags)
443
+
444
+ return trial_id
445
+
446
+ def set_trial_param(
447
+ self,
448
+ trial_id,
449
+ param_name: str,
450
+ param_value_internal: float,
451
+ distribution: BaseDistribution,
452
+ ) -> None:
453
+ # Flush the batch for this trial to ensure we have the latest data
454
+ self._flush_batch(trial_id)
455
+
456
+ trial_run = self._mlflow_client.get_run(trial_id)
457
+ distributions_dict = json.loads(trial_run.data.tags["param_directions"])
458
+ self.check_trial_is_updatable(trial_id, mlflow_optuna_status_map[trial_run.info.status])
459
+
460
+ if param_name in trial_run.data.params:
461
+ param_distribution = json_to_distribution(distributions_dict[param_name])
462
+ check_distribution_compatibility(param_distribution, distribution)
463
+
464
+ # Queue parameter update
465
+ self._queue_batch_operation(
466
+ trial_id,
467
+ params=[Param(param_name, distribution.to_external_repr(param_value_internal))],
468
+ tags=[RunTag(f"param_internal_val_{param_name}", json.dumps(param_value_internal))],
469
+ )
470
+
471
+ distributions_dict[param_name] = distribution_to_json(distribution)
472
+ self._queue_batch_operation(
473
+ trial_id, tags=[RunTag("param_directions", json.dumps(distributions_dict))]
474
+ )
475
+
476
+ def get_trial_id_from_study_id_trial_number(self, study_id, trial_number: int) -> int:
477
+ raise NotImplementedError("This method is not supported in MLflow backend.")
478
+
479
+ def get_trial_number_from_id(self, trial_id) -> int:
480
+ # Flush the batch for this trial to ensure we have the latest data
481
+ self._flush_batch(trial_id)
482
+
483
+ trial_run = self._mlflow_client.get_run(trial_id)
484
+ return int(trial_run.data.tags.get("numbers", 0))
485
+
486
+ def get_trial_param(self, trial_id, param_name: str) -> float:
487
+ # Flush the batch for this trial to ensure we have the latest data
488
+ self._flush_batch(trial_id)
489
+
490
+ trial_run = self._mlflow_client.get_run(trial_id)
491
+ param_value = trial_run.data.tags[f"param_internal_val_{param_name}"]
492
+
493
+ return float(json.loads(param_value))
494
+
495
+ def set_trial_state_values(
496
+ self, trial_id, state: TrialState, values: Optional[Sequence[float]] = None
497
+ ) -> bool:
498
+ # Update trial state
499
+ if state.is_finished():
500
+ self._mlflow_client.set_terminated(trial_id, status=optuna_mlflow_status_map[state])
501
+ else:
502
+ self._mlflow_client.update_run(trial_id, status=optuna_mlflow_status_map[state])
503
+
504
+ # Queue value metrics if provided
505
+ timestamp = int(time.time() * 1000)
506
+ if values is not None:
507
+ metrics = []
508
+ if len(values) > 1:
509
+ metrics = [
510
+ Metric(f"value_{idx}", val, timestamp, 1) for idx, val in enumerate(values)
511
+ ]
512
+ else:
513
+ metrics = [Metric("value", values[0], timestamp, 1)]
514
+
515
+ self._queue_batch_operation(trial_id, metrics=metrics)
516
+
517
+ if state == TrialState.RUNNING and state != TrialState.WAITING:
518
+ return False
519
+ return True
520
+
521
+ def set_trial_intermediate_value(self, trial_id, step: int, intermediate_value: float) -> None:
522
+ # Queue intermediate value metric
523
+ self._queue_batch_operation(
524
+ trial_id,
525
+ metrics=[
526
+ Metric("intermediate_value", intermediate_value, int(time.time() * 1000), step)
527
+ ],
528
+ )
529
+
530
+ def set_trial_user_attr(self, trial_id, key: str, value: Any) -> None:
531
+ # Queue user attribute tag
532
+ self._queue_batch_operation(trial_id, tags=[RunTag(f"user_{key}", json.dumps(value))])
533
+
534
+ def set_trial_system_attr(self, trial_id, key: str, value: Any) -> None:
535
+ # Queue system attribute tag
536
+ self._queue_batch_operation(trial_id, tags=[RunTag(f"sys_{key}", json.dumps(value))])
537
+
538
+ def get_trial(self, trial_id) -> FrozenTrial:
539
+ # Flush the batch for this trial to ensure we have the latest data
540
+ self._flush_batch(trial_id)
541
+
542
+ trial_run = self._mlflow_client.get_run(trial_id)
543
+ param_directions = trial_run.data.tags["param_directions"]
544
+ try:
545
+ distributions_dict = json.loads(param_directions)
546
+ except json.decoder.JSONDecodeError as e:
547
+ raise ValueError(f"error with param_directions = {param_directions!r}") from e
548
+
549
+ distributions = {
550
+ k: json_to_distribution(distribution) for k, distribution in distributions_dict.items()
551
+ }
552
+ params = {}
553
+ for key, value in trial_run.data.tags.items():
554
+ if key.startswith("param_internal_val_"):
555
+ param_name = key[19:]
556
+ param_value = json.loads(value)
557
+ params[param_name] = distributions[param_name].to_external_repr(float(param_value))
558
+
559
+ metrics = trial_run.data.metrics
560
+ values = None
561
+ if "value" in metrics:
562
+ values = [metrics["value"]]
563
+ if "value_0" in metrics:
564
+ values = [metrics[f"value_{idx}"] for idx in range(len(metrics))]
565
+
566
+ run_number = int(trial_run.data.tags.get("numbers", 0))
567
+
568
+ start_time = datetime.datetime.fromtimestamp(trial_run.info.start_time / 1000)
569
+ if trial_run.info.end_time:
570
+ end_time = datetime.datetime.fromtimestamp(trial_run.info.end_time / 1000)
571
+ else:
572
+ end_time = None
573
+ return FrozenTrial(
574
+ trial_id=trial_id,
575
+ number=run_number,
576
+ state=mlflow_optuna_status_map[trial_run.info.status],
577
+ value=None,
578
+ values=values,
579
+ datetime_start=start_time,
580
+ datetime_complete=end_time,
581
+ params=params,
582
+ distributions=distributions,
583
+ user_attrs=self.get_trial_user_attrs(trial_id),
584
+ system_attrs=self.get_trial_system_attrs(trial_id),
585
+ intermediate_values={
586
+ v.step: v.value
587
+ for idx, v in enumerate(
588
+ self._mlflow_client.get_metric_history(trial_id, "intermediate_value")
589
+ )
590
+ },
591
+ )
592
+
593
+ def get_trial_user_attrs(self, trial_id) -> dict[str, Any]:
594
+ # Flush the batch for this trial to ensure we have the latest data
595
+ self._flush_batch(trial_id)
596
+
597
+ run = self._mlflow_client.get_run(trial_id)
598
+ user_attrs = {}
599
+ for key, value in run.data.tags.items():
600
+ if key.startswith("user_"):
601
+ user_attrs[key[5:]] = json.loads(value)
602
+ return user_attrs
603
+
604
+ def get_trial_system_attrs(self, trial_id) -> dict[str, Any]:
605
+ # Flush the batch for this trial to ensure we have the latest data
606
+ self._flush_batch(trial_id)
607
+
608
+ run = self._mlflow_client.get_run(trial_id)
609
+ system_attrs = {}
610
+ for key, value in run.data.tags.items():
611
+ if key.startswith("sys_"):
612
+ system_attrs[key[4:]] = json.loads(value)
613
+ return system_attrs
614
+
615
+ def get_all_trials(
616
+ self,
617
+ study_id,
618
+ deepcopy: bool = True,
619
+ states: Optional[Container[TrialState]] = None,
620
+ ) -> list[FrozenTrial]:
621
+ # Flush all batches to ensure we have the latest data
622
+ self.flush_all_batches()
623
+
624
+ runs = self._mlflow_client.search_runs(
625
+ experiment_ids=[self._experiment_id],
626
+ filter_string=f"tags.mlflow.parentRunId='{study_id}'",
627
+ )
628
+ trials = []
629
+ for run in runs:
630
+ trials.append(self.get_trial(run.info.run_id))
631
+
632
+ frozen_trials: list[FrozenTrial] = []
633
+ for trial in trials:
634
+ if states is None or trial.state in states:
635
+ frozen_trials.append(trial)
636
+ return frozen_trials
637
+
638
+ def get_n_trials(self, study_id, states=None) -> int:
639
+ # Flush all batches to ensure we have the latest data
640
+ self.flush_all_batches()
641
+
642
+ runs = self._mlflow_client.search_runs(
643
+ experiment_ids=[self._experiment_id],
644
+ filter_string=f"tags.mlflow.parentRunId='{study_id}'",
645
+ )
646
+ return len(runs)