genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,580 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import tempfile
|
4
|
+
import warnings
|
5
|
+
|
6
|
+
from packaging.version import Version
|
7
|
+
|
8
|
+
import mlflow.pytorch
|
9
|
+
from mlflow.exceptions import MlflowException
|
10
|
+
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS
|
11
|
+
from mlflow.utils.autologging_utils import (
|
12
|
+
BatchMetricsLogger,
|
13
|
+
ExceptionSafeAbstractClass,
|
14
|
+
MlflowAutologgingQueueingClient,
|
15
|
+
disable_autologging,
|
16
|
+
get_autologging_config,
|
17
|
+
)
|
18
|
+
from mlflow.utils.checkpoint_utils import MlflowModelCheckpointCallbackBase
|
19
|
+
|
20
|
+
logging.basicConfig(level=logging.ERROR)
|
21
|
+
MIN_REQ_VERSION = Version(_ML_PACKAGE_VERSIONS["pytorch-lightning"]["autologging"]["minimum"])
|
22
|
+
MAX_REQ_VERSION = Version(_ML_PACKAGE_VERSIONS["pytorch-lightning"]["autologging"]["maximum"])
|
23
|
+
|
24
|
+
import pytorch_lightning as pl
|
25
|
+
from pytorch_lightning.utilities import rank_zero_only
|
26
|
+
|
27
|
+
# The following are the downsides of using PyTorch Lightning's built-in MlflowLogger.
|
28
|
+
# 1. MlflowLogger doesn't provide a mechanism to store an entire model into mlflow.
|
29
|
+
# Only model checkpoint is saved.
|
30
|
+
# 2. For storing the model into mlflow `mlflow.pytorch` library is used
|
31
|
+
# and the library expects `mlflow` object to be instantiated.
|
32
|
+
# In case of MlflowLogger, Run management is completely controlled by the class and
|
33
|
+
# hence mlflow object needs to be reinstantiated by setting
|
34
|
+
# tracking uri, experiment_id and run_id which may lead to a race condition.
|
35
|
+
# TODO: Replace __MlflowPLCallback with Pytorch Lightning's built-in MlflowLogger
|
36
|
+
# once the above mentioned issues have been addressed
|
37
|
+
|
38
|
+
_logger = logging.getLogger(__name__)
|
39
|
+
|
40
|
+
_pl_version = Version(pl.__version__)
|
41
|
+
if _pl_version < Version("1.5.0"):
|
42
|
+
from pytorch_lightning.core.memory import ModelSummary
|
43
|
+
else:
|
44
|
+
from pytorch_lightning.utilities.model_summary import ModelSummary
|
45
|
+
|
46
|
+
|
47
|
+
def _get_optimizer_name(optimizer):
|
48
|
+
"""
|
49
|
+
In pytorch-lightning 1.1.0, `LightningOptimizer` was introduced:
|
50
|
+
https://github.com/PyTorchLightning/pytorch-lightning/pull/4658
|
51
|
+
|
52
|
+
If a user sets `enable_pl_optimizer` to True when instantiating a `Trainer` object,
|
53
|
+
each optimizer will be wrapped by `LightningOptimizer`:
|
54
|
+
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.html
|
55
|
+
#pytorch_lightning.trainer.trainer.Trainer.params.enable_pl_optimizer
|
56
|
+
"""
|
57
|
+
if Version(pl.__version__) < Version("1.1.0"):
|
58
|
+
return optimizer.__class__.__name__
|
59
|
+
else:
|
60
|
+
from pytorch_lightning.core.optimizer import LightningOptimizer
|
61
|
+
|
62
|
+
return (
|
63
|
+
optimizer._optimizer.__class__.__name__
|
64
|
+
if isinstance(optimizer, LightningOptimizer)
|
65
|
+
else optimizer.__class__.__name__
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
class __MlflowPLCallback(pl.Callback, metaclass=ExceptionSafeAbstractClass):
|
70
|
+
"""
|
71
|
+
Callback for auto-logging metrics and parameters.
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self, client, metrics_logger, run_id, log_models, log_every_n_epoch, log_every_n_step
|
76
|
+
):
|
77
|
+
if log_every_n_step and _pl_version < Version("1.1.0"):
|
78
|
+
raise MlflowException(
|
79
|
+
"log_every_n_step is only supported for PyTorch-Lightning >= 1.1.0"
|
80
|
+
)
|
81
|
+
self.early_stopping = False
|
82
|
+
self.client = client
|
83
|
+
self.metrics_logger = metrics_logger
|
84
|
+
self.run_id = run_id
|
85
|
+
self.log_models = log_models
|
86
|
+
self.log_every_n_epoch = log_every_n_epoch
|
87
|
+
self.log_every_n_step = log_every_n_step
|
88
|
+
self._global_steps_per_training_step = 1
|
89
|
+
# Sets for tracking which metrics are logged on steps and which are logged on epochs
|
90
|
+
self._step_metrics = set()
|
91
|
+
self._epoch_metrics = set()
|
92
|
+
|
93
|
+
def _log_metrics(self, trainer, step, metric_items):
|
94
|
+
# pytorch-lightning runs a few steps of validation in the beginning of training
|
95
|
+
# as a sanity check to catch bugs without having to wait for the training routine
|
96
|
+
# to complete. During this check, we should skip logging metrics.
|
97
|
+
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#num-sanity-val-steps
|
98
|
+
sanity_checking = (
|
99
|
+
# `running_sanity_check` has been renamed to `sanity_checking`:
|
100
|
+
# https://github.com/PyTorchLightning/pytorch-lightning/pull/9209
|
101
|
+
trainer.sanity_checking
|
102
|
+
if Version(pl.__version__) > Version("1.4.5")
|
103
|
+
else trainer.running_sanity_check
|
104
|
+
)
|
105
|
+
if sanity_checking:
|
106
|
+
return
|
107
|
+
|
108
|
+
# Cast metric value as float before passing into logger.
|
109
|
+
metrics = {x[0]: float(x[1]) for x in metric_items}
|
110
|
+
self.metrics_logger.record_metrics(metrics, step)
|
111
|
+
|
112
|
+
def _log_epoch_metrics(self, trainer, pl_module):
|
113
|
+
# `trainer.callback_metrics` contains both training and validation metrics
|
114
|
+
# and includes metrics logged on steps and epochs.
|
115
|
+
# If we have logged any metrics on a step basis in mlflow, we exclude these from the
|
116
|
+
# epoch level metrics to prevent mixing epoch and step based values.
|
117
|
+
metric_items = [
|
118
|
+
(name, val)
|
119
|
+
for (name, val) in trainer.callback_metrics.items()
|
120
|
+
if name not in self._step_metrics
|
121
|
+
]
|
122
|
+
# Record which metrics are logged on epochs, so we don't try to log these on steps
|
123
|
+
self._epoch_metrics.update(name for (name, _) in metric_items)
|
124
|
+
if (pl_module.current_epoch + 1) % self.log_every_n_epoch == 0:
|
125
|
+
self._log_metrics(trainer, pl_module.current_epoch, metric_items)
|
126
|
+
|
127
|
+
_pl_version = Version(pl.__version__)
|
128
|
+
|
129
|
+
# In pytorch-lightning >= 1.4.0, validation is run inside the training epoch and
|
130
|
+
# `trainer.callback_metrics` contains both training and validation metrics of the
|
131
|
+
# current training epoch when `on_train_epoch_end` is called:
|
132
|
+
# https://github.com/PyTorchLightning/pytorch-lightning/pull/7357
|
133
|
+
if _pl_version >= Version("1.4.0dev"):
|
134
|
+
|
135
|
+
@rank_zero_only
|
136
|
+
def on_train_epoch_end(self, trainer, pl_module, *args):
|
137
|
+
self._log_epoch_metrics(trainer, pl_module)
|
138
|
+
|
139
|
+
# In pytorch-lightning >= 1.2.0, logging metrics in `on_epoch_end` results in duplicate
|
140
|
+
# metrics records because `on_epoch_end` is called after both train and validation
|
141
|
+
# epochs (related PR: https://github.com/PyTorchLightning/pytorch-lightning/pull/5986)
|
142
|
+
# As a workaround, use `on_train_epoch_end` and `on_validation_epoch_end` instead
|
143
|
+
# in pytorch-lightning >= 1.2.0.
|
144
|
+
elif _pl_version >= Version("1.2.0"):
|
145
|
+
# NB: Override `on_train_epoch_end` with an additional `*args` parameter for
|
146
|
+
# compatibility with versions of pytorch-lightning <= 1.2.0, which required an
|
147
|
+
# `outputs` argument that was not used and is no longer defined in
|
148
|
+
# pytorch-lightning >= 1.3.0
|
149
|
+
|
150
|
+
@rank_zero_only
|
151
|
+
def on_train_epoch_end(self, trainer, pl_module, *args):
|
152
|
+
"""
|
153
|
+
Log loss and other metrics values after each train epoch
|
154
|
+
|
155
|
+
Args:
|
156
|
+
trainer: pytorch lightning trainer instance
|
157
|
+
pl_module: pytorch lightning base module
|
158
|
+
args: additional positional arguments
|
159
|
+
"""
|
160
|
+
# If validation loop is enabled (meaning `validation_step` is overridden),
|
161
|
+
# log metrics in `on_validaion_epoch_end` to avoid logging the same metrics
|
162
|
+
# records twice
|
163
|
+
if not trainer.enable_validation:
|
164
|
+
self._log_epoch_metrics(trainer, pl_module)
|
165
|
+
|
166
|
+
@rank_zero_only
|
167
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
168
|
+
"""
|
169
|
+
Log loss and other metrics values after each validation epoch
|
170
|
+
|
171
|
+
Args:
|
172
|
+
trainer: pytorch lightning trainer instance
|
173
|
+
pl_module: pytorch lightning base module
|
174
|
+
"""
|
175
|
+
self._log_epoch_metrics(trainer, pl_module)
|
176
|
+
|
177
|
+
else:
|
178
|
+
|
179
|
+
@rank_zero_only
|
180
|
+
def on_epoch_end(self, trainer, pl_module):
|
181
|
+
"""
|
182
|
+
Log loss and other metrics values after each epoch
|
183
|
+
|
184
|
+
Args:
|
185
|
+
trainer: pytorch lightning trainer instance
|
186
|
+
pl_module: pytorch lightning base module
|
187
|
+
"""
|
188
|
+
self._log_epoch_metrics(trainer, pl_module)
|
189
|
+
|
190
|
+
@rank_zero_only
|
191
|
+
def on_train_batch_end(self, trainer, pl_module, *args):
|
192
|
+
"""
|
193
|
+
Log metric values after each step
|
194
|
+
|
195
|
+
Args:
|
196
|
+
trainer: pytorch lightning trainer instance
|
197
|
+
pl_module: pytorch lightning base module
|
198
|
+
args: additional positional arguments
|
199
|
+
"""
|
200
|
+
if not self.log_every_n_step:
|
201
|
+
return
|
202
|
+
# When logging at the end of a batch step, we only want to log metrics that are logged
|
203
|
+
# on steps. For forked metrics (metrics logged on both steps and epochs), we exclude the
|
204
|
+
# metric with the non-forked name (eg. "loss" when we have "loss", "loss_step" and
|
205
|
+
# "loss_epoch") so that this is only logged on epochs. We also record which metrics
|
206
|
+
# we've logged per step, so we can later exclude these from metrics logged on epochs.
|
207
|
+
metrics = _get_step_metrics(trainer)
|
208
|
+
metric_items = [
|
209
|
+
(name, val)
|
210
|
+
for (name, val) in metrics.items()
|
211
|
+
if (name not in self._epoch_metrics) and (f"{name}_step" not in metrics.keys())
|
212
|
+
]
|
213
|
+
self._step_metrics.update(name for (name, _) in metric_items)
|
214
|
+
step = trainer.global_step
|
215
|
+
if ((step // self._global_steps_per_training_step) + 1) % self.log_every_n_step == 0:
|
216
|
+
self._log_metrics(trainer, step, metric_items)
|
217
|
+
|
218
|
+
@rank_zero_only
|
219
|
+
def on_train_start(self, trainer, pl_module):
|
220
|
+
"""
|
221
|
+
Logs Optimizer related metrics when the train begins
|
222
|
+
|
223
|
+
Args:
|
224
|
+
trainer: pytorch lightning trainer instance
|
225
|
+
pl_module: pytorch lightning base module
|
226
|
+
"""
|
227
|
+
self.client.set_tags(self.run_id, {"Mode": "training"})
|
228
|
+
|
229
|
+
params = {"epochs": trainer.max_epochs}
|
230
|
+
|
231
|
+
# TODO For logging optimizer params - Following scenarios are to revisited.
|
232
|
+
# 1. In the current scenario, only the first optimizer details are logged.
|
233
|
+
# Code to be enhanced to log params when multiple optimizers are used.
|
234
|
+
# 2. mlflow.log_params is used to store optimizer default values into mlflow.
|
235
|
+
# The keys in default dictionary are too short, Ex: (lr - learning_rate).
|
236
|
+
# Efficient mapping technique needs to be introduced
|
237
|
+
# to rename the optimizer parameters based on keys in default dictionary.
|
238
|
+
|
239
|
+
if hasattr(trainer, "optimizers"):
|
240
|
+
# Lightning >= 1.6.0 increments the global step every time an optimizer is stepped.
|
241
|
+
# We assume every optimizer will be stepped in each training step.
|
242
|
+
if _pl_version >= Version("1.6.0"):
|
243
|
+
self._global_steps_per_training_step = len(trainer.optimizers)
|
244
|
+
optimizer = trainer.optimizers[0]
|
245
|
+
params["optimizer_name"] = _get_optimizer_name(optimizer)
|
246
|
+
|
247
|
+
if hasattr(optimizer, "defaults"):
|
248
|
+
params.update(optimizer.defaults)
|
249
|
+
|
250
|
+
self.client.log_params(self.run_id, params)
|
251
|
+
self.client.flush(synchronous=True)
|
252
|
+
|
253
|
+
@rank_zero_only
|
254
|
+
def on_train_end(self, trainer, pl_module):
|
255
|
+
"""
|
256
|
+
Logs the model checkpoint into mlflow - models folder on the training end
|
257
|
+
|
258
|
+
|
259
|
+
Args:
|
260
|
+
trainer: pytorch lightning trainer instance
|
261
|
+
pl_module: pytorch lightning base module
|
262
|
+
"""
|
263
|
+
# manually flush any remaining metadata from training
|
264
|
+
self.metrics_logger.flush()
|
265
|
+
self.client.flush(synchronous=True)
|
266
|
+
|
267
|
+
@rank_zero_only
|
268
|
+
def on_test_end(self, trainer, pl_module):
|
269
|
+
"""
|
270
|
+
Logs accuracy and other relevant metrics on the testing end
|
271
|
+
|
272
|
+
Args:
|
273
|
+
trainer: pytorch lightning trainer instance
|
274
|
+
pl_module: pytorch lightning base module
|
275
|
+
"""
|
276
|
+
self.client.set_tags(self.run_id, {"Mode": "testing"})
|
277
|
+
self.client.flush(synchronous=True)
|
278
|
+
|
279
|
+
self.metrics_logger.record_metrics(
|
280
|
+
{key: float(value) for key, value in trainer.callback_metrics.items()}
|
281
|
+
)
|
282
|
+
self.metrics_logger.flush()
|
283
|
+
|
284
|
+
|
285
|
+
class MlflowModelCheckpointCallback(pl.Callback, MlflowModelCheckpointCallbackBase):
|
286
|
+
"""Callback for auto-logging pytorch-lightning model checkpoints to MLflow.
|
287
|
+
This callback implementation only supports pytorch-lightning >= 1.6.0.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
monitor: In automatic model checkpointing, the metric name to monitor if
|
291
|
+
you set `model_checkpoint_save_best_only` to True.
|
292
|
+
save_best_only: If True, automatic model checkpointing only saves when
|
293
|
+
the model is considered the "best" model according to the quantity
|
294
|
+
monitored and previous checkpoint model is overwritten.
|
295
|
+
mode: one of {"min", "max"}. In automatic model checkpointing,
|
296
|
+
if save_best_only=True, the decision to overwrite the current save file is made
|
297
|
+
based on either the maximization or the minimization of the monitored quantity.
|
298
|
+
save_weights_only: In automatic model checkpointing, if True, then
|
299
|
+
only the model's weights will be saved. Otherwise, the optimizer states,
|
300
|
+
lr-scheduler states, etc are added in the checkpoint too.
|
301
|
+
save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
|
302
|
+
saves the model after each epoch. When using integer, the callback
|
303
|
+
saves the model at end of this many batches. Note that if the saving isn't
|
304
|
+
aligned to epochs, the monitored metric may potentially be less reliable (it
|
305
|
+
could reflect as little as 1 batch, since the metrics get reset
|
306
|
+
every epoch). Defaults to `"epoch"`.
|
307
|
+
|
308
|
+
.. code-block:: python
|
309
|
+
:caption: Example
|
310
|
+
|
311
|
+
import mlflow
|
312
|
+
from mlflow.pytorch import MlflowModelCheckpointCallback
|
313
|
+
from pytorch_lightning import Trainer
|
314
|
+
|
315
|
+
mlflow.pytorch.autolog(checkpoint=True)
|
316
|
+
|
317
|
+
model = MyLightningModuleNet() # A custom-pytorch lightning model
|
318
|
+
train_loader = create_train_dataset_loader()
|
319
|
+
|
320
|
+
mlflow_checkpoint_callback = MlflowModelCheckpointCallback()
|
321
|
+
|
322
|
+
trainer = Trainer(callbacks=[mlflow_checkpoint_callback])
|
323
|
+
|
324
|
+
with mlflow.start_run() as run:
|
325
|
+
trainer.fit(model, train_loader)
|
326
|
+
|
327
|
+
"""
|
328
|
+
|
329
|
+
def __init__(
|
330
|
+
self,
|
331
|
+
monitor="val_loss",
|
332
|
+
mode="min",
|
333
|
+
save_best_only=True,
|
334
|
+
save_weights_only=False,
|
335
|
+
save_freq="epoch",
|
336
|
+
):
|
337
|
+
super().__init__(
|
338
|
+
checkpoint_file_suffix=".pth",
|
339
|
+
monitor=monitor,
|
340
|
+
mode=mode,
|
341
|
+
save_best_only=save_best_only,
|
342
|
+
save_weights_only=save_weights_only,
|
343
|
+
save_freq=save_freq,
|
344
|
+
)
|
345
|
+
self.trainer = None
|
346
|
+
|
347
|
+
def save_checkpoint(self, filepath: str):
|
348
|
+
# Note: `trainer.save_checkpoint` implementation contains invocation of
|
349
|
+
# `self.strategy.barrier("Trainer.save_checkpoint")`,
|
350
|
+
# in DDP training, this callback is only invoked in rank 0 process,
|
351
|
+
# the `barrier` invocation causes deadlock,
|
352
|
+
# so I implement `save_checkpoint` instead of
|
353
|
+
# calling `trainer.save_checkpoint`.
|
354
|
+
checkpoint = self.trainer._checkpoint_connector.dump_checkpoint(self.save_weights_only)
|
355
|
+
self.trainer.strategy.save_checkpoint(checkpoint, filepath)
|
356
|
+
|
357
|
+
@rank_zero_only
|
358
|
+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
359
|
+
self.trainer = trainer
|
360
|
+
|
361
|
+
@rank_zero_only
|
362
|
+
def on_train_batch_end(
|
363
|
+
self,
|
364
|
+
trainer: "pl.Trainer",
|
365
|
+
pl_module: "pl.LightningModule",
|
366
|
+
outputs,
|
367
|
+
batch,
|
368
|
+
batch_idx,
|
369
|
+
) -> None:
|
370
|
+
if isinstance(self.save_freq, int) and (
|
371
|
+
trainer.global_step > 0 and trainer.global_step % self.save_freq == 0
|
372
|
+
):
|
373
|
+
self.check_and_save_checkpoint_if_needed(
|
374
|
+
current_epoch=trainer.current_epoch,
|
375
|
+
global_step=trainer.global_step,
|
376
|
+
metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()},
|
377
|
+
)
|
378
|
+
|
379
|
+
@rank_zero_only
|
380
|
+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
381
|
+
if self.save_freq == "epoch":
|
382
|
+
self.check_and_save_checkpoint_if_needed(
|
383
|
+
current_epoch=trainer.current_epoch,
|
384
|
+
global_step=trainer.global_step,
|
385
|
+
metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()},
|
386
|
+
)
|
387
|
+
|
388
|
+
|
389
|
+
# PyTorch-Lightning refactored the LoggerConnector class in version 1.4.0 and made metrics
|
390
|
+
# update on demand. Prior to this, the metrics from the current step were not available to
|
391
|
+
# callbacks immediately, so the view of metrics was off by one step.
|
392
|
+
# To avoid this problem, we access the metrics via the logger_connector for older versions.
|
393
|
+
if _pl_version >= Version("1.4.0"):
|
394
|
+
|
395
|
+
def _get_step_metrics(trainer):
|
396
|
+
return trainer.callback_metrics
|
397
|
+
|
398
|
+
else:
|
399
|
+
|
400
|
+
def _get_step_metrics(trainer):
|
401
|
+
return trainer.logger_connector.cached_results.get_latest_batch_log_metrics()
|
402
|
+
|
403
|
+
|
404
|
+
def _log_early_stop_params(early_stop_callback, client, run_id):
|
405
|
+
"""
|
406
|
+
Logs early stopping configuration parameters to MLflow.
|
407
|
+
|
408
|
+
Args:
|
409
|
+
early_stop_callback: The early stopping callback instance used during training.
|
410
|
+
client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
|
411
|
+
run_id: The ID of the MLflow Run to which to log configuration parameters.
|
412
|
+
"""
|
413
|
+
client.log_params(
|
414
|
+
run_id,
|
415
|
+
{
|
416
|
+
p: getattr(early_stop_callback, p)
|
417
|
+
for p in ["monitor", "mode", "patience", "min_delta", "stopped_epoch"]
|
418
|
+
if hasattr(early_stop_callback, p)
|
419
|
+
},
|
420
|
+
)
|
421
|
+
|
422
|
+
|
423
|
+
def _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=None):
|
424
|
+
"""
|
425
|
+
Logs early stopping behavior results (e.g. stopped epoch) as metrics to MLflow.
|
426
|
+
|
427
|
+
Args:
|
428
|
+
early_stop_callback: The early stopping callback instance used during training.
|
429
|
+
client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
|
430
|
+
run_id: The ID of the MLflow Run to which to log configuration parameters.
|
431
|
+
model_id: The ID of the LoggedModel to which the metrics are associated.
|
432
|
+
"""
|
433
|
+
if early_stop_callback.stopped_epoch == 0:
|
434
|
+
return
|
435
|
+
|
436
|
+
metrics = {
|
437
|
+
"stopped_epoch": early_stop_callback.stopped_epoch,
|
438
|
+
"restored_epoch": early_stop_callback.stopped_epoch - max(1, early_stop_callback.patience),
|
439
|
+
}
|
440
|
+
|
441
|
+
if hasattr(early_stop_callback, "best_score"):
|
442
|
+
metrics["best_score"] = float(early_stop_callback.best_score)
|
443
|
+
|
444
|
+
if hasattr(early_stop_callback, "wait_count"):
|
445
|
+
metrics["wait_count"] = early_stop_callback.wait_count
|
446
|
+
|
447
|
+
client.log_metrics(run_id, metrics, model_id=model_id)
|
448
|
+
|
449
|
+
|
450
|
+
def patched_fit(original, self, *args, **kwargs):
|
451
|
+
"""
|
452
|
+
A patched implementation of `pytorch_lightning.Trainer.fit` which enables logging the
|
453
|
+
following parameters, metrics and artifacts:
|
454
|
+
|
455
|
+
- Training epochs
|
456
|
+
- Optimizer parameters
|
457
|
+
- `EarlyStoppingCallback`_ parameters
|
458
|
+
- Metrics stored in `trainer.callback_metrics`
|
459
|
+
- Model checkpoints
|
460
|
+
- Trained model
|
461
|
+
|
462
|
+
.. _EarlyStoppingCallback:
|
463
|
+
https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html
|
464
|
+
"""
|
465
|
+
if not MIN_REQ_VERSION <= _pl_version <= MAX_REQ_VERSION:
|
466
|
+
warnings.warn(
|
467
|
+
"Autologging is known to be compatible with pytorch-lightning versions between "
|
468
|
+
f"{MIN_REQ_VERSION} and {MAX_REQ_VERSION} and may not succeed with packages "
|
469
|
+
"outside this range."
|
470
|
+
)
|
471
|
+
|
472
|
+
with disable_autologging():
|
473
|
+
run_id = mlflow.active_run().info.run_id
|
474
|
+
tracking_uri = mlflow.get_tracking_uri()
|
475
|
+
client = MlflowAutologgingQueueingClient(tracking_uri)
|
476
|
+
|
477
|
+
log_models = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "log_models", True)
|
478
|
+
model_id = None
|
479
|
+
if log_models:
|
480
|
+
model_id = mlflow.initialize_logged_model(name="model").model_id
|
481
|
+
metrics_logger = BatchMetricsLogger(run_id, tracking_uri, model_id=model_id)
|
482
|
+
|
483
|
+
log_every_n_epoch = get_autologging_config(
|
484
|
+
mlflow.pytorch.FLAVOR_NAME, "log_every_n_epoch", 1
|
485
|
+
)
|
486
|
+
log_every_n_step = get_autologging_config(
|
487
|
+
mlflow.pytorch.FLAVOR_NAME, "log_every_n_step", None
|
488
|
+
)
|
489
|
+
|
490
|
+
early_stop_callback = None
|
491
|
+
for callback in self.callbacks:
|
492
|
+
if isinstance(callback, pl.callbacks.early_stopping.EarlyStopping):
|
493
|
+
early_stop_callback = callback
|
494
|
+
_log_early_stop_params(early_stop_callback, client, run_id)
|
495
|
+
|
496
|
+
if not any(isinstance(callbacks, __MlflowPLCallback) for callbacks in self.callbacks):
|
497
|
+
self.callbacks += [
|
498
|
+
__MlflowPLCallback(
|
499
|
+
client, metrics_logger, run_id, log_models, log_every_n_epoch, log_every_n_step
|
500
|
+
)
|
501
|
+
]
|
502
|
+
|
503
|
+
model_checkpoint = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "checkpoint", True)
|
504
|
+
if model_checkpoint:
|
505
|
+
# __MLflowModelCheckpoint only supports pytorch-lightning >= 1.6.0
|
506
|
+
if _pl_version >= Version("1.6.0"):
|
507
|
+
checkpoint_monitor = get_autologging_config(
|
508
|
+
mlflow.pytorch.FLAVOR_NAME, "checkpoint_monitor", "val_loss"
|
509
|
+
)
|
510
|
+
checkpoint_mode = get_autologging_config(
|
511
|
+
mlflow.pytorch.FLAVOR_NAME, "checkpoint_mode", "min"
|
512
|
+
)
|
513
|
+
checkpoint_save_best_only = get_autologging_config(
|
514
|
+
mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_best_only", True
|
515
|
+
)
|
516
|
+
checkpoint_save_weights_only = get_autologging_config(
|
517
|
+
mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_weights_only", False
|
518
|
+
)
|
519
|
+
checkpoint_save_freq = get_autologging_config(
|
520
|
+
mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_freq", "epoch"
|
521
|
+
)
|
522
|
+
|
523
|
+
if not any(
|
524
|
+
isinstance(callbacks, MlflowModelCheckpointCallback)
|
525
|
+
for callbacks in self.callbacks
|
526
|
+
):
|
527
|
+
self.callbacks += [
|
528
|
+
MlflowModelCheckpointCallback(
|
529
|
+
monitor=checkpoint_monitor,
|
530
|
+
mode=checkpoint_mode,
|
531
|
+
save_best_only=checkpoint_save_best_only,
|
532
|
+
save_weights_only=checkpoint_save_weights_only,
|
533
|
+
save_freq=checkpoint_save_freq,
|
534
|
+
)
|
535
|
+
]
|
536
|
+
else:
|
537
|
+
warnings.warn(
|
538
|
+
"Automatic model checkpointing is disabled because this feature only "
|
539
|
+
"supports pytorch-lightning >= 1.6.0."
|
540
|
+
)
|
541
|
+
|
542
|
+
client.flush(synchronous=False)
|
543
|
+
|
544
|
+
result = original(self, *args, **kwargs)
|
545
|
+
|
546
|
+
if early_stop_callback is not None:
|
547
|
+
_log_early_stop_metrics(early_stop_callback, client, run_id, model_id=model_id)
|
548
|
+
|
549
|
+
if Version(pl.__version__) < Version("1.4.0"):
|
550
|
+
summary = str(ModelSummary(self.model, mode="full"))
|
551
|
+
else:
|
552
|
+
summary = str(ModelSummary(self.model, max_depth=-1))
|
553
|
+
|
554
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
555
|
+
summary_file = os.path.join(tempdir, "model_summary.txt")
|
556
|
+
with open(summary_file, "w") as f:
|
557
|
+
f.write(summary)
|
558
|
+
|
559
|
+
mlflow.log_artifact(local_path=summary_file)
|
560
|
+
|
561
|
+
if log_models:
|
562
|
+
registered_model_name = get_autologging_config(
|
563
|
+
mlflow.pytorch.FLAVOR_NAME, "registered_model_name", None
|
564
|
+
)
|
565
|
+
mlflow.pytorch.log_model(
|
566
|
+
self.model,
|
567
|
+
name="model",
|
568
|
+
registered_model_name=registered_model_name,
|
569
|
+
model_id=model_id,
|
570
|
+
)
|
571
|
+
|
572
|
+
if early_stop_callback is not None and self.checkpoint_callback.best_model_path:
|
573
|
+
mlflow.log_artifact(
|
574
|
+
local_path=self.checkpoint_callback.best_model_path,
|
575
|
+
artifact_path="restored_model_checkpoint",
|
576
|
+
)
|
577
|
+
|
578
|
+
client.flush(synchronous=True)
|
579
|
+
|
580
|
+
return result
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
import mlflow
|
4
|
+
from mlflow.entities import Metric, Param
|
5
|
+
from mlflow.tracking import MlflowClient
|
6
|
+
from mlflow.utils.autologging_utils.metrics_queue import (
|
7
|
+
add_to_metrics_queue,
|
8
|
+
flush_metrics_queue,
|
9
|
+
)
|
10
|
+
|
11
|
+
|
12
|
+
def patched_add_hparams(original, self, hparam_dict, metric_dict, *args, **kwargs):
|
13
|
+
"""use a synchronous call here since this is going to get called very infrequently."""
|
14
|
+
|
15
|
+
run = mlflow.active_run()
|
16
|
+
|
17
|
+
if run is not None and hparam_dict:
|
18
|
+
run_id = run.info.run_id
|
19
|
+
# str() is required by mlflow :(
|
20
|
+
params_arr = [Param(key, str(value)) for key, value in hparam_dict.items()]
|
21
|
+
metrics_arr = [
|
22
|
+
Metric(key, value, int(time.time() * 1000), 0) for key, value in metric_dict.items()
|
23
|
+
]
|
24
|
+
MlflowClient().log_batch(run_id=run_id, metrics=metrics_arr, params=params_arr, tags=[])
|
25
|
+
|
26
|
+
return original(self, hparam_dict, metric_dict, *args, **kwargs)
|
27
|
+
|
28
|
+
|
29
|
+
def patched_add_event(original, self, event, *args, mlflow_log_every_n_step, **kwargs):
|
30
|
+
run = mlflow.active_run()
|
31
|
+
if run is not None and event.WhichOneof("what") == "summary" and mlflow_log_every_n_step:
|
32
|
+
summary = event.summary
|
33
|
+
global_step = args[0] if len(args) > 0 else kwargs.get("global_step")
|
34
|
+
global_step = global_step or 0
|
35
|
+
for v in summary.value:
|
36
|
+
if v.HasField("simple_value") and global_step % mlflow_log_every_n_step == 0:
|
37
|
+
add_to_metrics_queue(
|
38
|
+
key=v.tag,
|
39
|
+
value=v.simple_value,
|
40
|
+
step=global_step,
|
41
|
+
time=int((event.wall_time or time.time()) * 1000),
|
42
|
+
run_id=run.info.run_id,
|
43
|
+
)
|
44
|
+
|
45
|
+
return original(self, event, *args, **kwargs)
|
46
|
+
|
47
|
+
|
48
|
+
def patched_add_summary(original, self, *args, **kwargs):
|
49
|
+
flush_metrics_queue()
|
50
|
+
return original(self, *args, **kwargs)
|
@@ -0,0 +1,35 @@
|
|
1
|
+
"""
|
2
|
+
This module imports contents from CloudPickle in a way that is compatible with the
|
3
|
+
``pickle_module`` parameter of PyTorch's model persistence function: ``torch.save``
|
4
|
+
(see https://github.com/pytorch/pytorch/blob/692898fe379c9092f5e380797c32305145cd06e1/torch/
|
5
|
+
serialization.py#L192). It is included as a distinct module from :mod:`mlflow.pytorch` to avoid
|
6
|
+
polluting the namespace with wildcard imports.
|
7
|
+
|
8
|
+
Calling ``torch.save(..., pickle_module=mlflow.pytorch.pickle_module)`` will persist PyTorch model
|
9
|
+
definitions using CloudPickle, leveraging improved pickling functionality such as the ability
|
10
|
+
to capture class definitions in the "__main__" scope.
|
11
|
+
|
12
|
+
TODO: Remove this module or make it an alias of CloudPickle when CloudPickle and PyTorch have
|
13
|
+
compatible pickling APIs.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# Import all contents of the CloudPickle module in an attempt to include all functions required
|
17
|
+
# by ``torch.save``.
|
18
|
+
|
19
|
+
# CloudPickle does not include `Unpickler` in its namespace, which is required by PyTorch for
|
20
|
+
# deserialization. Noting that CloudPickle's `load()` and `loads()` routines are aliases for
|
21
|
+
# `pickle.load()` and `pickle.loads()`, we therefore import Unpickler from the native
|
22
|
+
# Python pickle library.
|
23
|
+
from pickle import Unpickler # noqa: F401
|
24
|
+
|
25
|
+
from cloudpickle import * # noqa: F403
|
26
|
+
|
27
|
+
# PyTorch uses the ``Pickler`` class of the specified ``pickle_module``
|
28
|
+
# (https://github.com/pytorch/pytorch/blob/692898fe379c9092f5e380797c32305145cd06e1/torch/
|
29
|
+
# serialization.py#L290). Unfortunately, ``cloudpickle.Pickler`` is an alias for Python's native
|
30
|
+
# pickling class: ``pickle.Pickler``, instead of ``cloudpickle.CloudPickler``.
|
31
|
+
# https://github.com/cloudpipe/cloudpickle/pull/235 has been filed to correct the issue,
|
32
|
+
# but this import renaming is necessary until either the requested change has been incorporated
|
33
|
+
# into a CloudPickle release or the ``torch.save`` API has been updated to be compatible with
|
34
|
+
# the existing CloudPickle API.
|
35
|
+
from cloudpickle import CloudPickler as Pickler # noqa: F401
|