genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,280 @@
|
|
1
|
+
import os
|
2
|
+
import platform
|
3
|
+
import shutil
|
4
|
+
import subprocess
|
5
|
+
import sys
|
6
|
+
|
7
|
+
import yaml
|
8
|
+
|
9
|
+
import mlflow
|
10
|
+
from mlflow import MlflowClient
|
11
|
+
from mlflow.environment_variables import MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS
|
12
|
+
from mlflow.exceptions import MlflowException
|
13
|
+
from mlflow.protos.databricks_pb2 import BAD_REQUEST
|
14
|
+
from mlflow.pyfunc.model import MLMODEL_FILE_NAME, Model
|
15
|
+
from mlflow.store.artifact.utils.models import _parse_model_uri, get_model_name_and_version
|
16
|
+
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
|
17
|
+
from mlflow.utils.environment import (
|
18
|
+
_REQUIREMENTS_FILE_NAME,
|
19
|
+
_get_pip_deps,
|
20
|
+
_mlflow_additional_pip_env,
|
21
|
+
_overwrite_pip_deps,
|
22
|
+
)
|
23
|
+
from mlflow.utils.model_utils import _validate_and_prepare_target_save_path
|
24
|
+
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri
|
25
|
+
|
26
|
+
_WHEELS_FOLDER_NAME = "wheels"
|
27
|
+
_ORIGINAL_REQ_FILE_NAME = "original_requirements.txt"
|
28
|
+
_PLATFORM = "platform"
|
29
|
+
|
30
|
+
|
31
|
+
class WheeledModel:
|
32
|
+
"""
|
33
|
+
Helper class to create a model with added dependency wheels from an existing registered model.
|
34
|
+
The `wheeled` model contains all the model dependencies as wheels stored as model artifacts.
|
35
|
+
.. note::
|
36
|
+
This utility only operates on a model that has been registered to the Model Registry.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(self, model_uri):
|
40
|
+
self._model_uri = model_uri
|
41
|
+
databricks_profile_uri = (
|
42
|
+
get_databricks_profile_uri_from_artifact_uri(model_uri) or mlflow.get_registry_uri()
|
43
|
+
)
|
44
|
+
client = MlflowClient(registry_uri=databricks_profile_uri)
|
45
|
+
self._model_name, _ = get_model_name_and_version(client, model_uri)
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def log_model(cls, model_uri, registered_model_name=None):
|
49
|
+
"""
|
50
|
+
Logs a registered model as an MLflow artifact for the current run. This only operates on
|
51
|
+
a model which has been registered to the Model Registry. Given a registered model_uri (
|
52
|
+
e.g. models:/<model_name>/<model_version>), this utility re-logs the model along with all
|
53
|
+
the required model libraries back to the Model Registry. The required model libraries are
|
54
|
+
stored along with the model as model artifacts. In addition, supporting files to the
|
55
|
+
model (e.g. conda.yaml, requirements.txt) are modified to use the added libraries.
|
56
|
+
|
57
|
+
By default, this utility creates a new model version under the same registered model
|
58
|
+
specified by ``model_uri``. This behavior can be overridden by specifying the
|
59
|
+
``registered_model_name`` argument.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
model_uri: A registered model uri in the Model Registry of the form
|
63
|
+
models:/<model_name>/<model_version/stage/latest>
|
64
|
+
registered_model_name: The new model version (model with its libraries) is
|
65
|
+
registered under the inputted registered_model_name. If None,
|
66
|
+
a new version is logged to the existing model in the Model
|
67
|
+
Registry.
|
68
|
+
|
69
|
+
.. code-block:: python
|
70
|
+
:caption: Example
|
71
|
+
|
72
|
+
# Given a model uri, log the wheeled model
|
73
|
+
with mlflow.start_run():
|
74
|
+
WheeledModel.log_model(model_uri)
|
75
|
+
"""
|
76
|
+
parsed_uri = _parse_model_uri(model_uri)
|
77
|
+
return Model.log(
|
78
|
+
artifact_path=None,
|
79
|
+
flavor=WheeledModel(model_uri),
|
80
|
+
registered_model_name=registered_model_name or parsed_uri.name,
|
81
|
+
)
|
82
|
+
|
83
|
+
def save_model(self, path, mlflow_model=None):
|
84
|
+
"""
|
85
|
+
Given an existing registered model, saves the model along with it's dependencies stored as
|
86
|
+
wheels to a path on the local file system.
|
87
|
+
|
88
|
+
This does not modify existing model behavior or existing model flavors. It simply downloads
|
89
|
+
the model dependencies as wheels and modifies the requirements.txt and conda.yaml file to
|
90
|
+
point to the downloaded wheels.
|
91
|
+
|
92
|
+
The download_command defaults to downloading only binary packages using the
|
93
|
+
`--only-binary=:all:` option. This behavior can be overridden using an environment
|
94
|
+
variable `MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS`, which will allows setting
|
95
|
+
different options such as `--prefer-binary`, `--no-binary`, etc.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
path: Local path where the model is to be saved.
|
99
|
+
mlflow_model: The new :py:mod:`mlflow.models.Model` metadata file to store the
|
100
|
+
updated model metadata.
|
101
|
+
"""
|
102
|
+
from mlflow.pyfunc import ENV, FLAVOR_NAME, _extract_conda_env
|
103
|
+
|
104
|
+
path = os.path.abspath(path)
|
105
|
+
_validate_and_prepare_target_save_path(path)
|
106
|
+
|
107
|
+
local_model_path = _download_artifact_from_uri(self._model_uri, output_path=path)
|
108
|
+
|
109
|
+
wheels_dir = os.path.join(local_model_path, _WHEELS_FOLDER_NAME)
|
110
|
+
pip_requirements_path = os.path.join(local_model_path, _REQUIREMENTS_FILE_NAME)
|
111
|
+
model_metadata_path = os.path.join(local_model_path, MLMODEL_FILE_NAME)
|
112
|
+
|
113
|
+
model_metadata = Model.load(model_metadata_path)
|
114
|
+
|
115
|
+
# Check if the model file has `wheels` set to True
|
116
|
+
if model_metadata.__dict__.get(_WHEELS_FOLDER_NAME, None) is not None:
|
117
|
+
raise MlflowException("Model libraries are already added", BAD_REQUEST)
|
118
|
+
|
119
|
+
conda_env = _extract_conda_env(model_metadata.flavors.get(FLAVOR_NAME, {}).get(ENV, None))
|
120
|
+
conda_env_path = os.path.join(local_model_path, conda_env)
|
121
|
+
if conda_env is None and not os.path.isfile(pip_requirements_path):
|
122
|
+
raise MlflowException(
|
123
|
+
"Cannot add libraries for model with no logged dependencies.", BAD_REQUEST
|
124
|
+
)
|
125
|
+
|
126
|
+
if not os.path.isfile(pip_requirements_path):
|
127
|
+
self._create_pip_requirement(conda_env_path, pip_requirements_path)
|
128
|
+
|
129
|
+
WheeledModel._download_wheels(
|
130
|
+
pip_requirements_path=pip_requirements_path, dst_path=wheels_dir
|
131
|
+
)
|
132
|
+
|
133
|
+
# Keep a copy of the original requirement.txt
|
134
|
+
shutil.copy2(pip_requirements_path, os.path.join(local_model_path, _ORIGINAL_REQ_FILE_NAME))
|
135
|
+
|
136
|
+
# Update requirements.txt with wheels
|
137
|
+
pip_deps = self._overwrite_pip_requirements_with_wheels(
|
138
|
+
pip_requirements_path=pip_requirements_path, wheels_dir=wheels_dir
|
139
|
+
)
|
140
|
+
|
141
|
+
# Update conda.yaml with wheels
|
142
|
+
self._update_conda_env(pip_deps, conda_env_path)
|
143
|
+
|
144
|
+
# Update MLModel File
|
145
|
+
mlflow_model = self._update_mlflow_model(
|
146
|
+
original_model_metadata=model_metadata, mlflow_model=mlflow_model
|
147
|
+
)
|
148
|
+
mlflow_model.save(model_metadata_path)
|
149
|
+
return mlflow_model
|
150
|
+
|
151
|
+
def _update_conda_env(self, new_pip_deps, conda_env_path):
|
152
|
+
"""
|
153
|
+
Updates the list pip packages in the conda.yaml file to the list of wheels in the wheels
|
154
|
+
directory.
|
155
|
+
{
|
156
|
+
"name": "env",
|
157
|
+
"channels": [...],
|
158
|
+
"dependencies": [
|
159
|
+
...,
|
160
|
+
"pip",
|
161
|
+
{"pip": [...]}, <- Overwrite this with list of wheels
|
162
|
+
],
|
163
|
+
}
|
164
|
+
|
165
|
+
Args:
|
166
|
+
new_pip_deps: List of pip dependencies as wheels
|
167
|
+
conda_env_path: Path to conda.yaml file in the model directory
|
168
|
+
"""
|
169
|
+
with open(conda_env_path) as f:
|
170
|
+
conda_env = yaml.safe_load(f)
|
171
|
+
|
172
|
+
new_conda_env = _overwrite_pip_deps(conda_env, new_pip_deps)
|
173
|
+
|
174
|
+
with open(conda_env_path, "w") as out:
|
175
|
+
yaml.safe_dump(new_conda_env, stream=out, default_flow_style=False)
|
176
|
+
|
177
|
+
def _update_mlflow_model(self, original_model_metadata, mlflow_model):
|
178
|
+
"""
|
179
|
+
Modifies the MLModel file to reflect updated information such as the run_id,
|
180
|
+
utc_time_created. Additionally, this also adds `wheels` to the MLModel file to indicate that
|
181
|
+
this is a `wheeled` model.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
original_model_metadata: The model metadata stored in the original MLmodel file.
|
185
|
+
mlflow_model: :py:mod:`mlflow.models.Model` configuration of the newly created
|
186
|
+
wheeled model
|
187
|
+
"""
|
188
|
+
|
189
|
+
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
|
190
|
+
if mlflow_model is None:
|
191
|
+
mlflow_model = Model(run_id=run_id)
|
192
|
+
|
193
|
+
original_model_metadata.__dict__.update(
|
194
|
+
{k: v for k, v in mlflow_model.__dict__.items() if v}
|
195
|
+
)
|
196
|
+
mlflow_model.__dict__.update(original_model_metadata.__dict__)
|
197
|
+
mlflow_model.artifact_path = WheeledModel.get_wheel_artifact_path(
|
198
|
+
mlflow_model.artifact_path
|
199
|
+
)
|
200
|
+
|
201
|
+
mlflow_model.wheels = {_PLATFORM: platform.platform()}
|
202
|
+
return mlflow_model
|
203
|
+
|
204
|
+
@classmethod
|
205
|
+
def _download_wheels(cls, pip_requirements_path, dst_path):
|
206
|
+
"""
|
207
|
+
Downloads all the wheels of the dependencies specified in the requirements.txt file.
|
208
|
+
The pip wheel download_command defaults to downloading only binary packages using
|
209
|
+
the `--only-binary=:all:` option. This behavior can be overridden using an
|
210
|
+
environment variable `MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS`, which will allows
|
211
|
+
setting different options such as `--prefer-binary`, `--no-binary`, etc.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
pip_requirements_path: Path to requirements.txt in the model directory
|
215
|
+
dst_path: Path to the directory where the wheels are to be downloaded
|
216
|
+
"""
|
217
|
+
if not os.path.exists(dst_path):
|
218
|
+
os.makedirs(dst_path)
|
219
|
+
|
220
|
+
pip_wheel_options = MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS.get()
|
221
|
+
|
222
|
+
try:
|
223
|
+
subprocess.run(
|
224
|
+
[
|
225
|
+
sys.executable,
|
226
|
+
"-m",
|
227
|
+
"pip",
|
228
|
+
"wheel",
|
229
|
+
pip_wheel_options,
|
230
|
+
"--wheel-dir",
|
231
|
+
dst_path,
|
232
|
+
"-r",
|
233
|
+
pip_requirements_path,
|
234
|
+
"--no-cache-dir",
|
235
|
+
"--progress-bar=off",
|
236
|
+
],
|
237
|
+
check=True,
|
238
|
+
stdout=subprocess.PIPE,
|
239
|
+
stderr=subprocess.STDOUT,
|
240
|
+
)
|
241
|
+
except subprocess.CalledProcessError as e:
|
242
|
+
raise MlflowException(
|
243
|
+
f"An error occurred while downloading the dependency wheels: {e.stdout}"
|
244
|
+
)
|
245
|
+
|
246
|
+
def _overwrite_pip_requirements_with_wheels(self, pip_requirements_path, wheels_dir):
|
247
|
+
"""
|
248
|
+
Overwrites the requirements.txt with the wheels of the required dependencies.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
pip_requirements_path: Path to requirements.txt in the model directory.
|
252
|
+
wheels_dir: Path to directory where wheels are stored.
|
253
|
+
"""
|
254
|
+
wheels = []
|
255
|
+
with open(pip_requirements_path, "w") as wheels_requirements:
|
256
|
+
for wheel_file in os.listdir(wheels_dir):
|
257
|
+
if wheel_file.endswith(".whl"):
|
258
|
+
complete_wheel_file = os.path.join(_WHEELS_FOLDER_NAME, wheel_file)
|
259
|
+
wheels.append(complete_wheel_file)
|
260
|
+
wheels_requirements.write(complete_wheel_file + "\n")
|
261
|
+
return wheels
|
262
|
+
|
263
|
+
def _create_pip_requirement(self, conda_env_path, pip_requirements_path):
|
264
|
+
"""
|
265
|
+
This method creates a requirements.txt file for the model dependencies if the file does not
|
266
|
+
already exist. It uses the pip dependencies found in the conda.yaml env file.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
conda_env_path: Path to conda.yaml env file which contains the required pip
|
270
|
+
dependencies
|
271
|
+
pip_requirements_path: Path where the new requirements.txt will be created.
|
272
|
+
"""
|
273
|
+
with open(conda_env_path) as f:
|
274
|
+
conda_env = yaml.safe_load(f)
|
275
|
+
pip_deps = _get_pip_deps(conda_env)
|
276
|
+
_mlflow_additional_pip_env(pip_deps, pip_requirements_path)
|
277
|
+
|
278
|
+
@classmethod
|
279
|
+
def get_wheel_artifact_path(cls, original_artifact_path):
|
280
|
+
return original_artifact_path + "_" + _WHEELS_FOLDER_NAME
|
@@ -0,0 +1,57 @@
|
|
1
|
+
"""
|
2
|
+
The ``mlflow.openai`` module provides an API for logging and loading OpenAI models.
|
3
|
+
|
4
|
+
Credential management for OpenAI on Databricks
|
5
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
6
|
+
|
7
|
+
.. warning::
|
8
|
+
|
9
|
+
Specifying secrets for model serving with ``MLFLOW_OPENAI_SECRET_SCOPE`` is deprecated.
|
10
|
+
Use `secrets-based environment variables <https://docs.databricks.com/en/machine-learning/model-serving/store-env-variable-model-serving.html>`_
|
11
|
+
instead.
|
12
|
+
|
13
|
+
When this flavor logs a model on Databricks, it saves a YAML file with the following contents as
|
14
|
+
``openai.yaml`` if the ``MLFLOW_OPENAI_SECRET_SCOPE`` environment variable is set.
|
15
|
+
|
16
|
+
.. code-block:: yaml
|
17
|
+
|
18
|
+
OPENAI_API_BASE: {scope}:openai_api_base
|
19
|
+
OPENAI_API_KEY: {scope}:openai_api_key
|
20
|
+
OPENAI_API_KEY_PATH: {scope}:openai_api_key_path
|
21
|
+
OPENAI_API_TYPE: {scope}:openai_api_type
|
22
|
+
OPENAI_ORGANIZATION: {scope}:openai_organization
|
23
|
+
|
24
|
+
- ``{scope}`` is the value of the ``MLFLOW_OPENAI_SECRET_SCOPE`` environment variable.
|
25
|
+
- The keys are the environment variables that the ``openai-python`` package uses to
|
26
|
+
configure the API client.
|
27
|
+
- The values are the references to the secrets that store the values of the environment
|
28
|
+
variables.
|
29
|
+
|
30
|
+
When the logged model is served on Databricks, each secret will be resolved and set as the
|
31
|
+
corresponding environment variable. See https://docs.databricks.com/security/secrets/index.html
|
32
|
+
for how to set up secrets on Databricks.
|
33
|
+
"""
|
34
|
+
|
35
|
+
from mlflow.openai.autolog import autolog
|
36
|
+
from mlflow.openai.constant import FLAVOR_NAME
|
37
|
+
from mlflow.version import IS_TRACING_SDK_ONLY
|
38
|
+
|
39
|
+
__all__ = ["autolog", "FLAVOR_NAME"]
|
40
|
+
|
41
|
+
|
42
|
+
# Import model logging APIs only if mlflow skinny or full package is installed,
|
43
|
+
# i.e., skip if only mlflow-tracing package is installed.
|
44
|
+
if not IS_TRACING_SDK_ONLY:
|
45
|
+
from mlflow.openai.model import (
|
46
|
+
_load_pyfunc,
|
47
|
+
load_model,
|
48
|
+
log_model,
|
49
|
+
save_model,
|
50
|
+
)
|
51
|
+
|
52
|
+
__all__ += [
|
53
|
+
"load_model",
|
54
|
+
"log_model",
|
55
|
+
"save_model",
|
56
|
+
"_load_pyfunc",
|
57
|
+
]
|
@@ -0,0 +1,364 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from typing import Any, Optional
|
6
|
+
|
7
|
+
import agents.tracing as oai
|
8
|
+
from agents import add_trace_processor
|
9
|
+
from agents._run_impl import TraceCtxManager
|
10
|
+
from agents.tracing.setup import GLOBAL_TRACE_PROVIDER
|
11
|
+
from pydantic import BaseModel
|
12
|
+
|
13
|
+
from mlflow.entities.span import LiveSpan, SpanType
|
14
|
+
from mlflow.entities.span_event import SpanEvent
|
15
|
+
from mlflow.entities.span_status import SpanStatus, SpanStatusCode
|
16
|
+
from mlflow.openai import FLAVOR_NAME
|
17
|
+
from mlflow.tracing.constant import SpanAttributeKey
|
18
|
+
from mlflow.tracing.fluent import start_span_no_context
|
19
|
+
from mlflow.types.chat import (
|
20
|
+
ChatMessage,
|
21
|
+
ChatTool,
|
22
|
+
Function,
|
23
|
+
FunctionToolDefinition,
|
24
|
+
TextContentPart,
|
25
|
+
ToolCall,
|
26
|
+
)
|
27
|
+
from mlflow.utils.autologging_utils.safety import safe_patch
|
28
|
+
|
29
|
+
_logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class OpenAISpanType:
|
33
|
+
"""
|
34
|
+
https://github.com/openai/openai-agents-python/blob/ca8e8bed5d0f33e8a0bc3eabd5f1b0a183e73765/src/agents/tracing/span_data.py#L11
|
35
|
+
"""
|
36
|
+
|
37
|
+
AGENT = "agent"
|
38
|
+
FUNCTION = "function"
|
39
|
+
GENERATION = "generation"
|
40
|
+
RESPONSE = "response"
|
41
|
+
HANDOFF = "handoff"
|
42
|
+
CUSTOM = "custom"
|
43
|
+
GUARDRAIL = "guardrail"
|
44
|
+
|
45
|
+
|
46
|
+
_SPAN_TYPE_MAP = {
|
47
|
+
OpenAISpanType.AGENT: SpanType.AGENT,
|
48
|
+
OpenAISpanType.FUNCTION: SpanType.TOOL,
|
49
|
+
OpenAISpanType.GENERATION: SpanType.CHAT_MODEL,
|
50
|
+
OpenAISpanType.RESPONSE: SpanType.CHAT_MODEL,
|
51
|
+
OpenAISpanType.GUARDRAIL: SpanType.TOOL,
|
52
|
+
# Default to chain type
|
53
|
+
}
|
54
|
+
|
55
|
+
|
56
|
+
def add_mlflow_trace_processor():
|
57
|
+
processors = GLOBAL_TRACE_PROVIDER._multi_processor._processors
|
58
|
+
|
59
|
+
if any(isinstance(p, MlflowOpenAgentTracingProcessor) for p in processors):
|
60
|
+
return
|
61
|
+
|
62
|
+
add_trace_processor(MlflowOpenAgentTracingProcessor())
|
63
|
+
|
64
|
+
|
65
|
+
def remove_mlflow_trace_processor():
|
66
|
+
processors = GLOBAL_TRACE_PROVIDER._multi_processor._processors
|
67
|
+
non_mlflow_processors = [
|
68
|
+
p for p in processors if not isinstance(p, MlflowOpenAgentTracingProcessor)
|
69
|
+
]
|
70
|
+
GLOBAL_TRACE_PROVIDER._multi_processor._processors = non_mlflow_processors
|
71
|
+
|
72
|
+
|
73
|
+
class MlflowOpenAgentTracingProcessor(oai.TracingProcessor):
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
project_name: Optional[str] = None,
|
77
|
+
**kwargs: Any,
|
78
|
+
) -> None:
|
79
|
+
super().__init__(**kwargs)
|
80
|
+
self._span_id_to_mlflow_span: dict[str, LiveSpan] = {}
|
81
|
+
self._project_name = project_name
|
82
|
+
|
83
|
+
# Patch TraceCtxManager to handle exceptions from the agent properly
|
84
|
+
# The original implementation does not propagate exception to the root span,
|
85
|
+
# resulting in the trace to have status OK even if there is an exception.
|
86
|
+
def _patched_exit(original, instance, exc_type, exc_val, exc_tb):
|
87
|
+
try:
|
88
|
+
if exc_val and instance.trace:
|
89
|
+
span = self._span_id_to_mlflow_span.get(instance.trace.trace_id)
|
90
|
+
span.add_event(SpanEvent.from_exception(exc_val))
|
91
|
+
span.set_status(SpanStatusCode.ERROR)
|
92
|
+
except Exception:
|
93
|
+
_logger.debug("Failed to handle exception in MLflow trace", exc_info=True)
|
94
|
+
|
95
|
+
return original(instance, exc_type, exc_val, exc_tb)
|
96
|
+
|
97
|
+
safe_patch(
|
98
|
+
FLAVOR_NAME,
|
99
|
+
TraceCtxManager,
|
100
|
+
"__exit__",
|
101
|
+
_patched_exit,
|
102
|
+
)
|
103
|
+
|
104
|
+
def on_trace_start(self, trace: oai.Trace) -> None:
|
105
|
+
try:
|
106
|
+
mlflow_span = start_span_no_context(
|
107
|
+
name=trace.name,
|
108
|
+
span_type=SpanType.AGENT,
|
109
|
+
# TODO: Trace object doesn't contain input/output. Can we get it somehow?
|
110
|
+
inputs="",
|
111
|
+
attributes=trace.metadata,
|
112
|
+
)
|
113
|
+
# NB: Trace ID has different prefix as span ID so will not conflict
|
114
|
+
self._span_id_to_mlflow_span[trace.trace_id] = mlflow_span
|
115
|
+
|
116
|
+
if trace.group_id:
|
117
|
+
# Group ID is used for grouping multiple agent executions together
|
118
|
+
mlflow_span.set_tag("group_id", trace.group_id)
|
119
|
+
|
120
|
+
original_exit = trace.__exit__
|
121
|
+
|
122
|
+
# Patch __exit__ method to handle exception properly
|
123
|
+
def _patched_exit(self, exc_type, exc_val, exc_tb):
|
124
|
+
if exc_val:
|
125
|
+
mlflow_span.add_event(SpanEvent.from_exception(exc_val))
|
126
|
+
mlflow_span.set_status(SpanStatusCode.ERROR)
|
127
|
+
|
128
|
+
original_exit(exc_type, exc_val, exc_tb)
|
129
|
+
|
130
|
+
safe_patch(
|
131
|
+
FLAVOR_NAME,
|
132
|
+
trace.__class__,
|
133
|
+
"__exit__",
|
134
|
+
_patched_exit,
|
135
|
+
)
|
136
|
+
|
137
|
+
except Exception:
|
138
|
+
_logger.debug("Failed to start MLflow trace", exc_info=True)
|
139
|
+
|
140
|
+
def on_trace_end(self, trace: oai.Trace) -> None:
|
141
|
+
try:
|
142
|
+
mlflow_span = self._span_id_to_mlflow_span.pop(trace.trace_id, None)
|
143
|
+
mlflow_span.end(status=mlflow_span.status, outputs="")
|
144
|
+
except Exception:
|
145
|
+
_logger.debug("Failed to end MLflow trace", exc_info=True)
|
146
|
+
|
147
|
+
def on_span_start(self, span: oai.Span[Any]) -> None:
|
148
|
+
try:
|
149
|
+
parent_mlflow_span = self._span_id_to_mlflow_span.get(span.parent_id)
|
150
|
+
|
151
|
+
# Parent might be a trace
|
152
|
+
if not parent_mlflow_span:
|
153
|
+
parent_mlflow_span = self._span_id_to_mlflow_span.get(span.trace_id)
|
154
|
+
|
155
|
+
inputs, _, attributes = _parse_span_data(span.span_data)
|
156
|
+
|
157
|
+
mlflow_span = start_span_no_context(
|
158
|
+
name=_get_span_name(span.span_data),
|
159
|
+
span_type=_SPAN_TYPE_MAP.get(span.span_data.type, SpanType.CHAIN),
|
160
|
+
parent_span=parent_mlflow_span,
|
161
|
+
inputs=inputs,
|
162
|
+
attributes=attributes,
|
163
|
+
)
|
164
|
+
self._span_id_to_mlflow_span[span.span_id] = mlflow_span
|
165
|
+
except Exception:
|
166
|
+
_logger.debug("Failed to start MLflow span", exc_info=True)
|
167
|
+
|
168
|
+
def on_span_end(self, span: oai.Span[Any]) -> None:
|
169
|
+
try:
|
170
|
+
# parsed_span_data = parse_spandata(span.span_data)
|
171
|
+
mlflow_span = self._span_id_to_mlflow_span.pop(span.span_id, None)
|
172
|
+
|
173
|
+
inputs, outputs, attributes = _parse_span_data(span.span_data)
|
174
|
+
|
175
|
+
mlflow_span.set_inputs(inputs)
|
176
|
+
mlflow_span.set_outputs(outputs)
|
177
|
+
mlflow_span.set_attributes(attributes)
|
178
|
+
|
179
|
+
if span.error:
|
180
|
+
status = SpanStatus(
|
181
|
+
status_code=SpanStatusCode.ERROR,
|
182
|
+
description=span.error["message"],
|
183
|
+
)
|
184
|
+
mlflow_span.add_event(
|
185
|
+
SpanEvent(
|
186
|
+
name="exception",
|
187
|
+
attributes={
|
188
|
+
"exception.message": span.error["message"],
|
189
|
+
"exception.type": "",
|
190
|
+
"exception.stacktrace": json.dumps(span.error["data"]),
|
191
|
+
},
|
192
|
+
)
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
status = SpanStatusCode.OK
|
196
|
+
|
197
|
+
mlflow_span.end(status=status)
|
198
|
+
except Exception:
|
199
|
+
_logger.debug("Failed to end MLflow span", exc_info=True)
|
200
|
+
|
201
|
+
def force_flush(self) -> None:
|
202
|
+
# MLflow doesn't need flush but this method is required by the interface
|
203
|
+
pass
|
204
|
+
|
205
|
+
def shutdown(self) -> None:
|
206
|
+
self.force_flush()
|
207
|
+
|
208
|
+
|
209
|
+
def _get_span_name(span_data: oai.SpanData) -> str:
|
210
|
+
if hasattr(span_data, "name"):
|
211
|
+
return span_data.name
|
212
|
+
elif isinstance(span_data, oai.GenerationSpanData):
|
213
|
+
return "Generation"
|
214
|
+
elif isinstance(span_data, oai.ResponseSpanData):
|
215
|
+
return "Response"
|
216
|
+
elif isinstance(span_data, oai.HandoffSpanData):
|
217
|
+
return "Handoff"
|
218
|
+
else:
|
219
|
+
return "Unknown"
|
220
|
+
|
221
|
+
|
222
|
+
def _parse_span_data(span_data: oai.SpanData) -> tuple[Any, Any, dict[str, Any]]:
|
223
|
+
inputs = None
|
224
|
+
outputs = None
|
225
|
+
attributes = {}
|
226
|
+
|
227
|
+
if span_data.type == OpenAISpanType.AGENT:
|
228
|
+
attributes = {
|
229
|
+
"handoffs": span_data.handoffs,
|
230
|
+
"tools": span_data.tools,
|
231
|
+
"output_type": span_data.output_type,
|
232
|
+
}
|
233
|
+
outputs = {"output_type": span_data.output_type}
|
234
|
+
|
235
|
+
elif span_data.type == OpenAISpanType.FUNCTION:
|
236
|
+
try:
|
237
|
+
inputs = json.loads(span_data.input)
|
238
|
+
except Exception:
|
239
|
+
inputs = span_data.input
|
240
|
+
outputs = span_data.output
|
241
|
+
|
242
|
+
elif span_data.type == OpenAISpanType.GENERATION:
|
243
|
+
inputs = span_data.input
|
244
|
+
outputs = span_data.output
|
245
|
+
attributes = {
|
246
|
+
"model": span_data.model,
|
247
|
+
"model_config": span_data.model_config,
|
248
|
+
"usage": span_data.usage,
|
249
|
+
}
|
250
|
+
|
251
|
+
elif span_data.type == OpenAISpanType.RESPONSE:
|
252
|
+
inputs, outputs, attributes = _parse_response_span_data(span_data)
|
253
|
+
|
254
|
+
elif span_data.type == OpenAISpanType.HANDOFF:
|
255
|
+
inputs = {"from_agent": span_data.from_agent}
|
256
|
+
outputs = {"to_agent": span_data.to_agent}
|
257
|
+
|
258
|
+
elif span_data.type == OpenAISpanType.CUSTOM:
|
259
|
+
outputs = span_data.data
|
260
|
+
|
261
|
+
elif span_data.type == OpenAISpanType.GUARDRAIL:
|
262
|
+
outputs = {"triggered": span_data.triggered}
|
263
|
+
|
264
|
+
return inputs, outputs, attributes
|
265
|
+
|
266
|
+
|
267
|
+
def _parse_response_span_data(span_data: oai.ResponseSpanData) -> tuple[Any, Any, dict[str, Any]]:
|
268
|
+
inputs = span_data.input
|
269
|
+
response = span_data.response
|
270
|
+
response_dict = response.model_dump() if response else {}
|
271
|
+
outputs = response_dict.get("output")
|
272
|
+
attributes = {k: v for k, v in response_dict.items() if k != "output"}
|
273
|
+
|
274
|
+
# Extract chat messages
|
275
|
+
messages = []
|
276
|
+
if response and response.instructions:
|
277
|
+
messages.append(ChatMessage(role="system", content=span_data.response.instructions))
|
278
|
+
if span_data.input:
|
279
|
+
parsed = [_parse_message_like(m) for m in span_data.input]
|
280
|
+
messages.extend([m for m in parsed if m is not None])
|
281
|
+
if response and response.output:
|
282
|
+
parsed = [_parse_message_like(m) for m in span_data.response.output]
|
283
|
+
messages.extend(parsed)
|
284
|
+
attributes[SpanAttributeKey.CHAT_MESSAGES] = [m.model_dump_compat() for m in messages]
|
285
|
+
|
286
|
+
# Extract chat tools
|
287
|
+
chat_tools = []
|
288
|
+
for tool in response_dict.get("tools", []):
|
289
|
+
try:
|
290
|
+
tool = ChatTool(
|
291
|
+
type="function",
|
292
|
+
function=FunctionToolDefinition(
|
293
|
+
name=tool["name"],
|
294
|
+
description=tool.get("description"),
|
295
|
+
parameters=tool.get("parameters"),
|
296
|
+
strict=tool.get("strict"),
|
297
|
+
),
|
298
|
+
)
|
299
|
+
chat_tools.append(tool)
|
300
|
+
except Exception as e:
|
301
|
+
_logger.debug(f"Failed to parse chat tool: {tool}. Error: {e}")
|
302
|
+
|
303
|
+
if chat_tools:
|
304
|
+
attributes[SpanAttributeKey.CHAT_TOOLS] = chat_tools
|
305
|
+
|
306
|
+
return inputs, outputs, attributes
|
307
|
+
|
308
|
+
|
309
|
+
def _parse_message_like(message_like: Any) -> Optional[ChatMessage]:
|
310
|
+
try:
|
311
|
+
return ChatMessage.validate_compat(message_like)
|
312
|
+
except Exception:
|
313
|
+
pass
|
314
|
+
|
315
|
+
if isinstance(message_like, BaseModel):
|
316
|
+
message_like = message_like.model_dump()
|
317
|
+
|
318
|
+
msg_type = message_like["type"]
|
319
|
+
if msg_type == "message":
|
320
|
+
content = []
|
321
|
+
refusal = None
|
322
|
+
for content_block in message_like["content"]:
|
323
|
+
# Content is a list of either text or refusal https://github.com/openai/openai-python/blob/9dea82fb8cdd06683f9e8033b54cff219789af7f/src/openai/types/responses/response_output_message.py#L13C38-L13C56
|
324
|
+
if "text" in content_block:
|
325
|
+
content.append(TextContentPart(type="text", text=content_block["text"]))
|
326
|
+
elif "refusal" in content_block:
|
327
|
+
refusal = content_block["refusal"]
|
328
|
+
else:
|
329
|
+
_logger.debug(f"Unknown content type in message: {content_block}")
|
330
|
+
return ChatMessage(
|
331
|
+
role=message_like["role"],
|
332
|
+
content=content,
|
333
|
+
refusal=refusal,
|
334
|
+
)
|
335
|
+
elif msg_type == "function_call":
|
336
|
+
return ChatMessage(
|
337
|
+
role="assistant",
|
338
|
+
content="",
|
339
|
+
tool_calls=[
|
340
|
+
ToolCall(
|
341
|
+
id=message_like["call_id"],
|
342
|
+
function=Function(
|
343
|
+
name=message_like["name"],
|
344
|
+
arguments=message_like["arguments"],
|
345
|
+
),
|
346
|
+
)
|
347
|
+
],
|
348
|
+
)
|
349
|
+
elif msg_type == "function_call_output":
|
350
|
+
return ChatMessage(
|
351
|
+
role="tool",
|
352
|
+
content=message_like["output"],
|
353
|
+
tool_call_id=message_like["call_id"],
|
354
|
+
)
|
355
|
+
|
356
|
+
# Ignore unknown message types.
|
357
|
+
# Response API supports the following additional message types, which is not
|
358
|
+
# supported by our chat standard schema yet:
|
359
|
+
# https://github.com/openai/openai-python/blob/9dea82fb8cdd06683f9e8033b54cff219789af7f/src/openai/types/responses/response_output_item.py#L16
|
360
|
+
# - File search tool call
|
361
|
+
# - Web search tool call
|
362
|
+
# - Computer tool call
|
363
|
+
# - Reasoning
|
364
|
+
_logger.debug(f"Unknown message type: {msg_type}")
|