snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/ml/_internal/platform_capabilities.py +13 -11
- snowflake/ml/_internal/utils/identifier.py +2 -2
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +39 -30
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +1 -1
- snowflake/ml/jobs/_utils/spec_utils.py +1 -1
- snowflake/ml/jobs/decorators.py +6 -0
- snowflake/ml/jobs/job.py +63 -16
- snowflake/ml/jobs/manager.py +50 -16
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/service_ops.py +26 -14
- snowflake/ml/model/_client/service/model_deployment_spec.py +340 -170
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
- snowflake/ml/model/_client/sql/service.py +4 -13
- snowflake/ml/model/_model_composer/model_composer.py +41 -18
- snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +4 -4
- snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
- snowflake/ml/model/custom_model.py +17 -4
- snowflake/ml/model/model_signature.py +3 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +23 -2
- snowflake/ml/registry/registry.py +10 -9
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +40 -8
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/RECORD +190 -189
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class LinearSVC(BaseTransformer):
|
64
72
|
r"""Linear Support Vector Classification
|
65
73
|
For more details on this class, see [sklearn.svm.LinearSVC]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class LinearSVR(BaseTransformer):
|
64
72
|
r"""Linear Support Vector Regression
|
65
73
|
For more details on this class, see [sklearn.svm.LinearSVR]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class NuSVC(BaseTransformer):
|
64
72
|
r"""Nu-Support Vector Classification
|
65
73
|
For more details on this class, see [sklearn.svm.NuSVC]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class NuSVR(BaseTransformer):
|
64
72
|
r"""Nu Support Vector Regression
|
65
73
|
For more details on this class, see [sklearn.svm.NuSVR]
|
snowflake/ml/modeling/svm/svc.py
CHANGED
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class SVC(BaseTransformer):
|
64
72
|
r"""C-Support Vector Classification
|
65
73
|
For more details on this class, see [sklearn.svm.SVC]
|
snowflake/ml/modeling/svm/svr.py
CHANGED
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class SVR(BaseTransformer):
|
64
72
|
r"""Epsilon-Support Vector Regression
|
65
73
|
For more details on this class, see [sklearn.svm.SVR]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class DecisionTreeClassifier(BaseTransformer):
|
64
72
|
r"""A decision tree classifier
|
65
73
|
For more details on this class, see [sklearn.tree.DecisionTreeClassifier]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class DecisionTreeRegressor(BaseTransformer):
|
64
72
|
r"""A decision tree regressor
|
65
73
|
For more details on this class, see [sklearn.tree.DecisionTreeRegressor]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class ExtraTreeClassifier(BaseTransformer):
|
64
72
|
r"""An extremely randomized tree classifier
|
65
73
|
For more details on this class, see [sklearn.tree.ExtraTreeClassifier]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class ExtraTreeRegressor(BaseTransformer):
|
64
72
|
r"""An extremely randomized tree regressor
|
65
73
|
For more details on this class, see [sklearn.tree.ExtraTreeRegressor]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class XGBClassifier(BaseTransformer):
|
64
72
|
r"""Implementation of the scikit-learn API for XGBoost classification
|
65
73
|
For more details on this class, see [xgboost.XGBClassifier]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class XGBRegressor(BaseTransformer):
|
64
72
|
r"""Implementation of the scikit-learn API for XGBoost regression
|
65
73
|
For more details on this class, see [xgboost.XGBRegressor]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class XGBRFClassifier(BaseTransformer):
|
64
72
|
r"""scikit-learn API for XGBoost random forest classification
|
65
73
|
For more details on this class, see [xgboost.XGBRFClassifier]
|
@@ -11,7 +11,7 @@ import cloudpickle as cp
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
from numpy import typing as npt
|
14
|
-
|
14
|
+
from packaging import version
|
15
15
|
|
16
16
|
import numpy
|
17
17
|
import sklearn
|
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
|
60
60
|
|
61
61
|
INFER_SIGNATURE_MAX_ROWS = 100
|
62
62
|
|
63
|
+
SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
|
64
|
+
# Modeling library estimators require a smaller sklearn version range.
|
65
|
+
if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
|
66
|
+
raise Exception(
|
67
|
+
f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
63
71
|
class XGBRFRegressor(BaseTransformer):
|
64
72
|
r"""scikit-learn API for XGBoost random forest regression
|
65
73
|
For more details on this class, see [xgboost.XGBRFRegressor]
|
@@ -0,0 +1,286 @@
|
|
1
|
+
from typing import Union, cast, overload
|
2
|
+
|
3
|
+
import altair as alt
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
import snowflake.snowpark.dataframe as sp_df
|
8
|
+
from snowflake import snowpark
|
9
|
+
from snowflake.ml.model import model_signature, type_hints
|
10
|
+
from snowflake.ml.model._signatures import snowpark_handler
|
11
|
+
|
12
|
+
|
13
|
+
@overload
|
14
|
+
def plot_force(
|
15
|
+
shap_row: snowpark.Row,
|
16
|
+
features_row: snowpark.Row,
|
17
|
+
base_value: float = 0.0,
|
18
|
+
figsize: tuple[float, float] = (600, 200),
|
19
|
+
contribution_threshold: float = 0.05,
|
20
|
+
) -> alt.LayerChart:
|
21
|
+
...
|
22
|
+
|
23
|
+
|
24
|
+
@overload
|
25
|
+
def plot_force(
|
26
|
+
shap_row: pd.Series,
|
27
|
+
features_row: pd.Series,
|
28
|
+
base_value: float = 0.0,
|
29
|
+
figsize: tuple[float, float] = (600, 200),
|
30
|
+
contribution_threshold: float = 0.05,
|
31
|
+
) -> alt.LayerChart:
|
32
|
+
...
|
33
|
+
|
34
|
+
|
35
|
+
def plot_force(
|
36
|
+
shap_row: Union[pd.Series, snowpark.Row],
|
37
|
+
features_row: Union[pd.Series, snowpark.Row],
|
38
|
+
base_value: float = 0.0,
|
39
|
+
figsize: tuple[float, float] = (600, 200),
|
40
|
+
contribution_threshold: float = 0.05,
|
41
|
+
) -> alt.LayerChart:
|
42
|
+
"""
|
43
|
+
Create a force plot for SHAP values with stacked bars based on influence direction.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
shap_row: pandas Series or snowpark Row containing SHAP values for a specific instance
|
47
|
+
features_row: pandas Series or snowpark Row containing the feature values for the same instance
|
48
|
+
base_value: base value of the predictions. Defaults to 0, but is usually the model's average prediction
|
49
|
+
figsize: tuple of (width, height) for the plot
|
50
|
+
contribution_threshold:
|
51
|
+
Only features with magnitude greater than contribution_threshold as a percentage of the
|
52
|
+
total absolute SHAP values will be plotted. Defaults to 0.05 (5%)
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Altair chart object
|
56
|
+
"""
|
57
|
+
if isinstance(shap_row, snowpark.Row):
|
58
|
+
shap_row = pd.Series(shap_row.as_dict())
|
59
|
+
if isinstance(features_row, snowpark.Row):
|
60
|
+
features_row = pd.Series(features_row.as_dict())
|
61
|
+
|
62
|
+
# Create a dataframe for plotting
|
63
|
+
positive_label = "Positive"
|
64
|
+
negative_label = "Negative"
|
65
|
+
plot_df = pd.DataFrame(
|
66
|
+
[
|
67
|
+
{
|
68
|
+
"feature": feature,
|
69
|
+
"feature_value": features_row.iloc[index],
|
70
|
+
"feature_annotated": f"{feature}: {features_row.iloc[index]}",
|
71
|
+
"influence_value": shap_row.iloc[index],
|
72
|
+
"bar_direction": positive_label if shap_row.iloc[index] >= 0 else negative_label,
|
73
|
+
}
|
74
|
+
for index, feature in enumerate(features_row.index)
|
75
|
+
]
|
76
|
+
)
|
77
|
+
|
78
|
+
# Calculate cumulative positions for the stacked bars
|
79
|
+
shap_sum = np.sum(shap_row)
|
80
|
+
current_position_pos = shap_sum
|
81
|
+
current_position_neg = shap_sum
|
82
|
+
positions = []
|
83
|
+
|
84
|
+
total_abs_value_sum = np.sum(plot_df["influence_value"].abs())
|
85
|
+
max_abs_value = plot_df["influence_value"].abs().max()
|
86
|
+
spacing = max_abs_value * 0.07 # Use 2% of max value as spacing between bars
|
87
|
+
|
88
|
+
# Sort by absolute value to have largest impacts first
|
89
|
+
plot_df = plot_df.reindex(plot_df["influence_value"].abs().sort_values(ascending=False).index)
|
90
|
+
for _, row in plot_df.iterrows():
|
91
|
+
# Skip features with small contributions
|
92
|
+
row_influence_value = row["influence_value"]
|
93
|
+
if abs(row_influence_value) / total_abs_value_sum < contribution_threshold:
|
94
|
+
continue
|
95
|
+
|
96
|
+
if row_influence_value >= 0:
|
97
|
+
start = current_position_pos - spacing
|
98
|
+
end = current_position_pos - row_influence_value
|
99
|
+
current_position_pos = end
|
100
|
+
else:
|
101
|
+
start = current_position_neg + spacing
|
102
|
+
end = current_position_neg + abs(row_influence_value)
|
103
|
+
current_position_neg = end
|
104
|
+
|
105
|
+
positions.append(
|
106
|
+
{
|
107
|
+
"start": start,
|
108
|
+
"end": end,
|
109
|
+
"avg": (start + end) / 2,
|
110
|
+
"influence_value": row_influence_value,
|
111
|
+
"influence_annotated": f"Influence: {row_influence_value}",
|
112
|
+
"feature_value": row["feature_value"],
|
113
|
+
"feature_annotated": row["feature_annotated"],
|
114
|
+
"bar_direction": row["bar_direction"],
|
115
|
+
}
|
116
|
+
)
|
117
|
+
|
118
|
+
position_df = pd.DataFrame(positions)
|
119
|
+
|
120
|
+
# Create force plot using Altair
|
121
|
+
blue_color = "#1f77b4"
|
122
|
+
red_color = "#d62728"
|
123
|
+
width, height = figsize
|
124
|
+
bars: alt.Chart = (
|
125
|
+
alt.Chart(position_df)
|
126
|
+
.mark_bar(size=10)
|
127
|
+
.encode(
|
128
|
+
x=alt.X("start:Q", title="Feature Impact"),
|
129
|
+
x2=alt.X2("end:Q"),
|
130
|
+
color=alt.Color(
|
131
|
+
"bar_direction:N",
|
132
|
+
scale=alt.Scale(domain=[positive_label, negative_label], range=[red_color, blue_color]),
|
133
|
+
legend=alt.Legend(title="Influence Direction"),
|
134
|
+
),
|
135
|
+
tooltip=["influence_value", "feature_value"],
|
136
|
+
)
|
137
|
+
.properties(title="Feature Influence (SHAP values)", width=width, height=height)
|
138
|
+
).interactive()
|
139
|
+
|
140
|
+
arrow: alt.Chart = (
|
141
|
+
alt.Chart(position_df)
|
142
|
+
.mark_point(shape="triangle", filled=True, fillOpacity=1)
|
143
|
+
.encode(
|
144
|
+
x=alt.X("start:Q"),
|
145
|
+
angle=alt.Angle("bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=[90, -90])),
|
146
|
+
color=alt.Color(
|
147
|
+
"bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=["#1f77b4", "#d62728"])
|
148
|
+
),
|
149
|
+
size=alt.SizeValue(300),
|
150
|
+
tooltip=alt.value(None),
|
151
|
+
)
|
152
|
+
)
|
153
|
+
|
154
|
+
# Add a vertical line at the base value
|
155
|
+
zero_line: alt.Chart = alt.Chart(pd.DataFrame({"x": [base_value]})).mark_rule(strokeDash=[3, 3]).encode(x="x:Q")
|
156
|
+
|
157
|
+
# Add text labels on each bar
|
158
|
+
feature_labels = (
|
159
|
+
alt.Chart(position_df)
|
160
|
+
.mark_text(align="center", baseline="line-bottom", dy=30, fontSize=11)
|
161
|
+
.encode(
|
162
|
+
x=alt.X("avg:Q"),
|
163
|
+
text=alt.Text("feature_annotated:N"), # Display with 2 decimal places
|
164
|
+
color=alt.value("grey"), # Label color for positive values
|
165
|
+
tooltip=["feature_value"],
|
166
|
+
)
|
167
|
+
)
|
168
|
+
|
169
|
+
return cast(alt.LayerChart, bars + feature_labels + zero_line + arrow)
|
170
|
+
|
171
|
+
|
172
|
+
def plot_influence_sensitivity(
|
173
|
+
feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float] = (600, 400)
|
174
|
+
) -> alt.Chart:
|
175
|
+
"""
|
176
|
+
Create a SHAP dependence scatter plot for a specific feature.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
feature_values: pandas Series containing the feature values for a specific feature
|
180
|
+
shap_values: pandas Series containing the SHAP values for the same feature
|
181
|
+
figsize: tuple of (width, height) for the plot
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
Altair chart object
|
185
|
+
|
186
|
+
"""
|
187
|
+
|
188
|
+
unique_vals = np.sort(np.unique(feature_values.values))
|
189
|
+
max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
|
190
|
+
points_per_value = len(feature_values.values) / len(unique_vals)
|
191
|
+
is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10
|
192
|
+
|
193
|
+
kwargs = (
|
194
|
+
{
|
195
|
+
"x": alt.X("feature_value:N", title="Feature Value"),
|
196
|
+
"color": alt.Color("feature_value:N").legend(None),
|
197
|
+
"xOffset": "jitter:Q",
|
198
|
+
}
|
199
|
+
if is_categorical
|
200
|
+
else {"x": alt.X("feature_value:Q", title="Feature Value")}
|
201
|
+
)
|
202
|
+
|
203
|
+
# Create a dataframe for plotting
|
204
|
+
plot_df = pd.DataFrame({"feature_value": feature_values, "shap_value": shap_values})
|
205
|
+
|
206
|
+
width, height = figsize
|
207
|
+
|
208
|
+
# Create scatter plot
|
209
|
+
scatter = (
|
210
|
+
alt.Chart(plot_df)
|
211
|
+
.transform_calculate(jitter="random()")
|
212
|
+
.mark_circle(size=60, opacity=0.7)
|
213
|
+
.encode(
|
214
|
+
y=alt.Y("shap_value:Q", title="SHAP Value"),
|
215
|
+
tooltip=["feature_value", "shap_value"],
|
216
|
+
**kwargs,
|
217
|
+
)
|
218
|
+
.properties(title="SHAP Dependence Scatter Plot", width=width, height=height)
|
219
|
+
)
|
220
|
+
|
221
|
+
return cast(alt.Chart, scatter)
|
222
|
+
|
223
|
+
|
224
|
+
def plot_violin(
|
225
|
+
shap_df: type_hints.SupportedDataType,
|
226
|
+
feature_df: type_hints.SupportedDataType,
|
227
|
+
figsize: tuple[float, float] = (600, 200),
|
228
|
+
) -> alt.Chart:
|
229
|
+
"""
|
230
|
+
Create a violin plot per feature showing the distribution of SHAP values.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
shap_df: 2D array containing SHAP values for multiple features
|
234
|
+
feature_df: 2D array containing the corresponding feature values
|
235
|
+
figsize: tuple of (width, height) for the plot
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
Altair chart object
|
239
|
+
"""
|
240
|
+
|
241
|
+
shap_df_pd = _convert_to_pandas_df(shap_df)
|
242
|
+
feature_df_pd = _convert_to_pandas_df(feature_df)
|
243
|
+
|
244
|
+
# Assert that the input dataframes are 2D
|
245
|
+
assert len(shap_df_pd.shape) == 2, f"shap_df must be 2D, but got shape {shap_df_pd.shape}"
|
246
|
+
assert len(feature_df_pd.shape) == 2, f"feature_df must be 2D, but got shape {feature_df_pd.shape}"
|
247
|
+
|
248
|
+
# Prepare data for plotting
|
249
|
+
plot_data = pd.DataFrame(
|
250
|
+
{
|
251
|
+
"feature_name": feature_df_pd.columns.repeat(shap_df_pd.shape[0]),
|
252
|
+
"shap_value": shap_df_pd.transpose().values.flatten(),
|
253
|
+
}
|
254
|
+
)
|
255
|
+
|
256
|
+
# Order the rows by the absolute sum of SHAP values per feature
|
257
|
+
feature_abs_sum = shap_df_pd.abs().sum(axis=0)
|
258
|
+
sorted_features = feature_abs_sum.sort_values(ascending=False).index
|
259
|
+
column_sort_order = [feature_df_pd.columns[shap_df_pd.columns.get_loc(col)] for col in sorted_features]
|
260
|
+
|
261
|
+
# Create the violin plot
|
262
|
+
width, height = figsize
|
263
|
+
violin = (
|
264
|
+
alt.Chart(plot_data)
|
265
|
+
.transform_density(density="shap_value", groupby=["feature_name"], as_=["shap_value", "density"])
|
266
|
+
.mark_area(orient="vertical")
|
267
|
+
.encode(
|
268
|
+
y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=True),
|
269
|
+
x=alt.X("shap_value:Q", title="SHAP Value"),
|
270
|
+
row=alt.Row("feature_name:N", sort=column_sort_order).spacing(0),
|
271
|
+
color=alt.Color("feature_name:N", legend=None),
|
272
|
+
tooltip=["feature_name", "shap_value"],
|
273
|
+
)
|
274
|
+
.properties(width=width, height=height)
|
275
|
+
).interactive()
|
276
|
+
|
277
|
+
return cast(alt.Chart, violin)
|
278
|
+
|
279
|
+
|
280
|
+
def _convert_to_pandas_df(
|
281
|
+
data: type_hints.SupportedDataType,
|
282
|
+
) -> pd.DataFrame:
|
283
|
+
if isinstance(data, sp_df.DataFrame):
|
284
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(data)
|
285
|
+
|
286
|
+
return model_signature._convert_local_data_to_df(data)
|
@@ -12,8 +12,10 @@ from snowflake.ml.model import model_signature, type_hints as model_types
|
|
12
12
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
15
|
+
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
15
16
|
from snowflake.ml.model._packager.model_meta import model_meta
|
16
17
|
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
18
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
17
19
|
|
18
20
|
logger = logging.getLogger(__name__)
|
19
21
|
|
@@ -169,7 +171,10 @@ class ModelManager:
|
|
169
171
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
170
172
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
171
173
|
|
172
|
-
|
174
|
+
# TODO(SNOW-2091317): Remove this when the snowpark enables file PUT operation for snowurls
|
175
|
+
use_live_commit = (
|
176
|
+
not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
|
177
|
+
) and platform_capabilities.PlatformCapabilities.get_instance().is_live_commit_enabled()
|
173
178
|
if use_live_commit:
|
174
179
|
logger.info("Using live commit model version")
|
175
180
|
else:
|
@@ -212,8 +217,24 @@ class ModelManager:
|
|
212
217
|
# Convert any string target platforms to TargetPlatform objects
|
213
218
|
platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
|
214
219
|
else:
|
220
|
+
# Default the target platform to warehouse if not specified and any table function exists
|
221
|
+
if options and (
|
222
|
+
options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
223
|
+
or (
|
224
|
+
any(
|
225
|
+
opt.get("function_type") == "TABLE_FUNCTION"
|
226
|
+
for opt in options.get("method_options", {}).values()
|
227
|
+
)
|
228
|
+
)
|
229
|
+
):
|
230
|
+
logger.info(
|
231
|
+
"Logging a partitioned model with a table function without specifying `target_platforms`. "
|
232
|
+
'Default to `target_platforms=["WAREHOUSE"]`.'
|
233
|
+
)
|
234
|
+
platforms = [model_types.TargetPlatform.WAREHOUSE]
|
235
|
+
|
215
236
|
# Default the target platform to SPCS if not specified when running in ML runtime
|
216
|
-
if env.IN_ML_RUNTIME:
|
237
|
+
if not platforms and env.IN_ML_RUNTIME:
|
217
238
|
logger.info(
|
218
239
|
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
219
240
|
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|