snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/ml/_internal/platform_capabilities.py +13 -11
- snowflake/ml/_internal/telemetry.py +42 -13
- snowflake/ml/_internal/utils/identifier.py +2 -2
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/_utils/constants.py +10 -1
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +51 -34
- snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +86 -3
- snowflake/ml/jobs/_utils/spec_utils.py +8 -6
- snowflake/ml/jobs/decorators.py +13 -3
- snowflake/ml/jobs/job.py +206 -26
- snowflake/ml/jobs/manager.py +78 -34
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/service_ops.py +31 -17
- snowflake/ml/model/_client/service/model_deployment_spec.py +351 -170
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +20 -32
- snowflake/ml/model/_model_composer/model_composer.py +44 -19
- snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +5 -4
- snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
- snowflake/ml/model/custom_model.py +17 -4
- snowflake/ml/model/model_signature.py +3 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/explain_visualize.py +424 -0
- snowflake/ml/registry/_manager/model_manager.py +23 -2
- snowflake/ml/registry/registry.py +10 -9
- snowflake/ml/utils/connection_params.py +8 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +58 -8
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +196 -195
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,8 @@
|
|
1
1
|
import enum
|
2
|
-
import json
|
3
2
|
import textwrap
|
4
3
|
from typing import Any, Optional, Union
|
5
4
|
|
6
5
|
from snowflake import snowpark
|
7
|
-
from snowflake.ml._internal import platform_capabilities
|
8
6
|
from snowflake.ml._internal.utils import (
|
9
7
|
identifier,
|
10
8
|
query_result_checker,
|
@@ -16,22 +14,25 @@ from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
|
16
14
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
17
15
|
|
18
16
|
|
17
|
+
# The enum comes from https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service#output
|
18
|
+
# except UNKNOWN
|
19
19
|
class ServiceStatus(enum.Enum):
|
20
20
|
UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
|
21
21
|
PENDING = "PENDING" # resource set is being created, can't be used yet
|
22
|
-
READY = "READY" # resource set has been deployed.
|
23
22
|
SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
|
24
23
|
SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
|
25
24
|
DELETING = "DELETING" # resource set is being deleted
|
26
25
|
FAILED = "FAILED" # resource set has failed and cannot be used anymore
|
27
26
|
DONE = "DONE" # resource set has finished running
|
28
|
-
NOT_FOUND = "NOT_FOUND" # not found or deleted
|
29
27
|
INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
|
28
|
+
RUNNING = "RUNNING"
|
29
|
+
DELETED = "DELETED"
|
30
30
|
|
31
31
|
|
32
32
|
class ServiceSQLClient(_base._BaseSQLClient):
|
33
33
|
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
34
34
|
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
35
|
+
SERVICE_STATUS = "service_status"
|
35
36
|
|
36
37
|
def build_model_container(
|
37
38
|
self,
|
@@ -133,18 +134,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
133
134
|
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
134
135
|
args_sql = f"object_construct_keep_null({input_args_sql})"
|
135
136
|
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
|
141
|
-
else:
|
142
|
-
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
143
|
-
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
144
|
-
actual_database_name.identifier(),
|
145
|
-
actual_schema_name.identifier(),
|
146
|
-
function_name,
|
147
|
-
)
|
137
|
+
fully_qualified_service_name = self.fully_qualified_object_name(
|
138
|
+
actual_database_name, actual_schema_name, service_name
|
139
|
+
)
|
140
|
+
fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
|
148
141
|
|
149
142
|
sql = textwrap.dedent(
|
150
143
|
f"""{with_sql}
|
@@ -208,22 +201,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
208
201
|
include_message: bool = False,
|
209
202
|
statement_params: Optional[dict[str, Any]] = None,
|
210
203
|
) -> tuple[ServiceStatus, Optional[str]]:
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
if metadata and metadata["status"]:
|
223
|
-
service_status = ServiceStatus(metadata["status"])
|
224
|
-
message = metadata["message"] if include_message else None
|
225
|
-
return service_status, message
|
226
|
-
return ServiceStatus.UNKNOWN, None
|
204
|
+
fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
205
|
+
query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
|
206
|
+
rows = self._session.sql(query).collect(statement_params=statement_params)
|
207
|
+
if len(rows) == 0:
|
208
|
+
return ServiceStatus.UNKNOWN, None
|
209
|
+
row = rows[0]
|
210
|
+
service_status = row[ServiceSQLClient.SERVICE_STATUS]
|
211
|
+
message = row["message"] if include_message else None
|
212
|
+
if not isinstance(service_status, ServiceStatus):
|
213
|
+
return ServiceStatus.UNKNOWN, message
|
214
|
+
return ServiceStatus(service_status), message
|
227
215
|
|
228
216
|
def drop_service(
|
229
217
|
self,
|
@@ -142,30 +142,55 @@ class ModelComposer:
|
|
142
142
|
conda_dep_dict = env_utils.validate_conda_dependency_string_list(
|
143
143
|
conda_dependencies if conda_dependencies else []
|
144
144
|
)
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
145
|
+
|
146
|
+
enable_explainability = None
|
147
|
+
|
148
|
+
if options:
|
149
|
+
enable_explainability = options.get("enable_explainability", None)
|
150
|
+
|
151
|
+
# skip everything if user said False explicitly
|
152
|
+
if enable_explainability is None or enable_explainability is True:
|
153
|
+
is_warehouse_runnable = (
|
154
|
+
not conda_dep_dict
|
155
|
+
or all(
|
156
|
+
chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
157
|
+
for chan in conda_dep_dict
|
158
|
+
)
|
159
|
+
) and (not pip_requirements)
|
160
|
+
|
161
|
+
only_spcs = (
|
162
|
+
target_platforms
|
163
|
+
and len(target_platforms) == 1
|
164
|
+
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
161
165
|
)
|
166
|
+
if only_spcs or (not is_warehouse_runnable):
|
167
|
+
# if only SPCS and user asked for explainability we fail
|
168
|
+
if enable_explainability is True:
|
169
|
+
raise ValueError(
|
170
|
+
"`enable_explainability` cannot be set to True when the model is not runnable in WH "
|
171
|
+
"or the target platforms include SPCS."
|
172
|
+
)
|
173
|
+
elif not options: # explicitly set flag to false in these cases if not specified
|
174
|
+
options = model_types.BaseModelSaveOption()
|
175
|
+
options["enable_explainability"] = False
|
176
|
+
elif (
|
177
|
+
target_platforms
|
178
|
+
and len(target_platforms) > 1
|
179
|
+
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
180
|
+
): # if both then only available for WH
|
181
|
+
if enable_explainability is True:
|
182
|
+
warnings.warn(
|
183
|
+
("Explain function will only be available for model deployed to warehouse."),
|
184
|
+
category=UserWarning,
|
185
|
+
stacklevel=2,
|
186
|
+
)
|
162
187
|
|
163
188
|
if not options:
|
164
189
|
options = model_types.BaseModelSaveOption()
|
165
|
-
if disable_explainability:
|
166
|
-
options["enable_explainability"] = False
|
167
190
|
|
168
|
-
if not snowpark_utils.is_in_stored_procedure()
|
191
|
+
if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
|
192
|
+
model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
|
193
|
+
]:
|
169
194
|
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
170
195
|
self.session,
|
171
196
|
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
@@ -109,6 +109,35 @@ def get_input_signature(
|
|
109
109
|
return input_sig
|
110
110
|
|
111
111
|
|
112
|
+
def add_inferred_explain_method_signature(
|
113
|
+
model_meta: model_meta.ModelMetadata,
|
114
|
+
explain_method: str,
|
115
|
+
target_method: str,
|
116
|
+
background_data: model_types.SupportedDataType,
|
117
|
+
explain_fn: Callable[[model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
|
118
|
+
output_feature_names: Optional[Sequence[str]] = None,
|
119
|
+
) -> model_meta.ModelMetadata:
|
120
|
+
inputs = get_input_signature(model_meta, target_method)
|
121
|
+
if output_feature_names is None: # If not provided, assume output feature names are the same as input feature names
|
122
|
+
output_feature_names = [spec.name for spec in inputs]
|
123
|
+
|
124
|
+
if model_meta.model_type == "snowml":
|
125
|
+
suffixed_output_names = [identifier.concat_names([name, "_explanation"]) for name in output_feature_names]
|
126
|
+
else:
|
127
|
+
suffixed_output_names = [f"{name}_explanation" for name in output_feature_names]
|
128
|
+
|
129
|
+
truncated_background_data = get_truncated_sample_data(background_data, 5)
|
130
|
+
sig = model_signature.infer_signature(
|
131
|
+
input_data=truncated_background_data,
|
132
|
+
output_data=explain_fn(truncated_background_data),
|
133
|
+
input_feature_names=[spec.name for spec in inputs],
|
134
|
+
output_feature_names=suffixed_output_names,
|
135
|
+
)
|
136
|
+
|
137
|
+
model_meta.signatures[explain_method] = sig
|
138
|
+
return model_meta
|
139
|
+
|
140
|
+
|
112
141
|
def add_explain_method_signature(
|
113
142
|
model_meta: model_meta.ModelMetadata,
|
114
143
|
explain_method: str,
|
@@ -236,8 +265,9 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
|
|
236
265
|
def get_explain_target_method(
|
237
266
|
model_metadata: model_meta.ModelMetadata, target_methods_list: list[str]
|
238
267
|
) -> Optional[str]:
|
239
|
-
|
240
|
-
|
268
|
+
"""Returns the first target method that is found in the model metadata signatures."""
|
269
|
+
for method in target_methods_list:
|
270
|
+
if method in model_metadata.signatures.keys():
|
241
271
|
return method
|
242
272
|
return None
|
243
273
|
|
@@ -72,7 +72,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
72
72
|
predictions_df = target_method(model, sample_input_data)
|
73
73
|
return predictions_df
|
74
74
|
|
75
|
-
for func_name in model.
|
75
|
+
for func_name in model._get_partitioned_methods():
|
76
76
|
function_properties = model_meta.function_properties.get(func_name, {})
|
77
77
|
function_properties[model_meta_schema.FunctionProperties.PARTITIONED.value] = True
|
78
78
|
model_meta.function_properties[func_name] = function_properties
|
@@ -82,6 +82,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
82
82
|
enable_explainability = kwargs.get("enable_explainability", False)
|
83
83
|
if enable_explainability:
|
84
84
|
raise NotImplementedError("Explainability is not supported for PyTorch model.")
|
85
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
85
86
|
|
86
87
|
import torch
|
87
88
|
|
@@ -94,8 +95,6 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
94
95
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
95
96
|
)
|
96
97
|
|
97
|
-
multiple_inputs = kwargs.get("multiple_inputs", False)
|
98
|
-
|
99
98
|
def get_prediction(
|
100
99
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
101
100
|
) -> model_types.SupportedLocalDataType:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Callable, Optional, Union, cast, final
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
@@ -38,6 +38,35 @@ def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "s
|
|
38
38
|
return model
|
39
39
|
|
40
40
|
|
41
|
+
def _apply_transforms_up_to_last_step(
|
42
|
+
model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
43
|
+
data: model_types.SupportedDataType,
|
44
|
+
input_feature_names: Optional[list[str]] = None,
|
45
|
+
) -> pd.DataFrame:
|
46
|
+
"""Apply all transformations in the sklearn pipeline model up to the last step."""
|
47
|
+
transformed_data = data
|
48
|
+
output_features_names = input_feature_names
|
49
|
+
|
50
|
+
if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
|
51
|
+
for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
|
52
|
+
if not hasattr(step, "transform"):
|
53
|
+
raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
|
54
|
+
transformed_data = step.transform(transformed_data)
|
55
|
+
if output_features_names is None:
|
56
|
+
continue
|
57
|
+
elif hasattr(step, "get_feature_names_out"):
|
58
|
+
output_features_names = step.get_feature_names_out(output_features_names)
|
59
|
+
else:
|
60
|
+
raise ValueError(
|
61
|
+
f"Step '{step_name}' in the pipeline does not have a 'get_feature_names_out' method. "
|
62
|
+
"Feature names cannot be propagated."
|
63
|
+
)
|
64
|
+
if type_utils.LazyType("scipy.sparse.csr_matrix").isinstance(transformed_data):
|
65
|
+
# Convert to dense array if it's a sparse matrix
|
66
|
+
transformed_data = transformed_data.toarray() # type: ignore[attr-defined]
|
67
|
+
return pd.DataFrame(transformed_data, columns=output_features_names)
|
68
|
+
|
69
|
+
|
41
70
|
@final
|
42
71
|
class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
|
43
72
|
"""Handler for scikit-learn based model.
|
@@ -58,7 +87,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
58
87
|
"decision_function",
|
59
88
|
"score_samples",
|
60
89
|
]
|
61
|
-
|
90
|
+
|
91
|
+
# Prioritize predict_proba as it gives multi-class probabilities
|
92
|
+
EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
|
62
93
|
|
63
94
|
@classmethod
|
64
95
|
def can_handle(
|
@@ -160,17 +191,38 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
160
191
|
stacklevel=1,
|
161
192
|
)
|
162
193
|
enable_explainability = False
|
163
|
-
elif model_meta.task == model_types.Task.UNKNOWN
|
194
|
+
elif model_meta.task == model_types.Task.UNKNOWN:
|
195
|
+
enable_explainability = False
|
196
|
+
elif explain_target_method is None:
|
164
197
|
enable_explainability = False
|
165
198
|
else:
|
166
199
|
enable_explainability = True
|
167
200
|
if enable_explainability:
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
201
|
+
explain_target_method = str(explain_target_method) # mypy complains if we don't cast to str here
|
202
|
+
|
203
|
+
input_signature = handlers_utils.get_input_signature(model_meta, explain_target_method)
|
204
|
+
transformed_background_data = _apply_transforms_up_to_last_step(
|
205
|
+
model=model,
|
206
|
+
data=background_data,
|
207
|
+
input_feature_names=[spec.name for spec in input_signature],
|
173
208
|
)
|
209
|
+
|
210
|
+
try:
|
211
|
+
model_meta = handlers_utils.add_inferred_explain_method_signature(
|
212
|
+
model_meta=model_meta,
|
213
|
+
explain_method="explain",
|
214
|
+
target_method=explain_target_method,
|
215
|
+
background_data=background_data,
|
216
|
+
explain_fn=cls._build_explain_fn(model, background_data, input_signature),
|
217
|
+
output_feature_names=transformed_background_data.columns,
|
218
|
+
)
|
219
|
+
except Exception:
|
220
|
+
if kwargs.get("enable_explainability", None):
|
221
|
+
# user explicitly enabled explainability, so we should raise the error
|
222
|
+
raise ValueError(
|
223
|
+
"Explainability for this model is not supported. Please set `enable_explainability=False`"
|
224
|
+
)
|
225
|
+
|
174
226
|
handlers_utils.save_background_data(
|
175
227
|
model_blobs_dir_path,
|
176
228
|
cls.EXPLAIN_ARTIFACTS_DIR,
|
@@ -222,11 +274,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
222
274
|
)
|
223
275
|
|
224
276
|
if enable_explainability:
|
225
|
-
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
277
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
|
226
278
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
227
279
|
|
228
280
|
model_meta.env.include_if_absent(
|
229
|
-
[
|
281
|
+
[
|
282
|
+
model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
|
283
|
+
],
|
230
284
|
check_local_version=True,
|
231
285
|
)
|
232
286
|
|
@@ -286,37 +340,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
286
340
|
|
287
341
|
@custom_model.inference_api
|
288
342
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
289
|
-
|
290
|
-
|
291
|
-
try:
|
292
|
-
explainer = shap.Explainer(raw_model, background_data)
|
293
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
294
|
-
except TypeError:
|
295
|
-
try:
|
296
|
-
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
|
297
|
-
|
298
|
-
if isinstance(X, pd.DataFrame):
|
299
|
-
X = X.astype(dtype_map, copy=False)
|
300
|
-
if hasattr(raw_model, "predict_proba"):
|
301
|
-
if isinstance(X, np.ndarray):
|
302
|
-
explanations = shap.Explainer(
|
303
|
-
raw_model.predict_proba, background_data.values # type: ignore[union-attr]
|
304
|
-
)(X).values
|
305
|
-
else:
|
306
|
-
explanations = shap.Explainer(raw_model.predict_proba, background_data)(X).values
|
307
|
-
elif hasattr(raw_model, "predict"):
|
308
|
-
if isinstance(X, np.ndarray):
|
309
|
-
explanations = shap.Explainer(
|
310
|
-
raw_model.predict, background_data.values # type: ignore[union-attr]
|
311
|
-
)(X).values
|
312
|
-
else:
|
313
|
-
explanations = shap.Explainer(raw_model.predict, background_data)(X).values
|
314
|
-
else:
|
315
|
-
raise ValueError("Missing any supported target method to explain.")
|
316
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explanations)
|
317
|
-
except TypeError as e:
|
318
|
-
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
319
|
-
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
343
|
+
fn = cls._build_explain_fn(raw_model, background_data, signature.inputs)
|
344
|
+
return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
|
320
345
|
|
321
346
|
if target_method == "explain":
|
322
347
|
return explain_fn
|
@@ -339,3 +364,37 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
339
364
|
skl_model = _SKLModel(custom_model.ModelContext())
|
340
365
|
|
341
366
|
return skl_model
|
367
|
+
|
368
|
+
@classmethod
|
369
|
+
def _build_explain_fn(
|
370
|
+
cls,
|
371
|
+
model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
372
|
+
background_data: model_types.SupportedDataType,
|
373
|
+
input_specs: Sequence[model_signature.BaseFeatureSpec],
|
374
|
+
) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
|
375
|
+
import shap
|
376
|
+
import sklearn.pipeline
|
377
|
+
|
378
|
+
transformed_bg_data = _apply_transforms_up_to_last_step(model, background_data)
|
379
|
+
|
380
|
+
def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
|
381
|
+
transformed_data = _apply_transforms_up_to_last_step(model, data)
|
382
|
+
predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
|
383
|
+
try:
|
384
|
+
explainer = shap.Explainer(predictor, transformed_bg_data)
|
385
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
|
386
|
+
except TypeError:
|
387
|
+
if isinstance(data, pd.DataFrame):
|
388
|
+
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
|
389
|
+
transformed_data = _apply_transforms_up_to_last_step(model, data.astype(dtype_map))
|
390
|
+
for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
|
391
|
+
if not hasattr(predictor, explain_target_method):
|
392
|
+
continue
|
393
|
+
explain_target_method_fn = getattr(predictor, explain_target_method)
|
394
|
+
explanations = shap.Explainer(explain_target_method_fn, transformed_bg_data.values)(
|
395
|
+
transformed_data.to_numpy()
|
396
|
+
).values
|
397
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explanations)
|
398
|
+
raise ValueError("Missing any supported target method to explain.")
|
399
|
+
|
400
|
+
return explain_fn
|
@@ -88,6 +88,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
88
88
|
import tensorflow
|
89
89
|
|
90
90
|
assert isinstance(model, tensorflow.Module)
|
91
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
91
92
|
|
92
93
|
is_keras_model = type_utils.LazyType("keras.Model").isinstance(model)
|
93
94
|
is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
|
@@ -112,8 +113,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
112
113
|
default_target_methods=default_target_methods,
|
113
114
|
)
|
114
115
|
|
115
|
-
multiple_inputs = kwargs.get("multiple_inputs", False)
|
116
|
-
|
117
116
|
if is_keras_model and len(target_methods) > 1:
|
118
117
|
raise ValueError("Keras model can only have one target method.")
|
119
118
|
|
@@ -198,7 +197,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
198
197
|
model_blobs_dir_path: str,
|
199
198
|
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
200
199
|
) -> "tensorflow.Module":
|
201
|
-
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
202
200
|
import tensorflow
|
203
201
|
|
204
202
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
@@ -209,7 +207,12 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
209
207
|
load_path = os.path.join(model_blob_path, model_blob_filename)
|
210
208
|
save_format = model_blob_options.get("save_format", "keras_tf")
|
211
209
|
if save_format == "keras_tf":
|
212
|
-
|
210
|
+
if version.parse(tensorflow.keras.__version__) >= version.parse("3.0.0"):
|
211
|
+
import tf_keras
|
212
|
+
|
213
|
+
m = tf_keras.models.load_model(load_path)
|
214
|
+
else:
|
215
|
+
m = tensorflow.keras.models.load_model(load_path)
|
213
216
|
else:
|
214
217
|
m = tensorflow.saved_model.load(load_path)
|
215
218
|
|
@@ -76,6 +76,8 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
76
76
|
if enable_explainability:
|
77
77
|
raise NotImplementedError("Explainability is not supported for Torch Script model.")
|
78
78
|
|
79
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
80
|
+
|
79
81
|
import torch
|
80
82
|
|
81
83
|
assert isinstance(model, torch.jit.ScriptModule)
|
@@ -87,8 +89,6 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
87
89
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
88
90
|
)
|
89
91
|
|
90
|
-
multiple_inputs = kwargs.get("multiple_inputs", False)
|
91
|
-
|
92
92
|
def get_prediction(
|
93
93
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
94
94
|
) -> model_types.SupportedLocalDataType:
|
@@ -144,7 +144,12 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
144
144
|
model_type=cls.HANDLER_TYPE,
|
145
145
|
handler_version=cls.HANDLER_VERSION,
|
146
146
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
147
|
-
options=model_meta_schema.XgboostModelBlobOptions(
|
147
|
+
options=model_meta_schema.XgboostModelBlobOptions(
|
148
|
+
{
|
149
|
+
"xgb_estimator_type": model.__class__.__name__,
|
150
|
+
"enable_categorical": getattr(model, "enable_categorical", False),
|
151
|
+
}
|
152
|
+
),
|
148
153
|
)
|
149
154
|
model_meta.models[name] = base_meta
|
150
155
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -152,11 +157,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
152
157
|
model_meta.env.include_if_absent(
|
153
158
|
[
|
154
159
|
model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
|
155
|
-
],
|
156
|
-
check_local_version=True,
|
157
|
-
)
|
158
|
-
model_meta.env.include_if_absent(
|
159
|
-
[
|
160
160
|
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
161
161
|
],
|
162
162
|
check_local_version=True,
|
@@ -190,6 +190,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
190
190
|
raise ValueError("Type of XGB estimator is illegal.")
|
191
191
|
m = getattr(xgboost, xgb_estimator_type)()
|
192
192
|
m.load_model(os.path.join(model_blob_path, model_blob_filename))
|
193
|
+
m.enable_categorical = model_blob_options.get("enable_categorical", False)
|
193
194
|
|
194
195
|
if kwargs.get("use_gpu", False):
|
195
196
|
assert type(kwargs.get("use_gpu", False)) == bool
|
@@ -225,8 +226,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
225
226
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
226
227
|
@custom_model.inference_api
|
227
228
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
229
|
+
enable_categorical = False
|
230
|
+
for col, d_type in X.dtypes.items():
|
231
|
+
if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
|
232
|
+
continue
|
233
|
+
if not np.issubdtype(d_type, np.number):
|
234
|
+
# categorical columns are converted to numpy's str dtype
|
235
|
+
X[col] = X[col].astype("category")
|
236
|
+
enable_categorical = True
|
228
237
|
if isinstance(raw_model, xgboost.Booster):
|
229
|
-
X = xgboost.DMatrix(X)
|
238
|
+
X = xgboost.DMatrix(X, enable_categorical=enable_categorical)
|
230
239
|
|
231
240
|
res = getattr(raw_model, target_method)(X)
|
232
241
|
|
@@ -65,7 +65,8 @@ def create_model_metadata(
|
|
65
65
|
ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
|
66
66
|
conda_dependencies: List of conda requirements for running the model. Defaults to None.
|
67
67
|
pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
|
68
|
-
artifact_repository_map: A dict mapping from package channel to artifact repository name.
|
68
|
+
artifact_repository_map: A dict mapping from package channel to artifact repository name (e.g.
|
69
|
+
{'pip': 'snowflake.snowpark.pypi_shared_repository'}).
|
69
70
|
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
70
71
|
target_platforms: List of target platforms to run the model.
|
71
72
|
python_version: A string of python version where model is run. Used for user override. If specified as None,
|
@@ -6,13 +6,13 @@ REQUIREMENTS = [
|
|
6
6
|
"aiohttp!=4.0.0a0, !=4.0.0a1",
|
7
7
|
"anyio>=3.5.0,<5",
|
8
8
|
"cachetools>=3.1.1,<6",
|
9
|
-
"cloudpickle>=2.0.0
|
9
|
+
"cloudpickle>=2.0.0",
|
10
10
|
"cryptography",
|
11
11
|
"fsspec>=2024.6.1,<2026",
|
12
12
|
"importlib_resources>=6.1.1, <7",
|
13
13
|
"numpy>=1.23,<2",
|
14
14
|
"packaging>=20.9,<25",
|
15
|
-
"pandas>=1.
|
15
|
+
"pandas>=2.1.4,<3",
|
16
16
|
"pyarrow",
|
17
17
|
"pydantic>=2.8.2, <3",
|
18
18
|
"pyjwt>=2.0.0, <3",
|
@@ -21,9 +21,10 @@ REQUIREMENTS = [
|
|
21
21
|
"requests",
|
22
22
|
"retrying>=1.3.3,<2",
|
23
23
|
"s3fs>=2024.6.1,<2026",
|
24
|
-
"scikit-learn
|
24
|
+
"scikit-learn<1.6",
|
25
25
|
"scipy>=1.9,<2",
|
26
|
-
"
|
26
|
+
"shap>=0.46.0,<1",
|
27
|
+
"snowflake-connector-python>=3.15.0,<4",
|
27
28
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
28
29
|
"snowflake.core>=1.0.2,<2",
|
29
30
|
"sqlparse>=0.4,<1",
|
@@ -81,8 +81,16 @@ class XGBoostDMatrixHandler(base_handler.BaseDataHandler["xgboost.DMatrix"]):
|
|
81
81
|
) -> "xgboost.DMatrix":
|
82
82
|
import xgboost as xgb
|
83
83
|
|
84
|
+
enable_categorical = False
|
85
|
+
for col, d_type in df.dtypes.items():
|
86
|
+
if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
|
87
|
+
continue
|
88
|
+
if not np.issubdtype(d_type, np.number):
|
89
|
+
df[col] = df[col].astype("category")
|
90
|
+
enable_categorical = True
|
91
|
+
|
84
92
|
if not features:
|
85
|
-
return xgb.DMatrix(df)
|
93
|
+
return xgb.DMatrix(df, enable_categorical=enable_categorical)
|
86
94
|
else:
|
87
95
|
feature_names = []
|
88
96
|
feature_types = []
|
@@ -95,4 +103,9 @@ class XGBoostDMatrixHandler(base_handler.BaseDataHandler["xgboost.DMatrix"]):
|
|
95
103
|
assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
|
96
104
|
feature_names.append(feature.name)
|
97
105
|
feature_types.append(feature._dtype._numpy_type)
|
98
|
-
return xgb.DMatrix(
|
106
|
+
return xgb.DMatrix(
|
107
|
+
df,
|
108
|
+
feature_names=feature_names,
|
109
|
+
feature_types=feature_types,
|
110
|
+
enable_categorical=enable_categorical,
|
111
|
+
)
|