genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
mlflow/types/schema.py
ADDED
@@ -0,0 +1,1505 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import builtins
|
4
|
+
import datetime as dt
|
5
|
+
import json
|
6
|
+
import string
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from copy import deepcopy
|
9
|
+
from dataclasses import is_dataclass
|
10
|
+
from enum import Enum
|
11
|
+
from typing import Any, Optional, TypedDict, Union, get_args, get_origin
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
|
15
|
+
from mlflow.exceptions import MlflowException
|
16
|
+
from mlflow.utils.annotations import experimental
|
17
|
+
|
18
|
+
ARRAY_TYPE = "array"
|
19
|
+
OBJECT_TYPE = "object"
|
20
|
+
MAP_TYPE = "map"
|
21
|
+
ANY_TYPE = "any"
|
22
|
+
SPARKML_VECTOR_TYPE = "sparkml_vector"
|
23
|
+
ALLOWED_DTYPES = Union["Array", "DataType", "Map", "Object", "AnyType", str]
|
24
|
+
EXPECTED_TYPE_MESSAGE = (
|
25
|
+
"Expected mlflow.types.schema.Datatype, mlflow.types.schema.Array, "
|
26
|
+
"mlflow.types.schema.Object, mlflow.types.schema.Map, mlflow.types.schema.AnyType "
|
27
|
+
"or str for the '{arg_name}' argument, but got {passed_type}"
|
28
|
+
)
|
29
|
+
COLSPEC_TYPES = Union["Array", "DataType", "Map", "Object", "AnyType"]
|
30
|
+
|
31
|
+
try:
|
32
|
+
import pyspark # noqa: F401
|
33
|
+
|
34
|
+
HAS_PYSPARK = True
|
35
|
+
except ImportError:
|
36
|
+
HAS_PYSPARK = False
|
37
|
+
|
38
|
+
|
39
|
+
class DataType(Enum):
|
40
|
+
"""
|
41
|
+
MLflow data types.
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __new__(cls, value, numpy_type, spark_type, pandas_type=None, python_type=None):
|
45
|
+
res = object.__new__(cls)
|
46
|
+
res._value_ = value
|
47
|
+
res._numpy_type = numpy_type
|
48
|
+
res._spark_type = spark_type
|
49
|
+
res._pandas_type = pandas_type if pandas_type is not None else numpy_type
|
50
|
+
res._python_type = python_type if python_type is not None else numpy_type
|
51
|
+
return res
|
52
|
+
|
53
|
+
# NB: We only use pandas extension type for strings. There are also pandas extension types for
|
54
|
+
# integers and boolean values. We do not use them here for now as most downstream tools are
|
55
|
+
# most likely to use / expect native numpy types and would not be compatible with the extension
|
56
|
+
# types.
|
57
|
+
boolean = (1, np.dtype("bool"), "BooleanType", np.dtype("bool"), bool)
|
58
|
+
"""Logical data (True, False) ."""
|
59
|
+
integer = (2, np.dtype("int32"), "IntegerType", np.dtype("int32"), int)
|
60
|
+
"""32b signed integer numbers."""
|
61
|
+
long = (3, np.dtype("int64"), "LongType", np.dtype("int64"), int)
|
62
|
+
"""64b signed integer numbers. """
|
63
|
+
float = (4, np.dtype("float32"), "FloatType", np.dtype("float32"), builtins.float)
|
64
|
+
"""32b floating point numbers. """
|
65
|
+
double = (5, np.dtype("float64"), "DoubleType", np.dtype("float64"), builtins.float)
|
66
|
+
"""64b floating point numbers. """
|
67
|
+
string = (6, np.dtype("str"), "StringType", object, str)
|
68
|
+
"""Text data."""
|
69
|
+
binary = (7, np.dtype("bytes"), "BinaryType", object, bytes)
|
70
|
+
"""Sequence of raw bytes."""
|
71
|
+
datetime = (
|
72
|
+
8,
|
73
|
+
np.dtype("datetime64[ns]"),
|
74
|
+
"TimestampType",
|
75
|
+
np.dtype("datetime64[ns]"),
|
76
|
+
dt.date,
|
77
|
+
)
|
78
|
+
"""64b datetime data."""
|
79
|
+
|
80
|
+
def __repr__(self):
|
81
|
+
return self.name
|
82
|
+
|
83
|
+
def to_numpy(self) -> np.dtype:
|
84
|
+
"""Get equivalent numpy data type."""
|
85
|
+
return self._numpy_type
|
86
|
+
|
87
|
+
def to_pandas(self) -> np.dtype:
|
88
|
+
"""Get equivalent pandas data type."""
|
89
|
+
return self._pandas_type
|
90
|
+
|
91
|
+
def to_spark(self):
|
92
|
+
if self._spark_type == "VectorUDT":
|
93
|
+
from pyspark.ml.linalg import VectorUDT
|
94
|
+
|
95
|
+
return VectorUDT()
|
96
|
+
else:
|
97
|
+
import pyspark.sql.types
|
98
|
+
|
99
|
+
return getattr(pyspark.sql.types, self._spark_type)()
|
100
|
+
|
101
|
+
def to_python(self):
|
102
|
+
"""Get equivalent python data type."""
|
103
|
+
return self._python_type
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def check_type(cls, data_type, value):
|
107
|
+
types = [data_type.to_numpy(), data_type.to_pandas(), data_type.to_python()]
|
108
|
+
if data_type.name == "datetime":
|
109
|
+
types.extend([np.datetime64, dt.datetime])
|
110
|
+
if data_type.name == "binary":
|
111
|
+
types.append(bytearray)
|
112
|
+
if type(value) in types:
|
113
|
+
return True
|
114
|
+
if HAS_PYSPARK:
|
115
|
+
return isinstance(value, type(data_type.to_spark()))
|
116
|
+
return False
|
117
|
+
|
118
|
+
@classmethod
|
119
|
+
def all_types(cls):
|
120
|
+
return list(DataType.__members__.values())
|
121
|
+
|
122
|
+
@classmethod
|
123
|
+
def get_spark_types(cls):
|
124
|
+
return [dt.to_spark() for dt in cls._member_map_.values()]
|
125
|
+
|
126
|
+
@classmethod
|
127
|
+
def from_numpy_type(cls, np_type):
|
128
|
+
return next((v for v in cls._member_map_.values() if v.to_numpy() == np_type), None)
|
129
|
+
|
130
|
+
|
131
|
+
class BaseType(ABC):
|
132
|
+
@abstractmethod
|
133
|
+
def __eq__(self, other) -> bool:
|
134
|
+
"""
|
135
|
+
Determine if two objects are equal.
|
136
|
+
"""
|
137
|
+
|
138
|
+
@abstractmethod
|
139
|
+
def __repr__(self) -> str:
|
140
|
+
"""
|
141
|
+
The string representation of the object.
|
142
|
+
"""
|
143
|
+
|
144
|
+
@abstractmethod
|
145
|
+
def to_dict(self) -> dict[str, Any]:
|
146
|
+
"""
|
147
|
+
Dictionary representation of the object.
|
148
|
+
"""
|
149
|
+
|
150
|
+
@abstractmethod
|
151
|
+
def _merge(self, other: BaseType) -> BaseType:
|
152
|
+
"""
|
153
|
+
Merge two objects and return the updated object if they're compatible.
|
154
|
+
"""
|
155
|
+
|
156
|
+
|
157
|
+
class Property(BaseType):
|
158
|
+
"""
|
159
|
+
Specification used to represent a json-convertible object property.
|
160
|
+
"""
|
161
|
+
|
162
|
+
def __init__(
|
163
|
+
self,
|
164
|
+
name: str,
|
165
|
+
dtype: ALLOWED_DTYPES,
|
166
|
+
required: bool = True,
|
167
|
+
) -> None:
|
168
|
+
"""
|
169
|
+
Args:
|
170
|
+
name: The name of the property
|
171
|
+
dtype: The data type of the property
|
172
|
+
required: Whether this property is required
|
173
|
+
"""
|
174
|
+
if not isinstance(name, str):
|
175
|
+
raise MlflowException.invalid_parameter_value(
|
176
|
+
f"Expected name to be a string, got type {type(name).__name__}"
|
177
|
+
)
|
178
|
+
self._name = name
|
179
|
+
try:
|
180
|
+
self._dtype = DataType[dtype] if isinstance(dtype, str) else dtype
|
181
|
+
except KeyError:
|
182
|
+
raise MlflowException(
|
183
|
+
f"Unsupported type '{dtype}', expected instance of DataType, Array, Object, Map or "
|
184
|
+
f"one of {[t.name for t in DataType]}"
|
185
|
+
)
|
186
|
+
if not isinstance(self.dtype, (DataType, Array, Object, Map, AnyType)):
|
187
|
+
raise MlflowException(
|
188
|
+
EXPECTED_TYPE_MESSAGE.format(arg_name="dtype", passed_type=self.dtype)
|
189
|
+
)
|
190
|
+
self._required = required
|
191
|
+
|
192
|
+
@property
|
193
|
+
def name(self) -> str:
|
194
|
+
"""The property name."""
|
195
|
+
return self._name
|
196
|
+
|
197
|
+
@property
|
198
|
+
def dtype(self) -> Union[DataType, "Array", "Object", "Map"]:
|
199
|
+
"""The property data type."""
|
200
|
+
return self._dtype
|
201
|
+
|
202
|
+
@property
|
203
|
+
def required(self) -> bool:
|
204
|
+
"""Whether this property is required"""
|
205
|
+
return self._required
|
206
|
+
|
207
|
+
@required.setter
|
208
|
+
def required(self, value: bool) -> None:
|
209
|
+
self._required = value
|
210
|
+
|
211
|
+
def __eq__(self, other) -> bool:
|
212
|
+
if isinstance(other, Property):
|
213
|
+
return (
|
214
|
+
self.name == other.name
|
215
|
+
and self.dtype == other.dtype
|
216
|
+
and self.required == other.required
|
217
|
+
)
|
218
|
+
return False
|
219
|
+
|
220
|
+
def __lt__(self, other) -> bool:
|
221
|
+
return self.name < other.name
|
222
|
+
|
223
|
+
def __repr__(self) -> str:
|
224
|
+
required = "required" if self.required else "optional"
|
225
|
+
return f"{self.name}: {self.dtype!r} ({required})"
|
226
|
+
|
227
|
+
def to_dict(self):
|
228
|
+
d = {"type": self.dtype.name} if isinstance(self.dtype, DataType) else self.dtype.to_dict()
|
229
|
+
d["required"] = self.required
|
230
|
+
return {self.name: d}
|
231
|
+
|
232
|
+
@classmethod
|
233
|
+
def from_json_dict(cls, **kwargs):
|
234
|
+
"""
|
235
|
+
Deserialize from a json loaded dictionary.
|
236
|
+
The dictionary is expected to contain only one key as `name`, and
|
237
|
+
the value should be a dictionary containing `type` and
|
238
|
+
optional `required` keys.
|
239
|
+
Example: {"property_name": {"type": "string", "required": True}}
|
240
|
+
"""
|
241
|
+
if len(kwargs) != 1:
|
242
|
+
raise MlflowException(
|
243
|
+
f"Expected Property JSON to contain a single key as name, got {len(kwargs)} keys."
|
244
|
+
)
|
245
|
+
name, dic = kwargs.popitem()
|
246
|
+
if not {"type"} <= set(dic.keys()):
|
247
|
+
raise MlflowException(f"Missing keys in Property `{name}`. Expected to find key `type`")
|
248
|
+
required = dic.pop("required", True)
|
249
|
+
dtype = dic["type"]
|
250
|
+
if dtype == ARRAY_TYPE:
|
251
|
+
return cls(name=name, dtype=Array.from_json_dict(**dic), required=required)
|
252
|
+
if dtype == SPARKML_VECTOR_TYPE:
|
253
|
+
return SparkMLVector()
|
254
|
+
if dtype == OBJECT_TYPE:
|
255
|
+
return cls(name=name, dtype=Object.from_json_dict(**dic), required=required)
|
256
|
+
if dtype == MAP_TYPE:
|
257
|
+
return cls(name=name, dtype=Map.from_json_dict(**dic), required=required)
|
258
|
+
if dtype == ANY_TYPE:
|
259
|
+
return cls(name=name, dtype=AnyType(), required=required)
|
260
|
+
return cls(name=name, dtype=dtype, required=required)
|
261
|
+
|
262
|
+
def _merge(self, other: BaseType) -> Property:
|
263
|
+
"""
|
264
|
+
Check if current property is compatible with another property and return
|
265
|
+
the updated property.
|
266
|
+
When two properties have the same name, we need to check if their dtypes
|
267
|
+
are compatible or not.
|
268
|
+
An example of two compatible properties:
|
269
|
+
|
270
|
+
.. code-block:: python
|
271
|
+
|
272
|
+
prop1 = Property(
|
273
|
+
name="a",
|
274
|
+
dtype=Object(
|
275
|
+
properties=[Property(name="a", dtype=DataType.string, required=False)]
|
276
|
+
),
|
277
|
+
)
|
278
|
+
prop2 = Property(
|
279
|
+
name="a",
|
280
|
+
dtype=Object(
|
281
|
+
properties=[
|
282
|
+
Property(name="a", dtype=DataType.string),
|
283
|
+
Property(name="b", dtype=DataType.double),
|
284
|
+
]
|
285
|
+
),
|
286
|
+
)
|
287
|
+
merged_prop = prop1._merge(prop2)
|
288
|
+
assert merged_prop == Property(
|
289
|
+
name="a",
|
290
|
+
dtype=Object(
|
291
|
+
properties=[
|
292
|
+
Property(name="a", dtype=DataType.string, required=False),
|
293
|
+
Property(name="b", dtype=DataType.double, required=False),
|
294
|
+
]
|
295
|
+
),
|
296
|
+
)
|
297
|
+
|
298
|
+
"""
|
299
|
+
if isinstance(other, AnyType):
|
300
|
+
return Property(name=self.name, dtype=self.dtype, required=False)
|
301
|
+
if not isinstance(other, Property):
|
302
|
+
raise MlflowException(
|
303
|
+
f"Can't merge property with non-property type: {type(other).__name__}"
|
304
|
+
)
|
305
|
+
if self.name != other.name:
|
306
|
+
raise MlflowException("Can't merge properties with different names")
|
307
|
+
required = self.required and other.required
|
308
|
+
if isinstance(self.dtype, DataType) and isinstance(other.dtype, DataType):
|
309
|
+
if self.dtype == other.dtype:
|
310
|
+
return Property(name=self.name, dtype=self.dtype, required=required)
|
311
|
+
raise MlflowException(f"Properties are incompatible for {self.dtype} and {other.dtype}")
|
312
|
+
|
313
|
+
if isinstance(self.dtype, (Array, Object, Map, AnyType)):
|
314
|
+
obj = self.dtype._merge(other.dtype)
|
315
|
+
return Property(name=self.name, dtype=obj, required=required)
|
316
|
+
|
317
|
+
raise MlflowException("Properties are incompatible")
|
318
|
+
|
319
|
+
|
320
|
+
class Object(BaseType):
|
321
|
+
"""
|
322
|
+
Specification used to represent a json-convertible object.
|
323
|
+
"""
|
324
|
+
|
325
|
+
def __init__(self, properties: list[Property]) -> None:
|
326
|
+
self._check_properties(properties)
|
327
|
+
# Sort by name to make sure the order is stable
|
328
|
+
self._properties = sorted(properties)
|
329
|
+
|
330
|
+
def _check_properties(self, properties):
|
331
|
+
if not isinstance(properties, list):
|
332
|
+
raise MlflowException.invalid_parameter_value(
|
333
|
+
f"Expected properties to be a list, got type {type(properties).__name__}"
|
334
|
+
)
|
335
|
+
if len(properties) == 0:
|
336
|
+
raise MlflowException.invalid_parameter_value(
|
337
|
+
"Creating Object with empty properties is not allowed."
|
338
|
+
)
|
339
|
+
if any(not isinstance(v, Property) for v in properties):
|
340
|
+
raise MlflowException.invalid_parameter_value(
|
341
|
+
"Expected values to be instance of Property"
|
342
|
+
)
|
343
|
+
# check duplicated property names
|
344
|
+
names = [prop.name for prop in properties]
|
345
|
+
duplicates = {name for name in names if names.count(name) > 1}
|
346
|
+
if len(duplicates) > 0:
|
347
|
+
raise MlflowException.invalid_parameter_value(
|
348
|
+
f"Found duplicated property names: {duplicates}"
|
349
|
+
)
|
350
|
+
|
351
|
+
@property
|
352
|
+
def properties(self) -> list[Property]:
|
353
|
+
"""The list of object properties"""
|
354
|
+
return self._properties
|
355
|
+
|
356
|
+
@properties.setter
|
357
|
+
def properties(self, value: list[Property]) -> None:
|
358
|
+
self._check_properties(value)
|
359
|
+
self._properties = sorted(value)
|
360
|
+
|
361
|
+
def __eq__(self, other) -> bool:
|
362
|
+
if isinstance(other, Object):
|
363
|
+
return self.properties == other.properties
|
364
|
+
return False
|
365
|
+
|
366
|
+
def __repr__(self) -> str:
|
367
|
+
joined = ", ".join(map(repr, self.properties))
|
368
|
+
return "{" + joined + "}"
|
369
|
+
|
370
|
+
def to_dict(self):
|
371
|
+
properties = {
|
372
|
+
name: value for prop in self.properties for name, value in prop.to_dict().items()
|
373
|
+
}
|
374
|
+
return {
|
375
|
+
"type": OBJECT_TYPE,
|
376
|
+
"properties": properties,
|
377
|
+
}
|
378
|
+
|
379
|
+
@classmethod
|
380
|
+
def from_json_dict(cls, **kwargs):
|
381
|
+
"""
|
382
|
+
Deserialize from a json loaded dictionary.
|
383
|
+
The dictionary is expected to contain `type` and
|
384
|
+
`properties` keys.
|
385
|
+
Example: {"type": "object", "properties": {"property_name": {"type": "string"}}}
|
386
|
+
"""
|
387
|
+
if not {"properties", "type"} <= set(kwargs.keys()):
|
388
|
+
raise MlflowException(
|
389
|
+
"Missing keys in Object JSON. Expected to find keys `properties` and `type`"
|
390
|
+
)
|
391
|
+
if kwargs["type"] != OBJECT_TYPE:
|
392
|
+
raise MlflowException("Type mismatch, Object expects `object` as the type")
|
393
|
+
if not isinstance(kwargs["properties"], dict) or any(
|
394
|
+
not isinstance(prop, dict) for prop in kwargs["properties"].values()
|
395
|
+
):
|
396
|
+
raise MlflowException("Expected properties to be a dictionary of Property JSON")
|
397
|
+
return cls(
|
398
|
+
[Property.from_json_dict(**{name: prop}) for name, prop in kwargs["properties"].items()]
|
399
|
+
)
|
400
|
+
|
401
|
+
def _merge(self, other: BaseType) -> Object:
|
402
|
+
"""
|
403
|
+
Check if the current object is compatible with another object and return
|
404
|
+
the updated object.
|
405
|
+
When we infer the signature from a list of objects, it is possible
|
406
|
+
that one object has more properties than the other. In this case,
|
407
|
+
we should mark those optional properties as required=False.
|
408
|
+
For properties with the same name, we should check the compatibility
|
409
|
+
of two properties and update.
|
410
|
+
An example of two compatible objects:
|
411
|
+
|
412
|
+
.. code-block:: python
|
413
|
+
|
414
|
+
obj1 = Object(
|
415
|
+
properties=[
|
416
|
+
Property(name="a", dtype=DataType.string),
|
417
|
+
Property(name="b", dtype=DataType.double),
|
418
|
+
]
|
419
|
+
)
|
420
|
+
obj2 = Object(
|
421
|
+
properties=[
|
422
|
+
Property(name="a", dtype=DataType.string),
|
423
|
+
Property(name="c", dtype=DataType.boolean),
|
424
|
+
]
|
425
|
+
)
|
426
|
+
updated_obj = obj1._merge(obj2)
|
427
|
+
assert updated_obj == Object(
|
428
|
+
properties=[
|
429
|
+
Property(name="a", dtype=DataType.string),
|
430
|
+
Property(name="b", dtype=DataType.double, required=False),
|
431
|
+
Property(name="c", dtype=DataType.boolean, required=False),
|
432
|
+
]
|
433
|
+
)
|
434
|
+
|
435
|
+
"""
|
436
|
+
# Merging object type with AnyType makes all properties optional
|
437
|
+
if isinstance(other, AnyType):
|
438
|
+
return Object(
|
439
|
+
properties=[
|
440
|
+
Property(name=prop.name, dtype=prop.dtype, required=False)
|
441
|
+
for prop in self.properties
|
442
|
+
]
|
443
|
+
)
|
444
|
+
if not isinstance(other, Object):
|
445
|
+
raise MlflowException(
|
446
|
+
f"Can't merge object with non-object type: {type(other).__name__}"
|
447
|
+
)
|
448
|
+
if self == other:
|
449
|
+
return deepcopy(self)
|
450
|
+
prop_dict1 = {prop.name: prop for prop in self.properties}
|
451
|
+
prop_dict2 = {prop.name: prop for prop in other.properties}
|
452
|
+
updated_properties = []
|
453
|
+
# For each property in the first element, if it doesn't appear
|
454
|
+
# later, we update required=False
|
455
|
+
for k in prop_dict1.keys() - prop_dict2.keys():
|
456
|
+
updated_properties.append(Property(name=k, dtype=prop_dict1[k].dtype, required=False))
|
457
|
+
# For common keys, property type should be the same
|
458
|
+
for k in prop_dict1.keys() & prop_dict2.keys():
|
459
|
+
updated_properties.append(prop_dict1[k]._merge(prop_dict2[k]))
|
460
|
+
# For each property appears in the second elements, if it doesn't
|
461
|
+
# exist, we update and set required=False
|
462
|
+
for k in prop_dict2.keys() - prop_dict1.keys():
|
463
|
+
updated_properties.append(Property(name=k, dtype=prop_dict2[k].dtype, required=False))
|
464
|
+
return Object(properties=updated_properties)
|
465
|
+
|
466
|
+
|
467
|
+
class Array(BaseType):
|
468
|
+
"""
|
469
|
+
Specification used to represent a json-convertible array.
|
470
|
+
"""
|
471
|
+
|
472
|
+
def __init__(
|
473
|
+
self,
|
474
|
+
dtype: ALLOWED_DTYPES,
|
475
|
+
) -> None:
|
476
|
+
try:
|
477
|
+
self._dtype = DataType[dtype] if isinstance(dtype, str) else dtype
|
478
|
+
except KeyError:
|
479
|
+
raise MlflowException(
|
480
|
+
f"Unsupported type '{dtype}', expected instance of DataType, Array, Object, Map or "
|
481
|
+
f"one of {[t.name for t in DataType]}"
|
482
|
+
)
|
483
|
+
if not isinstance(self.dtype, (Array, DataType, Object, Map, AnyType)):
|
484
|
+
raise MlflowException(
|
485
|
+
EXPECTED_TYPE_MESSAGE.format(arg_name="dtype", passed_type=self.dtype)
|
486
|
+
)
|
487
|
+
|
488
|
+
@property
|
489
|
+
def dtype(self) -> Union["Array", DataType, Object, "Map", "AnyType"]:
|
490
|
+
"""The array data type."""
|
491
|
+
return self._dtype
|
492
|
+
|
493
|
+
def __eq__(self, other) -> bool:
|
494
|
+
if isinstance(other, Array):
|
495
|
+
return self.dtype == other.dtype
|
496
|
+
return False
|
497
|
+
|
498
|
+
def to_dict(self):
|
499
|
+
items = (
|
500
|
+
{"type": self.dtype.name} if isinstance(self.dtype, DataType) else self.dtype.to_dict()
|
501
|
+
)
|
502
|
+
return {"type": ARRAY_TYPE, "items": items}
|
503
|
+
|
504
|
+
@classmethod
|
505
|
+
def from_json_dict(cls, **kwargs):
|
506
|
+
"""
|
507
|
+
Deserialize from a json loaded dictionary.
|
508
|
+
The dictionary is expected to contain `type` and
|
509
|
+
`items` keys.
|
510
|
+
Example: {"type": "array", "items": "string"}
|
511
|
+
"""
|
512
|
+
if not {"items", "type"} <= set(kwargs.keys()):
|
513
|
+
raise MlflowException(
|
514
|
+
"Missing keys in Array JSON. Expected to find keys `items` and `type`"
|
515
|
+
)
|
516
|
+
if kwargs["type"] != ARRAY_TYPE:
|
517
|
+
raise MlflowException("Type mismatch, Array expects `array` as the type")
|
518
|
+
if not isinstance(kwargs["items"], dict):
|
519
|
+
raise MlflowException("Expected items to be a dictionary of Object JSON")
|
520
|
+
if not {"type"} <= set(kwargs["items"].keys()):
|
521
|
+
raise MlflowException("Missing keys in Array's items JSON. Expected to find key `type`")
|
522
|
+
|
523
|
+
if kwargs["items"]["type"] == OBJECT_TYPE:
|
524
|
+
item_type = Object.from_json_dict(**kwargs["items"])
|
525
|
+
elif kwargs["items"]["type"] == ARRAY_TYPE:
|
526
|
+
item_type = Array.from_json_dict(**kwargs["items"])
|
527
|
+
elif kwargs["items"]["type"] == SPARKML_VECTOR_TYPE:
|
528
|
+
item_type = SparkMLVector()
|
529
|
+
elif kwargs["items"]["type"] == MAP_TYPE:
|
530
|
+
item_type = Map.from_json_dict(**kwargs["items"])
|
531
|
+
elif kwargs["items"]["type"] == ANY_TYPE:
|
532
|
+
item_type = AnyType()
|
533
|
+
else:
|
534
|
+
item_type = kwargs["items"]["type"]
|
535
|
+
|
536
|
+
return cls(dtype=item_type)
|
537
|
+
|
538
|
+
def __repr__(self) -> str:
|
539
|
+
return f"Array({self.dtype!r})"
|
540
|
+
|
541
|
+
def _merge(self, other: BaseType) -> Array:
|
542
|
+
if isinstance(other, AnyType) or self == other:
|
543
|
+
return deepcopy(self)
|
544
|
+
if not isinstance(other, Array):
|
545
|
+
raise MlflowException(f"Can't merge array with non-array type: {type(other).__name__}")
|
546
|
+
if isinstance(self.dtype, DataType):
|
547
|
+
if self.dtype == other.dtype:
|
548
|
+
return Array(dtype=self.dtype)
|
549
|
+
raise MlflowException(
|
550
|
+
f"Array types are incompatible for {self} with dtype={self.dtype} and "
|
551
|
+
f"{other} with dtype={other.dtype}"
|
552
|
+
)
|
553
|
+
|
554
|
+
if isinstance(self.dtype, (Array, Object, Map, AnyType)):
|
555
|
+
return Array(dtype=self.dtype._merge(other.dtype))
|
556
|
+
|
557
|
+
raise MlflowException(f"Array type {self!r} and {other!r} are incompatible")
|
558
|
+
|
559
|
+
|
560
|
+
class SparkMLVector(Array):
|
561
|
+
"""
|
562
|
+
Specification used to represent a vector type in Spark ML.
|
563
|
+
"""
|
564
|
+
|
565
|
+
def __init__(self):
|
566
|
+
super().__init__(dtype=DataType.double)
|
567
|
+
|
568
|
+
def to_dict(self):
|
569
|
+
return {"type": SPARKML_VECTOR_TYPE}
|
570
|
+
|
571
|
+
@classmethod
|
572
|
+
def from_json_dict(cls, **kwargs):
|
573
|
+
return SparkMLVector()
|
574
|
+
|
575
|
+
def __repr__(self) -> str:
|
576
|
+
return "SparkML vector"
|
577
|
+
|
578
|
+
def __eq__(self, other) -> bool:
|
579
|
+
return isinstance(other, SparkMLVector)
|
580
|
+
|
581
|
+
def _merge(self, arr: BaseType) -> SparkMLVector:
|
582
|
+
if isinstance(arr, SparkMLVector):
|
583
|
+
return deepcopy(self)
|
584
|
+
raise MlflowException("SparkML vector type can't be merged with another Array type.")
|
585
|
+
|
586
|
+
|
587
|
+
class Map(BaseType):
|
588
|
+
"""
|
589
|
+
Specification used to represent a json-convertible map with string type keys.
|
590
|
+
"""
|
591
|
+
|
592
|
+
def __init__(self, value_type: ALLOWED_DTYPES):
|
593
|
+
try:
|
594
|
+
self._value_type = DataType[value_type] if isinstance(value_type, str) else value_type
|
595
|
+
except KeyError:
|
596
|
+
raise MlflowException(
|
597
|
+
f"Unsupported value type '{value_type}', expected instance of DataType, Array, "
|
598
|
+
f"Object, Map or one of {[t.name for t in DataType]}"
|
599
|
+
)
|
600
|
+
if not isinstance(self._value_type, (Array, Map, DataType, Object, AnyType)):
|
601
|
+
raise MlflowException.invalid_parameter_value(
|
602
|
+
EXPECTED_TYPE_MESSAGE.format(arg_name="value_type", passed_type=self._value_type)
|
603
|
+
)
|
604
|
+
|
605
|
+
@property
|
606
|
+
def value_type(self):
|
607
|
+
return self._value_type
|
608
|
+
|
609
|
+
def __repr__(self) -> str:
|
610
|
+
return f"Map(str -> {self._value_type})"
|
611
|
+
|
612
|
+
def __eq__(self, other) -> bool:
|
613
|
+
if isinstance(other, Map):
|
614
|
+
return self.value_type == other.value_type
|
615
|
+
return False
|
616
|
+
|
617
|
+
def to_dict(self):
|
618
|
+
values = (
|
619
|
+
{"type": self.value_type.name}
|
620
|
+
if isinstance(self.value_type, DataType)
|
621
|
+
else self.value_type.to_dict()
|
622
|
+
)
|
623
|
+
return {"type": MAP_TYPE, "values": values}
|
624
|
+
|
625
|
+
@classmethod
|
626
|
+
def from_json_dict(cls, **kwargs):
|
627
|
+
"""
|
628
|
+
Deserialize from a json loaded dictionary.
|
629
|
+
The dictionary is expected to contain `type` and
|
630
|
+
`values` keys.
|
631
|
+
Example: {"type": "map", "values": "string"}
|
632
|
+
"""
|
633
|
+
if not {"values", "type"} <= set(kwargs.keys()):
|
634
|
+
raise MlflowException(
|
635
|
+
"Missing keys in Array JSON. Expected to find keys `items` and `type`"
|
636
|
+
)
|
637
|
+
if kwargs["type"] != MAP_TYPE:
|
638
|
+
raise MlflowException("Type mismatch, Map expects `map` as the type")
|
639
|
+
if not isinstance(kwargs["values"], dict):
|
640
|
+
raise MlflowException("Expected values to be a dictionary of Object JSON")
|
641
|
+
if not {"type"} <= set(kwargs["values"].keys()):
|
642
|
+
raise MlflowException("Missing keys in Map's items JSON. Expected to find key `type`")
|
643
|
+
if kwargs["values"]["type"] == OBJECT_TYPE:
|
644
|
+
return cls(value_type=Object.from_json_dict(**kwargs["values"]))
|
645
|
+
if kwargs["values"]["type"] == ARRAY_TYPE:
|
646
|
+
return cls(value_type=Array.from_json_dict(**kwargs["values"]))
|
647
|
+
if kwargs["values"]["type"] == SPARKML_VECTOR_TYPE:
|
648
|
+
return SparkMLVector()
|
649
|
+
if kwargs["values"]["type"] == MAP_TYPE:
|
650
|
+
return cls(value_type=Map.from_json_dict(**kwargs["values"]))
|
651
|
+
if kwargs["values"]["type"] == ANY_TYPE:
|
652
|
+
return cls(value_type=AnyType())
|
653
|
+
return cls(value_type=kwargs["values"]["type"])
|
654
|
+
|
655
|
+
def _merge(self, other: BaseType) -> Map:
|
656
|
+
if isinstance(other, AnyType) or self == other:
|
657
|
+
return deepcopy(self)
|
658
|
+
if not isinstance(other, Map):
|
659
|
+
raise MlflowException(f"Can't merge map with non-map type: {type(other).__name__}")
|
660
|
+
if isinstance(self.value_type, DataType):
|
661
|
+
if self.value_type == other.value_type:
|
662
|
+
return Map(value_type=self.value_type)
|
663
|
+
raise MlflowException(
|
664
|
+
f"Map types are incompatible for {self} with value_type={self.value_type} and "
|
665
|
+
f"{other} with value_type={other.value_type}"
|
666
|
+
)
|
667
|
+
|
668
|
+
if isinstance(self.value_type, (Array, Object, Map, AnyType)):
|
669
|
+
return Map(value_type=self.value_type._merge(other.value_type))
|
670
|
+
|
671
|
+
raise MlflowException(f"Map type {self!r} and {other!r} are incompatible")
|
672
|
+
|
673
|
+
|
674
|
+
@experimental(version="2.19.0")
|
675
|
+
class AnyType(BaseType):
|
676
|
+
def __init__(self):
|
677
|
+
"""
|
678
|
+
AnyType can store any json-serializable data including None values.
|
679
|
+
For example:
|
680
|
+
|
681
|
+
.. code-block::python
|
682
|
+
|
683
|
+
from mlflow.types.schema import AnyType, Schema, ColSpec
|
684
|
+
|
685
|
+
schema = Schema([ColSpec(type=AnyType(), name="id")])
|
686
|
+
|
687
|
+
.. Note::
|
688
|
+
AnyType should be used when the field is None, the type is not known
|
689
|
+
at the time of data creation, or the field can have multiple types.
|
690
|
+
e.g. for GenAI flavors, the model output could contain `None` values,
|
691
|
+
and `AnyType` can be used to represent them.
|
692
|
+
AnyType has no data validation at all, please be aware of this when
|
693
|
+
using it.
|
694
|
+
"""
|
695
|
+
|
696
|
+
def __repr__(self) -> str:
|
697
|
+
return "Any"
|
698
|
+
|
699
|
+
def __eq__(self, other) -> bool:
|
700
|
+
return isinstance(other, AnyType)
|
701
|
+
|
702
|
+
def to_dict(self):
|
703
|
+
return {"type": ANY_TYPE}
|
704
|
+
|
705
|
+
def _merge(self, other: BaseType) -> BaseType:
|
706
|
+
if self == other:
|
707
|
+
return deepcopy(self)
|
708
|
+
if isinstance(other, DataType):
|
709
|
+
return other
|
710
|
+
if not isinstance(other, BaseType):
|
711
|
+
raise MlflowException(
|
712
|
+
f"Can't merge AnyType with {type(other).__name__}, "
|
713
|
+
"it must be a BaseType or DataType"
|
714
|
+
)
|
715
|
+
# Merging AnyType with another type makes the other type optional
|
716
|
+
return other._merge(self)
|
717
|
+
|
718
|
+
|
719
|
+
class ColSpec:
|
720
|
+
"""
|
721
|
+
Specification of name and type of a single column in a dataset.
|
722
|
+
"""
|
723
|
+
|
724
|
+
def __init__(
|
725
|
+
self,
|
726
|
+
type: ALLOWED_DTYPES,
|
727
|
+
name: Optional[str] = None,
|
728
|
+
required: bool = True,
|
729
|
+
):
|
730
|
+
self._name = name
|
731
|
+
|
732
|
+
self._required = required
|
733
|
+
try:
|
734
|
+
self._type = DataType[type] if isinstance(type, str) else type
|
735
|
+
except KeyError:
|
736
|
+
raise MlflowException(
|
737
|
+
f"Unsupported type '{type}', expected instance of DataType or "
|
738
|
+
f"one of {[t.name for t in DataType]}"
|
739
|
+
)
|
740
|
+
if not isinstance(self.type, (DataType, Array, Object, Map, AnyType)):
|
741
|
+
raise TypeError(EXPECTED_TYPE_MESSAGE.format(arg_name="type", passed_type=self.type))
|
742
|
+
|
743
|
+
@property
|
744
|
+
def type(self) -> Union[DataType, Array, Object, Map, AnyType]:
|
745
|
+
"""The column data type."""
|
746
|
+
return self._type
|
747
|
+
|
748
|
+
@property
|
749
|
+
def name(self) -> Optional[str]:
|
750
|
+
"""The column name or None if the columns is unnamed."""
|
751
|
+
return self._name
|
752
|
+
|
753
|
+
@name.setter
|
754
|
+
def name(self, value: bool) -> None:
|
755
|
+
self._name = value
|
756
|
+
|
757
|
+
@property
|
758
|
+
def required(self) -> bool:
|
759
|
+
"""Whether this column is required."""
|
760
|
+
return self._required
|
761
|
+
|
762
|
+
def to_dict(self) -> dict[str, Any]:
|
763
|
+
d = {"type": self.type.name} if isinstance(self.type, DataType) else self.type.to_dict()
|
764
|
+
if self.name is not None:
|
765
|
+
d["name"] = self.name
|
766
|
+
d["required"] = self.required
|
767
|
+
return d
|
768
|
+
|
769
|
+
def __eq__(self, other) -> bool:
|
770
|
+
if isinstance(other, ColSpec):
|
771
|
+
names_eq = (self.name is None and other.name is None) or self.name == other.name
|
772
|
+
return names_eq and self.type == other.type and self.required == other.required
|
773
|
+
return False
|
774
|
+
|
775
|
+
def __repr__(self) -> str:
|
776
|
+
required = "required" if self.required else "optional"
|
777
|
+
if self.name is None:
|
778
|
+
return f"{self.type!r} ({required})"
|
779
|
+
return f"{self.name!r}: {self.type!r} ({required})"
|
780
|
+
|
781
|
+
@classmethod
|
782
|
+
def from_json_dict(cls, **kwargs):
|
783
|
+
"""
|
784
|
+
Deserialize from a json loaded dictionary.
|
785
|
+
The dictionary is expected to contain `type` and
|
786
|
+
optional `name` and `required` keys.
|
787
|
+
"""
|
788
|
+
if not {"type"} <= set(kwargs.keys()):
|
789
|
+
raise MlflowException("Missing keys in ColSpec JSON. Expected to find key `type`")
|
790
|
+
if kwargs["type"] not in [ARRAY_TYPE, OBJECT_TYPE, MAP_TYPE, SPARKML_VECTOR_TYPE, ANY_TYPE]:
|
791
|
+
return cls(**kwargs)
|
792
|
+
name = kwargs.pop("name", None)
|
793
|
+
required = kwargs.pop("required", None)
|
794
|
+
if kwargs["type"] == ARRAY_TYPE:
|
795
|
+
return cls(name=name, type=Array.from_json_dict(**kwargs), required=required)
|
796
|
+
if kwargs["type"] == OBJECT_TYPE:
|
797
|
+
return cls(
|
798
|
+
name=name,
|
799
|
+
type=Object.from_json_dict(**kwargs),
|
800
|
+
required=required,
|
801
|
+
)
|
802
|
+
if kwargs["type"] == MAP_TYPE:
|
803
|
+
return cls(name=name, type=Map.from_json_dict(**kwargs), required=required)
|
804
|
+
if kwargs["type"] == SPARKML_VECTOR_TYPE:
|
805
|
+
return cls(name=name, type=SparkMLVector(), required=required)
|
806
|
+
if kwargs["type"] == ANY_TYPE:
|
807
|
+
return cls(name=name, type=AnyType(), required=required)
|
808
|
+
|
809
|
+
|
810
|
+
class TensorInfo:
|
811
|
+
"""
|
812
|
+
Representation of the shape and type of a Tensor.
|
813
|
+
"""
|
814
|
+
|
815
|
+
def __init__(self, dtype: np.dtype, shape: Union[tuple[Any, ...], list[Any]]):
|
816
|
+
if not isinstance(dtype, np.dtype):
|
817
|
+
raise TypeError(
|
818
|
+
f"Expected `dtype` to be instance of `{np.dtype}`, received `{dtype.__class__}`"
|
819
|
+
)
|
820
|
+
# Throw if size information exists flexible numpy data types
|
821
|
+
if dtype.char in ["U", "S"] and not dtype.name.isalpha():
|
822
|
+
raise MlflowException(
|
823
|
+
"MLflow does not support size information in flexible numpy data types. Use"
|
824
|
+
f' np.dtype("{dtype.name.rstrip(string.digits)}") instead'
|
825
|
+
)
|
826
|
+
|
827
|
+
if not isinstance(shape, (tuple, list)):
|
828
|
+
raise TypeError(
|
829
|
+
"Expected `shape` to be instance of `{}` or `{}`, received `{}`".format(
|
830
|
+
tuple, list, shape.__class__
|
831
|
+
)
|
832
|
+
)
|
833
|
+
self._dtype = dtype
|
834
|
+
self._shape = tuple(shape)
|
835
|
+
|
836
|
+
@property
|
837
|
+
def dtype(self) -> np.dtype:
|
838
|
+
"""
|
839
|
+
A unique character code for each of the 21 different numpy built-in types.
|
840
|
+
See https://numpy.org/devdocs/reference/generated/numpy.dtype.html#numpy.dtype for details.
|
841
|
+
"""
|
842
|
+
return self._dtype
|
843
|
+
|
844
|
+
@property
|
845
|
+
def shape(self) -> tuple[int, ...]:
|
846
|
+
"""The tensor shape"""
|
847
|
+
return self._shape
|
848
|
+
|
849
|
+
def to_dict(self) -> dict[str, Any]:
|
850
|
+
return {"dtype": self._dtype.name, "shape": self._shape}
|
851
|
+
|
852
|
+
@classmethod
|
853
|
+
def from_json_dict(cls, **kwargs):
|
854
|
+
"""
|
855
|
+
Deserialize from a json loaded dictionary.
|
856
|
+
The dictionary is expected to contain `dtype` and `shape` keys.
|
857
|
+
"""
|
858
|
+
if not {"dtype", "shape"} <= set(kwargs.keys()):
|
859
|
+
raise MlflowException(
|
860
|
+
"Missing keys in TensorSpec JSON. Expected to find keys `dtype` and `shape`"
|
861
|
+
)
|
862
|
+
tensor_type = np.dtype(kwargs["dtype"])
|
863
|
+
tensor_shape = tuple(kwargs["shape"])
|
864
|
+
return cls(tensor_type, tensor_shape)
|
865
|
+
|
866
|
+
def __repr__(self) -> str:
|
867
|
+
return f"Tensor({self.dtype.name!r}, {self.shape!r})"
|
868
|
+
|
869
|
+
|
870
|
+
class TensorSpec:
|
871
|
+
"""
|
872
|
+
Specification used to represent a dataset stored as a Tensor.
|
873
|
+
"""
|
874
|
+
|
875
|
+
def __init__(
|
876
|
+
self,
|
877
|
+
type: np.dtype,
|
878
|
+
shape: Union[tuple[int, ...], list[int]],
|
879
|
+
name: Optional[str] = None,
|
880
|
+
):
|
881
|
+
self._name = name
|
882
|
+
self._tensorInfo = TensorInfo(type, shape)
|
883
|
+
|
884
|
+
@property
|
885
|
+
def type(self) -> np.dtype:
|
886
|
+
"""
|
887
|
+
A unique character code for each of the 21 different numpy built-in types.
|
888
|
+
See https://numpy.org/devdocs/reference/generated/numpy.dtype.html#numpy.dtype for details.
|
889
|
+
"""
|
890
|
+
return self._tensorInfo.dtype
|
891
|
+
|
892
|
+
@property
|
893
|
+
def name(self) -> Optional[str]:
|
894
|
+
"""The tensor name or None if the tensor is unnamed."""
|
895
|
+
return self._name
|
896
|
+
|
897
|
+
@property
|
898
|
+
def shape(self) -> tuple[int, ...]:
|
899
|
+
"""The tensor shape"""
|
900
|
+
return self._tensorInfo.shape
|
901
|
+
|
902
|
+
@property
|
903
|
+
def required(self) -> bool:
|
904
|
+
"""Whether this tensor is required."""
|
905
|
+
return True
|
906
|
+
|
907
|
+
def to_dict(self) -> dict[str, Any]:
|
908
|
+
if self.name is None:
|
909
|
+
return {"type": "tensor", "tensor-spec": self._tensorInfo.to_dict()}
|
910
|
+
else:
|
911
|
+
return {"name": self.name, "type": "tensor", "tensor-spec": self._tensorInfo.to_dict()}
|
912
|
+
|
913
|
+
@classmethod
|
914
|
+
def from_json_dict(cls, **kwargs):
|
915
|
+
"""
|
916
|
+
Deserialize from a json loaded dictionary.
|
917
|
+
The dictionary is expected to contain `type` and `tensor-spec` keys.
|
918
|
+
"""
|
919
|
+
if not {"tensor-spec", "type"} <= set(kwargs.keys()):
|
920
|
+
raise MlflowException(
|
921
|
+
"Missing keys in TensorSpec JSON. Expected to find keys `tensor-spec` and `type`"
|
922
|
+
)
|
923
|
+
if kwargs["type"] != "tensor":
|
924
|
+
raise MlflowException("Type mismatch, TensorSpec expects `tensor` as the type")
|
925
|
+
tensor_info = TensorInfo.from_json_dict(**kwargs["tensor-spec"])
|
926
|
+
return cls(
|
927
|
+
tensor_info.dtype, tensor_info.shape, kwargs["name"] if "name" in kwargs else None
|
928
|
+
)
|
929
|
+
|
930
|
+
def __eq__(self, other) -> bool:
|
931
|
+
if isinstance(other, TensorSpec):
|
932
|
+
names_eq = (self.name is None and other.name is None) or self.name == other.name
|
933
|
+
return names_eq and self.type == other.type and self.shape == other.shape
|
934
|
+
return False
|
935
|
+
|
936
|
+
def __repr__(self) -> str:
|
937
|
+
if self.name is None:
|
938
|
+
return repr(self._tensorInfo)
|
939
|
+
else:
|
940
|
+
return f"{self.name!r}: {self._tensorInfo!r}"
|
941
|
+
|
942
|
+
|
943
|
+
class Schema:
|
944
|
+
"""
|
945
|
+
Specification of a dataset.
|
946
|
+
|
947
|
+
Schema is represented as a list of :py:class:`ColSpec` or :py:class:`TensorSpec`. A combination
|
948
|
+
of `ColSpec` and `TensorSpec` is not allowed.
|
949
|
+
|
950
|
+
The dataset represented by a schema can be named, with unique non empty names for every input.
|
951
|
+
In the case of :py:class:`ColSpec`, the dataset columns can be unnamed with implicit integer
|
952
|
+
index defined by their list indices.
|
953
|
+
Combination of named and unnamed data inputs are not allowed.
|
954
|
+
"""
|
955
|
+
|
956
|
+
def __init__(self, inputs: list[Union[ColSpec, TensorSpec]]):
|
957
|
+
if not isinstance(inputs, list):
|
958
|
+
raise MlflowException.invalid_parameter_value(
|
959
|
+
f"Inputs of Schema must be a list, got type {type(inputs).__name__}"
|
960
|
+
)
|
961
|
+
if not inputs:
|
962
|
+
raise MlflowException.invalid_parameter_value(
|
963
|
+
"Creating Schema with empty inputs is not allowed."
|
964
|
+
)
|
965
|
+
|
966
|
+
if not (all(x.name is None for x in inputs) or all(x.name is not None for x in inputs)):
|
967
|
+
raise MlflowException(
|
968
|
+
"Creating Schema with a combination of named and unnamed inputs "
|
969
|
+
f"is not allowed. Got input names {[x.name for x in inputs]}"
|
970
|
+
)
|
971
|
+
if not (
|
972
|
+
all(isinstance(x, TensorSpec) for x in inputs)
|
973
|
+
or all(isinstance(x, ColSpec) for x in inputs)
|
974
|
+
):
|
975
|
+
raise MlflowException(
|
976
|
+
"Creating Schema with a combination of {0} and {1} is not supported. "
|
977
|
+
f"Please choose one of {ColSpec.__name__} or {TensorSpec.__name__}"
|
978
|
+
)
|
979
|
+
if (
|
980
|
+
all(isinstance(x, TensorSpec) for x in inputs)
|
981
|
+
and len(inputs) > 1
|
982
|
+
and any(x.name is None for x in inputs)
|
983
|
+
):
|
984
|
+
raise MlflowException(
|
985
|
+
"Creating Schema with multiple unnamed TensorSpecs is not supported. "
|
986
|
+
"Please provide names for each TensorSpec."
|
987
|
+
)
|
988
|
+
if all(x.name is None for x in inputs) and any(x.required is False for x in inputs):
|
989
|
+
raise MlflowException(
|
990
|
+
"Creating Schema with unnamed optional inputs is not supported. "
|
991
|
+
"Please name all inputs or make all inputs required."
|
992
|
+
)
|
993
|
+
self._inputs = inputs
|
994
|
+
|
995
|
+
def __len__(self):
|
996
|
+
return len(self._inputs)
|
997
|
+
|
998
|
+
def __iter__(self):
|
999
|
+
return iter(self._inputs)
|
1000
|
+
|
1001
|
+
@property
|
1002
|
+
def inputs(self) -> list[Union[ColSpec, TensorSpec]]:
|
1003
|
+
"""Representation of a dataset that defines this schema."""
|
1004
|
+
return self._inputs
|
1005
|
+
|
1006
|
+
def is_tensor_spec(self) -> bool:
|
1007
|
+
"""Return true iff this schema is specified using TensorSpec"""
|
1008
|
+
return self.inputs and isinstance(self.inputs[0], TensorSpec)
|
1009
|
+
|
1010
|
+
def input_names(self) -> list[Union[str, int]]:
|
1011
|
+
"""Get list of data names or range of indices if the schema has no names."""
|
1012
|
+
return [x.name or i for i, x in enumerate(self.inputs)]
|
1013
|
+
|
1014
|
+
def required_input_names(self) -> list[Union[str, int]]:
|
1015
|
+
"""Get list of required data names or range of indices if schema has no names."""
|
1016
|
+
return [x.name or i for i, x in enumerate(self.inputs) if x.required]
|
1017
|
+
|
1018
|
+
def optional_input_names(self) -> list[Union[str, int]]:
|
1019
|
+
"""Get list of optional data names or range of indices if schema has no names."""
|
1020
|
+
return [x.name or i for i, x in enumerate(self.inputs) if not x.required]
|
1021
|
+
|
1022
|
+
def has_input_names(self) -> bool:
|
1023
|
+
"""Return true iff this schema declares names, false otherwise."""
|
1024
|
+
return self.inputs and self.inputs[0].name is not None
|
1025
|
+
|
1026
|
+
def input_types(self) -> list[Union[DataType, np.dtype, Array, Object]]:
|
1027
|
+
"""Get types for each column in the schema."""
|
1028
|
+
return [x.type for x in self.inputs]
|
1029
|
+
|
1030
|
+
def input_types_dict(self) -> dict[str, Union[DataType, np.dtype, Array, Object]]:
|
1031
|
+
"""Maps column names to types, iff this schema declares names."""
|
1032
|
+
if not self.has_input_names():
|
1033
|
+
raise MlflowException("Cannot get input types as a dict for schema without names.")
|
1034
|
+
return {x.name: x.type for x in self.inputs}
|
1035
|
+
|
1036
|
+
def input_dict(self) -> dict[str, Union[ColSpec, TensorSpec]]:
|
1037
|
+
"""Maps column names to inputs, iff this schema declares names."""
|
1038
|
+
if not self.has_input_names():
|
1039
|
+
raise MlflowException("Cannot get input dict for schema without names.")
|
1040
|
+
return {x.name: x for x in self.inputs}
|
1041
|
+
|
1042
|
+
def numpy_types(self) -> list[np.dtype]:
|
1043
|
+
"""Convenience shortcut to get the datatypes as numpy types."""
|
1044
|
+
if self.is_tensor_spec():
|
1045
|
+
return [x.type for x in self.inputs]
|
1046
|
+
if all(isinstance(x.type, DataType) for x in self.inputs):
|
1047
|
+
return [x.type.to_numpy() for x in self.inputs]
|
1048
|
+
raise MlflowException(
|
1049
|
+
"Failed to get numpy types as some of the inputs types are not DataType."
|
1050
|
+
)
|
1051
|
+
|
1052
|
+
def pandas_types(self) -> list[np.dtype]:
|
1053
|
+
"""Convenience shortcut to get the datatypes as pandas types. Unsupported by TensorSpec."""
|
1054
|
+
if self.is_tensor_spec():
|
1055
|
+
raise MlflowException("TensorSpec only supports numpy types, use numpy_types() instead")
|
1056
|
+
if all(isinstance(x.type, DataType) for x in self.inputs):
|
1057
|
+
return [x.type.to_pandas() for x in self.inputs]
|
1058
|
+
raise MlflowException(
|
1059
|
+
"Failed to get pandas types as some of the inputs types are not DataType."
|
1060
|
+
)
|
1061
|
+
|
1062
|
+
def as_spark_schema(self):
|
1063
|
+
"""Convert to Spark schema. If this schema is a single unnamed column, it is converted
|
1064
|
+
directly the corresponding spark data type, otherwise it's returned as a struct (missing
|
1065
|
+
column names are filled with an integer sequence).
|
1066
|
+
Unsupported by TensorSpec.
|
1067
|
+
"""
|
1068
|
+
if self.is_tensor_spec():
|
1069
|
+
raise MlflowException("TensorSpec cannot be converted to spark dataframe")
|
1070
|
+
if len(self.inputs) == 1 and self.inputs[0].name is None:
|
1071
|
+
return self.inputs[0].type.to_spark()
|
1072
|
+
from pyspark.sql.types import StructField, StructType
|
1073
|
+
|
1074
|
+
return StructType(
|
1075
|
+
[
|
1076
|
+
StructField(
|
1077
|
+
name=col.name or str(i), dataType=col.type.to_spark(), nullable=not col.required
|
1078
|
+
)
|
1079
|
+
for i, col in enumerate(self.inputs)
|
1080
|
+
]
|
1081
|
+
)
|
1082
|
+
|
1083
|
+
def to_json(self) -> str:
|
1084
|
+
"""Serialize into json string."""
|
1085
|
+
return json.dumps([x.to_dict() for x in self.inputs])
|
1086
|
+
|
1087
|
+
def to_dict(self) -> list[dict[str, Any]]:
|
1088
|
+
"""Serialize into a jsonable dictionary."""
|
1089
|
+
return [x.to_dict() for x in self.inputs]
|
1090
|
+
|
1091
|
+
@classmethod
|
1092
|
+
def from_json(cls, json_str: str):
|
1093
|
+
"""Deserialize from a json string."""
|
1094
|
+
|
1095
|
+
def read_input(x: dict[str, Any]):
|
1096
|
+
return (
|
1097
|
+
TensorSpec.from_json_dict(**x)
|
1098
|
+
if x["type"] == "tensor"
|
1099
|
+
else ColSpec.from_json_dict(**x)
|
1100
|
+
)
|
1101
|
+
|
1102
|
+
return cls([read_input(x) for x in json.loads(json_str)])
|
1103
|
+
|
1104
|
+
def __eq__(self, other) -> bool:
|
1105
|
+
if isinstance(other, Schema):
|
1106
|
+
return self.inputs == other.inputs
|
1107
|
+
else:
|
1108
|
+
return False
|
1109
|
+
|
1110
|
+
def __repr__(self) -> str:
|
1111
|
+
return repr(self.inputs)
|
1112
|
+
|
1113
|
+
|
1114
|
+
class ParamSpec:
|
1115
|
+
"""
|
1116
|
+
Specification used to represent parameters for the model.
|
1117
|
+
"""
|
1118
|
+
|
1119
|
+
def __init__(
|
1120
|
+
self,
|
1121
|
+
name: str,
|
1122
|
+
dtype: Union[DataType, Object, str],
|
1123
|
+
default: Any,
|
1124
|
+
shape: Optional[tuple[int, ...]] = None,
|
1125
|
+
):
|
1126
|
+
self._name = str(name)
|
1127
|
+
self._shape = tuple(shape) if shape is not None else None
|
1128
|
+
|
1129
|
+
try:
|
1130
|
+
self._dtype = DataType[dtype] if isinstance(dtype, str) else dtype
|
1131
|
+
except KeyError:
|
1132
|
+
supported_types = [t.name for t in DataType if t.name != "binary"]
|
1133
|
+
raise MlflowException.invalid_parameter_value(
|
1134
|
+
f"Unsupported type '{dtype}', expected instance of DataType or "
|
1135
|
+
f"one of {supported_types}",
|
1136
|
+
)
|
1137
|
+
if not isinstance(self.dtype, (DataType, Object)):
|
1138
|
+
raise TypeError(f"'dtype' must be DataType, Object or str, got {self.dtype}")
|
1139
|
+
if self.dtype == DataType.binary:
|
1140
|
+
raise MlflowException.invalid_parameter_value(
|
1141
|
+
f"Binary type is not supported for parameters, ParamSpec '{self.name}'"
|
1142
|
+
"has dtype 'binary'",
|
1143
|
+
)
|
1144
|
+
|
1145
|
+
# This line makes sure repr(self) works fine
|
1146
|
+
self._default = default
|
1147
|
+
self._default = self.validate_type_and_shape(repr(self), default, self.dtype, self.shape)
|
1148
|
+
|
1149
|
+
@classmethod
|
1150
|
+
def validate_param_spec(cls, value: Any, param_spec: "ParamSpec"):
|
1151
|
+
return cls.validate_type_and_shape(
|
1152
|
+
repr(param_spec), value, param_spec.dtype, param_spec.shape
|
1153
|
+
)
|
1154
|
+
|
1155
|
+
@classmethod
|
1156
|
+
def validate_type_and_shape(
|
1157
|
+
cls,
|
1158
|
+
spec: str,
|
1159
|
+
value: Any,
|
1160
|
+
value_type: Union[DataType, Object],
|
1161
|
+
shape: Optional[tuple[int, ...]],
|
1162
|
+
):
|
1163
|
+
"""
|
1164
|
+
Validate that the value has the expected type and shape.
|
1165
|
+
"""
|
1166
|
+
from mlflow.models.utils import _enforce_object, _enforce_param_datatype
|
1167
|
+
|
1168
|
+
def _is_1d_array(value):
|
1169
|
+
return isinstance(value, (list, np.ndarray)) and np.array(value).ndim == 1
|
1170
|
+
|
1171
|
+
if shape == (-1,) and not _is_1d_array(value):
|
1172
|
+
raise MlflowException.invalid_parameter_value(
|
1173
|
+
f"Value must be a 1D array with shape (-1,) for param {spec}, "
|
1174
|
+
f"received {type(value).__name__} with ndim {np.array(value).ndim}",
|
1175
|
+
)
|
1176
|
+
|
1177
|
+
try:
|
1178
|
+
if shape is None:
|
1179
|
+
if isinstance(value_type, DataType):
|
1180
|
+
return _enforce_param_datatype(value, value_type)
|
1181
|
+
elif isinstance(value_type, Object):
|
1182
|
+
# deepcopy to make sure the value is not mutated
|
1183
|
+
# use _enforce_object to validate that the value matches the object schema.
|
1184
|
+
# return the original value to preserve its type, as validation may cast it
|
1185
|
+
# to a numpy type, but models require the original parameter type.
|
1186
|
+
# TODO: we will drop data conversion for params in the future, including
|
1187
|
+
# the current allowed conversions in _enforce_param_datatype
|
1188
|
+
_enforce_object(deepcopy(value), value_type)
|
1189
|
+
return value
|
1190
|
+
elif shape == (-1,):
|
1191
|
+
return [_enforce_param_datatype(v, value_type) for v in value]
|
1192
|
+
except Exception as e:
|
1193
|
+
raise MlflowException.invalid_parameter_value(
|
1194
|
+
f"Failed to validate type and shape for {spec}, error: {e}"
|
1195
|
+
)
|
1196
|
+
|
1197
|
+
raise MlflowException.invalid_parameter_value(
|
1198
|
+
"Shape must be None for scalar or dictionary value, or (-1,) for 1D array value "
|
1199
|
+
f"for ParamSpec {spec}), received {shape}",
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
@property
|
1203
|
+
def name(self) -> str:
|
1204
|
+
"""The name of the parameter."""
|
1205
|
+
return self._name
|
1206
|
+
|
1207
|
+
@property
|
1208
|
+
def dtype(self) -> Union[DataType, Object]:
|
1209
|
+
"""The parameter data type."""
|
1210
|
+
return self._dtype
|
1211
|
+
|
1212
|
+
@property
|
1213
|
+
def default(self) -> Any:
|
1214
|
+
"""Default value of the parameter."""
|
1215
|
+
return self._default
|
1216
|
+
|
1217
|
+
@property
|
1218
|
+
def shape(self) -> Optional[tuple[int, ...]]:
|
1219
|
+
"""
|
1220
|
+
The parameter shape.
|
1221
|
+
If shape is None, the parameter is a scalar.
|
1222
|
+
"""
|
1223
|
+
return self._shape
|
1224
|
+
|
1225
|
+
class ParamSpecTypedDict(TypedDict):
|
1226
|
+
name: str
|
1227
|
+
type: str
|
1228
|
+
default: Union[DataType, list[DataType], None]
|
1229
|
+
shape: Optional[tuple[int, ...]]
|
1230
|
+
|
1231
|
+
def to_dict(self) -> ParamSpecTypedDict:
|
1232
|
+
if self.shape is None:
|
1233
|
+
if isinstance(self.dtype, DataType) and self.dtype.name == "datetime":
|
1234
|
+
default_value = self.default.isoformat()
|
1235
|
+
else:
|
1236
|
+
default_value = self.default
|
1237
|
+
elif self.shape == (-1,):
|
1238
|
+
default_value = (
|
1239
|
+
[v.isoformat() for v in self.default]
|
1240
|
+
if self.dtype.name == "datetime"
|
1241
|
+
else self.default
|
1242
|
+
)
|
1243
|
+
result = {
|
1244
|
+
"name": self.name,
|
1245
|
+
"default": default_value,
|
1246
|
+
"shape": self.shape,
|
1247
|
+
}
|
1248
|
+
if isinstance(self.dtype, DataType):
|
1249
|
+
type_dict = {"type": self.dtype.name}
|
1250
|
+
elif isinstance(self.dtype, Object):
|
1251
|
+
type_dict = self.dtype.to_dict()
|
1252
|
+
result.update(type_dict)
|
1253
|
+
return result
|
1254
|
+
|
1255
|
+
def __eq__(self, other) -> bool:
|
1256
|
+
if isinstance(other, ParamSpec):
|
1257
|
+
return (
|
1258
|
+
self.name == other.name
|
1259
|
+
and self.dtype == other.dtype
|
1260
|
+
and self.default == other.default
|
1261
|
+
and self.shape == other.shape
|
1262
|
+
)
|
1263
|
+
return False
|
1264
|
+
|
1265
|
+
def __repr__(self) -> str:
|
1266
|
+
shape = f" (shape: {self.shape})" if self.shape is not None else ""
|
1267
|
+
return f"{self.name!r}: {self.dtype!r} (default: {self.default}){shape}"
|
1268
|
+
|
1269
|
+
@classmethod
|
1270
|
+
def from_json_dict(cls, **kwargs):
|
1271
|
+
"""
|
1272
|
+
Deserialize from a json loaded dictionary.
|
1273
|
+
The dictionary is expected to contain `name`, `type` and `default` keys.
|
1274
|
+
"""
|
1275
|
+
# For backward compatibility, we accept both `type` and `dtype` keys
|
1276
|
+
required_keys1 = {"name", "dtype", "default"}
|
1277
|
+
required_keys2 = {"name", "type", "default"}
|
1278
|
+
|
1279
|
+
if not (required_keys1.issubset(kwargs) or required_keys2.issubset(kwargs)):
|
1280
|
+
raise MlflowException.invalid_parameter_value(
|
1281
|
+
"Missing keys in ParamSpec JSON. Expected to find "
|
1282
|
+
"keys `name`, `type`(or `dtype`) and `default`. "
|
1283
|
+
f"Received keys: {kwargs.keys()}"
|
1284
|
+
)
|
1285
|
+
dtype = kwargs.get("type") or kwargs.get("dtype")
|
1286
|
+
dtype = Object.from_json_dict(**kwargs) if dtype == OBJECT_TYPE else DataType[dtype]
|
1287
|
+
return cls(
|
1288
|
+
name=str(kwargs["name"]),
|
1289
|
+
dtype=dtype,
|
1290
|
+
default=kwargs["default"],
|
1291
|
+
shape=kwargs.get("shape"),
|
1292
|
+
)
|
1293
|
+
|
1294
|
+
|
1295
|
+
class ParamSchema:
|
1296
|
+
"""
|
1297
|
+
Specification of parameters applicable to the model.
|
1298
|
+
ParamSchema is represented as a list of :py:class:`ParamSpec`.
|
1299
|
+
"""
|
1300
|
+
|
1301
|
+
def __init__(self, params: list[ParamSpec]):
|
1302
|
+
if not all(isinstance(x, ParamSpec) for x in params):
|
1303
|
+
raise MlflowException.invalid_parameter_value(
|
1304
|
+
f"ParamSchema inputs only accept {ParamSchema.__class__}"
|
1305
|
+
)
|
1306
|
+
if duplicates := self._find_duplicates(params):
|
1307
|
+
raise MlflowException.invalid_parameter_value(
|
1308
|
+
f"Duplicated parameters found in schema: {duplicates}"
|
1309
|
+
)
|
1310
|
+
self._params = params
|
1311
|
+
|
1312
|
+
@staticmethod
|
1313
|
+
def _find_duplicates(params: list[ParamSpec]) -> list[str]:
|
1314
|
+
param_names = [param_spec.name for param_spec in params]
|
1315
|
+
uniq_param = set()
|
1316
|
+
duplicates = []
|
1317
|
+
for name in param_names:
|
1318
|
+
if name in uniq_param:
|
1319
|
+
duplicates.append(name)
|
1320
|
+
else:
|
1321
|
+
uniq_param.add(name)
|
1322
|
+
return duplicates
|
1323
|
+
|
1324
|
+
def __len__(self):
|
1325
|
+
return len(self._params)
|
1326
|
+
|
1327
|
+
def __iter__(self):
|
1328
|
+
return iter(self._params)
|
1329
|
+
|
1330
|
+
@property
|
1331
|
+
def params(self) -> list[ParamSpec]:
|
1332
|
+
"""Representation of ParamSchema as a list of ParamSpec."""
|
1333
|
+
return self._params
|
1334
|
+
|
1335
|
+
def to_json(self) -> str:
|
1336
|
+
"""Serialize into json string."""
|
1337
|
+
return json.dumps(self.to_dict())
|
1338
|
+
|
1339
|
+
@classmethod
|
1340
|
+
def from_json(cls, json_str: str):
|
1341
|
+
"""Deserialize from a json string."""
|
1342
|
+
return cls([ParamSpec.from_json_dict(**x) for x in json.loads(json_str)])
|
1343
|
+
|
1344
|
+
def to_dict(self) -> list[dict[str, Any]]:
|
1345
|
+
"""Serialize into a jsonable dictionary."""
|
1346
|
+
return [x.to_dict() for x in self.params]
|
1347
|
+
|
1348
|
+
def __eq__(self, other) -> bool:
|
1349
|
+
if isinstance(other, ParamSchema):
|
1350
|
+
return self.params == other.params
|
1351
|
+
return False
|
1352
|
+
|
1353
|
+
def __repr__(self) -> str:
|
1354
|
+
return repr(self.params)
|
1355
|
+
|
1356
|
+
|
1357
|
+
def _map_field_type(field):
|
1358
|
+
field_type_mapping = {
|
1359
|
+
bool: "boolean",
|
1360
|
+
int: "long", # int is mapped to long to support 64-bit integers
|
1361
|
+
builtins.float: "float",
|
1362
|
+
str: "string",
|
1363
|
+
bytes: "binary",
|
1364
|
+
dt.date: "datetime",
|
1365
|
+
}
|
1366
|
+
return field_type_mapping.get(field)
|
1367
|
+
|
1368
|
+
|
1369
|
+
def _get_dataclass_annotations(cls) -> dict[str, Any]:
|
1370
|
+
"""
|
1371
|
+
Given a dataclass or an instance of one, collect annotations from it and all its parent
|
1372
|
+
dataclasses.
|
1373
|
+
"""
|
1374
|
+
if not is_dataclass(cls):
|
1375
|
+
raise TypeError(f"{cls.__name__} is not a dataclass.")
|
1376
|
+
|
1377
|
+
annotations = {}
|
1378
|
+
effective_class = cls if isinstance(cls, type) else type(cls)
|
1379
|
+
|
1380
|
+
# Reverse MRO so subclass overrides are captured last
|
1381
|
+
for base in reversed(effective_class.__mro__):
|
1382
|
+
# Only capture supers that are dataclasses
|
1383
|
+
if is_dataclass(base) and hasattr(base, "__annotations__"):
|
1384
|
+
annotations.update(base.__annotations__)
|
1385
|
+
return annotations
|
1386
|
+
|
1387
|
+
|
1388
|
+
@experimental(version="2.13.0")
|
1389
|
+
def convert_dataclass_to_schema(dataclass):
|
1390
|
+
"""
|
1391
|
+
Converts a given dataclass into a Schema object. The dataclass must include type hints
|
1392
|
+
for all its fields. Fields can be of basic types, other dataclasses, or Lists/Optional of
|
1393
|
+
these types. Union types are not supported. Only the top-level fields are directly converted
|
1394
|
+
to ColSpecs, while nested fields are converted into nested Object types.
|
1395
|
+
"""
|
1396
|
+
|
1397
|
+
inputs = []
|
1398
|
+
|
1399
|
+
for field_name, field_type in _get_dataclass_annotations(dataclass).items():
|
1400
|
+
# Determine the type and handle Optional and List correctly
|
1401
|
+
is_optional = False
|
1402
|
+
effective_type = field_type
|
1403
|
+
|
1404
|
+
if get_origin(field_type) == Union:
|
1405
|
+
if type(None) in get_args(field_type) and len(get_args(field_type)) == 2:
|
1406
|
+
# This is an Optional type; determine the effective type excluding None
|
1407
|
+
is_optional = True
|
1408
|
+
effective_type = next(t for t in get_args(field_type) if t is not type(None))
|
1409
|
+
else:
|
1410
|
+
raise MlflowException(
|
1411
|
+
"Only Optional[...] is supported as a Union type in dataclass fields"
|
1412
|
+
)
|
1413
|
+
|
1414
|
+
if get_origin(effective_type) == list:
|
1415
|
+
# It's a list, check the type within the list
|
1416
|
+
list_type = get_args(effective_type)[0]
|
1417
|
+
if is_dataclass(list_type):
|
1418
|
+
dtype = _convert_dataclass_to_nested_object(list_type) # Convert to nested Object
|
1419
|
+
inputs.append(
|
1420
|
+
ColSpec(type=Array(dtype=dtype), name=field_name, required=not is_optional)
|
1421
|
+
)
|
1422
|
+
else:
|
1423
|
+
if dtype := _map_field_type(list_type):
|
1424
|
+
inputs.append(
|
1425
|
+
ColSpec(
|
1426
|
+
type=Array(dtype=dtype),
|
1427
|
+
name=field_name,
|
1428
|
+
required=not is_optional,
|
1429
|
+
)
|
1430
|
+
)
|
1431
|
+
else:
|
1432
|
+
raise MlflowException(
|
1433
|
+
f"List field type {list_type} is not supported in dataclass"
|
1434
|
+
f" {dataclass.__name__}"
|
1435
|
+
)
|
1436
|
+
elif is_dataclass(effective_type):
|
1437
|
+
# It's a nested dataclass
|
1438
|
+
dtype = _convert_dataclass_to_nested_object(effective_type) # Convert to nested Object
|
1439
|
+
inputs.append(
|
1440
|
+
ColSpec(
|
1441
|
+
type=dtype,
|
1442
|
+
name=field_name,
|
1443
|
+
required=not is_optional,
|
1444
|
+
)
|
1445
|
+
)
|
1446
|
+
# confirm the effective type is a basic type
|
1447
|
+
elif dtype := _map_field_type(effective_type):
|
1448
|
+
# It's a basic type
|
1449
|
+
inputs.append(
|
1450
|
+
ColSpec(
|
1451
|
+
type=dtype,
|
1452
|
+
name=field_name,
|
1453
|
+
required=not is_optional,
|
1454
|
+
)
|
1455
|
+
)
|
1456
|
+
else:
|
1457
|
+
raise MlflowException(
|
1458
|
+
f"Unsupported field type {effective_type} in dataclass {dataclass.__name__}"
|
1459
|
+
)
|
1460
|
+
|
1461
|
+
return Schema(inputs=inputs)
|
1462
|
+
|
1463
|
+
|
1464
|
+
def _convert_dataclass_to_nested_object(dataclass):
|
1465
|
+
"""
|
1466
|
+
Convert a nested dataclass to an Object type used within a ColSpec.
|
1467
|
+
"""
|
1468
|
+
properties = []
|
1469
|
+
for field_name, field_type in dataclass.__annotations__.items():
|
1470
|
+
properties.append(_convert_field_to_property(field_name, field_type))
|
1471
|
+
return Object(properties=properties)
|
1472
|
+
|
1473
|
+
|
1474
|
+
def _convert_field_to_property(field_name, field_type):
|
1475
|
+
"""
|
1476
|
+
Helper function to convert a single field to a Property object suitable for inclusion in an
|
1477
|
+
Object.
|
1478
|
+
"""
|
1479
|
+
|
1480
|
+
is_optional = False
|
1481
|
+
effective_type = field_type
|
1482
|
+
|
1483
|
+
if get_origin(field_type) == Union and type(None) in get_args(field_type):
|
1484
|
+
is_optional = True
|
1485
|
+
effective_type = next(t for t in get_args(field_type) if t is not type(None))
|
1486
|
+
|
1487
|
+
if get_origin(effective_type) == list:
|
1488
|
+
list_type = get_args(effective_type)[0]
|
1489
|
+
return Property(
|
1490
|
+
name=field_name,
|
1491
|
+
dtype=Array(dtype=_map_field_type(list_type)),
|
1492
|
+
required=not is_optional,
|
1493
|
+
)
|
1494
|
+
elif is_dataclass(effective_type):
|
1495
|
+
return Property(
|
1496
|
+
name=field_name,
|
1497
|
+
dtype=_convert_dataclass_to_nested_object(effective_type),
|
1498
|
+
required=not is_optional,
|
1499
|
+
)
|
1500
|
+
else:
|
1501
|
+
return Property(
|
1502
|
+
name=field_name,
|
1503
|
+
dtype=_map_field_type(effective_type),
|
1504
|
+
required=not is_optional,
|
1505
|
+
)
|