genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
mlflow/types/utils.py
ADDED
@@ -0,0 +1,753 @@
|
|
1
|
+
import logging
|
2
|
+
import warnings
|
3
|
+
from collections import defaultdict
|
4
|
+
from copy import deepcopy
|
5
|
+
from typing import Any, Dict, List, Optional, Union
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import pydantic
|
10
|
+
|
11
|
+
from mlflow.exceptions import MlflowException
|
12
|
+
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
|
13
|
+
from mlflow.types import DataType
|
14
|
+
from mlflow.types.schema import (
|
15
|
+
HAS_PYSPARK,
|
16
|
+
AnyType,
|
17
|
+
Array,
|
18
|
+
ColSpec,
|
19
|
+
Map,
|
20
|
+
Object,
|
21
|
+
ParamSchema,
|
22
|
+
ParamSpec,
|
23
|
+
Property,
|
24
|
+
Schema,
|
25
|
+
SparkMLVector,
|
26
|
+
TensorSpec,
|
27
|
+
)
|
28
|
+
|
29
|
+
MULTIPLE_TYPES_ERROR_MSG = (
|
30
|
+
"Expected all values in the list to be of the same type. To specify a model signature "
|
31
|
+
"with a list containing elements of multiple types, define the signature manually "
|
32
|
+
"using the Array(AnyType()) type from mlflow.models.schema."
|
33
|
+
)
|
34
|
+
_logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
class TensorsNotSupportedException(MlflowException):
|
38
|
+
def __init__(self, msg):
|
39
|
+
super().__init__(f"Multidimensional arrays (aka tensors) are not supported. {msg}")
|
40
|
+
|
41
|
+
|
42
|
+
def _get_tensor_shape(data, variable_dimension: Optional[int] = 0) -> tuple[int, ...]:
|
43
|
+
"""Infer the shape of the inputted data.
|
44
|
+
|
45
|
+
This method creates the shape of the tensor to store in the TensorSpec. The variable dimension
|
46
|
+
is assumed to be the first dimension by default. This assumption can be overridden by inputting
|
47
|
+
a different variable dimension or `None` to represent that the input tensor does not contain a
|
48
|
+
variable dimension.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
data: Dataset to infer from.
|
52
|
+
variable_dimension: An optional integer representing a variable dimension.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
tuple: Shape of the inputted data (including a variable dimension)
|
56
|
+
"""
|
57
|
+
from scipy.sparse import csc_matrix, csr_matrix
|
58
|
+
|
59
|
+
if not isinstance(data, (np.ndarray, csr_matrix, csc_matrix)):
|
60
|
+
raise TypeError(f"Expected numpy.ndarray or csc/csr matrix, got '{type(data)}'.")
|
61
|
+
variable_input_data_shape = data.shape
|
62
|
+
if variable_dimension is not None:
|
63
|
+
try:
|
64
|
+
variable_input_data_shape = list(variable_input_data_shape)
|
65
|
+
variable_input_data_shape[variable_dimension] = -1
|
66
|
+
except IndexError:
|
67
|
+
raise MlflowException(
|
68
|
+
f"The specified variable_dimension {variable_dimension} is out of bounds with "
|
69
|
+
f"respect to the number of dimensions {data.ndim} in the input dataset"
|
70
|
+
)
|
71
|
+
return tuple(variable_input_data_shape)
|
72
|
+
|
73
|
+
|
74
|
+
def clean_tensor_type(dtype: np.dtype):
|
75
|
+
"""
|
76
|
+
This method strips away the size information stored in flexible datatypes such as np.str_ and
|
77
|
+
np.bytes_. Other numpy dtypes are returned unchanged.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
dtype: Numpy dtype of a tensor
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
dtype: Cleaned numpy dtype
|
84
|
+
"""
|
85
|
+
if not isinstance(dtype, np.dtype):
|
86
|
+
raise TypeError(
|
87
|
+
f"Expected `type` to be instance of `{np.dtype}`, received `{dtype.__class__}`"
|
88
|
+
)
|
89
|
+
|
90
|
+
# Special casing for np.str_ and np.bytes_
|
91
|
+
if dtype.char == "U":
|
92
|
+
return np.dtype("str")
|
93
|
+
elif dtype.char == "S":
|
94
|
+
return np.dtype("bytes")
|
95
|
+
return dtype
|
96
|
+
|
97
|
+
|
98
|
+
def _infer_colspec_type(data: Any) -> Union[DataType, Array, Object, AnyType]:
|
99
|
+
"""
|
100
|
+
Infer an MLflow Colspec type from the dataset.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
data: data to infer from.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
Object
|
107
|
+
"""
|
108
|
+
dtype = _infer_datatype(data)
|
109
|
+
|
110
|
+
if dtype is None:
|
111
|
+
raise MlflowException(
|
112
|
+
f"Numpy array must include at least one non-empty item. Invalid input `{data}`."
|
113
|
+
)
|
114
|
+
|
115
|
+
return dtype
|
116
|
+
|
117
|
+
|
118
|
+
class InvalidDataForSignatureInferenceError(MlflowException):
|
119
|
+
def __init__(self, message):
|
120
|
+
super().__init__(message=message, error_code=INVALID_PARAMETER_VALUE)
|
121
|
+
|
122
|
+
|
123
|
+
def _infer_datatype(data: Any) -> Optional[Union[DataType, Array, Object, AnyType]]:
|
124
|
+
"""
|
125
|
+
Infer the datatype of input data.
|
126
|
+
Data type and inferred schema type mapping:
|
127
|
+
- dict -> Object
|
128
|
+
- list -> Array
|
129
|
+
- numpy.ndarray -> Array
|
130
|
+
- scalar -> DataType
|
131
|
+
- None, empty dictionary/list -> AnyType
|
132
|
+
|
133
|
+
.. Note::
|
134
|
+
Empty numpy arrays are inferred as None to keep the backward compatibility, as numpy
|
135
|
+
arrays are used by some traditional ML flavors.
|
136
|
+
e.g. numpy.array([]) -> None, numpy.array([[], []]) -> None
|
137
|
+
While empty lists are inferred as AnyType instead of None after the support of AnyType.
|
138
|
+
e.g. [] -> AnyType, [[], []] -> Array(Any)
|
139
|
+
"""
|
140
|
+
if isinstance(data, pydantic.BaseModel):
|
141
|
+
raise InvalidDataForSignatureInferenceError(
|
142
|
+
message="MLflow does not support inferring model signature from input example "
|
143
|
+
"with Pydantic objects. To use Pydantic objects, define your PythonModel's "
|
144
|
+
"`predict` method with a Pydantic type hint, and model signature will be automatically "
|
145
|
+
"inferred when logging the model. e.g. "
|
146
|
+
"`def predict(self, model_input: list[PydanticType])`. Check "
|
147
|
+
"https://mlflow.org/docs/latest/model/python_model.html#type-hint-usage-in-pythonmodel "
|
148
|
+
"for more details."
|
149
|
+
)
|
150
|
+
|
151
|
+
if _is_none_or_nan(data) or (isinstance(data, (list, dict)) and not data):
|
152
|
+
return AnyType()
|
153
|
+
|
154
|
+
if isinstance(data, dict):
|
155
|
+
properties = []
|
156
|
+
for k, v in data.items():
|
157
|
+
dtype = _infer_datatype(v)
|
158
|
+
if dtype is None:
|
159
|
+
raise MlflowException("Dictionary value must not be an empty numpy array.")
|
160
|
+
properties.append(
|
161
|
+
Property(name=k, dtype=dtype, required=not isinstance(dtype, AnyType))
|
162
|
+
)
|
163
|
+
return Object(properties=properties)
|
164
|
+
|
165
|
+
if isinstance(data, (list, np.ndarray)):
|
166
|
+
return _infer_array_datatype(data)
|
167
|
+
|
168
|
+
return _infer_scalar_datatype(data)
|
169
|
+
|
170
|
+
|
171
|
+
def _infer_array_datatype(data: Union[list[Any], np.ndarray]) -> Optional[Array]:
|
172
|
+
"""Infer schema from an array. This tries to infer type if there is at least one
|
173
|
+
non-null item in the list, assuming the list has a homogeneous type. However,
|
174
|
+
if the list is empty or all items are null, returns None as a sign of undetermined.
|
175
|
+
|
176
|
+
E.g.
|
177
|
+
["a", "b"] => Array(string)
|
178
|
+
["a", None] => Array(string)
|
179
|
+
[["a", "b"], []] => Array(Array(string))
|
180
|
+
[["a", "b"], None] => Array(Array(string))
|
181
|
+
[] => None
|
182
|
+
[None] => Array(Any)
|
183
|
+
|
184
|
+
Args:
|
185
|
+
data: data to infer from.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
Array(dtype) or None if undetermined
|
189
|
+
"""
|
190
|
+
result = None
|
191
|
+
for item in data:
|
192
|
+
dtype = _infer_datatype(item)
|
193
|
+
|
194
|
+
# Skip item with undetermined type
|
195
|
+
if dtype is None:
|
196
|
+
continue
|
197
|
+
|
198
|
+
if result is None:
|
199
|
+
result = Array(dtype)
|
200
|
+
elif isinstance(result.dtype, (Array, Object, Map, AnyType)):
|
201
|
+
try:
|
202
|
+
result = Array(result.dtype._merge(dtype))
|
203
|
+
except MlflowException as e:
|
204
|
+
raise MlflowException.invalid_parameter_value(MULTIPLE_TYPES_ERROR_MSG) from e
|
205
|
+
elif isinstance(result.dtype, DataType):
|
206
|
+
if not isinstance(dtype, AnyType) and dtype != result.dtype:
|
207
|
+
raise MlflowException.invalid_parameter_value(MULTIPLE_TYPES_ERROR_MSG)
|
208
|
+
else:
|
209
|
+
raise MlflowException.invalid_parameter_value(
|
210
|
+
f"{dtype} is not a valid type for an item of a list or numpy array."
|
211
|
+
)
|
212
|
+
return result
|
213
|
+
|
214
|
+
|
215
|
+
# datetime is not included here
|
216
|
+
SCALAR_TO_DATATYPE_MAPPING = {
|
217
|
+
bool: DataType.boolean,
|
218
|
+
np.bool_: DataType.boolean,
|
219
|
+
int: DataType.long,
|
220
|
+
np.int64: DataType.long,
|
221
|
+
np.int32: DataType.integer,
|
222
|
+
float: DataType.double,
|
223
|
+
np.float64: DataType.double,
|
224
|
+
np.float32: DataType.float,
|
225
|
+
str: DataType.string,
|
226
|
+
np.str_: DataType.string,
|
227
|
+
object: DataType.string,
|
228
|
+
bytes: DataType.binary,
|
229
|
+
np.bytes_: DataType.binary,
|
230
|
+
bytearray: DataType.binary,
|
231
|
+
}
|
232
|
+
|
233
|
+
|
234
|
+
def _infer_scalar_datatype(data) -> DataType:
|
235
|
+
if data_type := SCALAR_TO_DATATYPE_MAPPING.get(type(data)):
|
236
|
+
return data_type
|
237
|
+
if DataType.check_type(DataType.datetime, data):
|
238
|
+
return DataType.datetime
|
239
|
+
if HAS_PYSPARK:
|
240
|
+
for data_type in DataType.all_types():
|
241
|
+
if isinstance(data, type(data_type.to_spark())):
|
242
|
+
return data_type
|
243
|
+
raise MlflowException.invalid_parameter_value(
|
244
|
+
f"Data {data} is not one of the supported DataType"
|
245
|
+
)
|
246
|
+
|
247
|
+
|
248
|
+
def _infer_schema(data: Any) -> Schema:
|
249
|
+
"""
|
250
|
+
Infer an MLflow schema from a dataset.
|
251
|
+
|
252
|
+
Data inputted as a numpy array or a dictionary is represented by :py:class:`TensorSpec`.
|
253
|
+
All other inputted data types are specified by :py:class:`ColSpec`.
|
254
|
+
|
255
|
+
A `TensorSpec` captures the data shape (default variable axis is 0), the data type (numpy.dtype)
|
256
|
+
and an optional name for each individual tensor of the dataset.
|
257
|
+
A `ColSpec` captures the data type (defined in :py:class:`DataType`) and an optional name for
|
258
|
+
each individual column of the dataset.
|
259
|
+
|
260
|
+
This method will raise an exception if the user data contains incompatible types or is not
|
261
|
+
passed in one of the supported formats (containers).
|
262
|
+
|
263
|
+
The input should be one of these:
|
264
|
+
- pandas.DataFrame
|
265
|
+
- pandas.Series
|
266
|
+
- numpy.ndarray
|
267
|
+
- dictionary of (name -> numpy.ndarray)
|
268
|
+
- pyspark.sql.DataFrame
|
269
|
+
- scipy.sparse.csr_matrix/csc_matrix
|
270
|
+
- DataType
|
271
|
+
- List[DataType]
|
272
|
+
- Dict[str, Union[DataType, List, Dict]]
|
273
|
+
- List[Dict[str, Union[DataType, List, Dict]]]
|
274
|
+
|
275
|
+
The last two formats are used to represent complex data structures. For example,
|
276
|
+
|
277
|
+
Input Data:
|
278
|
+
[
|
279
|
+
{
|
280
|
+
'text': 'some sentence',
|
281
|
+
'ids': ['id1'],
|
282
|
+
'dict': {'key': 'value'}
|
283
|
+
},
|
284
|
+
{
|
285
|
+
'text': 'some sentence',
|
286
|
+
'ids': ['id1', 'id2'],
|
287
|
+
'dict': {'key': 'value', 'key2': 'value2'}
|
288
|
+
},
|
289
|
+
]
|
290
|
+
|
291
|
+
The corresponding pandas DataFrame representation should look like this:
|
292
|
+
|
293
|
+
output ids dict
|
294
|
+
0 some sentence [id1, id2] {'key': 'value'}
|
295
|
+
1 some sentence [id1, id2] {'key': 'value', 'key2': 'value2'}
|
296
|
+
|
297
|
+
The inferred schema should look like this:
|
298
|
+
|
299
|
+
Schema([
|
300
|
+
ColSpec(type=DataType.string, name='output'),
|
301
|
+
ColSpec(type=Array(dtype=DataType.string), name='ids'),
|
302
|
+
ColSpec(
|
303
|
+
type=Object([
|
304
|
+
Property(name='key', dtype=DataType.string),
|
305
|
+
Property(name='key2', dtype=DataType.string, required=False)
|
306
|
+
]),
|
307
|
+
name='dict')]
|
308
|
+
),
|
309
|
+
])
|
310
|
+
|
311
|
+
The element types should be mappable to one of :py:class:`mlflow.models.signature.DataType` for
|
312
|
+
dataframes and to one of numpy types for tensors.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
data: Dataset to infer from.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
Schema
|
319
|
+
"""
|
320
|
+
from scipy.sparse import csc_matrix, csr_matrix
|
321
|
+
|
322
|
+
# To keep backward compatibility with < 2.9.0, an empty list is inferred as string.
|
323
|
+
# ref: https://github.com/mlflow/mlflow/pull/10125#discussion_r1372751487
|
324
|
+
if isinstance(data, list) and data == []:
|
325
|
+
return Schema([ColSpec(DataType.string)])
|
326
|
+
|
327
|
+
if isinstance(data, list) and all(isinstance(value, dict) for value in data):
|
328
|
+
col_data_mapping = defaultdict(list)
|
329
|
+
for item in data:
|
330
|
+
for k, v in item.items():
|
331
|
+
col_data_mapping[k].append(v)
|
332
|
+
requiredness = {}
|
333
|
+
for col in col_data_mapping:
|
334
|
+
# if col exists in item but its value is None, then it is not required
|
335
|
+
requiredness[col] = all(item.get(col) is not None for item in data)
|
336
|
+
|
337
|
+
schema = Schema(
|
338
|
+
[
|
339
|
+
ColSpec(_infer_colspec_type(values).dtype, name=name, required=requiredness[name])
|
340
|
+
for name, values in col_data_mapping.items()
|
341
|
+
]
|
342
|
+
)
|
343
|
+
|
344
|
+
elif isinstance(data, dict):
|
345
|
+
# dictionary of (name -> numpy.ndarray)
|
346
|
+
if all(isinstance(values, np.ndarray) for values in data.values()):
|
347
|
+
schema = Schema(
|
348
|
+
[
|
349
|
+
TensorSpec(
|
350
|
+
type=clean_tensor_type(ndarray.dtype),
|
351
|
+
shape=_get_tensor_shape(ndarray),
|
352
|
+
name=name,
|
353
|
+
)
|
354
|
+
for name, ndarray in data.items()
|
355
|
+
]
|
356
|
+
)
|
357
|
+
# Dict[str, Union[DataType, List, Dict]]
|
358
|
+
else:
|
359
|
+
if any(not isinstance(key, str) for key in data):
|
360
|
+
raise MlflowException("The dictionary keys are not all strings.")
|
361
|
+
schema = Schema(
|
362
|
+
[
|
363
|
+
ColSpec(
|
364
|
+
_infer_colspec_type(value),
|
365
|
+
name=name,
|
366
|
+
required=_infer_required(value),
|
367
|
+
)
|
368
|
+
for name, value in data.items()
|
369
|
+
]
|
370
|
+
)
|
371
|
+
# pandas.Series
|
372
|
+
elif isinstance(data, pd.Series):
|
373
|
+
name = getattr(data, "name", None)
|
374
|
+
schema = Schema(
|
375
|
+
[
|
376
|
+
ColSpec(
|
377
|
+
type=_infer_pandas_column(data),
|
378
|
+
name=name,
|
379
|
+
required=_infer_required(data),
|
380
|
+
)
|
381
|
+
]
|
382
|
+
)
|
383
|
+
# pandas.DataFrame
|
384
|
+
elif isinstance(data, pd.DataFrame):
|
385
|
+
schema = Schema(
|
386
|
+
[
|
387
|
+
ColSpec(
|
388
|
+
type=_infer_pandas_column(data[col]),
|
389
|
+
name=col,
|
390
|
+
required=_infer_required(data[col]),
|
391
|
+
)
|
392
|
+
for col in data.columns
|
393
|
+
]
|
394
|
+
)
|
395
|
+
# numpy.ndarray
|
396
|
+
elif isinstance(data, np.ndarray):
|
397
|
+
schema = Schema(
|
398
|
+
[TensorSpec(type=clean_tensor_type(data.dtype), shape=_get_tensor_shape(data))]
|
399
|
+
)
|
400
|
+
# scipy.sparse.csr_matrix/csc_matrix
|
401
|
+
elif isinstance(data, (csc_matrix, csr_matrix)):
|
402
|
+
schema = Schema(
|
403
|
+
[TensorSpec(type=clean_tensor_type(data.data.dtype), shape=_get_tensor_shape(data))]
|
404
|
+
)
|
405
|
+
# pyspark.sql.DataFrame
|
406
|
+
elif _is_spark_df(data):
|
407
|
+
schema = Schema(
|
408
|
+
[
|
409
|
+
ColSpec(
|
410
|
+
type=_infer_spark_type(field.dataType, data, field.name),
|
411
|
+
name=field.name,
|
412
|
+
# Avoid setting required field for spark dataframe
|
413
|
+
# as the default value for spark df nullable is True
|
414
|
+
# which counterparts to default required=True in ColSpec
|
415
|
+
)
|
416
|
+
for field in data.schema.fields
|
417
|
+
]
|
418
|
+
)
|
419
|
+
elif isinstance(data, list):
|
420
|
+
# Assume list as a single column
|
421
|
+
# List[DataType]
|
422
|
+
# e.g. ['some sentence', 'some sentence'] -> Schema([ColSpec(type=DataType.string)])
|
423
|
+
# The corresponding pandas DataFrame representation should be pd.DataFrame(data)
|
424
|
+
# We set required=True as unnamed optional inputs is not allowed
|
425
|
+
schema = Schema([ColSpec(_infer_colspec_type(data).dtype)])
|
426
|
+
else:
|
427
|
+
# DataType
|
428
|
+
# e.g. "some sentence" -> Schema([ColSpec(type=DataType.string)])
|
429
|
+
try:
|
430
|
+
# We set required=True as unnamed optional inputs is not allowed
|
431
|
+
schema = Schema([ColSpec(_infer_colspec_type(data))])
|
432
|
+
except MlflowException as e:
|
433
|
+
raise MlflowException.invalid_parameter_value(
|
434
|
+
"Failed to infer schema. Expected one of the following types:\n"
|
435
|
+
"- pandas.DataFrame\n"
|
436
|
+
"- pandas.Series\n"
|
437
|
+
"- numpy.ndarray\n"
|
438
|
+
"- dictionary of (name -> numpy.ndarray)\n"
|
439
|
+
"- pyspark.sql.DataFrame\n"
|
440
|
+
"- scipy.sparse.csr_matrix\n"
|
441
|
+
"- scipy.sparse.csc_matrix\n"
|
442
|
+
"- DataType\n"
|
443
|
+
"- List[DataType]\n"
|
444
|
+
"- Dict[str, Union[DataType, List, Dict]]\n"
|
445
|
+
"- List[Dict[str, Union[DataType, List, Dict]]]\n"
|
446
|
+
f"but got '{data}'.\n"
|
447
|
+
f"Error: {e}",
|
448
|
+
)
|
449
|
+
if not schema.is_tensor_spec() and any(
|
450
|
+
t in (DataType.integer, DataType.long) for t in schema.input_types()
|
451
|
+
):
|
452
|
+
warnings.warn(
|
453
|
+
"Hint: Inferred schema contains integer column(s). Integer columns in "
|
454
|
+
"Python cannot represent missing values. If your input data contains "
|
455
|
+
"missing values at inference time, it will be encoded as floats and will "
|
456
|
+
"cause a schema enforcement error. The best way to avoid this problem is "
|
457
|
+
"to infer the model schema based on a realistic data sample (training "
|
458
|
+
"dataset) that includes missing values. Alternatively, you can declare "
|
459
|
+
"integer columns as doubles (float64) whenever these columns may have "
|
460
|
+
"missing values. See `Handling Integers With Missing Values "
|
461
|
+
"<https://www.mlflow.org/docs/latest/models.html#"
|
462
|
+
"handling-integers-with-missing-values>`_ for more details."
|
463
|
+
)
|
464
|
+
return schema
|
465
|
+
|
466
|
+
|
467
|
+
def _infer_numpy_dtype(dtype) -> DataType:
|
468
|
+
supported_types = np.dtype
|
469
|
+
|
470
|
+
# noinspection PyBroadException
|
471
|
+
try:
|
472
|
+
from pandas.core.dtypes.base import ExtensionDtype
|
473
|
+
|
474
|
+
supported_types = (np.dtype, ExtensionDtype)
|
475
|
+
except ImportError:
|
476
|
+
# This version of pandas does not support extension types
|
477
|
+
pass
|
478
|
+
if not isinstance(dtype, supported_types):
|
479
|
+
raise TypeError(f"Expected numpy.dtype or pandas.ExtensionDtype, got '{type(dtype)}'.")
|
480
|
+
|
481
|
+
if dtype.kind == "b":
|
482
|
+
return DataType.boolean
|
483
|
+
elif dtype.kind == "i" or dtype.kind == "u":
|
484
|
+
if dtype.itemsize < 4 or (dtype.kind == "i" and dtype.itemsize == 4):
|
485
|
+
return DataType.integer
|
486
|
+
elif dtype.itemsize < 8 or (dtype.kind == "i" and dtype.itemsize == 8):
|
487
|
+
return DataType.long
|
488
|
+
elif dtype.kind == "f":
|
489
|
+
if dtype.itemsize <= 4:
|
490
|
+
return DataType.float
|
491
|
+
elif dtype.itemsize <= 8:
|
492
|
+
return DataType.double
|
493
|
+
|
494
|
+
elif dtype.kind == "U":
|
495
|
+
return DataType.string
|
496
|
+
elif dtype.kind == "S":
|
497
|
+
return DataType.binary
|
498
|
+
elif dtype.kind == "O":
|
499
|
+
raise Exception(
|
500
|
+
"Can not infer object without looking at the values, call _map_numpy_array instead."
|
501
|
+
)
|
502
|
+
elif dtype.kind == "M":
|
503
|
+
return DataType.datetime
|
504
|
+
raise MlflowException(f"Unsupported numpy data type '{dtype}', kind '{dtype.kind}'")
|
505
|
+
|
506
|
+
|
507
|
+
def _is_none_or_nan(x):
|
508
|
+
if isinstance(x, float):
|
509
|
+
return np.isnan(x)
|
510
|
+
# NB: We can't use pd.isna() because the input can be a series.
|
511
|
+
return x is None or x is pd.NA or x is pd.NaT
|
512
|
+
|
513
|
+
|
514
|
+
def _infer_required(col) -> bool:
|
515
|
+
if isinstance(col, (list, pd.Series)):
|
516
|
+
return not any(_is_none_or_nan(x) for x in col)
|
517
|
+
return not _is_none_or_nan(col)
|
518
|
+
|
519
|
+
|
520
|
+
def _infer_pandas_column(col: pd.Series) -> DataType:
|
521
|
+
if not isinstance(col, pd.Series):
|
522
|
+
raise TypeError(f"Expected pandas.Series, got '{type(col)}'.")
|
523
|
+
if len(col.values.shape) > 1:
|
524
|
+
raise MlflowException(f"Expected 1d array, got array with shape {col.shape}")
|
525
|
+
|
526
|
+
if col.dtype.kind == "O":
|
527
|
+
col = col.infer_objects()
|
528
|
+
if col.dtype.kind == "O":
|
529
|
+
try:
|
530
|
+
# We convert pandas Series into list and infer the schema.
|
531
|
+
# The real schema for internal field should be the Array's dtype
|
532
|
+
arr_type = _infer_colspec_type(col.to_list())
|
533
|
+
return arr_type.dtype
|
534
|
+
except Exception as e:
|
535
|
+
# For backwards compatibility, we fall back to string
|
536
|
+
# if the provided array is of string type
|
537
|
+
# This is for diviner test where df field is ('key2', 'key1', 'key0')
|
538
|
+
if pd.api.types.is_string_dtype(col):
|
539
|
+
return DataType.string
|
540
|
+
raise MlflowException(f"Failed to infer schema for pandas.Series {col}. Error: {e}")
|
541
|
+
else:
|
542
|
+
# NB: The following works for numpy types as well as pandas extension types.
|
543
|
+
return _infer_numpy_dtype(col.dtype)
|
544
|
+
|
545
|
+
|
546
|
+
def _infer_spark_type(x, data=None, col_name=None) -> DataType:
|
547
|
+
import pyspark.sql.types
|
548
|
+
from pyspark.ml.linalg import VectorUDT
|
549
|
+
from pyspark.sql.functions import col, collect_list
|
550
|
+
|
551
|
+
if isinstance(x, pyspark.sql.types.NumericType):
|
552
|
+
if isinstance(x, pyspark.sql.types.IntegralType):
|
553
|
+
if isinstance(x, pyspark.sql.types.LongType):
|
554
|
+
return DataType.long
|
555
|
+
else:
|
556
|
+
return DataType.integer
|
557
|
+
elif isinstance(x, pyspark.sql.types.FloatType):
|
558
|
+
return DataType.float
|
559
|
+
elif isinstance(x, pyspark.sql.types.DoubleType):
|
560
|
+
return DataType.double
|
561
|
+
elif isinstance(x, pyspark.sql.types.BooleanType):
|
562
|
+
return DataType.boolean
|
563
|
+
elif isinstance(x, pyspark.sql.types.StringType):
|
564
|
+
return DataType.string
|
565
|
+
elif isinstance(x, pyspark.sql.types.BinaryType):
|
566
|
+
return DataType.binary
|
567
|
+
# NB: Spark differentiates date and timestamps, so we coerce both to TimestampType.
|
568
|
+
elif isinstance(x, (pyspark.sql.types.DateType, pyspark.sql.types.TimestampType)):
|
569
|
+
return DataType.datetime
|
570
|
+
elif isinstance(x, pyspark.sql.types.ArrayType):
|
571
|
+
return Array(_infer_spark_type(x.elementType))
|
572
|
+
elif isinstance(x, pyspark.sql.types.StructType):
|
573
|
+
return Object(
|
574
|
+
properties=[
|
575
|
+
Property(
|
576
|
+
name=f.name,
|
577
|
+
dtype=_infer_spark_type(f.dataType),
|
578
|
+
required=not f.nullable,
|
579
|
+
)
|
580
|
+
for f in x.fields
|
581
|
+
]
|
582
|
+
)
|
583
|
+
elif isinstance(x, pyspark.sql.types.MapType):
|
584
|
+
if data is None or col_name is None:
|
585
|
+
raise MlflowException("Cannot infer schema for MapType without data and column name.")
|
586
|
+
# Map MapType to StructType
|
587
|
+
# Note that MapType assumes all values are of same type,
|
588
|
+
# if they're not then spark picks the first item's type
|
589
|
+
# and tries to convert rest to that type.
|
590
|
+
# e.g.
|
591
|
+
# >>> spark.createDataFrame([{"col": {"a": 1, "b": "b"}}]).show()
|
592
|
+
# +-------------------+
|
593
|
+
# | col|
|
594
|
+
# +-------------------+
|
595
|
+
# |{a -> 1, b -> null}|
|
596
|
+
# +-------------------+
|
597
|
+
if isinstance(x.valueType, pyspark.sql.types.MapType):
|
598
|
+
raise MlflowException(
|
599
|
+
"Please construct spark DataFrame with schema using StructType "
|
600
|
+
"for dictionary/map fields, MLflow schema inference only supports "
|
601
|
+
"scalar, array and struct types."
|
602
|
+
)
|
603
|
+
|
604
|
+
merged_keys = (
|
605
|
+
data.selectExpr(f"map_keys({col_name}) as keys")
|
606
|
+
.agg(collect_list(col("keys")).alias("merged_keys"))
|
607
|
+
.head()
|
608
|
+
.merged_keys
|
609
|
+
)
|
610
|
+
keys = {key for sublist in merged_keys for key in sublist}
|
611
|
+
return Object(
|
612
|
+
properties=[
|
613
|
+
Property(
|
614
|
+
name=k,
|
615
|
+
dtype=_infer_spark_type(x.valueType),
|
616
|
+
)
|
617
|
+
for k in keys
|
618
|
+
]
|
619
|
+
)
|
620
|
+
elif isinstance(x, VectorUDT):
|
621
|
+
return SparkMLVector()
|
622
|
+
|
623
|
+
else:
|
624
|
+
raise MlflowException.invalid_parameter_value(
|
625
|
+
f"Unsupported Spark Type '{type(x)}' for MLflow schema."
|
626
|
+
)
|
627
|
+
|
628
|
+
|
629
|
+
def _is_spark_df(x) -> bool:
|
630
|
+
try:
|
631
|
+
import pyspark.sql.dataframe
|
632
|
+
|
633
|
+
if isinstance(x, pyspark.sql.dataframe.DataFrame):
|
634
|
+
return True
|
635
|
+
except ImportError:
|
636
|
+
return False
|
637
|
+
# For spark 4.0
|
638
|
+
try:
|
639
|
+
import pyspark.sql.connect.dataframe
|
640
|
+
|
641
|
+
return isinstance(x, pyspark.sql.connect.dataframe.DataFrame)
|
642
|
+
except ImportError:
|
643
|
+
return False
|
644
|
+
|
645
|
+
|
646
|
+
def _validate_input_dictionary_contains_only_strings_and_lists_of_strings(data) -> None:
|
647
|
+
# isinstance(True, int) is True
|
648
|
+
invalid_keys = [
|
649
|
+
key for key in data.keys() if not isinstance(key, (str, int)) or isinstance(key, bool)
|
650
|
+
]
|
651
|
+
if invalid_keys:
|
652
|
+
raise MlflowException(
|
653
|
+
f"The dictionary keys are not all strings or indexes. Invalid keys: {invalid_keys}"
|
654
|
+
)
|
655
|
+
if any(isinstance(value, np.ndarray) for value in data.values()) and not all(
|
656
|
+
isinstance(value, np.ndarray) for value in data.values()
|
657
|
+
):
|
658
|
+
raise MlflowException("The dictionary values are not all numpy.ndarray.")
|
659
|
+
|
660
|
+
invalid_values = [
|
661
|
+
key
|
662
|
+
for key, value in data.items()
|
663
|
+
if (isinstance(value, list) and not all(isinstance(item, (str, bytes)) for item in value))
|
664
|
+
or (not isinstance(value, (np.ndarray, list, str, bytes)))
|
665
|
+
]
|
666
|
+
if invalid_values:
|
667
|
+
raise MlflowException.invalid_parameter_value(
|
668
|
+
"Invalid values in dictionary. If passing a dictionary containing strings, all "
|
669
|
+
"values must be either strings or lists of strings. If passing a dictionary containing "
|
670
|
+
"numeric values, the data must be enclosed in a numpy.ndarray. The following keys "
|
671
|
+
f"in the input dictionary are invalid: {invalid_values}",
|
672
|
+
)
|
673
|
+
|
674
|
+
|
675
|
+
def _is_list_str(type_hint: Any) -> bool:
|
676
|
+
return type_hint in [
|
677
|
+
List[str], # noqa: UP006
|
678
|
+
list[str],
|
679
|
+
]
|
680
|
+
|
681
|
+
|
682
|
+
def _is_list_dict_str(type_hint: Any) -> bool:
|
683
|
+
return type_hint in [
|
684
|
+
List[Dict[str, str]], # noqa: UP006
|
685
|
+
list[Dict[str, str]], # noqa: UP006
|
686
|
+
List[dict[str, str]], # noqa: UP006
|
687
|
+
list[dict[str, str]],
|
688
|
+
]
|
689
|
+
|
690
|
+
|
691
|
+
def _get_array_depth(l: Any) -> int:
|
692
|
+
if isinstance(l, np.ndarray):
|
693
|
+
return l.ndim
|
694
|
+
if isinstance(l, list):
|
695
|
+
return max(_get_array_depth(item) for item in l) + 1 if l else 1
|
696
|
+
return 0
|
697
|
+
|
698
|
+
|
699
|
+
def _infer_type_and_shape(value):
|
700
|
+
if isinstance(value, (list, np.ndarray)):
|
701
|
+
ndim = _get_array_depth(value)
|
702
|
+
if ndim != 1:
|
703
|
+
raise MlflowException.invalid_parameter_value(
|
704
|
+
f"Expected parameters to be 1D array or scalar, got {ndim}D array",
|
705
|
+
)
|
706
|
+
if all(DataType.check_type(DataType.datetime, v) for v in value):
|
707
|
+
return DataType.datetime, (-1,)
|
708
|
+
value_type = _infer_numpy_dtype(np.array(value).dtype)
|
709
|
+
return value_type, (-1,)
|
710
|
+
elif DataType.check_type(DataType.datetime, value):
|
711
|
+
return DataType.datetime, None
|
712
|
+
elif np.isscalar(value):
|
713
|
+
try:
|
714
|
+
value_type = _infer_numpy_dtype(np.array(value).dtype)
|
715
|
+
return value_type, None
|
716
|
+
except (Exception, MlflowException) as e:
|
717
|
+
raise MlflowException.invalid_parameter_value(
|
718
|
+
f"Failed to infer schema for parameter {value}: {e!r}"
|
719
|
+
)
|
720
|
+
elif isinstance(value, dict):
|
721
|
+
# reuse _infer_schema to infer schema for dict, wrapping it in a dictionary is
|
722
|
+
# necessary to make sure value is inferred as Object
|
723
|
+
schema = _infer_schema({"value": value})
|
724
|
+
object_type = schema.inputs[0].type
|
725
|
+
return object_type, None
|
726
|
+
raise MlflowException.invalid_parameter_value(
|
727
|
+
f"Expected parameters to be 1D array or scalar, got {type(value).__name__}",
|
728
|
+
)
|
729
|
+
|
730
|
+
|
731
|
+
def _infer_param_schema(parameters: dict[str, Any]):
|
732
|
+
if not isinstance(parameters, dict):
|
733
|
+
raise MlflowException.invalid_parameter_value(
|
734
|
+
f"Expected parameters to be dict, got {type(parameters).__name__}",
|
735
|
+
)
|
736
|
+
|
737
|
+
param_specs = []
|
738
|
+
invalid_params = []
|
739
|
+
for name, value in parameters.items():
|
740
|
+
try:
|
741
|
+
value_type, shape = _infer_type_and_shape(value)
|
742
|
+
param_specs.append(
|
743
|
+
ParamSpec(name=name, dtype=value_type, default=deepcopy(value), shape=shape)
|
744
|
+
)
|
745
|
+
except Exception as e:
|
746
|
+
invalid_params.append((name, value, e))
|
747
|
+
|
748
|
+
if invalid_params:
|
749
|
+
raise MlflowException.invalid_parameter_value(
|
750
|
+
f"Failed to infer schema for parameters: {invalid_params}",
|
751
|
+
)
|
752
|
+
|
753
|
+
return ParamSchema(param_specs)
|