snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,15 @@
|
|
1
|
+
import enum
|
2
|
+
import json
|
1
3
|
import textwrap
|
2
4
|
from typing import Any, Dict, List, Optional, Tuple
|
3
5
|
|
6
|
+
from packaging import version
|
7
|
+
|
8
|
+
from snowflake import snowpark
|
4
9
|
from snowflake.ml._internal.utils import (
|
5
10
|
identifier,
|
6
11
|
query_result_checker,
|
12
|
+
snowflake_env,
|
7
13
|
sql_identifier,
|
8
14
|
)
|
9
15
|
from snowflake.ml.model._client.sql import _base
|
@@ -11,6 +17,17 @@ from snowflake.snowpark import dataframe, functions as F, types as spt
|
|
11
17
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
12
18
|
|
13
19
|
|
20
|
+
class ServiceStatus(enum.Enum):
|
21
|
+
UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
|
22
|
+
PENDING = "PENDING" # resource set is being created, can't be used yet
|
23
|
+
READY = "READY" # resource set has been deployed.
|
24
|
+
DELETING = "DELETING" # resource set is being deleted
|
25
|
+
FAILED = "FAILED" # resource set has failed and cannot be used anymore
|
26
|
+
DONE = "DONE" # resource set has finished running
|
27
|
+
NOT_FOUND = "NOT_FOUND" # not found or deleted
|
28
|
+
INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
|
29
|
+
|
30
|
+
|
14
31
|
class ServiceSQLClient(_base._BaseSQLClient):
|
15
32
|
def build_model_container(
|
16
33
|
self,
|
@@ -30,20 +47,21 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
30
47
|
) -> None:
|
31
48
|
actual_image_repo_database = image_repo_database_name or self._database_name
|
32
49
|
actual_image_repo_schema = image_repo_schema_name or self._schema_name
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
50
|
+
actual_model_database = database_name or self._database_name
|
51
|
+
actual_model_schema = schema_name or self._schema_name
|
52
|
+
fq_model_name = self.fully_qualified_object_name(actual_model_database, actual_model_schema, model_name)
|
53
|
+
fq_image_repo_name = identifier.get_schema_level_object_identifier(
|
54
|
+
actual_image_repo_database.identifier(),
|
55
|
+
actual_image_repo_schema.identifier(),
|
56
|
+
image_repo_name.identifier(),
|
40
57
|
)
|
41
|
-
|
58
|
+
is_gpu_str = "TRUE" if gpu else "FALSE"
|
59
|
+
force_rebuild_str = "TRUE" if force_rebuild else "FALSE"
|
42
60
|
query_result_checker.SqlResultValidator(
|
43
61
|
self._session,
|
44
62
|
(
|
45
63
|
f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
|
46
|
-
f" '{fq_image_repo_name}', '{
|
64
|
+
f" '{fq_image_repo_name}', '{is_gpu_str}', '{force_rebuild_str}', '', '{external_access_integration}')"
|
47
65
|
),
|
48
66
|
statement_params=statement_params,
|
49
67
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -54,12 +72,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
54
72
|
stage_path: str,
|
55
73
|
model_deployment_spec_file_rel_path: str,
|
56
74
|
statement_params: Optional[Dict[str, Any]] = None,
|
57
|
-
) ->
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
75
|
+
) -> Tuple[str, snowpark.AsyncJob]:
|
76
|
+
async_job = self._session.sql(
|
77
|
+
f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
|
78
|
+
).collect(block=False, statement_params=statement_params)
|
79
|
+
assert isinstance(async_job, snowpark.AsyncJob)
|
80
|
+
return async_job.query_id, async_job
|
63
81
|
|
64
82
|
def invoke_function_method(
|
65
83
|
self,
|
@@ -74,12 +92,13 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
74
92
|
statement_params: Optional[Dict[str, Any]] = None,
|
75
93
|
) -> dataframe.DataFrame:
|
76
94
|
with_statements = []
|
95
|
+
actual_database_name = database_name or self._database_name
|
96
|
+
actual_schema_name = schema_name or self._schema_name
|
97
|
+
|
77
98
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
78
|
-
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
99
|
+
INTERMEDIATE_TABLE_NAME = ServiceSQLClient.get_tmp_name_with_prefix("SNOWPARK_ML_MODEL_INFERENCE_INPUT")
|
79
100
|
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
80
101
|
else:
|
81
|
-
actual_database_name = database_name or self._database_name
|
82
|
-
actual_schema_name = schema_name or self._schema_name
|
83
102
|
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
84
103
|
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
85
104
|
actual_database_name.identifier(),
|
@@ -93,7 +112,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
93
112
|
statement_params=statement_params,
|
94
113
|
)
|
95
114
|
|
96
|
-
INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
|
115
|
+
INTERMEDIATE_OBJ_NAME = ServiceSQLClient.get_tmp_name_with_prefix("TMP_RESULT")
|
97
116
|
|
98
117
|
with_sql = f"WITH {','.join(with_statements)}" if with_statements else ""
|
99
118
|
args_sql_list = []
|
@@ -101,10 +120,26 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
101
120
|
args_sql_list.append(input_arg_value)
|
102
121
|
args_sql = ", ".join(args_sql_list)
|
103
122
|
|
123
|
+
if snowflake_env.get_current_snowflake_version(
|
124
|
+
self._session, statement_params=statement_params
|
125
|
+
) >= version.parse("8.39.0"):
|
126
|
+
fully_qualified_service_name = self.fully_qualified_object_name(
|
127
|
+
actual_database_name, actual_schema_name, service_name
|
128
|
+
)
|
129
|
+
fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
|
130
|
+
|
131
|
+
else:
|
132
|
+
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
133
|
+
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
134
|
+
actual_database_name.identifier(),
|
135
|
+
actual_schema_name.identifier(),
|
136
|
+
function_name,
|
137
|
+
)
|
138
|
+
|
104
139
|
sql = textwrap.dedent(
|
105
140
|
f"""{with_sql}
|
106
141
|
SELECT *,
|
107
|
-
{
|
142
|
+
{fully_qualified_function_name}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
|
108
143
|
FROM {INTERMEDIATE_TABLE_NAME}"""
|
109
144
|
)
|
110
145
|
|
@@ -127,3 +162,69 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
127
162
|
output_df._statement_params = statement_params # type: ignore[assignment]
|
128
163
|
|
129
164
|
return output_df
|
165
|
+
|
166
|
+
def get_service_logs(
|
167
|
+
self,
|
168
|
+
*,
|
169
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
170
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
171
|
+
service_name: sql_identifier.SqlIdentifier,
|
172
|
+
instance_id: str = "0",
|
173
|
+
container_name: str,
|
174
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
175
|
+
) -> str:
|
176
|
+
system_func = "SYSTEM$GET_SERVICE_LOGS"
|
177
|
+
rows = (
|
178
|
+
query_result_checker.SqlResultValidator(
|
179
|
+
self._session,
|
180
|
+
(
|
181
|
+
f"CALL {system_func}("
|
182
|
+
f"'{self.fully_qualified_object_name(database_name, schema_name, service_name)}', '{instance_id}', "
|
183
|
+
f"'{container_name}')"
|
184
|
+
),
|
185
|
+
statement_params=statement_params,
|
186
|
+
)
|
187
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
188
|
+
.validate()
|
189
|
+
)
|
190
|
+
return str(rows[0][system_func])
|
191
|
+
|
192
|
+
def get_service_status(
|
193
|
+
self,
|
194
|
+
*,
|
195
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
196
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
197
|
+
service_name: sql_identifier.SqlIdentifier,
|
198
|
+
include_message: bool = False,
|
199
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
200
|
+
) -> Tuple[ServiceStatus, Optional[str]]:
|
201
|
+
system_func = "SYSTEM$GET_SERVICE_STATUS"
|
202
|
+
rows = (
|
203
|
+
query_result_checker.SqlResultValidator(
|
204
|
+
self._session,
|
205
|
+
f"CALL {system_func}('{self.fully_qualified_object_name(database_name, schema_name, service_name)}')",
|
206
|
+
statement_params=statement_params,
|
207
|
+
)
|
208
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
209
|
+
.validate()
|
210
|
+
)
|
211
|
+
metadata = json.loads(rows[0][system_func])[0]
|
212
|
+
if metadata and metadata["status"]:
|
213
|
+
service_status = ServiceStatus(metadata["status"])
|
214
|
+
message = metadata["message"] if include_message else None
|
215
|
+
return service_status, message
|
216
|
+
return ServiceStatus.UNKNOWN, None
|
217
|
+
|
218
|
+
def drop_service(
|
219
|
+
self,
|
220
|
+
*,
|
221
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
222
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
223
|
+
service_name: sql_identifier.SqlIdentifier,
|
224
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
225
|
+
) -> None:
|
226
|
+
query_result_checker.SqlResultValidator(
|
227
|
+
self._session,
|
228
|
+
f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}",
|
229
|
+
statement_params=statement_params,
|
230
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -1,14 +1,11 @@
|
|
1
|
-
import glob
|
2
1
|
import pathlib
|
3
2
|
import tempfile
|
4
3
|
import uuid
|
5
|
-
import zipfile
|
6
4
|
from types import ModuleType
|
7
5
|
from typing import Any, Dict, List, Optional
|
8
6
|
|
9
7
|
from absl import logging
|
10
8
|
from packaging import requirements
|
11
|
-
from typing_extensions import deprecated
|
12
9
|
|
13
10
|
from snowflake import snowpark
|
14
11
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
@@ -92,6 +89,7 @@ class ModelComposer:
|
|
92
89
|
python_version: Optional[str] = None,
|
93
90
|
ext_modules: Optional[List[ModuleType]] = None,
|
94
91
|
code_paths: Optional[List[str]] = None,
|
92
|
+
task: model_types.Task = model_types.Task.UNKNOWN,
|
95
93
|
options: Optional[model_types.ModelSaveOption] = None,
|
96
94
|
) -> model_meta.ModelMetadata:
|
97
95
|
if not options:
|
@@ -120,24 +118,20 @@ class ModelComposer:
|
|
120
118
|
python_version=python_version,
|
121
119
|
ext_modules=ext_modules,
|
122
120
|
code_paths=code_paths,
|
121
|
+
task=task,
|
123
122
|
options=options,
|
124
123
|
)
|
125
124
|
assert self.packager.meta is not None
|
126
125
|
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
)
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
options=options,
|
137
|
-
data_sources=self._get_data_sources(model, sample_input_data),
|
138
|
-
)
|
139
|
-
else:
|
140
|
-
file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
|
126
|
+
file_utils.copytree(
|
127
|
+
str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
|
128
|
+
)
|
129
|
+
self.manifest.save(
|
130
|
+
model_meta=self.packager.meta,
|
131
|
+
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
132
|
+
options=options,
|
133
|
+
data_sources=self._get_data_sources(model, sample_input_data),
|
134
|
+
)
|
141
135
|
|
142
136
|
file_utils.upload_directory_to_stage(
|
143
137
|
self.session,
|
@@ -147,28 +141,6 @@ class ModelComposer:
|
|
147
141
|
)
|
148
142
|
return model_metadata
|
149
143
|
|
150
|
-
@deprecated("Only used by PrPr model registry. Use static method version of load instead.")
|
151
|
-
def legacy_load(
|
152
|
-
self,
|
153
|
-
*,
|
154
|
-
meta_only: bool = False,
|
155
|
-
options: Optional[model_types.ModelLoadOption] = None,
|
156
|
-
) -> None:
|
157
|
-
file_utils.download_directory_from_stage(
|
158
|
-
self.session,
|
159
|
-
stage_path=self.stage_path,
|
160
|
-
local_path=self.workspace_path,
|
161
|
-
statement_params=self._statement_params,
|
162
|
-
)
|
163
|
-
|
164
|
-
# TODO (Server-side Model Rollout): Remove this section.
|
165
|
-
model_zip_path = pathlib.Path(glob.glob(str(self.workspace_path / "*.zip"))[0])
|
166
|
-
self.model_file_rel_path = str(model_zip_path.relative_to(self.workspace_path))
|
167
|
-
|
168
|
-
with zipfile.ZipFile(self.model_local_path, mode="r", compression=zipfile.ZIP_DEFLATED) as zf:
|
169
|
-
zf.extractall(path=self._packager_workspace_path)
|
170
|
-
self.packager.load(meta_only=meta_only, options=options)
|
171
|
-
|
172
144
|
@staticmethod
|
173
145
|
def load(
|
174
146
|
workspace_path: pathlib.Path,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import collections
|
2
|
-
import
|
2
|
+
import logging
|
3
3
|
import pathlib
|
4
4
|
import warnings
|
5
5
|
from typing import List, Optional, cast
|
@@ -18,6 +18,9 @@ from snowflake.ml.model._packager.model_meta import (
|
|
18
18
|
model_meta as model_meta_api,
|
19
19
|
model_meta_schema,
|
20
20
|
)
|
21
|
+
from snowflake.ml.model._packager.model_runtime import model_runtime
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
21
24
|
|
22
25
|
|
23
26
|
class ModelManifest:
|
@@ -45,9 +48,30 @@ class ModelManifest:
|
|
45
48
|
if options is None:
|
46
49
|
options = {}
|
47
50
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
+
if "relax_version" not in options:
|
52
|
+
warnings.warn(
|
53
|
+
(
|
54
|
+
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed"
|
55
|
+
" from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, "
|
56
|
+
"reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
57
|
+
),
|
58
|
+
category=UserWarning,
|
59
|
+
stacklevel=2,
|
60
|
+
)
|
61
|
+
relax_version = options.get("relax_version", True)
|
62
|
+
|
63
|
+
runtime_to_use = model_runtime.ModelRuntime(
|
64
|
+
name=self._DEFAULT_RUNTIME_NAME,
|
65
|
+
env=model_meta.env,
|
66
|
+
imports=[str(model_rel_path) + "/"],
|
67
|
+
is_gpu=False,
|
68
|
+
is_warehouse=True,
|
69
|
+
)
|
70
|
+
if relax_version:
|
71
|
+
runtime_to_use.runtime_env.relax_version()
|
72
|
+
logger.info("Relaxing version constraints for dependencies in the model.")
|
73
|
+
logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
|
74
|
+
logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
|
51
75
|
runtime_dict = runtime_to_use.save(
|
52
76
|
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
53
77
|
)
|
@@ -78,13 +102,9 @@ class ModelManifest:
|
|
78
102
|
)
|
79
103
|
|
80
104
|
dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
"be warehouse-compabible. The model may need to be run in SPCS.",
|
85
|
-
category=UserWarning,
|
86
|
-
stacklevel=1,
|
87
|
-
)
|
105
|
+
|
106
|
+
# We only want to include pip dependencies file if there are any pip requirements.
|
107
|
+
if len(model_meta.env.pip_requirements) > 0:
|
88
108
|
dependencies["pip"] = runtime_dict["dependencies"]["pip"]
|
89
109
|
|
90
110
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
@@ -21,7 +21,7 @@ _DEFAULT_PIP_REQUIREMENTS_FILENAME = "requirements.txt"
|
|
21
21
|
# The default CUDA version is chosen based on the driver availability in SPCS.
|
22
22
|
# If changing this version, we need also change the version of default PyTorch in HuggingFace pipeline handler to
|
23
23
|
# make sure they are compatible.
|
24
|
-
DEFAULT_CUDA_VERSION = "11.
|
24
|
+
DEFAULT_CUDA_VERSION = "11.8"
|
25
25
|
|
26
26
|
|
27
27
|
class ModelEnv:
|
@@ -199,50 +199,16 @@ class ModelEnv:
|
|
199
199
|
)
|
200
200
|
if xgboost_spec:
|
201
201
|
self.include_if_absent(
|
202
|
-
[
|
203
|
-
ModelDependency(
|
204
|
-
requirement=f"conda-forge::py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost"
|
205
|
-
)
|
206
|
-
],
|
202
|
+
[ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
|
207
203
|
check_local_version=False,
|
208
204
|
)
|
209
205
|
|
210
|
-
pytorch_spec = env_utils.find_dep_spec(
|
211
|
-
self._conda_dependencies,
|
212
|
-
self._pip_requirements,
|
213
|
-
conda_pkg_name="pytorch",
|
214
|
-
pip_pkg_name="torch",
|
215
|
-
remove_spec=True,
|
216
|
-
)
|
217
|
-
pytorch_cuda_spec = env_utils.find_dep_spec(
|
218
|
-
self._conda_dependencies,
|
219
|
-
self._pip_requirements,
|
220
|
-
conda_pkg_name="pytorch-cuda",
|
221
|
-
remove_spec=False,
|
222
|
-
)
|
223
|
-
if pytorch_cuda_spec and not pytorch_cuda_spec.specifier.contains(self.cuda_version):
|
224
|
-
raise ValueError(
|
225
|
-
"The Pytorch-CUDA requirement you specified in your conda dependencies or pip requirements is"
|
226
|
-
" conflicting with CUDA version required. Please do not specify Pytorch-CUDA dependency using conda"
|
227
|
-
" dependencies or pip requirements."
|
228
|
-
)
|
229
|
-
if pytorch_spec:
|
230
|
-
self.include_if_absent(
|
231
|
-
[ModelDependency(requirement=f"pytorch::pytorch{pytorch_spec.specifier}", pip_name="torch")],
|
232
|
-
check_local_version=False,
|
233
|
-
)
|
234
|
-
if not pytorch_cuda_spec:
|
235
|
-
self.include_if_absent(
|
236
|
-
[ModelDependency(requirement=f"pytorch::pytorch-cuda=={self.cuda_version}.*", pip_name="torch")],
|
237
|
-
check_local_version=False,
|
238
|
-
)
|
239
|
-
|
240
206
|
tf_spec = env_utils.find_dep_spec(
|
241
207
|
self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True
|
242
208
|
)
|
243
209
|
if tf_spec:
|
244
210
|
self.include_if_absent(
|
245
|
-
[ModelDependency(requirement=f"
|
211
|
+
[ModelDependency(requirement=f"tensorflow-gpu{tf_spec.specifier}", pip_name="tensorflow")],
|
246
212
|
check_local_version=False,
|
247
213
|
)
|
248
214
|
|
@@ -252,7 +218,7 @@ class ModelEnv:
|
|
252
218
|
if transformers_spec:
|
253
219
|
self.include_if_absent(
|
254
220
|
[
|
255
|
-
ModelDependency(requirement="
|
221
|
+
ModelDependency(requirement="accelerate>=0.22.0", pip_name="accelerate"),
|
256
222
|
ModelDependency(requirement="scipy>=1.9", pip_name="scipy"),
|
257
223
|
],
|
258
224
|
check_local_version=False,
|
@@ -1,20 +1,54 @@
|
|
1
1
|
import json
|
2
|
-
|
2
|
+
import os
|
3
|
+
import warnings
|
4
|
+
from typing import Any, Callable, Iterable, List, Optional, Sequence, cast
|
3
5
|
|
4
6
|
import numpy as np
|
5
7
|
import numpy.typing as npt
|
6
8
|
import pandas as pd
|
9
|
+
from absl import logging
|
7
10
|
|
11
|
+
import snowflake.snowpark.dataframe as sp_df
|
12
|
+
from snowflake.ml._internal.utils import identifier
|
8
13
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
9
14
|
from snowflake.ml.model._packager.model_meta import model_meta
|
10
|
-
from snowflake.ml.model._signatures import
|
15
|
+
from snowflake.ml.model._signatures import (
|
16
|
+
core,
|
17
|
+
snowpark_handler,
|
18
|
+
utils as model_signature_utils,
|
19
|
+
)
|
11
20
|
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
12
21
|
|
22
|
+
EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT = 1000
|
23
|
+
|
24
|
+
|
25
|
+
class NumpyEncoder(json.JSONEncoder):
|
26
|
+
def default(self, obj: Any) -> Any:
|
27
|
+
if isinstance(obj, np.integer):
|
28
|
+
return int(obj)
|
29
|
+
if isinstance(obj, np.floating):
|
30
|
+
return float(obj)
|
31
|
+
if isinstance(obj, np.ndarray):
|
32
|
+
return obj.tolist()
|
33
|
+
return super().default(obj)
|
34
|
+
|
13
35
|
|
14
36
|
def _is_callable(model: model_types.SupportedModelType, method_name: str) -> bool:
|
15
37
|
return callable(getattr(model, method_name, None))
|
16
38
|
|
17
39
|
|
40
|
+
def get_truncated_sample_data(sample_input_data: model_types.SupportedDataType) -> model_types.SupportedLocalDataType:
|
41
|
+
trunc_sample_input = model_signature._truncate_data(sample_input_data)
|
42
|
+
local_sample_input: model_types.SupportedLocalDataType = None
|
43
|
+
if isinstance(sample_input_data, SnowparkDataFrame):
|
44
|
+
# Added because of Any from missing stubs.
|
45
|
+
trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
|
46
|
+
local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
|
47
|
+
else:
|
48
|
+
local_sample_input = trunc_sample_input
|
49
|
+
return local_sample_input
|
50
|
+
|
51
|
+
|
18
52
|
def validate_signature(
|
19
53
|
model: model_types.SupportedRequireSignatureModelType,
|
20
54
|
model_meta: model_meta.ModelMetadata,
|
@@ -24,19 +58,23 @@ def validate_signature(
|
|
24
58
|
) -> model_meta.ModelMetadata:
|
25
59
|
if model_meta.signatures:
|
26
60
|
validate_target_methods(model, list(model_meta.signatures.keys()))
|
61
|
+
if sample_input_data is not None:
|
62
|
+
local_sample_input = get_truncated_sample_data(sample_input_data)
|
63
|
+
for target_method in model_meta.signatures.keys():
|
64
|
+
|
65
|
+
model_signature_inst = model_meta.signatures.get(target_method)
|
66
|
+
if model_signature_inst is not None:
|
67
|
+
# strict validation the input signature
|
68
|
+
model_signature._convert_and_validate_local_data(
|
69
|
+
local_sample_input, model_signature_inst._inputs, True
|
70
|
+
)
|
27
71
|
return model_meta
|
28
72
|
|
29
73
|
# In this case sample_input_data should be available, because of the check in save_model.
|
30
74
|
assert (
|
31
75
|
sample_input_data is not None
|
32
76
|
), "Model signature and sample input are None at the same time. This should not happen with local model."
|
33
|
-
|
34
|
-
if isinstance(sample_input_data, SnowparkDataFrame):
|
35
|
-
# Added because of Any from missing stubs.
|
36
|
-
trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
|
37
|
-
local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
|
38
|
-
else:
|
39
|
-
local_sample_input = trunc_sample_input
|
77
|
+
local_sample_input = get_truncated_sample_data(sample_input_data)
|
40
78
|
for target_method in target_methods:
|
41
79
|
predictions_df = get_prediction_fn(target_method, local_sample_input)
|
42
80
|
sig = model_signature.infer_signature(local_sample_input, predictions_df)
|
@@ -45,24 +83,55 @@ def validate_signature(
|
|
45
83
|
return model_meta
|
46
84
|
|
47
85
|
|
86
|
+
def get_input_signature(
|
87
|
+
model_meta: model_meta.ModelMetadata, target_method: Optional[str]
|
88
|
+
) -> Sequence[core.BaseFeatureSpec]:
|
89
|
+
if target_method is None or target_method not in model_meta.signatures:
|
90
|
+
raise ValueError(f"Signature for target method {target_method} is missing or no method to explain.")
|
91
|
+
input_sig = model_meta.signatures[target_method].inputs
|
92
|
+
return input_sig
|
93
|
+
|
94
|
+
|
48
95
|
def add_explain_method_signature(
|
49
96
|
model_meta: model_meta.ModelMetadata,
|
50
97
|
explain_method: str,
|
51
|
-
target_method: str,
|
98
|
+
target_method: Optional[str],
|
52
99
|
output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
|
53
100
|
) -> model_meta.ModelMetadata:
|
54
|
-
|
55
|
-
|
56
|
-
|
101
|
+
inputs = get_input_signature(model_meta, target_method)
|
102
|
+
if model_meta.model_type == "snowml":
|
103
|
+
output_feature_names = [identifier.concat_names([spec.name, "_explanation"]) for spec in inputs]
|
104
|
+
else:
|
105
|
+
output_feature_names = [f"{spec.name}_explanation" for spec in inputs]
|
57
106
|
model_meta.signatures[explain_method] = model_signature.ModelSignature(
|
58
107
|
inputs=inputs,
|
59
108
|
outputs=[
|
60
|
-
model_signature.FeatureSpec(dtype=output_return_type, name=
|
109
|
+
model_signature.FeatureSpec(dtype=output_return_type, name=output_name)
|
110
|
+
for output_name in output_feature_names
|
61
111
|
],
|
62
112
|
)
|
63
113
|
return model_meta
|
64
114
|
|
65
115
|
|
116
|
+
def get_explainability_supported_background(
|
117
|
+
sample_input_data: Optional[model_types.SupportedDataType],
|
118
|
+
meta: model_meta.ModelMetadata,
|
119
|
+
explain_target_method: Optional[str],
|
120
|
+
) -> pd.DataFrame:
|
121
|
+
if sample_input_data is None:
|
122
|
+
return None
|
123
|
+
|
124
|
+
if isinstance(sample_input_data, pd.DataFrame):
|
125
|
+
return sample_input_data
|
126
|
+
if isinstance(sample_input_data, sp_df.DataFrame):
|
127
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
|
128
|
+
|
129
|
+
df = model_signature._convert_local_data_to_df(sample_input_data)
|
130
|
+
input_signature_for_explain = get_input_signature(meta, explain_target_method)
|
131
|
+
df_with_named_cols = model_signature_utils.rename_pandas_df(df, input_signature_for_explain)
|
132
|
+
return df_with_named_cols
|
133
|
+
|
134
|
+
|
66
135
|
def get_target_methods(
|
67
136
|
model: model_types.SupportedModelType,
|
68
137
|
target_methods: Optional[Sequence[str]],
|
@@ -75,6 +144,23 @@ def get_target_methods(
|
|
75
144
|
return target_methods
|
76
145
|
|
77
146
|
|
147
|
+
def save_background_data(
|
148
|
+
model_blobs_dir_path: str,
|
149
|
+
explain_artifact_dir: str,
|
150
|
+
bg_data_file_suffix: str,
|
151
|
+
model_name: str,
|
152
|
+
background_data: pd.DataFrame,
|
153
|
+
) -> None:
|
154
|
+
data_blob_path = os.path.join(model_blobs_dir_path, explain_artifact_dir)
|
155
|
+
os.makedirs(data_blob_path, exist_ok=True)
|
156
|
+
with open(os.path.join(data_blob_path, model_name + bg_data_file_suffix), "wb") as f:
|
157
|
+
# saving only the truncated data
|
158
|
+
trunc_background_data = background_data.head(
|
159
|
+
min(len(background_data.index), EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT)
|
160
|
+
)
|
161
|
+
trunc_background_data.to_parquet(f)
|
162
|
+
|
163
|
+
|
78
164
|
def validate_target_methods(model: model_types.SupportedModelType, target_methods: Iterable[str]) -> None:
|
79
165
|
for method_name in target_methods:
|
80
166
|
if not _is_callable(model, method_name):
|
@@ -93,23 +179,43 @@ def convert_explanations_to_2D_df(
|
|
93
179
|
return pd.DataFrame(explanations)
|
94
180
|
|
95
181
|
if hasattr(model, "classes_"):
|
96
|
-
classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
|
182
|
+
classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr]
|
97
183
|
len_classes = len(classes_list)
|
98
184
|
if explanations.shape[2] != len_classes:
|
99
185
|
raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
|
100
186
|
else:
|
101
|
-
classes_list = [i for i in range(explanations.shape[2])]
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
if isinstance(cl, (int, np.integer)):
|
110
|
-
cl = int(cl)
|
111
|
-
class_explanations[cl] = cl_exp
|
112
|
-
col_list.append(json.dumps(class_explanations))
|
113
|
-
exp_2d.append(col_list)
|
187
|
+
classes_list = [str(i) for i in range(explanations.shape[2])]
|
188
|
+
|
189
|
+
def row_to_dict(row: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
190
|
+
"""Converts a single row to a dictionary."""
|
191
|
+
# convert to object or numpy creates strings of fixed length
|
192
|
+
return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object)
|
193
|
+
|
194
|
+
exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
|
114
195
|
|
115
196
|
return pd.DataFrame(exp_2d)
|
197
|
+
|
198
|
+
|
199
|
+
def validate_model_task(passed_model_task: model_types.Task, inferred_model_task: model_types.Task) -> model_types.Task:
|
200
|
+
if passed_model_task != model_types.Task.UNKNOWN and inferred_model_task != model_types.Task.UNKNOWN:
|
201
|
+
if passed_model_task != inferred_model_task:
|
202
|
+
warnings.warn(
|
203
|
+
f"Inferred Task: {inferred_model_task.name} is used as task for this model "
|
204
|
+
f"version and passed argument Task: {passed_model_task.name} is ignored",
|
205
|
+
category=UserWarning,
|
206
|
+
stacklevel=1,
|
207
|
+
)
|
208
|
+
return inferred_model_task
|
209
|
+
elif inferred_model_task != model_types.Task.UNKNOWN:
|
210
|
+
logging.info(f"Inferred Task: {inferred_model_task.name} is used as task for this model " f"version")
|
211
|
+
return inferred_model_task
|
212
|
+
return passed_model_task
|
213
|
+
|
214
|
+
|
215
|
+
def get_explain_target_method(
|
216
|
+
model_metadata: model_meta.ModelMetadata, target_methods_list: List[str]
|
217
|
+
) -> Optional[str]:
|
218
|
+
for method in model_metadata.signatures.keys():
|
219
|
+
if method in target_methods_list:
|
220
|
+
return method
|
221
|
+
return None
|