genesis-flow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (645) hide show
  1. genesis_flow-1.0.0.dist-info/METADATA +822 -0
  2. genesis_flow-1.0.0.dist-info/RECORD +645 -0
  3. genesis_flow-1.0.0.dist-info/WHEEL +5 -0
  4. genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
  5. genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
  6. genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
  7. mlflow/__init__.py +367 -0
  8. mlflow/__main__.py +3 -0
  9. mlflow/ag2/__init__.py +56 -0
  10. mlflow/ag2/ag2_logger.py +294 -0
  11. mlflow/anthropic/__init__.py +40 -0
  12. mlflow/anthropic/autolog.py +129 -0
  13. mlflow/anthropic/chat.py +144 -0
  14. mlflow/artifacts/__init__.py +268 -0
  15. mlflow/autogen/__init__.py +144 -0
  16. mlflow/autogen/chat.py +142 -0
  17. mlflow/azure/__init__.py +26 -0
  18. mlflow/azure/auth_handler.py +257 -0
  19. mlflow/azure/client.py +319 -0
  20. mlflow/azure/config.py +120 -0
  21. mlflow/azure/connection_factory.py +340 -0
  22. mlflow/azure/exceptions.py +27 -0
  23. mlflow/azure/stores.py +327 -0
  24. mlflow/azure/utils.py +183 -0
  25. mlflow/bedrock/__init__.py +45 -0
  26. mlflow/bedrock/_autolog.py +202 -0
  27. mlflow/bedrock/chat.py +122 -0
  28. mlflow/bedrock/stream.py +160 -0
  29. mlflow/bedrock/utils.py +43 -0
  30. mlflow/cli.py +707 -0
  31. mlflow/client.py +12 -0
  32. mlflow/config/__init__.py +56 -0
  33. mlflow/crewai/__init__.py +79 -0
  34. mlflow/crewai/autolog.py +253 -0
  35. mlflow/crewai/chat.py +29 -0
  36. mlflow/data/__init__.py +75 -0
  37. mlflow/data/artifact_dataset_sources.py +170 -0
  38. mlflow/data/code_dataset_source.py +40 -0
  39. mlflow/data/dataset.py +123 -0
  40. mlflow/data/dataset_registry.py +168 -0
  41. mlflow/data/dataset_source.py +110 -0
  42. mlflow/data/dataset_source_registry.py +219 -0
  43. mlflow/data/delta_dataset_source.py +167 -0
  44. mlflow/data/digest_utils.py +108 -0
  45. mlflow/data/evaluation_dataset.py +562 -0
  46. mlflow/data/filesystem_dataset_source.py +81 -0
  47. mlflow/data/http_dataset_source.py +145 -0
  48. mlflow/data/huggingface_dataset.py +258 -0
  49. mlflow/data/huggingface_dataset_source.py +118 -0
  50. mlflow/data/meta_dataset.py +104 -0
  51. mlflow/data/numpy_dataset.py +223 -0
  52. mlflow/data/pandas_dataset.py +231 -0
  53. mlflow/data/polars_dataset.py +352 -0
  54. mlflow/data/pyfunc_dataset_mixin.py +31 -0
  55. mlflow/data/schema.py +76 -0
  56. mlflow/data/sources.py +1 -0
  57. mlflow/data/spark_dataset.py +406 -0
  58. mlflow/data/spark_dataset_source.py +74 -0
  59. mlflow/data/spark_delta_utils.py +118 -0
  60. mlflow/data/tensorflow_dataset.py +350 -0
  61. mlflow/data/uc_volume_dataset_source.py +81 -0
  62. mlflow/db.py +27 -0
  63. mlflow/dspy/__init__.py +17 -0
  64. mlflow/dspy/autolog.py +197 -0
  65. mlflow/dspy/callback.py +398 -0
  66. mlflow/dspy/constant.py +1 -0
  67. mlflow/dspy/load.py +93 -0
  68. mlflow/dspy/save.py +393 -0
  69. mlflow/dspy/util.py +109 -0
  70. mlflow/dspy/wrapper.py +226 -0
  71. mlflow/entities/__init__.py +104 -0
  72. mlflow/entities/_mlflow_object.py +52 -0
  73. mlflow/entities/assessment.py +545 -0
  74. mlflow/entities/assessment_error.py +80 -0
  75. mlflow/entities/assessment_source.py +141 -0
  76. mlflow/entities/dataset.py +92 -0
  77. mlflow/entities/dataset_input.py +51 -0
  78. mlflow/entities/dataset_summary.py +62 -0
  79. mlflow/entities/document.py +48 -0
  80. mlflow/entities/experiment.py +109 -0
  81. mlflow/entities/experiment_tag.py +35 -0
  82. mlflow/entities/file_info.py +45 -0
  83. mlflow/entities/input_tag.py +35 -0
  84. mlflow/entities/lifecycle_stage.py +35 -0
  85. mlflow/entities/logged_model.py +228 -0
  86. mlflow/entities/logged_model_input.py +26 -0
  87. mlflow/entities/logged_model_output.py +32 -0
  88. mlflow/entities/logged_model_parameter.py +46 -0
  89. mlflow/entities/logged_model_status.py +74 -0
  90. mlflow/entities/logged_model_tag.py +33 -0
  91. mlflow/entities/metric.py +200 -0
  92. mlflow/entities/model_registry/__init__.py +29 -0
  93. mlflow/entities/model_registry/_model_registry_entity.py +13 -0
  94. mlflow/entities/model_registry/model_version.py +243 -0
  95. mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
  96. mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
  97. mlflow/entities/model_registry/model_version_search.py +25 -0
  98. mlflow/entities/model_registry/model_version_stages.py +25 -0
  99. mlflow/entities/model_registry/model_version_status.py +35 -0
  100. mlflow/entities/model_registry/model_version_tag.py +35 -0
  101. mlflow/entities/model_registry/prompt.py +73 -0
  102. mlflow/entities/model_registry/prompt_version.py +244 -0
  103. mlflow/entities/model_registry/registered_model.py +175 -0
  104. mlflow/entities/model_registry/registered_model_alias.py +35 -0
  105. mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
  106. mlflow/entities/model_registry/registered_model_search.py +25 -0
  107. mlflow/entities/model_registry/registered_model_tag.py +35 -0
  108. mlflow/entities/multipart_upload.py +74 -0
  109. mlflow/entities/param.py +49 -0
  110. mlflow/entities/run.py +97 -0
  111. mlflow/entities/run_data.py +84 -0
  112. mlflow/entities/run_info.py +188 -0
  113. mlflow/entities/run_inputs.py +59 -0
  114. mlflow/entities/run_outputs.py +43 -0
  115. mlflow/entities/run_status.py +41 -0
  116. mlflow/entities/run_tag.py +36 -0
  117. mlflow/entities/source_type.py +31 -0
  118. mlflow/entities/span.py +774 -0
  119. mlflow/entities/span_event.py +96 -0
  120. mlflow/entities/span_status.py +102 -0
  121. mlflow/entities/trace.py +317 -0
  122. mlflow/entities/trace_data.py +71 -0
  123. mlflow/entities/trace_info.py +220 -0
  124. mlflow/entities/trace_info_v2.py +162 -0
  125. mlflow/entities/trace_location.py +173 -0
  126. mlflow/entities/trace_state.py +39 -0
  127. mlflow/entities/trace_status.py +68 -0
  128. mlflow/entities/view_type.py +51 -0
  129. mlflow/environment_variables.py +866 -0
  130. mlflow/evaluation/__init__.py +16 -0
  131. mlflow/evaluation/assessment.py +369 -0
  132. mlflow/evaluation/evaluation.py +411 -0
  133. mlflow/evaluation/evaluation_tag.py +61 -0
  134. mlflow/evaluation/fluent.py +48 -0
  135. mlflow/evaluation/utils.py +201 -0
  136. mlflow/exceptions.py +213 -0
  137. mlflow/experiments.py +140 -0
  138. mlflow/gemini/__init__.py +81 -0
  139. mlflow/gemini/autolog.py +186 -0
  140. mlflow/gemini/chat.py +261 -0
  141. mlflow/genai/__init__.py +71 -0
  142. mlflow/genai/datasets/__init__.py +67 -0
  143. mlflow/genai/datasets/evaluation_dataset.py +131 -0
  144. mlflow/genai/evaluation/__init__.py +3 -0
  145. mlflow/genai/evaluation/base.py +411 -0
  146. mlflow/genai/evaluation/constant.py +23 -0
  147. mlflow/genai/evaluation/utils.py +244 -0
  148. mlflow/genai/judges/__init__.py +21 -0
  149. mlflow/genai/judges/databricks.py +404 -0
  150. mlflow/genai/label_schemas/__init__.py +153 -0
  151. mlflow/genai/label_schemas/label_schemas.py +209 -0
  152. mlflow/genai/labeling/__init__.py +159 -0
  153. mlflow/genai/labeling/labeling.py +250 -0
  154. mlflow/genai/optimize/__init__.py +13 -0
  155. mlflow/genai/optimize/base.py +198 -0
  156. mlflow/genai/optimize/optimizers/__init__.py +4 -0
  157. mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
  158. mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
  159. mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
  160. mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
  161. mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
  162. mlflow/genai/optimize/types.py +75 -0
  163. mlflow/genai/optimize/util.py +30 -0
  164. mlflow/genai/prompts/__init__.py +206 -0
  165. mlflow/genai/scheduled_scorers.py +431 -0
  166. mlflow/genai/scorers/__init__.py +26 -0
  167. mlflow/genai/scorers/base.py +492 -0
  168. mlflow/genai/scorers/builtin_scorers.py +765 -0
  169. mlflow/genai/scorers/scorer_utils.py +138 -0
  170. mlflow/genai/scorers/validation.py +165 -0
  171. mlflow/genai/utils/data_validation.py +146 -0
  172. mlflow/genai/utils/enum_utils.py +23 -0
  173. mlflow/genai/utils/trace_utils.py +211 -0
  174. mlflow/groq/__init__.py +42 -0
  175. mlflow/groq/_groq_autolog.py +74 -0
  176. mlflow/johnsnowlabs/__init__.py +888 -0
  177. mlflow/langchain/__init__.py +24 -0
  178. mlflow/langchain/api_request_parallel_processor.py +330 -0
  179. mlflow/langchain/autolog.py +147 -0
  180. mlflow/langchain/chat_agent_langgraph.py +340 -0
  181. mlflow/langchain/constant.py +1 -0
  182. mlflow/langchain/constants.py +1 -0
  183. mlflow/langchain/databricks_dependencies.py +444 -0
  184. mlflow/langchain/langchain_tracer.py +597 -0
  185. mlflow/langchain/model.py +919 -0
  186. mlflow/langchain/output_parsers.py +142 -0
  187. mlflow/langchain/retriever_chain.py +153 -0
  188. mlflow/langchain/runnables.py +527 -0
  189. mlflow/langchain/utils/chat.py +402 -0
  190. mlflow/langchain/utils/logging.py +671 -0
  191. mlflow/langchain/utils/serialization.py +36 -0
  192. mlflow/legacy_databricks_cli/__init__.py +0 -0
  193. mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
  194. mlflow/legacy_databricks_cli/configure/provider.py +482 -0
  195. mlflow/litellm/__init__.py +175 -0
  196. mlflow/llama_index/__init__.py +22 -0
  197. mlflow/llama_index/autolog.py +55 -0
  198. mlflow/llama_index/chat.py +43 -0
  199. mlflow/llama_index/constant.py +1 -0
  200. mlflow/llama_index/model.py +577 -0
  201. mlflow/llama_index/pyfunc_wrapper.py +332 -0
  202. mlflow/llama_index/serialize_objects.py +188 -0
  203. mlflow/llama_index/tracer.py +561 -0
  204. mlflow/metrics/__init__.py +479 -0
  205. mlflow/metrics/base.py +39 -0
  206. mlflow/metrics/genai/__init__.py +25 -0
  207. mlflow/metrics/genai/base.py +101 -0
  208. mlflow/metrics/genai/genai_metric.py +771 -0
  209. mlflow/metrics/genai/metric_definitions.py +450 -0
  210. mlflow/metrics/genai/model_utils.py +371 -0
  211. mlflow/metrics/genai/prompt_template.py +68 -0
  212. mlflow/metrics/genai/prompts/__init__.py +0 -0
  213. mlflow/metrics/genai/prompts/v1.py +422 -0
  214. mlflow/metrics/genai/utils.py +6 -0
  215. mlflow/metrics/metric_definitions.py +619 -0
  216. mlflow/mismatch.py +34 -0
  217. mlflow/mistral/__init__.py +34 -0
  218. mlflow/mistral/autolog.py +71 -0
  219. mlflow/mistral/chat.py +135 -0
  220. mlflow/ml_package_versions.py +452 -0
  221. mlflow/models/__init__.py +97 -0
  222. mlflow/models/auth_policy.py +83 -0
  223. mlflow/models/cli.py +354 -0
  224. mlflow/models/container/__init__.py +294 -0
  225. mlflow/models/container/scoring_server/__init__.py +0 -0
  226. mlflow/models/container/scoring_server/nginx.conf +39 -0
  227. mlflow/models/dependencies_schemas.py +287 -0
  228. mlflow/models/display_utils.py +158 -0
  229. mlflow/models/docker_utils.py +211 -0
  230. mlflow/models/evaluation/__init__.py +23 -0
  231. mlflow/models/evaluation/_shap_patch.py +64 -0
  232. mlflow/models/evaluation/artifacts.py +194 -0
  233. mlflow/models/evaluation/base.py +1811 -0
  234. mlflow/models/evaluation/calibration_curve.py +109 -0
  235. mlflow/models/evaluation/default_evaluator.py +996 -0
  236. mlflow/models/evaluation/deprecated.py +23 -0
  237. mlflow/models/evaluation/evaluator_registry.py +80 -0
  238. mlflow/models/evaluation/evaluators/classifier.py +704 -0
  239. mlflow/models/evaluation/evaluators/default.py +233 -0
  240. mlflow/models/evaluation/evaluators/regressor.py +96 -0
  241. mlflow/models/evaluation/evaluators/shap.py +296 -0
  242. mlflow/models/evaluation/lift_curve.py +178 -0
  243. mlflow/models/evaluation/utils/metric.py +123 -0
  244. mlflow/models/evaluation/utils/trace.py +179 -0
  245. mlflow/models/evaluation/validation.py +434 -0
  246. mlflow/models/flavor_backend.py +93 -0
  247. mlflow/models/flavor_backend_registry.py +53 -0
  248. mlflow/models/model.py +1639 -0
  249. mlflow/models/model_config.py +150 -0
  250. mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
  251. mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
  252. mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
  253. mlflow/models/python_api.py +369 -0
  254. mlflow/models/rag_signatures.py +128 -0
  255. mlflow/models/resources.py +321 -0
  256. mlflow/models/signature.py +662 -0
  257. mlflow/models/utils.py +2054 -0
  258. mlflow/models/wheeled_model.py +280 -0
  259. mlflow/openai/__init__.py +57 -0
  260. mlflow/openai/_agent_tracer.py +364 -0
  261. mlflow/openai/api_request_parallel_processor.py +131 -0
  262. mlflow/openai/autolog.py +509 -0
  263. mlflow/openai/constant.py +1 -0
  264. mlflow/openai/model.py +824 -0
  265. mlflow/openai/utils/chat_schema.py +367 -0
  266. mlflow/optuna/__init__.py +3 -0
  267. mlflow/optuna/storage.py +646 -0
  268. mlflow/plugins/__init__.py +72 -0
  269. mlflow/plugins/base.py +358 -0
  270. mlflow/plugins/builtin/__init__.py +24 -0
  271. mlflow/plugins/builtin/pytorch_plugin.py +150 -0
  272. mlflow/plugins/builtin/sklearn_plugin.py +158 -0
  273. mlflow/plugins/builtin/transformers_plugin.py +187 -0
  274. mlflow/plugins/cli.py +321 -0
  275. mlflow/plugins/discovery.py +340 -0
  276. mlflow/plugins/manager.py +465 -0
  277. mlflow/plugins/registry.py +316 -0
  278. mlflow/plugins/templates/framework_plugin_template.py +329 -0
  279. mlflow/prompt/constants.py +20 -0
  280. mlflow/prompt/promptlab_model.py +197 -0
  281. mlflow/prompt/registry_utils.py +248 -0
  282. mlflow/promptflow/__init__.py +495 -0
  283. mlflow/protos/__init__.py +0 -0
  284. mlflow/protos/assessments_pb2.py +174 -0
  285. mlflow/protos/databricks_artifacts_pb2.py +489 -0
  286. mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
  287. mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
  288. mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
  289. mlflow/protos/databricks_pb2.py +267 -0
  290. mlflow/protos/databricks_trace_server_pb2.py +374 -0
  291. mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
  292. mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
  293. mlflow/protos/facet_feature_statistics_pb2.py +296 -0
  294. mlflow/protos/internal_pb2.py +77 -0
  295. mlflow/protos/mlflow_artifacts_pb2.py +336 -0
  296. mlflow/protos/model_registry_pb2.py +1073 -0
  297. mlflow/protos/scalapb/__init__.py +0 -0
  298. mlflow/protos/scalapb/scalapb_pb2.py +104 -0
  299. mlflow/protos/service_pb2.py +2600 -0
  300. mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
  301. mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
  302. mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
  303. mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
  304. mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
  305. mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
  306. mlflow/py.typed +0 -0
  307. mlflow/pydantic_ai/__init__.py +57 -0
  308. mlflow/pydantic_ai/autolog.py +173 -0
  309. mlflow/pyfunc/__init__.py +3844 -0
  310. mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
  311. mlflow/pyfunc/backend.py +523 -0
  312. mlflow/pyfunc/context.py +78 -0
  313. mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
  314. mlflow/pyfunc/loaders/__init__.py +7 -0
  315. mlflow/pyfunc/loaders/chat_agent.py +117 -0
  316. mlflow/pyfunc/loaders/chat_model.py +125 -0
  317. mlflow/pyfunc/loaders/code_model.py +31 -0
  318. mlflow/pyfunc/loaders/responses_agent.py +112 -0
  319. mlflow/pyfunc/mlserver.py +46 -0
  320. mlflow/pyfunc/model.py +1473 -0
  321. mlflow/pyfunc/scoring_server/__init__.py +604 -0
  322. mlflow/pyfunc/scoring_server/app.py +7 -0
  323. mlflow/pyfunc/scoring_server/client.py +146 -0
  324. mlflow/pyfunc/spark_model_cache.py +48 -0
  325. mlflow/pyfunc/stdin_server.py +44 -0
  326. mlflow/pyfunc/utils/__init__.py +3 -0
  327. mlflow/pyfunc/utils/data_validation.py +224 -0
  328. mlflow/pyfunc/utils/environment.py +22 -0
  329. mlflow/pyfunc/utils/input_converter.py +47 -0
  330. mlflow/pyfunc/utils/serving_data_parser.py +11 -0
  331. mlflow/pytorch/__init__.py +1171 -0
  332. mlflow/pytorch/_lightning_autolog.py +580 -0
  333. mlflow/pytorch/_pytorch_autolog.py +50 -0
  334. mlflow/pytorch/pickle_module.py +35 -0
  335. mlflow/rfunc/__init__.py +42 -0
  336. mlflow/rfunc/backend.py +134 -0
  337. mlflow/runs.py +89 -0
  338. mlflow/server/__init__.py +302 -0
  339. mlflow/server/auth/__init__.py +1224 -0
  340. mlflow/server/auth/__main__.py +4 -0
  341. mlflow/server/auth/basic_auth.ini +6 -0
  342. mlflow/server/auth/cli.py +11 -0
  343. mlflow/server/auth/client.py +537 -0
  344. mlflow/server/auth/config.py +34 -0
  345. mlflow/server/auth/db/__init__.py +0 -0
  346. mlflow/server/auth/db/cli.py +18 -0
  347. mlflow/server/auth/db/migrations/__init__.py +0 -0
  348. mlflow/server/auth/db/migrations/alembic.ini +110 -0
  349. mlflow/server/auth/db/migrations/env.py +76 -0
  350. mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
  351. mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
  352. mlflow/server/auth/db/models.py +67 -0
  353. mlflow/server/auth/db/utils.py +37 -0
  354. mlflow/server/auth/entities.py +165 -0
  355. mlflow/server/auth/logo.py +14 -0
  356. mlflow/server/auth/permissions.py +65 -0
  357. mlflow/server/auth/routes.py +18 -0
  358. mlflow/server/auth/sqlalchemy_store.py +263 -0
  359. mlflow/server/graphql/__init__.py +0 -0
  360. mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
  361. mlflow/server/graphql/graphql_custom_scalars.py +24 -0
  362. mlflow/server/graphql/graphql_errors.py +15 -0
  363. mlflow/server/graphql/graphql_no_batching.py +89 -0
  364. mlflow/server/graphql/graphql_schema_extensions.py +74 -0
  365. mlflow/server/handlers.py +3217 -0
  366. mlflow/server/prometheus_exporter.py +17 -0
  367. mlflow/server/validation.py +30 -0
  368. mlflow/shap/__init__.py +691 -0
  369. mlflow/sklearn/__init__.py +1994 -0
  370. mlflow/sklearn/utils.py +1041 -0
  371. mlflow/smolagents/__init__.py +66 -0
  372. mlflow/smolagents/autolog.py +139 -0
  373. mlflow/smolagents/chat.py +29 -0
  374. mlflow/store/__init__.py +10 -0
  375. mlflow/store/_unity_catalog/__init__.py +1 -0
  376. mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
  377. mlflow/store/_unity_catalog/lineage/constants.py +2 -0
  378. mlflow/store/_unity_catalog/registry/__init__.py +6 -0
  379. mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
  380. mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
  381. mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
  382. mlflow/store/_unity_catalog/registry/utils.py +121 -0
  383. mlflow/store/artifact/__init__.py +0 -0
  384. mlflow/store/artifact/artifact_repo.py +472 -0
  385. mlflow/store/artifact/artifact_repository_registry.py +154 -0
  386. mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
  387. mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
  388. mlflow/store/artifact/cli.py +141 -0
  389. mlflow/store/artifact/cloud_artifact_repo.py +332 -0
  390. mlflow/store/artifact/databricks_artifact_repo.py +729 -0
  391. mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
  392. mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
  393. mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
  394. mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
  395. mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
  396. mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
  397. mlflow/store/artifact/ftp_artifact_repo.py +132 -0
  398. mlflow/store/artifact/gcs_artifact_repo.py +296 -0
  399. mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
  400. mlflow/store/artifact/http_artifact_repo.py +218 -0
  401. mlflow/store/artifact/local_artifact_repo.py +142 -0
  402. mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
  403. mlflow/store/artifact/models_artifact_repo.py +259 -0
  404. mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
  405. mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
  406. mlflow/store/artifact/r2_artifact_repo.py +70 -0
  407. mlflow/store/artifact/runs_artifact_repo.py +265 -0
  408. mlflow/store/artifact/s3_artifact_repo.py +330 -0
  409. mlflow/store/artifact/sftp_artifact_repo.py +141 -0
  410. mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
  411. mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
  412. mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
  413. mlflow/store/artifact/utils/__init__.py +0 -0
  414. mlflow/store/artifact/utils/models.py +148 -0
  415. mlflow/store/db/__init__.py +0 -0
  416. mlflow/store/db/base_sql_model.py +3 -0
  417. mlflow/store/db/db_types.py +10 -0
  418. mlflow/store/db/utils.py +314 -0
  419. mlflow/store/db_migrations/__init__.py +0 -0
  420. mlflow/store/db_migrations/alembic.ini +74 -0
  421. mlflow/store/db_migrations/env.py +84 -0
  422. mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
  423. mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
  424. mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
  425. mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
  426. mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
  427. mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
  428. mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
  429. mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
  430. mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
  431. mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
  432. mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
  433. mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
  434. mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
  435. mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
  436. mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
  437. mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
  438. mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
  439. mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
  440. mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
  441. mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
  442. mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
  443. mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
  444. mlflow/store/db_migrations/versions/__init__.py +0 -0
  445. mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
  446. mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
  447. mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
  448. mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
  449. mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
  450. mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
  451. mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
  452. mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
  453. mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
  454. mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
  455. mlflow/store/entities/__init__.py +3 -0
  456. mlflow/store/entities/paged_list.py +18 -0
  457. mlflow/store/model_registry/__init__.py +10 -0
  458. mlflow/store/model_registry/abstract_store.py +1081 -0
  459. mlflow/store/model_registry/base_rest_store.py +44 -0
  460. mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
  461. mlflow/store/model_registry/dbmodels/__init__.py +0 -0
  462. mlflow/store/model_registry/dbmodels/models.py +206 -0
  463. mlflow/store/model_registry/file_store.py +1091 -0
  464. mlflow/store/model_registry/rest_store.py +481 -0
  465. mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
  466. mlflow/store/tracking/__init__.py +23 -0
  467. mlflow/store/tracking/abstract_store.py +816 -0
  468. mlflow/store/tracking/dbmodels/__init__.py +0 -0
  469. mlflow/store/tracking/dbmodels/initial_models.py +243 -0
  470. mlflow/store/tracking/dbmodels/models.py +1073 -0
  471. mlflow/store/tracking/file_store.py +2438 -0
  472. mlflow/store/tracking/postgres_managed_identity.py +146 -0
  473. mlflow/store/tracking/rest_store.py +1131 -0
  474. mlflow/store/tracking/sqlalchemy_store.py +2785 -0
  475. mlflow/system_metrics/__init__.py +61 -0
  476. mlflow/system_metrics/metrics/__init__.py +0 -0
  477. mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
  478. mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
  479. mlflow/system_metrics/metrics/disk_monitor.py +21 -0
  480. mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
  481. mlflow/system_metrics/metrics/network_monitor.py +34 -0
  482. mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
  483. mlflow/system_metrics/system_metrics_monitor.py +198 -0
  484. mlflow/tracing/__init__.py +16 -0
  485. mlflow/tracing/assessment.py +356 -0
  486. mlflow/tracing/client.py +531 -0
  487. mlflow/tracing/config.py +125 -0
  488. mlflow/tracing/constant.py +105 -0
  489. mlflow/tracing/destination.py +81 -0
  490. mlflow/tracing/display/__init__.py +40 -0
  491. mlflow/tracing/display/display_handler.py +196 -0
  492. mlflow/tracing/export/async_export_queue.py +186 -0
  493. mlflow/tracing/export/inference_table.py +138 -0
  494. mlflow/tracing/export/mlflow_v3.py +137 -0
  495. mlflow/tracing/export/utils.py +70 -0
  496. mlflow/tracing/fluent.py +1417 -0
  497. mlflow/tracing/processor/base_mlflow.py +199 -0
  498. mlflow/tracing/processor/inference_table.py +175 -0
  499. mlflow/tracing/processor/mlflow_v3.py +47 -0
  500. mlflow/tracing/processor/otel.py +73 -0
  501. mlflow/tracing/provider.py +487 -0
  502. mlflow/tracing/trace_manager.py +200 -0
  503. mlflow/tracing/utils/__init__.py +616 -0
  504. mlflow/tracing/utils/artifact_utils.py +28 -0
  505. mlflow/tracing/utils/copy.py +55 -0
  506. mlflow/tracing/utils/environment.py +55 -0
  507. mlflow/tracing/utils/exception.py +21 -0
  508. mlflow/tracing/utils/once.py +35 -0
  509. mlflow/tracing/utils/otlp.py +63 -0
  510. mlflow/tracing/utils/processor.py +54 -0
  511. mlflow/tracing/utils/search.py +292 -0
  512. mlflow/tracing/utils/timeout.py +250 -0
  513. mlflow/tracing/utils/token.py +19 -0
  514. mlflow/tracing/utils/truncation.py +124 -0
  515. mlflow/tracing/utils/warning.py +76 -0
  516. mlflow/tracking/__init__.py +39 -0
  517. mlflow/tracking/_model_registry/__init__.py +1 -0
  518. mlflow/tracking/_model_registry/client.py +764 -0
  519. mlflow/tracking/_model_registry/fluent.py +853 -0
  520. mlflow/tracking/_model_registry/registry.py +67 -0
  521. mlflow/tracking/_model_registry/utils.py +251 -0
  522. mlflow/tracking/_tracking_service/__init__.py +0 -0
  523. mlflow/tracking/_tracking_service/client.py +883 -0
  524. mlflow/tracking/_tracking_service/registry.py +56 -0
  525. mlflow/tracking/_tracking_service/utils.py +275 -0
  526. mlflow/tracking/artifact_utils.py +179 -0
  527. mlflow/tracking/client.py +5900 -0
  528. mlflow/tracking/context/__init__.py +0 -0
  529. mlflow/tracking/context/abstract_context.py +35 -0
  530. mlflow/tracking/context/databricks_cluster_context.py +15 -0
  531. mlflow/tracking/context/databricks_command_context.py +15 -0
  532. mlflow/tracking/context/databricks_job_context.py +49 -0
  533. mlflow/tracking/context/databricks_notebook_context.py +41 -0
  534. mlflow/tracking/context/databricks_repo_context.py +43 -0
  535. mlflow/tracking/context/default_context.py +51 -0
  536. mlflow/tracking/context/git_context.py +32 -0
  537. mlflow/tracking/context/registry.py +98 -0
  538. mlflow/tracking/context/system_environment_context.py +15 -0
  539. mlflow/tracking/default_experiment/__init__.py +1 -0
  540. mlflow/tracking/default_experiment/abstract_context.py +43 -0
  541. mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
  542. mlflow/tracking/default_experiment/registry.py +75 -0
  543. mlflow/tracking/fluent.py +3595 -0
  544. mlflow/tracking/metric_value_conversion_utils.py +93 -0
  545. mlflow/tracking/multimedia.py +206 -0
  546. mlflow/tracking/registry.py +86 -0
  547. mlflow/tracking/request_auth/__init__.py +0 -0
  548. mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
  549. mlflow/tracking/request_auth/registry.py +60 -0
  550. mlflow/tracking/request_header/__init__.py +0 -0
  551. mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
  552. mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
  553. mlflow/tracking/request_header/default_request_header_provider.py +17 -0
  554. mlflow/tracking/request_header/registry.py +79 -0
  555. mlflow/transformers/__init__.py +2982 -0
  556. mlflow/transformers/flavor_config.py +258 -0
  557. mlflow/transformers/hub_utils.py +83 -0
  558. mlflow/transformers/llm_inference_utils.py +468 -0
  559. mlflow/transformers/model_io.py +301 -0
  560. mlflow/transformers/peft.py +51 -0
  561. mlflow/transformers/signature.py +183 -0
  562. mlflow/transformers/torch_utils.py +55 -0
  563. mlflow/types/__init__.py +21 -0
  564. mlflow/types/agent.py +270 -0
  565. mlflow/types/chat.py +240 -0
  566. mlflow/types/llm.py +935 -0
  567. mlflow/types/responses.py +139 -0
  568. mlflow/types/responses_helpers.py +416 -0
  569. mlflow/types/schema.py +1505 -0
  570. mlflow/types/type_hints.py +647 -0
  571. mlflow/types/utils.py +753 -0
  572. mlflow/utils/__init__.py +283 -0
  573. mlflow/utils/_capture_modules.py +256 -0
  574. mlflow/utils/_capture_transformers_modules.py +75 -0
  575. mlflow/utils/_spark_utils.py +201 -0
  576. mlflow/utils/_unity_catalog_oss_utils.py +97 -0
  577. mlflow/utils/_unity_catalog_utils.py +479 -0
  578. mlflow/utils/annotations.py +218 -0
  579. mlflow/utils/arguments_utils.py +16 -0
  580. mlflow/utils/async_logging/__init__.py +1 -0
  581. mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
  582. mlflow/utils/async_logging/async_logging_queue.py +366 -0
  583. mlflow/utils/async_logging/run_artifact.py +38 -0
  584. mlflow/utils/async_logging/run_batch.py +58 -0
  585. mlflow/utils/async_logging/run_operations.py +49 -0
  586. mlflow/utils/autologging_utils/__init__.py +737 -0
  587. mlflow/utils/autologging_utils/client.py +432 -0
  588. mlflow/utils/autologging_utils/config.py +33 -0
  589. mlflow/utils/autologging_utils/events.py +294 -0
  590. mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
  591. mlflow/utils/autologging_utils/metrics_queue.py +71 -0
  592. mlflow/utils/autologging_utils/safety.py +1104 -0
  593. mlflow/utils/autologging_utils/versioning.py +95 -0
  594. mlflow/utils/checkpoint_utils.py +206 -0
  595. mlflow/utils/class_utils.py +6 -0
  596. mlflow/utils/cli_args.py +257 -0
  597. mlflow/utils/conda.py +354 -0
  598. mlflow/utils/credentials.py +231 -0
  599. mlflow/utils/data_utils.py +17 -0
  600. mlflow/utils/databricks_utils.py +1436 -0
  601. mlflow/utils/docstring_utils.py +477 -0
  602. mlflow/utils/doctor.py +133 -0
  603. mlflow/utils/download_cloud_file_chunk.py +43 -0
  604. mlflow/utils/env_manager.py +16 -0
  605. mlflow/utils/env_pack.py +131 -0
  606. mlflow/utils/environment.py +1009 -0
  607. mlflow/utils/exception_utils.py +14 -0
  608. mlflow/utils/file_utils.py +978 -0
  609. mlflow/utils/git_utils.py +77 -0
  610. mlflow/utils/gorilla.py +797 -0
  611. mlflow/utils/import_hooks/__init__.py +363 -0
  612. mlflow/utils/lazy_load.py +51 -0
  613. mlflow/utils/logging_utils.py +168 -0
  614. mlflow/utils/mime_type_utils.py +58 -0
  615. mlflow/utils/mlflow_tags.py +103 -0
  616. mlflow/utils/model_utils.py +486 -0
  617. mlflow/utils/name_utils.py +346 -0
  618. mlflow/utils/nfs_on_spark.py +62 -0
  619. mlflow/utils/openai_utils.py +164 -0
  620. mlflow/utils/os.py +12 -0
  621. mlflow/utils/oss_registry_utils.py +29 -0
  622. mlflow/utils/plugins.py +17 -0
  623. mlflow/utils/process.py +182 -0
  624. mlflow/utils/promptlab_utils.py +146 -0
  625. mlflow/utils/proto_json_utils.py +743 -0
  626. mlflow/utils/pydantic_utils.py +54 -0
  627. mlflow/utils/request_utils.py +279 -0
  628. mlflow/utils/requirements_utils.py +704 -0
  629. mlflow/utils/rest_utils.py +673 -0
  630. mlflow/utils/search_logged_model_utils.py +127 -0
  631. mlflow/utils/search_utils.py +2111 -0
  632. mlflow/utils/secure_loading.py +221 -0
  633. mlflow/utils/security_validation.py +384 -0
  634. mlflow/utils/server_cli_utils.py +61 -0
  635. mlflow/utils/spark_utils.py +15 -0
  636. mlflow/utils/string_utils.py +138 -0
  637. mlflow/utils/thread_utils.py +63 -0
  638. mlflow/utils/time.py +54 -0
  639. mlflow/utils/timeout.py +42 -0
  640. mlflow/utils/uri.py +572 -0
  641. mlflow/utils/validation.py +662 -0
  642. mlflow/utils/virtualenv.py +458 -0
  643. mlflow/utils/warnings_utils.py +25 -0
  644. mlflow/utils/yaml_utils.py +179 -0
  645. mlflow/version.py +24 -0
@@ -0,0 +1,1994 @@
1
+ """
2
+ The ``mlflow.sklearn`` module provides an API for logging and loading scikit-learn models. This
3
+ module exports scikit-learn models with the following flavors:
4
+
5
+ Python (native) `pickle <https://scikit-learn.org/stable/modules/model_persistence.html>`_ format
6
+ This is the main flavor that can be loaded back into scikit-learn.
7
+
8
+ :py:mod:`mlflow.pyfunc`
9
+ Produced for use by generic pyfunc-based deployment tools and batch inference.
10
+ NOTE: The `mlflow.pyfunc` flavor is only added for scikit-learn models that define `predict()`,
11
+ since `predict()` is required for pyfunc model inference.
12
+ """
13
+
14
+ import functools
15
+ import inspect
16
+ import logging
17
+ import os
18
+ import pickle
19
+ import weakref
20
+ from collections import OrderedDict, defaultdict
21
+ from copy import deepcopy
22
+ from typing import Any, Optional
23
+
24
+ import numpy as np
25
+ import yaml
26
+ from packaging.version import Version
27
+
28
+ import mlflow
29
+ from mlflow import pyfunc
30
+ from mlflow.data.code_dataset_source import CodeDatasetSource
31
+ from mlflow.data.numpy_dataset import from_numpy
32
+ from mlflow.data.pandas_dataset import from_pandas
33
+ from mlflow.entities.dataset_input import DatasetInput
34
+ from mlflow.entities.input_tag import InputTag
35
+ from mlflow.exceptions import MlflowException
36
+ from mlflow.models import Model, ModelInputExample, ModelSignature
37
+ from mlflow.models.model import MLMODEL_FILE_NAME
38
+ from mlflow.models.signature import _infer_signature_from_input_example
39
+ from mlflow.models.utils import _save_example
40
+ from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, INVALID_PARAMETER_VALUE
41
+ from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
42
+ from mlflow.tracking.artifact_utils import _download_artifact_from_uri
43
+ from mlflow.tracking.client import MlflowClient
44
+ from mlflow.utils import _inspect_original_var_name, gorilla
45
+ from mlflow.utils.autologging_utils import (
46
+ INPUT_EXAMPLE_SAMPLE_ROWS,
47
+ MlflowAutologgingQueueingClient,
48
+ _get_new_training_session_class,
49
+ autologging_integration,
50
+ disable_autologging,
51
+ get_autologging_config,
52
+ get_instance_method_first_arg_value,
53
+ resolve_input_example_and_signature,
54
+ safe_patch,
55
+ update_wrapper_extended,
56
+ )
57
+ from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
58
+ from mlflow.utils.environment import (
59
+ _CONDA_ENV_FILE_NAME,
60
+ _CONSTRAINTS_FILE_NAME,
61
+ _PYTHON_ENV_FILE_NAME,
62
+ _REQUIREMENTS_FILE_NAME,
63
+ _mlflow_conda_env,
64
+ _process_conda_env,
65
+ _process_pip_requirements,
66
+ _PythonEnv,
67
+ _validate_env_arguments,
68
+ )
69
+ from mlflow.utils.file_utils import get_total_file_size, write_to
70
+ from mlflow.utils.mlflow_tags import (
71
+ MLFLOW_AUTOLOGGING,
72
+ MLFLOW_DATASET_CONTEXT,
73
+ )
74
+ from mlflow.utils.model_utils import (
75
+ _add_code_from_conf_to_system_path,
76
+ _get_flavor_configuration,
77
+ _validate_and_copy_code_paths,
78
+ _validate_and_prepare_target_save_path,
79
+ )
80
+ from mlflow.utils.requirements_utils import _get_pinned_requirement
81
+
82
+ FLAVOR_NAME = "sklearn"
83
+
84
+ SERIALIZATION_FORMAT_PICKLE = "pickle"
85
+ SERIALIZATION_FORMAT_CLOUDPICKLE = "cloudpickle"
86
+
87
+ SUPPORTED_SERIALIZATION_FORMATS = [SERIALIZATION_FORMAT_PICKLE, SERIALIZATION_FORMAT_CLOUDPICKLE]
88
+
89
+ _logger = logging.getLogger(__name__)
90
+ _SklearnTrainingSession = _get_new_training_session_class()
91
+
92
+ _MODEL_DATA_SUBPATH = "model.pkl"
93
+
94
+
95
+ def _gen_estimators_to_patch():
96
+ from mlflow.sklearn.utils import (
97
+ _all_estimators,
98
+ _get_meta_estimators_for_autologging,
99
+ )
100
+
101
+ _, estimators_to_patch = zip(*_all_estimators())
102
+ # Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected
103
+ # for patching if they are not already included in the output of `all_estimators()`
104
+ estimators_to_patch = set(estimators_to_patch).union(
105
+ set(_get_meta_estimators_for_autologging())
106
+ )
107
+ # Exclude certain preprocessing & feature manipulation estimators from patching. These
108
+ # estimators represent data manipulation routines (e.g., normalization, label encoding)
109
+ # rather than ML algorithms. Accordingly, we should not create MLflow runs and log
110
+ # parameters / metrics for these routines, unless they are captured as part of an ML pipeline
111
+ # (via `sklearn.pipeline.Pipeline`)
112
+ excluded_module_names = [
113
+ "sklearn.preprocessing",
114
+ "sklearn.impute",
115
+ "sklearn.feature_extraction",
116
+ "sklearn.feature_selection",
117
+ ]
118
+
119
+ excluded_class_names = [
120
+ "sklearn.compose._column_transformer.ColumnTransformer",
121
+ ]
122
+
123
+ return [
124
+ estimator
125
+ for estimator in estimators_to_patch
126
+ if not any(
127
+ estimator.__module__.startswith(excluded_module_name)
128
+ or (estimator.__module__ + "." + estimator.__name__) in excluded_class_names
129
+ for excluded_module_name in excluded_module_names
130
+ )
131
+ ]
132
+
133
+
134
+ def get_default_pip_requirements(include_cloudpickle=False):
135
+ """
136
+ Returns:
137
+ A list of default pip requirements for MLflow Models produced by this flavor.
138
+ Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
139
+ that, at minimum, contains these requirements.
140
+ """
141
+ pip_deps = [_get_pinned_requirement("scikit-learn", module="sklearn")]
142
+ if include_cloudpickle:
143
+ pip_deps += [_get_pinned_requirement("cloudpickle")]
144
+
145
+ return pip_deps
146
+
147
+
148
+ def get_default_conda_env(include_cloudpickle=False):
149
+ """
150
+ Returns:
151
+ The default Conda environment for MLflow Models produced by calls to
152
+ :func:`save_model()` and :func:`log_model()`.
153
+ """
154
+ return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements(include_cloudpickle))
155
+
156
+
157
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
158
+ def save_model(
159
+ sk_model,
160
+ path,
161
+ conda_env=None,
162
+ code_paths=None,
163
+ mlflow_model=None,
164
+ serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
165
+ signature: ModelSignature = None,
166
+ input_example: ModelInputExample = None,
167
+ pip_requirements=None,
168
+ extra_pip_requirements=None,
169
+ pyfunc_predict_fn="predict",
170
+ metadata=None,
171
+ ):
172
+ """
173
+ Save a scikit-learn model to a path on the local file system. Produces a MLflow Model
174
+ containing the following flavors:
175
+
176
+ - :py:mod:`mlflow.sklearn`
177
+ - :py:mod:`mlflow.pyfunc`. NOTE: This flavor is only included for scikit-learn models
178
+ that define `predict()`, since `predict()` is required for pyfunc model inference.
179
+
180
+ Args:
181
+ sk_model: scikit-learn model to be saved.
182
+ path: Local path where the model is to be saved.
183
+ conda_env: {{ conda_env }}
184
+ code_paths: {{ code_paths }}
185
+ mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
186
+ serialization_format: The format in which to serialize the model. This should be one of
187
+ the formats listed in
188
+ ``mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS``. The Cloudpickle
189
+ format, ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``,
190
+ provides better cross-system compatibility by identifying and
191
+ packaging code dependencies with the serialized model.
192
+
193
+ signature: {{ signature }}
194
+ input_example: {{ input_example }}
195
+ pip_requirements: {{ pip_requirements }}
196
+ extra_pip_requirements: {{ extra_pip_requirements }}
197
+ pyfunc_predict_fn: The name of the prediction function to use for inference with the
198
+ pyfunc representation of the resulting MLflow Model. Current supported functions
199
+ are: ``"predict"``, ``"predict_proba"``, ``"predict_log_proba"``,
200
+ ``"predict_joint_log_proba"``, and ``"score"``.
201
+ metadata: {{ metadata }}
202
+
203
+ .. code-block:: python
204
+ :caption: Example
205
+
206
+ import mlflow.sklearn
207
+ from sklearn.datasets import load_iris
208
+ from sklearn import tree
209
+
210
+ iris = load_iris()
211
+ sk_model = tree.DecisionTreeClassifier()
212
+ sk_model = sk_model.fit(iris.data, iris.target)
213
+
214
+ # Save the model in cloudpickle format
215
+ # set path to location for persistence
216
+ sk_path_dir_1 = ...
217
+ mlflow.sklearn.save_model(
218
+ sk_model,
219
+ sk_path_dir_1,
220
+ serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE,
221
+ )
222
+
223
+ # save the model in pickle format
224
+ # set path to location for persistence
225
+ sk_path_dir_2 = ...
226
+ mlflow.sklearn.save_model(
227
+ sk_model,
228
+ sk_path_dir_2,
229
+ serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE,
230
+ )
231
+ """
232
+ import sklearn
233
+
234
+ _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
235
+
236
+ if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS:
237
+ raise MlflowException(
238
+ message=(
239
+ f"Unrecognized serialization format: {serialization_format}. Please specify one"
240
+ f" of the following supported formats: {SUPPORTED_SERIALIZATION_FORMATS}."
241
+ ),
242
+ error_code=INVALID_PARAMETER_VALUE,
243
+ )
244
+
245
+ _validate_and_prepare_target_save_path(path)
246
+ code_path_subdir = _validate_and_copy_code_paths(code_paths, path)
247
+
248
+ if mlflow_model is None:
249
+ mlflow_model = Model()
250
+ saved_example = _save_example(mlflow_model, input_example, path)
251
+
252
+ if signature is None and saved_example is not None:
253
+ wrapped_model = _SklearnModelWrapper(sk_model)
254
+ signature = _infer_signature_from_input_example(saved_example, wrapped_model)
255
+ elif signature is False:
256
+ signature = None
257
+
258
+ if signature is not None:
259
+ mlflow_model.signature = signature
260
+ if metadata is not None:
261
+ mlflow_model.metadata = metadata
262
+
263
+ model_data_subpath = _MODEL_DATA_SUBPATH
264
+ model_data_path = os.path.join(path, model_data_subpath)
265
+ _save_model(
266
+ sk_model=sk_model,
267
+ output_path=model_data_path,
268
+ serialization_format=serialization_format,
269
+ )
270
+
271
+ # `PyFuncModel` only works for sklearn models that define a predict function
272
+
273
+ if hasattr(sk_model, pyfunc_predict_fn):
274
+ pyfunc.add_to_model(
275
+ mlflow_model,
276
+ loader_module="mlflow.sklearn",
277
+ model_path=model_data_subpath,
278
+ conda_env=_CONDA_ENV_FILE_NAME,
279
+ python_env=_PYTHON_ENV_FILE_NAME,
280
+ code=code_path_subdir,
281
+ predict_fn=pyfunc_predict_fn,
282
+ )
283
+ else:
284
+ _logger.warning(
285
+ f"Model was missing function: {pyfunc_predict_fn}. Not logging python_function flavor!"
286
+ )
287
+
288
+ mlflow_model.add_flavor(
289
+ FLAVOR_NAME,
290
+ pickled_model=model_data_subpath,
291
+ sklearn_version=sklearn.__version__,
292
+ serialization_format=serialization_format,
293
+ code=code_path_subdir,
294
+ )
295
+ if size := get_total_file_size(path):
296
+ mlflow_model.model_size_bytes = size
297
+ mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
298
+
299
+ if conda_env is None:
300
+ if pip_requirements is None:
301
+ include_cloudpickle = serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE
302
+ default_reqs = get_default_pip_requirements(include_cloudpickle)
303
+ # To ensure `_load_pyfunc` can successfully load the model during the dependency
304
+ # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
305
+ inferred_reqs = mlflow.models.infer_pip_requirements(
306
+ model_data_path,
307
+ FLAVOR_NAME,
308
+ fallback=default_reqs,
309
+ )
310
+ default_reqs = sorted(set(inferred_reqs).union(default_reqs))
311
+ else:
312
+ default_reqs = None
313
+ conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
314
+ default_reqs,
315
+ pip_requirements,
316
+ extra_pip_requirements,
317
+ )
318
+ else:
319
+ conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
320
+
321
+ with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
322
+ yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
323
+
324
+ # Save `constraints.txt` if necessary
325
+ if pip_constraints:
326
+ write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
327
+
328
+ # Save `requirements.txt`
329
+ write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
330
+
331
+ _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
332
+
333
+
334
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
335
+ def log_model(
336
+ sk_model,
337
+ artifact_path: Optional[str] = None,
338
+ conda_env=None,
339
+ code_paths=None,
340
+ serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
341
+ registered_model_name=None,
342
+ signature: ModelSignature = None,
343
+ input_example: ModelInputExample = None,
344
+ await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
345
+ pip_requirements=None,
346
+ extra_pip_requirements=None,
347
+ pyfunc_predict_fn="predict",
348
+ metadata=None,
349
+ # New arguments
350
+ params: Optional[dict[str, Any]] = None,
351
+ tags: Optional[dict[str, Any]] = None,
352
+ model_type: Optional[str] = None,
353
+ step: int = 0,
354
+ model_id: Optional[str] = None,
355
+ name: Optional[str] = None,
356
+ ):
357
+ """
358
+ Log a scikit-learn model as an MLflow artifact for the current run. Produces an MLflow Model
359
+ containing the following flavors:
360
+
361
+ - :py:mod:`mlflow.sklearn`
362
+ - :py:mod:`mlflow.pyfunc`. NOTE: This flavor is only included for scikit-learn models
363
+ that define `predict()`, since `predict()` is required for pyfunc model inference.
364
+
365
+ Args:
366
+ sk_model: scikit-learn model to be saved.
367
+ artifact_path: Deprecated. Use `name` instead.
368
+ conda_env: {{ conda_env }}
369
+ code_paths: {{ code_paths }}
370
+ serialization_format: The format in which to serialize the model. This should be one of
371
+ the formats listed in
372
+ ``mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS``. The Cloudpickle
373
+ format, ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``,
374
+ provides better cross-system compatibility by identifying and
375
+ packaging code dependencies with the serialized model.
376
+ registered_model_name: If given, create a model version under
377
+ ``registered_model_name``, also creating a registered model if one
378
+ with the given name does not exist.
379
+ signature: {{ signature }}
380
+ input_example: {{ input_example }}
381
+ await_registration_for: Number of seconds to wait for the model version to finish
382
+ being created and is in ``READY`` status. By default, the function
383
+ waits for five minutes. Specify 0 or None to skip waiting.
384
+ pip_requirements: {{ pip_requirements }}
385
+ extra_pip_requirements: {{ extra_pip_requirements }}
386
+ pyfunc_predict_fn: The name of the prediction function to use for inference with the
387
+ pyfunc representation of the resulting MLflow Model. Current supported functions
388
+ are: ``"predict"``, ``"predict_proba"``, ``"predict_log_proba"``,
389
+ ``"predict_joint_log_proba"``, and ``"score"``.
390
+ metadata: {{ metadata }}
391
+ params: {{ params }}
392
+ tags: {{ tags }}
393
+ model_type: {{ model_type }}
394
+ step: {{ step }}
395
+ model_id: {{ model_id }}
396
+ name: {{ name }}
397
+
398
+ Returns:
399
+ A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
400
+ metadata of the logged model.
401
+
402
+ .. code-block:: python
403
+ :caption: Example
404
+
405
+ import mlflow
406
+ import mlflow.sklearn
407
+ from mlflow.models import infer_signature
408
+ from sklearn.datasets import load_iris
409
+ from sklearn import tree
410
+
411
+ with mlflow.start_run():
412
+ # load dataset and train model
413
+ iris = load_iris()
414
+ sk_model = tree.DecisionTreeClassifier()
415
+ sk_model = sk_model.fit(iris.data, iris.target)
416
+
417
+ # log model params
418
+ mlflow.log_param("criterion", sk_model.criterion)
419
+ mlflow.log_param("splitter", sk_model.splitter)
420
+ signature = infer_signature(iris.data, sk_model.predict(iris.data))
421
+
422
+ # log model
423
+ mlflow.sklearn.log_model(sk_model, name="sk_models", signature=signature)
424
+
425
+ """
426
+ return Model.log(
427
+ artifact_path=artifact_path,
428
+ name=name,
429
+ flavor=mlflow.sklearn,
430
+ sk_model=sk_model,
431
+ conda_env=conda_env,
432
+ code_paths=code_paths,
433
+ serialization_format=serialization_format,
434
+ registered_model_name=registered_model_name,
435
+ signature=signature,
436
+ input_example=input_example,
437
+ await_registration_for=await_registration_for,
438
+ pip_requirements=pip_requirements,
439
+ extra_pip_requirements=extra_pip_requirements,
440
+ pyfunc_predict_fn=pyfunc_predict_fn,
441
+ metadata=metadata,
442
+ params=params,
443
+ tags=tags,
444
+ model_type=model_type,
445
+ step=step,
446
+ model_id=model_id,
447
+ )
448
+
449
+
450
+ def _load_model_from_local_file(path, serialization_format):
451
+ """Load a scikit-learn model saved as an MLflow artifact on the local file system.
452
+
453
+ Args:
454
+ path: Local filesystem path to the MLflow Model saved with the ``sklearn`` flavor
455
+ serialization_format: The format in which the model was serialized. This should be one of
456
+ the following: ``mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE`` or
457
+ ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``.
458
+ """
459
+ # TODO: we could validate the scikit-learn version here
460
+ if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS:
461
+ raise MlflowException(
462
+ message=(
463
+ f"Unrecognized serialization format: {serialization_format}. Please specify one"
464
+ f" of the following supported formats: {SUPPORTED_SERIALIZATION_FORMATS}."
465
+ ),
466
+ error_code=INVALID_PARAMETER_VALUE,
467
+ )
468
+ # Genesis-Flow: Use secure model loading to prevent code execution attacks
469
+ from mlflow.utils.secure_loading import SecureModelLoader
470
+
471
+ if serialization_format == SERIALIZATION_FORMAT_PICKLE:
472
+ return SecureModelLoader.safe_pickle_load(path)
473
+ elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE:
474
+ return SecureModelLoader.safe_cloudpickle_load(path)
475
+
476
+
477
+ def _load_pyfunc(path):
478
+ """
479
+ Load PyFunc implementation. Called by ``pyfunc.load_model``.
480
+
481
+ Args:
482
+ path: Local filesystem path to the MLflow Model with the ``sklearn`` flavor.
483
+ """
484
+ if os.path.isfile(path):
485
+ # Scikit-learn models saved in older versions of MLflow (<= 1.9.1) specify the ``data``
486
+ # field within the pyfunc flavor configuration. For these older models, the ``path``
487
+ # parameter of ``_load_pyfunc()`` refers directly to a serialized scikit-learn model
488
+ # object. In this case, we assume that the serialization format is ``pickle``, since
489
+ # the model loading procedure in older versions of MLflow used ``pickle.load()``.
490
+ serialization_format = SERIALIZATION_FORMAT_PICKLE
491
+ else:
492
+ # In contrast, scikit-learn models saved in versions of MLflow > 1.9.1 do not
493
+ # specify the ``data`` field within the pyfunc flavor configuration. For these newer
494
+ # models, the ``path`` parameter of ``load_pyfunc()`` refers to the top-level MLflow
495
+ # Model directory. In this case, we parse the model path from the MLmodel's pyfunc
496
+ # flavor configuration and attempt to fetch the serialization format from the
497
+ # scikit-learn flavor configuration
498
+ try:
499
+ sklearn_flavor_conf = _get_flavor_configuration(
500
+ model_path=path, flavor_name=FLAVOR_NAME
501
+ )
502
+ serialization_format = sklearn_flavor_conf.get(
503
+ "serialization_format", SERIALIZATION_FORMAT_PICKLE
504
+ )
505
+ except MlflowException:
506
+ _logger.warning(
507
+ "Could not find scikit-learn flavor configuration during model loading process."
508
+ " Assuming 'pickle' serialization format."
509
+ )
510
+ serialization_format = SERIALIZATION_FORMAT_PICKLE
511
+
512
+ pyfunc_flavor_conf = _get_flavor_configuration(
513
+ model_path=path, flavor_name=pyfunc.FLAVOR_NAME
514
+ )
515
+ path = os.path.join(path, pyfunc_flavor_conf["model_path"])
516
+
517
+ return _SklearnModelWrapper(
518
+ _load_model_from_local_file(path=path, serialization_format=serialization_format)
519
+ )
520
+
521
+
522
+ class _SklearnModelWrapper:
523
+ _SUPPORTED_CUSTOM_PREDICT_FN = [
524
+ "predict_proba",
525
+ "predict_log_proba",
526
+ "predict_joint_log_proba",
527
+ "score",
528
+ ]
529
+
530
+ def __init__(self, sklearn_model):
531
+ self.sklearn_model = sklearn_model
532
+
533
+ # Patch the model with custom predict functions that can be specified
534
+ # via `pyfunc_predict_fn` argument when saving or logging.
535
+ for predict_fn in self._SUPPORTED_CUSTOM_PREDICT_FN:
536
+ if fn := getattr(self.sklearn_model, predict_fn, None):
537
+ setattr(self, predict_fn, fn)
538
+
539
+ def get_raw_model(self):
540
+ """
541
+ Returns the underlying scikit-learn model.
542
+ """
543
+ return self.sklearn_model
544
+
545
+ def predict(
546
+ self,
547
+ data,
548
+ params: Optional[dict[str, Any]] = None,
549
+ ):
550
+ """
551
+ Args:
552
+ data: Model input data.
553
+ params: Additional parameters to pass to the model for inference.
554
+
555
+ Returns:
556
+ Model predictions.
557
+ """
558
+ return self.sklearn_model.predict(data)
559
+
560
+
561
+ class _SklearnCustomModelPicklingError(pickle.PicklingError):
562
+ """
563
+ Exception for describing error raised during pickling custom sklearn estimator
564
+ """
565
+
566
+ def __init__(self, sk_model, original_exception):
567
+ """
568
+ Args:
569
+ sk_model: The custom sklearn model to be pickled
570
+ original_exception: The original exception raised
571
+ """
572
+ super().__init__(
573
+ f"Pickling custom sklearn model {sk_model.__class__.__name__} failed "
574
+ f"when saving model: {original_exception}"
575
+ )
576
+ self.original_exception = original_exception
577
+
578
+
579
+ def _dump_model(pickle_lib, sk_model, out):
580
+ try:
581
+ # Using python's default protocol to optimize compatibility.
582
+ # Otherwise cloudpickle uses latest protocol leading to incompatibilities.
583
+ # See https://github.com/mlflow/mlflow/issues/5419
584
+ pickle_lib.dump(sk_model, out, protocol=pickle.DEFAULT_PROTOCOL)
585
+ except (pickle.PicklingError, TypeError, AttributeError) as e:
586
+ if sk_model.__class__ not in _gen_estimators_to_patch():
587
+ raise _SklearnCustomModelPicklingError(sk_model, e)
588
+ else:
589
+ raise
590
+
591
+
592
+ def _save_model(sk_model, output_path, serialization_format):
593
+ """
594
+ Args:
595
+ sk_model: The scikit-learn model to serialize.
596
+ output_path: The file path to which to write the serialized model.
597
+ serialization_format: The format in which to serialize the model. This should be one of
598
+ the following: ``mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE`` or
599
+ ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``.
600
+ """
601
+ with open(output_path, "wb") as out:
602
+ if serialization_format == SERIALIZATION_FORMAT_PICKLE:
603
+ _dump_model(pickle, sk_model, out)
604
+ elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE:
605
+ import cloudpickle
606
+
607
+ _dump_model(cloudpickle, sk_model, out)
608
+ else:
609
+ raise MlflowException(
610
+ message=f"Unrecognized serialization format: {serialization_format}",
611
+ error_code=INTERNAL_ERROR,
612
+ )
613
+
614
+
615
+ def load_model(model_uri, dst_path=None):
616
+ """
617
+ Load a scikit-learn model from a local file or a run.
618
+
619
+ Args:
620
+ model_uri: The location, in URI format, of the MLflow model, for example:
621
+
622
+ - ``/Users/me/path/to/local/model``
623
+ - ``relative/path/to/local/model``
624
+ - ``s3://my_bucket/path/to/model``
625
+ - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
626
+ - ``models:/<model_name>/<model_version>``
627
+ - ``models:/<model_name>/<stage>``
628
+
629
+ For more information about supported URI schemes, see
630
+ `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
631
+ artifact-locations>`_.
632
+ dst_path: The local filesystem path to which to download the model artifact.
633
+ This directory must already exist. If unspecified, a local output
634
+ path will be created.
635
+
636
+ Returns:
637
+ A scikit-learn model.
638
+
639
+ .. code-block:: python
640
+ :caption: Example
641
+
642
+ import mlflow.sklearn
643
+
644
+ sk_model = mlflow.sklearn.load_model("runs:/96771d893a5e46159d9f3b49bf9013e2/sk_models")
645
+
646
+ # use Pandas DataFrame to make predictions
647
+ pandas_df = ...
648
+ predictions = sk_model.predict(pandas_df)
649
+ """
650
+ local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
651
+ flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
652
+ _add_code_from_conf_to_system_path(local_model_path, flavor_conf)
653
+ sklearn_model_artifacts_path = os.path.join(local_model_path, flavor_conf["pickled_model"])
654
+ serialization_format = flavor_conf.get("serialization_format", SERIALIZATION_FORMAT_PICKLE)
655
+ return _load_model_from_local_file(
656
+ path=sklearn_model_artifacts_path, serialization_format=serialization_format
657
+ )
658
+
659
+
660
+ # The `_apis_autologging_disabled` contains APIs which is incompatible with autologging,
661
+ # when user call these APIs, autolog is temporarily disabled.
662
+ _apis_autologging_disabled = [
663
+ "cross_validate",
664
+ "cross_val_predict",
665
+ "cross_val_score",
666
+ "learning_curve",
667
+ "permutation_test_score",
668
+ "validation_curve",
669
+ ]
670
+
671
+
672
+ class _AutologgingMetricsManager:
673
+ """
674
+ This class is designed for holding information which is used by autologging metrics
675
+ It will hold information of:
676
+ (1) a map of "prediction result object id" to a tuple of dataset name(the dataset is
677
+ the one which generate the prediction result) and run_id.
678
+ Note: We need this map instead of setting the run_id into the "prediction result object"
679
+ because the object maybe a numpy array which does not support additional attribute
680
+ assignment.
681
+ (2) _log_post_training_metrics_enabled flag, in the following method scope:
682
+ `model.fit` and `model.score`, in order to avoid nested/duplicated autologging metric, when
683
+ run into these scopes, we need temporarily disable the metric autologging.
684
+ (3) _eval_dataset_info_map, it is a double level map:
685
+ `_eval_dataset_info_map[run_id][eval_dataset_var_name]` will get a list, each
686
+ element in the list is an id of "eval_dataset" instance.
687
+ This data structure is used for:
688
+ * generating unique dataset name key when autologging metric. For each eval dataset object,
689
+ if they have the same eval_dataset_var_name, but object ids are different,
690
+ then they will be assigned different name (via appending index to the
691
+ eval_dataset_var_name) when autologging.
692
+ (4) _metric_api_call_info, it is a double level map:
693
+ `_metric_api_call_info[run_id][metric_name]` will get a list of tuples, each tuple is:
694
+ (logged_metric_key, metric_call_command_string)
695
+ each call command string is like `metric_fn(arg1, arg2, ...)`
696
+ This data structure is used for:
697
+ * storing the call arguments dict for each metric call, we need log them into metric_info
698
+ artifact file.
699
+
700
+ Note: this class is not thread-safe.
701
+ Design rule for this class:
702
+ Because this class instance is a global instance, in order to prevent memory leak, it should
703
+ only holds IDs and other small objects references. This class internal data structure should
704
+ avoid reference to user dataset variables or model variables.
705
+ """
706
+
707
+ def __init__(self):
708
+ self._pred_result_id_mapping = {}
709
+ self._eval_dataset_info_map = defaultdict(lambda: defaultdict(list))
710
+ self._metric_api_call_info = defaultdict(lambda: defaultdict(list))
711
+ self._log_post_training_metrics_enabled = True
712
+ self._metric_info_artifact_need_update = defaultdict(lambda: False)
713
+ self._model_id_mapping = {}
714
+
715
+ def should_log_post_training_metrics(self):
716
+ """
717
+ Check whether we should run patching code for autologging post training metrics.
718
+ This checking should surround the whole patched code due to the safe guard checking,
719
+ See following note.
720
+
721
+ Note: It includes checking `_SklearnTrainingSession.is_active()`, This is a safe guarding
722
+ for meta-estimator (e.g. GridSearchCV) case:
723
+ running GridSearchCV.fit, the nested `estimator.fit` will be called in parallel,
724
+ but, the _autolog_training_status is a global status without thread-safe lock protecting.
725
+ This safe guarding will prevent code run into this case.
726
+ """
727
+ return not _SklearnTrainingSession.is_active() and self._log_post_training_metrics_enabled
728
+
729
+ def disable_log_post_training_metrics(self):
730
+ class LogPostTrainingMetricsDisabledScope:
731
+ def __enter__(inner_self):
732
+ inner_self.old_status = self._log_post_training_metrics_enabled
733
+ self._log_post_training_metrics_enabled = False
734
+
735
+ def __exit__(inner_self, exc_type, exc_val, exc_tb):
736
+ self._log_post_training_metrics_enabled = inner_self.old_status
737
+
738
+ return LogPostTrainingMetricsDisabledScope()
739
+
740
+ @staticmethod
741
+ def get_run_id_for_model(model):
742
+ return getattr(model, "_mlflow_run_id", None)
743
+
744
+ @staticmethod
745
+ def is_metric_value_loggable(metric_value):
746
+ """
747
+ Check whether the specified `metric_value` is a numeric value which can be logged
748
+ as an MLflow metric.
749
+ """
750
+ return isinstance(metric_value, (int, float, np.number)) and not isinstance(
751
+ metric_value, bool
752
+ )
753
+
754
+ def register_model(self, model, run_id):
755
+ """
756
+ In `patched_fit`, we need register the model with the run_id used in `patched_fit`
757
+ So that in following metric autologging, the metric will be logged into the registered
758
+ run_id
759
+ """
760
+ model._mlflow_run_id = run_id
761
+
762
+ def record_model_id(self, model, model_id):
763
+ """
764
+ Record the id(model) -> model_id mapping so that we can log metrics to the
765
+ model later.
766
+ """
767
+ self._model_id_mapping[id(model)] = model_id
768
+
769
+ def get_model_id_for_model(self, model) -> Optional[str]:
770
+ return self._model_id_mapping.get(id(model))
771
+
772
+ @staticmethod
773
+ def gen_name_with_index(name, index):
774
+ assert index >= 0
775
+ if index == 0:
776
+ return name
777
+ else:
778
+ # Use '-' as the separator between name and index,
779
+ # The '-' is not valid character in python var name
780
+ # so it can prevent name conflicts after appending index.
781
+ return f"{name}-{index + 1}"
782
+
783
+ def register_prediction_input_dataset(self, model, eval_dataset):
784
+ """
785
+ Register prediction input dataset into eval_dataset_info_map, it will do:
786
+ 1. inspect eval dataset var name.
787
+ 2. check whether eval_dataset_info_map already registered this eval dataset.
788
+ will check by object id.
789
+ 3. register eval dataset with id.
790
+ 4. return eval dataset name with index.
791
+
792
+ Note: this method include inspecting argument variable name.
793
+ So should be called directly from the "patched method", to ensure it capture
794
+ correct argument variable name.
795
+ """
796
+ eval_dataset_name = _inspect_original_var_name(
797
+ eval_dataset, fallback_name="unknown_dataset"
798
+ )
799
+ eval_dataset_id = id(eval_dataset)
800
+
801
+ run_id = self.get_run_id_for_model(model)
802
+ registered_dataset_list = self._eval_dataset_info_map[run_id][eval_dataset_name]
803
+
804
+ for i, id_i in enumerate(registered_dataset_list):
805
+ if eval_dataset_id == id_i:
806
+ index = i
807
+ break
808
+ else:
809
+ index = len(registered_dataset_list)
810
+
811
+ if index == len(registered_dataset_list):
812
+ # register new eval dataset
813
+ registered_dataset_list.append(eval_dataset_id)
814
+
815
+ return self.gen_name_with_index(eval_dataset_name, index)
816
+
817
+ def register_prediction_result(self, run_id, eval_dataset_name, predict_result, model_id=None):
818
+ """
819
+ Register the relationship
820
+ id(prediction_result) --> (eval_dataset_name, run_id, model_id)
821
+ into map `_pred_result_id_mapping`
822
+ """
823
+ value = (eval_dataset_name, run_id, model_id)
824
+ prediction_result_id = id(predict_result)
825
+ self._pred_result_id_mapping[prediction_result_id] = value
826
+
827
+ def clean_id(id_):
828
+ _AUTOLOGGING_METRICS_MANAGER._pred_result_id_mapping.pop(id_, None)
829
+
830
+ # When the `predict_result` object being GCed, its ID may be reused, so register a finalizer
831
+ # to clear the ID from the dict for preventing wrong ID mapping.
832
+ weakref.finalize(predict_result, clean_id, prediction_result_id)
833
+
834
+ @staticmethod
835
+ def gen_metric_call_command(self_obj, metric_fn, *call_pos_args, **call_kwargs):
836
+ """
837
+ Generate metric function call command string like `metric_fn(arg1, arg2, ...)`
838
+ Note: this method include inspecting argument variable name.
839
+ So should be called directly from the "patched method", to ensure it capture
840
+ correct argument variable name.
841
+
842
+ Args:
843
+ self_obj: If the metric_fn is a method of an instance (e.g. `model.score`),
844
+ the `self_obj` represent the instance.
845
+ metric_fn: metric function.
846
+ call_pos_args: the positional arguments of the metric function call. If `metric_fn`
847
+ is instance method, then the `call_pos_args` should exclude the first `self`
848
+ argument.
849
+ call_kwargs: the keyword arguments of the metric function call.
850
+ """
851
+
852
+ arg_list = []
853
+
854
+ def arg_to_str(arg):
855
+ if arg is None or np.isscalar(arg):
856
+ if isinstance(arg, str) and len(arg) > 32:
857
+ # truncate too long string
858
+ return repr(arg[:32] + "...")
859
+ return repr(arg)
860
+ else:
861
+ # dataset arguments or other non-scalar type argument
862
+ return _inspect_original_var_name(arg, fallback_name=f"<{arg.__class__.__name__}>")
863
+
864
+ param_sig = inspect.signature(metric_fn).parameters
865
+ arg_names = list(param_sig.keys())
866
+
867
+ if self_obj is not None:
868
+ # If metric_fn is a method of an instance, e.g. `model.score`,
869
+ # then the first argument is `self` which we need exclude it.
870
+ arg_names.pop(0)
871
+
872
+ if self_obj is not None:
873
+ call_fn_name = f"{self_obj.__class__.__name__}.{metric_fn.__name__}"
874
+ else:
875
+ call_fn_name = metric_fn.__name__
876
+
877
+ # Attach param signature key for positinal param values
878
+ for arg_name, arg in zip(arg_names, call_pos_args):
879
+ arg_list.append(f"{arg_name}={arg_to_str(arg)}")
880
+
881
+ for arg_name, arg in call_kwargs.items():
882
+ arg_list.append(f"{arg_name}={arg_to_str(arg)}")
883
+
884
+ arg_list_str = ", ".join(arg_list)
885
+
886
+ return f"{call_fn_name}({arg_list_str})"
887
+
888
+ def register_metric_api_call(self, run_id, metric_name, dataset_name, call_command):
889
+ """
890
+ This method will do:
891
+ (1) Generate and return metric key, format is:
892
+ {metric_name}[-{call_index}]_{eval_dataset_name}
893
+ metric_name is generated by metric function name, if multiple calls on the same
894
+ metric API happen, the following calls will be assigned with an increasing "call index".
895
+ (2) Register the metric key with the "call command" information into
896
+ `_AUTOLOGGING_METRICS_MANAGER`. See doc of `gen_metric_call_command` method for
897
+ details of "call command".
898
+ """
899
+
900
+ call_cmd_list = self._metric_api_call_info[run_id][metric_name]
901
+
902
+ index = len(call_cmd_list)
903
+ metric_name_with_index = self.gen_name_with_index(metric_name, index)
904
+ metric_key = f"{metric_name_with_index}_{dataset_name}"
905
+
906
+ call_cmd_list.append((metric_key, call_command))
907
+
908
+ # Set the flag to true, represent the metric info in this run need update.
909
+ # Later when `log_eval_metric` called, it will generate a new metric_info artifact
910
+ # and overwrite the old artifact.
911
+ self._metric_info_artifact_need_update[run_id] = True
912
+ return metric_key
913
+
914
+ def get_info_for_metric_api_call(self, call_pos_args, call_kwargs):
915
+ """
916
+ Given a metric api call (include the called metric function, and call arguments)
917
+ Register the call information (arguments dict) into the `metric_api_call_arg_dict_list_map`
918
+ and return a tuple of (run_id, eval_dataset_name, model_id)
919
+ """
920
+ call_arg_list = list(call_pos_args) + list(call_kwargs.values())
921
+
922
+ dataset_id_list = self._pred_result_id_mapping.keys()
923
+
924
+ # Note: some metric API the arguments is not like `y_true`, `y_pred`
925
+ # e.g.
926
+ # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score
927
+ # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.silhouette_score.html#sklearn.metrics.silhouette_score
928
+ for arg in call_arg_list:
929
+ if arg is not None and not np.isscalar(arg) and id(arg) in dataset_id_list:
930
+ dataset_name, run_id, model_id = self._pred_result_id_mapping[id(arg)]
931
+ break
932
+ else:
933
+ return None, None, None
934
+
935
+ return run_id, dataset_name, model_id
936
+
937
+ def log_post_training_metric(self, run_id, key, value, model_id=None):
938
+ """
939
+ Log the metric into the specified mlflow run.
940
+ and it will also update the metric_info artifact if needed.
941
+ If model_id is not None, metrics are logged into the model as well.
942
+ """
943
+ # Note: if the case log the same metric key multiple times,
944
+ # newer value will overwrite old value
945
+ client = MlflowClient()
946
+ client.log_metric(run_id=run_id, key=key, value=value, model_id=model_id)
947
+ if self._metric_info_artifact_need_update[run_id]:
948
+ call_commands_list = []
949
+ for v in self._metric_api_call_info[run_id].values():
950
+ call_commands_list.extend(v)
951
+
952
+ call_commands_list.sort(key=lambda x: x[0])
953
+ dict_to_log = OrderedDict(call_commands_list)
954
+ client.log_dict(run_id=run_id, dictionary=dict_to_log, artifact_file="metric_info.json")
955
+ self._metric_info_artifact_need_update[run_id] = False
956
+
957
+
958
+ # The global `_AutologgingMetricsManager` instance which holds information used in
959
+ # post-training metric autologging. See doc of class `_AutologgingMetricsManager` for details.
960
+ _AUTOLOGGING_METRICS_MANAGER = _AutologgingMetricsManager()
961
+
962
+
963
+ _metric_api_excluding_list = ["check_scoring", "get_scorer", "make_scorer", "get_scorer_names"]
964
+
965
+
966
+ def _get_metric_name_list():
967
+ """
968
+ Return metric function name list in `sklearn.metrics` module
969
+ """
970
+ from sklearn import metrics
971
+
972
+ metric_list = []
973
+ for metric_method_name in metrics.__all__:
974
+ # excludes plot_* methods
975
+ # exclude class (e.g. metrics.ConfusionMatrixDisplay)
976
+ metric_method = getattr(metrics, metric_method_name)
977
+ if (
978
+ metric_method_name not in _metric_api_excluding_list
979
+ and not inspect.isclass(metric_method)
980
+ and callable(metric_method)
981
+ and not metric_method_name.startswith("plot_")
982
+ ):
983
+ metric_list.append(metric_method_name)
984
+ return metric_list
985
+
986
+
987
+ def _patch_estimator_method_if_available(
988
+ flavor_name, class_def, func_name, patched_fn, manage_run, extra_tags=None
989
+ ):
990
+ if not hasattr(class_def, func_name):
991
+ return
992
+
993
+ original = gorilla.get_original_attribute(
994
+ class_def, func_name, bypass_descriptor_protocol=False
995
+ )
996
+ # Retrieve raw attribute while bypassing the descriptor protocol
997
+ raw_original_obj = gorilla.get_original_attribute(
998
+ class_def, func_name, bypass_descriptor_protocol=True
999
+ )
1000
+ if raw_original_obj == original and (callable(original) or isinstance(original, property)):
1001
+ # normal method or property decorated method
1002
+ safe_patch(
1003
+ flavor_name,
1004
+ class_def,
1005
+ func_name,
1006
+ patched_fn,
1007
+ manage_run=manage_run,
1008
+ extra_tags=extra_tags,
1009
+ )
1010
+ elif hasattr(raw_original_obj, "delegate_names") or hasattr(raw_original_obj, "check"):
1011
+ # sklearn delegated method
1012
+ safe_patch(
1013
+ flavor_name,
1014
+ raw_original_obj,
1015
+ "fn",
1016
+ patched_fn,
1017
+ manage_run=manage_run,
1018
+ extra_tags=extra_tags,
1019
+ )
1020
+ else:
1021
+ # unsupported method type. skip patching
1022
+ pass
1023
+
1024
+
1025
+ @autologging_integration(FLAVOR_NAME)
1026
+ def autolog(
1027
+ log_input_examples=False,
1028
+ log_model_signatures=True,
1029
+ log_models=True,
1030
+ log_datasets=True,
1031
+ disable=False,
1032
+ exclusive=False,
1033
+ disable_for_unsupported_versions=False,
1034
+ silent=False,
1035
+ max_tuning_runs=5,
1036
+ log_post_training_metrics=True,
1037
+ serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
1038
+ registered_model_name=None,
1039
+ pos_label=None,
1040
+ extra_tags=None,
1041
+ ):
1042
+ """
1043
+ Enables (or disables) and configures autologging for scikit-learn estimators.
1044
+
1045
+ **When is autologging performed?**
1046
+ Autologging is performed when you call:
1047
+
1048
+ - ``estimator.fit()``
1049
+ - ``estimator.fit_predict()``
1050
+ - ``estimator.fit_transform()``
1051
+
1052
+ **Logged information**
1053
+ **Parameters**
1054
+ - Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params``
1055
+ is called with ``deep=True``. This means when you fit a meta estimator that chains
1056
+ a series of estimators, the parameters of these child estimators are also logged.
1057
+
1058
+ **Training metrics**
1059
+ - A training score obtained by ``estimator.score``. Note that the training score is
1060
+ computed using parameters given to ``fit()``.
1061
+ - Common metrics for classifier:
1062
+
1063
+ - `precision score`_
1064
+
1065
+ .. _precision score:
1066
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html
1067
+
1068
+ - `recall score`_
1069
+
1070
+ .. _recall score:
1071
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html
1072
+
1073
+ - `f1 score`_
1074
+
1075
+ .. _f1 score:
1076
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
1077
+
1078
+ - `accuracy score`_
1079
+
1080
+ .. _accuracy score:
1081
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html
1082
+
1083
+ If the classifier has method ``predict_proba``, we additionally log:
1084
+
1085
+ - `log loss`_
1086
+
1087
+ .. _log loss:
1088
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html
1089
+
1090
+ - `roc auc score`_
1091
+
1092
+ .. _roc auc score:
1093
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
1094
+
1095
+ - Common metrics for regressor:
1096
+
1097
+ - `mean squared error`_
1098
+
1099
+ .. _mean squared error:
1100
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html
1101
+
1102
+ - root mean squared error
1103
+
1104
+ - `mean absolute error`_
1105
+
1106
+ .. _mean absolute error:
1107
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html
1108
+
1109
+ - `r2 score`_
1110
+
1111
+ .. _r2 score:
1112
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html
1113
+
1114
+ .. _post training metrics:
1115
+
1116
+ **Post training metrics**
1117
+ When users call metric APIs after model training, MLflow tries to capture the metric API
1118
+ results and log them as MLflow metrics to the Run associated with the model. The following
1119
+ types of scikit-learn metric APIs are supported:
1120
+
1121
+ - model.score
1122
+ - metric APIs defined in the `sklearn.metrics` module
1123
+
1124
+ For post training metrics autologging, the metric key format is:
1125
+ "{metric_name}[-{call_index}]_{dataset_name}"
1126
+
1127
+ - If the metric function is from `sklearn.metrics`, the MLflow "metric_name" is the
1128
+ metric function name. If the metric function is `model.score`, then "metric_name" is
1129
+ "{model_class_name}_score".
1130
+ - If multiple calls are made to the same scikit-learn metric API, each subsequent call
1131
+ adds a "call_index" (starting from 2) to the metric key.
1132
+ - MLflow uses the prediction input dataset variable name as the "dataset_name" in the
1133
+ metric key. The "prediction input dataset variable" refers to the variable which was
1134
+ used as the first argument of the associated `model.predict` or `model.score` call.
1135
+ Note: MLflow captures the "prediction input dataset" instance in the outermost call
1136
+ frame and fetches the variable name in the outermost call frame. If the "prediction
1137
+ input dataset" instance is an intermediate expression without a defined variable
1138
+ name, the dataset name is set to "unknown_dataset". If multiple "prediction input
1139
+ dataset" instances have the same variable name, then subsequent ones will append an
1140
+ index (starting from 2) to the inspected dataset name.
1141
+
1142
+ **Limitations**
1143
+ - MLflow can only map the original prediction result object returned by a model
1144
+ prediction API (including predict / predict_proba / predict_log_proba / transform,
1145
+ but excluding fit_predict / fit_transform.) to an MLflow run.
1146
+ MLflow cannot find run information
1147
+ for other objects derived from a given prediction result (e.g. by copying or selecting
1148
+ a subset of the prediction result). scikit-learn metric APIs invoked on derived objects
1149
+ do not log metrics to MLflow.
1150
+ - Autologging must be enabled before scikit-learn metric APIs are imported from
1151
+ `sklearn.metrics`. Metric APIs imported before autologging is enabled do not log
1152
+ metrics to MLflow runs.
1153
+ - If user define a scorer which is not based on metric APIs in `sklearn.metrics`, then
1154
+ then post training metric autologging for the scorer is invalid.
1155
+
1156
+ **Tags**
1157
+ - An estimator class name (e.g. "LinearRegression").
1158
+ - A fully qualified estimator class name
1159
+ (e.g. "sklearn.linear_model._base.LinearRegression").
1160
+
1161
+ **Artifacts**
1162
+ - An MLflow Model with the :py:mod:`mlflow.sklearn` flavor containing a fitted estimator
1163
+ (logged by :py:func:`mlflow.sklearn.log_model()`). The Model also contains the
1164
+ :py:mod:`mlflow.pyfunc` flavor when the scikit-learn estimator defines `predict()`.
1165
+ - For post training metrics API calls, a "metric_info.json" artifact is logged. This is a
1166
+ JSON object whose keys are MLflow post training metric names
1167
+ (see "Post training metrics" section for the key format) and whose values are the
1168
+ corresponding metric call commands that produced the metrics, e.g.
1169
+ ``accuracy_score(y_true=test_iris_y, y_pred=pred_iris_y, normalize=False)``.
1170
+
1171
+ **How does autologging work for meta estimators?**
1172
+ When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit()``, it internally calls
1173
+ ``fit()`` on its child estimators. Autologging does NOT perform logging on these constituent
1174
+ ``fit()`` calls.
1175
+
1176
+ **Parameter search**
1177
+ In addition to recording the information discussed above, autologging for parameter
1178
+ search meta estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) records child runs
1179
+ with metrics for each set of explored parameters, as well as artifacts and parameters
1180
+ for the best model (if available).
1181
+
1182
+ **Supported estimators**
1183
+ - All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators).
1184
+ - `Pipeline`_
1185
+ - Parameter search estimators (`GridSearchCV`_ and `RandomizedSearchCV`_)
1186
+
1187
+ .. _sklearn.utils.all_estimators:
1188
+ https://scikit-learn.org/stable/modules/generated/sklearn.utils.all_estimators.html
1189
+
1190
+ .. _Pipeline:
1191
+ https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html
1192
+
1193
+ .. _GridSearchCV:
1194
+ https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
1195
+
1196
+ .. _RandomizedSearchCV:
1197
+ https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html
1198
+
1199
+ **Example**
1200
+
1201
+ `See more examples <https://github.com/mlflow/mlflow/blob/master/examples/sklearn_autolog>`_
1202
+
1203
+ .. code-block:: python
1204
+
1205
+ from pprint import pprint
1206
+ import numpy as np
1207
+ from sklearn.linear_model import LinearRegression
1208
+ import mlflow
1209
+ from mlflow import MlflowClient
1210
+
1211
+
1212
+ def fetch_logged_data(run_id):
1213
+ client = MlflowClient()
1214
+ data = client.get_run(run_id).data
1215
+ tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")}
1216
+ artifacts = [f.path for f in client.list_artifacts(run_id, "model")]
1217
+ return data.params, data.metrics, tags, artifacts
1218
+
1219
+
1220
+ # enable autologging
1221
+ mlflow.sklearn.autolog()
1222
+
1223
+ # prepare training data
1224
+ X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
1225
+ y = np.dot(X, np.array([1, 2])) + 3
1226
+
1227
+ # train a model
1228
+ model = LinearRegression()
1229
+ with mlflow.start_run() as run:
1230
+ model.fit(X, y)
1231
+
1232
+ # fetch logged data
1233
+ params, metrics, tags, artifacts = fetch_logged_data(run.info.run_id)
1234
+
1235
+ pprint(params)
1236
+ # {'copy_X': 'True',
1237
+ # 'fit_intercept': 'True',
1238
+ # 'n_jobs': 'None',
1239
+ # 'normalize': 'False'}
1240
+
1241
+ pprint(metrics)
1242
+ # {'training_score': 1.0,
1243
+ # 'training_mean_absolute_error': 2.220446049250313e-16,
1244
+ # 'training_mean_squared_error': 1.9721522630525295e-31,
1245
+ # 'training_r2_score': 1.0,
1246
+ # 'training_root_mean_squared_error': 4.440892098500626e-16}
1247
+
1248
+ pprint(tags)
1249
+ # {'estimator_class': 'sklearn.linear_model._base.LinearRegression',
1250
+ # 'estimator_name': 'LinearRegression'}
1251
+
1252
+ pprint(artifacts)
1253
+ # ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl']
1254
+
1255
+ Args:
1256
+ log_input_examples: If ``True``, input examples from training datasets are collected and
1257
+ logged along with scikit-learn model artifacts during training. If
1258
+ ``False``, input examples are not logged.
1259
+ Note: Input examples are MLflow model attributes
1260
+ and are only collected if ``log_models`` is also ``True``.
1261
+ log_model_signatures: If ``True``,
1262
+ :py:class:`ModelSignatures <mlflow.models.ModelSignature>`
1263
+ describing model inputs and outputs are collected and logged along
1264
+ with scikit-learn model artifacts during training. If ``False``,
1265
+ signatures are not logged.
1266
+ Note: Model signatures are MLflow model attributes
1267
+ and are only collected if ``log_models`` is also ``True``.
1268
+ log_models: If ``True``, trained models are logged as MLflow model artifacts.
1269
+ If ``False``, trained models are not logged.
1270
+ Input examples and model signatures, which are attributes of MLflow models,
1271
+ are also omitted when ``log_models`` is ``False``.
1272
+ log_datasets: If ``True``, train and validation dataset information is logged to MLflow
1273
+ Tracking if applicable. If ``False``, dataset information is not logged.
1274
+ disable: If ``True``, disables the scikit-learn autologging integration. If ``False``,
1275
+ enables the scikit-learn autologging integration.
1276
+ exclusive: If ``True``, autologged content is not logged to user-created fluent runs.
1277
+ If ``False``, autologged content is logged to the active fluent run,
1278
+ which may be user-created.
1279
+ disable_for_unsupported_versions: If ``True``, disable autologging for versions of
1280
+ scikit-learn that have not been tested against this version of the MLflow
1281
+ client or are incompatible.
1282
+ silent: If ``True``, suppress all event logs and warnings from MLflow during scikit-learn
1283
+ autologging. If ``False``, show all events and warnings during scikit-learn
1284
+ autologging.
1285
+ max_tuning_runs: The maximum number of child MLflow runs created for hyperparameter
1286
+ search estimators. To create child runs for the best `k` results from
1287
+ the search, set `max_tuning_runs` to `k`. The default value is to track
1288
+ the best 5 search parameter sets. If `max_tuning_runs=None`, then
1289
+ a child run is created for each search parameter set. Note: The best k
1290
+ results is based on ordering in `rank_test_score`. In the case of
1291
+ multi-metric evaluation with a custom scorer, the first scorer's
1292
+ `rank_test_score_<scorer_name>` will be used to select the best k
1293
+ results. To change metric used for selecting best k results, change
1294
+ ordering of dict passed as `scoring` parameter for estimator.
1295
+ log_post_training_metrics: If ``True``, post training metrics are logged. Defaults to
1296
+ ``True``. See the `post training metrics`_ section for more
1297
+ details.
1298
+ serialization_format: The format in which to serialize the model. This should be one of
1299
+ the following: ``mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE`` or
1300
+ ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``.
1301
+ registered_model_name: If given, each time a model is trained, it is registered as a
1302
+ new model version of the registered model with this name.
1303
+ The registered model is created if it does not already exist.
1304
+ pos_label: If given, used as the positive label to compute binary classification
1305
+ training metrics such as precision, recall, f1, etc. This parameter should
1306
+ only be set for binary classification model. If used for multi-label model,
1307
+ the training metrics calculation will fail and the training metrics won't
1308
+ be logged. If used for regression model, the parameter will be ignored.
1309
+ extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
1310
+ """
1311
+ _autolog(
1312
+ flavor_name=FLAVOR_NAME,
1313
+ log_input_examples=log_input_examples,
1314
+ log_model_signatures=log_model_signatures,
1315
+ log_models=log_models,
1316
+ log_datasets=log_datasets,
1317
+ disable=disable,
1318
+ exclusive=exclusive,
1319
+ disable_for_unsupported_versions=disable_for_unsupported_versions,
1320
+ silent=silent,
1321
+ max_tuning_runs=max_tuning_runs,
1322
+ log_post_training_metrics=log_post_training_metrics,
1323
+ serialization_format=serialization_format,
1324
+ pos_label=pos_label,
1325
+ extra_tags=extra_tags,
1326
+ )
1327
+
1328
+
1329
+ def _autolog( # noqa: D417
1330
+ flavor_name=FLAVOR_NAME,
1331
+ log_input_examples=False,
1332
+ log_model_signatures=True,
1333
+ log_models=True,
1334
+ log_datasets=True,
1335
+ disable=False,
1336
+ exclusive=False,
1337
+ disable_for_unsupported_versions=False,
1338
+ silent=False,
1339
+ max_tuning_runs=5,
1340
+ log_post_training_metrics=True,
1341
+ serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
1342
+ pos_label=None,
1343
+ extra_tags=None,
1344
+ ):
1345
+ """
1346
+ Internal autologging function for scikit-learn models.
1347
+
1348
+ Args:
1349
+ flavor_name: A string value. Enable a ``mlflow.sklearn`` autologging routine
1350
+ for a flavor. By default it enables autologging for original
1351
+ scikit-learn models, as ``mlflow.sklearn.autolog()`` does. If
1352
+ the argument is `xgboost`, autologging for XGBoost scikit-learn
1353
+ models is enabled.
1354
+ """
1355
+ import pandas as pd
1356
+ import sklearn
1357
+ import sklearn.metrics
1358
+ import sklearn.model_selection
1359
+
1360
+ from mlflow.models import infer_signature
1361
+ from mlflow.sklearn.utils import (
1362
+ _TRAINING_PREFIX,
1363
+ _create_child_runs_for_parameter_search,
1364
+ _gen_lightgbm_sklearn_estimators_to_patch,
1365
+ _gen_xgboost_sklearn_estimators_to_patch,
1366
+ _get_estimator_info_tags,
1367
+ _get_X_y_and_sample_weight,
1368
+ _is_parameter_search_estimator,
1369
+ _log_estimator_content,
1370
+ _log_parameter_search_results_as_artifact,
1371
+ )
1372
+ from mlflow.tracking.context import registry as context_registry
1373
+
1374
+ if max_tuning_runs is not None and max_tuning_runs < 0:
1375
+ raise MlflowException(
1376
+ message=f"`max_tuning_runs` must be non-negative, instead got {max_tuning_runs}.",
1377
+ error_code=INVALID_PARAMETER_VALUE,
1378
+ )
1379
+
1380
+ def fit_mlflow_xgboost_and_lightgbm(original, self, *args, **kwargs):
1381
+ """
1382
+ Autologging function for XGBoost and LightGBM scikit-learn models
1383
+ """
1384
+ # Obtain a copy of a model input example from the training dataset prior to model training
1385
+ # for subsequent use during model logging, ensuring that the input example and inferred
1386
+ # model signature to not include any mutations from model training
1387
+ input_example_exc = None
1388
+ try:
1389
+ input_example = deepcopy(
1390
+ _get_X_y_and_sample_weight(self.fit, args, kwargs)[0][:INPUT_EXAMPLE_SAMPLE_ROWS]
1391
+ )
1392
+ except Exception as e:
1393
+ input_example_exc = e
1394
+
1395
+ def get_input_example():
1396
+ if input_example_exc is not None:
1397
+ raise input_example_exc
1398
+ else:
1399
+ return input_example
1400
+
1401
+ # parameter, metric, and non-model artifact logging are done in
1402
+ # `train()` in `mlflow.xgboost.autolog()` and `mlflow.lightgbm.autolog()`
1403
+ fit_output = original(self, *args, **kwargs)
1404
+ # log models after training
1405
+ if log_models:
1406
+ input_example, signature = resolve_input_example_and_signature(
1407
+ get_input_example,
1408
+ lambda input_example: infer_signature(
1409
+ input_example,
1410
+ # Copy the input example so that it is not mutated by the call to
1411
+ # predict() prior to signature inference
1412
+ self.predict(deepcopy(input_example)),
1413
+ ),
1414
+ log_input_examples,
1415
+ log_model_signatures,
1416
+ _logger,
1417
+ )
1418
+ log_model_func = (
1419
+ mlflow.xgboost.log_model
1420
+ if flavor_name == mlflow.xgboost.FLAVOR_NAME
1421
+ else mlflow.lightgbm.log_model
1422
+ )
1423
+ registered_model_name = get_autologging_config(
1424
+ flavor_name, "registered_model_name", None
1425
+ )
1426
+ if flavor_name == mlflow.xgboost.FLAVOR_NAME:
1427
+ model_format = get_autologging_config(flavor_name, "model_format", "xgb")
1428
+ model_info = log_model_func(
1429
+ self,
1430
+ "model",
1431
+ signature=signature,
1432
+ input_example=input_example,
1433
+ registered_model_name=registered_model_name,
1434
+ model_format=model_format,
1435
+ )
1436
+ else:
1437
+ model_info = log_model_func(
1438
+ self,
1439
+ "model",
1440
+ signature=signature,
1441
+ input_example=input_example,
1442
+ registered_model_name=registered_model_name,
1443
+ )
1444
+ _AUTOLOGGING_METRICS_MANAGER.record_model_id(self, model_info.model_id)
1445
+ return fit_output
1446
+
1447
+ def fit_mlflow(original, self, *args, **kwargs):
1448
+ """
1449
+ Autologging function that performs model training by executing the training method
1450
+ referred to be `func_name` on the instance of `clazz` referred to by `self` & records
1451
+ MLflow parameters, metrics, tags, and artifacts to a corresponding MLflow Run.
1452
+ """
1453
+ # Obtain a copy of the training dataset prior to model training for subsequent
1454
+ # use during model logging & input example extraction, ensuring that we don't
1455
+ # attempt to infer input examples on data that was mutated during training
1456
+ (X, y_true, sample_weight) = _get_X_y_and_sample_weight(self.fit, args, kwargs)
1457
+ autologging_client = MlflowAutologgingQueueingClient()
1458
+ _log_pretraining_metadata(autologging_client, self, X, y_true)
1459
+ params_logging_future = autologging_client.flush(synchronous=False)
1460
+ fit_output = original(self, *args, **kwargs)
1461
+ _log_posttraining_metadata(autologging_client, self, X, y_true, sample_weight)
1462
+ autologging_client.flush(synchronous=True)
1463
+ params_logging_future.await_completion()
1464
+ return fit_output
1465
+
1466
+ def _log_pretraining_metadata(autologging_client, estimator, X, y): # noqa: D417
1467
+ """
1468
+ Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training.
1469
+ This is intended to be invoked within a patched scikit-learn training routine
1470
+ (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
1471
+ MLflow run that can be referenced via the fluent Tracking API.
1472
+
1473
+ Args:
1474
+ autologging_client: An instance of `MlflowAutologgingQueueingClient` used for
1475
+ efficiently logging run data to MLflow Tracking.
1476
+ estimator: The scikit-learn estimator for which to log metadata.
1477
+ """
1478
+ # Deep parameter logging includes parameters from children of a given
1479
+ # estimator. For some meta estimators (e.g., pipelines), recording
1480
+ # these parameters is desirable. For parameter search estimators,
1481
+ # however, child estimators act as seeds for the parameter search
1482
+ # process; accordingly, we avoid logging initial, untuned parameters
1483
+ # for these seed estimators.
1484
+ should_log_params_deeply = not _is_parameter_search_estimator(estimator)
1485
+ run_id = mlflow.active_run().info.run_id
1486
+ autologging_client.log_params(
1487
+ run_id=mlflow.active_run().info.run_id,
1488
+ params=estimator.get_params(deep=should_log_params_deeply),
1489
+ )
1490
+ autologging_client.set_tags(
1491
+ run_id=run_id,
1492
+ tags=_get_estimator_info_tags(estimator),
1493
+ )
1494
+
1495
+ if log_datasets:
1496
+ try:
1497
+ context_tags = context_registry.resolve_tags()
1498
+ source = CodeDatasetSource(context_tags)
1499
+
1500
+ dataset = _create_dataset(X, source, y)
1501
+ if dataset:
1502
+ tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value="train")]
1503
+ dataset_input = DatasetInput(dataset=dataset._to_mlflow_entity(), tags=tags)
1504
+
1505
+ autologging_client.log_inputs(
1506
+ run_id=mlflow.active_run().info.run_id, datasets=[dataset_input]
1507
+ )
1508
+ except Exception as e:
1509
+ _logger.warning(
1510
+ "Failed to log training dataset information to MLflow Tracking. Reason: %s", e
1511
+ )
1512
+
1513
+ def _log_posttraining_metadata(autologging_client, estimator, X, y, sample_weight):
1514
+ """
1515
+ Records metadata for a scikit-learn estimator after training has completed.
1516
+ This is intended to be invoked within a patched scikit-learn training routine
1517
+ (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
1518
+ MLflow run that can be referenced via the fluent Tracking API.
1519
+
1520
+ Args:
1521
+ autologging_client: An instance of `MlflowAutologgingQueueingClient` used for
1522
+ efficiently logging run data to MLflow Tracking.
1523
+ estimator: The scikit-learn estimator for which to log metadata.
1524
+ X: The training dataset samples passed to the ``estimator.fit()`` function.
1525
+ y: The training dataset labels passed to the ``estimator.fit()`` function.
1526
+ sample_weight: Sample weights passed to the ``estimator.fit()`` function.
1527
+ """
1528
+ # Fetch an input example using the first several rows of the array-like
1529
+ # training data supplied to the training routine (e.g., `fit()`). Copy the
1530
+ # example to avoid mutation during subsequent metric computations
1531
+ input_example_exc = None
1532
+ try:
1533
+ input_example = deepcopy(X[:INPUT_EXAMPLE_SAMPLE_ROWS])
1534
+ except Exception as e:
1535
+ input_example_exc = e
1536
+
1537
+ def get_input_example():
1538
+ if input_example_exc is not None:
1539
+ raise input_example_exc
1540
+ else:
1541
+ return input_example
1542
+
1543
+ def infer_model_signature(input_example):
1544
+ if hasattr(estimator, "predict"):
1545
+ # Copy the input example so that it is not mutated by the call to
1546
+ # predict() prior to signature inference
1547
+ model_output = estimator.predict(deepcopy(input_example))
1548
+ elif hasattr(estimator, "transform"):
1549
+ model_output = estimator.transform(deepcopy(input_example))
1550
+ else:
1551
+ raise Exception(
1552
+ "the trained model does not have a `predict` or `transform` "
1553
+ "function, which is required in order to infer the signature"
1554
+ )
1555
+
1556
+ return infer_signature(input_example, model_output)
1557
+
1558
+ def _log_model_with_except_handling(*args, **kwargs):
1559
+ try:
1560
+ return log_model(*args, **kwargs)
1561
+ except _SklearnCustomModelPicklingError as e:
1562
+ _logger.warning(str(e))
1563
+
1564
+ model_id = None
1565
+ if log_models:
1566
+ # Will only resolve `input_example` and `signature` if `log_models` is `True`.
1567
+ input_example, signature = resolve_input_example_and_signature(
1568
+ get_input_example,
1569
+ infer_model_signature,
1570
+ log_input_examples,
1571
+ log_model_signatures,
1572
+ _logger,
1573
+ )
1574
+ registered_model_name = get_autologging_config(
1575
+ FLAVOR_NAME, "registered_model_name", None
1576
+ )
1577
+ should_log_params_deeply = not _is_parameter_search_estimator(estimator)
1578
+ params = estimator.get_params(deep=should_log_params_deeply)
1579
+ if hasattr(estimator, "best_params_"):
1580
+ params |= {
1581
+ f"best_{param_name}": param_value
1582
+ for param_name, param_value in estimator.best_params_.items()
1583
+ }
1584
+ if logged_model := _log_model_with_except_handling(
1585
+ estimator,
1586
+ name="model",
1587
+ signature=signature,
1588
+ input_example=input_example,
1589
+ serialization_format=serialization_format,
1590
+ registered_model_name=registered_model_name,
1591
+ params=params,
1592
+ ):
1593
+ model_id = logged_model.model_id
1594
+ _AUTOLOGGING_METRICS_MANAGER.record_model_id(estimator, logged_model.model_id)
1595
+
1596
+ # log common metrics and artifacts for estimators (classifier, regressor)
1597
+ context_tags = context_registry.resolve_tags()
1598
+ source = CodeDatasetSource(context_tags)
1599
+ try:
1600
+ dataset = _create_dataset(X, source, y)
1601
+ except Exception:
1602
+ _logger.debug("Failed to create dataset for logging.", exc_info=True)
1603
+ dataset = None
1604
+ logged_metrics = _log_estimator_content(
1605
+ autologging_client=autologging_client,
1606
+ estimator=estimator,
1607
+ prefix=_TRAINING_PREFIX,
1608
+ run_id=mlflow.active_run().info.run_id,
1609
+ X=X,
1610
+ y_true=y,
1611
+ sample_weight=sample_weight,
1612
+ pos_label=pos_label,
1613
+ dataset=dataset,
1614
+ model_id=model_id,
1615
+ )
1616
+ if y is None and not logged_metrics:
1617
+ _logger.warning(
1618
+ "Training metrics will not be recorded because training labels were not specified."
1619
+ " To automatically record training metrics, provide training labels as inputs to"
1620
+ " the model training function."
1621
+ )
1622
+
1623
+ best_estimator_model_id = None
1624
+ best_estimator_params = None
1625
+ if _is_parameter_search_estimator(estimator):
1626
+ if hasattr(estimator, "best_estimator_") and log_models:
1627
+ best_estimator_params = estimator.best_estimator_.get_params(deep=True)
1628
+ if model_info := _log_model_with_except_handling(
1629
+ estimator.best_estimator_,
1630
+ name="best_estimator",
1631
+ signature=signature,
1632
+ input_example=input_example,
1633
+ serialization_format=serialization_format,
1634
+ params=best_estimator_params,
1635
+ ):
1636
+ best_estimator_model_id = model_info.model_id
1637
+
1638
+ if hasattr(estimator, "best_score_"):
1639
+ autologging_client.log_metrics(
1640
+ run_id=mlflow.active_run().info.run_id,
1641
+ metrics={"best_cv_score": estimator.best_score_},
1642
+ dataset=dataset,
1643
+ model_id=model_id,
1644
+ )
1645
+
1646
+ if hasattr(estimator, "best_params_"):
1647
+ best_params = {
1648
+ f"best_{param_name}": param_value
1649
+ for param_name, param_value in estimator.best_params_.items()
1650
+ }
1651
+ autologging_client.log_params(
1652
+ run_id=mlflow.active_run().info.run_id,
1653
+ params=best_params,
1654
+ )
1655
+
1656
+ if hasattr(estimator, "cv_results_"):
1657
+ try:
1658
+ # Fetch environment-specific tags (e.g., user and source) to ensure that lineage
1659
+ # information is consistent with the parent run
1660
+ child_tags = context_registry.resolve_tags()
1661
+ child_tags.update({MLFLOW_AUTOLOGGING: flavor_name})
1662
+ _create_child_runs_for_parameter_search(
1663
+ autologging_client=autologging_client,
1664
+ cv_estimator=estimator,
1665
+ parent_run=mlflow.active_run(),
1666
+ max_tuning_runs=max_tuning_runs,
1667
+ child_tags=child_tags,
1668
+ dataset=dataset,
1669
+ best_estimator_params=best_estimator_params,
1670
+ best_estimator_model_id=best_estimator_model_id,
1671
+ )
1672
+ except Exception as e:
1673
+ _logger.warning(
1674
+ "Encountered exception during creation of child runs for parameter search."
1675
+ f" Child runs may be missing. Exception: {e}"
1676
+ )
1677
+
1678
+ try:
1679
+ cv_results_df = pd.DataFrame.from_dict(estimator.cv_results_)
1680
+ _log_parameter_search_results_as_artifact(
1681
+ cv_results_df, mlflow.active_run().info.run_id
1682
+ )
1683
+ except Exception as e:
1684
+ _logger.warning(
1685
+ f"Failed to log parameter search results as an artifact. Exception: {e}"
1686
+ )
1687
+
1688
+ def patched_fit(fit_impl, allow_children_patch, original, self, *args, **kwargs):
1689
+ """
1690
+ Autologging patch function to be applied to a sklearn model class that defines a `fit`
1691
+ method and inherits from `BaseEstimator` (thereby defining the `get_params()` method)
1692
+
1693
+ Args:
1694
+ fit_impl: The patched fit function implementation, the function should be defined as
1695
+ `fit_mlflow(original, self, *args, **kwargs)`, the `original` argument
1696
+ refers to the original `EstimatorClass.fit` method, the `self` argument
1697
+ refers to the estimator instance being patched, the `*args` and
1698
+ `**kwargs` are arguments passed to the original fit method.
1699
+ allow_children_patch: Whether to allow children sklearn session logging or not.
1700
+ original: the original `EstimatorClass.fit` method to be patched.
1701
+ self: the estimator instance being patched.
1702
+ args: positional arguments to be passed to the original fit method.
1703
+ kwargs: keyword arguments to be passed to the original fit method.
1704
+ """
1705
+ should_log_post_training_metrics = (
1706
+ log_post_training_metrics
1707
+ and _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics()
1708
+ )
1709
+
1710
+ with _SklearnTrainingSession(estimator=self, allow_children=allow_children_patch) as t:
1711
+ if t.should_log():
1712
+ # In `fit_mlflow` call, it will also call metric API for computing training metrics
1713
+ # so we need temporarily disable the post_training_metrics patching.
1714
+ with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
1715
+ result = fit_impl(original, self, *args, **kwargs)
1716
+ if should_log_post_training_metrics:
1717
+ _AUTOLOGGING_METRICS_MANAGER.register_model(
1718
+ self, mlflow.active_run().info.run_id
1719
+ )
1720
+ return result
1721
+ else:
1722
+ return original(self, *args, **kwargs)
1723
+
1724
+ def patched_predict(original, self, *args, **kwargs):
1725
+ """
1726
+ In `patched_predict`, register the prediction result instance with the run id and
1727
+ eval dataset name. e.g.
1728
+ ```
1729
+ prediction_result = model_1.predict(eval_X)
1730
+ ```
1731
+ then we need register the following relationship into the `_AUTOLOGGING_METRICS_MANAGER`:
1732
+ id(prediction_result) --> (eval_dataset_name, run_id)
1733
+
1734
+ Note: we cannot set additional attributes "eval_dataset_name" and "run_id" into
1735
+ the prediction_result object, because certain dataset type like numpy does not support
1736
+ additional attribute assignment.
1737
+ """
1738
+ run_id = _AUTOLOGGING_METRICS_MANAGER.get_run_id_for_model(self)
1739
+ if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics() and run_id:
1740
+ # Avoid nested patch when nested inference calls happens.
1741
+ with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
1742
+ predict_result = original(self, *args, **kwargs)
1743
+ eval_dataset = get_instance_method_first_arg_value(original, args, kwargs)
1744
+ eval_dataset_name = _AUTOLOGGING_METRICS_MANAGER.register_prediction_input_dataset(
1745
+ self, eval_dataset
1746
+ )
1747
+ _AUTOLOGGING_METRICS_MANAGER.register_prediction_result(
1748
+ run_id,
1749
+ eval_dataset_name,
1750
+ predict_result,
1751
+ model_id=_AUTOLOGGING_METRICS_MANAGER.get_model_id_for_model(self),
1752
+ )
1753
+ if log_datasets:
1754
+ try:
1755
+ context_tags = context_registry.resolve_tags()
1756
+ source = CodeDatasetSource(context_tags)
1757
+
1758
+ dataset = _create_dataset(eval_dataset, source)
1759
+
1760
+ # log the dataset
1761
+ if dataset:
1762
+ tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value="eval")]
1763
+ dataset_input = DatasetInput(dataset=dataset._to_mlflow_entity(), tags=tags)
1764
+
1765
+ # log the dataset
1766
+ client = mlflow.MlflowClient()
1767
+ client.log_inputs(run_id=run_id, datasets=[dataset_input])
1768
+ except Exception as e:
1769
+ _logger.warning(
1770
+ "Failed to log evaluation dataset information to "
1771
+ "MLflow Tracking. Reason: %s",
1772
+ e,
1773
+ )
1774
+ return predict_result
1775
+ else:
1776
+ return original(self, *args, **kwargs)
1777
+
1778
+ def patched_metric_api(original, *args, **kwargs):
1779
+ if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics():
1780
+ # one metric api may call another metric api,
1781
+ # to avoid this, call disable_log_post_training_metrics to avoid nested patch
1782
+ with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
1783
+ metric = original(*args, **kwargs)
1784
+
1785
+ if _AUTOLOGGING_METRICS_MANAGER.is_metric_value_loggable(metric):
1786
+ metric_name = original.__name__
1787
+ call_command = _AUTOLOGGING_METRICS_MANAGER.gen_metric_call_command(
1788
+ None, original, *args, **kwargs
1789
+ )
1790
+
1791
+ (run_id, dataset_name, model_id) = (
1792
+ _AUTOLOGGING_METRICS_MANAGER.get_info_for_metric_api_call(args, kwargs)
1793
+ )
1794
+ if run_id and dataset_name:
1795
+ metric_key = _AUTOLOGGING_METRICS_MANAGER.register_metric_api_call(
1796
+ run_id, metric_name, dataset_name, call_command
1797
+ )
1798
+ _AUTOLOGGING_METRICS_MANAGER.log_post_training_metric(
1799
+ run_id, metric_key, metric, model_id=model_id
1800
+ )
1801
+
1802
+ return metric
1803
+ else:
1804
+ return original(*args, **kwargs)
1805
+
1806
+ # we need patch model.score method because:
1807
+ # some model.score() implementation won't call metric APIs in `sklearn.metrics`
1808
+ # e.g.
1809
+ # https://github.com/scikit-learn/scikit-learn/blob/82df48934eba1df9a1ed3be98aaace8eada59e6e/sklearn/covariance/_empirical_covariance.py#L220
1810
+ def patched_model_score(original, self, *args, **kwargs):
1811
+ run_id = _AUTOLOGGING_METRICS_MANAGER.get_run_id_for_model(self)
1812
+ if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics() and run_id:
1813
+ # `model.score` may call metric APIs internally, in order to prevent nested metric call
1814
+ # being logged, temporarily disable post_training_metrics patching.
1815
+ with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
1816
+ score_value = original(self, *args, **kwargs)
1817
+
1818
+ if _AUTOLOGGING_METRICS_MANAGER.is_metric_value_loggable(score_value):
1819
+ metric_name = f"{self.__class__.__name__}_score"
1820
+ call_command = _AUTOLOGGING_METRICS_MANAGER.gen_metric_call_command(
1821
+ self, original, *args, **kwargs
1822
+ )
1823
+
1824
+ eval_dataset = get_instance_method_first_arg_value(original, args, kwargs)
1825
+ eval_dataset_name = _AUTOLOGGING_METRICS_MANAGER.register_prediction_input_dataset(
1826
+ self, eval_dataset
1827
+ )
1828
+ metric_key = _AUTOLOGGING_METRICS_MANAGER.register_metric_api_call(
1829
+ run_id, metric_name, eval_dataset_name, call_command
1830
+ )
1831
+ model_id = _AUTOLOGGING_METRICS_MANAGER.get_model_id_for_model(self)
1832
+ _AUTOLOGGING_METRICS_MANAGER.log_post_training_metric(
1833
+ run_id, metric_key, score_value, model_id=model_id
1834
+ )
1835
+
1836
+ return score_value
1837
+ else:
1838
+ return original(self, *args, **kwargs)
1839
+
1840
+ def _apply_sklearn_descriptor_unbound_method_call_fix():
1841
+ import sklearn
1842
+
1843
+ if Version(sklearn.__version__) <= Version("0.24.2"):
1844
+ import sklearn.utils.metaestimators
1845
+
1846
+ if not hasattr(sklearn.utils.metaestimators, "_IffHasAttrDescriptor"):
1847
+ return
1848
+
1849
+ def patched_IffHasAttrDescriptor__get__(self, obj, type=None):
1850
+ """
1851
+ For sklearn version <= 0.24.2, `_IffHasAttrDescriptor.__get__` method does not
1852
+ support unbound method call.
1853
+ See https://github.com/scikit-learn/scikit-learn/issues/20614
1854
+ This patched function is for hot patch.
1855
+ """
1856
+
1857
+ # raise an AttributeError if the attribute is not present on the object
1858
+ if obj is not None:
1859
+ # delegate only on instances, not the classes.
1860
+ # this is to allow access to the docstrings.
1861
+ for delegate_name in self.delegate_names:
1862
+ try:
1863
+ delegate = sklearn.utils.metaestimators.attrgetter(delegate_name)(obj)
1864
+ except AttributeError:
1865
+ continue
1866
+ else:
1867
+ getattr(delegate, self.attribute_name)
1868
+ break
1869
+ else:
1870
+ sklearn.utils.metaestimators.attrgetter(self.delegate_names[-1])(obj)
1871
+
1872
+ def out(*args, **kwargs):
1873
+ return self.fn(obj, *args, **kwargs)
1874
+
1875
+ else:
1876
+ # This makes it possible to use the decorated method as an unbound method,
1877
+ # for instance when monkeypatching.
1878
+ def out(*args, **kwargs):
1879
+ return self.fn(*args, **kwargs)
1880
+
1881
+ # update the docstring of the returned function
1882
+ functools.update_wrapper(out, self.fn)
1883
+ return out
1884
+
1885
+ update_wrapper_extended(
1886
+ patched_IffHasAttrDescriptor__get__,
1887
+ sklearn.utils.metaestimators._IffHasAttrDescriptor.__get__,
1888
+ )
1889
+
1890
+ sklearn.utils.metaestimators._IffHasAttrDescriptor.__get__ = (
1891
+ patched_IffHasAttrDescriptor__get__
1892
+ )
1893
+
1894
+ _apply_sklearn_descriptor_unbound_method_call_fix()
1895
+
1896
+ if flavor_name == mlflow.xgboost.FLAVOR_NAME:
1897
+ estimators_to_patch = _gen_xgboost_sklearn_estimators_to_patch()
1898
+ patched_fit_impl = fit_mlflow_xgboost_and_lightgbm
1899
+ allow_children_patch = True
1900
+ elif flavor_name == mlflow.lightgbm.FLAVOR_NAME:
1901
+ estimators_to_patch = _gen_lightgbm_sklearn_estimators_to_patch()
1902
+ patched_fit_impl = fit_mlflow_xgboost_and_lightgbm
1903
+ allow_children_patch = True
1904
+ else:
1905
+ estimators_to_patch = _gen_estimators_to_patch()
1906
+ patched_fit_impl = fit_mlflow
1907
+ allow_children_patch = False
1908
+
1909
+ for class_def in estimators_to_patch:
1910
+ # Patch fitting methods
1911
+ for func_name in ["fit", "fit_transform", "fit_predict"]:
1912
+ _patch_estimator_method_if_available(
1913
+ flavor_name,
1914
+ class_def,
1915
+ func_name,
1916
+ functools.partial(patched_fit, patched_fit_impl, allow_children_patch),
1917
+ manage_run=True,
1918
+ extra_tags=extra_tags,
1919
+ )
1920
+
1921
+ # Patch inference methods
1922
+ for func_name in ["predict", "predict_proba", "transform", "predict_log_proba"]:
1923
+ _patch_estimator_method_if_available(
1924
+ flavor_name,
1925
+ class_def,
1926
+ func_name,
1927
+ patched_predict,
1928
+ manage_run=False,
1929
+ )
1930
+
1931
+ # Patch scoring methods
1932
+ _patch_estimator_method_if_available(
1933
+ flavor_name,
1934
+ class_def,
1935
+ "score",
1936
+ patched_model_score,
1937
+ manage_run=False,
1938
+ extra_tags=extra_tags,
1939
+ )
1940
+
1941
+ if log_post_training_metrics:
1942
+ for metric_name in _get_metric_name_list():
1943
+ safe_patch(
1944
+ flavor_name, sklearn.metrics, metric_name, patched_metric_api, manage_run=False
1945
+ )
1946
+
1947
+ # `sklearn.metrics.SCORERS` was removed in scikit-learn 1.3
1948
+ if hasattr(sklearn.metrics, "get_scorer_names"):
1949
+ for scoring in sklearn.metrics.get_scorer_names():
1950
+ scorer = sklearn.metrics.get_scorer(scoring)
1951
+ safe_patch(flavor_name, scorer, "_score_func", patched_metric_api, manage_run=False)
1952
+ else:
1953
+ for scorer in sklearn.metrics.SCORERS.values():
1954
+ safe_patch(flavor_name, scorer, "_score_func", patched_metric_api, manage_run=False)
1955
+
1956
+ def patched_fn_with_autolog_disabled(original, *args, **kwargs):
1957
+ with disable_autologging():
1958
+ return original(*args, **kwargs)
1959
+
1960
+ for disable_autolog_func_name in _apis_autologging_disabled:
1961
+ safe_patch(
1962
+ flavor_name,
1963
+ sklearn.model_selection,
1964
+ disable_autolog_func_name,
1965
+ patched_fn_with_autolog_disabled,
1966
+ manage_run=False,
1967
+ )
1968
+
1969
+ def _create_dataset(X, source, y=None, dataset_name=None):
1970
+ # create a dataset
1971
+ from scipy.sparse import issparse
1972
+
1973
+ if isinstance(X, pd.DataFrame):
1974
+ dataset = from_pandas(df=X, source=source)
1975
+ elif issparse(X):
1976
+ arr_X = X.toarray()
1977
+ if y is not None:
1978
+ dataset = from_numpy(
1979
+ features=arr_X,
1980
+ targets=y.toarray() if issparse(y) else y,
1981
+ source=source,
1982
+ name=dataset_name,
1983
+ )
1984
+ else:
1985
+ dataset = from_numpy(features=arr_X, source=source, name=dataset_name)
1986
+ elif isinstance(X, np.ndarray):
1987
+ if y is not None:
1988
+ dataset = from_numpy(features=X, targets=y, source=source, name=dataset_name)
1989
+ else:
1990
+ dataset = from_numpy(features=X, source=source, name=dataset_name)
1991
+ else:
1992
+ _logger.warning("Unrecognized dataset type %s. Dataset logging skipped.", type(X))
1993
+ return None
1994
+ return dataset