genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
mlflow/openai/model.py
ADDED
@@ -0,0 +1,824 @@
|
|
1
|
+
import importlib.metadata
|
2
|
+
import itertools
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import warnings
|
6
|
+
from functools import partial
|
7
|
+
from string import Formatter
|
8
|
+
from typing import Any, Optional, Union
|
9
|
+
|
10
|
+
import yaml
|
11
|
+
from packaging.version import Version
|
12
|
+
|
13
|
+
import mlflow
|
14
|
+
from mlflow import pyfunc
|
15
|
+
from mlflow.entities.model_registry.prompt import Prompt
|
16
|
+
from mlflow.environment_variables import MLFLOW_OPENAI_SECRET_SCOPE
|
17
|
+
from mlflow.exceptions import MlflowException
|
18
|
+
from mlflow.models import Model, ModelInputExample, ModelSignature
|
19
|
+
from mlflow.models.model import MLMODEL_FILE_NAME, _update_active_model_id_based_on_mlflow_model
|
20
|
+
from mlflow.models.signature import _infer_signature_from_input_example
|
21
|
+
from mlflow.models.utils import _save_example
|
22
|
+
from mlflow.openai.constant import FLAVOR_NAME
|
23
|
+
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
|
24
|
+
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
|
25
|
+
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
|
26
|
+
from mlflow.types import ColSpec, Schema, TensorSpec
|
27
|
+
from mlflow.utils.annotations import experimental
|
28
|
+
from mlflow.utils.databricks_utils import (
|
29
|
+
check_databricks_secret_scope_access,
|
30
|
+
is_in_databricks_runtime,
|
31
|
+
)
|
32
|
+
from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
|
33
|
+
from mlflow.utils.environment import (
|
34
|
+
_CONDA_ENV_FILE_NAME,
|
35
|
+
_CONSTRAINTS_FILE_NAME,
|
36
|
+
_PYTHON_ENV_FILE_NAME,
|
37
|
+
_REQUIREMENTS_FILE_NAME,
|
38
|
+
_mlflow_conda_env,
|
39
|
+
_process_conda_env,
|
40
|
+
_process_pip_requirements,
|
41
|
+
_PythonEnv,
|
42
|
+
_validate_env_arguments,
|
43
|
+
)
|
44
|
+
from mlflow.utils.file_utils import write_to
|
45
|
+
from mlflow.utils.model_utils import (
|
46
|
+
_add_code_from_conf_to_system_path,
|
47
|
+
_get_flavor_configuration,
|
48
|
+
_validate_and_copy_code_paths,
|
49
|
+
_validate_and_prepare_target_save_path,
|
50
|
+
)
|
51
|
+
from mlflow.utils.openai_utils import (
|
52
|
+
_OAITokenHolder,
|
53
|
+
_OpenAIApiConfig,
|
54
|
+
_OpenAIEnvVar,
|
55
|
+
_validate_model_params,
|
56
|
+
)
|
57
|
+
from mlflow.utils.requirements_utils import _get_pinned_requirement
|
58
|
+
|
59
|
+
MODEL_FILENAME = "model.yaml"
|
60
|
+
_PYFUNC_SUPPORTED_TASKS = ("chat.completions", "embeddings", "completions")
|
61
|
+
|
62
|
+
_logger = logging.getLogger(__name__)
|
63
|
+
|
64
|
+
|
65
|
+
@experimental(version="2.3.0")
|
66
|
+
def get_default_pip_requirements():
|
67
|
+
"""
|
68
|
+
Returns:
|
69
|
+
A list of default pip requirements for MLflow Models produced by this flavor.
|
70
|
+
Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
|
71
|
+
that, at minimum, contains these requirements.
|
72
|
+
"""
|
73
|
+
return list(map(_get_pinned_requirement, ["openai", "tiktoken", "tenacity"]))
|
74
|
+
|
75
|
+
|
76
|
+
@experimental(version="2.3.0")
|
77
|
+
def get_default_conda_env():
|
78
|
+
"""
|
79
|
+
Returns:
|
80
|
+
The default Conda environment for MLflow Models produced by calls to
|
81
|
+
:func:`save_model()` and :func:`log_model()`.
|
82
|
+
"""
|
83
|
+
return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
|
84
|
+
|
85
|
+
|
86
|
+
def _get_obj_to_task_mapping():
|
87
|
+
from openai import resources as r
|
88
|
+
|
89
|
+
mapping = {
|
90
|
+
r.Audio: "audio",
|
91
|
+
r.chat.Completions: "chat.completions",
|
92
|
+
r.Completions: "completions",
|
93
|
+
r.Images.edit: "images.edit",
|
94
|
+
r.Embeddings: "embeddings",
|
95
|
+
r.Files: "files",
|
96
|
+
r.Images: "images",
|
97
|
+
r.FineTuning: "fine_tuning",
|
98
|
+
r.Moderations: "moderations",
|
99
|
+
r.Models: "models",
|
100
|
+
r.chat.AsyncCompletions: "chat.completions",
|
101
|
+
r.AsyncCompletions: "completions",
|
102
|
+
r.AsyncEmbeddings: "embeddings",
|
103
|
+
}
|
104
|
+
|
105
|
+
try:
|
106
|
+
from openai.resources.beta.chat import completions as c
|
107
|
+
|
108
|
+
mapping.update(
|
109
|
+
{
|
110
|
+
c.AsyncCompletions: "chat.completions",
|
111
|
+
c.Completions: "chat.completions",
|
112
|
+
}
|
113
|
+
)
|
114
|
+
except ImportError:
|
115
|
+
pass
|
116
|
+
return mapping
|
117
|
+
|
118
|
+
|
119
|
+
def _get_model_name(model):
|
120
|
+
import openai
|
121
|
+
|
122
|
+
if isinstance(model, str):
|
123
|
+
return model
|
124
|
+
|
125
|
+
if Version(_get_openai_package_version()).major < 1 and isinstance(model, openai.Model):
|
126
|
+
return model.id
|
127
|
+
|
128
|
+
raise mlflow.MlflowException(
|
129
|
+
f"Unsupported model type: {type(model)}", error_code=INVALID_PARAMETER_VALUE
|
130
|
+
)
|
131
|
+
|
132
|
+
|
133
|
+
def _get_task_name(task):
|
134
|
+
mapping = _get_obj_to_task_mapping()
|
135
|
+
if isinstance(task, str):
|
136
|
+
if task not in mapping.values():
|
137
|
+
raise mlflow.MlflowException(
|
138
|
+
f"Unsupported task: {task}", error_code=INVALID_PARAMETER_VALUE
|
139
|
+
)
|
140
|
+
return task
|
141
|
+
else:
|
142
|
+
task_name = (
|
143
|
+
mapping.get(task)
|
144
|
+
or mapping.get(task.__class__)
|
145
|
+
or mapping.get(getattr(task, "__func__")) # if task is a method
|
146
|
+
)
|
147
|
+
if task_name is None:
|
148
|
+
raise mlflow.MlflowException(
|
149
|
+
f"Unsupported task object: {task}", error_code=INVALID_PARAMETER_VALUE
|
150
|
+
)
|
151
|
+
return task_name
|
152
|
+
|
153
|
+
|
154
|
+
def _get_api_config() -> _OpenAIApiConfig:
|
155
|
+
"""Gets the parameters and configuration of the OpenAI API connected to."""
|
156
|
+
import openai
|
157
|
+
|
158
|
+
api_type = os.getenv(_OpenAIEnvVar.OPENAI_API_TYPE.value, openai.api_type)
|
159
|
+
api_version = os.getenv(_OpenAIEnvVar.OPENAI_API_VERSION.value, openai.api_version)
|
160
|
+
api_base = os.getenv(_OpenAIEnvVar.OPENAI_API_BASE.value) or os.getenv(
|
161
|
+
_OpenAIEnvVar.OPENAI_BASE_URL.value
|
162
|
+
)
|
163
|
+
deployment_id = os.getenv(_OpenAIEnvVar.OPENAI_DEPLOYMENT_NAME.value, None)
|
164
|
+
organization = os.getenv(_OpenAIEnvVar.OPENAI_ORGANIZATION.value, None)
|
165
|
+
if api_type in ("azure", "azure_ad", "azuread"):
|
166
|
+
batch_size = 16
|
167
|
+
max_tokens_per_minute = 60_000
|
168
|
+
else:
|
169
|
+
# The maximum batch size is 2048:
|
170
|
+
# https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/embeddings_utils.py#L43
|
171
|
+
# We use a smaller batch size to be safe.
|
172
|
+
batch_size = 1024
|
173
|
+
max_tokens_per_minute = 90_000
|
174
|
+
return _OpenAIApiConfig(
|
175
|
+
api_type=api_type,
|
176
|
+
batch_size=batch_size,
|
177
|
+
max_requests_per_minute=3_500,
|
178
|
+
max_tokens_per_minute=max_tokens_per_minute,
|
179
|
+
api_base=api_base,
|
180
|
+
api_version=api_version,
|
181
|
+
deployment_id=deployment_id,
|
182
|
+
organization=organization,
|
183
|
+
)
|
184
|
+
|
185
|
+
|
186
|
+
def _get_openai_package_version():
|
187
|
+
return importlib.metadata.version("openai")
|
188
|
+
|
189
|
+
|
190
|
+
def _log_secrets_yaml(local_model_dir, scope):
|
191
|
+
with open(os.path.join(local_model_dir, "openai.yaml"), "w") as f:
|
192
|
+
yaml.safe_dump({e.value: f"{scope}:{e.secret_key}" for e in _OpenAIEnvVar}, f)
|
193
|
+
|
194
|
+
|
195
|
+
def _parse_format_fields(s) -> set[str]:
|
196
|
+
"""Parses format fields from a given string, e.g. "Hello {name}" -> ["name"]."""
|
197
|
+
return {fn for _, fn, _, _ in Formatter().parse(s) if fn is not None}
|
198
|
+
|
199
|
+
|
200
|
+
def _get_input_schema(task, content):
|
201
|
+
if content:
|
202
|
+
formatter = _ContentFormatter(task, content)
|
203
|
+
variables = formatter.variables
|
204
|
+
if len(variables) == 1:
|
205
|
+
return Schema([ColSpec(type="string")])
|
206
|
+
elif len(variables) > 1:
|
207
|
+
return Schema([ColSpec(name=v, type="string") for v in variables])
|
208
|
+
else:
|
209
|
+
return Schema([ColSpec(type="string")])
|
210
|
+
else:
|
211
|
+
return Schema([ColSpec(type="string")])
|
212
|
+
|
213
|
+
|
214
|
+
@experimental(version="2.3.0")
|
215
|
+
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
|
216
|
+
def save_model(
|
217
|
+
model,
|
218
|
+
task,
|
219
|
+
path,
|
220
|
+
conda_env=None,
|
221
|
+
code_paths=None,
|
222
|
+
mlflow_model=None,
|
223
|
+
signature: ModelSignature = None,
|
224
|
+
input_example: ModelInputExample = None,
|
225
|
+
pip_requirements=None,
|
226
|
+
extra_pip_requirements=None,
|
227
|
+
metadata=None,
|
228
|
+
**kwargs,
|
229
|
+
):
|
230
|
+
"""
|
231
|
+
Save an OpenAI model to a path on the local file system.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
model: The OpenAI model name.
|
235
|
+
task: The task the model is performing, e.g., ``openai.chat.completions`` or
|
236
|
+
``'chat.completions'``.
|
237
|
+
path: Local path where the model is to be saved.
|
238
|
+
conda_env: {{ conda_env }}
|
239
|
+
code_paths: {{ code_paths }}
|
240
|
+
mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
|
241
|
+
signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
|
242
|
+
describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
|
243
|
+
The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
|
244
|
+
from datasets with valid model input (e.g. the training dataset with target
|
245
|
+
column omitted) and valid model output (e.g. model predictions generated on
|
246
|
+
the training dataset), for example:
|
247
|
+
|
248
|
+
.. code-block:: python
|
249
|
+
|
250
|
+
from mlflow.models import infer_signature
|
251
|
+
|
252
|
+
train = df.drop_column("target_label")
|
253
|
+
predictions = ... # compute model predictions
|
254
|
+
signature = infer_signature(train, predictions)
|
255
|
+
input_example: {{ input_example }}
|
256
|
+
pip_requirements: {{ pip_requirements }}
|
257
|
+
extra_pip_requirements: {{ extra_pip_requirements }}
|
258
|
+
metadata: {{ metadata }}
|
259
|
+
kwargs: Keyword arguments specific to the OpenAI task, such as the ``messages`` (see
|
260
|
+
:ref:`mlflow.openai.messages` for more details on this parameter)
|
261
|
+
or ``top_p`` value to use for chat completion.
|
262
|
+
|
263
|
+
.. code-block:: python
|
264
|
+
|
265
|
+
import mlflow
|
266
|
+
import openai
|
267
|
+
|
268
|
+
# Chat
|
269
|
+
mlflow.openai.save_model(
|
270
|
+
model="gpt-4o-mini",
|
271
|
+
task=openai.chat.completions,
|
272
|
+
messages=[{"role": "user", "content": "Tell me a joke."}],
|
273
|
+
path="model",
|
274
|
+
)
|
275
|
+
|
276
|
+
# Completions
|
277
|
+
mlflow.openai.save_model(
|
278
|
+
model="text-davinci-002",
|
279
|
+
task=openai.completions,
|
280
|
+
prompt="{text}. The general sentiment of the text is",
|
281
|
+
path="model",
|
282
|
+
)
|
283
|
+
|
284
|
+
# Embeddings
|
285
|
+
mlflow.openai.save_model(
|
286
|
+
model="text-embedding-ada-002",
|
287
|
+
task=openai.embeddings,
|
288
|
+
path="model",
|
289
|
+
)
|
290
|
+
"""
|
291
|
+
if Version(_get_openai_package_version()).major < 1:
|
292
|
+
raise MlflowException("Only openai>=1.0 is supported.")
|
293
|
+
|
294
|
+
import numpy as np
|
295
|
+
|
296
|
+
_validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
|
297
|
+
path = os.path.abspath(path)
|
298
|
+
_validate_and_prepare_target_save_path(path)
|
299
|
+
code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
|
300
|
+
task = _get_task_name(task)
|
301
|
+
|
302
|
+
if mlflow_model is None:
|
303
|
+
mlflow_model = Model()
|
304
|
+
|
305
|
+
if signature is not None:
|
306
|
+
if signature.params:
|
307
|
+
_validate_model_params(
|
308
|
+
task, kwargs, {p.name: p.default for p in signature.params.params}
|
309
|
+
)
|
310
|
+
elif task == "chat.completions":
|
311
|
+
messages = kwargs.get("messages", [])
|
312
|
+
if messages and not (
|
313
|
+
all(isinstance(m, dict) for m in messages) and all(map(_is_valid_message, messages))
|
314
|
+
):
|
315
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
316
|
+
"If `messages` is provided, it must be a list of dictionaries with keys "
|
317
|
+
"'role' and 'content'."
|
318
|
+
)
|
319
|
+
|
320
|
+
signature = ModelSignature(
|
321
|
+
inputs=_get_input_schema(task, messages),
|
322
|
+
outputs=Schema([ColSpec(type="string", name=None)]),
|
323
|
+
)
|
324
|
+
elif task == "completions":
|
325
|
+
prompt = kwargs.get("prompt")
|
326
|
+
signature = ModelSignature(
|
327
|
+
inputs=_get_input_schema(task, prompt),
|
328
|
+
outputs=Schema([ColSpec(type="string", name=None)]),
|
329
|
+
)
|
330
|
+
elif task == "embeddings":
|
331
|
+
signature = ModelSignature(
|
332
|
+
inputs=Schema([ColSpec(type="string", name=None)]),
|
333
|
+
outputs=Schema([TensorSpec(type=np.dtype("float64"), shape=(-1,))]),
|
334
|
+
)
|
335
|
+
|
336
|
+
saved_example = _save_example(mlflow_model, input_example, path)
|
337
|
+
if signature is None and saved_example is not None:
|
338
|
+
wrapped_model = _OpenAIWrapper(model)
|
339
|
+
signature = _infer_signature_from_input_example(saved_example, wrapped_model)
|
340
|
+
|
341
|
+
if signature is not None:
|
342
|
+
mlflow_model.signature = signature
|
343
|
+
|
344
|
+
if metadata is not None:
|
345
|
+
mlflow_model.metadata = metadata
|
346
|
+
model_data_path = os.path.join(path, MODEL_FILENAME)
|
347
|
+
model_dict = {
|
348
|
+
"model": _get_model_name(model),
|
349
|
+
"task": task,
|
350
|
+
**kwargs,
|
351
|
+
}
|
352
|
+
with open(model_data_path, "w") as f:
|
353
|
+
yaml.safe_dump(model_dict, f)
|
354
|
+
|
355
|
+
if task in _PYFUNC_SUPPORTED_TASKS:
|
356
|
+
pyfunc.add_to_model(
|
357
|
+
mlflow_model,
|
358
|
+
loader_module="mlflow.openai",
|
359
|
+
data=MODEL_FILENAME,
|
360
|
+
conda_env=_CONDA_ENV_FILE_NAME,
|
361
|
+
python_env=_PYTHON_ENV_FILE_NAME,
|
362
|
+
code=code_dir_subpath,
|
363
|
+
)
|
364
|
+
mlflow_model.add_flavor(
|
365
|
+
FLAVOR_NAME,
|
366
|
+
openai_version=_get_openai_package_version(),
|
367
|
+
data=MODEL_FILENAME,
|
368
|
+
code=code_dir_subpath,
|
369
|
+
)
|
370
|
+
mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
|
371
|
+
|
372
|
+
if is_in_databricks_runtime():
|
373
|
+
if scope := MLFLOW_OPENAI_SECRET_SCOPE.get():
|
374
|
+
url = "https://docs.databricks.com/en/machine-learning/model-serving/store-env-variable-model-serving.html"
|
375
|
+
warnings.warn(
|
376
|
+
"Specifying secrets for model serving with `MLFLOW_OPENAI_SECRET_SCOPE` is "
|
377
|
+
f"deprecated. Use secrets-based environment variables ({url}) instead.",
|
378
|
+
FutureWarning,
|
379
|
+
)
|
380
|
+
check_databricks_secret_scope_access(scope)
|
381
|
+
_log_secrets_yaml(path, scope)
|
382
|
+
|
383
|
+
if conda_env is None:
|
384
|
+
if pip_requirements is None:
|
385
|
+
default_reqs = get_default_pip_requirements()
|
386
|
+
inferred_reqs = mlflow.models.infer_pip_requirements(
|
387
|
+
path, FLAVOR_NAME, fallback=default_reqs
|
388
|
+
)
|
389
|
+
default_reqs = sorted(set(inferred_reqs).union(default_reqs))
|
390
|
+
else:
|
391
|
+
default_reqs = None
|
392
|
+
conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
|
393
|
+
default_reqs,
|
394
|
+
pip_requirements,
|
395
|
+
extra_pip_requirements,
|
396
|
+
)
|
397
|
+
else:
|
398
|
+
conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
|
399
|
+
|
400
|
+
with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
|
401
|
+
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
|
402
|
+
|
403
|
+
# Save `constraints.txt` if necessary
|
404
|
+
if pip_constraints:
|
405
|
+
write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
|
406
|
+
|
407
|
+
# Save `requirements.txt`
|
408
|
+
write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
|
409
|
+
|
410
|
+
_PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
|
411
|
+
|
412
|
+
|
413
|
+
@experimental(version="2.3.0")
|
414
|
+
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
|
415
|
+
def log_model(
|
416
|
+
model,
|
417
|
+
task,
|
418
|
+
artifact_path: Optional[str] = None,
|
419
|
+
conda_env=None,
|
420
|
+
code_paths=None,
|
421
|
+
registered_model_name=None,
|
422
|
+
signature: ModelSignature = None,
|
423
|
+
input_example: ModelInputExample = None,
|
424
|
+
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
|
425
|
+
pip_requirements=None,
|
426
|
+
extra_pip_requirements=None,
|
427
|
+
metadata=None,
|
428
|
+
prompts: Optional[list[Union[str, Prompt]]] = None,
|
429
|
+
name: Optional[str] = None,
|
430
|
+
params: Optional[dict[str, Any]] = None,
|
431
|
+
tags: Optional[dict[str, Any]] = None,
|
432
|
+
model_type: Optional[str] = None,
|
433
|
+
step: int = 0,
|
434
|
+
model_id: Optional[str] = None,
|
435
|
+
**kwargs,
|
436
|
+
):
|
437
|
+
"""
|
438
|
+
Log an OpenAI model as an MLflow artifact for the current run.
|
439
|
+
|
440
|
+
Args:
|
441
|
+
model: The OpenAI model name or reference instance, e.g.,
|
442
|
+
``openai.Model.retrieve("gpt-4o-mini")``.
|
443
|
+
task: The task the model is performing, e.g., ``openai.chat.completions`` or
|
444
|
+
``'chat.completions'``.
|
445
|
+
artifact_path: Deprecated. Use `name` instead.
|
446
|
+
conda_env: {{ conda_env }}
|
447
|
+
code_paths: {{ code_paths }}
|
448
|
+
registered_model_name: If given, create a model version under
|
449
|
+
``registered_model_name``, also creating a registered model if one
|
450
|
+
with the given name does not exist.
|
451
|
+
signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
|
452
|
+
describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
|
453
|
+
The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
|
454
|
+
from datasets with valid model input (e.g. the training dataset with target
|
455
|
+
column omitted) and valid model output (e.g. model predictions generated on
|
456
|
+
the training dataset), for example:
|
457
|
+
|
458
|
+
.. code-block:: python
|
459
|
+
|
460
|
+
from mlflow.models import infer_signature
|
461
|
+
|
462
|
+
train = df.drop_column("target_label")
|
463
|
+
predictions = ... # compute model predictions
|
464
|
+
signature = infer_signature(train, predictions)
|
465
|
+
|
466
|
+
input_example: {{ input_example }}
|
467
|
+
await_registration_for: Number of seconds to wait for the model version to finish
|
468
|
+
being created and is in ``READY`` status. By default, the function
|
469
|
+
waits for five minutes. Specify 0 or None to skip waiting.
|
470
|
+
pip_requirements: {{ pip_requirements }}
|
471
|
+
extra_pip_requirements: {{ extra_pip_requirements }}
|
472
|
+
metadata: {{ metadata }}
|
473
|
+
prompts: {{ prompts }}
|
474
|
+
name: {{ name }}
|
475
|
+
params: {{ params }}
|
476
|
+
tags: {{ tags }}
|
477
|
+
model_type: {{ model_type }}
|
478
|
+
step: {{ step }}
|
479
|
+
model_id: {{ model_id }}
|
480
|
+
kwargs: Keyword arguments specific to the OpenAI task, such as the ``messages`` (see
|
481
|
+
:ref:`mlflow.openai.messages` for more details on this parameter)
|
482
|
+
or ``top_p`` value to use for chat completion.
|
483
|
+
|
484
|
+
Returns:
|
485
|
+
A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
|
486
|
+
metadata of the logged model.
|
487
|
+
|
488
|
+
.. code-block:: python
|
489
|
+
:caption: Example
|
490
|
+
|
491
|
+
import mlflow
|
492
|
+
import openai
|
493
|
+
import pandas as pd
|
494
|
+
|
495
|
+
# Chat
|
496
|
+
with mlflow.start_run():
|
497
|
+
info = mlflow.openai.log_model(
|
498
|
+
model="gpt-4o-mini",
|
499
|
+
task=openai.chat.completions,
|
500
|
+
messages=[{"role": "user", "content": "Tell me a joke about {animal}."}],
|
501
|
+
name="model",
|
502
|
+
)
|
503
|
+
model = mlflow.pyfunc.load_model(info.model_uri)
|
504
|
+
df = pd.DataFrame({"animal": ["cats", "dogs"]})
|
505
|
+
print(model.predict(df))
|
506
|
+
|
507
|
+
# Embeddings
|
508
|
+
with mlflow.start_run():
|
509
|
+
info = mlflow.openai.log_model(
|
510
|
+
model="text-embedding-ada-002",
|
511
|
+
task=openai.embeddings,
|
512
|
+
name="embeddings",
|
513
|
+
)
|
514
|
+
model = mlflow.pyfunc.load_model(info.model_uri)
|
515
|
+
print(model.predict(["hello", "world"]))
|
516
|
+
"""
|
517
|
+
return Model.log(
|
518
|
+
artifact_path=artifact_path,
|
519
|
+
name=name,
|
520
|
+
flavor=mlflow.openai,
|
521
|
+
registered_model_name=registered_model_name,
|
522
|
+
model=model,
|
523
|
+
task=task,
|
524
|
+
conda_env=conda_env,
|
525
|
+
code_paths=code_paths,
|
526
|
+
signature=signature,
|
527
|
+
input_example=input_example,
|
528
|
+
await_registration_for=await_registration_for,
|
529
|
+
pip_requirements=pip_requirements,
|
530
|
+
extra_pip_requirements=extra_pip_requirements,
|
531
|
+
metadata=metadata,
|
532
|
+
prompts=prompts,
|
533
|
+
params=params,
|
534
|
+
tags=tags,
|
535
|
+
model_type=model_type,
|
536
|
+
step=step,
|
537
|
+
model_id=model_id,
|
538
|
+
**kwargs,
|
539
|
+
)
|
540
|
+
|
541
|
+
|
542
|
+
def _load_model(path):
|
543
|
+
model_file_path = os.path.dirname(path)
|
544
|
+
if os.path.exists(model_file_path):
|
545
|
+
mlflow_model = Model.load(model_file_path)
|
546
|
+
_update_active_model_id_based_on_mlflow_model(mlflow_model)
|
547
|
+
with open(path) as f:
|
548
|
+
return yaml.safe_load(f)
|
549
|
+
|
550
|
+
|
551
|
+
def _is_valid_message(d):
|
552
|
+
return isinstance(d, dict) and "content" in d and "role" in d
|
553
|
+
|
554
|
+
|
555
|
+
class _ContentFormatter:
|
556
|
+
def __init__(self, task, template=None):
|
557
|
+
if task == "completions":
|
558
|
+
template = template or "{prompt}"
|
559
|
+
if not isinstance(template, str):
|
560
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
561
|
+
f"Template for task {task} expects type `str`, but got {type(template)}."
|
562
|
+
)
|
563
|
+
|
564
|
+
self.template = template
|
565
|
+
self.format_fn = self.format_prompt
|
566
|
+
self.variables = sorted(_parse_format_fields(self.template))
|
567
|
+
elif task == "chat.completions":
|
568
|
+
if not template:
|
569
|
+
template = [{"role": "user", "content": "{content}"}]
|
570
|
+
if not all(map(_is_valid_message, template)):
|
571
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
572
|
+
f"Template for task {task} expects type `dict` with keys 'content' "
|
573
|
+
f"and 'role', but got {type(template)}."
|
574
|
+
)
|
575
|
+
|
576
|
+
self.template = template.copy()
|
577
|
+
self.format_fn = self.format_chat
|
578
|
+
self.variables = sorted(
|
579
|
+
set(
|
580
|
+
itertools.chain.from_iterable(
|
581
|
+
_parse_format_fields(message.get("content"))
|
582
|
+
| _parse_format_fields(message.get("role"))
|
583
|
+
for message in self.template
|
584
|
+
)
|
585
|
+
)
|
586
|
+
)
|
587
|
+
if not self.variables:
|
588
|
+
self.template.append({"role": "user", "content": "{content}"})
|
589
|
+
self.variables.append("content")
|
590
|
+
else:
|
591
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
592
|
+
f"Task type ``{task}`` is not supported for formatting."
|
593
|
+
)
|
594
|
+
|
595
|
+
def format(self, **params):
|
596
|
+
if missing_params := set(self.variables) - set(params):
|
597
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
598
|
+
f"Expected parameters {self.variables} to be provided, "
|
599
|
+
f"only got {list(params)}, {list(missing_params)} are missing."
|
600
|
+
)
|
601
|
+
return self.format_fn(**params)
|
602
|
+
|
603
|
+
def format_prompt(self, **params):
|
604
|
+
return self.template.format(**{v: params[v] for v in self.variables})
|
605
|
+
|
606
|
+
def format_chat(self, **params):
|
607
|
+
format_args = {v: params[v] for v in self.variables}
|
608
|
+
return [
|
609
|
+
{
|
610
|
+
"role": message.get("role").format(**format_args),
|
611
|
+
"content": message.get("content").format(**format_args),
|
612
|
+
}
|
613
|
+
for message in self.template
|
614
|
+
]
|
615
|
+
|
616
|
+
|
617
|
+
def _first_string_column(pdf):
|
618
|
+
iter_str_cols = (c for c, v in pdf.iloc[0].items() if isinstance(v, str))
|
619
|
+
col = next(iter_str_cols, None)
|
620
|
+
if col is None:
|
621
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
622
|
+
f"Could not find a string column in the input data: {pdf.dtypes.to_dict()}"
|
623
|
+
)
|
624
|
+
return col
|
625
|
+
|
626
|
+
|
627
|
+
class _OpenAIWrapper:
|
628
|
+
def __init__(self, model):
|
629
|
+
task = model.pop("task")
|
630
|
+
if task not in _PYFUNC_SUPPORTED_TASKS:
|
631
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
632
|
+
f"Unsupported task: {task}. Supported tasks: {_PYFUNC_SUPPORTED_TASKS}."
|
633
|
+
)
|
634
|
+
self.model = model
|
635
|
+
self.task = task
|
636
|
+
self.api_config = _get_api_config()
|
637
|
+
self.api_token = _OAITokenHolder(self.api_config.api_type)
|
638
|
+
|
639
|
+
if self.task != "embeddings":
|
640
|
+
self._setup_completions()
|
641
|
+
|
642
|
+
def get_raw_model(self):
|
643
|
+
"""
|
644
|
+
Returns the underlying model.
|
645
|
+
"""
|
646
|
+
return self.model
|
647
|
+
|
648
|
+
def _setup_completions(self):
|
649
|
+
if self.task == "chat.completions":
|
650
|
+
self.template = self.model.get("messages", [])
|
651
|
+
else:
|
652
|
+
self.template = self.model.get("prompt")
|
653
|
+
self.formatter = _ContentFormatter(self.task, self.template)
|
654
|
+
|
655
|
+
def format_completions(self, params_list):
|
656
|
+
return [self.formatter.format(**params) for params in params_list]
|
657
|
+
|
658
|
+
def get_params_list(self, data):
|
659
|
+
if len(self.formatter.variables) == 1:
|
660
|
+
variable = self.formatter.variables[0]
|
661
|
+
if variable in data.columns:
|
662
|
+
return data[[variable]].to_dict(orient="records")
|
663
|
+
else:
|
664
|
+
first_string_column = _first_string_column(data)
|
665
|
+
return [{variable: s} for s in data[first_string_column]]
|
666
|
+
else:
|
667
|
+
return data[self.formatter.variables].to_dict(orient="records")
|
668
|
+
|
669
|
+
def get_client(self, max_retries: int, timeout: float):
|
670
|
+
# with_option method should not be used before v1.3.8: https://github.com/openai/openai-python/issues/865
|
671
|
+
if self.api_config.api_type in ("azure", "azure_ad", "azuread"):
|
672
|
+
from openai import AzureOpenAI
|
673
|
+
|
674
|
+
return AzureOpenAI(
|
675
|
+
api_key=self.api_token.token,
|
676
|
+
azure_endpoint=self.api_config.api_base,
|
677
|
+
api_version=self.api_config.api_version,
|
678
|
+
azure_deployment=self.api_config.deployment_id,
|
679
|
+
max_retries=max_retries,
|
680
|
+
timeout=timeout,
|
681
|
+
)
|
682
|
+
else:
|
683
|
+
from openai import OpenAI
|
684
|
+
|
685
|
+
return OpenAI(
|
686
|
+
api_key=self.api_token.token,
|
687
|
+
base_url=self.api_config.api_base,
|
688
|
+
max_retries=max_retries,
|
689
|
+
timeout=timeout,
|
690
|
+
)
|
691
|
+
|
692
|
+
def _predict_chat(self, data, params):
|
693
|
+
from mlflow.openai.api_request_parallel_processor import process_api_requests
|
694
|
+
|
695
|
+
_validate_model_params(self.task, self.model, params)
|
696
|
+
max_retries = params.pop("max_retries", self.api_config.max_retries)
|
697
|
+
timeout = params.pop("timeout", self.api_config.timeout)
|
698
|
+
|
699
|
+
messages_list = self.format_completions(self.get_params_list(data))
|
700
|
+
client = self.get_client(max_retries=max_retries, timeout=timeout)
|
701
|
+
|
702
|
+
requests = [
|
703
|
+
partial(
|
704
|
+
client.chat.completions.create,
|
705
|
+
messages=messages,
|
706
|
+
model=self.model["model"],
|
707
|
+
**params,
|
708
|
+
)
|
709
|
+
for messages in messages_list
|
710
|
+
]
|
711
|
+
|
712
|
+
results = process_api_requests(request_tasks=requests)
|
713
|
+
|
714
|
+
return [r.choices[0].message.content for r in results]
|
715
|
+
|
716
|
+
def _predict_completions(self, data, params):
|
717
|
+
from mlflow.openai.api_request_parallel_processor import process_api_requests
|
718
|
+
|
719
|
+
_validate_model_params(self.task, self.model, params)
|
720
|
+
prompts_list = self.format_completions(self.get_params_list(data))
|
721
|
+
max_retries = params.pop("max_retries", self.api_config.max_retries)
|
722
|
+
timeout = params.pop("timeout", self.api_config.timeout)
|
723
|
+
batch_size = params.pop("batch_size", self.api_config.batch_size)
|
724
|
+
_logger.debug(f"Requests are being batched by {batch_size} samples.")
|
725
|
+
|
726
|
+
client = self.get_client(max_retries=max_retries, timeout=timeout)
|
727
|
+
|
728
|
+
requests = [
|
729
|
+
partial(
|
730
|
+
client.completions.create,
|
731
|
+
prompt=prompts_list[i : i + batch_size],
|
732
|
+
model=self.model["model"],
|
733
|
+
**params,
|
734
|
+
)
|
735
|
+
for i in range(0, len(prompts_list), batch_size)
|
736
|
+
]
|
737
|
+
|
738
|
+
results = process_api_requests(request_tasks=requests)
|
739
|
+
|
740
|
+
return [row.text for batch in results for row in batch.choices]
|
741
|
+
|
742
|
+
def _predict_embeddings(self, data, params):
|
743
|
+
from mlflow.openai.api_request_parallel_processor import process_api_requests
|
744
|
+
|
745
|
+
_validate_model_params(self.task, self.model, params)
|
746
|
+
max_retries = params.pop("max_retries", self.api_config.max_retries)
|
747
|
+
timeout = params.pop("timeout", self.api_config.timeout)
|
748
|
+
batch_size = params.pop("batch_size", self.api_config.batch_size)
|
749
|
+
_logger.debug(f"Requests are being batched by {batch_size} samples.")
|
750
|
+
|
751
|
+
first_string_column = _first_string_column(data)
|
752
|
+
texts = data[first_string_column].tolist()
|
753
|
+
|
754
|
+
client = self.get_client(max_retries=max_retries, timeout=timeout)
|
755
|
+
|
756
|
+
requests = [
|
757
|
+
partial(
|
758
|
+
client.embeddings.create,
|
759
|
+
input=texts[i : i + batch_size],
|
760
|
+
model=self.model["model"],
|
761
|
+
**params,
|
762
|
+
)
|
763
|
+
for i in range(0, len(texts), batch_size)
|
764
|
+
]
|
765
|
+
|
766
|
+
results = process_api_requests(request_tasks=requests)
|
767
|
+
|
768
|
+
return [row.embedding for batch in results for row in batch.data]
|
769
|
+
|
770
|
+
def predict(self, data, params: Optional[dict[str, Any]] = None):
|
771
|
+
"""
|
772
|
+
Args:
|
773
|
+
data: Model input data.
|
774
|
+
params: Additional parameters to pass to the model for inference.
|
775
|
+
|
776
|
+
Returns:
|
777
|
+
Model predictions.
|
778
|
+
"""
|
779
|
+
self.api_token.refresh()
|
780
|
+
if self.task == "chat.completions":
|
781
|
+
return self._predict_chat(data, params or {})
|
782
|
+
elif self.task == "completions":
|
783
|
+
return self._predict_completions(data, params or {})
|
784
|
+
elif self.task == "embeddings":
|
785
|
+
return self._predict_embeddings(data, params or {})
|
786
|
+
|
787
|
+
|
788
|
+
def _load_pyfunc(path):
|
789
|
+
"""Loads PyFunc implementation. Called by ``pyfunc.load_model``.
|
790
|
+
|
791
|
+
Args:
|
792
|
+
path: Local filesystem path to the MLflow Model with the ``openai`` flavor.
|
793
|
+
"""
|
794
|
+
return _OpenAIWrapper(_load_model(path))
|
795
|
+
|
796
|
+
|
797
|
+
@experimental(version="2.3.0")
|
798
|
+
def load_model(model_uri, dst_path=None):
|
799
|
+
"""
|
800
|
+
Load an OpenAI model from a local file or a run.
|
801
|
+
|
802
|
+
Args:
|
803
|
+
model_uri: The location, in URI format, of the MLflow model. For example:
|
804
|
+
|
805
|
+
- ``/Users/me/path/to/local/model``
|
806
|
+
- ``relative/path/to/local/model``
|
807
|
+
- ``s3://my_bucket/path/to/model``
|
808
|
+
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
|
809
|
+
|
810
|
+
For more information about supported URI schemes, see
|
811
|
+
`Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
|
812
|
+
artifact-locations>`_.
|
813
|
+
dst_path: The local filesystem path to which to download the model artifact.
|
814
|
+
This directory must already exist. If unspecified, a local output
|
815
|
+
path will be created.
|
816
|
+
|
817
|
+
Returns:
|
818
|
+
A dictionary representing the OpenAI model.
|
819
|
+
"""
|
820
|
+
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
|
821
|
+
flavor_conf = _get_flavor_configuration(local_model_path, FLAVOR_NAME)
|
822
|
+
_add_code_from_conf_to_system_path(local_model_path, flavor_conf)
|
823
|
+
model_data_path = os.path.join(local_model_path, flavor_conf.get("data", MODEL_FILENAME))
|
824
|
+
return _load_model(model_data_path)
|