genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,619 @@
|
|
1
|
+
import functools
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import subprocess
|
5
|
+
import tempfile
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from mlflow.environment_variables import _MLFLOW_TESTING
|
11
|
+
from mlflow.metrics.base import MetricValue, standard_aggregations
|
12
|
+
|
13
|
+
_logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
# used to silently fail with invalid metric params
|
17
|
+
def noop(*args, **kwargs):
|
18
|
+
return None
|
19
|
+
|
20
|
+
|
21
|
+
targets_col_specifier = "the column specified by the `targets` parameter"
|
22
|
+
predictions_col_specifier = (
|
23
|
+
"the column specified by the `predictions` parameter or the model output column"
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
def _validate_text_data(data, metric_name, col_specifier):
|
28
|
+
"""Validates that the data is a list of strs and is non-empty"""
|
29
|
+
if data is None or len(data) == 0:
|
30
|
+
_logger.warning(
|
31
|
+
f"Cannot calculate {metric_name} for empty inputs: "
|
32
|
+
f"{col_specifier} is empty or the parameter is not specified. Skipping metric logging."
|
33
|
+
)
|
34
|
+
return False
|
35
|
+
|
36
|
+
for row, line in enumerate(data):
|
37
|
+
if not isinstance(line, str):
|
38
|
+
_logger.warning(
|
39
|
+
f"Cannot calculate {metric_name} for non-string inputs. "
|
40
|
+
f"Non-string found for {col_specifier} on row {row}. Skipping metric logging."
|
41
|
+
)
|
42
|
+
return False
|
43
|
+
|
44
|
+
return True
|
45
|
+
|
46
|
+
|
47
|
+
def _validate_array_like_id_data(data, metric_name, col_specifier):
|
48
|
+
"""Validates that the data is a list of lists/np.ndarrays of strings/ints and is non-empty"""
|
49
|
+
if data is None or len(data) == 0:
|
50
|
+
return False
|
51
|
+
|
52
|
+
for index, value in data.items():
|
53
|
+
if not (
|
54
|
+
(isinstance(value, list) and all(isinstance(val, (str, int)) for val in value))
|
55
|
+
or (
|
56
|
+
isinstance(value, np.ndarray)
|
57
|
+
and (np.issubdtype(value.dtype, str) or np.issubdtype(value.dtype, int))
|
58
|
+
)
|
59
|
+
):
|
60
|
+
_logger.warning(
|
61
|
+
f"Cannot calculate metric '{metric_name}' for non-arraylike of string or int "
|
62
|
+
f"inputs. Non-arraylike of strings/ints found for {col_specifier} on row "
|
63
|
+
f"{index}, value {value}. Skipping metric logging."
|
64
|
+
)
|
65
|
+
return False
|
66
|
+
|
67
|
+
return True
|
68
|
+
|
69
|
+
|
70
|
+
def _token_count_eval_fn(predictions, targets=None, metrics=None):
|
71
|
+
import tiktoken
|
72
|
+
|
73
|
+
# ref: https://github.com/openai/tiktoken/issues/75
|
74
|
+
# Only set TIKTOKEN_CACHE_DIR if not already set by user
|
75
|
+
if "TIKTOKEN_CACHE_DIR" not in os.environ:
|
76
|
+
os.environ["TIKTOKEN_CACHE_DIR"] = ""
|
77
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
78
|
+
|
79
|
+
num_tokens = []
|
80
|
+
for prediction in predictions:
|
81
|
+
if isinstance(prediction, str):
|
82
|
+
num_tokens.append(len(encoding.encode(prediction)))
|
83
|
+
else:
|
84
|
+
num_tokens.append(None)
|
85
|
+
|
86
|
+
return MetricValue(
|
87
|
+
scores=num_tokens,
|
88
|
+
aggregate_results={},
|
89
|
+
)
|
90
|
+
|
91
|
+
|
92
|
+
def _load_from_github(path: str, module_type: str = "metric"):
|
93
|
+
import evaluate
|
94
|
+
|
95
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
96
|
+
tmpdir = Path(tmpdir)
|
97
|
+
subprocess.check_call(
|
98
|
+
[
|
99
|
+
"git",
|
100
|
+
"clone",
|
101
|
+
"--filter=blob:none",
|
102
|
+
"--no-checkout",
|
103
|
+
"https://github.com/huggingface/evaluate.git",
|
104
|
+
tmpdir,
|
105
|
+
]
|
106
|
+
)
|
107
|
+
path = f"{module_type}s/{path}"
|
108
|
+
subprocess.check_call(["git", "sparse-checkout", "set", path], cwd=tmpdir)
|
109
|
+
subprocess.check_call(["git", "checkout"], cwd=tmpdir)
|
110
|
+
return evaluate.load(str(tmpdir / path))
|
111
|
+
|
112
|
+
|
113
|
+
@functools.lru_cache(maxsize=8)
|
114
|
+
def _cached_evaluate_load(path: str, module_type: str = "metric"):
|
115
|
+
import evaluate
|
116
|
+
|
117
|
+
try:
|
118
|
+
return evaluate.load(path, module_type=module_type)
|
119
|
+
except (FileNotFoundError, OSError):
|
120
|
+
if _MLFLOW_TESTING.get():
|
121
|
+
# `evaluate.load` is highly unstable and often fails due to a network error or
|
122
|
+
# huggingface hub being down. In testing, we want to avoid this instability, so we
|
123
|
+
# load the metric from the evaluate repository on GitHub.
|
124
|
+
return _load_from_github(path, module_type=module_type)
|
125
|
+
raise
|
126
|
+
|
127
|
+
|
128
|
+
def _toxicity_eval_fn(predictions, targets=None, metrics=None):
|
129
|
+
if not _validate_text_data(predictions, "toxicity", predictions_col_specifier):
|
130
|
+
return
|
131
|
+
try:
|
132
|
+
toxicity = _cached_evaluate_load("toxicity", module_type="measurement")
|
133
|
+
except Exception as e:
|
134
|
+
_logger.warning(
|
135
|
+
f"Failed to load 'toxicity' metric (error: {e!r}), skipping metric logging."
|
136
|
+
)
|
137
|
+
return
|
138
|
+
|
139
|
+
scores = toxicity.compute(predictions=predictions)["toxicity"]
|
140
|
+
toxicity_ratio = toxicity.compute(predictions=predictions, aggregation="ratio")[
|
141
|
+
"toxicity_ratio"
|
142
|
+
]
|
143
|
+
return MetricValue(
|
144
|
+
scores=scores,
|
145
|
+
aggregate_results={
|
146
|
+
**standard_aggregations(scores),
|
147
|
+
"ratio": toxicity_ratio,
|
148
|
+
},
|
149
|
+
)
|
150
|
+
|
151
|
+
|
152
|
+
def _flesch_kincaid_eval_fn(predictions, targets=None, metrics=None):
|
153
|
+
if not _validate_text_data(predictions, "flesch_kincaid", predictions_col_specifier):
|
154
|
+
return
|
155
|
+
|
156
|
+
try:
|
157
|
+
import textstat
|
158
|
+
except ImportError:
|
159
|
+
_logger.warning(
|
160
|
+
"Failed to import textstat for flesch kincaid metric, skipping metric logging. "
|
161
|
+
"Please install textstat using 'pip install textstat'."
|
162
|
+
)
|
163
|
+
return
|
164
|
+
|
165
|
+
scores = [textstat.flesch_kincaid_grade(prediction) for prediction in predictions]
|
166
|
+
return MetricValue(
|
167
|
+
scores=scores,
|
168
|
+
aggregate_results=standard_aggregations(scores),
|
169
|
+
)
|
170
|
+
|
171
|
+
|
172
|
+
def _ari_eval_fn(predictions, targets=None, metrics=None):
|
173
|
+
if not _validate_text_data(predictions, "ari", predictions_col_specifier):
|
174
|
+
return
|
175
|
+
|
176
|
+
try:
|
177
|
+
import textstat
|
178
|
+
except ImportError:
|
179
|
+
_logger.warning(
|
180
|
+
"Failed to import textstat for automated readability index metric, "
|
181
|
+
"skipping metric logging. "
|
182
|
+
"Please install textstat using 'pip install textstat'."
|
183
|
+
)
|
184
|
+
return
|
185
|
+
|
186
|
+
scores = [textstat.automated_readability_index(prediction) for prediction in predictions]
|
187
|
+
return MetricValue(
|
188
|
+
scores=scores,
|
189
|
+
aggregate_results=standard_aggregations(scores),
|
190
|
+
)
|
191
|
+
|
192
|
+
|
193
|
+
def _accuracy_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
|
194
|
+
if targets is not None and len(targets) != 0:
|
195
|
+
from sklearn.metrics import accuracy_score
|
196
|
+
|
197
|
+
acc = accuracy_score(y_true=targets, y_pred=predictions, sample_weight=sample_weight)
|
198
|
+
return MetricValue(aggregate_results={"exact_match": acc})
|
199
|
+
|
200
|
+
|
201
|
+
def _rouge1_eval_fn(predictions, targets=None, metrics=None):
|
202
|
+
if not _validate_text_data(targets, "rouge1", targets_col_specifier) or not _validate_text_data(
|
203
|
+
predictions, "rouge1", predictions_col_specifier
|
204
|
+
):
|
205
|
+
return
|
206
|
+
|
207
|
+
try:
|
208
|
+
rouge = _cached_evaluate_load("rouge")
|
209
|
+
except Exception as e:
|
210
|
+
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
|
211
|
+
return
|
212
|
+
|
213
|
+
scores = rouge.compute(
|
214
|
+
predictions=predictions,
|
215
|
+
references=targets,
|
216
|
+
rouge_types=["rouge1"],
|
217
|
+
use_aggregator=False,
|
218
|
+
)["rouge1"]
|
219
|
+
return MetricValue(
|
220
|
+
scores=scores,
|
221
|
+
aggregate_results=standard_aggregations(scores),
|
222
|
+
)
|
223
|
+
|
224
|
+
|
225
|
+
def _rouge2_eval_fn(predictions, targets=None, metrics=None):
|
226
|
+
if not _validate_text_data(targets, "rouge2", targets_col_specifier) or not _validate_text_data(
|
227
|
+
predictions, "rouge2", predictions_col_specifier
|
228
|
+
):
|
229
|
+
return
|
230
|
+
|
231
|
+
try:
|
232
|
+
rouge = _cached_evaluate_load("rouge")
|
233
|
+
except Exception as e:
|
234
|
+
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
|
235
|
+
return
|
236
|
+
|
237
|
+
scores = rouge.compute(
|
238
|
+
predictions=predictions,
|
239
|
+
references=targets,
|
240
|
+
rouge_types=["rouge2"],
|
241
|
+
use_aggregator=False,
|
242
|
+
)["rouge2"]
|
243
|
+
return MetricValue(
|
244
|
+
scores=scores,
|
245
|
+
aggregate_results=standard_aggregations(scores),
|
246
|
+
)
|
247
|
+
|
248
|
+
|
249
|
+
def _rougeL_eval_fn(predictions, targets=None, metrics=None):
|
250
|
+
if not _validate_text_data(targets, "rougeL", targets_col_specifier) or not _validate_text_data(
|
251
|
+
predictions, "rougeL", predictions_col_specifier
|
252
|
+
):
|
253
|
+
return
|
254
|
+
|
255
|
+
try:
|
256
|
+
rouge = _cached_evaluate_load("rouge")
|
257
|
+
except Exception as e:
|
258
|
+
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
|
259
|
+
return
|
260
|
+
|
261
|
+
scores = rouge.compute(
|
262
|
+
predictions=predictions,
|
263
|
+
references=targets,
|
264
|
+
rouge_types=["rougeL"],
|
265
|
+
use_aggregator=False,
|
266
|
+
)["rougeL"]
|
267
|
+
return MetricValue(
|
268
|
+
scores=scores,
|
269
|
+
aggregate_results=standard_aggregations(scores),
|
270
|
+
)
|
271
|
+
|
272
|
+
|
273
|
+
def _rougeLsum_eval_fn(predictions, targets=None, metrics=None):
|
274
|
+
if not _validate_text_data(
|
275
|
+
targets, "rougeLsum", targets_col_specifier
|
276
|
+
) or not _validate_text_data(predictions, "rougeLsum", predictions_col_specifier):
|
277
|
+
return
|
278
|
+
|
279
|
+
try:
|
280
|
+
rouge = _cached_evaluate_load("rouge")
|
281
|
+
except Exception as e:
|
282
|
+
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
|
283
|
+
return
|
284
|
+
|
285
|
+
scores = rouge.compute(
|
286
|
+
predictions=predictions,
|
287
|
+
references=targets,
|
288
|
+
rouge_types=["rougeLsum"],
|
289
|
+
use_aggregator=False,
|
290
|
+
)["rougeLsum"]
|
291
|
+
return MetricValue(
|
292
|
+
scores=scores,
|
293
|
+
aggregate_results=standard_aggregations(scores),
|
294
|
+
)
|
295
|
+
|
296
|
+
|
297
|
+
def _mae_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
|
298
|
+
if targets is not None and len(targets) != 0:
|
299
|
+
from sklearn.metrics import mean_absolute_error
|
300
|
+
|
301
|
+
mae = mean_absolute_error(targets, predictions, sample_weight=sample_weight)
|
302
|
+
return MetricValue(aggregate_results={"mean_absolute_error": mae})
|
303
|
+
|
304
|
+
|
305
|
+
def _mse_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
|
306
|
+
if targets is not None and len(targets) != 0:
|
307
|
+
from sklearn.metrics import mean_squared_error
|
308
|
+
|
309
|
+
mse = mean_squared_error(targets, predictions, sample_weight=sample_weight)
|
310
|
+
return MetricValue(aggregate_results={"mean_squared_error": mse})
|
311
|
+
|
312
|
+
|
313
|
+
def _root_mean_squared_error(*, y_true, y_pred, sample_weight):
|
314
|
+
try:
|
315
|
+
from sklearn.metrics import root_mean_squared_error
|
316
|
+
except ImportError:
|
317
|
+
# If root_mean_squared_error is unavailable, fall back to
|
318
|
+
# `mean_squared_error(..., squared=False)`, which is deprecated in scikit-learn >= 1.4.
|
319
|
+
from sklearn.metrics import mean_squared_error
|
320
|
+
|
321
|
+
return mean_squared_error(
|
322
|
+
y_true=y_true, y_pred=y_pred, sample_weight=sample_weight, squared=False
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
return root_mean_squared_error(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
|
326
|
+
|
327
|
+
|
328
|
+
def _rmse_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
|
329
|
+
if targets is not None and len(targets) != 0:
|
330
|
+
rmse = _root_mean_squared_error(
|
331
|
+
y_true=targets, y_pred=predictions, sample_weight=sample_weight
|
332
|
+
)
|
333
|
+
return MetricValue(aggregate_results={"root_mean_squared_error": rmse})
|
334
|
+
|
335
|
+
|
336
|
+
def _r2_score_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
|
337
|
+
if targets is not None and len(targets) != 0:
|
338
|
+
from sklearn.metrics import r2_score
|
339
|
+
|
340
|
+
r2 = r2_score(targets, predictions, sample_weight=sample_weight)
|
341
|
+
return MetricValue(aggregate_results={"r2_score": r2})
|
342
|
+
|
343
|
+
|
344
|
+
def _max_error_eval_fn(predictions, targets=None, metrics=None):
|
345
|
+
if targets is not None and len(targets) != 0:
|
346
|
+
from sklearn.metrics import max_error
|
347
|
+
|
348
|
+
error = max_error(targets, predictions)
|
349
|
+
return MetricValue(aggregate_results={"max_error": error})
|
350
|
+
|
351
|
+
|
352
|
+
def _mape_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
|
353
|
+
if targets is not None and len(targets) != 0:
|
354
|
+
from sklearn.metrics import mean_absolute_percentage_error
|
355
|
+
|
356
|
+
mape = mean_absolute_percentage_error(targets, predictions, sample_weight=sample_weight)
|
357
|
+
return MetricValue(aggregate_results={"mean_absolute_percentage_error": mape})
|
358
|
+
|
359
|
+
|
360
|
+
def _recall_eval_fn(
|
361
|
+
predictions, targets=None, metrics=None, pos_label=1, average="binary", sample_weight=None
|
362
|
+
):
|
363
|
+
if targets is not None and len(targets) != 0:
|
364
|
+
from sklearn.metrics import recall_score
|
365
|
+
|
366
|
+
recall = recall_score(
|
367
|
+
targets, predictions, pos_label=pos_label, average=average, sample_weight=sample_weight
|
368
|
+
)
|
369
|
+
return MetricValue(aggregate_results={"recall_score": recall})
|
370
|
+
|
371
|
+
|
372
|
+
def _precision_eval_fn(
|
373
|
+
predictions, targets=None, metrics=None, pos_label=1, average="binary", sample_weight=None
|
374
|
+
):
|
375
|
+
if targets is not None and len(targets) != 0:
|
376
|
+
from sklearn.metrics import precision_score
|
377
|
+
|
378
|
+
precision = precision_score(
|
379
|
+
targets,
|
380
|
+
predictions,
|
381
|
+
pos_label=pos_label,
|
382
|
+
average=average,
|
383
|
+
sample_weight=sample_weight,
|
384
|
+
)
|
385
|
+
return MetricValue(aggregate_results={"precision_score": precision})
|
386
|
+
|
387
|
+
|
388
|
+
def _f1_score_eval_fn(
|
389
|
+
predictions, targets=None, metrics=None, pos_label=1, average="binary", sample_weight=None
|
390
|
+
):
|
391
|
+
if targets is not None and len(targets) != 0:
|
392
|
+
from sklearn.metrics import f1_score
|
393
|
+
|
394
|
+
f1 = f1_score(
|
395
|
+
targets,
|
396
|
+
predictions,
|
397
|
+
pos_label=pos_label,
|
398
|
+
average=average,
|
399
|
+
sample_weight=sample_weight,
|
400
|
+
)
|
401
|
+
return MetricValue(aggregate_results={"f1_score": f1})
|
402
|
+
|
403
|
+
|
404
|
+
def _precision_at_k_eval_fn(k):
|
405
|
+
if not (isinstance(k, int) and k > 0):
|
406
|
+
_logger.warning(
|
407
|
+
f"Cannot calculate 'precision_at_k' for invalid parameter 'k'. "
|
408
|
+
f"'k' should be a positive integer; found: {k}. Skipping metric logging."
|
409
|
+
)
|
410
|
+
return noop
|
411
|
+
|
412
|
+
def _fn(predictions, targets):
|
413
|
+
if not _validate_array_like_id_data(
|
414
|
+
predictions, "precision_at_k", predictions_col_specifier
|
415
|
+
) or not _validate_array_like_id_data(targets, "precision_at_k", targets_col_specifier):
|
416
|
+
return
|
417
|
+
|
418
|
+
scores = []
|
419
|
+
for target, prediction in zip(targets, predictions):
|
420
|
+
# only include the top k retrieved chunks
|
421
|
+
ground_truth = set(target)
|
422
|
+
retrieved = prediction[:k]
|
423
|
+
relevant_doc_count = sum(1 for doc in retrieved if doc in ground_truth)
|
424
|
+
if len(retrieved) > 0:
|
425
|
+
scores.append(relevant_doc_count / len(retrieved))
|
426
|
+
else:
|
427
|
+
# when no documents are retrieved, precision is 0
|
428
|
+
scores.append(0)
|
429
|
+
|
430
|
+
return MetricValue(scores=scores, aggregate_results=standard_aggregations(scores))
|
431
|
+
|
432
|
+
return _fn
|
433
|
+
|
434
|
+
|
435
|
+
def _expand_duplicate_retrieved_docs(predictions, targets):
|
436
|
+
counter = {}
|
437
|
+
expanded_predictions = []
|
438
|
+
expanded_targets = targets
|
439
|
+
for doc_id in predictions:
|
440
|
+
if doc_id not in counter:
|
441
|
+
counter[doc_id] = 1
|
442
|
+
expanded_predictions.append(doc_id)
|
443
|
+
else:
|
444
|
+
counter[doc_id] += 1
|
445
|
+
new_doc_id = (
|
446
|
+
f"{doc_id}_bc574ae_{counter[doc_id]}" # adding a random string to avoid collisions
|
447
|
+
)
|
448
|
+
expanded_predictions.append(new_doc_id)
|
449
|
+
if doc_id in expanded_targets:
|
450
|
+
expanded_targets.add(new_doc_id)
|
451
|
+
return expanded_predictions, expanded_targets
|
452
|
+
|
453
|
+
|
454
|
+
def _prepare_row_for_ndcg(predictions, targets):
|
455
|
+
"""Prepare data one row from predictions and targets to y_score, y_true for ndcg calculation.
|
456
|
+
|
457
|
+
Args:
|
458
|
+
predictions: A list of strings of at most k doc IDs retrieved.
|
459
|
+
targets: A list of strings of ground-truth doc IDs.
|
460
|
+
|
461
|
+
Returns:
|
462
|
+
y_true : ndarray of shape (1, n_docs) Representing the ground-truth relevant docs.
|
463
|
+
n_docs is the number of unique docs in union of predictions and targets.
|
464
|
+
y_score : ndarray of shape (1, n_docs) Representing the retrieved docs.
|
465
|
+
n_docs is the number of unique docs in union of predictions and targets.
|
466
|
+
"""
|
467
|
+
# sklearn does an internal sort of y_score, so to preserve the order of our retrieved
|
468
|
+
# docs, we need to modify the relevance value slightly
|
469
|
+
eps = 1e-6
|
470
|
+
|
471
|
+
# support predictions containing duplicate doc ID
|
472
|
+
targets = set(targets)
|
473
|
+
predictions, targets = _expand_duplicate_retrieved_docs(predictions, targets)
|
474
|
+
|
475
|
+
all_docs = targets.union(predictions)
|
476
|
+
doc_id_to_index = {doc_id: i for i, doc_id in enumerate(all_docs)}
|
477
|
+
n_labels = max(len(doc_id_to_index), 2) # sklearn.metrics.ndcg_score requires at least 2 labels
|
478
|
+
y_true = np.zeros((1, n_labels), dtype=np.float32)
|
479
|
+
y_score = np.zeros((1, n_labels), dtype=np.float32)
|
480
|
+
for i, doc_id in enumerate(predictions):
|
481
|
+
# "1 - i * eps" means we assign higher score to docs that are ranked higher,
|
482
|
+
# but all scores are still approximately 1.
|
483
|
+
y_score[0, doc_id_to_index[doc_id]] = 1 - i * eps
|
484
|
+
for doc_id in targets:
|
485
|
+
y_true[0, doc_id_to_index[doc_id]] = 1
|
486
|
+
return y_score, y_true
|
487
|
+
|
488
|
+
|
489
|
+
def _ndcg_at_k_eval_fn(k):
|
490
|
+
if not (isinstance(k, int) and k > 0):
|
491
|
+
_logger.warning(
|
492
|
+
f"Cannot calculate 'ndcg_at_k' for invalid parameter 'k'. "
|
493
|
+
f"'k' should be a positive integer; found: {k}. Skipping metric logging."
|
494
|
+
)
|
495
|
+
return noop
|
496
|
+
|
497
|
+
def _fn(predictions, targets):
|
498
|
+
from sklearn.metrics import ndcg_score
|
499
|
+
|
500
|
+
if not _validate_array_like_id_data(
|
501
|
+
predictions, "ndcg_at_k", predictions_col_specifier
|
502
|
+
) or not _validate_array_like_id_data(targets, "ndcg_at_k", targets_col_specifier):
|
503
|
+
return
|
504
|
+
|
505
|
+
scores = []
|
506
|
+
for ground_truth, retrieved in zip(targets, predictions):
|
507
|
+
# 1. If no ground truth doc IDs are provided and no documents are retrieved,
|
508
|
+
# the score is 1.
|
509
|
+
if len(retrieved) == 0 and len(ground_truth) == 0:
|
510
|
+
scores.append(1) # no error is made
|
511
|
+
continue
|
512
|
+
# 2. If no ground truth doc IDs are provided and documents are retrieved,
|
513
|
+
# the score is 0.
|
514
|
+
# 3. If ground truth doc IDs are provided and no documents are retrieved,
|
515
|
+
# the score is 0.
|
516
|
+
if len(retrieved) == 0 or len(ground_truth) == 0:
|
517
|
+
scores.append(0)
|
518
|
+
continue
|
519
|
+
|
520
|
+
# only include the top k retrieved chunks
|
521
|
+
y_score, y_true = _prepare_row_for_ndcg(retrieved[:k], ground_truth)
|
522
|
+
score = ndcg_score(y_true, y_score, k=len(retrieved[:k]), ignore_ties=True)
|
523
|
+
scores.append(score)
|
524
|
+
|
525
|
+
return MetricValue(scores=scores, aggregate_results=standard_aggregations(scores))
|
526
|
+
|
527
|
+
return _fn
|
528
|
+
|
529
|
+
|
530
|
+
def _recall_at_k_eval_fn(k):
|
531
|
+
if not (isinstance(k, int) and k > 0):
|
532
|
+
_logger.warning(
|
533
|
+
f"Cannot calculate 'recall_at_k' for invalid parameter 'k'. "
|
534
|
+
f"'k' should be a positive integer; found: {k}. Skipping metric logging."
|
535
|
+
)
|
536
|
+
return noop
|
537
|
+
|
538
|
+
def _fn(predictions, targets):
|
539
|
+
if not _validate_array_like_id_data(
|
540
|
+
predictions, "recall_at_k", predictions_col_specifier
|
541
|
+
) or not _validate_array_like_id_data(targets, "recall_at_k", targets_col_specifier):
|
542
|
+
return
|
543
|
+
|
544
|
+
scores = []
|
545
|
+
for target, prediction in zip(targets, predictions):
|
546
|
+
# only include the top k retrieved chunks
|
547
|
+
ground_truth = set(target)
|
548
|
+
retrieved = set(prediction[:k])
|
549
|
+
relevant_doc_count = len(ground_truth.intersection(retrieved))
|
550
|
+
if len(ground_truth) > 0:
|
551
|
+
scores.append(relevant_doc_count / len(ground_truth))
|
552
|
+
elif len(retrieved) == 0:
|
553
|
+
# there are 0 retrieved and ground truth docs, so reward for the match
|
554
|
+
scores.append(1)
|
555
|
+
else:
|
556
|
+
# there are > 0 retrieved, but 0 ground truth, so penalize
|
557
|
+
scores.append(0)
|
558
|
+
|
559
|
+
return MetricValue(scores=scores, aggregate_results=standard_aggregations(scores))
|
560
|
+
|
561
|
+
return _fn
|
562
|
+
|
563
|
+
|
564
|
+
def _bleu_eval_fn(predictions, targets=None, metrics=None):
|
565
|
+
# Validate input data
|
566
|
+
if not _validate_text_data(targets, "bleu", targets_col_specifier):
|
567
|
+
_logger.error(
|
568
|
+
"""Target validation failed.
|
569
|
+
Ensure targets are valid for BLEU computation."""
|
570
|
+
)
|
571
|
+
return
|
572
|
+
if not _validate_text_data(predictions, "bleu", predictions_col_specifier):
|
573
|
+
_logger.error(
|
574
|
+
"""Prediction validation failed.
|
575
|
+
Ensure predictions are valid for BLEU computation."""
|
576
|
+
)
|
577
|
+
return
|
578
|
+
|
579
|
+
# Load BLEU metric
|
580
|
+
try:
|
581
|
+
bleu = _cached_evaluate_load("bleu")
|
582
|
+
except Exception as e:
|
583
|
+
_logger.warning(f"Failed to load 'bleu' metric (error: {e!r}), skipping metric logging.")
|
584
|
+
return
|
585
|
+
|
586
|
+
# Calculate BLEU scores for each prediction-target pair
|
587
|
+
result = []
|
588
|
+
invalid_indices = []
|
589
|
+
|
590
|
+
for i, (prediction, target) in enumerate(zip(predictions, targets)):
|
591
|
+
if len(target) == 0 or len(prediction) == 0:
|
592
|
+
invalid_indices.append(i)
|
593
|
+
result.append(0) # Append 0 as a placeholder for invalid entries
|
594
|
+
continue
|
595
|
+
|
596
|
+
try:
|
597
|
+
score = bleu.compute(predictions=[prediction], references=[[target]])
|
598
|
+
result.append(score["bleu"])
|
599
|
+
except Exception as e:
|
600
|
+
_logger.warning(f"Failed to calculate BLEU for row {i} (error: {e!r}). Skipping.")
|
601
|
+
result.append(0) # Append 0 for consistency if an unexpected error occurs
|
602
|
+
|
603
|
+
# Log warning for any invalid indices
|
604
|
+
if invalid_indices:
|
605
|
+
_logger.warning(
|
606
|
+
f"BLEU score calculation skipped for the following indices "
|
607
|
+
f"due to empty target or prediction: {invalid_indices}. "
|
608
|
+
f"A score of 0 was appended for these entries."
|
609
|
+
)
|
610
|
+
|
611
|
+
# Return results
|
612
|
+
if not result:
|
613
|
+
_logger.warning("No BLEU scores were calculated due to input errors.")
|
614
|
+
return
|
615
|
+
|
616
|
+
return MetricValue(
|
617
|
+
scores=result,
|
618
|
+
aggregate_results=standard_aggregations(result),
|
619
|
+
)
|
mlflow/mismatch.py
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
import importlib.metadata
|
2
|
+
import warnings
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
|
6
|
+
def _get_version(package_name: str) -> Optional[str]:
|
7
|
+
try:
|
8
|
+
return importlib.metadata.version(package_name)
|
9
|
+
except importlib.metadata.PackageNotFoundError:
|
10
|
+
return None
|
11
|
+
|
12
|
+
|
13
|
+
def _check_version_mismatch() -> None:
|
14
|
+
"""
|
15
|
+
Warns if both mlflow and mlflow-skinny are installed but their versions are different.
|
16
|
+
|
17
|
+
Reference: https://github.com/pypa/pip/issues/4625
|
18
|
+
"""
|
19
|
+
if (
|
20
|
+
(mlflow_ver := _get_version("mlflow"))
|
21
|
+
and ("dev" not in mlflow_ver)
|
22
|
+
and (skinny_ver := _get_version("mlflow-skinny"))
|
23
|
+
and ("dev" not in skinny_ver)
|
24
|
+
and mlflow_ver != skinny_ver
|
25
|
+
):
|
26
|
+
warnings.warn(
|
27
|
+
(
|
28
|
+
f"Versions of mlflow ({mlflow_ver}) and mlflow-skinny ({skinny_ver}) "
|
29
|
+
"are different. This may lead to unexpected behavior. "
|
30
|
+
"Please install the same version of both packages."
|
31
|
+
),
|
32
|
+
stacklevel=2,
|
33
|
+
category=UserWarning,
|
34
|
+
)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from mlflow.mistral.autolog import patched_class_call
|
2
|
+
from mlflow.utils.annotations import experimental
|
3
|
+
from mlflow.utils.autologging_utils import autologging_integration, safe_patch
|
4
|
+
|
5
|
+
FLAVOR_NAME = "mistral"
|
6
|
+
|
7
|
+
|
8
|
+
@experimental(version="2.21.0")
|
9
|
+
@autologging_integration(FLAVOR_NAME)
|
10
|
+
def autolog(
|
11
|
+
log_traces: bool = True,
|
12
|
+
disable: bool = False,
|
13
|
+
silent: bool = False,
|
14
|
+
):
|
15
|
+
"""
|
16
|
+
Enables (or disables) and configures autologging from Mistral AI to MLflow.
|
17
|
+
Only synchronous calls to the Text generation API are supported.
|
18
|
+
Asynchronous APIs and streaming are not recorded.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
log_traces: If ``True``, traces are logged for Mistral AI models.
|
22
|
+
If ``False``, no traces are collected during inference. Default to ``True``.
|
23
|
+
disable: If ``True``, disables the Mistral AI autologging. Default to ``False``.
|
24
|
+
silent: If ``True``, suppress all event logs and warnings from MLflow during Mistral AI
|
25
|
+
autologging. If ``False``, show all events and warnings.
|
26
|
+
"""
|
27
|
+
from mistralai.chat import Chat
|
28
|
+
|
29
|
+
safe_patch(
|
30
|
+
FLAVOR_NAME,
|
31
|
+
Chat,
|
32
|
+
"complete",
|
33
|
+
patched_class_call,
|
34
|
+
)
|