genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,479 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Callable, Optional
|
3
|
+
|
4
|
+
from mlflow.entities.logged_model_parameter import LoggedModelParameter as ModelParam
|
5
|
+
from mlflow.entities.metric import Metric
|
6
|
+
from mlflow.entities.model_registry import (
|
7
|
+
ModelVersion,
|
8
|
+
ModelVersionDeploymentJobState,
|
9
|
+
ModelVersionTag,
|
10
|
+
RegisteredModel,
|
11
|
+
RegisteredModelAlias,
|
12
|
+
RegisteredModelDeploymentJobState,
|
13
|
+
RegisteredModelTag,
|
14
|
+
)
|
15
|
+
from mlflow.entities.model_registry.model_version_search import ModelVersionSearch
|
16
|
+
from mlflow.entities.model_registry.registered_model_search import RegisteredModelSearch
|
17
|
+
from mlflow.environment_variables import MLFLOW_USE_DATABRICKS_SDK_MODEL_ARTIFACTS_REPO_FOR_UC
|
18
|
+
from mlflow.exceptions import MlflowException
|
19
|
+
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
|
20
|
+
EmitModelVersionLineageRequest,
|
21
|
+
EmitModelVersionLineageResponse,
|
22
|
+
IsDatabricksSdkModelsArtifactRepositoryEnabledRequest,
|
23
|
+
IsDatabricksSdkModelsArtifactRepositoryEnabledResponse,
|
24
|
+
ModelVersionLineageInfo,
|
25
|
+
SseEncryptionAlgorithm,
|
26
|
+
TemporaryCredentials,
|
27
|
+
)
|
28
|
+
from mlflow.protos.databricks_uc_registry_messages_pb2 import ModelVersion as ProtoModelVersion
|
29
|
+
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
|
30
|
+
ModelVersionStatus as ProtoModelVersionStatus,
|
31
|
+
)
|
32
|
+
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
|
33
|
+
ModelVersionTag as ProtoModelVersionTag,
|
34
|
+
)
|
35
|
+
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
|
36
|
+
RegisteredModel as ProtoRegisteredModel,
|
37
|
+
)
|
38
|
+
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
|
39
|
+
RegisteredModelTag as ProtoRegisteredModelTag,
|
40
|
+
)
|
41
|
+
from mlflow.protos.databricks_uc_registry_service_pb2 import UcModelRegistryService
|
42
|
+
from mlflow.protos.unity_catalog_oss_messages_pb2 import (
|
43
|
+
TemporaryCredentials as TemporaryCredentialsOSS,
|
44
|
+
)
|
45
|
+
from mlflow.store.artifact.artifact_repo import ArtifactRepository
|
46
|
+
from mlflow.utils.proto_json_utils import message_to_json
|
47
|
+
from mlflow.utils.rest_utils import (
|
48
|
+
_REST_API_PATH_PREFIX,
|
49
|
+
call_endpoint,
|
50
|
+
extract_api_info_for_service,
|
51
|
+
)
|
52
|
+
|
53
|
+
_logger = logging.getLogger(__name__)
|
54
|
+
_METHOD_TO_INFO = extract_api_info_for_service(UcModelRegistryService, _REST_API_PATH_PREFIX)
|
55
|
+
_STRING_TO_STATUS = {k: ProtoModelVersionStatus.Value(k) for k in ProtoModelVersionStatus.keys()}
|
56
|
+
_STATUS_TO_STRING = {value: key for key, value in _STRING_TO_STATUS.items()}
|
57
|
+
_ACTIVE_CATALOG_QUERY = "SELECT current_catalog() AS catalog"
|
58
|
+
_ACTIVE_SCHEMA_QUERY = "SELECT current_database() AS schema"
|
59
|
+
|
60
|
+
|
61
|
+
def uc_model_version_status_to_string(status):
|
62
|
+
return _STATUS_TO_STRING[status]
|
63
|
+
|
64
|
+
|
65
|
+
def model_version_from_uc_proto(uc_proto: ProtoModelVersion) -> ModelVersion:
|
66
|
+
return ModelVersion(
|
67
|
+
name=uc_proto.name,
|
68
|
+
version=uc_proto.version,
|
69
|
+
creation_timestamp=uc_proto.creation_timestamp,
|
70
|
+
last_updated_timestamp=uc_proto.last_updated_timestamp,
|
71
|
+
description=uc_proto.description,
|
72
|
+
user_id=uc_proto.user_id,
|
73
|
+
source=uc_proto.source,
|
74
|
+
run_id=uc_proto.run_id,
|
75
|
+
status=uc_model_version_status_to_string(uc_proto.status),
|
76
|
+
status_message=uc_proto.status_message,
|
77
|
+
aliases=[alias.alias for alias in (uc_proto.aliases or [])],
|
78
|
+
tags=[ModelVersionTag(key=tag.key, value=tag.value) for tag in (uc_proto.tags or [])],
|
79
|
+
model_id=uc_proto.model_id,
|
80
|
+
params=[
|
81
|
+
ModelParam(key=param.name, value=param.value) for param in (uc_proto.model_params or [])
|
82
|
+
],
|
83
|
+
metrics=[
|
84
|
+
Metric(
|
85
|
+
key=metric.key,
|
86
|
+
value=metric.value,
|
87
|
+
timestamp=metric.timestamp,
|
88
|
+
step=metric.step,
|
89
|
+
dataset_name=metric.dataset_name,
|
90
|
+
dataset_digest=metric.dataset_digest,
|
91
|
+
model_id=metric.model_id,
|
92
|
+
run_id=metric.run_id,
|
93
|
+
)
|
94
|
+
for metric in (uc_proto.model_metrics or [])
|
95
|
+
],
|
96
|
+
deployment_job_state=ModelVersionDeploymentJobState.from_proto(
|
97
|
+
uc_proto.deployment_job_state
|
98
|
+
),
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
def model_version_search_from_uc_proto(uc_proto: ProtoModelVersion) -> ModelVersionSearch:
|
103
|
+
return ModelVersionSearch(
|
104
|
+
name=uc_proto.name,
|
105
|
+
version=uc_proto.version,
|
106
|
+
creation_timestamp=uc_proto.creation_timestamp,
|
107
|
+
last_updated_timestamp=uc_proto.last_updated_timestamp,
|
108
|
+
description=uc_proto.description,
|
109
|
+
user_id=uc_proto.user_id,
|
110
|
+
source=uc_proto.source,
|
111
|
+
run_id=uc_proto.run_id,
|
112
|
+
status=uc_model_version_status_to_string(uc_proto.status),
|
113
|
+
status_message=uc_proto.status_message,
|
114
|
+
aliases=[],
|
115
|
+
tags=[],
|
116
|
+
deployment_job_state=ModelVersionDeploymentJobState.from_proto(
|
117
|
+
uc_proto.deployment_job_state
|
118
|
+
),
|
119
|
+
)
|
120
|
+
|
121
|
+
|
122
|
+
def registered_model_from_uc_proto(uc_proto: ProtoRegisteredModel) -> RegisteredModel:
|
123
|
+
return RegisteredModel(
|
124
|
+
name=uc_proto.name,
|
125
|
+
creation_timestamp=uc_proto.creation_timestamp,
|
126
|
+
last_updated_timestamp=uc_proto.last_updated_timestamp,
|
127
|
+
description=uc_proto.description,
|
128
|
+
aliases=[
|
129
|
+
RegisteredModelAlias(alias=alias.alias, version=alias.version)
|
130
|
+
for alias in (uc_proto.aliases or [])
|
131
|
+
],
|
132
|
+
tags=[RegisteredModelTag(key=tag.key, value=tag.value) for tag in (uc_proto.tags or [])],
|
133
|
+
deployment_job_id=uc_proto.deployment_job_id,
|
134
|
+
deployment_job_state=RegisteredModelDeploymentJobState.to_string(
|
135
|
+
uc_proto.deployment_job_state
|
136
|
+
),
|
137
|
+
)
|
138
|
+
|
139
|
+
|
140
|
+
def registered_model_search_from_uc_proto(uc_proto: ProtoRegisteredModel) -> RegisteredModelSearch:
|
141
|
+
return RegisteredModelSearch(
|
142
|
+
name=uc_proto.name,
|
143
|
+
creation_timestamp=uc_proto.creation_timestamp,
|
144
|
+
last_updated_timestamp=uc_proto.last_updated_timestamp,
|
145
|
+
description=uc_proto.description,
|
146
|
+
aliases=[],
|
147
|
+
tags=[],
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
def uc_registered_model_tag_from_mlflow_tags(
|
152
|
+
tags: Optional[list[RegisteredModelTag]],
|
153
|
+
) -> list[ProtoRegisteredModelTag]:
|
154
|
+
if tags is None:
|
155
|
+
return []
|
156
|
+
return [ProtoRegisteredModelTag(key=t.key, value=t.value) for t in tags]
|
157
|
+
|
158
|
+
|
159
|
+
def uc_model_version_tag_from_mlflow_tags(
|
160
|
+
tags: Optional[list[ModelVersionTag]],
|
161
|
+
) -> list[ProtoModelVersionTag]:
|
162
|
+
if tags is None:
|
163
|
+
return []
|
164
|
+
return [ProtoModelVersionTag(key=t.key, value=t.value) for t in tags]
|
165
|
+
|
166
|
+
|
167
|
+
def get_artifact_repo_from_storage_info(
|
168
|
+
storage_location: str,
|
169
|
+
scoped_token: TemporaryCredentials,
|
170
|
+
base_credential_refresh_def: Callable[[], TemporaryCredentials],
|
171
|
+
is_oss: bool = False,
|
172
|
+
) -> ArtifactRepository:
|
173
|
+
"""
|
174
|
+
Get an ArtifactRepository instance capable of reading/writing to a UC model version's
|
175
|
+
file storage location
|
176
|
+
|
177
|
+
Args:
|
178
|
+
storage_location: Storage location of the model version
|
179
|
+
scoped_token: Protobuf scoped token to use to authenticate to blob storage
|
180
|
+
base_credential_refresh_def: Function that returns temporary credentials for accessing blob
|
181
|
+
storage. It is first used to determine the type of blob storage and to access it. It is
|
182
|
+
then passed to the relevant ArtifactRepository implementation to refresh credentials as
|
183
|
+
needed.
|
184
|
+
is_oss: Whether the user is using the OSS version of Unity Catalog
|
185
|
+
"""
|
186
|
+
try:
|
187
|
+
if is_oss:
|
188
|
+
return _get_artifact_repo_from_storage_info_oss(
|
189
|
+
storage_location=storage_location,
|
190
|
+
scoped_token=scoped_token,
|
191
|
+
base_credential_refresh_def=base_credential_refresh_def,
|
192
|
+
)
|
193
|
+
else:
|
194
|
+
return _get_artifact_repo_from_storage_info(
|
195
|
+
storage_location=storage_location,
|
196
|
+
scoped_token=scoped_token,
|
197
|
+
base_credential_refresh_def=base_credential_refresh_def,
|
198
|
+
)
|
199
|
+
except ImportError as e:
|
200
|
+
raise MlflowException(
|
201
|
+
"Unable to import necessary dependencies to access model version files in "
|
202
|
+
"Unity Catalog. Please ensure you have the necessary dependencies installed, "
|
203
|
+
"e.g. by running 'pip install mlflow[databricks]' or "
|
204
|
+
"'pip install mlflow-skinny[databricks]'"
|
205
|
+
) from e
|
206
|
+
|
207
|
+
|
208
|
+
def _get_artifact_repo_from_storage_info(
|
209
|
+
storage_location: str,
|
210
|
+
scoped_token: TemporaryCredentials,
|
211
|
+
base_credential_refresh_def: Callable[[], TemporaryCredentials],
|
212
|
+
) -> ArtifactRepository:
|
213
|
+
credential_type = scoped_token.WhichOneof("credentials")
|
214
|
+
if credential_type == "aws_temp_credentials":
|
215
|
+
# Verify upfront that boto3 is importable
|
216
|
+
import boto3 # noqa: F401
|
217
|
+
|
218
|
+
from mlflow.store.artifact.optimized_s3_artifact_repo import OptimizedS3ArtifactRepository
|
219
|
+
|
220
|
+
aws_creds = scoped_token.aws_temp_credentials
|
221
|
+
s3_upload_extra_args = _parse_aws_sse_credential(scoped_token)
|
222
|
+
|
223
|
+
def aws_credential_refresh():
|
224
|
+
new_scoped_token = base_credential_refresh_def()
|
225
|
+
new_aws_creds = new_scoped_token.aws_temp_credentials
|
226
|
+
new_s3_upload_extra_args = _parse_aws_sse_credential(new_scoped_token)
|
227
|
+
return {
|
228
|
+
"access_key_id": new_aws_creds.access_key_id,
|
229
|
+
"secret_access_key": new_aws_creds.secret_access_key,
|
230
|
+
"session_token": new_aws_creds.session_token,
|
231
|
+
"s3_upload_extra_args": new_s3_upload_extra_args,
|
232
|
+
}
|
233
|
+
|
234
|
+
return OptimizedS3ArtifactRepository(
|
235
|
+
artifact_uri=storage_location,
|
236
|
+
access_key_id=aws_creds.access_key_id,
|
237
|
+
secret_access_key=aws_creds.secret_access_key,
|
238
|
+
session_token=aws_creds.session_token,
|
239
|
+
credential_refresh_def=aws_credential_refresh,
|
240
|
+
s3_upload_extra_args=s3_upload_extra_args,
|
241
|
+
)
|
242
|
+
elif credential_type == "azure_user_delegation_sas":
|
243
|
+
from azure.core.credentials import AzureSasCredential
|
244
|
+
|
245
|
+
from mlflow.store.artifact.azure_data_lake_artifact_repo import (
|
246
|
+
AzureDataLakeArtifactRepository,
|
247
|
+
)
|
248
|
+
|
249
|
+
sas_token = scoped_token.azure_user_delegation_sas.sas_token
|
250
|
+
|
251
|
+
def azure_credential_refresh():
|
252
|
+
new_scoped_token = base_credential_refresh_def()
|
253
|
+
new_sas_token = new_scoped_token.azure_user_delegation_sas.sas_token
|
254
|
+
return {
|
255
|
+
"credential": AzureSasCredential(new_sas_token),
|
256
|
+
}
|
257
|
+
|
258
|
+
return AzureDataLakeArtifactRepository(
|
259
|
+
artifact_uri=storage_location,
|
260
|
+
credential=AzureSasCredential(sas_token),
|
261
|
+
credential_refresh_def=azure_credential_refresh,
|
262
|
+
)
|
263
|
+
|
264
|
+
elif credential_type == "gcp_oauth_token":
|
265
|
+
from google.cloud.storage import Client
|
266
|
+
from google.oauth2.credentials import Credentials
|
267
|
+
|
268
|
+
from mlflow.store.artifact.gcs_artifact_repo import GCSArtifactRepository
|
269
|
+
|
270
|
+
credentials = Credentials(scoped_token.gcp_oauth_token.oauth_token)
|
271
|
+
|
272
|
+
def gcp_credential_refresh():
|
273
|
+
new_scoped_token = base_credential_refresh_def()
|
274
|
+
new_gcp_creds = new_scoped_token.gcp_oauth_token
|
275
|
+
return {
|
276
|
+
"oauth_token": new_gcp_creds.oauth_token,
|
277
|
+
}
|
278
|
+
|
279
|
+
client = Client(project="mlflow", credentials=credentials)
|
280
|
+
return GCSArtifactRepository(
|
281
|
+
artifact_uri=storage_location,
|
282
|
+
client=client,
|
283
|
+
credential_refresh_def=gcp_credential_refresh,
|
284
|
+
)
|
285
|
+
elif credential_type == "r2_temp_credentials":
|
286
|
+
from mlflow.store.artifact.r2_artifact_repo import R2ArtifactRepository
|
287
|
+
|
288
|
+
r2_creds = scoped_token.r2_temp_credentials
|
289
|
+
|
290
|
+
def r2_credential_refresh():
|
291
|
+
new_scoped_token = base_credential_refresh_def()
|
292
|
+
new_r2_creds = new_scoped_token.r2_temp_credentials
|
293
|
+
return {
|
294
|
+
"access_key_id": new_r2_creds.access_key_id,
|
295
|
+
"secret_access_key": new_r2_creds.secret_access_key,
|
296
|
+
"session_token": new_r2_creds.session_token,
|
297
|
+
}
|
298
|
+
|
299
|
+
return R2ArtifactRepository(
|
300
|
+
artifact_uri=storage_location,
|
301
|
+
access_key_id=r2_creds.access_key_id,
|
302
|
+
secret_access_key=r2_creds.secret_access_key,
|
303
|
+
session_token=r2_creds.session_token,
|
304
|
+
credential_refresh_def=r2_credential_refresh,
|
305
|
+
)
|
306
|
+
else:
|
307
|
+
raise MlflowException(
|
308
|
+
f"Got unexpected credential type {credential_type} when attempting to "
|
309
|
+
"access model version files in Unity Catalog. Try upgrading to the latest "
|
310
|
+
"version of the MLflow Python client."
|
311
|
+
)
|
312
|
+
|
313
|
+
|
314
|
+
def _get_artifact_repo_from_storage_info_oss(
|
315
|
+
storage_location: str,
|
316
|
+
scoped_token: TemporaryCredentialsOSS,
|
317
|
+
base_credential_refresh_def: Callable[[], TemporaryCredentialsOSS],
|
318
|
+
) -> ArtifactRepository:
|
319
|
+
# OSS Temp Credential doesn't have a oneof credential field
|
320
|
+
# So, we must check for the individual cloud credentials
|
321
|
+
if len(scoped_token.aws_temp_credentials.access_key_id) > 0:
|
322
|
+
# Verify upfront that boto3 is importable
|
323
|
+
import boto3 # noqa: F401
|
324
|
+
|
325
|
+
from mlflow.store.artifact.optimized_s3_artifact_repo import OptimizedS3ArtifactRepository
|
326
|
+
|
327
|
+
aws_creds = scoped_token.aws_temp_credentials
|
328
|
+
|
329
|
+
def aws_credential_refresh():
|
330
|
+
new_scoped_token = base_credential_refresh_def()
|
331
|
+
new_aws_creds = new_scoped_token.aws_temp_credentials
|
332
|
+
return {
|
333
|
+
"access_key_id": new_aws_creds.access_key_id,
|
334
|
+
"secret_access_key": new_aws_creds.secret_access_key,
|
335
|
+
"session_token": new_aws_creds.session_token,
|
336
|
+
}
|
337
|
+
|
338
|
+
return OptimizedS3ArtifactRepository(
|
339
|
+
artifact_uri=storage_location,
|
340
|
+
access_key_id=aws_creds.access_key_id,
|
341
|
+
secret_access_key=aws_creds.secret_access_key,
|
342
|
+
session_token=aws_creds.session_token,
|
343
|
+
credential_refresh_def=aws_credential_refresh,
|
344
|
+
)
|
345
|
+
elif len(scoped_token.azure_user_delegation_sas.sas_token) > 0:
|
346
|
+
from azure.core.credentials import AzureSasCredential
|
347
|
+
|
348
|
+
from mlflow.store.artifact.azure_data_lake_artifact_repo import (
|
349
|
+
AzureDataLakeArtifactRepository,
|
350
|
+
)
|
351
|
+
|
352
|
+
sas_token = scoped_token.azure_user_delegation_sas.sas_token
|
353
|
+
|
354
|
+
def azure_credential_refresh():
|
355
|
+
new_scoped_token = base_credential_refresh_def()
|
356
|
+
new_sas_token = new_scoped_token.azure_user_delegation_sas.sas_token
|
357
|
+
return {
|
358
|
+
"credential": AzureSasCredential(new_sas_token),
|
359
|
+
}
|
360
|
+
|
361
|
+
return AzureDataLakeArtifactRepository(
|
362
|
+
artifact_uri=storage_location,
|
363
|
+
credential=AzureSasCredential(sas_token),
|
364
|
+
credential_refresh_def=azure_credential_refresh,
|
365
|
+
)
|
366
|
+
|
367
|
+
elif len(scoped_token.gcp_oauth_token.oauth_token) > 0:
|
368
|
+
from google.cloud.storage import Client
|
369
|
+
from google.oauth2.credentials import Credentials
|
370
|
+
|
371
|
+
from mlflow.store.artifact.gcs_artifact_repo import GCSArtifactRepository
|
372
|
+
|
373
|
+
credentials = Credentials(scoped_token.gcp_oauth_token.oauth_token)
|
374
|
+
client = Client(project="mlflow", credentials=credentials)
|
375
|
+
return GCSArtifactRepository(artifact_uri=storage_location, client=client)
|
376
|
+
else:
|
377
|
+
raise MlflowException(
|
378
|
+
"Got no credential type when attempting to "
|
379
|
+
"access model version files in Unity Catalog. Try upgrading to the latest "
|
380
|
+
"version of the MLflow Python client."
|
381
|
+
)
|
382
|
+
|
383
|
+
|
384
|
+
def _parse_aws_sse_credential(scoped_token: TemporaryCredentials):
|
385
|
+
encryption_details = scoped_token.encryption_details
|
386
|
+
if not encryption_details:
|
387
|
+
return {}
|
388
|
+
|
389
|
+
if encryption_details.WhichOneof("encryption_details_type") != "sse_encryption_details":
|
390
|
+
return {}
|
391
|
+
|
392
|
+
sse_encryption_details = encryption_details.sse_encryption_details
|
393
|
+
|
394
|
+
if sse_encryption_details.algorithm == SseEncryptionAlgorithm.AWS_SSE_S3:
|
395
|
+
return {
|
396
|
+
"ServerSideEncryption": "AES256",
|
397
|
+
}
|
398
|
+
if sse_encryption_details.algorithm == SseEncryptionAlgorithm.AWS_SSE_KMS:
|
399
|
+
key_id = sse_encryption_details.aws_kms_key_arn.split("/")[-1]
|
400
|
+
return {
|
401
|
+
"ServerSideEncryption": "aws:kms",
|
402
|
+
"SSEKMSKeyId": key_id,
|
403
|
+
}
|
404
|
+
else:
|
405
|
+
return {}
|
406
|
+
|
407
|
+
|
408
|
+
def get_full_name_from_sc(name, spark) -> str:
|
409
|
+
"""
|
410
|
+
Constructs the full name of a registered model using the active catalog and schema in a spark
|
411
|
+
session / context.
|
412
|
+
|
413
|
+
Args:
|
414
|
+
name: The model name provided by the user.
|
415
|
+
spark: The active spark session.
|
416
|
+
"""
|
417
|
+
num_levels = len(name.split("."))
|
418
|
+
if num_levels >= 3 or spark is None:
|
419
|
+
return name
|
420
|
+
catalog = spark.sql(_ACTIVE_CATALOG_QUERY).collect()[0]["catalog"]
|
421
|
+
# return the user provided name if the catalog is the hive metastore default
|
422
|
+
if catalog in {"spark_catalog", "hive_metastore"}:
|
423
|
+
return name
|
424
|
+
if num_levels == 2:
|
425
|
+
return f"{catalog}.{name}"
|
426
|
+
schema = spark.sql(_ACTIVE_SCHEMA_QUERY).collect()[0]["schema"]
|
427
|
+
return f"{catalog}.{schema}.{name}"
|
428
|
+
|
429
|
+
|
430
|
+
def is_databricks_sdk_models_artifact_repository_enabled(host_creds):
|
431
|
+
# Return early if the environment variable is set to use the SDK models artifact repository
|
432
|
+
if MLFLOW_USE_DATABRICKS_SDK_MODEL_ARTIFACTS_REPO_FOR_UC.defined:
|
433
|
+
return MLFLOW_USE_DATABRICKS_SDK_MODEL_ARTIFACTS_REPO_FOR_UC.get()
|
434
|
+
|
435
|
+
endpoint, method = _METHOD_TO_INFO[IsDatabricksSdkModelsArtifactRepositoryEnabledRequest]
|
436
|
+
req_body = message_to_json(IsDatabricksSdkModelsArtifactRepositoryEnabledRequest())
|
437
|
+
response_proto = IsDatabricksSdkModelsArtifactRepositoryEnabledResponse()
|
438
|
+
|
439
|
+
try:
|
440
|
+
resp = call_endpoint(
|
441
|
+
host_creds=host_creds,
|
442
|
+
endpoint=endpoint,
|
443
|
+
method=method,
|
444
|
+
json_body=req_body,
|
445
|
+
response_proto=response_proto,
|
446
|
+
)
|
447
|
+
return resp.is_databricks_sdk_models_artifact_repository_enabled
|
448
|
+
except Exception as e:
|
449
|
+
_logger.warning(
|
450
|
+
"Failed to confirm if DatabricksSDKModelsArtifactRepository should be used; "
|
451
|
+
f"falling back to default. Error: {e}"
|
452
|
+
)
|
453
|
+
return False
|
454
|
+
|
455
|
+
|
456
|
+
def emit_model_version_lineage(host_creds, name, version, entities, direction):
|
457
|
+
endpoint, method = _METHOD_TO_INFO[EmitModelVersionLineageRequest]
|
458
|
+
|
459
|
+
req_body = message_to_json(
|
460
|
+
EmitModelVersionLineageRequest(
|
461
|
+
name=name,
|
462
|
+
version=version,
|
463
|
+
model_version_lineage_info=ModelVersionLineageInfo(
|
464
|
+
entities=entities,
|
465
|
+
direction=direction,
|
466
|
+
),
|
467
|
+
)
|
468
|
+
)
|
469
|
+
response_proto = EmitModelVersionLineageResponse()
|
470
|
+
try:
|
471
|
+
call_endpoint(
|
472
|
+
host_creds=host_creds,
|
473
|
+
endpoint=endpoint,
|
474
|
+
method=method,
|
475
|
+
json_body=req_body,
|
476
|
+
response_proto=response_proto,
|
477
|
+
)
|
478
|
+
except Exception as e:
|
479
|
+
_logger.warning(f"Failed to emit best-effort model version lineage. Error: {e}")
|
@@ -0,0 +1,218 @@
|
|
1
|
+
import inspect
|
2
|
+
import re
|
3
|
+
import types
|
4
|
+
import warnings
|
5
|
+
from functools import wraps
|
6
|
+
from typing import Callable, Optional, TypeVar
|
7
|
+
|
8
|
+
from typing_extensions import ParamSpec
|
9
|
+
|
10
|
+
|
11
|
+
def _get_min_indent_of_docstring(docstring_str: str) -> str:
|
12
|
+
"""
|
13
|
+
Get the minimum indentation string of a docstring, based on the assumption
|
14
|
+
that the closing triple quote for multiline comments must be on a new line.
|
15
|
+
Note that based on ruff rule D209, the closing triple quote for multiline
|
16
|
+
comments must be on a new line.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
docstring_str: string with docstring
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
Whitespace corresponding to the indent of a docstring.
|
23
|
+
"""
|
24
|
+
|
25
|
+
if not docstring_str or "\n" not in docstring_str:
|
26
|
+
return ""
|
27
|
+
|
28
|
+
return re.match(r"^\s*", docstring_str.rsplit("\n", 1)[-1]).group()
|
29
|
+
|
30
|
+
|
31
|
+
P = ParamSpec("P")
|
32
|
+
R = TypeVar("R")
|
33
|
+
|
34
|
+
|
35
|
+
def experimental(
|
36
|
+
f: Optional[Callable[P, R]] = None,
|
37
|
+
version: Optional[str] = None,
|
38
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
39
|
+
"""Decorator / decorator creator for marking APIs experimental in the docstring.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
f: The function to be decorated.
|
43
|
+
version: The version in which the API was introduced as experimental.
|
44
|
+
The version is used to determine whether the API should be considered
|
45
|
+
as stable or not when releasing a new version of MLflow.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
A decorator that adds a note to the docstring of the decorated API,
|
49
|
+
"""
|
50
|
+
if f:
|
51
|
+
return _experimental(f)
|
52
|
+
else:
|
53
|
+
|
54
|
+
def decorator(f: Callable[P, R]) -> Callable[P, R]:
|
55
|
+
return _experimental(f)
|
56
|
+
|
57
|
+
return decorator
|
58
|
+
|
59
|
+
|
60
|
+
def _experimental(api: Callable[P, R]) -> Callable[P, R]:
|
61
|
+
if inspect.isclass(api):
|
62
|
+
api_type = "class"
|
63
|
+
elif inspect.isfunction(api):
|
64
|
+
api_type = "function"
|
65
|
+
elif isinstance(api, (property, types.MethodType)):
|
66
|
+
api_type = "property"
|
67
|
+
else:
|
68
|
+
api_type = str(type(api))
|
69
|
+
|
70
|
+
indent = _get_min_indent_of_docstring(api.__doc__) if api.__doc__ else ""
|
71
|
+
notice = (
|
72
|
+
indent + f".. Note:: Experimental: This {api_type} may change or "
|
73
|
+
"be removed in a future release without warning.\n\n"
|
74
|
+
)
|
75
|
+
if api_type == "property":
|
76
|
+
api.__doc__ = api.__doc__ + "\n\n" + notice if api.__doc__ else notice
|
77
|
+
else:
|
78
|
+
api.__doc__ = notice + api.__doc__ if api.__doc__ else notice
|
79
|
+
return api
|
80
|
+
|
81
|
+
|
82
|
+
def developer_stable(func):
|
83
|
+
"""
|
84
|
+
The API marked here as `@developer_stable` has certain protections associated with future
|
85
|
+
development work.
|
86
|
+
Classes marked with this decorator implicitly apply this status to all methods contained within
|
87
|
+
them.
|
88
|
+
|
89
|
+
APIs that are annotated with this decorator are guaranteed (except in cases of notes below) to:
|
90
|
+
- maintain backwards compatibility such that earlier versions of any MLflow client, cli, or
|
91
|
+
server will not have issues with any changes being made to them from an interface perspective.
|
92
|
+
- maintain a consistent contract with respect to existing named arguments such that
|
93
|
+
modifications will not alter or remove an existing named argument.
|
94
|
+
- maintain implied or declared types of arguments within its signature.
|
95
|
+
- maintain consistent behavior with respect to return types.
|
96
|
+
|
97
|
+
Note: Should an API marked as `@developer_stable` require a modification for enhanced feature
|
98
|
+
functionality, a deprecation warning will be added to the API well in advance of its
|
99
|
+
modification.
|
100
|
+
|
101
|
+
Note: Should an API marked as `@developer_stable` require patching for any security reason,
|
102
|
+
advanced notice is not guaranteed and the labeling of such API as stable will be ignored
|
103
|
+
for the sake of such a security patch.
|
104
|
+
|
105
|
+
"""
|
106
|
+
return func
|
107
|
+
|
108
|
+
|
109
|
+
_DEPRECATED_MARK_ATTR_NAME = "__deprecated"
|
110
|
+
|
111
|
+
|
112
|
+
def mark_deprecated(func):
|
113
|
+
"""
|
114
|
+
Mark a function as deprecated by setting a private attribute on it.
|
115
|
+
"""
|
116
|
+
setattr(func, _DEPRECATED_MARK_ATTR_NAME, True)
|
117
|
+
|
118
|
+
|
119
|
+
def is_marked_deprecated(func):
|
120
|
+
"""
|
121
|
+
Is the function marked as deprecated.
|
122
|
+
"""
|
123
|
+
return getattr(func, _DEPRECATED_MARK_ATTR_NAME, False)
|
124
|
+
|
125
|
+
|
126
|
+
def deprecated(
|
127
|
+
alternative: Optional[str] = None, since: Optional[str] = None, impact: Optional[str] = None
|
128
|
+
):
|
129
|
+
"""Annotation decorator for marking APIs as deprecated in docstrings and raising a warning if
|
130
|
+
called.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
alternative: The name of a superseded replacement function, method,
|
134
|
+
or class to use in place of the deprecated one.
|
135
|
+
since: A version designator defining during which release the function,
|
136
|
+
method, or class was marked as deprecated.
|
137
|
+
impact: Indication of whether the method, function, or class will be
|
138
|
+
removed in a future release.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
Decorated function or class.
|
142
|
+
"""
|
143
|
+
|
144
|
+
def deprecated_decorator(obj):
|
145
|
+
since_str = f" since {since}" if since else ""
|
146
|
+
impact_str = impact if impact else "This method will be removed in a future release."
|
147
|
+
|
148
|
+
qual_name = f"{obj.__module__}.{obj.__qualname__}"
|
149
|
+
notice = f"``{qual_name}`` is deprecated{since_str}. {impact_str}"
|
150
|
+
if alternative and alternative.strip():
|
151
|
+
notice += f" Use ``{alternative}`` instead."
|
152
|
+
|
153
|
+
if inspect.isclass(obj):
|
154
|
+
original_init = obj.__init__
|
155
|
+
|
156
|
+
@wraps(original_init)
|
157
|
+
def new_init(self, *args, **kwargs):
|
158
|
+
warnings.warn(notice, category=FutureWarning, stacklevel=2)
|
159
|
+
original_init(self, *args, **kwargs)
|
160
|
+
|
161
|
+
obj.__init__ = new_init
|
162
|
+
|
163
|
+
if obj.__doc__:
|
164
|
+
obj.__doc__ = f".. Warning:: {notice}\n{obj.__doc__}"
|
165
|
+
else:
|
166
|
+
obj.__doc__ = f".. Warning:: {notice}"
|
167
|
+
|
168
|
+
mark_deprecated(obj)
|
169
|
+
return obj
|
170
|
+
|
171
|
+
elif isinstance(obj, (types.FunctionType, types.MethodType)):
|
172
|
+
|
173
|
+
@wraps(obj)
|
174
|
+
def deprecated_func(*args, **kwargs):
|
175
|
+
warnings.warn(notice, category=FutureWarning, stacklevel=2)
|
176
|
+
return obj(*args, **kwargs)
|
177
|
+
|
178
|
+
if obj.__doc__:
|
179
|
+
indent = _get_min_indent_of_docstring(obj.__doc__)
|
180
|
+
deprecated_func.__doc__ = f"{indent}.. Warning:: {notice}\n{obj.__doc__}"
|
181
|
+
else:
|
182
|
+
deprecated_func.__doc__ = f".. Warning:: {notice}"
|
183
|
+
|
184
|
+
mark_deprecated(deprecated_func)
|
185
|
+
return deprecated_func
|
186
|
+
|
187
|
+
else:
|
188
|
+
return obj
|
189
|
+
|
190
|
+
return deprecated_decorator
|
191
|
+
|
192
|
+
|
193
|
+
def keyword_only(func):
|
194
|
+
"""A decorator that forces keyword arguments in the wrapped method."""
|
195
|
+
|
196
|
+
@wraps(func)
|
197
|
+
def wrapper(*args, **kwargs):
|
198
|
+
if len(args) > 0:
|
199
|
+
raise TypeError(f"Method {func.__name__} only takes keyword arguments.")
|
200
|
+
return func(**kwargs)
|
201
|
+
|
202
|
+
indent = _get_min_indent_of_docstring(wrapper.__doc__) if wrapper.__doc__ else ""
|
203
|
+
notice = indent + ".. note:: This method requires all argument be specified by keyword.\n"
|
204
|
+
wrapper.__doc__ = notice + wrapper.__doc__ if wrapper.__doc__ else notice
|
205
|
+
|
206
|
+
return wrapper
|
207
|
+
|
208
|
+
|
209
|
+
def filter_user_warnings_once(func):
|
210
|
+
"""A decorator that filter user warnings to only show once in the wrapped method."""
|
211
|
+
|
212
|
+
@wraps(func)
|
213
|
+
def wrapper(*args, **kwargs):
|
214
|
+
with warnings.catch_warnings():
|
215
|
+
warnings.simplefilter("once", category=UserWarning)
|
216
|
+
return func(*args, **kwargs)
|
217
|
+
|
218
|
+
return wrapper
|