genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,704 @@
1
+ import logging
2
+ import math
3
+ from collections import namedtuple
4
+ from contextlib import contextmanager
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ from sklearn import metrics as sk_metrics
10
+
11
+ import mlflow
12
+ from mlflow import MlflowException
13
+ from mlflow.environment_variables import _MLFLOW_EVALUATE_SUPPRESS_CLASSIFICATION_ERRORS
14
+ from mlflow.models.evaluation.artifacts import CsvEvaluationArtifact
15
+ from mlflow.models.evaluation.base import EvaluationMetric, EvaluationResult, _ModelType
16
+ from mlflow.models.evaluation.default_evaluator import (
17
+ BuiltInEvaluator,
18
+ _extract_raw_model,
19
+ _get_aggregate_metrics_values,
20
+ )
21
+ from mlflow.models.utils import plot_lines
22
+
23
+ _logger = logging.getLogger(__name__)
24
+
25
+
26
+ _Curve = namedtuple("_Curve", ["plot_fn", "plot_fn_args", "auc"])
27
+
28
+
29
+ class ClassifierEvaluator(BuiltInEvaluator):
30
+ """
31
+ A built-in evaluator for classifier models.
32
+ """
33
+
34
+ name = "classifier"
35
+
36
+ @classmethod
37
+ def can_evaluate(cls, *, model_type, evaluator_config, **kwargs):
38
+ # TODO: Also the model needs to be pyfunc model, not function or endpoint URI
39
+ return model_type == _ModelType.CLASSIFIER
40
+
41
+ def _evaluate(
42
+ self,
43
+ model: Optional["mlflow.pyfunc.PyFuncModel"],
44
+ extra_metrics: list[EvaluationMetric],
45
+ custom_artifacts=None,
46
+ **kwargs,
47
+ ) -> Optional[EvaluationResult]:
48
+ # Get classification config
49
+ self.y_true = self.dataset.labels_data
50
+ self.label_list = self.evaluator_config.get("label_list")
51
+ self.pos_label = self.evaluator_config.get("pos_label")
52
+ self.sample_weights = self.evaluator_config.get("sample_weights")
53
+ if self.pos_label and self.label_list and self.pos_label not in self.label_list:
54
+ raise MlflowException.invalid_parameter_value(
55
+ f"'pos_label' {self.pos_label} must exist in 'label_list' {self.label_list}."
56
+ )
57
+
58
+ # Check if the model_type is consistent with ground truth labels
59
+ inferred_model_type = _infer_model_type_by_labels(self.y_true)
60
+ if _ModelType.CLASSIFIER != inferred_model_type:
61
+ _logger.warning(
62
+ f"According to the evaluation dataset label values, the model type looks like "
63
+ f"{inferred_model_type}, but you specified model type 'classifier'. Please "
64
+ f"verify that you set the `model_type` and `dataset` arguments correctly."
65
+ )
66
+
67
+ # Run model prediction
68
+ input_df = self.X.copy_to_avoid_mutation()
69
+ self.y_pred, self.y_probs = self._generate_model_predictions(model, input_df)
70
+
71
+ self._validate_label_list()
72
+
73
+ self._compute_builtin_metrics(model)
74
+ self.evaluate_metrics(extra_metrics, prediction=self.y_pred, target=self.y_true)
75
+ self.evaluate_and_log_custom_artifacts(
76
+ custom_artifacts, prediction=self.y_pred, target=self.y_true
77
+ )
78
+
79
+ # Log metrics and artifacts
80
+ self.log_metrics()
81
+ self.log_eval_table(self.y_pred)
82
+
83
+ if len(self.label_list) == 2:
84
+ self._log_binary_classifier_artifacts()
85
+ else:
86
+ self._log_multiclass_classifier_artifacts()
87
+ self._log_confusion_matrix()
88
+
89
+ return EvaluationResult(
90
+ metrics=self.aggregate_metrics, artifacts=self.artifacts, run_id=self.run_id
91
+ )
92
+
93
+ def _generate_model_predictions(self, model, input_df):
94
+ predict_fn, predict_proba_fn = _extract_predict_fn_and_prodict_proba_fn(model)
95
+ # Classifier model is guaranteed to output single column of predictions
96
+ y_pred = self.dataset.predictions_data if model is None else predict_fn(input_df)
97
+
98
+ # Predict class probabilities if the model supports it
99
+ y_probs = predict_proba_fn(input_df) if predict_proba_fn is not None else None
100
+ return y_pred, y_probs
101
+
102
+ def _validate_label_list(self):
103
+ if self.label_list is None:
104
+ # If label list is not specified, infer label list from model output
105
+ self.label_list = np.unique(np.concatenate([self.y_true, self.y_pred]))
106
+ else:
107
+ # np.where only works for numpy array, not list
108
+ self.label_list = np.array(self.label_list)
109
+
110
+ if len(self.label_list) < 2:
111
+ raise MlflowException(
112
+ "Evaluation dataset for classification must contain at least two unique "
113
+ f"labels, but only {len(self.label_list)} unique labels were found.",
114
+ "Please provide a 'label_list' parameter in 'evaluator_config' with all "
115
+ "possible classes, e.g., evaluator_config={{'label_list': [0, 1]}}.",
116
+ )
117
+
118
+ # sort label_list ASC, for binary classification it makes sure the last one is pos label
119
+ self.label_list.sort()
120
+
121
+ if len(self.label_list) == 2:
122
+ if self.pos_label is None:
123
+ self.pos_label = self.label_list[-1]
124
+ else:
125
+ if self.pos_label in self.label_list:
126
+ self.label_list = np.delete(
127
+ self.label_list, np.where(self.label_list == self.pos_label)
128
+ )
129
+ self.label_list = np.append(self.label_list, self.pos_label)
130
+ with _suppress_class_imbalance_errors(IndexError, log_warning=False):
131
+ _logger.info(
132
+ "The evaluation dataset is inferred as binary dataset, positive label is "
133
+ f"{self.label_list[1]}, negative label is {self.label_list[0]}."
134
+ )
135
+ else:
136
+ _logger.info(
137
+ "The evaluation dataset is inferred as multiclass dataset, number of classes "
138
+ f"is inferred as {len(self.label_list)}. If this is incorrect, please specify the "
139
+ "`label_list` parameter in `evaluator_config`."
140
+ )
141
+
142
+ def _compute_builtin_metrics(self, model):
143
+ self._evaluate_sklearn_model_score_if_scorable(model, self.y_true, self.sample_weights)
144
+
145
+ if len(self.label_list) == 2:
146
+ metrics = _get_binary_classifier_metrics(
147
+ y_true=self.y_true,
148
+ y_pred=self.y_pred,
149
+ y_proba=self.y_probs,
150
+ labels=self.label_list,
151
+ pos_label=self.pos_label,
152
+ sample_weights=self.sample_weights,
153
+ )
154
+ if metrics:
155
+ self.metrics_values.update(_get_aggregate_metrics_values(metrics))
156
+ self._compute_roc_and_pr_curve()
157
+ else:
158
+ average = self.evaluator_config.get("average", "weighted")
159
+ metrics = _get_multiclass_classifier_metrics(
160
+ y_true=self.y_true,
161
+ y_pred=self.y_pred,
162
+ y_proba=self.y_probs,
163
+ labels=self.label_list,
164
+ average=average,
165
+ sample_weights=self.sample_weights,
166
+ )
167
+ if metrics:
168
+ self.metrics_values.update(_get_aggregate_metrics_values(metrics))
169
+
170
+ def _compute_roc_and_pr_curve(self):
171
+ if self.y_probs is not None:
172
+ with _suppress_class_imbalance_errors(ValueError, log_warning=False):
173
+ self.roc_curve = _gen_classifier_curve(
174
+ is_binomial=True,
175
+ y=self.y_true,
176
+ y_probs=self.y_probs[:, 1],
177
+ labels=self.label_list,
178
+ pos_label=self.pos_label,
179
+ curve_type="roc",
180
+ sample_weights=self.sample_weights,
181
+ )
182
+
183
+ self.metrics_values.update(
184
+ _get_aggregate_metrics_values({"roc_auc": self.roc_curve.auc})
185
+ )
186
+ with _suppress_class_imbalance_errors(ValueError, log_warning=False):
187
+ self.pr_curve = _gen_classifier_curve(
188
+ is_binomial=True,
189
+ y=self.y_true,
190
+ y_probs=self.y_probs[:, 1],
191
+ labels=self.label_list,
192
+ pos_label=self.pos_label,
193
+ curve_type="pr",
194
+ sample_weights=self.sample_weights,
195
+ )
196
+
197
+ self.metrics_values.update(
198
+ _get_aggregate_metrics_values({"precision_recall_auc": self.pr_curve.auc})
199
+ )
200
+
201
+ def _log_pandas_df_artifact(self, pandas_df, artifact_name):
202
+ artifact_file_name = f"{artifact_name}.csv"
203
+ artifact_file_local_path = self.temp_dir.path(artifact_file_name)
204
+ pandas_df.to_csv(artifact_file_local_path, index=False)
205
+ mlflow.log_artifact(artifact_file_local_path)
206
+ artifact = CsvEvaluationArtifact(
207
+ uri=mlflow.get_artifact_uri(artifact_file_name),
208
+ content=pandas_df,
209
+ )
210
+ artifact._load(artifact_file_local_path)
211
+ self.artifacts[artifact_name] = artifact
212
+
213
+ def _log_multiclass_classifier_artifacts(self):
214
+ per_class_metrics_collection_df = _get_classifier_per_class_metrics_collection_df(
215
+ y=self.y_true,
216
+ y_pred=self.y_pred,
217
+ labels=self.label_list,
218
+ sample_weights=self.sample_weights,
219
+ )
220
+
221
+ log_roc_pr_curve = False
222
+ if self.y_probs is not None:
223
+ with _suppress_class_imbalance_errors(TypeError, log_warning=False):
224
+ self._log_calibration_curve()
225
+
226
+ max_classes_for_multiclass_roc_pr = self.evaluator_config.get(
227
+ "max_classes_for_multiclass_roc_pr", 10
228
+ )
229
+ if len(self.label_list) <= max_classes_for_multiclass_roc_pr:
230
+ log_roc_pr_curve = True
231
+ else:
232
+ _logger.warning(
233
+ f"The classifier num_classes > {max_classes_for_multiclass_roc_pr}, skip "
234
+ f"logging ROC curve and Precision-Recall curve. You can add evaluator config "
235
+ f"'max_classes_for_multiclass_roc_pr' to increase the threshold."
236
+ )
237
+
238
+ if log_roc_pr_curve:
239
+ roc_curve = _gen_classifier_curve(
240
+ is_binomial=False,
241
+ y=self.y_true,
242
+ y_probs=self.y_probs,
243
+ labels=self.label_list,
244
+ pos_label=self.pos_label,
245
+ curve_type="roc",
246
+ sample_weights=self.sample_weights,
247
+ )
248
+
249
+ def plot_roc_curve():
250
+ roc_curve.plot_fn(**roc_curve.plot_fn_args)
251
+
252
+ self._log_image_artifact(plot_roc_curve, "roc_curve_plot")
253
+ per_class_metrics_collection_df["roc_auc"] = roc_curve.auc
254
+
255
+ pr_curve = _gen_classifier_curve(
256
+ is_binomial=False,
257
+ y=self.y_true,
258
+ y_probs=self.y_probs,
259
+ labels=self.label_list,
260
+ pos_label=self.pos_label,
261
+ curve_type="pr",
262
+ sample_weights=self.sample_weights,
263
+ )
264
+
265
+ def plot_pr_curve():
266
+ pr_curve.plot_fn(**pr_curve.plot_fn_args)
267
+
268
+ self._log_image_artifact(plot_pr_curve, "precision_recall_curve_plot")
269
+ per_class_metrics_collection_df["precision_recall_auc"] = pr_curve.auc
270
+
271
+ self._log_pandas_df_artifact(per_class_metrics_collection_df, "per_class_metrics")
272
+
273
+ def _log_roc_curve(self):
274
+ def _plot_roc_curve():
275
+ self.roc_curve.plot_fn(**self.roc_curve.plot_fn_args)
276
+
277
+ self._log_image_artifact(_plot_roc_curve, "roc_curve_plot")
278
+
279
+ def _log_precision_recall_curve(self):
280
+ def _plot_pr_curve():
281
+ self.pr_curve.plot_fn(**self.pr_curve.plot_fn_args)
282
+
283
+ self._log_image_artifact(_plot_pr_curve, "precision_recall_curve_plot")
284
+
285
+ def _log_lift_curve(self):
286
+ from mlflow.models.evaluation.lift_curve import plot_lift_curve
287
+
288
+ def _plot_lift_curve():
289
+ return plot_lift_curve(self.y_true, self.y_probs, pos_label=self.pos_label)
290
+
291
+ self._log_image_artifact(_plot_lift_curve, "lift_curve_plot")
292
+
293
+ def _log_calibration_curve(self):
294
+ from mlflow.models.evaluation.calibration_curve import plot_calibration_curve
295
+
296
+ def _plot_calibration_curve():
297
+ return plot_calibration_curve(
298
+ y_true=self.y_true,
299
+ y_probs=self.y_probs,
300
+ pos_label=self.pos_label,
301
+ calibration_config={
302
+ k: v for k, v in self.evaluator_config.items() if k.startswith("calibration_")
303
+ },
304
+ label_list=self.label_list,
305
+ )
306
+
307
+ self._log_image_artifact(_plot_calibration_curve, "calibration_curve_plot")
308
+
309
+ def _log_binary_classifier_artifacts(self):
310
+ if self.y_probs is not None:
311
+ with _suppress_class_imbalance_errors(log_warning=False):
312
+ self._log_roc_curve()
313
+ with _suppress_class_imbalance_errors(log_warning=False):
314
+ self._log_precision_recall_curve()
315
+ with _suppress_class_imbalance_errors(ValueError, log_warning=False):
316
+ self._log_lift_curve()
317
+ with _suppress_class_imbalance_errors(TypeError, log_warning=False):
318
+ self._log_calibration_curve()
319
+
320
+ def _log_confusion_matrix(self):
321
+ """
322
+ Helper method for logging confusion matrix
323
+ """
324
+ # normalize the confusion matrix, keep consistent with sklearn autologging.
325
+ confusion_matrix = sk_metrics.confusion_matrix(
326
+ self.y_true,
327
+ self.y_pred,
328
+ labels=self.label_list,
329
+ normalize="true",
330
+ sample_weight=self.sample_weights,
331
+ )
332
+
333
+ def plot_confusion_matrix():
334
+ import matplotlib
335
+ import matplotlib.pyplot as plt
336
+
337
+ with matplotlib.rc_context(
338
+ {
339
+ "font.size": min(8, math.ceil(50.0 / len(self.label_list))),
340
+ "axes.labelsize": 8,
341
+ }
342
+ ):
343
+ _, ax = plt.subplots(1, 1, figsize=(6.0, 4.0), dpi=175)
344
+ disp = sk_metrics.ConfusionMatrixDisplay(
345
+ confusion_matrix=confusion_matrix,
346
+ display_labels=self.label_list,
347
+ ).plot(cmap="Blues", ax=ax)
348
+ disp.ax_.set_title("Normalized confusion matrix")
349
+
350
+ if hasattr(sk_metrics, "ConfusionMatrixDisplay"):
351
+ self._log_image_artifact(
352
+ plot_confusion_matrix,
353
+ "confusion_matrix",
354
+ )
355
+ return
356
+
357
+
358
+ def _is_categorical(values):
359
+ """
360
+ Infer whether input values are categorical on best effort.
361
+ Return True represent they are categorical, return False represent we cannot determine result.
362
+ """
363
+ dtype_name = pd.Series(values).convert_dtypes().dtype.name.lower()
364
+ return dtype_name in ["category", "string", "boolean"]
365
+
366
+
367
+ def _is_continuous(values):
368
+ """
369
+ Infer whether input values is continuous on best effort.
370
+ Return True represent they are continuous, return False represent we cannot determine result.
371
+ """
372
+ dtype_name = pd.Series(values).convert_dtypes().dtype.name.lower()
373
+ return dtype_name.startswith("float")
374
+
375
+
376
+ def _infer_model_type_by_labels(labels):
377
+ """
378
+ Infer model type by target values.
379
+ """
380
+ if _is_categorical(labels):
381
+ return _ModelType.CLASSIFIER
382
+ elif _is_continuous(labels):
383
+ return _ModelType.REGRESSOR
384
+ else:
385
+ return None # Unknown
386
+
387
+
388
+ def _extract_predict_fn_and_prodict_proba_fn(model):
389
+ predict_fn = None
390
+ predict_proba_fn = None
391
+
392
+ _, raw_model = _extract_raw_model(model)
393
+
394
+ if raw_model is not None:
395
+ predict_fn = raw_model.predict
396
+ predict_proba_fn = getattr(raw_model, "predict_proba", None)
397
+ try:
398
+ from mlflow.xgboost import (
399
+ _wrapped_xgboost_model_predict_fn,
400
+ _wrapped_xgboost_model_predict_proba_fn,
401
+ )
402
+
403
+ # Because shap evaluation will pass evaluation data in ndarray format
404
+ # (without feature names), if set validate_features=True it will raise error.
405
+ predict_fn = _wrapped_xgboost_model_predict_fn(raw_model, validate_features=False)
406
+ predict_proba_fn = _wrapped_xgboost_model_predict_proba_fn(
407
+ raw_model, validate_features=False
408
+ )
409
+ except ImportError:
410
+ pass
411
+ elif model is not None:
412
+ predict_fn = model.predict
413
+
414
+ return predict_fn, predict_proba_fn
415
+
416
+
417
+ @contextmanager
418
+ def _suppress_class_imbalance_errors(exception_type=Exception, log_warning=True):
419
+ """
420
+ Exception handler context manager to suppress Exceptions if the private environment
421
+ variable `_MLFLOW_EVALUATE_SUPPRESS_CLASSIFICATION_ERRORS` is set to `True`.
422
+ The purpose of this handler is to prevent an evaluation call for a binary or multiclass
423
+ classification automl run from aborting due to an extreme minority class imbalance
424
+ encountered during iterative training cycles due to the non deterministic sampling
425
+ behavior of Spark's DataFrame.sample() API.
426
+ The Exceptions caught in the usage of this are broad and are designed purely to not
427
+ interrupt the iterative hyperparameter tuning process. Final evaluations are done
428
+ in a more deterministic (but expensive) fashion.
429
+ """
430
+ try:
431
+ yield
432
+ except exception_type as e:
433
+ if _MLFLOW_EVALUATE_SUPPRESS_CLASSIFICATION_ERRORS.get():
434
+ if log_warning:
435
+ _logger.warning(
436
+ "Failed to calculate metrics due to class imbalance. "
437
+ "This is expected when the dataset is imbalanced."
438
+ )
439
+ else:
440
+ raise e
441
+
442
+
443
+ def _get_binary_sum_up_label_pred_prob(positive_class_index, positive_class, y, y_pred, y_probs):
444
+ y = np.array(y)
445
+ y_bin = np.where(y == positive_class, 1, 0)
446
+ y_pred_bin = None
447
+ y_prob_bin = None
448
+ if y_pred is not None:
449
+ y_pred = np.array(y_pred)
450
+ y_pred_bin = np.where(y_pred == positive_class, 1, 0)
451
+
452
+ if y_probs is not None:
453
+ y_probs = np.array(y_probs)
454
+ y_prob_bin = y_probs[:, positive_class_index]
455
+
456
+ return y_bin, y_pred_bin, y_prob_bin
457
+
458
+
459
+ def _get_common_classifier_metrics(
460
+ *, y_true, y_pred, y_proba, labels, average, pos_label, sample_weights
461
+ ):
462
+ metrics = {
463
+ "example_count": len(y_true),
464
+ "accuracy_score": sk_metrics.accuracy_score(y_true, y_pred, sample_weight=sample_weights),
465
+ "recall_score": sk_metrics.recall_score(
466
+ y_true,
467
+ y_pred,
468
+ average=average,
469
+ pos_label=pos_label,
470
+ sample_weight=sample_weights,
471
+ ),
472
+ "precision_score": sk_metrics.precision_score(
473
+ y_true,
474
+ y_pred,
475
+ average=average,
476
+ pos_label=pos_label,
477
+ sample_weight=sample_weights,
478
+ ),
479
+ "f1_score": sk_metrics.f1_score(
480
+ y_true,
481
+ y_pred,
482
+ average=average,
483
+ pos_label=pos_label,
484
+ sample_weight=sample_weights,
485
+ ),
486
+ }
487
+
488
+ if y_proba is not None:
489
+ with _suppress_class_imbalance_errors(ValueError):
490
+ metrics["log_loss"] = sk_metrics.log_loss(
491
+ y_true, y_proba, labels=labels, sample_weight=sample_weights
492
+ )
493
+ return metrics
494
+
495
+
496
+ def _get_binary_classifier_metrics(
497
+ *, y_true, y_pred, y_proba=None, labels=None, pos_label=1, sample_weights=None
498
+ ):
499
+ with _suppress_class_imbalance_errors(ValueError):
500
+ tn, fp, fn, tp = sk_metrics.confusion_matrix(y_true, y_pred, labels=labels).ravel()
501
+ return {
502
+ "true_negatives": tn,
503
+ "false_positives": fp,
504
+ "false_negatives": fn,
505
+ "true_positives": tp,
506
+ **_get_common_classifier_metrics(
507
+ y_true=y_true,
508
+ y_pred=y_pred,
509
+ y_proba=y_proba,
510
+ labels=labels,
511
+ average="binary",
512
+ pos_label=pos_label,
513
+ sample_weights=sample_weights,
514
+ ),
515
+ }
516
+
517
+
518
+ def _get_multiclass_classifier_metrics(
519
+ *,
520
+ y_true,
521
+ y_pred,
522
+ y_proba=None,
523
+ labels=None,
524
+ average="weighted",
525
+ sample_weights=None,
526
+ ):
527
+ metrics = _get_common_classifier_metrics(
528
+ y_true=y_true,
529
+ y_pred=y_pred,
530
+ y_proba=y_proba,
531
+ labels=labels,
532
+ average=average,
533
+ pos_label=None,
534
+ sample_weights=sample_weights,
535
+ )
536
+ if average in ("macro", "weighted") and y_proba is not None:
537
+ metrics.update(
538
+ roc_auc=sk_metrics.roc_auc_score(
539
+ y_true=y_true,
540
+ y_score=y_proba,
541
+ sample_weight=sample_weights,
542
+ average=average,
543
+ multi_class="ovr",
544
+ )
545
+ )
546
+ return metrics
547
+
548
+
549
+ def _get_classifier_per_class_metrics_collection_df(y, y_pred, labels, sample_weights):
550
+ per_class_metrics_list = []
551
+ for positive_class_index, positive_class in enumerate(labels):
552
+ (
553
+ y_bin,
554
+ y_pred_bin,
555
+ _,
556
+ ) = _get_binary_sum_up_label_pred_prob(
557
+ positive_class_index, positive_class, y, y_pred, None
558
+ )
559
+ per_class_metrics = {"positive_class": positive_class}
560
+ binary_classifier_metrics = _get_binary_classifier_metrics(
561
+ y_true=y_bin,
562
+ y_pred=y_pred_bin,
563
+ labels=[0, 1], # Use binary labels for per-class metrics
564
+ pos_label=1,
565
+ sample_weights=sample_weights,
566
+ )
567
+ if binary_classifier_metrics:
568
+ per_class_metrics.update(binary_classifier_metrics)
569
+ per_class_metrics_list.append(per_class_metrics)
570
+
571
+ return pd.DataFrame(per_class_metrics_list)
572
+
573
+
574
+ _Curve = namedtuple("_Curve", ["plot_fn", "plot_fn_args", "auc"])
575
+
576
+
577
+ def _gen_classifier_curve(
578
+ is_binomial,
579
+ y,
580
+ y_probs,
581
+ labels,
582
+ pos_label,
583
+ curve_type,
584
+ sample_weights,
585
+ ):
586
+ """
587
+ Generate precision-recall curve or ROC curve for classifier.
588
+
589
+ Args:
590
+ is_binomial: True if it is binary classifier otherwise False
591
+ y: True label values
592
+ y_probs: if binary classifier, the predicted probability for positive class.
593
+ if multiclass classifier, the predicted probabilities for all classes.
594
+ labels: The set of labels.
595
+ pos_label: The label of the positive class.
596
+ curve_type: "pr" or "roc"
597
+ sample_weights: Optional sample weights.
598
+
599
+ Returns:
600
+ An instance of "_Curve" which includes attributes "plot_fn", "plot_fn_args", "auc".
601
+ """
602
+ if curve_type == "roc":
603
+
604
+ def gen_line_x_y_label_auc(_y, _y_prob, _pos_label):
605
+ fpr, tpr, _ = sk_metrics.roc_curve(
606
+ _y,
607
+ _y_prob,
608
+ sample_weight=sample_weights,
609
+ # For multiclass classification where a one-vs-rest ROC curve is produced for each
610
+ # class, the positive label is binarized and should not be included in the plot
611
+ # legend
612
+ pos_label=_pos_label if _pos_label == pos_label else None,
613
+ )
614
+
615
+ auc = sk_metrics.roc_auc_score(y_true=_y, y_score=_y_prob, sample_weight=sample_weights)
616
+ return fpr, tpr, f"AUC={auc:.3f}", auc
617
+
618
+ xlabel = "False Positive Rate"
619
+ ylabel = "True Positive Rate"
620
+ title = "ROC curve"
621
+ if pos_label:
622
+ xlabel = f"False Positive Rate (Positive label: {pos_label})"
623
+ ylabel = f"True Positive Rate (Positive label: {pos_label})"
624
+ elif curve_type == "pr":
625
+
626
+ def gen_line_x_y_label_auc(_y, _y_prob, _pos_label):
627
+ precision, recall, _ = sk_metrics.precision_recall_curve(
628
+ _y,
629
+ _y_prob,
630
+ sample_weight=sample_weights,
631
+ # For multiclass classification where a one-vs-rest precision-recall curve is
632
+ # produced for each class, the positive label is binarized and should not be
633
+ # included in the plot legend
634
+ pos_label=_pos_label if _pos_label == pos_label else None,
635
+ )
636
+ # NB: We return average precision score (AP) instead of AUC because AP is more
637
+ # appropriate for summarizing a precision-recall curve
638
+ ap = sk_metrics.average_precision_score(
639
+ y_true=_y, y_score=_y_prob, pos_label=_pos_label, sample_weight=sample_weights
640
+ )
641
+ return recall, precision, f"AP={ap:.3f}", ap
642
+
643
+ xlabel = "Recall"
644
+ ylabel = "Precision"
645
+ title = "Precision recall curve"
646
+ if pos_label:
647
+ xlabel = f"Recall (Positive label: {pos_label})"
648
+ ylabel = f"Precision (Positive label: {pos_label})"
649
+ else:
650
+ assert False, "illegal curve type"
651
+
652
+ if is_binomial:
653
+ x_data, y_data, line_label, auc = gen_line_x_y_label_auc(y, y_probs, pos_label)
654
+ data_series = [(line_label, x_data, y_data)]
655
+ else:
656
+ curve_list = []
657
+ for positive_class_index, positive_class in enumerate(labels):
658
+ y_bin, _, y_prob_bin = _get_binary_sum_up_label_pred_prob(
659
+ positive_class_index, positive_class, y, labels, y_probs
660
+ )
661
+
662
+ x_data, y_data, line_label, auc = gen_line_x_y_label_auc(
663
+ y_bin, y_prob_bin, _pos_label=1
664
+ )
665
+ curve_list.append((positive_class, x_data, y_data, line_label, auc))
666
+
667
+ data_series = [
668
+ (f"label={positive_class},{line_label}", x_data, y_data)
669
+ for positive_class, x_data, y_data, line_label, _ in curve_list
670
+ ]
671
+ auc = [auc for _, _, _, _, auc in curve_list]
672
+
673
+ def _do_plot(**kwargs):
674
+ from matplotlib import pyplot
675
+
676
+ _, ax = plot_lines(**kwargs)
677
+ dash_line_args = {
678
+ "color": "gray",
679
+ "alpha": 0.3,
680
+ "drawstyle": "default",
681
+ "linestyle": "dashed",
682
+ }
683
+ if curve_type == "pr":
684
+ ax.plot([0, 1], [1, 0], **dash_line_args)
685
+ elif curve_type == "roc":
686
+ ax.plot([0, 1], [0, 1], **dash_line_args)
687
+
688
+ if is_binomial:
689
+ ax.legend(loc="best")
690
+ else:
691
+ ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
692
+ pyplot.subplots_adjust(right=0.6, bottom=0.25)
693
+
694
+ return _Curve(
695
+ plot_fn=_do_plot,
696
+ plot_fn_args={
697
+ "data_series": data_series,
698
+ "xlabel": xlabel,
699
+ "ylabel": ylabel,
700
+ "line_kwargs": {"drawstyle": "steps-post", "linewidth": 1},
701
+ "title": title,
702
+ },
703
+ auc=auc,
704
+ )