snowflake-ml-python 1.6.2__py3-none-any.whl → 1.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/sql_identifier.py +25 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +19 -2
- snowflake/ml/feature_store/feature_view.py +82 -28
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +78 -9
- snowflake/ml/model/_client/ops/model_ops.py +89 -7
- snowflake/ml/model/_client/ops/service_ops.py +200 -91
- snowflake/ml/model/_client/service/model_deployment_spec.py +4 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +35 -13
- snowflake/ml/model/_model_composer/model_composer.py +11 -41
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +29 -4
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +106 -32
- snowflake/ml/model/_packager/model_handlers/catboost.py +26 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -3
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +21 -6
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +111 -58
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -66
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +36 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +22 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -45
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -6
- snowflake/ml/model/_packager/model_packager.py +14 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/type_hints.py +11 -152
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +0 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -0
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -0
- snowflake/ml/modeling/cluster/birch.py +1 -0
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -0
- snowflake/ml/modeling/cluster/dbscan.py +1 -0
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -0
- snowflake/ml/modeling/cluster/k_means.py +1 -0
- snowflake/ml/modeling/cluster/mean_shift.py +1 -0
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -0
- snowflake/ml/modeling/cluster/optics.py +1 -0
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -0
- snowflake/ml/modeling/compose/column_transformer.py +1 -0
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -0
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -0
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -0
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -0
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -0
- snowflake/ml/modeling/covariance/oas.py +1 -0
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -0
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -0
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -0
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -0
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/pca.py +1 -0
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -0
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -0
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -0
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -0
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -0
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -0
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -0
- snowflake/ml/modeling/impute/knn_imputer.py +1 -0
- snowflake/ml/modeling/impute/missing_indicator.py +1 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -0
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -0
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -0
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/lars.py +1 -0
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -0
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/perceptron.py +1 -0
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ridge.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -0
- snowflake/ml/modeling/manifold/isomap.py +1 -0
- snowflake/ml/modeling/manifold/mds.py +1 -0
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -0
- snowflake/ml/modeling/manifold/tsne.py +1 -0
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -0
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -0
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -0
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -0
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -0
- snowflake/ml/modeling/pipeline/pipeline.py +0 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -0
- snowflake/ml/modeling/svm/linear_svc.py +1 -0
- snowflake/ml/modeling/svm/linear_svr.py +1 -0
- snowflake/ml/modeling/svm/nu_svc.py +1 -0
- snowflake/ml/modeling/svm/nu_svr.py +1 -0
- snowflake/ml/modeling/svm/svc.py +1 -0
- snowflake/ml/modeling/svm/svr.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -0
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -4
- snowflake/ml/registry/registry.py +165 -6
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/METADATA +30 -9
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/RECORD +225 -249
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -269
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -106
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,11 @@
|
|
1
|
-
import glob
|
2
1
|
import pathlib
|
3
2
|
import tempfile
|
4
3
|
import uuid
|
5
|
-
import zipfile
|
6
4
|
from types import ModuleType
|
7
5
|
from typing import Any, Dict, List, Optional
|
8
6
|
|
9
7
|
from absl import logging
|
10
8
|
from packaging import requirements
|
11
|
-
from typing_extensions import deprecated
|
12
9
|
|
13
10
|
from snowflake import snowpark
|
14
11
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
@@ -92,7 +89,7 @@ class ModelComposer:
|
|
92
89
|
python_version: Optional[str] = None,
|
93
90
|
ext_modules: Optional[List[ModuleType]] = None,
|
94
91
|
code_paths: Optional[List[str]] = None,
|
95
|
-
|
92
|
+
task: model_types.Task = model_types.Task.UNKNOWN,
|
96
93
|
options: Optional[model_types.ModelSaveOption] = None,
|
97
94
|
) -> model_meta.ModelMetadata:
|
98
95
|
if not options:
|
@@ -121,25 +118,20 @@ class ModelComposer:
|
|
121
118
|
python_version=python_version,
|
122
119
|
ext_modules=ext_modules,
|
123
120
|
code_paths=code_paths,
|
124
|
-
|
121
|
+
task=task,
|
125
122
|
options=options,
|
126
123
|
)
|
127
124
|
assert self.packager.meta is not None
|
128
125
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
)
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
options=options,
|
139
|
-
data_sources=self._get_data_sources(model, sample_input_data),
|
140
|
-
)
|
141
|
-
else:
|
142
|
-
file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
|
126
|
+
file_utils.copytree(
|
127
|
+
str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
|
128
|
+
)
|
129
|
+
self.manifest.save(
|
130
|
+
model_meta=self.packager.meta,
|
131
|
+
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
132
|
+
options=options,
|
133
|
+
data_sources=self._get_data_sources(model, sample_input_data),
|
134
|
+
)
|
143
135
|
|
144
136
|
file_utils.upload_directory_to_stage(
|
145
137
|
self.session,
|
@@ -149,28 +141,6 @@ class ModelComposer:
|
|
149
141
|
)
|
150
142
|
return model_metadata
|
151
143
|
|
152
|
-
@deprecated("Only used by PrPr model registry. Use static method version of load instead.")
|
153
|
-
def legacy_load(
|
154
|
-
self,
|
155
|
-
*,
|
156
|
-
meta_only: bool = False,
|
157
|
-
options: Optional[model_types.ModelLoadOption] = None,
|
158
|
-
) -> None:
|
159
|
-
file_utils.download_directory_from_stage(
|
160
|
-
self.session,
|
161
|
-
stage_path=self.stage_path,
|
162
|
-
local_path=self.workspace_path,
|
163
|
-
statement_params=self._statement_params,
|
164
|
-
)
|
165
|
-
|
166
|
-
# TODO (Server-side Model Rollout): Remove this section.
|
167
|
-
model_zip_path = pathlib.Path(glob.glob(str(self.workspace_path / "*.zip"))[0])
|
168
|
-
self.model_file_rel_path = str(model_zip_path.relative_to(self.workspace_path))
|
169
|
-
|
170
|
-
with zipfile.ZipFile(self.model_local_path, mode="r", compression=zipfile.ZIP_DEFLATED) as zf:
|
171
|
-
zf.extractall(path=self._packager_workspace_path)
|
172
|
-
self.packager.load(meta_only=meta_only, options=options)
|
173
|
-
|
174
144
|
@staticmethod
|
175
145
|
def load(
|
176
146
|
workspace_path: pathlib.Path,
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import collections
|
2
|
-
import
|
2
|
+
import logging
|
3
3
|
import pathlib
|
4
|
+
import warnings
|
4
5
|
from typing import List, Optional, cast
|
5
6
|
|
6
7
|
import yaml
|
@@ -17,6 +18,9 @@ from snowflake.ml.model._packager.model_meta import (
|
|
17
18
|
model_meta as model_meta_api,
|
18
19
|
model_meta_schema,
|
19
20
|
)
|
21
|
+
from snowflake.ml.model._packager.model_runtime import model_runtime
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
20
24
|
|
21
25
|
|
22
26
|
class ModelManifest:
|
@@ -44,9 +48,30 @@ class ModelManifest:
|
|
44
48
|
if options is None:
|
45
49
|
options = {}
|
46
50
|
|
47
|
-
|
48
|
-
|
49
|
-
|
51
|
+
if "relax_version" not in options:
|
52
|
+
warnings.warn(
|
53
|
+
(
|
54
|
+
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed"
|
55
|
+
" from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, "
|
56
|
+
"reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
57
|
+
),
|
58
|
+
category=UserWarning,
|
59
|
+
stacklevel=2,
|
60
|
+
)
|
61
|
+
relax_version = options.get("relax_version", True)
|
62
|
+
|
63
|
+
runtime_to_use = model_runtime.ModelRuntime(
|
64
|
+
name=self._DEFAULT_RUNTIME_NAME,
|
65
|
+
env=model_meta.env,
|
66
|
+
imports=[str(model_rel_path) + "/"],
|
67
|
+
is_gpu=False,
|
68
|
+
is_warehouse=True,
|
69
|
+
)
|
70
|
+
if relax_version:
|
71
|
+
runtime_to_use.runtime_env.relax_version()
|
72
|
+
logger.info("Relaxing version constraints for dependencies in the model.")
|
73
|
+
logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
|
74
|
+
logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
|
50
75
|
runtime_dict = runtime_to_use.save(
|
51
76
|
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
52
77
|
)
|
@@ -21,7 +21,7 @@ _DEFAULT_PIP_REQUIREMENTS_FILENAME = "requirements.txt"
|
|
21
21
|
# The default CUDA version is chosen based on the driver availability in SPCS.
|
22
22
|
# If changing this version, we need also change the version of default PyTorch in HuggingFace pipeline handler to
|
23
23
|
# make sure they are compatible.
|
24
|
-
DEFAULT_CUDA_VERSION = "11.
|
24
|
+
DEFAULT_CUDA_VERSION = "11.8"
|
25
25
|
|
26
26
|
|
27
27
|
class ModelEnv:
|
@@ -199,50 +199,16 @@ class ModelEnv:
|
|
199
199
|
)
|
200
200
|
if xgboost_spec:
|
201
201
|
self.include_if_absent(
|
202
|
-
[
|
203
|
-
ModelDependency(
|
204
|
-
requirement=f"conda-forge::py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost"
|
205
|
-
)
|
206
|
-
],
|
202
|
+
[ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
|
207
203
|
check_local_version=False,
|
208
204
|
)
|
209
205
|
|
210
|
-
pytorch_spec = env_utils.find_dep_spec(
|
211
|
-
self._conda_dependencies,
|
212
|
-
self._pip_requirements,
|
213
|
-
conda_pkg_name="pytorch",
|
214
|
-
pip_pkg_name="torch",
|
215
|
-
remove_spec=True,
|
216
|
-
)
|
217
|
-
pytorch_cuda_spec = env_utils.find_dep_spec(
|
218
|
-
self._conda_dependencies,
|
219
|
-
self._pip_requirements,
|
220
|
-
conda_pkg_name="pytorch-cuda",
|
221
|
-
remove_spec=False,
|
222
|
-
)
|
223
|
-
if pytorch_cuda_spec and not pytorch_cuda_spec.specifier.contains(self.cuda_version):
|
224
|
-
raise ValueError(
|
225
|
-
"The Pytorch-CUDA requirement you specified in your conda dependencies or pip requirements is"
|
226
|
-
" conflicting with CUDA version required. Please do not specify Pytorch-CUDA dependency using conda"
|
227
|
-
" dependencies or pip requirements."
|
228
|
-
)
|
229
|
-
if pytorch_spec:
|
230
|
-
self.include_if_absent(
|
231
|
-
[ModelDependency(requirement=f"pytorch::pytorch{pytorch_spec.specifier}", pip_name="torch")],
|
232
|
-
check_local_version=False,
|
233
|
-
)
|
234
|
-
if not pytorch_cuda_spec:
|
235
|
-
self.include_if_absent(
|
236
|
-
[ModelDependency(requirement=f"pytorch::pytorch-cuda=={self.cuda_version}.*", pip_name="torch")],
|
237
|
-
check_local_version=False,
|
238
|
-
)
|
239
|
-
|
240
206
|
tf_spec = env_utils.find_dep_spec(
|
241
207
|
self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True
|
242
208
|
)
|
243
209
|
if tf_spec:
|
244
210
|
self.include_if_absent(
|
245
|
-
[ModelDependency(requirement=f"
|
211
|
+
[ModelDependency(requirement=f"tensorflow-gpu{tf_spec.specifier}", pip_name="tensorflow")],
|
246
212
|
check_local_version=False,
|
247
213
|
)
|
248
214
|
|
@@ -252,7 +218,7 @@ class ModelEnv:
|
|
252
218
|
if transformers_spec:
|
253
219
|
self.include_if_absent(
|
254
220
|
[
|
255
|
-
ModelDependency(requirement="
|
221
|
+
ModelDependency(requirement="accelerate>=0.22.0", pip_name="accelerate"),
|
256
222
|
ModelDependency(requirement="scipy>=1.9", pip_name="scipy"),
|
257
223
|
],
|
258
224
|
check_local_version=False,
|
@@ -1,17 +1,26 @@
|
|
1
1
|
import json
|
2
|
+
import os
|
2
3
|
import warnings
|
3
|
-
from typing import Any, Callable, Iterable, Optional, Sequence, cast
|
4
|
+
from typing import Any, Callable, Iterable, List, Optional, Sequence, cast
|
4
5
|
|
5
6
|
import numpy as np
|
6
7
|
import numpy.typing as npt
|
7
8
|
import pandas as pd
|
8
9
|
from absl import logging
|
9
10
|
|
11
|
+
import snowflake.snowpark.dataframe as sp_df
|
12
|
+
from snowflake.ml._internal.utils import identifier
|
10
13
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
11
14
|
from snowflake.ml.model._packager.model_meta import model_meta
|
12
|
-
from snowflake.ml.model._signatures import
|
15
|
+
from snowflake.ml.model._signatures import (
|
16
|
+
core,
|
17
|
+
snowpark_handler,
|
18
|
+
utils as model_signature_utils,
|
19
|
+
)
|
13
20
|
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
14
21
|
|
22
|
+
EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT = 1000
|
23
|
+
|
15
24
|
|
16
25
|
class NumpyEncoder(json.JSONEncoder):
|
17
26
|
def default(self, obj: Any) -> Any:
|
@@ -28,6 +37,18 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
|
|
28
37
|
return callable(getattr(model, method_name, None))
|
29
38
|
|
30
39
|
|
40
|
+
def get_truncated_sample_data(sample_input_data: model_types.SupportedDataType) -> model_types.SupportedLocalDataType:
|
41
|
+
trunc_sample_input = model_signature._truncate_data(sample_input_data)
|
42
|
+
local_sample_input: model_types.SupportedLocalDataType = None
|
43
|
+
if isinstance(sample_input_data, SnowparkDataFrame):
|
44
|
+
# Added because of Any from missing stubs.
|
45
|
+
trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
|
46
|
+
local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
|
47
|
+
else:
|
48
|
+
local_sample_input = trunc_sample_input
|
49
|
+
return local_sample_input
|
50
|
+
|
51
|
+
|
31
52
|
def validate_signature(
|
32
53
|
model: model_types.SupportedRequireSignatureModelType,
|
33
54
|
model_meta: model_meta.ModelMetadata,
|
@@ -37,19 +58,23 @@ def validate_signature(
|
|
37
58
|
) -> model_meta.ModelMetadata:
|
38
59
|
if model_meta.signatures:
|
39
60
|
validate_target_methods(model, list(model_meta.signatures.keys()))
|
61
|
+
if sample_input_data is not None:
|
62
|
+
local_sample_input = get_truncated_sample_data(sample_input_data)
|
63
|
+
for target_method in model_meta.signatures.keys():
|
64
|
+
|
65
|
+
model_signature_inst = model_meta.signatures.get(target_method)
|
66
|
+
if model_signature_inst is not None:
|
67
|
+
# strict validation the input signature
|
68
|
+
model_signature._convert_and_validate_local_data(
|
69
|
+
local_sample_input, model_signature_inst._inputs, True
|
70
|
+
)
|
40
71
|
return model_meta
|
41
72
|
|
42
73
|
# In this case sample_input_data should be available, because of the check in save_model.
|
43
74
|
assert (
|
44
75
|
sample_input_data is not None
|
45
76
|
), "Model signature and sample input are None at the same time. This should not happen with local model."
|
46
|
-
|
47
|
-
if isinstance(sample_input_data, SnowparkDataFrame):
|
48
|
-
# Added because of Any from missing stubs.
|
49
|
-
trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
|
50
|
-
local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
|
51
|
-
else:
|
52
|
-
local_sample_input = trunc_sample_input
|
77
|
+
local_sample_input = get_truncated_sample_data(sample_input_data)
|
53
78
|
for target_method in target_methods:
|
54
79
|
predictions_df = get_prediction_fn(target_method, local_sample_input)
|
55
80
|
sig = model_signature.infer_signature(local_sample_input, predictions_df)
|
@@ -58,24 +83,55 @@ def validate_signature(
|
|
58
83
|
return model_meta
|
59
84
|
|
60
85
|
|
86
|
+
def get_input_signature(
|
87
|
+
model_meta: model_meta.ModelMetadata, target_method: Optional[str]
|
88
|
+
) -> Sequence[core.BaseFeatureSpec]:
|
89
|
+
if target_method is None or target_method not in model_meta.signatures:
|
90
|
+
raise ValueError(f"Signature for target method {target_method} is missing or no method to explain.")
|
91
|
+
input_sig = model_meta.signatures[target_method].inputs
|
92
|
+
return input_sig
|
93
|
+
|
94
|
+
|
61
95
|
def add_explain_method_signature(
|
62
96
|
model_meta: model_meta.ModelMetadata,
|
63
97
|
explain_method: str,
|
64
|
-
target_method: str,
|
98
|
+
target_method: Optional[str],
|
65
99
|
output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
|
66
100
|
) -> model_meta.ModelMetadata:
|
67
|
-
|
68
|
-
|
69
|
-
|
101
|
+
inputs = get_input_signature(model_meta, target_method)
|
102
|
+
if model_meta.model_type == "snowml":
|
103
|
+
output_feature_names = [identifier.concat_names([spec.name, "_explanation"]) for spec in inputs]
|
104
|
+
else:
|
105
|
+
output_feature_names = [f"{spec.name}_explanation" for spec in inputs]
|
70
106
|
model_meta.signatures[explain_method] = model_signature.ModelSignature(
|
71
107
|
inputs=inputs,
|
72
108
|
outputs=[
|
73
|
-
model_signature.FeatureSpec(dtype=output_return_type, name=
|
109
|
+
model_signature.FeatureSpec(dtype=output_return_type, name=output_name)
|
110
|
+
for output_name in output_feature_names
|
74
111
|
],
|
75
112
|
)
|
76
113
|
return model_meta
|
77
114
|
|
78
115
|
|
116
|
+
def get_explainability_supported_background(
|
117
|
+
sample_input_data: Optional[model_types.SupportedDataType],
|
118
|
+
meta: model_meta.ModelMetadata,
|
119
|
+
explain_target_method: Optional[str],
|
120
|
+
) -> pd.DataFrame:
|
121
|
+
if sample_input_data is None:
|
122
|
+
return None
|
123
|
+
|
124
|
+
if isinstance(sample_input_data, pd.DataFrame):
|
125
|
+
return sample_input_data
|
126
|
+
if isinstance(sample_input_data, sp_df.DataFrame):
|
127
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
|
128
|
+
|
129
|
+
df = model_signature._convert_local_data_to_df(sample_input_data)
|
130
|
+
input_signature_for_explain = get_input_signature(meta, explain_target_method)
|
131
|
+
df_with_named_cols = model_signature_utils.rename_pandas_df(df, input_signature_for_explain)
|
132
|
+
return df_with_named_cols
|
133
|
+
|
134
|
+
|
79
135
|
def get_target_methods(
|
80
136
|
model: model_types.SupportedModelType,
|
81
137
|
target_methods: Optional[Sequence[str]],
|
@@ -88,6 +144,23 @@ def get_target_methods(
|
|
88
144
|
return target_methods
|
89
145
|
|
90
146
|
|
147
|
+
def save_background_data(
|
148
|
+
model_blobs_dir_path: str,
|
149
|
+
explain_artifact_dir: str,
|
150
|
+
bg_data_file_suffix: str,
|
151
|
+
model_name: str,
|
152
|
+
background_data: pd.DataFrame,
|
153
|
+
) -> None:
|
154
|
+
data_blob_path = os.path.join(model_blobs_dir_path, explain_artifact_dir)
|
155
|
+
os.makedirs(data_blob_path, exist_ok=True)
|
156
|
+
with open(os.path.join(data_blob_path, model_name + bg_data_file_suffix), "wb") as f:
|
157
|
+
# saving only the truncated data
|
158
|
+
trunc_background_data = background_data.head(
|
159
|
+
min(len(background_data.index), EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT)
|
160
|
+
)
|
161
|
+
trunc_background_data.to_parquet(f)
|
162
|
+
|
163
|
+
|
91
164
|
def validate_target_methods(model: model_types.SupportedModelType, target_methods: Iterable[str]) -> None:
|
92
165
|
for method_name in target_methods:
|
93
166
|
if not _is_callable(model, method_name):
|
@@ -123,25 +196,26 @@ def convert_explanations_to_2D_df(
|
|
123
196
|
return pd.DataFrame(exp_2d)
|
124
197
|
|
125
198
|
|
126
|
-
def
|
127
|
-
|
128
|
-
|
129
|
-
if (
|
130
|
-
passed_model_objective != model_types.ModelObjective.UNKNOWN
|
131
|
-
and inferred_model_objective != model_types.ModelObjective.UNKNOWN
|
132
|
-
):
|
133
|
-
if passed_model_objective != inferred_model_objective:
|
199
|
+
def validate_model_task(passed_model_task: model_types.Task, inferred_model_task: model_types.Task) -> model_types.Task:
|
200
|
+
if passed_model_task != model_types.Task.UNKNOWN and inferred_model_task != model_types.Task.UNKNOWN:
|
201
|
+
if passed_model_task != inferred_model_task:
|
134
202
|
warnings.warn(
|
135
|
-
f"Inferred
|
136
|
-
f"version and passed argument
|
203
|
+
f"Inferred Task: {inferred_model_task.name} is used as task for this model "
|
204
|
+
f"version and passed argument Task: {passed_model_task.name} is ignored",
|
137
205
|
category=UserWarning,
|
138
206
|
stacklevel=1,
|
139
207
|
)
|
140
|
-
return
|
141
|
-
elif
|
142
|
-
logging.info(
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
208
|
+
return inferred_model_task
|
209
|
+
elif inferred_model_task != model_types.Task.UNKNOWN:
|
210
|
+
logging.info(f"Inferred Task: {inferred_model_task.name} is used as task for this model " f"version")
|
211
|
+
return inferred_model_task
|
212
|
+
return passed_model_task
|
213
|
+
|
214
|
+
|
215
|
+
def get_explain_target_method(
|
216
|
+
model_metadata: model_meta.ModelMetadata, target_methods_list: List[str]
|
217
|
+
) -> Optional[str]:
|
218
|
+
for method in model_metadata.signatures.keys():
|
219
|
+
if method in target_methods_list:
|
220
|
+
return method
|
221
|
+
return None
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import warnings
|
2
3
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
|
3
4
|
|
4
5
|
import numpy as np
|
@@ -8,7 +9,11 @@ from typing_extensions import TypeGuard, Unpack
|
|
8
9
|
from snowflake.ml._internal import type_utils
|
9
10
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
10
11
|
from snowflake.ml.model._packager.model_env import model_env
|
11
|
-
from snowflake.ml.model._packager.model_handlers import
|
12
|
+
from snowflake.ml.model._packager.model_handlers import (
|
13
|
+
_base,
|
14
|
+
_utils as handlers_utils,
|
15
|
+
model_objective_utils,
|
16
|
+
)
|
12
17
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
13
18
|
from snowflake.ml.model._packager.model_meta import (
|
14
19
|
model_blob_meta,
|
@@ -32,22 +37,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
32
37
|
|
33
38
|
MODEL_BLOB_FILE_OR_DIR = "model.bin"
|
34
39
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
|
-
|
36
|
-
@classmethod
|
37
|
-
def get_model_objective_and_output_type(cls, model: "catboost.CatBoost") -> model_types.ModelObjective:
|
38
|
-
import catboost
|
39
|
-
|
40
|
-
if isinstance(model, catboost.CatBoostClassifier):
|
41
|
-
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
42
|
-
if num_classes == 2:
|
43
|
-
return model_types.ModelObjective.BINARY_CLASSIFICATION
|
44
|
-
return model_types.ModelObjective.MULTI_CLASSIFICATION
|
45
|
-
if isinstance(model, catboost.CatBoostRanker):
|
46
|
-
return model_types.ModelObjective.RANKING
|
47
|
-
if isinstance(model, catboost.CatBoostRegressor):
|
48
|
-
return model_types.ModelObjective.REGRESSION
|
49
|
-
# TODO: Find out model type from the generic Catboost Model
|
50
|
-
return model_types.ModelObjective.UNKNOWN
|
40
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
|
51
41
|
|
52
42
|
@classmethod
|
53
43
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
@@ -107,25 +97,34 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
107
97
|
sample_input_data=sample_input_data,
|
108
98
|
get_prediction_fn=get_prediction,
|
109
99
|
)
|
110
|
-
|
111
|
-
model_meta.
|
112
|
-
model_meta.model_objective, inferred_model_objective
|
113
|
-
)
|
114
|
-
model_objective = model_meta.model_objective
|
100
|
+
model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
|
101
|
+
model_meta.task = model_task_and_output.task
|
115
102
|
if enable_explainability:
|
116
|
-
|
117
|
-
if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
|
118
|
-
output_type = model_signature.DataType.STRING
|
103
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
119
104
|
model_meta = handlers_utils.add_explain_method_signature(
|
120
105
|
model_meta=model_meta,
|
121
106
|
explain_method="explain",
|
122
|
-
target_method=
|
123
|
-
output_return_type=output_type,
|
107
|
+
target_method=explain_target_method,
|
108
|
+
output_return_type=model_task_and_output.output_type,
|
124
109
|
)
|
125
110
|
model_meta.function_properties = {
|
126
111
|
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
127
112
|
}
|
128
113
|
|
114
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
115
|
+
sample_input_data, model_meta, explain_target_method
|
116
|
+
)
|
117
|
+
if background_data is not None:
|
118
|
+
handlers_utils.save_background_data(
|
119
|
+
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
warnings.warn(
|
123
|
+
"sample_input_data should be provided for better explainability results",
|
124
|
+
category=UserWarning,
|
125
|
+
stacklevel=1,
|
126
|
+
)
|
127
|
+
|
129
128
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
130
129
|
os.makedirs(model_blob_path, exist_ok=True)
|
131
130
|
model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
@@ -242,10 +242,10 @@ class HuggingFacePipelineHandler(
|
|
242
242
|
task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
|
243
243
|
)
|
244
244
|
if framework is None or framework == "pt":
|
245
|
-
# Since we set default cuda version to be 11.
|
246
|
-
# Pytorch version that works with CUDA 11.
|
245
|
+
# Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
|
246
|
+
# Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
|
247
247
|
# users are not required to install pytorch locally if they are using the wrapper.
|
248
|
-
pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch
|
248
|
+
pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
|
249
249
|
elif framework == "tf":
|
250
250
|
pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
|
251
251
|
model_meta.env.include_if_absent(
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import warnings
|
2
3
|
from typing import (
|
3
4
|
TYPE_CHECKING,
|
4
5
|
Any,
|
@@ -47,6 +48,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
47
48
|
|
48
49
|
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
49
50
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
51
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
|
50
52
|
|
51
53
|
@classmethod
|
52
54
|
def can_handle(
|
@@ -111,21 +113,34 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
111
113
|
sample_input_data=sample_input_data,
|
112
114
|
get_prediction_fn=get_prediction,
|
113
115
|
)
|
114
|
-
|
115
|
-
model_meta.
|
116
|
-
model_meta.model_objective, model_objective_and_output.objective
|
117
|
-
)
|
116
|
+
model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
|
117
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
118
118
|
if enable_explainability:
|
119
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
119
120
|
model_meta = handlers_utils.add_explain_method_signature(
|
120
121
|
model_meta=model_meta,
|
121
122
|
explain_method="explain",
|
122
|
-
target_method=
|
123
|
-
output_return_type=
|
123
|
+
target_method=explain_target_method,
|
124
|
+
output_return_type=model_task_and_output.output_type,
|
124
125
|
)
|
125
126
|
model_meta.function_properties = {
|
126
127
|
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
127
128
|
}
|
128
129
|
|
130
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
131
|
+
sample_input_data, model_meta, explain_target_method
|
132
|
+
)
|
133
|
+
if background_data is not None:
|
134
|
+
handlers_utils.save_background_data(
|
135
|
+
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
136
|
+
)
|
137
|
+
else:
|
138
|
+
warnings.warn(
|
139
|
+
"sample_input_data should be provided for better explainability results",
|
140
|
+
category=UserWarning,
|
141
|
+
stacklevel=1,
|
142
|
+
)
|
143
|
+
|
129
144
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
130
145
|
os.makedirs(model_blob_path, exist_ok=True)
|
131
146
|
|
@@ -168,11 +168,6 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
168
168
|
) -> "mlflow.pyfunc.PyFuncModel":
|
169
169
|
import mlflow
|
170
170
|
|
171
|
-
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
172
|
-
# We need to redirect the mlruns folder to a writable location in the sandbox.
|
173
|
-
tmpdir = tempfile.TemporaryDirectory(dir="/tmp")
|
174
|
-
mlflow.set_tracking_uri(f"file://{tmpdir}")
|
175
|
-
|
176
171
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
177
172
|
model_blobs_metadata = model_meta.models
|
178
173
|
model_blob_metadata = model_blobs_metadata[name]
|
@@ -183,6 +178,9 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
183
178
|
model_artifact_path = model_blob_options["artifact_path"]
|
184
179
|
model_blob_filename = model_blob_metadata.path
|
185
180
|
|
181
|
+
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
182
|
+
return mlflow.pyfunc.load_model(os.path.join(model_blob_path, model_blob_filename, model_artifact_path))
|
183
|
+
|
186
184
|
# This is to make sure the loaded model can be saved again.
|
187
185
|
with mlflow.start_run() as run:
|
188
186
|
mlflow.log_artifacts(
|