genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,1041 @@
1
+ import collections
2
+ import inspect
3
+ import logging
4
+ import pkgutil
5
+ import platform
6
+ import warnings
7
+ from copy import deepcopy
8
+ from importlib import import_module
9
+ from numbers import Number
10
+ from operator import itemgetter
11
+
12
+ import numpy as np
13
+ from packaging.version import Version
14
+
15
+ from mlflow import MlflowClient
16
+ from mlflow.entities.dataset_input import DatasetInput
17
+ from mlflow.entities.input_tag import InputTag
18
+ from mlflow.tracking.fluent import MLFLOW_DATASET_CONTEXT
19
+ from mlflow.utils.arguments_utils import _get_arg_names
20
+ from mlflow.utils.file_utils import TempDir
21
+ from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID
22
+ from mlflow.utils.time import get_current_time_millis
23
+
24
+ _logger = logging.getLogger(__name__)
25
+
26
+ # The prefix to note that all calculated metrics and artifacts are solely based on training datasets
27
+ _TRAINING_PREFIX = "training_"
28
+
29
+ _SAMPLE_WEIGHT = "sample_weight"
30
+
31
+ # _SklearnArtifact represents a artifact (e.g confusion matrix) that will be computed and
32
+ # logged during the autologging routine for a particular model type (eg, classifier, regressor).
33
+ _SklearnArtifact = collections.namedtuple(
34
+ "_SklearnArtifact", ["name", "function", "arguments", "title"]
35
+ )
36
+
37
+ # _SklearnMetric represents a metric (e.g, precision_score) that will be computed and
38
+ # logged during the autologging routine for a particular model type (eg, classifier, regressor).
39
+ _SklearnMetric = collections.namedtuple("_SklearnMetric", ["name", "function", "arguments"])
40
+
41
+
42
+ def _gen_xgboost_sklearn_estimators_to_patch():
43
+ import xgboost as xgb
44
+
45
+ all_classes = inspect.getmembers(xgb.sklearn, inspect.isclass)
46
+ base_class = xgb.sklearn.XGBModel
47
+ sklearn_estimators = []
48
+ for _, class_object in all_classes:
49
+ if issubclass(class_object, base_class) and class_object != base_class:
50
+ sklearn_estimators.append(class_object)
51
+
52
+ return sklearn_estimators
53
+
54
+
55
+ def _gen_lightgbm_sklearn_estimators_to_patch():
56
+ import lightgbm as lgb
57
+
58
+ import mlflow.lightgbm
59
+
60
+ all_classes = inspect.getmembers(lgb.sklearn, inspect.isclass)
61
+ base_class = lgb.sklearn._LGBMModelBase
62
+ sklearn_estimators = []
63
+ for _, class_object in all_classes:
64
+ package_name = class_object.__module__.split(".")[0]
65
+ if (
66
+ package_name == mlflow.lightgbm.FLAVOR_NAME
67
+ and issubclass(class_object, base_class)
68
+ and class_object != base_class
69
+ ):
70
+ sklearn_estimators.append(class_object)
71
+
72
+ return sklearn_estimators
73
+
74
+
75
+ def _get_estimator_info_tags(estimator):
76
+ """
77
+ Returns:
78
+ A dictionary of MLflow run tag keys and values describing the specified estimator.
79
+ """
80
+ return {
81
+ "estimator_name": estimator.__class__.__name__,
82
+ "estimator_class": (estimator.__class__.__module__ + "." + estimator.__class__.__name__),
83
+ }
84
+
85
+
86
+ def _get_X_y_and_sample_weight(fit_func, fit_args, fit_kwargs):
87
+ """
88
+ Get a tuple of (X, y, sample_weight) in the following steps.
89
+
90
+ 1. Extract X and y from fit_args and fit_kwargs.
91
+ 2. If the sample_weight argument exists in fit_func,
92
+ extract it from fit_args or fit_kwargs and return (X, y, sample_weight),
93
+ otherwise return (X, y)
94
+
95
+ Args:
96
+ fit_func: A fit function object.
97
+ fit_args: Positional arguments given to fit_func.
98
+ fit_kwargs: Keyword arguments given to fit_func.
99
+
100
+ Returns:
101
+ A tuple of either (X, y, sample_weight), where `y` and `sample_weight` may be
102
+ `None` if the specified `fit_args` and `fit_kwargs` do not specify labels or
103
+ a sample weighting. Copies of `X` and `y` are made in order to avoid mutation
104
+ of the dataset during training.
105
+ """
106
+
107
+ def _get_Xy(args, kwargs, X_var_name, y_var_name):
108
+ # corresponds to: model.fit(X, y)
109
+ if len(args) >= 2:
110
+ return args[:2]
111
+
112
+ # corresponds to: model.fit(X, <y_var_name>=y)
113
+ if len(args) == 1:
114
+ return args[0], kwargs.get(y_var_name)
115
+
116
+ # corresponds to: model.fit(<X_var_name>=X, <y_var_name>=y)
117
+ return kwargs[X_var_name], kwargs.get(y_var_name)
118
+
119
+ def _get_sample_weight(arg_names, args, kwargs):
120
+ sample_weight_index = arg_names.index(_SAMPLE_WEIGHT)
121
+
122
+ # corresponds to: model.fit(X, y, ..., sample_weight)
123
+ if len(args) > sample_weight_index:
124
+ return args[sample_weight_index]
125
+
126
+ # corresponds to: model.fit(X, y, ..., sample_weight=sample_weight)
127
+ if _SAMPLE_WEIGHT in kwargs:
128
+ return kwargs[_SAMPLE_WEIGHT]
129
+
130
+ return None
131
+
132
+ fit_arg_names = _get_arg_names(fit_func)
133
+ # In most cases, X_var_name and y_var_name become "X" and "y", respectively.
134
+ # However, certain sklearn models use different variable names for X and y.
135
+ # E.g., see: https://scikit-learn.org/stable/modules/generated/sklearn.multioutput.MultiOutputClassifier.html#sklearn.multioutput.MultiOutputClassifier.fit
136
+ X_var_name, y_var_name = fit_arg_names[:2]
137
+ X, y = _get_Xy(fit_args, fit_kwargs, X_var_name, y_var_name)
138
+ if X is not None:
139
+ X = deepcopy(X)
140
+ if y is not None:
141
+ y = deepcopy(y)
142
+ sample_weight = (
143
+ _get_sample_weight(fit_arg_names, fit_args, fit_kwargs)
144
+ if (_SAMPLE_WEIGHT in fit_arg_names)
145
+ else None
146
+ )
147
+
148
+ return (X, y, sample_weight)
149
+
150
+
151
+ def _get_metrics_value_dict(metrics_list):
152
+ metric_value_dict = {}
153
+ for metric in metrics_list:
154
+ try:
155
+ metric_value = metric.function(**metric.arguments)
156
+ except Exception as e:
157
+ _log_warning_for_metrics(metric.name, metric.function, e)
158
+ else:
159
+ metric_value_dict[metric.name] = metric_value
160
+ return metric_value_dict
161
+
162
+
163
+ def _get_classifier_metrics(fitted_estimator, prefix, X, y_true, sample_weight, pos_label): # noqa: D417
164
+ """
165
+ Compute and record various common metrics for classifiers
166
+
167
+ For (1) precision score:
168
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html
169
+ (2) recall score:
170
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html
171
+ (3) f1_score:
172
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
173
+ By default, when `pos_label` is not specified (passed in as `None`), we set `average`
174
+ to `weighted` to compute the weighted score of these metrics.
175
+ When the `pos_label` is specified (not `None`), we set `average` to `binary`.
176
+
177
+ For (4) accuracy score:
178
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html
179
+ we choose the parameter `normalize` to be `True` to output the percentage of accuracy,
180
+ as opposed to `False` that outputs the absolute correct number of sample prediction
181
+
182
+ We log additional metrics if certain classifier has method `predict_proba`
183
+ (5) log loss:
184
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html
185
+ (6) roc_auc_score:
186
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
187
+ By default, for roc_auc_score, we pick `average` to be `weighted`, `multi_class` to be `ovo`,
188
+ to make the output more insensitive to dataset imbalance.
189
+
190
+ Steps:
191
+ 1. Extract X and y_true from fit_args and fit_kwargs, and compute y_pred.
192
+ 2. If the sample_weight argument exists in fit_func (accuracy_score by default
193
+ has sample_weight), extract it from fit_args or fit_kwargs as
194
+ (y_true, y_pred, ...... sample_weight), otherwise as (y_true, y_pred, ......)
195
+ 3. return a dictionary of metric(name, value)
196
+
197
+ Args:
198
+ fitted_estimator: The already fitted classifier
199
+ fit_args: Positional arguments given to fit_func.
200
+ fit_kwargs: Keyword arguments given to fit_func.
201
+
202
+ Returns:
203
+ dictionary of (function name, computed value)
204
+ """
205
+ import sklearn
206
+
207
+ average = "weighted" if pos_label is None else "binary"
208
+ y_pred = fitted_estimator.predict(X)
209
+
210
+ classifier_metrics = [
211
+ _SklearnMetric(
212
+ name=prefix + "precision_score",
213
+ function=sklearn.metrics.precision_score,
214
+ arguments={
215
+ "y_true": y_true,
216
+ "y_pred": y_pred,
217
+ "pos_label": pos_label,
218
+ "average": average,
219
+ "sample_weight": sample_weight,
220
+ },
221
+ ),
222
+ _SklearnMetric(
223
+ name=prefix + "recall_score",
224
+ function=sklearn.metrics.recall_score,
225
+ arguments={
226
+ "y_true": y_true,
227
+ "y_pred": y_pred,
228
+ "pos_label": pos_label,
229
+ "average": average,
230
+ "sample_weight": sample_weight,
231
+ },
232
+ ),
233
+ _SklearnMetric(
234
+ name=prefix + "f1_score",
235
+ function=sklearn.metrics.f1_score,
236
+ arguments={
237
+ "y_true": y_true,
238
+ "y_pred": y_pred,
239
+ "pos_label": pos_label,
240
+ "average": average,
241
+ "sample_weight": sample_weight,
242
+ },
243
+ ),
244
+ _SklearnMetric(
245
+ name=prefix + "accuracy_score",
246
+ function=sklearn.metrics.accuracy_score,
247
+ arguments={
248
+ "y_true": y_true,
249
+ "y_pred": y_pred,
250
+ "normalize": True,
251
+ "sample_weight": sample_weight,
252
+ },
253
+ ),
254
+ ]
255
+
256
+ if hasattr(fitted_estimator, "predict_proba"):
257
+ y_pred_proba = fitted_estimator.predict_proba(X)
258
+ classifier_metrics.extend(
259
+ [
260
+ _SklearnMetric(
261
+ name=prefix + "log_loss",
262
+ function=sklearn.metrics.log_loss,
263
+ arguments={
264
+ "y_true": y_true,
265
+ "y_pred": y_pred_proba,
266
+ "sample_weight": sample_weight,
267
+ },
268
+ ),
269
+ ]
270
+ )
271
+
272
+ if _is_metric_supported("roc_auc_score"):
273
+ # For binary case, the parameter `y_score` expect scores must be
274
+ # the scores of the class with the greater label.
275
+ if len(y_pred_proba[0]) == 2:
276
+ y_pred_proba = y_pred_proba[:, 1]
277
+
278
+ classifier_metrics.extend(
279
+ [
280
+ _SklearnMetric(
281
+ name=prefix + "roc_auc",
282
+ function=sklearn.metrics.roc_auc_score,
283
+ arguments={
284
+ "y_true": y_true,
285
+ "y_score": y_pred_proba,
286
+ "average": "weighted",
287
+ "sample_weight": sample_weight,
288
+ "multi_class": "ovo",
289
+ },
290
+ ),
291
+ ]
292
+ )
293
+
294
+ return _get_metrics_value_dict(classifier_metrics)
295
+
296
+
297
+ def _get_class_labels_from_estimator(estimator):
298
+ """
299
+ Extracts class labels from `estimator` if `estimator.classes` is available.
300
+ """
301
+ return estimator.classes_ if hasattr(estimator, "classes_") else None
302
+
303
+
304
+ def _get_classifier_artifacts(fitted_estimator, prefix, X, y_true, sample_weight): # noqa: D417
305
+ """
306
+ Draw and record various common artifacts for classifier
307
+
308
+ For all classifiers, we always log:
309
+ (1) confusion matrix:
310
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html
311
+
312
+ For only binary classifiers, we will log:
313
+ (2) precision recall curve:
314
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_precision_recall_curve.html
315
+ (3) roc curve:
316
+ https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
317
+
318
+ Steps:
319
+ 1. Extract X and y_true from fit_args and fit_kwargs, and split into train & test datasets.
320
+ 2. If the sample_weight argument exists in fit_func (accuracy_score by default
321
+ has sample_weight), extract it from fit_args or fit_kwargs as
322
+ (y_true, y_pred, sample_weight, multioutput), otherwise as (y_true, y_pred, multioutput)
323
+ 3. return a list of artifacts path to be logged
324
+
325
+ Args:
326
+ fitted_estimator: The already fitted regressor
327
+ fit_args: Positional arguments given to fit_func.
328
+ fit_kwargs: Keyword arguments given to fit_func.
329
+
330
+ Returns:
331
+ List of artifacts to be logged
332
+ """
333
+ import sklearn
334
+
335
+ if not _is_plotting_supported():
336
+ return []
337
+
338
+ is_plot_function_deprecated = Version(sklearn.__version__) >= Version("1.0")
339
+
340
+ def plot_confusion_matrix(*args, **kwargs):
341
+ import matplotlib
342
+ import matplotlib.pyplot as plt
343
+
344
+ class_labels = _get_class_labels_from_estimator(fitted_estimator)
345
+ if class_labels is None:
346
+ class_labels = set(y_true)
347
+
348
+ with matplotlib.rc_context(
349
+ {
350
+ "font.size": min(8.0, 50.0 / len(class_labels)),
351
+ "axes.labelsize": 8.0,
352
+ "figure.dpi": 175,
353
+ }
354
+ ):
355
+ _, ax = plt.subplots(1, 1, figsize=(6.0, 4.0))
356
+ return (
357
+ sklearn.metrics.ConfusionMatrixDisplay.from_estimator(*args, **kwargs, ax=ax)
358
+ if is_plot_function_deprecated
359
+ else sklearn.metrics.plot_confusion_matrix(*args, **kwargs, ax=ax)
360
+ )
361
+
362
+ y_true_arg_name = "y" if is_plot_function_deprecated else "y_true"
363
+ classifier_artifacts = [
364
+ _SklearnArtifact(
365
+ name=prefix + "confusion_matrix",
366
+ function=plot_confusion_matrix,
367
+ arguments=dict(
368
+ estimator=fitted_estimator,
369
+ X=X,
370
+ sample_weight=sample_weight,
371
+ normalize="true",
372
+ cmap="Blues",
373
+ **{y_true_arg_name: y_true},
374
+ ),
375
+ title="Normalized confusion matrix",
376
+ ),
377
+ ]
378
+
379
+ # The plot_roc_curve and plot_precision_recall_curve can only be
380
+ # supported for binary classifier
381
+ if len(set(y_true)) == 2:
382
+ classifier_artifacts.extend(
383
+ [
384
+ _SklearnArtifact(
385
+ name=prefix + "roc_curve",
386
+ function=sklearn.metrics.RocCurveDisplay.from_estimator
387
+ if is_plot_function_deprecated
388
+ else sklearn.metrics.plot_roc_curve,
389
+ arguments={
390
+ "estimator": fitted_estimator,
391
+ "X": X,
392
+ "y": y_true,
393
+ "sample_weight": sample_weight,
394
+ },
395
+ title="ROC curve",
396
+ ),
397
+ _SklearnArtifact(
398
+ name=prefix + "precision_recall_curve",
399
+ function=sklearn.metrics.PrecisionRecallDisplay.from_estimator
400
+ if is_plot_function_deprecated
401
+ else sklearn.metrics.plot_precision_recall_curve,
402
+ arguments={
403
+ "estimator": fitted_estimator,
404
+ "X": X,
405
+ "y": y_true,
406
+ "sample_weight": sample_weight,
407
+ },
408
+ title="Precision recall curve",
409
+ ),
410
+ ]
411
+ )
412
+
413
+ return classifier_artifacts
414
+
415
+
416
+ def _get_regressor_metrics(fitted_estimator, prefix, X, y_true, sample_weight): # noqa: D417
417
+ """
418
+ Compute and record various common metrics for regressors
419
+
420
+ For (1) (root) mean squared error:
421
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html
422
+ (2) mean absolute error:
423
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html
424
+ (3) r2 score:
425
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html
426
+ By default, we choose the parameter `multioutput` to be `uniform_average`
427
+ to average outputs with uniform weight.
428
+
429
+ Steps:
430
+ 1. Extract X and y_true from fit_args and fit_kwargs, and compute y_pred.
431
+ 2. If the sample_weight argument exists in fit_func (accuracy_score by default
432
+ has sample_weight), extract it from fit_args or fit_kwargs as
433
+ (y_true, y_pred, sample_weight, multioutput), otherwise as (y_true, y_pred, multioutput)
434
+ 3. return a dictionary of metric(name, value)
435
+
436
+ Args:
437
+ fitted_estimator: The already fitted regressor
438
+ fit_args: Positional arguments given to fit_func.
439
+ fit_kwargs: Keyword arguments given to fit_func.
440
+
441
+ Returns:
442
+ dictionary of (function name, computed value)
443
+ """
444
+ import sklearn
445
+
446
+ y_pred = fitted_estimator.predict(X)
447
+
448
+ regressor_metrics = [
449
+ _SklearnMetric(
450
+ name=prefix + "mean_squared_error",
451
+ function=sklearn.metrics.mean_squared_error,
452
+ arguments={
453
+ "y_true": y_true,
454
+ "y_pred": y_pred,
455
+ "sample_weight": sample_weight,
456
+ "multioutput": "uniform_average",
457
+ },
458
+ ),
459
+ _SklearnMetric(
460
+ name=prefix + "mean_absolute_error",
461
+ function=sklearn.metrics.mean_absolute_error,
462
+ arguments={
463
+ "y_true": y_true,
464
+ "y_pred": y_pred,
465
+ "sample_weight": sample_weight,
466
+ "multioutput": "uniform_average",
467
+ },
468
+ ),
469
+ _SklearnMetric(
470
+ name=prefix + "r2_score",
471
+ function=sklearn.metrics.r2_score,
472
+ arguments={
473
+ "y_true": y_true,
474
+ "y_pred": y_pred,
475
+ "sample_weight": sample_weight,
476
+ "multioutput": "uniform_average",
477
+ },
478
+ ),
479
+ ]
480
+
481
+ # To be compatible with older versions of scikit-learn (below 0.22.2), where
482
+ # `sklearn.metrics.mean_squared_error` does not have "squared" parameter to calculate `rmse`,
483
+ # we compute it through np.sqrt(<value of mse>)
484
+ metrics_value_dict = _get_metrics_value_dict(regressor_metrics)
485
+ metrics_value_dict[prefix + "root_mean_squared_error"] = np.sqrt(
486
+ metrics_value_dict[prefix + "mean_squared_error"]
487
+ )
488
+
489
+ return metrics_value_dict
490
+
491
+
492
+ def _log_warning_for_metrics(func_name, func_call, err):
493
+ msg = (
494
+ func_call.__qualname__
495
+ + " failed. The metric "
496
+ + func_name
497
+ + " will not be recorded."
498
+ + " Metric error: "
499
+ + str(err)
500
+ )
501
+ _logger.warning(msg)
502
+
503
+
504
+ def _log_warning_for_artifacts(func_name, func_call, err):
505
+ msg = (
506
+ func_call.__qualname__
507
+ + " failed. The artifact "
508
+ + func_name
509
+ + " will not be recorded."
510
+ + " Artifact error: "
511
+ + str(err)
512
+ )
513
+ _logger.warning(msg)
514
+
515
+
516
+ def _log_specialized_estimator_content(
517
+ autologging_client,
518
+ fitted_estimator,
519
+ run_id,
520
+ prefix,
521
+ X,
522
+ y_true,
523
+ sample_weight,
524
+ pos_label,
525
+ model_id,
526
+ dataset,
527
+ ):
528
+ import sklearn
529
+
530
+ metrics = {}
531
+
532
+ if y_true is not None:
533
+ try:
534
+ if sklearn.base.is_classifier(fitted_estimator):
535
+ metrics = _get_classifier_metrics(
536
+ fitted_estimator, prefix, X, y_true, sample_weight, pos_label
537
+ )
538
+ elif sklearn.base.is_regressor(fitted_estimator):
539
+ metrics = _get_regressor_metrics(fitted_estimator, prefix, X, y_true, sample_weight)
540
+ except Exception as err:
541
+ msg = (
542
+ "Failed to autolog metrics for "
543
+ + fitted_estimator.__class__.__name__
544
+ + ". Logging error: "
545
+ + str(err)
546
+ )
547
+ _logger.warning(msg)
548
+ else:
549
+ autologging_client.log_metrics(
550
+ run_id=run_id,
551
+ metrics=metrics,
552
+ model_id=model_id,
553
+ dataset=dataset,
554
+ )
555
+
556
+ if sklearn.base.is_classifier(fitted_estimator):
557
+ try:
558
+ artifacts = _get_classifier_artifacts(
559
+ fitted_estimator, prefix, X, y_true, sample_weight
560
+ )
561
+ except Exception as e:
562
+ msg = (
563
+ "Failed to autolog artifacts for "
564
+ + fitted_estimator.__class__.__name__
565
+ + ". Logging error: "
566
+ + str(e)
567
+ )
568
+ _logger.warning(msg)
569
+ return metrics
570
+
571
+ try:
572
+ import matplotlib
573
+ import matplotlib.pyplot as plt
574
+ except ImportError as ie:
575
+ _logger.warning(
576
+ f"Failed to import matplotlib (error: {ie!r}). Skipping artifact logging."
577
+ )
578
+ return metrics
579
+
580
+ _matplotlib_config = {"savefig.dpi": 175, "figure.autolayout": True, "font.size": 8}
581
+ with TempDir() as tmp_dir:
582
+ for artifact in artifacts:
583
+ try:
584
+ with matplotlib.rc_context(_matplotlib_config):
585
+ display = artifact.function(**artifact.arguments)
586
+ display.ax_.set_title(artifact.title)
587
+ artifact_path = f"{artifact.name}.png"
588
+ filepath = tmp_dir.path(artifact_path)
589
+ display.figure_.savefig(fname=filepath, format="png")
590
+ plt.close(display.figure_)
591
+ except Exception as e:
592
+ _log_warning_for_artifacts(artifact.name, artifact.function, e)
593
+
594
+ MlflowClient().log_artifacts(run_id, tmp_dir.path())
595
+
596
+ return metrics
597
+
598
+
599
+ def _is_estimator_html_repr_supported():
600
+ import sklearn
601
+
602
+ # Only scikit-learn >= 0.23 supports `estimator_html_repr`
603
+ return Version(sklearn.__version__) >= Version("0.23.0")
604
+
605
+
606
+ def _log_estimator_html(run_id, estimator):
607
+ if not _is_estimator_html_repr_supported():
608
+ return
609
+
610
+ from sklearn.utils import estimator_html_repr
611
+
612
+ # Specifies charset so triangle toggle buttons are not garbled
613
+ estimator_html_string = f"""
614
+ <!DOCTYPE html>
615
+ <html lang="en">
616
+ <head>
617
+ <meta charset="UTF-8"/>
618
+ </head>
619
+ <body>
620
+ {estimator_html_repr(estimator)}
621
+ </body>
622
+ </html>
623
+ """
624
+ MlflowClient().log_text(run_id, estimator_html_string, artifact_file="estimator.html")
625
+
626
+
627
+ def _log_estimator_content(
628
+ autologging_client,
629
+ estimator,
630
+ run_id,
631
+ prefix,
632
+ X,
633
+ y_true=None,
634
+ sample_weight=None,
635
+ pos_label=None,
636
+ model_id=None,
637
+ dataset=None,
638
+ ):
639
+ """
640
+ Logs content for the given estimator, which includes metrics and artifacts that might be
641
+ tailored to the estimator's type (e.g., regression vs classification). Training labels
642
+ are required for metric computation; metrics will be omitted if labels are not available.
643
+
644
+ Args:
645
+ autologging_client: An instance of `MlflowAutologgingQueueingClient` used for
646
+ efficiently logging run data to MLflow Tracking.
647
+ estimator: The estimator used to compute metrics and artifacts.
648
+ run_id: The run under which the content is logged.
649
+ prefix: A prefix used to name the logged content. Typically it's 'training_' for
650
+ training-time content and user-controlled for evaluation-time content.
651
+ X: The data samples.
652
+ y_true: Labels.
653
+ sample_weight: Per-sample weights used in the computation of metrics and artifacts.
654
+ pos_label: The positive label used to compute binary classification metrics such as
655
+ precision, recall, f1, etc. This parameter is only used for classification metrics.
656
+ If set to `None`, the function will calculate metrics for each label and find their
657
+ average weighted by support (number of true instances for each label).
658
+ model_id: Model ID.
659
+ dataset: The dataset used to evaluate the model.
660
+
661
+ Returns:
662
+ A dict of the computed metrics.
663
+ """
664
+ metrics = _log_specialized_estimator_content(
665
+ autologging_client=autologging_client,
666
+ fitted_estimator=estimator,
667
+ run_id=run_id,
668
+ prefix=prefix,
669
+ X=X,
670
+ y_true=y_true,
671
+ sample_weight=sample_weight,
672
+ pos_label=pos_label,
673
+ model_id=model_id,
674
+ dataset=dataset,
675
+ )
676
+
677
+ if hasattr(estimator, "score") and y_true is not None:
678
+ try:
679
+ # Use the sample weight only if it is present in the score args
680
+ score_arg_names = _get_arg_names(estimator.score)
681
+ score_args = (
682
+ (X, y_true, sample_weight) if _SAMPLE_WEIGHT in score_arg_names else (X, y_true)
683
+ )
684
+ score = estimator.score(*score_args)
685
+ except Exception as e:
686
+ msg = (
687
+ estimator.score.__qualname__
688
+ + " failed. The 'training_score' metric will not be recorded. Scoring error: "
689
+ + str(e)
690
+ )
691
+ _logger.warning(msg)
692
+ else:
693
+ score_key = prefix + "score"
694
+ autologging_client.log_metrics(
695
+ run_id=run_id,
696
+ metrics={score_key: score},
697
+ model_id=model_id,
698
+ dataset=dataset,
699
+ )
700
+ metrics[score_key] = score
701
+ _log_estimator_html(run_id, estimator)
702
+ return metrics
703
+
704
+
705
+ def _get_meta_estimators_for_autologging():
706
+ """
707
+ Returns:
708
+ A list of meta estimator class definitions
709
+ (e.g., `sklearn.model_selection.GridSearchCV`) that should be included
710
+ when patching training functions for autologging
711
+ """
712
+ from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
713
+ from sklearn.pipeline import Pipeline
714
+
715
+ return [
716
+ GridSearchCV,
717
+ RandomizedSearchCV,
718
+ Pipeline,
719
+ ]
720
+
721
+
722
+ def _is_parameter_search_estimator(estimator):
723
+ """
724
+ Returns:
725
+ `True` if the specified scikit-learn estimator is a parameter search estimator,
726
+ such as `GridSearchCV`. `False` otherwise.
727
+ """
728
+ from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
729
+
730
+ parameter_search_estimators = [
731
+ GridSearchCV,
732
+ RandomizedSearchCV,
733
+ ]
734
+
735
+ return any(
736
+ isinstance(estimator, param_search_estimator)
737
+ for param_search_estimator in parameter_search_estimators
738
+ )
739
+
740
+
741
+ def _log_parameter_search_results_as_artifact(cv_results_df, run_id):
742
+ """
743
+ Records a collection of parameter search results as an MLflow artifact
744
+ for the specified run.
745
+
746
+ Args:
747
+ cv_results_df: A Pandas DataFrame containing the results of a parameter search
748
+ training session, which may be obtained by parsing the `cv_results_`
749
+ attribute of a trained parameter search estimator such as
750
+ `GridSearchCV`.
751
+ run_id: The ID of the MLflow Run to which the artifact should be recorded.
752
+ """
753
+ with TempDir() as t:
754
+ results_path = t.path("cv_results.csv")
755
+ cv_results_df.to_csv(results_path, index=False)
756
+ MlflowClient().log_artifact(run_id, results_path)
757
+
758
+
759
+ # Log how many child runs will be created vs omitted based on `max_tuning_runs`.
760
+ def _log_child_runs_info(max_tuning_runs, total_runs):
761
+ rest = total_runs - max_tuning_runs
762
+
763
+ # Set logging statement for runs to be logged.
764
+ if max_tuning_runs == 0:
765
+ logging_phrase = "no runs"
766
+ elif max_tuning_runs == 1:
767
+ logging_phrase = "the best run"
768
+ else:
769
+ logging_phrase = f"the {max_tuning_runs} best runs"
770
+
771
+ # Set logging statement for runs to be omitted.
772
+ if rest <= 0:
773
+ omitting_phrase = "no runs"
774
+ elif rest == 1:
775
+ omitting_phrase = "one run"
776
+ else:
777
+ omitting_phrase = f"{rest} runs"
778
+
779
+ _logger.info("Logging %s, %s will be omitted.", logging_phrase, omitting_phrase)
780
+
781
+
782
+ def _create_child_runs_for_parameter_search( # noqa: D417
783
+ autologging_client,
784
+ cv_estimator,
785
+ parent_run,
786
+ max_tuning_runs,
787
+ child_tags=None,
788
+ dataset=None,
789
+ best_estimator_params=None,
790
+ best_estimator_model_id=None,
791
+ ):
792
+ """
793
+ Creates a collection of child runs for a parameter search training session.
794
+ Runs are reconstructed from the `cv_results_` attribute of the specified trained
795
+ parameter search estimator - `cv_estimator`, which provides relevant performance
796
+ metrics for each point in the parameter search space. One child run is created
797
+ for each point in the parameter search space. For additional information, see
798
+ `https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html`_.
799
+
800
+ Args:
801
+ autologging_client: An instance of `MlflowAutologgingQueueingClient` used for
802
+ efficiently logging run data to MLflow Tracking.
803
+ cv_estimator: The trained parameter search estimator for which to create
804
+ child runs.
805
+ parent_run: A py:class:`mlflow.entities.Run` object referring to the parent
806
+ parameter search run for which child runs should be created.
807
+ child_tags: An optional dictionary of MLflow tag keys and values to log
808
+ for each child run.
809
+ dataset: The dataset used to evaluate the model.
810
+ best_estimator_params: The parameters of the best estimator.
811
+ best_estimator_model_id: The model ID of the logged best estimator.
812
+ """
813
+ import pandas as pd
814
+
815
+ def first_custom_rank_column(df):
816
+ column_names = df.columns.values
817
+ for col_name in column_names:
818
+ if "rank_test_" in col_name:
819
+ return col_name
820
+
821
+ # Use the start time of the parent parameter search run as a rough estimate for the
822
+ # start time of child runs, since we cannot precisely determine when each point
823
+ # in the parameter search space was explored
824
+ child_run_start_time = parent_run.info.start_time
825
+ child_run_end_time = get_current_time_millis()
826
+
827
+ seed_estimator = cv_estimator.estimator
828
+ # In the unlikely case that a seed of a parameter search estimator is,
829
+ # itself, a parameter search estimator, we should avoid logging the untuned
830
+ # parameters of the seeds's seed estimator
831
+ should_log_params_deeply = not _is_parameter_search_estimator(seed_estimator)
832
+ # Each row of `cv_results_` only provides parameters that vary across
833
+ # the user-specified parameter grid. In order to log the complete set
834
+ # of parameters for each child run, we fetch the parameters defined by
835
+ # the seed estimator and update them with parameter subset specified
836
+ # in the result row
837
+ base_params = seed_estimator.get_params(deep=should_log_params_deeply)
838
+ cv_results_df = pd.DataFrame.from_dict(cv_estimator.cv_results_)
839
+
840
+ if max_tuning_runs is None:
841
+ cv_results_best_n_df = cv_results_df
842
+ else:
843
+ rank_column_name = "rank_test_score"
844
+ if rank_column_name not in cv_results_df.columns.values:
845
+ rank_column_name = first_custom_rank_column(cv_results_df)
846
+ warnings.warn(
847
+ f"Top {max_tuning_runs} child runs will be created based on ordering in "
848
+ f"{rank_column_name} column. You can choose not to limit the number of "
849
+ "child runs created by setting `max_tuning_runs=None`."
850
+ )
851
+ cv_results_best_n_df = cv_results_df.nsmallest(max_tuning_runs, rank_column_name)
852
+ # Log how many child runs will be created vs omitted.
853
+ _log_child_runs_info(max_tuning_runs, len(cv_results_df))
854
+
855
+ datasets = [
856
+ DatasetInput(
857
+ dataset._to_mlflow_entity(), tags=[InputTag(key=MLFLOW_DATASET_CONTEXT, value="train")]
858
+ )
859
+ ]
860
+ for _, result_row in cv_results_best_n_df.iterrows():
861
+ tags_to_log = dict(child_tags) if child_tags else {}
862
+ tags_to_log.update({MLFLOW_PARENT_RUN_ID: parent_run.info.run_id})
863
+ tags_to_log.update(_get_estimator_info_tags(seed_estimator))
864
+ pending_child_run_id = autologging_client.create_run(
865
+ experiment_id=parent_run.info.experiment_id,
866
+ start_time=child_run_start_time,
867
+ tags=tags_to_log,
868
+ )
869
+
870
+ params_to_log = dict(base_params)
871
+ params_to_log.update(result_row.get("params", {}))
872
+ autologging_client.log_params(run_id=pending_child_run_id, params=params_to_log)
873
+
874
+ # Parameters values are recorded twice in the set of search `cv_results_`:
875
+ # once within a `params` column with dictionary values and once within
876
+ # a separate dataframe column that is created for each parameter. To prevent
877
+ # duplication of parameters, we log the consolidated values from the parameter
878
+ # dictionary column and filter out the other parameter-specific columns with
879
+ # names of the form `param_{param_name}`. Additionally, `cv_results_` produces
880
+ # metrics for each training split, which is fairly verbose; accordingly, we filter
881
+ # out per-split metrics in favor of aggregate metrics (mean, std, etc.)
882
+ excluded_metric_prefixes = ["param", "split"]
883
+ metrics_to_log = {
884
+ key: value
885
+ for key, value in result_row.items()
886
+ if not any(key.startswith(prefix) for prefix in excluded_metric_prefixes)
887
+ and isinstance(value, Number)
888
+ }
889
+ # Only log metrics to the best_estimator_model when the child run's
890
+ # parameters match the best_estimator's parameters.
891
+ model_id = (
892
+ best_estimator_model_id
893
+ if best_estimator_params
894
+ and result_row.get("params", {}).items() <= best_estimator_params.items()
895
+ else None
896
+ )
897
+ autologging_client.log_metrics(
898
+ run_id=pending_child_run_id,
899
+ metrics=metrics_to_log,
900
+ dataset=dataset,
901
+ model_id=model_id,
902
+ )
903
+ autologging_client.log_inputs(run_id=pending_child_run_id, datasets=datasets)
904
+ autologging_client.set_terminated(run_id=pending_child_run_id, end_time=child_run_end_time)
905
+
906
+
907
+ # Util function to check whether a metric is able to be computed in given sklearn version
908
+ def _is_metric_supported(metric_name):
909
+ import sklearn
910
+
911
+ # This dict can be extended to store special metrics' specific supported versions
912
+ _metric_supported_version = {"roc_auc_score": "0.22.2"}
913
+
914
+ return Version(sklearn.__version__) >= Version(_metric_supported_version[metric_name])
915
+
916
+
917
+ # Util function to check whether artifact plotting functions are able to be computed
918
+ # in given sklearn version (should >= 0.22.0)
919
+ def _is_plotting_supported():
920
+ import sklearn
921
+
922
+ return Version(sklearn.__version__) >= Version("0.22.0")
923
+
924
+
925
+ def _all_estimators():
926
+ try:
927
+ from sklearn.utils import all_estimators
928
+
929
+ return all_estimators()
930
+ except ImportError:
931
+ return _backported_all_estimators()
932
+
933
+
934
+ def _backported_all_estimators(type_filter=None):
935
+ """
936
+ Backported from scikit-learn 0.23.2:
937
+ https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/utils/__init__.py#L1146
938
+
939
+ Use this backported `all_estimators` in old versions of sklearn because:
940
+ 1. An inferior version of `all_estimators` that old versions of sklearn use for testing,
941
+ might function differently from a newer version.
942
+ 2. This backported `all_estimators` works on old versions of sklearn that don't even define
943
+ the testing utility variant of `all_estimators`.
944
+
945
+ ========== original docstring ==========
946
+ Get a list of all estimators from sklearn.
947
+ This function crawls the module and gets all classes that inherit
948
+ from BaseEstimator. Classes that are defined in test-modules are not
949
+ included.
950
+ By default meta_estimators such as GridSearchCV are also not included.
951
+ Parameters
952
+ ----------
953
+ type_filter : string, list of string, or None, default=None
954
+ Which kind of estimators should be returned. If None, no filter is
955
+ applied and all estimators are returned. Possible values are
956
+ 'classifier', 'regressor', 'cluster' and 'transformer' to get
957
+ estimators only of these specific types, or a list of these to
958
+ get the estimators that fit at least one of the types.
959
+
960
+ Returns
961
+ -------
962
+ estimators : list of tuples
963
+ List of (name, class), where ``name`` is the class name as string
964
+ and ``class`` is the actual type of the class.
965
+ """
966
+ # lazy import to avoid circular imports from sklearn.base
967
+ import sklearn
968
+ from sklearn.base import (
969
+ BaseEstimator,
970
+ ClassifierMixin,
971
+ ClusterMixin,
972
+ RegressorMixin,
973
+ TransformerMixin,
974
+ )
975
+ from sklearn.utils._testing import ignore_warnings
976
+
977
+ IS_PYPY = platform.python_implementation() == "PyPy"
978
+
979
+ def is_abstract(c):
980
+ if not hasattr(c, "__abstractmethods__"):
981
+ return False
982
+ if not len(c.__abstractmethods__):
983
+ return False
984
+ return True
985
+
986
+ all_classes = []
987
+ modules_to_ignore = {"tests", "externals", "setup", "conftest"}
988
+ root = sklearn.__path__[0] # sklearn package
989
+ # Ignore deprecation warnings triggered at import time and from walking
990
+ # packages
991
+ with ignore_warnings(category=FutureWarning):
992
+ for _, modname, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
993
+ mod_parts = modname.split(".")
994
+ if any(part in modules_to_ignore for part in mod_parts) or "._" in modname:
995
+ continue
996
+ module = import_module(modname)
997
+ classes = inspect.getmembers(module, inspect.isclass)
998
+ classes = [(name, est_cls) for name, est_cls in classes if not name.startswith("_")]
999
+
1000
+ # TODO: Remove when FeatureHasher is implemented in PYPY
1001
+ # Skips FeatureHasher for PYPY
1002
+ if IS_PYPY and "feature_extraction" in modname:
1003
+ classes = [(name, est_cls) for name, est_cls in classes if name == "FeatureHasher"]
1004
+
1005
+ all_classes.extend(classes)
1006
+
1007
+ all_classes = set(all_classes)
1008
+
1009
+ estimators = [
1010
+ c for c in all_classes if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
1011
+ ]
1012
+ # get rid of abstract base classes
1013
+ estimators = [c for c in estimators if not is_abstract(c[1])]
1014
+
1015
+ if type_filter is not None:
1016
+ # copy the object if type_filter is a list
1017
+ type_filter = list(type_filter) if isinstance(type_filter, list) else [type_filter]
1018
+ filtered_estimators = []
1019
+ filters = {
1020
+ "classifier": ClassifierMixin,
1021
+ "regressor": RegressorMixin,
1022
+ "transformer": TransformerMixin,
1023
+ "cluster": ClusterMixin,
1024
+ }
1025
+ for name, mixin in filters.items():
1026
+ if name in type_filter:
1027
+ type_filter.remove(name)
1028
+ filtered_estimators.extend([est for est in estimators if issubclass(est[1], mixin)])
1029
+ estimators = filtered_estimators
1030
+ if type_filter:
1031
+ raise ValueError(
1032
+ "Parameter type_filter must be 'classifier', "
1033
+ "'regressor', 'transformer', 'cluster' or "
1034
+ "None, got"
1035
+ f" {type_filter!r}"
1036
+ )
1037
+
1038
+ # drop duplicates, sort for reproducibility
1039
+ # itemgetter is used to ensure the sort does not extend to the 2nd item of
1040
+ # the tuple
1041
+ return sorted(set(estimators), key=itemgetter(0))