genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,996 @@
|
|
1
|
+
import copy
|
2
|
+
import inspect
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import pathlib
|
6
|
+
import pickle
|
7
|
+
import shutil
|
8
|
+
import tempfile
|
9
|
+
import traceback
|
10
|
+
from abc import abstractmethod
|
11
|
+
from typing import Any, Callable, NamedTuple, Optional, Union
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
import pandas as pd
|
15
|
+
|
16
|
+
import mlflow
|
17
|
+
from mlflow import MlflowClient, MlflowException
|
18
|
+
from mlflow.data.evaluation_dataset import EvaluationDataset
|
19
|
+
from mlflow.entities.metric import Metric
|
20
|
+
from mlflow.metrics.base import MetricValue
|
21
|
+
from mlflow.models.evaluation.artifacts import (
|
22
|
+
CsvEvaluationArtifact,
|
23
|
+
ImageEvaluationArtifact,
|
24
|
+
JsonEvaluationArtifact,
|
25
|
+
NumpyEvaluationArtifact,
|
26
|
+
_infer_artifact_type_and_ext,
|
27
|
+
)
|
28
|
+
from mlflow.models.evaluation.base import EvaluationMetric, EvaluationResult, ModelEvaluator
|
29
|
+
from mlflow.models.evaluation.utils.metric import MetricDefinition
|
30
|
+
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
|
31
|
+
from mlflow.pyfunc import _ServedPyFuncModel
|
32
|
+
from mlflow.utils.file_utils import TempDir
|
33
|
+
from mlflow.utils.proto_json_utils import NumpyEncoder
|
34
|
+
from mlflow.utils.time import get_current_time_millis
|
35
|
+
|
36
|
+
_logger = logging.getLogger(__name__)
|
37
|
+
|
38
|
+
_EVAL_TABLE_FILE_NAME = "eval_results_table.json"
|
39
|
+
_TOKEN_COUNT_METRIC_NAME = "token_count"
|
40
|
+
_LATENCY_METRIC_NAME = "latency"
|
41
|
+
|
42
|
+
|
43
|
+
def _extract_raw_model(model):
|
44
|
+
if not getattr(model, "metadata", None):
|
45
|
+
return None, None
|
46
|
+
|
47
|
+
model_loader_module = model.metadata.flavors["python_function"]["loader_module"]
|
48
|
+
# If we load a model with mlflow.pyfunc.load_model, the model will be wrapped
|
49
|
+
# with a pyfunc wrapper. We need to extract the raw model so that shap
|
50
|
+
# explainer uses the raw model instead of the wrapper and skips data schema validation.
|
51
|
+
if model_loader_module in [
|
52
|
+
"mlflow.catboost",
|
53
|
+
"mlflow.sklearn",
|
54
|
+
"mlflow.xgboost",
|
55
|
+
] and not isinstance(model, _ServedPyFuncModel):
|
56
|
+
if hasattr(model._model_impl, "get_raw_model"):
|
57
|
+
return model_loader_module, model._model_impl.get_raw_model()
|
58
|
+
return model_loader_module, model._model_impl
|
59
|
+
else:
|
60
|
+
return model_loader_module, None
|
61
|
+
|
62
|
+
|
63
|
+
def _extract_output_and_other_columns(
|
64
|
+
model_predictions: Union[list[Any], dict[str, Any], pd.DataFrame, pd.Series],
|
65
|
+
output_column_name: Optional[str],
|
66
|
+
) -> tuple[pd.Series, Optional[pd.DataFrame], str]:
|
67
|
+
y_pred = None
|
68
|
+
other_output_columns = None
|
69
|
+
ERROR_MISSING_OUTPUT_COLUMN_NAME = (
|
70
|
+
"Output column name is not specified for the multi-output model. "
|
71
|
+
"Please set the correct output column name using the `predictions` parameter."
|
72
|
+
)
|
73
|
+
|
74
|
+
if isinstance(model_predictions, list) and all(isinstance(p, dict) for p in model_predictions):
|
75
|
+
# Extract 'y_pred' and 'other_output_columns' from list of dictionaries
|
76
|
+
if output_column_name in model_predictions[0]:
|
77
|
+
y_pred = pd.Series(
|
78
|
+
[p.get(output_column_name) for p in model_predictions], name=output_column_name
|
79
|
+
)
|
80
|
+
other_output_columns = pd.DataFrame(
|
81
|
+
[{k: v for k, v in p.items() if k != output_column_name} for p in model_predictions]
|
82
|
+
)
|
83
|
+
elif len(model_predictions[0]) == 1:
|
84
|
+
# Set the only key as self.predictions and its value as self.y_pred
|
85
|
+
key, value = list(model_predictions[0].items())[0]
|
86
|
+
y_pred = pd.Series(value, name=key)
|
87
|
+
output_column_name = key
|
88
|
+
elif output_column_name is None:
|
89
|
+
raise MlflowException(
|
90
|
+
ERROR_MISSING_OUTPUT_COLUMN_NAME,
|
91
|
+
error_code=INVALID_PARAMETER_VALUE,
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
raise MlflowException(
|
95
|
+
f"Output column name '{output_column_name}' is not found in the model "
|
96
|
+
f"predictions list: {model_predictions}. Please set the correct output column "
|
97
|
+
"name using the `predictions` parameter.",
|
98
|
+
error_code=INVALID_PARAMETER_VALUE,
|
99
|
+
)
|
100
|
+
elif isinstance(model_predictions, pd.DataFrame):
|
101
|
+
if output_column_name in model_predictions.columns:
|
102
|
+
y_pred = model_predictions[output_column_name]
|
103
|
+
other_output_columns = model_predictions.drop(columns=output_column_name)
|
104
|
+
elif len(model_predictions.columns) == 1:
|
105
|
+
output_column_name = model_predictions.columns[0]
|
106
|
+
y_pred = model_predictions[output_column_name]
|
107
|
+
elif output_column_name is None:
|
108
|
+
raise MlflowException(
|
109
|
+
ERROR_MISSING_OUTPUT_COLUMN_NAME,
|
110
|
+
error_code=INVALID_PARAMETER_VALUE,
|
111
|
+
)
|
112
|
+
else:
|
113
|
+
raise MlflowException(
|
114
|
+
f"Output column name '{output_column_name}' is not found in the model "
|
115
|
+
f"predictions dataframe {model_predictions.columns}. Please set the correct "
|
116
|
+
"output column name using the `predictions` parameter.",
|
117
|
+
error_code=INVALID_PARAMETER_VALUE,
|
118
|
+
)
|
119
|
+
elif isinstance(model_predictions, dict):
|
120
|
+
if output_column_name in model_predictions:
|
121
|
+
y_pred = pd.Series(model_predictions[output_column_name], name=output_column_name)
|
122
|
+
other_output_columns = pd.DataFrame(
|
123
|
+
{k: v for k, v in model_predictions.items() if k != output_column_name}
|
124
|
+
)
|
125
|
+
elif len(model_predictions) == 1:
|
126
|
+
key, value = list(model_predictions.items())[0]
|
127
|
+
y_pred = pd.Series(value, name=key)
|
128
|
+
output_column_name = key
|
129
|
+
elif output_column_name is None:
|
130
|
+
raise MlflowException(
|
131
|
+
ERROR_MISSING_OUTPUT_COLUMN_NAME,
|
132
|
+
error_code=INVALID_PARAMETER_VALUE,
|
133
|
+
)
|
134
|
+
else:
|
135
|
+
raise MlflowException(
|
136
|
+
f"Output column name '{output_column_name}' is not found in the "
|
137
|
+
f"model predictions dict {model_predictions}. Please set the correct "
|
138
|
+
"output column name using the `predictions` parameter.",
|
139
|
+
error_code=INVALID_PARAMETER_VALUE,
|
140
|
+
)
|
141
|
+
|
142
|
+
return (
|
143
|
+
y_pred if y_pred is not None else model_predictions,
|
144
|
+
other_output_columns,
|
145
|
+
output_column_name,
|
146
|
+
)
|
147
|
+
|
148
|
+
|
149
|
+
def _extract_predict_fn(model: Any) -> Optional[Callable[..., Any]]:
|
150
|
+
"""
|
151
|
+
Extracts the predict function from the given model or raw_model.
|
152
|
+
|
153
|
+
Precedence order:
|
154
|
+
1. If raw_model is specified, its predict function is used.
|
155
|
+
2. If model is specified, its predict function is used.
|
156
|
+
3. If none of the above, predict function is None.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
model: A model object that has a predict method.
|
160
|
+
raw_model: A raw model object that has a predict method.
|
161
|
+
|
162
|
+
Returns: The predict function.
|
163
|
+
"""
|
164
|
+
_, raw_model = _extract_raw_model(model)
|
165
|
+
predict_fn = None
|
166
|
+
|
167
|
+
if raw_model is not None:
|
168
|
+
predict_fn = raw_model.predict
|
169
|
+
try:
|
170
|
+
from mlflow.xgboost import _wrapped_xgboost_model_predict_fn
|
171
|
+
|
172
|
+
# Because shap evaluation will pass evaluation data in ndarray format
|
173
|
+
# (without feature names), if set validate_features=True it will raise error.
|
174
|
+
predict_fn = _wrapped_xgboost_model_predict_fn(raw_model, validate_features=False)
|
175
|
+
except ImportError:
|
176
|
+
pass
|
177
|
+
|
178
|
+
elif model is not None:
|
179
|
+
predict_fn = model.predict
|
180
|
+
|
181
|
+
return predict_fn
|
182
|
+
|
183
|
+
|
184
|
+
def _get_dataframe_with_renamed_columns(x, new_column_names):
|
185
|
+
"""
|
186
|
+
Downstream inference functions may expect a pd.DataFrame to be created from x. However,
|
187
|
+
if x is already a pd.DataFrame, and new_column_names != x.columns, we cannot simply call
|
188
|
+
pd.DataFrame(x, columns=new_column_names) because the resulting pd.DataFrame will contain
|
189
|
+
NaNs for every column in new_column_names that does not exist in x.columns. This function
|
190
|
+
instead creates a new pd.DataFrame object from x, and then explicitly renames the columns
|
191
|
+
to avoid NaNs.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
x: A data object, such as a Pandas DataFrame, numPy array, or list
|
195
|
+
new_column_names: Column names for the output Pandas DataFrame
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
A pd.DataFrame with x as data, with columns new_column_names
|
199
|
+
"""
|
200
|
+
df = pd.DataFrame(x)
|
201
|
+
return df.rename(columns=dict(zip(df.columns, new_column_names)))
|
202
|
+
|
203
|
+
|
204
|
+
def _get_aggregate_metrics_values(metrics):
|
205
|
+
return {name: MetricValue(aggregate_results={name: value}) for name, value in metrics.items()}
|
206
|
+
|
207
|
+
|
208
|
+
_matplotlib_config = {
|
209
|
+
"figure.dpi": 175,
|
210
|
+
"figure.figsize": [6.0, 4.0],
|
211
|
+
"figure.autolayout": True,
|
212
|
+
"font.size": 8,
|
213
|
+
}
|
214
|
+
|
215
|
+
|
216
|
+
class _CustomArtifact(NamedTuple):
|
217
|
+
"""
|
218
|
+
A namedtuple representing a custom artifact function and its properties.
|
219
|
+
|
220
|
+
function : the custom artifact function
|
221
|
+
name : the name of the custom artifact function
|
222
|
+
index : the index of the function in the ``custom_artifacts`` argument of mlflow.evaluate
|
223
|
+
artifacts_dir : the path to a temporary directory to store produced artifacts of the function
|
224
|
+
"""
|
225
|
+
|
226
|
+
function: Callable[..., Any]
|
227
|
+
name: str
|
228
|
+
index: int
|
229
|
+
artifacts_dir: str
|
230
|
+
|
231
|
+
|
232
|
+
def _is_valid_artifacts(artifacts):
|
233
|
+
return isinstance(artifacts, dict) and all(isinstance(k, str) for k in artifacts.keys())
|
234
|
+
|
235
|
+
|
236
|
+
def _evaluate_custom_artifacts(custom_artifact_tuple, eval_df, builtin_metrics):
|
237
|
+
"""
|
238
|
+
This function calls the `custom_artifact` function and performs validations on the returned
|
239
|
+
result to ensure that they are in the expected format. It will raise a MlflowException if
|
240
|
+
the result is not in the expected format.
|
241
|
+
|
242
|
+
Args:
|
243
|
+
custom_artifact_tuple: Containing a user provided function and its index in the
|
244
|
+
``custom_artifacts`` parameter of ``mlflow.evaluate``
|
245
|
+
eval_df: A Pandas dataframe object containing a prediction and a target column.
|
246
|
+
builtin_metrics: A dictionary of metrics produced by the default evaluator.
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
A dictionary of artifacts.
|
250
|
+
"""
|
251
|
+
exception_header = (
|
252
|
+
f"Custom artifact function '{custom_artifact_tuple.name}' "
|
253
|
+
" at index {custom_artifact_tuple.index}"
|
254
|
+
" in the `custom_artifacts` parameter"
|
255
|
+
)
|
256
|
+
artifacts = custom_artifact_tuple.function(
|
257
|
+
eval_df, builtin_metrics, custom_artifact_tuple.artifacts_dir
|
258
|
+
)
|
259
|
+
|
260
|
+
if artifacts is None:
|
261
|
+
_logger.warning(f"{exception_header} returned None.")
|
262
|
+
return
|
263
|
+
|
264
|
+
if not _is_valid_artifacts(artifacts):
|
265
|
+
_logger.warning(
|
266
|
+
f"{exception_header} did not return artifacts as a dictionary of string artifact "
|
267
|
+
"names with their corresponding objects."
|
268
|
+
)
|
269
|
+
return
|
270
|
+
|
271
|
+
return artifacts
|
272
|
+
|
273
|
+
|
274
|
+
# TODO: Move this to the /evaluators directory
|
275
|
+
class BuiltInEvaluator(ModelEvaluator):
|
276
|
+
"""
|
277
|
+
The base class for all evaluators that are built-in to MLflow.
|
278
|
+
|
279
|
+
Each evaluator is responsible for implementing the `_evaluate()` method, which is called by
|
280
|
+
the `evaluate()` method of this base class. This class contains many helper methods used
|
281
|
+
across built-in evaluators, such as logging metrics, artifacts, and ordering metrics.
|
282
|
+
"""
|
283
|
+
|
284
|
+
def __init__(self):
|
285
|
+
self.client = MlflowClient()
|
286
|
+
|
287
|
+
@abstractmethod
|
288
|
+
def _evaluate(
|
289
|
+
self,
|
290
|
+
model: Optional["mlflow.pyfunc.PyFuncModel"],
|
291
|
+
extra_metrics: list[EvaluationMetric],
|
292
|
+
custom_artifacts=None,
|
293
|
+
**kwargs,
|
294
|
+
) -> Optional[EvaluationResult]:
|
295
|
+
"""Implement the evaluation logic for each evaluator."""
|
296
|
+
|
297
|
+
def log_metrics(self):
|
298
|
+
"""
|
299
|
+
Helper method to log metrics into specified run.
|
300
|
+
"""
|
301
|
+
self._add_prefix_to_metrics()
|
302
|
+
|
303
|
+
timestamp = get_current_time_millis()
|
304
|
+
self.client.log_batch(
|
305
|
+
self.run_id,
|
306
|
+
metrics=[
|
307
|
+
Metric(
|
308
|
+
key=key,
|
309
|
+
value=value,
|
310
|
+
timestamp=timestamp,
|
311
|
+
step=0,
|
312
|
+
model_id=self.model_id,
|
313
|
+
dataset_name=self.dataset.name,
|
314
|
+
dataset_digest=self.dataset.digest,
|
315
|
+
run_id=self.run_id,
|
316
|
+
)
|
317
|
+
for key, value in self.aggregate_metrics.items()
|
318
|
+
],
|
319
|
+
)
|
320
|
+
|
321
|
+
def _log_image_artifact(
|
322
|
+
self,
|
323
|
+
do_plot,
|
324
|
+
artifact_name,
|
325
|
+
):
|
326
|
+
from matplotlib import pyplot
|
327
|
+
|
328
|
+
prefix = self.evaluator_config.get("metric_prefix", "")
|
329
|
+
artifact_file_name = f"{prefix}{artifact_name}.png"
|
330
|
+
artifact_file_local_path = self.temp_dir.path(artifact_file_name)
|
331
|
+
|
332
|
+
try:
|
333
|
+
pyplot.clf()
|
334
|
+
do_plot()
|
335
|
+
pyplot.savefig(artifact_file_local_path, bbox_inches="tight")
|
336
|
+
except Exception as e:
|
337
|
+
_logger.warning(f"Failed to log image artifact {artifact_name!r}: {e!r}")
|
338
|
+
else:
|
339
|
+
mlflow.log_artifact(artifact_file_local_path)
|
340
|
+
artifact = ImageEvaluationArtifact(uri=mlflow.get_artifact_uri(artifact_file_name))
|
341
|
+
artifact._load(artifact_file_local_path)
|
342
|
+
self.artifacts[artifact_name] = artifact
|
343
|
+
finally:
|
344
|
+
pyplot.close(pyplot.gcf())
|
345
|
+
|
346
|
+
def _evaluate_sklearn_model_score_if_scorable(self, model, y_true, sample_weights):
|
347
|
+
model_loader_module, raw_model = _extract_raw_model(model)
|
348
|
+
if model_loader_module == "mlflow.sklearn" and raw_model is not None:
|
349
|
+
try:
|
350
|
+
score = raw_model.score(
|
351
|
+
self.X.copy_to_avoid_mutation(), y_true, sample_weight=sample_weights
|
352
|
+
)
|
353
|
+
self.metrics_values.update(_get_aggregate_metrics_values({"score": score}))
|
354
|
+
except Exception as e:
|
355
|
+
_logger.warning(
|
356
|
+
f"Computing sklearn model score failed: {e!r}. Set logging level to "
|
357
|
+
"DEBUG to see the full traceback."
|
358
|
+
)
|
359
|
+
_logger.debug("", exc_info=True)
|
360
|
+
|
361
|
+
def _log_custom_metric_artifact(self, artifact_name, raw_artifact, custom_metric_tuple):
|
362
|
+
"""
|
363
|
+
This function logs and returns a custom metric artifact. Two cases:
|
364
|
+
- The provided artifact is a path to a file, the function will make a copy of it with
|
365
|
+
a formatted name in a temporary directory and call mlflow.log_artifact.
|
366
|
+
- Otherwise: will attempt to save the artifact to an temporary path with an inferred
|
367
|
+
type. Then call mlflow.log_artifact.
|
368
|
+
|
369
|
+
Args:
|
370
|
+
artifact_name: the name of the artifact
|
371
|
+
raw_artifact: the object representing the artifact
|
372
|
+
custom_metric_tuple: an instance of the _CustomMetric namedtuple
|
373
|
+
|
374
|
+
Returns:
|
375
|
+
EvaluationArtifact
|
376
|
+
"""
|
377
|
+
|
378
|
+
exception_and_warning_header = (
|
379
|
+
f"Custom artifact function '{custom_metric_tuple.name}' at index "
|
380
|
+
f"{custom_metric_tuple.index} in the `custom_artifacts` parameter"
|
381
|
+
)
|
382
|
+
|
383
|
+
inferred_from_path, inferred_type, inferred_ext = _infer_artifact_type_and_ext(
|
384
|
+
artifact_name, raw_artifact, custom_metric_tuple
|
385
|
+
)
|
386
|
+
artifact_file_local_path = self.temp_dir.path(artifact_name + inferred_ext)
|
387
|
+
|
388
|
+
if pathlib.Path(artifact_file_local_path).exists():
|
389
|
+
raise MlflowException(
|
390
|
+
f"{exception_and_warning_header} produced an artifact '{artifact_name}' that "
|
391
|
+
"cannot be logged because there already exists an artifact with the same name."
|
392
|
+
)
|
393
|
+
|
394
|
+
# ParquetEvaluationArtifact isn't explicitly stated here because such artifacts can only
|
395
|
+
# be supplied through file. Which is handled by the first if clause. This is because
|
396
|
+
# DataFrame objects default to be stored as CsvEvaluationArtifact.
|
397
|
+
if inferred_from_path:
|
398
|
+
shutil.copy2(raw_artifact, artifact_file_local_path)
|
399
|
+
elif inferred_type is JsonEvaluationArtifact:
|
400
|
+
with open(artifact_file_local_path, "w") as f:
|
401
|
+
if isinstance(raw_artifact, str):
|
402
|
+
f.write(raw_artifact)
|
403
|
+
else:
|
404
|
+
json.dump(raw_artifact, f, cls=NumpyEncoder)
|
405
|
+
elif inferred_type is CsvEvaluationArtifact:
|
406
|
+
raw_artifact.to_csv(artifact_file_local_path, index=False)
|
407
|
+
elif inferred_type is NumpyEvaluationArtifact:
|
408
|
+
np.save(artifact_file_local_path, raw_artifact, allow_pickle=False)
|
409
|
+
elif inferred_type is ImageEvaluationArtifact:
|
410
|
+
raw_artifact.savefig(artifact_file_local_path)
|
411
|
+
else:
|
412
|
+
# storing as pickle
|
413
|
+
try:
|
414
|
+
with open(artifact_file_local_path, "wb") as f:
|
415
|
+
pickle.dump(raw_artifact, f)
|
416
|
+
_logger.warning(
|
417
|
+
f"{exception_and_warning_header} produced an artifact '{artifact_name}'"
|
418
|
+
f" with type '{type(raw_artifact)}' that is logged as a pickle artifact."
|
419
|
+
)
|
420
|
+
except pickle.PickleError:
|
421
|
+
raise MlflowException(
|
422
|
+
f"{exception_and_warning_header} produced an unsupported artifact "
|
423
|
+
f"'{artifact_name}' with type '{type(raw_artifact)}' that cannot be pickled. "
|
424
|
+
"Supported object types for artifacts are:\n"
|
425
|
+
"- A string uri representing the file path to the artifact. MLflow"
|
426
|
+
" will infer the type of the artifact based on the file extension.\n"
|
427
|
+
"- A string representation of a JSON object. This will be saved as a "
|
428
|
+
".json artifact.\n"
|
429
|
+
"- Pandas DataFrame. This will be saved as a .csv artifact."
|
430
|
+
"- Numpy array. This will be saved as a .npy artifact."
|
431
|
+
"- Matplotlib Figure. This will be saved as an .png image artifact."
|
432
|
+
"- Other objects will be attempted to be pickled with default protocol."
|
433
|
+
)
|
434
|
+
|
435
|
+
mlflow.log_artifact(artifact_file_local_path)
|
436
|
+
artifact = inferred_type(uri=mlflow.get_artifact_uri(artifact_name + inferred_ext))
|
437
|
+
artifact._load(artifact_file_local_path)
|
438
|
+
return artifact
|
439
|
+
|
440
|
+
def _get_column_in_metrics_values(self, column):
|
441
|
+
for metric_name, metric_value in self.metrics_values.items():
|
442
|
+
if metric_name.split("/")[0] == column:
|
443
|
+
return metric_value
|
444
|
+
|
445
|
+
def _get_args_for_metrics(
|
446
|
+
self,
|
447
|
+
metric: MetricDefinition,
|
448
|
+
eval_df: pd.DataFrame,
|
449
|
+
input_df: pd.DataFrame,
|
450
|
+
other_output_df: Optional[pd.DataFrame],
|
451
|
+
) -> tuple[bool, list[Union[str, pd.DataFrame]]]:
|
452
|
+
"""
|
453
|
+
Given a metric_tuple, read the signature of the metric function and get the appropriate
|
454
|
+
arguments from the input/output columns, other calculated metrics, and evaluator_config.
|
455
|
+
|
456
|
+
Args:
|
457
|
+
metric: The metric definition containing a user provided function and its index
|
458
|
+
in the ``extra_metrics`` parameter of ``mlflow.evaluate``.
|
459
|
+
eval_df: The evaluation dataframe containing the prediction and target columns.
|
460
|
+
input_df: The input dataframe containing the features used to make predictions.
|
461
|
+
other_output_df: A dataframe containing all model output columns but the predictions.
|
462
|
+
|
463
|
+
Returns:
|
464
|
+
tuple: A tuple of (bool, list) where the bool indicates if the given metric can
|
465
|
+
be calculated with the given eval_df, metrics, and input_df.
|
466
|
+
- If the user is missing "targets" or "predictions" parameters when needed, or we
|
467
|
+
cannot find a column or metric for a parameter to the metric, return
|
468
|
+
(False, list of missing parameters)
|
469
|
+
- If all arguments to the metric function were found, return
|
470
|
+
(True, list of arguments).
|
471
|
+
"""
|
472
|
+
# deepcopying eval_df and builtin_metrics for each custom metric function call,
|
473
|
+
# in case the user modifies them inside their function(s).
|
474
|
+
eval_df_copy = eval_df.copy()
|
475
|
+
parameters = inspect.signature(metric.function).parameters
|
476
|
+
eval_fn_args = []
|
477
|
+
params_not_found = []
|
478
|
+
if len(parameters) == 2:
|
479
|
+
param_0_name, param_1_name = parameters.keys()
|
480
|
+
|
481
|
+
# eval_fn has parameters (eval_df, builtin_metrics) for backwards compatibility
|
482
|
+
if len(parameters) == 2 and param_0_name != "predictions" and param_1_name != "targets":
|
483
|
+
eval_fn_args.append(eval_df_copy)
|
484
|
+
self._update_aggregate_metrics()
|
485
|
+
eval_fn_args.append(copy.deepcopy(self.aggregate_metrics))
|
486
|
+
# eval_fn can have parameters like (predictions, targets, metrics, random_col)
|
487
|
+
else:
|
488
|
+
for param_name, param in parameters.items():
|
489
|
+
column = self.col_mapping.get(param_name, param_name)
|
490
|
+
|
491
|
+
if (
|
492
|
+
column == "predictions"
|
493
|
+
or column == self.predictions
|
494
|
+
or column == self.dataset.predictions_name
|
495
|
+
):
|
496
|
+
eval_fn_args.append(eval_df_copy["prediction"])
|
497
|
+
elif column == "targets" or column == self.dataset.targets_name:
|
498
|
+
if "target" in eval_df_copy:
|
499
|
+
eval_fn_args.append(eval_df_copy["target"])
|
500
|
+
else:
|
501
|
+
if param.default == inspect.Parameter.empty:
|
502
|
+
params_not_found.append(param_name)
|
503
|
+
else:
|
504
|
+
eval_fn_args.append(param.default)
|
505
|
+
elif column == "metrics":
|
506
|
+
eval_fn_args.append(copy.deepcopy(self.metrics_values))
|
507
|
+
else:
|
508
|
+
# case when column passed in col_mapping contains the entire column
|
509
|
+
if not isinstance(column, str):
|
510
|
+
eval_fn_args.append(column)
|
511
|
+
|
512
|
+
# case column in col_mapping is string and the column value
|
513
|
+
# is part of the input_df
|
514
|
+
elif column in input_df.columns:
|
515
|
+
eval_fn_args.append(input_df[column])
|
516
|
+
|
517
|
+
# case column in col_mapping is string and the column value
|
518
|
+
# is part of the output_df(other than predictions)
|
519
|
+
elif other_output_df is not None and column in other_output_df.columns:
|
520
|
+
self.other_output_columns_for_eval.add(column)
|
521
|
+
eval_fn_args.append(other_output_df[column])
|
522
|
+
|
523
|
+
# case where the param is defined as part of the evaluator_config
|
524
|
+
elif column in self.evaluator_config:
|
525
|
+
eval_fn_args.append(self.evaluator_config.get(column))
|
526
|
+
|
527
|
+
# case where this is the name of another metric
|
528
|
+
elif metric_value := self._get_column_in_metrics_values(column):
|
529
|
+
eval_fn_args.append(metric_value)
|
530
|
+
|
531
|
+
# in the case that:
|
532
|
+
# the metric has not been calculated yet, but is scheduled to be calculated
|
533
|
+
# "before" this metric in self.ordered_metrics, we append None to indicate
|
534
|
+
# that there is not an error in the dependencies
|
535
|
+
elif column in [metric_tuple.name for metric_tuple in self.ordered_metrics]:
|
536
|
+
eval_fn_args.append(None)
|
537
|
+
|
538
|
+
elif param.default == inspect.Parameter.empty:
|
539
|
+
params_not_found.append(param_name)
|
540
|
+
else:
|
541
|
+
eval_fn_args.append(param.default)
|
542
|
+
|
543
|
+
if len(params_not_found) > 0:
|
544
|
+
return False, params_not_found
|
545
|
+
return True, eval_fn_args
|
546
|
+
|
547
|
+
def evaluate_and_log_custom_artifacts(
|
548
|
+
self,
|
549
|
+
custom_artifacts: list[_CustomArtifact],
|
550
|
+
prediction: pd.Series,
|
551
|
+
target: Optional[np.array] = None,
|
552
|
+
):
|
553
|
+
"""Evaluate custom artifacts provided by users."""
|
554
|
+
if not custom_artifacts:
|
555
|
+
return
|
556
|
+
|
557
|
+
eval_df = self._get_eval_df(prediction, target)
|
558
|
+
for index, custom_artifact in enumerate(custom_artifacts):
|
559
|
+
with tempfile.TemporaryDirectory() as artifacts_dir:
|
560
|
+
# deepcopying eval_df and builtin_metrics for each custom artifact function call,
|
561
|
+
# in case the user modifies them inside their function(s).
|
562
|
+
custom_artifact_tuple = _CustomArtifact(
|
563
|
+
function=custom_artifact,
|
564
|
+
index=index,
|
565
|
+
name=getattr(custom_artifact, "__name__", repr(custom_artifact)),
|
566
|
+
artifacts_dir=artifacts_dir,
|
567
|
+
)
|
568
|
+
artifact_results = _evaluate_custom_artifacts(
|
569
|
+
custom_artifact_tuple,
|
570
|
+
eval_df.copy(),
|
571
|
+
copy.deepcopy(self.metrics_values),
|
572
|
+
)
|
573
|
+
if artifact_results:
|
574
|
+
for artifact_name, raw_artifact in artifact_results.items():
|
575
|
+
self.artifacts[artifact_name] = self._log_custom_metric_artifact(
|
576
|
+
artifact_name,
|
577
|
+
raw_artifact,
|
578
|
+
custom_artifact_tuple,
|
579
|
+
)
|
580
|
+
|
581
|
+
def _get_error_message_missing_columns(self, metric_name, param_names):
|
582
|
+
error_message_parts = [f"Metric '{metric_name}' requires the following:"]
|
583
|
+
|
584
|
+
special_params = ["targets", "predictions"]
|
585
|
+
for param in special_params:
|
586
|
+
if param in param_names:
|
587
|
+
error_message_parts.append(f" - the '{param}' parameter needs to be specified")
|
588
|
+
|
589
|
+
remaining_params = [param for param in param_names if param not in special_params]
|
590
|
+
|
591
|
+
if remaining_params:
|
592
|
+
error_message_parts.append(
|
593
|
+
f" - missing columns {remaining_params} need to be defined or mapped"
|
594
|
+
)
|
595
|
+
|
596
|
+
return "\n".join(error_message_parts)
|
597
|
+
|
598
|
+
def _construct_error_message_for_malformed_metrics(
|
599
|
+
self, malformed_results, input_columns, output_columns
|
600
|
+
):
|
601
|
+
error_messages = [
|
602
|
+
self._get_error_message_missing_columns(metric_name, param_names)
|
603
|
+
for metric_name, param_names in malformed_results
|
604
|
+
]
|
605
|
+
joined_error_message = "\n".join(error_messages)
|
606
|
+
|
607
|
+
full_message = f"""Error: Metric calculation failed for the following metrics:
|
608
|
+
{joined_error_message}
|
609
|
+
|
610
|
+
Below are the existing column names for the input/output data:
|
611
|
+
Input Columns: {input_columns}
|
612
|
+
Output Columns: {output_columns}
|
613
|
+
|
614
|
+
To resolve this issue, you may need to:
|
615
|
+
- specify any required parameters
|
616
|
+
- if you are missing columns, check that there are no circular dependencies among your
|
617
|
+
metrics, and you may want to map them to an existing column using the following
|
618
|
+
configuration:
|
619
|
+
evaluator_config={{'col_mapping': {{<missing column name>: <existing column name>}}}}"""
|
620
|
+
|
621
|
+
return "\n".join(l.lstrip() for l in full_message.splitlines())
|
622
|
+
|
623
|
+
def _raise_exception_for_malformed_metrics(self, malformed_results, eval_df, other_output_df):
|
624
|
+
output_columns = [] if other_output_df is None else list(other_output_df.columns)
|
625
|
+
if self.predictions:
|
626
|
+
output_columns.append(self.predictions)
|
627
|
+
elif self.dataset.predictions_name:
|
628
|
+
output_columns.append(self.dataset.predictions_name)
|
629
|
+
else:
|
630
|
+
output_columns.append("predictions")
|
631
|
+
|
632
|
+
input_columns = list(self.X.copy_to_avoid_mutation().columns)
|
633
|
+
if "target" in eval_df:
|
634
|
+
if self.dataset.targets_name:
|
635
|
+
input_columns.append(self.dataset.targets_name)
|
636
|
+
else:
|
637
|
+
input_columns.append("targets")
|
638
|
+
|
639
|
+
error_message = self._construct_error_message_for_malformed_metrics(
|
640
|
+
malformed_results, input_columns, output_columns
|
641
|
+
)
|
642
|
+
|
643
|
+
raise MlflowException(error_message, error_code=INVALID_PARAMETER_VALUE)
|
644
|
+
|
645
|
+
def _get_eval_df(self, prediction: pd.Series, target: Optional[np.array] = None):
|
646
|
+
"""
|
647
|
+
Create a DataFrame with "prediction" and "target" columns.
|
648
|
+
|
649
|
+
This is a standard format that can be passed to the metric functions.
|
650
|
+
"""
|
651
|
+
eval_df = pd.DataFrame({"prediction": copy.deepcopy(prediction)})
|
652
|
+
if target is not None:
|
653
|
+
eval_df["target"] = target
|
654
|
+
return eval_df
|
655
|
+
|
656
|
+
def _order_metrics(
|
657
|
+
self,
|
658
|
+
metrics: list[EvaluationMetric],
|
659
|
+
eval_df: pd.DataFrame,
|
660
|
+
other_output_df: Optional[pd.DataFrame],
|
661
|
+
):
|
662
|
+
"""
|
663
|
+
Order the list metrics so they can be computed in sequence.
|
664
|
+
|
665
|
+
Some metrics might use the results of other metrics to compute their own results. This
|
666
|
+
function iteratively resolve this dependency, by checking if each metric can be computed
|
667
|
+
with the current available columns and metrics values.
|
668
|
+
"""
|
669
|
+
remaining_metrics = metrics
|
670
|
+
input_df = self.X.copy_to_avoid_mutation()
|
671
|
+
|
672
|
+
while len(remaining_metrics) > 0:
|
673
|
+
pending_metrics = []
|
674
|
+
failed_results = []
|
675
|
+
did_append_metric = False
|
676
|
+
for metric_tuple in remaining_metrics:
|
677
|
+
can_calculate, eval_fn_args = self._get_args_for_metrics(
|
678
|
+
metric_tuple, eval_df, input_df, other_output_df
|
679
|
+
)
|
680
|
+
if can_calculate:
|
681
|
+
self.ordered_metrics.append(metric_tuple)
|
682
|
+
did_append_metric = True
|
683
|
+
else: # cannot calculate the metric yet
|
684
|
+
pending_metrics.append(metric_tuple)
|
685
|
+
failed_results.append((metric_tuple.name, eval_fn_args))
|
686
|
+
|
687
|
+
# cant calculate any more metrics
|
688
|
+
if not did_append_metric:
|
689
|
+
self._raise_exception_for_malformed_metrics(
|
690
|
+
failed_results, eval_df, other_output_df
|
691
|
+
)
|
692
|
+
|
693
|
+
remaining_metrics = pending_metrics
|
694
|
+
|
695
|
+
return self.ordered_metrics
|
696
|
+
|
697
|
+
def _test_first_row(
|
698
|
+
self,
|
699
|
+
metrics: list[MetricDefinition],
|
700
|
+
eval_df: pd.DataFrame,
|
701
|
+
other_output_df: Optional[pd.DataFrame],
|
702
|
+
):
|
703
|
+
# test calculations on first row of eval_df
|
704
|
+
_logger.info("Testing metrics on first row...")
|
705
|
+
exceptions = []
|
706
|
+
first_row_df = eval_df.iloc[[0]]
|
707
|
+
first_row_input_df = self.X.copy_to_avoid_mutation().iloc[[0]]
|
708
|
+
for metric in metrics:
|
709
|
+
try:
|
710
|
+
_, eval_fn_args = self._get_args_for_metrics(
|
711
|
+
metric, first_row_df, first_row_input_df, other_output_df
|
712
|
+
)
|
713
|
+
metric_value = metric.evaluate(eval_fn_args)
|
714
|
+
if metric_value:
|
715
|
+
name = f"{metric.name}/{metric.version}" if metric.version else metric.name
|
716
|
+
self.metrics_values.update({name: metric_value})
|
717
|
+
except Exception as e:
|
718
|
+
stacktrace_str = traceback.format_exc()
|
719
|
+
if isinstance(e, MlflowException):
|
720
|
+
exceptions.append(
|
721
|
+
f"Metric '{metric.name}': Error:\n{e.message}\n{stacktrace_str}"
|
722
|
+
)
|
723
|
+
else:
|
724
|
+
exceptions.append(f"Metric '{metric.name}': Error:\n{e!r}\n{stacktrace_str}")
|
725
|
+
|
726
|
+
if len(exceptions) > 0:
|
727
|
+
raise MlflowException("\n".join(exceptions))
|
728
|
+
|
729
|
+
def evaluate_metrics(
|
730
|
+
self,
|
731
|
+
metrics: list[EvaluationMetric],
|
732
|
+
prediction: pd.Series,
|
733
|
+
target: Optional[np.array] = None,
|
734
|
+
other_output_df: Optional[pd.DataFrame] = None,
|
735
|
+
):
|
736
|
+
"""
|
737
|
+
Evaluate the metrics on the given prediction and target data.
|
738
|
+
|
739
|
+
Args:
|
740
|
+
metrics: A list of metrics to evaluate.
|
741
|
+
prediction: A Pandas Series containing the predictions.
|
742
|
+
target: A numpy array containing the target values.
|
743
|
+
other_output_df: A Pandas DataFrame containing other output columns from the model.
|
744
|
+
|
745
|
+
Returns:
|
746
|
+
None, the metrics values are recorded in the self.metrics_values dictionary.
|
747
|
+
"""
|
748
|
+
|
749
|
+
eval_df = self._get_eval_df(prediction, target)
|
750
|
+
metrics = [
|
751
|
+
MetricDefinition.from_index_and_metric(i, metric) for i, metric in enumerate(metrics)
|
752
|
+
]
|
753
|
+
metrics = self._order_metrics(metrics, eval_df, other_output_df)
|
754
|
+
|
755
|
+
self._test_first_row(metrics, eval_df, other_output_df)
|
756
|
+
|
757
|
+
# calculate metrics for the full eval_df
|
758
|
+
input_df = self.X.copy_to_avoid_mutation()
|
759
|
+
for metric in metrics:
|
760
|
+
_, eval_fn_args = self._get_args_for_metrics(metric, eval_df, input_df, other_output_df)
|
761
|
+
metric_value = metric.evaluate(eval_fn_args)
|
762
|
+
|
763
|
+
if metric_value:
|
764
|
+
name = f"{metric.name}/{metric.version}" if metric.version else metric.name
|
765
|
+
self.metrics_values.update({name: metric_value})
|
766
|
+
|
767
|
+
def log_eval_table(self, y_pred, other_output_columns=None):
|
768
|
+
# only log eval table if there are per row metrics recorded
|
769
|
+
if not any(
|
770
|
+
metric_value.scores is not None or metric_value.justifications is not None
|
771
|
+
for _, metric_value in self.metrics_values.items()
|
772
|
+
):
|
773
|
+
return
|
774
|
+
|
775
|
+
metric_prefix = self.evaluator_config.get("metric_prefix", "")
|
776
|
+
if not isinstance(metric_prefix, str):
|
777
|
+
metric_prefix = ""
|
778
|
+
if isinstance(self.dataset.features_data, pd.DataFrame):
|
779
|
+
# Handle DataFrame case
|
780
|
+
if self.dataset.has_targets:
|
781
|
+
data = self.dataset.features_data.assign(
|
782
|
+
**{
|
783
|
+
self.dataset.targets_name or "target": self.dataset.labels_data,
|
784
|
+
self.dataset.predictions_name or self.predictions or "outputs": y_pred,
|
785
|
+
}
|
786
|
+
)
|
787
|
+
else:
|
788
|
+
data = self.dataset.features_data.assign(outputs=y_pred)
|
789
|
+
else:
|
790
|
+
# Handle NumPy array case, converting it to a DataFrame
|
791
|
+
data = pd.DataFrame(self.dataset.features_data, columns=self.dataset.feature_names)
|
792
|
+
if self.dataset.has_targets:
|
793
|
+
data = data.assign(
|
794
|
+
**{
|
795
|
+
self.dataset.targets_name or "target": self.dataset.labels_data,
|
796
|
+
self.dataset.predictions_name or self.predictions or "outputs": y_pred,
|
797
|
+
}
|
798
|
+
)
|
799
|
+
else:
|
800
|
+
data = data.assign(outputs=y_pred)
|
801
|
+
|
802
|
+
# Include other_output_columns used in evaluation to the eval table
|
803
|
+
if other_output_columns is not None and len(self.other_output_columns_for_eval) > 0:
|
804
|
+
for column in self.other_output_columns_for_eval:
|
805
|
+
data[column] = other_output_columns[column]
|
806
|
+
|
807
|
+
columns = {}
|
808
|
+
for metric_name, metric_value in self.metrics_values.items():
|
809
|
+
scores = metric_value.scores
|
810
|
+
justifications = metric_value.justifications
|
811
|
+
|
812
|
+
if scores:
|
813
|
+
if metric_name.startswith(metric_prefix) and metric_name[len(metric_prefix) :] in [
|
814
|
+
_TOKEN_COUNT_METRIC_NAME,
|
815
|
+
_LATENCY_METRIC_NAME,
|
816
|
+
]:
|
817
|
+
columns[metric_name] = scores
|
818
|
+
else:
|
819
|
+
columns[f"{metric_name}/score"] = scores
|
820
|
+
if justifications:
|
821
|
+
columns[f"{metric_name}/justification"] = justifications
|
822
|
+
data = data.assign(**columns)
|
823
|
+
artifact_file_name = f"{metric_prefix}{_EVAL_TABLE_FILE_NAME}"
|
824
|
+
mlflow.log_table(data, artifact_file=artifact_file_name)
|
825
|
+
if self.eval_results_path:
|
826
|
+
eval_table_spark = self.spark_session.createDataFrame(data)
|
827
|
+
try:
|
828
|
+
eval_table_spark.write.mode(self.eval_results_mode).option(
|
829
|
+
"mergeSchema", "true"
|
830
|
+
).format("delta").saveAsTable(self.eval_results_path)
|
831
|
+
except Exception as e:
|
832
|
+
_logger.info(f"Saving eval table to delta table failed. Reason: {e}")
|
833
|
+
|
834
|
+
name = _EVAL_TABLE_FILE_NAME.split(".", 1)[0]
|
835
|
+
self.artifacts[name] = JsonEvaluationArtifact(
|
836
|
+
uri=mlflow.get_artifact_uri(artifact_file_name)
|
837
|
+
)
|
838
|
+
|
839
|
+
def _update_aggregate_metrics(self):
|
840
|
+
self.aggregate_metrics = {}
|
841
|
+
for metric_name, metric_value in self.metrics_values.items():
|
842
|
+
if metric_value.aggregate_results:
|
843
|
+
for agg_name, agg_value in metric_value.aggregate_results.items():
|
844
|
+
if agg_value is not None:
|
845
|
+
if agg_name == metric_name.split("/")[0]:
|
846
|
+
self.aggregate_metrics[metric_name] = agg_value
|
847
|
+
else:
|
848
|
+
self.aggregate_metrics[f"{metric_name}/{agg_name}"] = agg_value
|
849
|
+
|
850
|
+
def _add_prefix_to_metrics(self):
|
851
|
+
def _prefix_value(value):
|
852
|
+
aggregate = (
|
853
|
+
{f"{prefix}{k}": v for k, v in value.aggregate_results.items()}
|
854
|
+
if value.aggregate_results
|
855
|
+
else None
|
856
|
+
)
|
857
|
+
return MetricValue(value.scores, value.justifications, aggregate)
|
858
|
+
|
859
|
+
if prefix := self.evaluator_config.get("metric_prefix"):
|
860
|
+
self.metrics_values = {
|
861
|
+
f"{prefix}{k}": _prefix_value(v) for k, v in self.metrics_values.items()
|
862
|
+
}
|
863
|
+
|
864
|
+
self._update_aggregate_metrics()
|
865
|
+
|
866
|
+
def evaluate(
|
867
|
+
self,
|
868
|
+
*,
|
869
|
+
model_type,
|
870
|
+
dataset,
|
871
|
+
run_id,
|
872
|
+
evaluator_config,
|
873
|
+
model: "mlflow.pyfunc.PyFuncModel" = None,
|
874
|
+
extra_metrics=None,
|
875
|
+
custom_artifacts=None,
|
876
|
+
predictions=None,
|
877
|
+
model_id=None,
|
878
|
+
**kwargs,
|
879
|
+
) -> EvaluationResult:
|
880
|
+
if model is None and predictions is None and dataset.predictions_data is None:
|
881
|
+
raise MlflowException(
|
882
|
+
message=(
|
883
|
+
"Either a model or set of predictions must be specified in order to use the"
|
884
|
+
" default evaluator. Either specify the `model` parameter, the `predictions`"
|
885
|
+
" parameter, an MLflow dataset containing the `predictions` column name"
|
886
|
+
" (via the `data` parameter), or a different evaluator (via the `evaluators`"
|
887
|
+
" parameter)."
|
888
|
+
),
|
889
|
+
error_code=INVALID_PARAMETER_VALUE,
|
890
|
+
)
|
891
|
+
|
892
|
+
self.artifacts = {}
|
893
|
+
self.aggregate_metrics = {}
|
894
|
+
self.metrics_values = {}
|
895
|
+
self.ordered_metrics = []
|
896
|
+
self.other_output_columns_for_eval = set()
|
897
|
+
|
898
|
+
self.dataset: EvaluationDataset = dataset
|
899
|
+
self.run_id = run_id
|
900
|
+
self.model_type = model_type
|
901
|
+
self.model_id = model_id
|
902
|
+
self.evaluator_config = evaluator_config
|
903
|
+
|
904
|
+
self.predictions = predictions
|
905
|
+
self.col_mapping = self.evaluator_config.get("col_mapping", {})
|
906
|
+
self.eval_results_path = self.evaluator_config.get("eval_results_path")
|
907
|
+
self.eval_results_mode = self.evaluator_config.get("eval_results_mode", "overwrite")
|
908
|
+
|
909
|
+
if self.eval_results_path:
|
910
|
+
from mlflow.utils._spark_utils import _get_active_spark_session
|
911
|
+
|
912
|
+
self.spark_session = _get_active_spark_session()
|
913
|
+
if not self.spark_session:
|
914
|
+
raise MlflowException(
|
915
|
+
message="eval_results_path is only supported in Spark environment. ",
|
916
|
+
error_code=INVALID_PARAMETER_VALUE,
|
917
|
+
)
|
918
|
+
|
919
|
+
if self.eval_results_mode not in ["overwrite", "append"]:
|
920
|
+
raise MlflowException(
|
921
|
+
message="eval_results_mode can only be 'overwrite' or 'append'. ",
|
922
|
+
error_code=INVALID_PARAMETER_VALUE,
|
923
|
+
)
|
924
|
+
|
925
|
+
if extra_metrics is None:
|
926
|
+
extra_metrics = []
|
927
|
+
|
928
|
+
bad_metrics = []
|
929
|
+
for metric in extra_metrics:
|
930
|
+
if not isinstance(metric, EvaluationMetric):
|
931
|
+
bad_metrics.append(metric)
|
932
|
+
if len(bad_metrics) > 0:
|
933
|
+
message = "\n".join(
|
934
|
+
[f"- Metric '{m}' has type '{type(m).__name__}'" for m in bad_metrics]
|
935
|
+
)
|
936
|
+
raise MlflowException(
|
937
|
+
f"In the 'extra_metrics' parameter, the following metrics have the wrong type:\n"
|
938
|
+
f"{message}\n"
|
939
|
+
f"Please ensure that all extra metrics are instances of "
|
940
|
+
f"mlflow.metrics.EvaluationMetric."
|
941
|
+
)
|
942
|
+
|
943
|
+
import matplotlib
|
944
|
+
|
945
|
+
with TempDir() as temp_dir, matplotlib.rc_context(_matplotlib_config):
|
946
|
+
self.temp_dir = temp_dir
|
947
|
+
return self._evaluate(model, extra_metrics, custom_artifacts)
|
948
|
+
|
949
|
+
@property
|
950
|
+
def X(self) -> pd.DataFrame:
|
951
|
+
"""
|
952
|
+
The features (`X`) portion of the dataset, guarded against accidental mutations.
|
953
|
+
"""
|
954
|
+
return BuiltInEvaluator._MutationGuardedData(
|
955
|
+
_get_dataframe_with_renamed_columns(
|
956
|
+
self.dataset.features_data, self.dataset.feature_names
|
957
|
+
)
|
958
|
+
)
|
959
|
+
|
960
|
+
class _MutationGuardedData:
|
961
|
+
"""
|
962
|
+
Wrapper around a data object that requires explicit API calls to obtain either a copy
|
963
|
+
of the data object, or, in cases where the caller can guaranteed that the object will not
|
964
|
+
be mutated, the original data object.
|
965
|
+
"""
|
966
|
+
|
967
|
+
def __init__(self, data):
|
968
|
+
"""
|
969
|
+
Args:
|
970
|
+
data: A data object, such as a Pandas DataFrame, numPy array, or list.
|
971
|
+
"""
|
972
|
+
self._data = data
|
973
|
+
|
974
|
+
def copy_to_avoid_mutation(self):
|
975
|
+
"""
|
976
|
+
Obtain a copy of the data. This method should be called every time the data needs
|
977
|
+
to be used in a context where it may be subsequently mutated, guarding against
|
978
|
+
accidental reuse after mutation.
|
979
|
+
|
980
|
+
Returns:
|
981
|
+
A copy of the data object.
|
982
|
+
"""
|
983
|
+
if isinstance(self._data, pd.DataFrame):
|
984
|
+
return self._data.copy(deep=True)
|
985
|
+
else:
|
986
|
+
return copy.deepcopy(self._data)
|
987
|
+
|
988
|
+
def get_original(self):
|
989
|
+
"""
|
990
|
+
Obtain the original data object. This method should only be called if the caller
|
991
|
+
can guarantee that it will not mutate the data during subsequent operations.
|
992
|
+
|
993
|
+
Returns:
|
994
|
+
The original data object.
|
995
|
+
"""
|
996
|
+
return self._data
|