genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,258 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional
|
6
|
+
|
7
|
+
from mlflow.exceptions import MlflowException
|
8
|
+
from mlflow.protos.databricks_pb2 import ALREADY_EXISTS, INVALID_PARAMETER_VALUE
|
9
|
+
from mlflow.transformers.hub_utils import get_latest_commit_for_repo
|
10
|
+
from mlflow.transformers.peft import _PEFT_ADAPTOR_DIR_NAME, get_peft_base_model, is_peft_model
|
11
|
+
from mlflow.transformers.torch_utils import _extract_torch_dtype_if_set
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
import transformers
|
15
|
+
|
16
|
+
|
17
|
+
# Flavor configuration keys
|
18
|
+
class FlavorKey:
|
19
|
+
TASK = "task"
|
20
|
+
INSTANCE_TYPE = "instance_type"
|
21
|
+
TORCH_DTYPE = "torch_dtype"
|
22
|
+
FRAMEWORK = "framework"
|
23
|
+
|
24
|
+
MODEL = "model"
|
25
|
+
MODEL_TYPE = "pipeline_model_type"
|
26
|
+
MODEL_BINARY = "model_binary"
|
27
|
+
MODEL_NAME = "source_model_name"
|
28
|
+
MODEL_REVISION = "source_model_revision"
|
29
|
+
|
30
|
+
PEFT = "peft_adaptor"
|
31
|
+
|
32
|
+
COMPONENTS = "components"
|
33
|
+
COMPONENT_NAME = "{}_name" # e.g. tokenizer_name
|
34
|
+
COMPONENT_REVISION = "{}_revision"
|
35
|
+
COMPONENT_TYPE = "{}_type"
|
36
|
+
TOKENIZER = "tokenizer"
|
37
|
+
FEATURE_EXTRACTOR = "feature_extractor"
|
38
|
+
IMAGE_PROCESSOR = "image_processor"
|
39
|
+
PROCESSOR = "processor"
|
40
|
+
PROCESSOR_TYPE = "processor_type"
|
41
|
+
|
42
|
+
PROMPT_TEMPLATE = "prompt_template"
|
43
|
+
|
44
|
+
|
45
|
+
def build_flavor_config(
|
46
|
+
pipeline: transformers.Pipeline, processor=None, torch_dtype=None, save_pretrained=True
|
47
|
+
) -> dict[str, Any]:
|
48
|
+
"""
|
49
|
+
Generates the base flavor metadata needed for reconstructing a pipeline from saved
|
50
|
+
components. This is important because the ``Pipeline`` class does not have a loader
|
51
|
+
functionality. The serialization of a Pipeline saves the model, configurations, and
|
52
|
+
metadata for ``FeatureExtractor``s, ``Processor``s, and ``Tokenizer``s exclusively.
|
53
|
+
This function extracts key information from the submitted model object so that the precise
|
54
|
+
instance types can be loaded correctly.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
pipeline: Transformer pipeline to generate the flavor configuration for.
|
58
|
+
processor: Optional processor instance to save alongside the pipeline.
|
59
|
+
torch_dtype: Torch tensor data type.
|
60
|
+
save_pretrained: Whether to save the pipeline and components weights to local disk.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
A dictionary containing the flavor configuration for the pipeline and its components,
|
64
|
+
i.e. the configurations stored in "transformers" key in the MLModel YAML file.
|
65
|
+
"""
|
66
|
+
flavor_conf = _generate_base_config(pipeline, torch_dtype=torch_dtype)
|
67
|
+
|
68
|
+
if is_peft_model(pipeline.model):
|
69
|
+
flavor_conf[FlavorKey.PEFT] = _PEFT_ADAPTOR_DIR_NAME
|
70
|
+
model = get_peft_base_model(pipeline.model)
|
71
|
+
else:
|
72
|
+
model = pipeline.model
|
73
|
+
|
74
|
+
flavor_conf.update(_get_model_config(model, save_pretrained))
|
75
|
+
|
76
|
+
components = _get_components_from_pipeline(pipeline, processor)
|
77
|
+
for key, instance in components.items():
|
78
|
+
# Some components don't have name_or_path, then we fallback to the one from the model.
|
79
|
+
flavor_conf.update(
|
80
|
+
_get_component_config(instance, key, save_pretrained, default_repo=model.name_or_path)
|
81
|
+
)
|
82
|
+
|
83
|
+
# "components" field doesn't include processor
|
84
|
+
components.pop(FlavorKey.PROCESSOR, None)
|
85
|
+
flavor_conf[FlavorKey.COMPONENTS] = list(components.keys())
|
86
|
+
|
87
|
+
return flavor_conf
|
88
|
+
|
89
|
+
|
90
|
+
def _generate_base_config(pipeline, torch_dtype=None):
|
91
|
+
flavor_conf = {
|
92
|
+
FlavorKey.TASK: pipeline.task,
|
93
|
+
FlavorKey.INSTANCE_TYPE: _get_instance_type(pipeline),
|
94
|
+
}
|
95
|
+
|
96
|
+
if framework := getattr(pipeline, "framework", None):
|
97
|
+
flavor_conf[FlavorKey.FRAMEWORK] = framework
|
98
|
+
|
99
|
+
# User-provided torch_dtype takes precedence
|
100
|
+
if torch_dtype := (torch_dtype or _extract_torch_dtype_if_set(pipeline)):
|
101
|
+
flavor_conf[FlavorKey.TORCH_DTYPE] = str(torch_dtype)
|
102
|
+
|
103
|
+
return flavor_conf
|
104
|
+
|
105
|
+
|
106
|
+
def _get_model_config(model, save_pretrained=True):
|
107
|
+
conf = {
|
108
|
+
FlavorKey.MODEL_TYPE: _get_instance_type(model),
|
109
|
+
FlavorKey.MODEL_NAME: model.name_or_path,
|
110
|
+
}
|
111
|
+
|
112
|
+
if save_pretrained:
|
113
|
+
# log local path to model binary file
|
114
|
+
from mlflow.transformers.model_io import _MODEL_BINARY_FILE_NAME
|
115
|
+
|
116
|
+
conf[FlavorKey.MODEL_BINARY] = _MODEL_BINARY_FILE_NAME
|
117
|
+
else:
|
118
|
+
# log HuggingFace repo name and commit hash
|
119
|
+
conf[FlavorKey.MODEL_REVISION] = get_latest_commit_for_repo(model.name_or_path)
|
120
|
+
|
121
|
+
return conf
|
122
|
+
|
123
|
+
|
124
|
+
def _get_component_config(
|
125
|
+
component: Any,
|
126
|
+
key: str,
|
127
|
+
save_pretrained: bool = True,
|
128
|
+
default_repo: Optional[str] = None,
|
129
|
+
commit_sha: Optional[str] = None,
|
130
|
+
):
|
131
|
+
conf = {FlavorKey.COMPONENT_TYPE.format(key): _get_instance_type(component)}
|
132
|
+
|
133
|
+
# Log source repo name and commit sha for the component
|
134
|
+
if not save_pretrained:
|
135
|
+
repo = getattr(component, "name_or_path", default_repo)
|
136
|
+
revision = commit_sha or get_latest_commit_for_repo(repo)
|
137
|
+
conf[FlavorKey.COMPONENT_NAME.format(key)] = repo
|
138
|
+
conf[FlavorKey.COMPONENT_REVISION.format(key)] = revision
|
139
|
+
|
140
|
+
return conf
|
141
|
+
|
142
|
+
|
143
|
+
def _get_components_from_pipeline(pipeline, processor=None):
|
144
|
+
supported_component_names = [
|
145
|
+
FlavorKey.FEATURE_EXTRACTOR,
|
146
|
+
FlavorKey.TOKENIZER,
|
147
|
+
FlavorKey.IMAGE_PROCESSOR,
|
148
|
+
]
|
149
|
+
|
150
|
+
components = {}
|
151
|
+
for name in supported_component_names:
|
152
|
+
if instance := getattr(pipeline, name, None):
|
153
|
+
components[name] = instance
|
154
|
+
|
155
|
+
if processor:
|
156
|
+
components[FlavorKey.PROCESSOR] = processor
|
157
|
+
|
158
|
+
return components
|
159
|
+
|
160
|
+
|
161
|
+
def _get_instance_type(obj):
|
162
|
+
"""
|
163
|
+
Utility for extracting the saved object type or, if the `base` argument is set to `True`,
|
164
|
+
the base ABC type of the model.
|
165
|
+
"""
|
166
|
+
return obj.__class__.__name__
|
167
|
+
|
168
|
+
|
169
|
+
def build_flavor_config_from_local_checkpoint(
|
170
|
+
local_checkpoint_dir: str,
|
171
|
+
task: str,
|
172
|
+
processor=None,
|
173
|
+
torch_dtype=None,
|
174
|
+
) -> dict[str, Any]:
|
175
|
+
"""
|
176
|
+
Generates the flavor metadata from a Hugging Face model repository ID
|
177
|
+
e.g. "meta-llama/Meta-Llama-3.1-405B, instead of the pipeline instance in-memory.
|
178
|
+
"""
|
179
|
+
from transformers import AutoTokenizer, pipelines
|
180
|
+
from transformers.utils import is_torch_available
|
181
|
+
|
182
|
+
from mlflow.transformers.model_io import _MODEL_BINARY_FILE_NAME
|
183
|
+
|
184
|
+
config_path = os.path.join(local_checkpoint_dir, "config.json")
|
185
|
+
if not os.path.exists(config_path):
|
186
|
+
raise MlflowException(
|
187
|
+
f"The provided directory {local_checkpoint_dir} does not contain a config.json file."
|
188
|
+
"Please ensure that the directory contains a valid transformers model checkpoint.",
|
189
|
+
error_code=INVALID_PARAMETER_VALUE,
|
190
|
+
)
|
191
|
+
|
192
|
+
with open(config_path) as f:
|
193
|
+
config = json.load(f)
|
194
|
+
|
195
|
+
task_metadata = pipelines.check_task(task)
|
196
|
+
pipeline_class = task_metadata[1]["impl"].__name__
|
197
|
+
flavor_conf = {
|
198
|
+
FlavorKey.TASK: task,
|
199
|
+
FlavorKey.INSTANCE_TYPE: pipeline_class,
|
200
|
+
FlavorKey.FRAMEWORK: "pt" if is_torch_available() else "tf",
|
201
|
+
FlavorKey.TORCH_DTYPE: str(torch_dtype) if torch_dtype else None,
|
202
|
+
FlavorKey.MODEL_TYPE: config["architectures"][0],
|
203
|
+
FlavorKey.MODEL_NAME: local_checkpoint_dir,
|
204
|
+
FlavorKey.MODEL_BINARY: _MODEL_BINARY_FILE_NAME,
|
205
|
+
}
|
206
|
+
|
207
|
+
components = {FlavorKey.TOKENIZER}
|
208
|
+
try:
|
209
|
+
tokenizer = AutoTokenizer.from_pretrained(local_checkpoint_dir)
|
210
|
+
except OSError as e:
|
211
|
+
raise MlflowException(
|
212
|
+
f"Error loading tokenizer from {local_checkpoint_dir}. When logging a "
|
213
|
+
"Transformers model from a local checkpoint, please make sure that the "
|
214
|
+
"checkpoint directory contains a valid tokenizer configuration as well.",
|
215
|
+
error_code=INVALID_PARAMETER_VALUE,
|
216
|
+
) from e
|
217
|
+
|
218
|
+
tokenizer_conf = _get_component_config(tokenizer, FlavorKey.TOKENIZER)
|
219
|
+
flavor_conf.update(tokenizer_conf)
|
220
|
+
|
221
|
+
if processor:
|
222
|
+
flavor_conf.update(_get_component_config(processor, FlavorKey.PROCESSOR))
|
223
|
+
|
224
|
+
flavor_conf[FlavorKey.COMPONENTS] = list(components)
|
225
|
+
return flavor_conf
|
226
|
+
|
227
|
+
|
228
|
+
def update_flavor_conf_to_persist_pretrained_model(
|
229
|
+
original_flavor_conf: dict[str, Any],
|
230
|
+
) -> dict[str, Any]:
|
231
|
+
"""
|
232
|
+
Updates the flavor configuration that was saved with save_pretrained=False to the one that
|
233
|
+
includes the local path to the model binary file.
|
234
|
+
"""
|
235
|
+
flavor_conf = original_flavor_conf.copy()
|
236
|
+
|
237
|
+
# Replace model commit path with local path
|
238
|
+
if FlavorKey.MODEL_BINARY in original_flavor_conf:
|
239
|
+
raise MlflowException(
|
240
|
+
"It appears that the pretrained model weight is already saved to the artifact path.",
|
241
|
+
error_code=ALREADY_EXISTS,
|
242
|
+
)
|
243
|
+
|
244
|
+
from mlflow.transformers.model_io import _MODEL_BINARY_FILE_NAME
|
245
|
+
|
246
|
+
flavor_conf[FlavorKey.MODEL_BINARY] = _MODEL_BINARY_FILE_NAME
|
247
|
+
flavor_conf.pop(FlavorKey.MODEL_REVISION, None)
|
248
|
+
|
249
|
+
# Remove component repo name and commit hash
|
250
|
+
components = original_flavor_conf.get(FlavorKey.COMPONENTS, [])
|
251
|
+
if FlavorKey.PROCESSOR_TYPE in original_flavor_conf:
|
252
|
+
components.append(FlavorKey.PROCESSOR)
|
253
|
+
|
254
|
+
for component in components:
|
255
|
+
flavor_conf.pop(FlavorKey.COMPONENT_NAME.format(component), None)
|
256
|
+
flavor_conf.pop(FlavorKey.COMPONENT_REVISION.format(component), None)
|
257
|
+
|
258
|
+
return flavor_conf
|
@@ -0,0 +1,83 @@
|
|
1
|
+
import functools
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
from mlflow.environment_variables import _MLFLOW_TESTING
|
8
|
+
from mlflow.exceptions import MlflowException
|
9
|
+
from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST
|
10
|
+
|
11
|
+
_logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
# NB: The maxsize=1 is added for encouraging the cache refresh so the user doesn't get stale
|
15
|
+
# commit hash from the cache. This doesn't work perfectly because it only updates cache
|
16
|
+
# when the user calls it with a different repo name, but it's better than nothing.
|
17
|
+
@functools.lru_cache(maxsize=1)
|
18
|
+
def get_latest_commit_for_repo(repo: str) -> str:
|
19
|
+
"""
|
20
|
+
Fetches the latest commit hash for a repository from the HuggingFace model hub.
|
21
|
+
"""
|
22
|
+
try:
|
23
|
+
import huggingface_hub as hub
|
24
|
+
except ImportError:
|
25
|
+
raise MlflowException(
|
26
|
+
"Unable to fetch model commit hash from the HuggingFace model hub. "
|
27
|
+
"This is required for saving Transformer model without base model "
|
28
|
+
"weights, while ensuring the version consistency of the model. "
|
29
|
+
"Please install the `huggingface-hub` package and retry.",
|
30
|
+
error_code=RESOURCE_DOES_NOT_EXIST,
|
31
|
+
)
|
32
|
+
|
33
|
+
from huggingface_hub.errors import HfHubHTTPError
|
34
|
+
|
35
|
+
api = hub.HfApi()
|
36
|
+
for i in range(7):
|
37
|
+
try:
|
38
|
+
return api.model_info(repo).sha
|
39
|
+
except HfHubHTTPError as e:
|
40
|
+
if not _MLFLOW_TESTING.get():
|
41
|
+
raise
|
42
|
+
|
43
|
+
# Retry on rate limit error
|
44
|
+
if e.response.status_code == 429:
|
45
|
+
_logger.warning(
|
46
|
+
f"Rate limit exceeded while fetching commit hash for repo {repo}. "
|
47
|
+
f"Retrying in {2**i} seconds. Error: {e}",
|
48
|
+
)
|
49
|
+
time.sleep(2**i)
|
50
|
+
continue
|
51
|
+
raise
|
52
|
+
|
53
|
+
raise MlflowException(
|
54
|
+
"Unable to fetch model commit hash from the HuggingFace model hub. "
|
55
|
+
"This is required for saving Transformer model without base model "
|
56
|
+
"weights, while ensuring the version consistency of the model. ",
|
57
|
+
error_code=RESOURCE_DOES_NOT_EXIST,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def is_valid_hf_repo_id(maybe_repo_id: Optional[str]) -> bool:
|
62
|
+
"""
|
63
|
+
Check if the given string is a valid HuggingFace repo identifier e.g. "username/repo_id".
|
64
|
+
"""
|
65
|
+
|
66
|
+
if not maybe_repo_id or os.path.isdir(maybe_repo_id):
|
67
|
+
return False
|
68
|
+
|
69
|
+
try:
|
70
|
+
from huggingface_hub.utils import HFValidationError, validate_repo_id
|
71
|
+
except ImportError:
|
72
|
+
raise MlflowException(
|
73
|
+
"Unable to validate the repository identifier for the HuggingFace model hub "
|
74
|
+
"because the `huggingface-hub` package is not installed. Please install the "
|
75
|
+
"package with `pip install huggingface-hub` command and retry."
|
76
|
+
)
|
77
|
+
|
78
|
+
try:
|
79
|
+
validate_repo_id(maybe_repo_id)
|
80
|
+
return True
|
81
|
+
except HFValidationError as e:
|
82
|
+
_logger.warning(f"The repository identified {maybe_repo_id} is invalid: {e}")
|
83
|
+
return False
|