genesis-flow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_flow-1.0.0.dist-info/METADATA +822 -0
- genesis_flow-1.0.0.dist-info/RECORD +645 -0
- genesis_flow-1.0.0.dist-info/WHEEL +5 -0
- genesis_flow-1.0.0.dist-info/entry_points.txt +19 -0
- genesis_flow-1.0.0.dist-info/licenses/LICENSE.txt +202 -0
- genesis_flow-1.0.0.dist-info/top_level.txt +1 -0
- mlflow/__init__.py +367 -0
- mlflow/__main__.py +3 -0
- mlflow/ag2/__init__.py +56 -0
- mlflow/ag2/ag2_logger.py +294 -0
- mlflow/anthropic/__init__.py +40 -0
- mlflow/anthropic/autolog.py +129 -0
- mlflow/anthropic/chat.py +144 -0
- mlflow/artifacts/__init__.py +268 -0
- mlflow/autogen/__init__.py +144 -0
- mlflow/autogen/chat.py +142 -0
- mlflow/azure/__init__.py +26 -0
- mlflow/azure/auth_handler.py +257 -0
- mlflow/azure/client.py +319 -0
- mlflow/azure/config.py +120 -0
- mlflow/azure/connection_factory.py +340 -0
- mlflow/azure/exceptions.py +27 -0
- mlflow/azure/stores.py +327 -0
- mlflow/azure/utils.py +183 -0
- mlflow/bedrock/__init__.py +45 -0
- mlflow/bedrock/_autolog.py +202 -0
- mlflow/bedrock/chat.py +122 -0
- mlflow/bedrock/stream.py +160 -0
- mlflow/bedrock/utils.py +43 -0
- mlflow/cli.py +707 -0
- mlflow/client.py +12 -0
- mlflow/config/__init__.py +56 -0
- mlflow/crewai/__init__.py +79 -0
- mlflow/crewai/autolog.py +253 -0
- mlflow/crewai/chat.py +29 -0
- mlflow/data/__init__.py +75 -0
- mlflow/data/artifact_dataset_sources.py +170 -0
- mlflow/data/code_dataset_source.py +40 -0
- mlflow/data/dataset.py +123 -0
- mlflow/data/dataset_registry.py +168 -0
- mlflow/data/dataset_source.py +110 -0
- mlflow/data/dataset_source_registry.py +219 -0
- mlflow/data/delta_dataset_source.py +167 -0
- mlflow/data/digest_utils.py +108 -0
- mlflow/data/evaluation_dataset.py +562 -0
- mlflow/data/filesystem_dataset_source.py +81 -0
- mlflow/data/http_dataset_source.py +145 -0
- mlflow/data/huggingface_dataset.py +258 -0
- mlflow/data/huggingface_dataset_source.py +118 -0
- mlflow/data/meta_dataset.py +104 -0
- mlflow/data/numpy_dataset.py +223 -0
- mlflow/data/pandas_dataset.py +231 -0
- mlflow/data/polars_dataset.py +352 -0
- mlflow/data/pyfunc_dataset_mixin.py +31 -0
- mlflow/data/schema.py +76 -0
- mlflow/data/sources.py +1 -0
- mlflow/data/spark_dataset.py +406 -0
- mlflow/data/spark_dataset_source.py +74 -0
- mlflow/data/spark_delta_utils.py +118 -0
- mlflow/data/tensorflow_dataset.py +350 -0
- mlflow/data/uc_volume_dataset_source.py +81 -0
- mlflow/db.py +27 -0
- mlflow/dspy/__init__.py +17 -0
- mlflow/dspy/autolog.py +197 -0
- mlflow/dspy/callback.py +398 -0
- mlflow/dspy/constant.py +1 -0
- mlflow/dspy/load.py +93 -0
- mlflow/dspy/save.py +393 -0
- mlflow/dspy/util.py +109 -0
- mlflow/dspy/wrapper.py +226 -0
- mlflow/entities/__init__.py +104 -0
- mlflow/entities/_mlflow_object.py +52 -0
- mlflow/entities/assessment.py +545 -0
- mlflow/entities/assessment_error.py +80 -0
- mlflow/entities/assessment_source.py +141 -0
- mlflow/entities/dataset.py +92 -0
- mlflow/entities/dataset_input.py +51 -0
- mlflow/entities/dataset_summary.py +62 -0
- mlflow/entities/document.py +48 -0
- mlflow/entities/experiment.py +109 -0
- mlflow/entities/experiment_tag.py +35 -0
- mlflow/entities/file_info.py +45 -0
- mlflow/entities/input_tag.py +35 -0
- mlflow/entities/lifecycle_stage.py +35 -0
- mlflow/entities/logged_model.py +228 -0
- mlflow/entities/logged_model_input.py +26 -0
- mlflow/entities/logged_model_output.py +32 -0
- mlflow/entities/logged_model_parameter.py +46 -0
- mlflow/entities/logged_model_status.py +74 -0
- mlflow/entities/logged_model_tag.py +33 -0
- mlflow/entities/metric.py +200 -0
- mlflow/entities/model_registry/__init__.py +29 -0
- mlflow/entities/model_registry/_model_registry_entity.py +13 -0
- mlflow/entities/model_registry/model_version.py +243 -0
- mlflow/entities/model_registry/model_version_deployment_job_run_state.py +44 -0
- mlflow/entities/model_registry/model_version_deployment_job_state.py +70 -0
- mlflow/entities/model_registry/model_version_search.py +25 -0
- mlflow/entities/model_registry/model_version_stages.py +25 -0
- mlflow/entities/model_registry/model_version_status.py +35 -0
- mlflow/entities/model_registry/model_version_tag.py +35 -0
- mlflow/entities/model_registry/prompt.py +73 -0
- mlflow/entities/model_registry/prompt_version.py +244 -0
- mlflow/entities/model_registry/registered_model.py +175 -0
- mlflow/entities/model_registry/registered_model_alias.py +35 -0
- mlflow/entities/model_registry/registered_model_deployment_job_state.py +39 -0
- mlflow/entities/model_registry/registered_model_search.py +25 -0
- mlflow/entities/model_registry/registered_model_tag.py +35 -0
- mlflow/entities/multipart_upload.py +74 -0
- mlflow/entities/param.py +49 -0
- mlflow/entities/run.py +97 -0
- mlflow/entities/run_data.py +84 -0
- mlflow/entities/run_info.py +188 -0
- mlflow/entities/run_inputs.py +59 -0
- mlflow/entities/run_outputs.py +43 -0
- mlflow/entities/run_status.py +41 -0
- mlflow/entities/run_tag.py +36 -0
- mlflow/entities/source_type.py +31 -0
- mlflow/entities/span.py +774 -0
- mlflow/entities/span_event.py +96 -0
- mlflow/entities/span_status.py +102 -0
- mlflow/entities/trace.py +317 -0
- mlflow/entities/trace_data.py +71 -0
- mlflow/entities/trace_info.py +220 -0
- mlflow/entities/trace_info_v2.py +162 -0
- mlflow/entities/trace_location.py +173 -0
- mlflow/entities/trace_state.py +39 -0
- mlflow/entities/trace_status.py +68 -0
- mlflow/entities/view_type.py +51 -0
- mlflow/environment_variables.py +866 -0
- mlflow/evaluation/__init__.py +16 -0
- mlflow/evaluation/assessment.py +369 -0
- mlflow/evaluation/evaluation.py +411 -0
- mlflow/evaluation/evaluation_tag.py +61 -0
- mlflow/evaluation/fluent.py +48 -0
- mlflow/evaluation/utils.py +201 -0
- mlflow/exceptions.py +213 -0
- mlflow/experiments.py +140 -0
- mlflow/gemini/__init__.py +81 -0
- mlflow/gemini/autolog.py +186 -0
- mlflow/gemini/chat.py +261 -0
- mlflow/genai/__init__.py +71 -0
- mlflow/genai/datasets/__init__.py +67 -0
- mlflow/genai/datasets/evaluation_dataset.py +131 -0
- mlflow/genai/evaluation/__init__.py +3 -0
- mlflow/genai/evaluation/base.py +411 -0
- mlflow/genai/evaluation/constant.py +23 -0
- mlflow/genai/evaluation/utils.py +244 -0
- mlflow/genai/judges/__init__.py +21 -0
- mlflow/genai/judges/databricks.py +404 -0
- mlflow/genai/label_schemas/__init__.py +153 -0
- mlflow/genai/label_schemas/label_schemas.py +209 -0
- mlflow/genai/labeling/__init__.py +159 -0
- mlflow/genai/labeling/labeling.py +250 -0
- mlflow/genai/optimize/__init__.py +13 -0
- mlflow/genai/optimize/base.py +198 -0
- mlflow/genai/optimize/optimizers/__init__.py +4 -0
- mlflow/genai/optimize/optimizers/base_optimizer.py +38 -0
- mlflow/genai/optimize/optimizers/dspy_mipro_optimizer.py +221 -0
- mlflow/genai/optimize/optimizers/dspy_optimizer.py +91 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_callback.py +76 -0
- mlflow/genai/optimize/optimizers/utils/dspy_mipro_utils.py +18 -0
- mlflow/genai/optimize/types.py +75 -0
- mlflow/genai/optimize/util.py +30 -0
- mlflow/genai/prompts/__init__.py +206 -0
- mlflow/genai/scheduled_scorers.py +431 -0
- mlflow/genai/scorers/__init__.py +26 -0
- mlflow/genai/scorers/base.py +492 -0
- mlflow/genai/scorers/builtin_scorers.py +765 -0
- mlflow/genai/scorers/scorer_utils.py +138 -0
- mlflow/genai/scorers/validation.py +165 -0
- mlflow/genai/utils/data_validation.py +146 -0
- mlflow/genai/utils/enum_utils.py +23 -0
- mlflow/genai/utils/trace_utils.py +211 -0
- mlflow/groq/__init__.py +42 -0
- mlflow/groq/_groq_autolog.py +74 -0
- mlflow/johnsnowlabs/__init__.py +888 -0
- mlflow/langchain/__init__.py +24 -0
- mlflow/langchain/api_request_parallel_processor.py +330 -0
- mlflow/langchain/autolog.py +147 -0
- mlflow/langchain/chat_agent_langgraph.py +340 -0
- mlflow/langchain/constant.py +1 -0
- mlflow/langchain/constants.py +1 -0
- mlflow/langchain/databricks_dependencies.py +444 -0
- mlflow/langchain/langchain_tracer.py +597 -0
- mlflow/langchain/model.py +919 -0
- mlflow/langchain/output_parsers.py +142 -0
- mlflow/langchain/retriever_chain.py +153 -0
- mlflow/langchain/runnables.py +527 -0
- mlflow/langchain/utils/chat.py +402 -0
- mlflow/langchain/utils/logging.py +671 -0
- mlflow/langchain/utils/serialization.py +36 -0
- mlflow/legacy_databricks_cli/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/__init__.py +0 -0
- mlflow/legacy_databricks_cli/configure/provider.py +482 -0
- mlflow/litellm/__init__.py +175 -0
- mlflow/llama_index/__init__.py +22 -0
- mlflow/llama_index/autolog.py +55 -0
- mlflow/llama_index/chat.py +43 -0
- mlflow/llama_index/constant.py +1 -0
- mlflow/llama_index/model.py +577 -0
- mlflow/llama_index/pyfunc_wrapper.py +332 -0
- mlflow/llama_index/serialize_objects.py +188 -0
- mlflow/llama_index/tracer.py +561 -0
- mlflow/metrics/__init__.py +479 -0
- mlflow/metrics/base.py +39 -0
- mlflow/metrics/genai/__init__.py +25 -0
- mlflow/metrics/genai/base.py +101 -0
- mlflow/metrics/genai/genai_metric.py +771 -0
- mlflow/metrics/genai/metric_definitions.py +450 -0
- mlflow/metrics/genai/model_utils.py +371 -0
- mlflow/metrics/genai/prompt_template.py +68 -0
- mlflow/metrics/genai/prompts/__init__.py +0 -0
- mlflow/metrics/genai/prompts/v1.py +422 -0
- mlflow/metrics/genai/utils.py +6 -0
- mlflow/metrics/metric_definitions.py +619 -0
- mlflow/mismatch.py +34 -0
- mlflow/mistral/__init__.py +34 -0
- mlflow/mistral/autolog.py +71 -0
- mlflow/mistral/chat.py +135 -0
- mlflow/ml_package_versions.py +452 -0
- mlflow/models/__init__.py +97 -0
- mlflow/models/auth_policy.py +83 -0
- mlflow/models/cli.py +354 -0
- mlflow/models/container/__init__.py +294 -0
- mlflow/models/container/scoring_server/__init__.py +0 -0
- mlflow/models/container/scoring_server/nginx.conf +39 -0
- mlflow/models/dependencies_schemas.py +287 -0
- mlflow/models/display_utils.py +158 -0
- mlflow/models/docker_utils.py +211 -0
- mlflow/models/evaluation/__init__.py +23 -0
- mlflow/models/evaluation/_shap_patch.py +64 -0
- mlflow/models/evaluation/artifacts.py +194 -0
- mlflow/models/evaluation/base.py +1811 -0
- mlflow/models/evaluation/calibration_curve.py +109 -0
- mlflow/models/evaluation/default_evaluator.py +996 -0
- mlflow/models/evaluation/deprecated.py +23 -0
- mlflow/models/evaluation/evaluator_registry.py +80 -0
- mlflow/models/evaluation/evaluators/classifier.py +704 -0
- mlflow/models/evaluation/evaluators/default.py +233 -0
- mlflow/models/evaluation/evaluators/regressor.py +96 -0
- mlflow/models/evaluation/evaluators/shap.py +296 -0
- mlflow/models/evaluation/lift_curve.py +178 -0
- mlflow/models/evaluation/utils/metric.py +123 -0
- mlflow/models/evaluation/utils/trace.py +179 -0
- mlflow/models/evaluation/validation.py +434 -0
- mlflow/models/flavor_backend.py +93 -0
- mlflow/models/flavor_backend_registry.py +53 -0
- mlflow/models/model.py +1639 -0
- mlflow/models/model_config.py +150 -0
- mlflow/models/notebook_resources/agent_evaluation_template.html +235 -0
- mlflow/models/notebook_resources/eval_with_dataset_example.py +22 -0
- mlflow/models/notebook_resources/eval_with_synthetic_example.py +22 -0
- mlflow/models/python_api.py +369 -0
- mlflow/models/rag_signatures.py +128 -0
- mlflow/models/resources.py +321 -0
- mlflow/models/signature.py +662 -0
- mlflow/models/utils.py +2054 -0
- mlflow/models/wheeled_model.py +280 -0
- mlflow/openai/__init__.py +57 -0
- mlflow/openai/_agent_tracer.py +364 -0
- mlflow/openai/api_request_parallel_processor.py +131 -0
- mlflow/openai/autolog.py +509 -0
- mlflow/openai/constant.py +1 -0
- mlflow/openai/model.py +824 -0
- mlflow/openai/utils/chat_schema.py +367 -0
- mlflow/optuna/__init__.py +3 -0
- mlflow/optuna/storage.py +646 -0
- mlflow/plugins/__init__.py +72 -0
- mlflow/plugins/base.py +358 -0
- mlflow/plugins/builtin/__init__.py +24 -0
- mlflow/plugins/builtin/pytorch_plugin.py +150 -0
- mlflow/plugins/builtin/sklearn_plugin.py +158 -0
- mlflow/plugins/builtin/transformers_plugin.py +187 -0
- mlflow/plugins/cli.py +321 -0
- mlflow/plugins/discovery.py +340 -0
- mlflow/plugins/manager.py +465 -0
- mlflow/plugins/registry.py +316 -0
- mlflow/plugins/templates/framework_plugin_template.py +329 -0
- mlflow/prompt/constants.py +20 -0
- mlflow/prompt/promptlab_model.py +197 -0
- mlflow/prompt/registry_utils.py +248 -0
- mlflow/promptflow/__init__.py +495 -0
- mlflow/protos/__init__.py +0 -0
- mlflow/protos/assessments_pb2.py +174 -0
- mlflow/protos/databricks_artifacts_pb2.py +489 -0
- mlflow/protos/databricks_filesystem_service_pb2.py +196 -0
- mlflow/protos/databricks_managed_catalog_messages_pb2.py +95 -0
- mlflow/protos/databricks_managed_catalog_service_pb2.py +86 -0
- mlflow/protos/databricks_pb2.py +267 -0
- mlflow/protos/databricks_trace_server_pb2.py +374 -0
- mlflow/protos/databricks_uc_registry_messages_pb2.py +1249 -0
- mlflow/protos/databricks_uc_registry_service_pb2.py +170 -0
- mlflow/protos/facet_feature_statistics_pb2.py +296 -0
- mlflow/protos/internal_pb2.py +77 -0
- mlflow/protos/mlflow_artifacts_pb2.py +336 -0
- mlflow/protos/model_registry_pb2.py +1073 -0
- mlflow/protos/scalapb/__init__.py +0 -0
- mlflow/protos/scalapb/scalapb_pb2.py +104 -0
- mlflow/protos/service_pb2.py +2600 -0
- mlflow/protos/unity_catalog_oss_messages_pb2.py +457 -0
- mlflow/protos/unity_catalog_oss_service_pb2.py +130 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2.py +447 -0
- mlflow/protos/unity_catalog_prompt_messages_pb2_grpc.py +24 -0
- mlflow/protos/unity_catalog_prompt_service_pb2.py +164 -0
- mlflow/protos/unity_catalog_prompt_service_pb2_grpc.py +785 -0
- mlflow/py.typed +0 -0
- mlflow/pydantic_ai/__init__.py +57 -0
- mlflow/pydantic_ai/autolog.py +173 -0
- mlflow/pyfunc/__init__.py +3844 -0
- mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py +61 -0
- mlflow/pyfunc/backend.py +523 -0
- mlflow/pyfunc/context.py +78 -0
- mlflow/pyfunc/dbconnect_artifact_cache.py +144 -0
- mlflow/pyfunc/loaders/__init__.py +7 -0
- mlflow/pyfunc/loaders/chat_agent.py +117 -0
- mlflow/pyfunc/loaders/chat_model.py +125 -0
- mlflow/pyfunc/loaders/code_model.py +31 -0
- mlflow/pyfunc/loaders/responses_agent.py +112 -0
- mlflow/pyfunc/mlserver.py +46 -0
- mlflow/pyfunc/model.py +1473 -0
- mlflow/pyfunc/scoring_server/__init__.py +604 -0
- mlflow/pyfunc/scoring_server/app.py +7 -0
- mlflow/pyfunc/scoring_server/client.py +146 -0
- mlflow/pyfunc/spark_model_cache.py +48 -0
- mlflow/pyfunc/stdin_server.py +44 -0
- mlflow/pyfunc/utils/__init__.py +3 -0
- mlflow/pyfunc/utils/data_validation.py +224 -0
- mlflow/pyfunc/utils/environment.py +22 -0
- mlflow/pyfunc/utils/input_converter.py +47 -0
- mlflow/pyfunc/utils/serving_data_parser.py +11 -0
- mlflow/pytorch/__init__.py +1171 -0
- mlflow/pytorch/_lightning_autolog.py +580 -0
- mlflow/pytorch/_pytorch_autolog.py +50 -0
- mlflow/pytorch/pickle_module.py +35 -0
- mlflow/rfunc/__init__.py +42 -0
- mlflow/rfunc/backend.py +134 -0
- mlflow/runs.py +89 -0
- mlflow/server/__init__.py +302 -0
- mlflow/server/auth/__init__.py +1224 -0
- mlflow/server/auth/__main__.py +4 -0
- mlflow/server/auth/basic_auth.ini +6 -0
- mlflow/server/auth/cli.py +11 -0
- mlflow/server/auth/client.py +537 -0
- mlflow/server/auth/config.py +34 -0
- mlflow/server/auth/db/__init__.py +0 -0
- mlflow/server/auth/db/cli.py +18 -0
- mlflow/server/auth/db/migrations/__init__.py +0 -0
- mlflow/server/auth/db/migrations/alembic.ini +110 -0
- mlflow/server/auth/db/migrations/env.py +76 -0
- mlflow/server/auth/db/migrations/versions/8606fa83a998_initial_migration.py +51 -0
- mlflow/server/auth/db/migrations/versions/__init__.py +0 -0
- mlflow/server/auth/db/models.py +67 -0
- mlflow/server/auth/db/utils.py +37 -0
- mlflow/server/auth/entities.py +165 -0
- mlflow/server/auth/logo.py +14 -0
- mlflow/server/auth/permissions.py +65 -0
- mlflow/server/auth/routes.py +18 -0
- mlflow/server/auth/sqlalchemy_store.py +263 -0
- mlflow/server/graphql/__init__.py +0 -0
- mlflow/server/graphql/autogenerated_graphql_schema.py +353 -0
- mlflow/server/graphql/graphql_custom_scalars.py +24 -0
- mlflow/server/graphql/graphql_errors.py +15 -0
- mlflow/server/graphql/graphql_no_batching.py +89 -0
- mlflow/server/graphql/graphql_schema_extensions.py +74 -0
- mlflow/server/handlers.py +3217 -0
- mlflow/server/prometheus_exporter.py +17 -0
- mlflow/server/validation.py +30 -0
- mlflow/shap/__init__.py +691 -0
- mlflow/sklearn/__init__.py +1994 -0
- mlflow/sklearn/utils.py +1041 -0
- mlflow/smolagents/__init__.py +66 -0
- mlflow/smolagents/autolog.py +139 -0
- mlflow/smolagents/chat.py +29 -0
- mlflow/store/__init__.py +10 -0
- mlflow/store/_unity_catalog/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/__init__.py +1 -0
- mlflow/store/_unity_catalog/lineage/constants.py +2 -0
- mlflow/store/_unity_catalog/registry/__init__.py +6 -0
- mlflow/store/_unity_catalog/registry/prompt_info.py +75 -0
- mlflow/store/_unity_catalog/registry/rest_store.py +1740 -0
- mlflow/store/_unity_catalog/registry/uc_oss_rest_store.py +507 -0
- mlflow/store/_unity_catalog/registry/utils.py +121 -0
- mlflow/store/artifact/__init__.py +0 -0
- mlflow/store/artifact/artifact_repo.py +472 -0
- mlflow/store/artifact/artifact_repository_registry.py +154 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +275 -0
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +295 -0
- mlflow/store/artifact/cli.py +141 -0
- mlflow/store/artifact/cloud_artifact_repo.py +332 -0
- mlflow/store/artifact/databricks_artifact_repo.py +729 -0
- mlflow/store/artifact/databricks_artifact_repo_resources.py +301 -0
- mlflow/store/artifact/databricks_logged_model_artifact_repo.py +93 -0
- mlflow/store/artifact/databricks_models_artifact_repo.py +216 -0
- mlflow/store/artifact/databricks_sdk_artifact_repo.py +134 -0
- mlflow/store/artifact/databricks_sdk_models_artifact_repo.py +97 -0
- mlflow/store/artifact/dbfs_artifact_repo.py +240 -0
- mlflow/store/artifact/ftp_artifact_repo.py +132 -0
- mlflow/store/artifact/gcs_artifact_repo.py +296 -0
- mlflow/store/artifact/hdfs_artifact_repo.py +209 -0
- mlflow/store/artifact/http_artifact_repo.py +218 -0
- mlflow/store/artifact/local_artifact_repo.py +142 -0
- mlflow/store/artifact/mlflow_artifacts_repo.py +94 -0
- mlflow/store/artifact/models_artifact_repo.py +259 -0
- mlflow/store/artifact/optimized_s3_artifact_repo.py +356 -0
- mlflow/store/artifact/presigned_url_artifact_repo.py +173 -0
- mlflow/store/artifact/r2_artifact_repo.py +70 -0
- mlflow/store/artifact/runs_artifact_repo.py +265 -0
- mlflow/store/artifact/s3_artifact_repo.py +330 -0
- mlflow/store/artifact/sftp_artifact_repo.py +141 -0
- mlflow/store/artifact/uc_volume_artifact_repo.py +76 -0
- mlflow/store/artifact/unity_catalog_models_artifact_repo.py +168 -0
- mlflow/store/artifact/unity_catalog_oss_models_artifact_repo.py +168 -0
- mlflow/store/artifact/utils/__init__.py +0 -0
- mlflow/store/artifact/utils/models.py +148 -0
- mlflow/store/db/__init__.py +0 -0
- mlflow/store/db/base_sql_model.py +3 -0
- mlflow/store/db/db_types.py +10 -0
- mlflow/store/db/utils.py +314 -0
- mlflow/store/db_migrations/__init__.py +0 -0
- mlflow/store/db_migrations/alembic.ini +74 -0
- mlflow/store/db_migrations/env.py +84 -0
- mlflow/store/db_migrations/versions/0584bdc529eb_add_cascading_deletion_to_datasets_from_experiments.py +88 -0
- mlflow/store/db_migrations/versions/0a8213491aaa_drop_duplicate_killed_constraint.py +49 -0
- mlflow/store/db_migrations/versions/0c779009ac13_add_deleted_time_field_to_runs_table.py +24 -0
- mlflow/store/db_migrations/versions/181f10493468_allow_nulls_for_metric_values.py +35 -0
- mlflow/store/db_migrations/versions/27a6a02d2cf1_add_model_version_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/2b4d017a5e9b_add_model_registry_tables_to_db.py +77 -0
- mlflow/store/db_migrations/versions/2d6e25af4d3e_increase_max_param_val_length.py +33 -0
- mlflow/store/db_migrations/versions/3500859a5d39_add_model_aliases_table.py +50 -0
- mlflow/store/db_migrations/versions/39d1c3be5f05_add_is_nan_constraint_for_metrics_tables_if_necessary.py +41 -0
- mlflow/store/db_migrations/versions/400f98739977_add_logged_model_tables.py +123 -0
- mlflow/store/db_migrations/versions/4465047574b1_increase_max_dataset_schema_size.py +38 -0
- mlflow/store/db_migrations/versions/451aebb31d03_add_metric_step.py +35 -0
- mlflow/store/db_migrations/versions/5b0e9adcef9c_add_cascade_deletion_to_trace_tables_fk.py +40 -0
- mlflow/store/db_migrations/versions/6953534de441_add_step_to_inputs_table.py +25 -0
- mlflow/store/db_migrations/versions/728d730b5ebd_add_registered_model_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/7ac759974ad8_update_run_tags_with_larger_limit.py +36 -0
- mlflow/store/db_migrations/versions/7f2a7d5fae7d_add_datasets_inputs_input_tags_tables.py +82 -0
- mlflow/store/db_migrations/versions/84291f40a231_add_run_link_to_model_version.py +26 -0
- mlflow/store/db_migrations/versions/867495a8f9d4_add_trace_tables.py +90 -0
- mlflow/store/db_migrations/versions/89d4b8295536_create_latest_metrics_table.py +169 -0
- mlflow/store/db_migrations/versions/90e64c465722_migrate_user_column_to_tags.py +64 -0
- mlflow/store/db_migrations/versions/97727af70f4d_creation_time_last_update_time_experiments.py +25 -0
- mlflow/store/db_migrations/versions/__init__.py +0 -0
- mlflow/store/db_migrations/versions/a8c4a736bde6_allow_nulls_for_run_id.py +27 -0
- mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py +29 -0
- mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py +26 -0
- mlflow/store/db_migrations/versions/bda7b8c39065_increase_model_version_tag_value_limit.py +38 -0
- mlflow/store/db_migrations/versions/c48cb773bb87_reset_default_value_for_is_nan_in_metrics_table_for_mysql.py +41 -0
- mlflow/store/db_migrations/versions/cbc13b556ace_add_v3_trace_schema_columns.py +31 -0
- mlflow/store/db_migrations/versions/cc1f77228345_change_param_value_length_to_500.py +34 -0
- mlflow/store/db_migrations/versions/cfd24bdc0731_update_run_status_constraint_with_killed.py +78 -0
- mlflow/store/db_migrations/versions/df50e92ffc5e_add_experiment_tags_table.py +38 -0
- mlflow/store/db_migrations/versions/f5a4f2784254_increase_run_tag_value_limit.py +36 -0
- mlflow/store/entities/__init__.py +3 -0
- mlflow/store/entities/paged_list.py +18 -0
- mlflow/store/model_registry/__init__.py +10 -0
- mlflow/store/model_registry/abstract_store.py +1081 -0
- mlflow/store/model_registry/base_rest_store.py +44 -0
- mlflow/store/model_registry/databricks_workspace_model_registry_rest_store.py +37 -0
- mlflow/store/model_registry/dbmodels/__init__.py +0 -0
- mlflow/store/model_registry/dbmodels/models.py +206 -0
- mlflow/store/model_registry/file_store.py +1091 -0
- mlflow/store/model_registry/rest_store.py +481 -0
- mlflow/store/model_registry/sqlalchemy_store.py +1286 -0
- mlflow/store/tracking/__init__.py +23 -0
- mlflow/store/tracking/abstract_store.py +816 -0
- mlflow/store/tracking/dbmodels/__init__.py +0 -0
- mlflow/store/tracking/dbmodels/initial_models.py +243 -0
- mlflow/store/tracking/dbmodels/models.py +1073 -0
- mlflow/store/tracking/file_store.py +2438 -0
- mlflow/store/tracking/postgres_managed_identity.py +146 -0
- mlflow/store/tracking/rest_store.py +1131 -0
- mlflow/store/tracking/sqlalchemy_store.py +2785 -0
- mlflow/system_metrics/__init__.py +61 -0
- mlflow/system_metrics/metrics/__init__.py +0 -0
- mlflow/system_metrics/metrics/base_metrics_monitor.py +32 -0
- mlflow/system_metrics/metrics/cpu_monitor.py +23 -0
- mlflow/system_metrics/metrics/disk_monitor.py +21 -0
- mlflow/system_metrics/metrics/gpu_monitor.py +71 -0
- mlflow/system_metrics/metrics/network_monitor.py +34 -0
- mlflow/system_metrics/metrics/rocm_monitor.py +123 -0
- mlflow/system_metrics/system_metrics_monitor.py +198 -0
- mlflow/tracing/__init__.py +16 -0
- mlflow/tracing/assessment.py +356 -0
- mlflow/tracing/client.py +531 -0
- mlflow/tracing/config.py +125 -0
- mlflow/tracing/constant.py +105 -0
- mlflow/tracing/destination.py +81 -0
- mlflow/tracing/display/__init__.py +40 -0
- mlflow/tracing/display/display_handler.py +196 -0
- mlflow/tracing/export/async_export_queue.py +186 -0
- mlflow/tracing/export/inference_table.py +138 -0
- mlflow/tracing/export/mlflow_v3.py +137 -0
- mlflow/tracing/export/utils.py +70 -0
- mlflow/tracing/fluent.py +1417 -0
- mlflow/tracing/processor/base_mlflow.py +199 -0
- mlflow/tracing/processor/inference_table.py +175 -0
- mlflow/tracing/processor/mlflow_v3.py +47 -0
- mlflow/tracing/processor/otel.py +73 -0
- mlflow/tracing/provider.py +487 -0
- mlflow/tracing/trace_manager.py +200 -0
- mlflow/tracing/utils/__init__.py +616 -0
- mlflow/tracing/utils/artifact_utils.py +28 -0
- mlflow/tracing/utils/copy.py +55 -0
- mlflow/tracing/utils/environment.py +55 -0
- mlflow/tracing/utils/exception.py +21 -0
- mlflow/tracing/utils/once.py +35 -0
- mlflow/tracing/utils/otlp.py +63 -0
- mlflow/tracing/utils/processor.py +54 -0
- mlflow/tracing/utils/search.py +292 -0
- mlflow/tracing/utils/timeout.py +250 -0
- mlflow/tracing/utils/token.py +19 -0
- mlflow/tracing/utils/truncation.py +124 -0
- mlflow/tracing/utils/warning.py +76 -0
- mlflow/tracking/__init__.py +39 -0
- mlflow/tracking/_model_registry/__init__.py +1 -0
- mlflow/tracking/_model_registry/client.py +764 -0
- mlflow/tracking/_model_registry/fluent.py +853 -0
- mlflow/tracking/_model_registry/registry.py +67 -0
- mlflow/tracking/_model_registry/utils.py +251 -0
- mlflow/tracking/_tracking_service/__init__.py +0 -0
- mlflow/tracking/_tracking_service/client.py +883 -0
- mlflow/tracking/_tracking_service/registry.py +56 -0
- mlflow/tracking/_tracking_service/utils.py +275 -0
- mlflow/tracking/artifact_utils.py +179 -0
- mlflow/tracking/client.py +5900 -0
- mlflow/tracking/context/__init__.py +0 -0
- mlflow/tracking/context/abstract_context.py +35 -0
- mlflow/tracking/context/databricks_cluster_context.py +15 -0
- mlflow/tracking/context/databricks_command_context.py +15 -0
- mlflow/tracking/context/databricks_job_context.py +49 -0
- mlflow/tracking/context/databricks_notebook_context.py +41 -0
- mlflow/tracking/context/databricks_repo_context.py +43 -0
- mlflow/tracking/context/default_context.py +51 -0
- mlflow/tracking/context/git_context.py +32 -0
- mlflow/tracking/context/registry.py +98 -0
- mlflow/tracking/context/system_environment_context.py +15 -0
- mlflow/tracking/default_experiment/__init__.py +1 -0
- mlflow/tracking/default_experiment/abstract_context.py +43 -0
- mlflow/tracking/default_experiment/databricks_notebook_experiment_provider.py +44 -0
- mlflow/tracking/default_experiment/registry.py +75 -0
- mlflow/tracking/fluent.py +3595 -0
- mlflow/tracking/metric_value_conversion_utils.py +93 -0
- mlflow/tracking/multimedia.py +206 -0
- mlflow/tracking/registry.py +86 -0
- mlflow/tracking/request_auth/__init__.py +0 -0
- mlflow/tracking/request_auth/abstract_request_auth_provider.py +34 -0
- mlflow/tracking/request_auth/registry.py +60 -0
- mlflow/tracking/request_header/__init__.py +0 -0
- mlflow/tracking/request_header/abstract_request_header_provider.py +36 -0
- mlflow/tracking/request_header/databricks_request_header_provider.py +38 -0
- mlflow/tracking/request_header/default_request_header_provider.py +17 -0
- mlflow/tracking/request_header/registry.py +79 -0
- mlflow/transformers/__init__.py +2982 -0
- mlflow/transformers/flavor_config.py +258 -0
- mlflow/transformers/hub_utils.py +83 -0
- mlflow/transformers/llm_inference_utils.py +468 -0
- mlflow/transformers/model_io.py +301 -0
- mlflow/transformers/peft.py +51 -0
- mlflow/transformers/signature.py +183 -0
- mlflow/transformers/torch_utils.py +55 -0
- mlflow/types/__init__.py +21 -0
- mlflow/types/agent.py +270 -0
- mlflow/types/chat.py +240 -0
- mlflow/types/llm.py +935 -0
- mlflow/types/responses.py +139 -0
- mlflow/types/responses_helpers.py +416 -0
- mlflow/types/schema.py +1505 -0
- mlflow/types/type_hints.py +647 -0
- mlflow/types/utils.py +753 -0
- mlflow/utils/__init__.py +283 -0
- mlflow/utils/_capture_modules.py +256 -0
- mlflow/utils/_capture_transformers_modules.py +75 -0
- mlflow/utils/_spark_utils.py +201 -0
- mlflow/utils/_unity_catalog_oss_utils.py +97 -0
- mlflow/utils/_unity_catalog_utils.py +479 -0
- mlflow/utils/annotations.py +218 -0
- mlflow/utils/arguments_utils.py +16 -0
- mlflow/utils/async_logging/__init__.py +1 -0
- mlflow/utils/async_logging/async_artifacts_logging_queue.py +258 -0
- mlflow/utils/async_logging/async_logging_queue.py +366 -0
- mlflow/utils/async_logging/run_artifact.py +38 -0
- mlflow/utils/async_logging/run_batch.py +58 -0
- mlflow/utils/async_logging/run_operations.py +49 -0
- mlflow/utils/autologging_utils/__init__.py +737 -0
- mlflow/utils/autologging_utils/client.py +432 -0
- mlflow/utils/autologging_utils/config.py +33 -0
- mlflow/utils/autologging_utils/events.py +294 -0
- mlflow/utils/autologging_utils/logging_and_warnings.py +328 -0
- mlflow/utils/autologging_utils/metrics_queue.py +71 -0
- mlflow/utils/autologging_utils/safety.py +1104 -0
- mlflow/utils/autologging_utils/versioning.py +95 -0
- mlflow/utils/checkpoint_utils.py +206 -0
- mlflow/utils/class_utils.py +6 -0
- mlflow/utils/cli_args.py +257 -0
- mlflow/utils/conda.py +354 -0
- mlflow/utils/credentials.py +231 -0
- mlflow/utils/data_utils.py +17 -0
- mlflow/utils/databricks_utils.py +1436 -0
- mlflow/utils/docstring_utils.py +477 -0
- mlflow/utils/doctor.py +133 -0
- mlflow/utils/download_cloud_file_chunk.py +43 -0
- mlflow/utils/env_manager.py +16 -0
- mlflow/utils/env_pack.py +131 -0
- mlflow/utils/environment.py +1009 -0
- mlflow/utils/exception_utils.py +14 -0
- mlflow/utils/file_utils.py +978 -0
- mlflow/utils/git_utils.py +77 -0
- mlflow/utils/gorilla.py +797 -0
- mlflow/utils/import_hooks/__init__.py +363 -0
- mlflow/utils/lazy_load.py +51 -0
- mlflow/utils/logging_utils.py +168 -0
- mlflow/utils/mime_type_utils.py +58 -0
- mlflow/utils/mlflow_tags.py +103 -0
- mlflow/utils/model_utils.py +486 -0
- mlflow/utils/name_utils.py +346 -0
- mlflow/utils/nfs_on_spark.py +62 -0
- mlflow/utils/openai_utils.py +164 -0
- mlflow/utils/os.py +12 -0
- mlflow/utils/oss_registry_utils.py +29 -0
- mlflow/utils/plugins.py +17 -0
- mlflow/utils/process.py +182 -0
- mlflow/utils/promptlab_utils.py +146 -0
- mlflow/utils/proto_json_utils.py +743 -0
- mlflow/utils/pydantic_utils.py +54 -0
- mlflow/utils/request_utils.py +279 -0
- mlflow/utils/requirements_utils.py +704 -0
- mlflow/utils/rest_utils.py +673 -0
- mlflow/utils/search_logged_model_utils.py +127 -0
- mlflow/utils/search_utils.py +2111 -0
- mlflow/utils/secure_loading.py +221 -0
- mlflow/utils/security_validation.py +384 -0
- mlflow/utils/server_cli_utils.py +61 -0
- mlflow/utils/spark_utils.py +15 -0
- mlflow/utils/string_utils.py +138 -0
- mlflow/utils/thread_utils.py +63 -0
- mlflow/utils/time.py +54 -0
- mlflow/utils/timeout.py +42 -0
- mlflow/utils/uri.py +572 -0
- mlflow/utils/validation.py +662 -0
- mlflow/utils/virtualenv.py +458 -0
- mlflow/utils/warnings_utils.py +25 -0
- mlflow/utils/yaml_utils.py +179 -0
- mlflow/version.py +24 -0
@@ -0,0 +1,671 @@
|
|
1
|
+
"""Utility functions for mlflow.langchain."""
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import functools
|
5
|
+
import importlib
|
6
|
+
import json
|
7
|
+
import logging
|
8
|
+
import os
|
9
|
+
import re
|
10
|
+
import shutil
|
11
|
+
import types
|
12
|
+
import warnings
|
13
|
+
from functools import lru_cache
|
14
|
+
from importlib.util import find_spec
|
15
|
+
from typing import Any, Callable, NamedTuple
|
16
|
+
|
17
|
+
import cloudpickle
|
18
|
+
import yaml
|
19
|
+
from packaging import version
|
20
|
+
from packaging.version import Version
|
21
|
+
|
22
|
+
import mlflow
|
23
|
+
from mlflow.exceptions import MlflowException
|
24
|
+
from mlflow.models.utils import _validate_and_get_model_code_path
|
25
|
+
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
|
26
|
+
from mlflow.utils.class_utils import _get_class_from_string
|
27
|
+
|
28
|
+
_AGENT_PRIMITIVES_FILE_NAME = "agent_primitive_args.json"
|
29
|
+
_AGENT_PRIMITIVES_DATA_KEY = "agent_primitive_data"
|
30
|
+
_AGENT_DATA_FILE_NAME = "agent.yaml"
|
31
|
+
_AGENT_DATA_KEY = "agent_data"
|
32
|
+
_TOOLS_DATA_FILE_NAME = "tools.pkl"
|
33
|
+
_TOOLS_DATA_KEY = "tools_data"
|
34
|
+
_LOADER_FN_FILE_NAME = "loader_fn.pkl"
|
35
|
+
_LOADER_FN_KEY = "loader_fn"
|
36
|
+
_LOADER_ARG_KEY = "loader_arg"
|
37
|
+
_PERSIST_DIR_NAME = "persist_dir_data"
|
38
|
+
_PERSIST_DIR_KEY = "persist_dir"
|
39
|
+
_MODEL_DATA_YAML_FILE_NAME = "model.yaml"
|
40
|
+
_MODEL_DATA_PKL_FILE_NAME = "model.pkl"
|
41
|
+
_MODEL_DATA_FOLDER_NAME = "model"
|
42
|
+
_MODEL_DATA_KEY = "model_data"
|
43
|
+
_MODEL_TYPE_KEY = "model_type"
|
44
|
+
_RUNNABLE_LOAD_KEY = "runnable_load"
|
45
|
+
_BASE_LOAD_KEY = "base_load"
|
46
|
+
_CONFIG_LOAD_KEY = "config_load"
|
47
|
+
_PICKLE_LOAD_KEY = "pickle_load"
|
48
|
+
_MODEL_LOAD_KEY = "model_load"
|
49
|
+
_UNSUPPORTED_MODEL_WARNING_MESSAGE = (
|
50
|
+
"MLflow does not guarantee support for Chains outside of the subclasses of LLMChain, found %s"
|
51
|
+
)
|
52
|
+
_UNSUPPORTED_LLM_WARNING_MESSAGE = (
|
53
|
+
"MLflow does not guarantee support for LLMs outside of HuggingFacePipeline and OpenAI, found %s"
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
_CHAT_MODELS_ERROR_MSG = re.compile("Loading (openai-chat|azure-openai-chat) LLM not supported")
|
58
|
+
|
59
|
+
|
60
|
+
try:
|
61
|
+
import langchain_community
|
62
|
+
|
63
|
+
# Since langchain-community 0.0.27, saving or loading a module that relies on the pickle
|
64
|
+
# deserialization requires passing `allow_dangerous_deserialization=True`.
|
65
|
+
IS_PICKLE_SERIALIZATION_RESTRICTED = Version(langchain_community.__version__) >= Version(
|
66
|
+
"0.0.27"
|
67
|
+
)
|
68
|
+
except ImportError:
|
69
|
+
IS_PICKLE_SERIALIZATION_RESTRICTED = False
|
70
|
+
|
71
|
+
logger = logging.getLogger(__name__)
|
72
|
+
|
73
|
+
|
74
|
+
@lru_cache
|
75
|
+
def base_lc_types():
|
76
|
+
# add this import to avoid missing module error
|
77
|
+
import langchain.agents
|
78
|
+
import langchain.agents.agent
|
79
|
+
import langchain.chains.base
|
80
|
+
import langchain.schema
|
81
|
+
|
82
|
+
return (
|
83
|
+
langchain.chains.base.Chain,
|
84
|
+
langchain.agents.agent.AgentExecutor,
|
85
|
+
langchain.schema.BaseRetriever,
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
@lru_cache
|
90
|
+
def picklable_runnable_types():
|
91
|
+
"""
|
92
|
+
Runnable types that can be pickled and unpickled by cloudpickle.
|
93
|
+
"""
|
94
|
+
from langchain.chat_models.base import SimpleChatModel
|
95
|
+
from langchain.prompts import ChatPromptTemplate
|
96
|
+
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
|
97
|
+
|
98
|
+
return (
|
99
|
+
SimpleChatModel,
|
100
|
+
ChatPromptTemplate,
|
101
|
+
RunnablePassthrough,
|
102
|
+
RunnableLambda,
|
103
|
+
)
|
104
|
+
|
105
|
+
|
106
|
+
@lru_cache
|
107
|
+
def lc_runnable_with_steps_types():
|
108
|
+
from langchain.schema.runnable import RunnableParallel, RunnableSequence
|
109
|
+
|
110
|
+
return (RunnableParallel, RunnableSequence)
|
111
|
+
|
112
|
+
|
113
|
+
def lc_runnable_assign_types():
|
114
|
+
from langchain.schema.runnable.passthrough import RunnableAssign
|
115
|
+
|
116
|
+
return (RunnableAssign,)
|
117
|
+
|
118
|
+
|
119
|
+
def lc_runnable_branch_types():
|
120
|
+
from langchain.schema.runnable import RunnableBranch
|
121
|
+
|
122
|
+
return (RunnableBranch,)
|
123
|
+
|
124
|
+
|
125
|
+
def lc_runnable_binding_types():
|
126
|
+
from langchain.schema.runnable import RunnableBinding
|
127
|
+
|
128
|
+
return (RunnableBinding,)
|
129
|
+
|
130
|
+
|
131
|
+
def lc_runnables_types():
|
132
|
+
return (
|
133
|
+
picklable_runnable_types()
|
134
|
+
+ lc_runnable_with_steps_types()
|
135
|
+
+ lc_runnable_branch_types()
|
136
|
+
+ lc_runnable_assign_types()
|
137
|
+
+ lc_runnable_binding_types()
|
138
|
+
)
|
139
|
+
|
140
|
+
|
141
|
+
def langgraph_types():
|
142
|
+
try:
|
143
|
+
from langgraph.graph.state import CompiledStateGraph
|
144
|
+
|
145
|
+
return (CompiledStateGraph,)
|
146
|
+
except ImportError:
|
147
|
+
return ()
|
148
|
+
|
149
|
+
|
150
|
+
def supported_lc_types():
|
151
|
+
return base_lc_types() + lc_runnables_types() + langgraph_types()
|
152
|
+
|
153
|
+
|
154
|
+
# Wrapping as a function to avoid callign supported_lc_types() at import time
|
155
|
+
def get_unsupported_model_message(model_type):
|
156
|
+
return (
|
157
|
+
"MLflow langchain flavor only supports subclasses of "
|
158
|
+
f"{supported_lc_types()}, found {model_type}."
|
159
|
+
)
|
160
|
+
|
161
|
+
|
162
|
+
@lru_cache
|
163
|
+
def custom_type_to_loader_dict():
|
164
|
+
# helper function to load output_parsers from config
|
165
|
+
def _load_output_parser(config: dict[str, Any]) -> Any:
|
166
|
+
"""Load output parser."""
|
167
|
+
from langchain.schema.output_parser import StrOutputParser
|
168
|
+
|
169
|
+
output_parser_type = config.pop("_type", None)
|
170
|
+
if output_parser_type == "default":
|
171
|
+
return StrOutputParser(**config)
|
172
|
+
else:
|
173
|
+
raise ValueError(f"Unsupported output parser {output_parser_type}")
|
174
|
+
|
175
|
+
return {"default": _load_output_parser}
|
176
|
+
|
177
|
+
|
178
|
+
class _SpecialChainInfo(NamedTuple):
|
179
|
+
loader_arg: str
|
180
|
+
|
181
|
+
|
182
|
+
def _get_special_chain_info_or_none(chain):
|
183
|
+
for (
|
184
|
+
special_chain_class,
|
185
|
+
loader_arg,
|
186
|
+
) in _get_map_of_special_chain_class_to_loader_arg().items():
|
187
|
+
if isinstance(chain, special_chain_class):
|
188
|
+
return _SpecialChainInfo(loader_arg=loader_arg)
|
189
|
+
|
190
|
+
|
191
|
+
@lru_cache
|
192
|
+
def _get_map_of_special_chain_class_to_loader_arg():
|
193
|
+
import langchain
|
194
|
+
|
195
|
+
from mlflow.langchain.retriever_chain import _RetrieverChain
|
196
|
+
|
197
|
+
class_name_to_loader_arg = {
|
198
|
+
"langchain.chains.RetrievalQA": "retriever",
|
199
|
+
"langchain.chains.APIChain": "requests_wrapper",
|
200
|
+
"langchain.chains.HypotheticalDocumentEmbedder": "embeddings",
|
201
|
+
}
|
202
|
+
# NB: SQLDatabaseChain was migrated to langchain_experimental beginning with version 0.0.247
|
203
|
+
if version.parse(langchain.__version__) <= version.parse("0.0.246"):
|
204
|
+
class_name_to_loader_arg["langchain.chains.SQLDatabaseChain"] = "database"
|
205
|
+
else:
|
206
|
+
if find_spec("langchain_experimental"):
|
207
|
+
# Add this entry only if langchain_experimental is installed
|
208
|
+
class_name_to_loader_arg["langchain_experimental.sql.SQLDatabaseChain"] = "database"
|
209
|
+
|
210
|
+
class_to_loader_arg = {
|
211
|
+
_RetrieverChain: "retriever",
|
212
|
+
}
|
213
|
+
for class_name, loader_arg in class_name_to_loader_arg.items():
|
214
|
+
try:
|
215
|
+
cls = _get_class_from_string(class_name)
|
216
|
+
class_to_loader_arg[cls] = loader_arg
|
217
|
+
except Exception:
|
218
|
+
logger.warning(
|
219
|
+
"Unexpected import failure for class '%s'. Please file an issue at"
|
220
|
+
" https://github.com/mlflow/mlflow/issues/.",
|
221
|
+
class_name,
|
222
|
+
exc_info=True,
|
223
|
+
)
|
224
|
+
|
225
|
+
return class_to_loader_arg
|
226
|
+
|
227
|
+
|
228
|
+
@lru_cache
|
229
|
+
def _get_supported_llms():
|
230
|
+
supported_llms = set()
|
231
|
+
|
232
|
+
def try_adding_llm(module, class_name):
|
233
|
+
if cls := getattr(module, class_name, None):
|
234
|
+
supported_llms.add(cls)
|
235
|
+
|
236
|
+
def safe_import_and_add(module_name, class_name):
|
237
|
+
"""Add conditional support for `partner` and `community` APIs in langchain"""
|
238
|
+
try:
|
239
|
+
module = importlib.import_module(module_name)
|
240
|
+
try_adding_llm(module, class_name)
|
241
|
+
except ImportError:
|
242
|
+
pass
|
243
|
+
|
244
|
+
safe_import_and_add("langchain.llms.openai", "OpenAI")
|
245
|
+
# HuggingFacePipeline is moved to langchain_huggingface since langchain 0.2.0
|
246
|
+
safe_import_and_add("langchain.llms", "HuggingFacePipeline")
|
247
|
+
safe_import_and_add("langchain.langchain_huggingface", "HuggingFacePipeline")
|
248
|
+
safe_import_and_add("langchain_openai", "OpenAI")
|
249
|
+
safe_import_and_add("langchain_databricks", "ChatDatabricks")
|
250
|
+
safe_import_and_add("databricks_langchain", "ChatDatabricks")
|
251
|
+
|
252
|
+
for llm_name in ["Databricks", "Mlflow"]:
|
253
|
+
safe_import_and_add("langchain.llms", llm_name)
|
254
|
+
|
255
|
+
for chat_model_name in [
|
256
|
+
"ChatDatabricks",
|
257
|
+
"ChatMlflow",
|
258
|
+
"ChatOpenAI",
|
259
|
+
"AzureChatOpenAI",
|
260
|
+
]:
|
261
|
+
safe_import_and_add("langchain.chat_models", chat_model_name)
|
262
|
+
|
263
|
+
return supported_llms
|
264
|
+
|
265
|
+
|
266
|
+
def _agent_executor_contains_unsupported_llm(lc_model, _SUPPORTED_LLMS):
|
267
|
+
import langchain.agents.agent
|
268
|
+
|
269
|
+
return (
|
270
|
+
isinstance(lc_model, langchain.agents.agent.AgentExecutor)
|
271
|
+
# 'RunnableMultiActionAgent' object has no attribute 'llm_chain'
|
272
|
+
and hasattr(lc_model.agent, "llm_chain")
|
273
|
+
and not any(
|
274
|
+
isinstance(lc_model.agent.llm_chain.llm, supported_llm)
|
275
|
+
for supported_llm in _SUPPORTED_LLMS
|
276
|
+
)
|
277
|
+
)
|
278
|
+
|
279
|
+
|
280
|
+
# temp_dir is only required when lc_model could be a file path
|
281
|
+
def _validate_and_prepare_lc_model_or_path(lc_model, loader_fn, temp_dir=None):
|
282
|
+
import langchain.agents.agent
|
283
|
+
import langchain.chains.base
|
284
|
+
import langchain.chains.llm
|
285
|
+
import langchain.llms.huggingface_hub
|
286
|
+
import langchain.llms.openai
|
287
|
+
import langchain.schema
|
288
|
+
|
289
|
+
# lc_model is a file path
|
290
|
+
if isinstance(lc_model, str):
|
291
|
+
return _validate_and_get_model_code_path(lc_model, temp_dir)
|
292
|
+
|
293
|
+
if not isinstance(lc_model, supported_lc_types()):
|
294
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
295
|
+
get_unsupported_model_message(type(lc_model).__name__)
|
296
|
+
)
|
297
|
+
|
298
|
+
_SUPPORTED_LLMS = _get_supported_llms()
|
299
|
+
if isinstance(lc_model, langchain.chains.llm.LLMChain) and not any(
|
300
|
+
isinstance(lc_model.llm, supported_llm) for supported_llm in _SUPPORTED_LLMS
|
301
|
+
):
|
302
|
+
logger.warning(
|
303
|
+
_UNSUPPORTED_LLM_WARNING_MESSAGE,
|
304
|
+
type(lc_model.llm).__name__,
|
305
|
+
)
|
306
|
+
|
307
|
+
if _agent_executor_contains_unsupported_llm(lc_model, _SUPPORTED_LLMS):
|
308
|
+
logger.warning(
|
309
|
+
_UNSUPPORTED_LLM_WARNING_MESSAGE,
|
310
|
+
type(lc_model.agent.llm_chain.llm).__name__,
|
311
|
+
)
|
312
|
+
|
313
|
+
if special_chain_info := _get_special_chain_info_or_none(lc_model):
|
314
|
+
if loader_fn is None:
|
315
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
316
|
+
f"For {type(lc_model).__name__} models, a `loader_fn` must be provided."
|
317
|
+
)
|
318
|
+
if not isinstance(loader_fn, types.FunctionType):
|
319
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
320
|
+
"The `loader_fn` must be a function that returns a {loader_arg}.".format(
|
321
|
+
loader_arg=special_chain_info.loader_arg
|
322
|
+
)
|
323
|
+
)
|
324
|
+
|
325
|
+
# If lc_model is a retriever, wrap it in a _RetrieverChain
|
326
|
+
if isinstance(lc_model, langchain.schema.BaseRetriever):
|
327
|
+
from mlflow.langchain.retriever_chain import _RetrieverChain
|
328
|
+
|
329
|
+
if loader_fn is None:
|
330
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
331
|
+
f"For {type(lc_model).__name__} models, a `loader_fn` must be provided."
|
332
|
+
)
|
333
|
+
if not isinstance(loader_fn, types.FunctionType):
|
334
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
335
|
+
"The `loader_fn` must be a function that returns a retriever."
|
336
|
+
)
|
337
|
+
lc_model = _RetrieverChain(retriever=lc_model)
|
338
|
+
|
339
|
+
return lc_model
|
340
|
+
|
341
|
+
|
342
|
+
def _save_base_lcs(model, path, loader_fn=None, persist_dir=None):
|
343
|
+
from langchain.agents.agent import AgentExecutor
|
344
|
+
from langchain.chains.base import Chain
|
345
|
+
from langchain.chains.llm import LLMChain
|
346
|
+
from langchain.chat_models.base import BaseChatModel
|
347
|
+
|
348
|
+
model_data_path = os.path.join(path, _MODEL_DATA_YAML_FILE_NAME)
|
349
|
+
model_data_kwargs = {
|
350
|
+
_MODEL_DATA_KEY: _MODEL_DATA_YAML_FILE_NAME,
|
351
|
+
_MODEL_LOAD_KEY: _BASE_LOAD_KEY,
|
352
|
+
}
|
353
|
+
|
354
|
+
if isinstance(model, (LLMChain, BaseChatModel)):
|
355
|
+
model.save(model_data_path)
|
356
|
+
elif isinstance(model, AgentExecutor):
|
357
|
+
if model.agent and getattr(model.agent, "llm_chain", None):
|
358
|
+
model.agent.llm_chain.save(model_data_path)
|
359
|
+
|
360
|
+
if model.agent:
|
361
|
+
agent_data_path = os.path.join(path, _AGENT_DATA_FILE_NAME)
|
362
|
+
model.save_agent(agent_data_path)
|
363
|
+
model_data_kwargs[_AGENT_DATA_KEY] = _AGENT_DATA_FILE_NAME
|
364
|
+
|
365
|
+
if model.tools:
|
366
|
+
tools_data_path = os.path.join(path, _TOOLS_DATA_FILE_NAME)
|
367
|
+
try:
|
368
|
+
with open(tools_data_path, "wb") as f:
|
369
|
+
cloudpickle.dump(model.tools, f)
|
370
|
+
except Exception as e:
|
371
|
+
raise mlflow.MlflowException(
|
372
|
+
"Error when attempting to pickle the AgentExecutor tools. "
|
373
|
+
"This model likely does not support serialization."
|
374
|
+
) from e
|
375
|
+
model_data_kwargs[_TOOLS_DATA_KEY] = _TOOLS_DATA_FILE_NAME
|
376
|
+
else:
|
377
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
378
|
+
"For initializing the AgentExecutor, tools must be provided."
|
379
|
+
)
|
380
|
+
|
381
|
+
key_to_ignore = ["llm_chain", "agent", "tools", "callback_manager"]
|
382
|
+
temp_dict = {k: v for k, v in model.__dict__.items() if k not in key_to_ignore}
|
383
|
+
|
384
|
+
agent_primitive_path = os.path.join(path, _AGENT_PRIMITIVES_FILE_NAME)
|
385
|
+
with open(agent_primitive_path, "w") as config_file:
|
386
|
+
json.dump(temp_dict, config_file, indent=4)
|
387
|
+
|
388
|
+
model_data_kwargs[_AGENT_PRIMITIVES_DATA_KEY] = _AGENT_PRIMITIVES_FILE_NAME
|
389
|
+
|
390
|
+
elif special_chain_info := _get_special_chain_info_or_none(model):
|
391
|
+
# Save loader_fn by pickling
|
392
|
+
loader_fn_path = os.path.join(path, _LOADER_FN_FILE_NAME)
|
393
|
+
with open(loader_fn_path, "wb") as f:
|
394
|
+
cloudpickle.dump(loader_fn, f)
|
395
|
+
model_data_kwargs[_LOADER_FN_KEY] = _LOADER_FN_FILE_NAME
|
396
|
+
model_data_kwargs[_LOADER_ARG_KEY] = special_chain_info.loader_arg
|
397
|
+
|
398
|
+
if persist_dir is not None:
|
399
|
+
if os.path.exists(persist_dir):
|
400
|
+
# Save persist_dir by copying into subdir _PERSIST_DIR_NAME
|
401
|
+
persist_dir_data_path = os.path.join(path, _PERSIST_DIR_NAME)
|
402
|
+
shutil.copytree(persist_dir, persist_dir_data_path)
|
403
|
+
model_data_kwargs[_PERSIST_DIR_KEY] = _PERSIST_DIR_NAME
|
404
|
+
else:
|
405
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
406
|
+
"The directory provided for persist_dir does not exist."
|
407
|
+
)
|
408
|
+
|
409
|
+
# Save model
|
410
|
+
model.save(model_data_path)
|
411
|
+
elif isinstance(model, Chain):
|
412
|
+
logger.warning(get_unsupported_model_message(type(model).__name__))
|
413
|
+
model.save(model_data_path)
|
414
|
+
else:
|
415
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
416
|
+
get_unsupported_model_message(type(model).__name__)
|
417
|
+
)
|
418
|
+
|
419
|
+
return model_data_kwargs
|
420
|
+
|
421
|
+
|
422
|
+
def _load_from_pickle(path):
|
423
|
+
with open(path, "rb") as f:
|
424
|
+
return cloudpickle.load(f)
|
425
|
+
|
426
|
+
|
427
|
+
def _load_from_json(path):
|
428
|
+
with open(path) as f:
|
429
|
+
return json.load(f)
|
430
|
+
|
431
|
+
|
432
|
+
def _load_from_yaml(path):
|
433
|
+
with open(path) as f:
|
434
|
+
return yaml.safe_load(f)
|
435
|
+
|
436
|
+
|
437
|
+
def _get_path_by_key(root_path, key, conf):
|
438
|
+
key_path = conf.get(key)
|
439
|
+
return os.path.join(root_path, key_path) if key_path else None
|
440
|
+
|
441
|
+
|
442
|
+
def _patch_loader(loader_func: Callable[..., Any]) -> Callable[..., Any]:
|
443
|
+
"""
|
444
|
+
Patch LangChain loader function like load_chain() to handle the breaking change introduced in
|
445
|
+
LangChain 0.1.12.
|
446
|
+
|
447
|
+
Since langchain-community 0.0.27, loading a module that relies on the pickle deserialization
|
448
|
+
requires the `allow_dangerous_deserialization` flag to be set to True, for security reasons.
|
449
|
+
However, this flag could not be specified via the LangChain's loading API like load_chain(),
|
450
|
+
load_llm(), until LangChain 0.1.14. As a result, such module cannot be loaded with MLflow
|
451
|
+
with earlier version of LangChain and we have to tell the user to upgrade LangChain to 0.0.14
|
452
|
+
or above.
|
453
|
+
|
454
|
+
Args:
|
455
|
+
loader_func: The LangChain loader function to be patched e.g. load_chain().
|
456
|
+
|
457
|
+
Returns:
|
458
|
+
The patched loader function.
|
459
|
+
"""
|
460
|
+
if not IS_PICKLE_SERIALIZATION_RESTRICTED:
|
461
|
+
return loader_func
|
462
|
+
|
463
|
+
import langchain
|
464
|
+
|
465
|
+
if Version(langchain.__version__) >= Version("0.1.14"):
|
466
|
+
# For LangChain 0.1.14 and above, we can pass `allow_dangerous_deserialization` flag
|
467
|
+
# via the loader APIs. Since the model is serialized by the user (or someone who has
|
468
|
+
# access to the tracking server), it is safe to set this flag to True.
|
469
|
+
def patched_loader(*args, **kwargs):
|
470
|
+
return loader_func(*args, **kwargs, allow_dangerous_deserialization=True)
|
471
|
+
else:
|
472
|
+
|
473
|
+
def patched_loader(*args, **kwargs):
|
474
|
+
try:
|
475
|
+
return loader_func(*args, **kwargs)
|
476
|
+
except ValueError as e:
|
477
|
+
if "This code relies on the pickle module" in str(e):
|
478
|
+
raise MlflowException(
|
479
|
+
"Since langchain-community 0.0.27, loading a module that relies on "
|
480
|
+
"the pickle deserialization requires the `allow_dangerous_deserialization` "
|
481
|
+
"flag to be set to True when loading. However, this flag is not supported "
|
482
|
+
"by the installed version of LangChain. Please upgrade LangChain to 0.1.14 "
|
483
|
+
"or above by running `pip install langchain>=0.1.14`.",
|
484
|
+
error_code=INTERNAL_ERROR,
|
485
|
+
) from e
|
486
|
+
else:
|
487
|
+
raise
|
488
|
+
|
489
|
+
return patched_loader
|
490
|
+
|
491
|
+
|
492
|
+
def _load_base_lcs(
|
493
|
+
local_model_path,
|
494
|
+
conf,
|
495
|
+
):
|
496
|
+
lc_model_path = os.path.join(
|
497
|
+
local_model_path, conf.get(_MODEL_DATA_KEY, _MODEL_DATA_YAML_FILE_NAME)
|
498
|
+
)
|
499
|
+
|
500
|
+
agent_path = _get_path_by_key(local_model_path, _AGENT_DATA_KEY, conf)
|
501
|
+
tools_path = _get_path_by_key(local_model_path, _TOOLS_DATA_KEY, conf)
|
502
|
+
agent_primitive_path = _get_path_by_key(local_model_path, _AGENT_PRIMITIVES_DATA_KEY, conf)
|
503
|
+
loader_fn_path = _get_path_by_key(local_model_path, _LOADER_FN_KEY, conf)
|
504
|
+
persist_dir = _get_path_by_key(local_model_path, _PERSIST_DIR_KEY, conf)
|
505
|
+
|
506
|
+
model_type = conf.get(_MODEL_TYPE_KEY)
|
507
|
+
loader_arg = conf.get(_LOADER_ARG_KEY)
|
508
|
+
|
509
|
+
from langchain.chains.loading import load_chain
|
510
|
+
|
511
|
+
from mlflow.langchain.retriever_chain import _RetrieverChain
|
512
|
+
|
513
|
+
if loader_arg is not None:
|
514
|
+
if loader_fn_path is None:
|
515
|
+
raise mlflow.MlflowException.invalid_parameter_value(
|
516
|
+
"Missing file for loader_fn which is required to build the model."
|
517
|
+
)
|
518
|
+
loader_fn = _load_from_pickle(loader_fn_path)
|
519
|
+
kwargs = {loader_arg: loader_fn(persist_dir)}
|
520
|
+
if model_type == _RetrieverChain.__name__:
|
521
|
+
model = _RetrieverChain.load(lc_model_path, **kwargs).retriever
|
522
|
+
else:
|
523
|
+
model = _patch_loader(load_chain)(lc_model_path, **kwargs)
|
524
|
+
elif agent_path is None and tools_path is None:
|
525
|
+
model = _patch_loader(load_chain)(lc_model_path)
|
526
|
+
else:
|
527
|
+
from langchain.agents import initialize_agent
|
528
|
+
|
529
|
+
llm = _patch_loader(load_chain)(lc_model_path)
|
530
|
+
tools = []
|
531
|
+
kwargs = {}
|
532
|
+
|
533
|
+
if os.path.exists(tools_path):
|
534
|
+
tools = _load_from_pickle(tools_path)
|
535
|
+
else:
|
536
|
+
raise mlflow.MlflowException(
|
537
|
+
"Missing file for tools which is required to build the AgentExecutor object."
|
538
|
+
)
|
539
|
+
|
540
|
+
if os.path.exists(agent_primitive_path):
|
541
|
+
kwargs = _load_from_json(agent_primitive_path)
|
542
|
+
|
543
|
+
model = initialize_agent(tools=tools, llm=llm, agent_path=agent_path, **kwargs)
|
544
|
+
return model
|
545
|
+
|
546
|
+
|
547
|
+
def patch_langchain_type_to_cls_dict(func):
|
548
|
+
@functools.wraps(func)
|
549
|
+
def wrapper(*args, **kwargs):
|
550
|
+
def _load_chat_openai():
|
551
|
+
from langchain_community.chat_models import ChatOpenAI
|
552
|
+
|
553
|
+
return ChatOpenAI
|
554
|
+
|
555
|
+
def _load_azure_chat_openai():
|
556
|
+
from langchain_community.chat_models import AzureChatOpenAI
|
557
|
+
|
558
|
+
return AzureChatOpenAI
|
559
|
+
|
560
|
+
def _load_chat_databricks():
|
561
|
+
from databricks_langchain import ChatDatabricks
|
562
|
+
|
563
|
+
return ChatDatabricks
|
564
|
+
|
565
|
+
def _patched_get_type_to_cls_dict(original):
|
566
|
+
def _wrapped():
|
567
|
+
return {
|
568
|
+
**original(),
|
569
|
+
"openai-chat": _load_chat_openai,
|
570
|
+
"azure-openai-chat": _load_azure_chat_openai,
|
571
|
+
"chat-databricks": _load_chat_databricks,
|
572
|
+
}
|
573
|
+
|
574
|
+
return _wrapped
|
575
|
+
|
576
|
+
modules_to_patch = [
|
577
|
+
"langchain_databricks",
|
578
|
+
"langchain.llms",
|
579
|
+
"langchain_community.llms.loading",
|
580
|
+
]
|
581
|
+
originals = {}
|
582
|
+
for name in modules_to_patch:
|
583
|
+
try:
|
584
|
+
module = importlib.import_module(name)
|
585
|
+
originals[name] = module.get_type_to_cls_dict # Record original impl for cleanup
|
586
|
+
except (ImportError, AttributeError):
|
587
|
+
continue
|
588
|
+
module.get_type_to_cls_dict = _patched_get_type_to_cls_dict(originals[name])
|
589
|
+
|
590
|
+
try:
|
591
|
+
return func(*args, **kwargs)
|
592
|
+
except ValueError as e:
|
593
|
+
if m := _CHAT_MODELS_ERROR_MSG.search(str(e)):
|
594
|
+
model_name = "ChatOpenAI" if m.group(1) == "openai-chat" else "AzureChatOpenAI"
|
595
|
+
raise mlflow.MlflowException(
|
596
|
+
f"Loading {model_name} chat model is not supported in MLflow with the "
|
597
|
+
"current version of LangChain. Please upgrade LangChain to 0.0.307 or above "
|
598
|
+
"by running `pip install langchain>=0.0.307`."
|
599
|
+
) from e
|
600
|
+
else:
|
601
|
+
raise
|
602
|
+
finally:
|
603
|
+
# Clean up the patch
|
604
|
+
for module_name, original_impl in originals.items():
|
605
|
+
module = importlib.import_module(module_name)
|
606
|
+
module.get_type_to_cls_dict = original_impl
|
607
|
+
|
608
|
+
return wrapper
|
609
|
+
|
610
|
+
|
611
|
+
def register_pydantic_serializer():
|
612
|
+
"""
|
613
|
+
Helper function to pickle pydantic fields for pydantic v1.
|
614
|
+
Pydantic's Cython validators are not serializable.
|
615
|
+
https://github.com/cloudpipe/cloudpickle/issues/408
|
616
|
+
"""
|
617
|
+
import pydantic
|
618
|
+
|
619
|
+
if Version(pydantic.__version__) >= Version("2.0.0"):
|
620
|
+
return
|
621
|
+
|
622
|
+
import pydantic.fields
|
623
|
+
|
624
|
+
def custom_serializer(obj):
|
625
|
+
return {
|
626
|
+
"name": obj.name,
|
627
|
+
# outer_type_ is the original type for ModelFields,
|
628
|
+
# while type_ can be updated later with the nested type
|
629
|
+
# like int for List[int].
|
630
|
+
"type_": obj.outer_type_,
|
631
|
+
"class_validators": obj.class_validators,
|
632
|
+
"model_config": obj.model_config,
|
633
|
+
"default": obj.default,
|
634
|
+
"default_factory": obj.default_factory,
|
635
|
+
"required": obj.required,
|
636
|
+
"final": obj.final,
|
637
|
+
"alias": obj.alias,
|
638
|
+
"field_info": obj.field_info,
|
639
|
+
}
|
640
|
+
|
641
|
+
def custom_deserializer(kwargs):
|
642
|
+
return pydantic.fields.ModelField(**kwargs)
|
643
|
+
|
644
|
+
def _CloudPicklerReducer(obj):
|
645
|
+
return custom_deserializer, (custom_serializer(obj),)
|
646
|
+
|
647
|
+
warnings.warn(
|
648
|
+
"Using custom serializer to pickle pydantic.fields.ModelField classes, "
|
649
|
+
"this might miss some fields and validators. To avoid this, "
|
650
|
+
"please upgrade pydantic to v2 using `pip install pydantic -U` with "
|
651
|
+
"langchain 0.0.267 and above."
|
652
|
+
)
|
653
|
+
cloudpickle.CloudPickler.dispatch[pydantic.fields.ModelField] = _CloudPicklerReducer
|
654
|
+
|
655
|
+
|
656
|
+
def unregister_pydantic_serializer():
|
657
|
+
import pydantic
|
658
|
+
|
659
|
+
if Version(pydantic.__version__) >= Version("2.0.0"):
|
660
|
+
return
|
661
|
+
|
662
|
+
cloudpickle.CloudPickler.dispatch.pop(pydantic.fields.ModelField, None)
|
663
|
+
|
664
|
+
|
665
|
+
@contextlib.contextmanager
|
666
|
+
def register_pydantic_v1_serializer_cm():
|
667
|
+
try:
|
668
|
+
register_pydantic_serializer()
|
669
|
+
yield
|
670
|
+
finally:
|
671
|
+
unregister_pydantic_serializer()
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import inspect
|
2
|
+
|
3
|
+
from packaging.version import Version
|
4
|
+
|
5
|
+
|
6
|
+
def convert_to_serializable(response):
|
7
|
+
"""
|
8
|
+
Convert the response to a JSON serializable format.
|
9
|
+
|
10
|
+
LangChain response objects often contains Pydantic objects, which causes an serialization
|
11
|
+
error when the model is served behind REST endpoint.
|
12
|
+
"""
|
13
|
+
import langchain
|
14
|
+
|
15
|
+
# LangChain >= 0.3.0 uses Pydantic 2.x while < 0.3.0 is based on Pydantic 1.x.
|
16
|
+
if Version(langchain.__version__) >= Version("0.3.0"):
|
17
|
+
from pydantic import BaseModel
|
18
|
+
|
19
|
+
if isinstance(response, BaseModel):
|
20
|
+
return response.model_dump()
|
21
|
+
else:
|
22
|
+
from langchain_core.pydantic_v1 import BaseModel as LangChainBaseModel
|
23
|
+
|
24
|
+
if isinstance(response, LangChainBaseModel):
|
25
|
+
return response.dict()
|
26
|
+
|
27
|
+
if inspect.isgenerator(response):
|
28
|
+
return (convert_to_serializable(chunk) for chunk in response)
|
29
|
+
elif isinstance(response, dict):
|
30
|
+
return {k: convert_to_serializable(v) for k, v in response.items()}
|
31
|
+
elif isinstance(response, list):
|
32
|
+
return [convert_to_serializable(v) for v in response]
|
33
|
+
elif isinstance(response, tuple):
|
34
|
+
return tuple(convert_to_serializable(v) for v in response)
|
35
|
+
|
36
|
+
return response
|
File without changes
|