snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -280,7 +280,7 @@ def _get_or_create_image_repo(session: Session, *, service_func_name: str, image
|
|
280
280
|
conn = session._conn._conn
|
281
281
|
# We try to use the same db and schema as the service function locates, as we could retrieve those information
|
282
282
|
# if that is a fully qualified one. If not we use the current session one.
|
283
|
-
(_db, _schema, _
|
283
|
+
(_db, _schema, _) = identifier.parse_schema_level_object_identifier(service_func_name)
|
284
284
|
db = _db if _db is not None else conn._database
|
285
285
|
schema = _schema if _schema is not None else conn._schema
|
286
286
|
assert isinstance(db, str) and isinstance(schema, str)
|
@@ -343,7 +343,7 @@ class SnowServiceDeployment:
|
|
343
343
|
self.model_zip_stage_path = model_zip_stage_path
|
344
344
|
self.options = options
|
345
345
|
self.target_method = target_method
|
346
|
-
(db, schema, _
|
346
|
+
(db, schema, _) = identifier.parse_schema_level_object_identifier(service_func_name)
|
347
347
|
|
348
348
|
self._service_name = identifier.get_schema_level_object_identifier(db, schema, f"service_{model_id}")
|
349
349
|
self._job_name = identifier.get_schema_level_object_identifier(db, schema, f"build_{model_id}")
|
@@ -503,7 +503,7 @@ class SnowServiceDeployment:
|
|
503
503
|
norm_stage_path = posixpath.normpath(identifier.remove_prefix(self.model_zip_stage_path, "@"))
|
504
504
|
# Ensure model stage path has root prefix as stage mount will it mount it to root.
|
505
505
|
absolute_model_stage_path = os.path.join("/", norm_stage_path)
|
506
|
-
(db, schema, stage, path) = identifier.
|
506
|
+
(db, schema, stage, path) = identifier.parse_snowflake_stage_path(norm_stage_path)
|
507
507
|
substitutes = {
|
508
508
|
"image": image,
|
509
509
|
"predict_endpoint_name": constants.PREDICT,
|
@@ -10,6 +10,7 @@ from absl import logging
|
|
10
10
|
from packaging import requirements
|
11
11
|
from typing_extensions import deprecated
|
12
12
|
|
13
|
+
from snowflake import snowpark
|
13
14
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
15
|
from snowflake.ml._internal.lineage import lineage_utils
|
15
16
|
from snowflake.ml.data import data_source
|
@@ -91,6 +92,7 @@ class ModelComposer:
|
|
91
92
|
python_version: Optional[str] = None,
|
92
93
|
ext_modules: Optional[List[ModuleType]] = None,
|
93
94
|
code_paths: Optional[List[str]] = None,
|
95
|
+
model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
|
94
96
|
options: Optional[model_types.ModelSaveOption] = None,
|
95
97
|
) -> model_meta.ModelMetadata:
|
96
98
|
if not options:
|
@@ -119,6 +121,7 @@ class ModelComposer:
|
|
119
121
|
python_version=python_version,
|
120
122
|
ext_modules=ext_modules,
|
121
123
|
code_paths=code_paths,
|
124
|
+
model_objective=model_objective,
|
122
125
|
options=options,
|
123
126
|
)
|
124
127
|
assert self.packager.meta is not None
|
@@ -185,4 +188,6 @@ class ModelComposer:
|
|
185
188
|
data_sources = lineage_utils.get_data_sources(model)
|
186
189
|
if not data_sources and sample_input_data is not None:
|
187
190
|
data_sources = lineage_utils.get_data_sources(sample_input_data)
|
191
|
+
if not data_sources and isinstance(sample_input_data, snowpark.DataFrame):
|
192
|
+
data_sources = [data_source.DataFrameInfo(sample_input_data.queries["queries"][-1])]
|
188
193
|
return data_sources
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import collections
|
2
2
|
import copy
|
3
3
|
import pathlib
|
4
|
-
import warnings
|
5
4
|
from typing import List, Optional, cast
|
6
5
|
|
7
6
|
import yaml
|
8
7
|
|
8
|
+
from snowflake.ml._internal import env_utils
|
9
9
|
from snowflake.ml.data import data_source
|
10
10
|
from snowflake.ml.model import type_hints
|
11
11
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
@@ -47,7 +47,9 @@ class ModelManifest:
|
|
47
47
|
runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
|
48
48
|
runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
|
49
49
|
runtime_to_use.imports.append(str(model_rel_path) + "/")
|
50
|
-
runtime_dict = runtime_to_use.save(
|
50
|
+
runtime_dict = runtime_to_use.save(
|
51
|
+
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
52
|
+
)
|
51
53
|
|
52
54
|
self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
|
53
55
|
self.methods: List[model_method.ModelMethod] = []
|
@@ -75,13 +77,9 @@ class ModelManifest:
|
|
75
77
|
)
|
76
78
|
|
77
79
|
dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
"be warehouse-compabible. The model may need to be run in SPCS.",
|
82
|
-
category=UserWarning,
|
83
|
-
stacklevel=1,
|
84
|
-
)
|
80
|
+
|
81
|
+
# We only want to include pip dependencies file if there are any pip requirements.
|
82
|
+
if len(model_meta.env.pip_requirements) > 0:
|
85
83
|
dependencies["pip"] = runtime_dict["dependencies"]["pip"]
|
86
84
|
|
87
85
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
@@ -137,10 +135,15 @@ class ModelManifest:
|
|
137
135
|
if isinstance(source, data_source.DatasetInfo):
|
138
136
|
result.append(
|
139
137
|
model_manifest_schema.LineageSourceDict(
|
140
|
-
# Currently, we only support lineage from Dataset.
|
141
138
|
type=model_manifest_schema.LineageSourceTypes.DATASET.value,
|
142
139
|
entity=source.fully_qualified_name,
|
143
140
|
version=source.version,
|
144
141
|
)
|
145
142
|
)
|
143
|
+
elif isinstance(source, data_source.DataFrameInfo):
|
144
|
+
result.append(
|
145
|
+
model_manifest_schema.LineageSourceDict(
|
146
|
+
type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
|
147
|
+
)
|
148
|
+
)
|
146
149
|
return result
|
@@ -57,12 +57,14 @@ class ModelFunctionInfo(TypedDict):
|
|
57
57
|
target_method: actual target method name to be called.
|
58
58
|
target_method_function_type: target method function type (FUNCTION or TABLE_FUNCTION).
|
59
59
|
signature: The signature of the model method.
|
60
|
+
is_partitioned: Whether the function is partitioned.
|
60
61
|
"""
|
61
62
|
|
62
63
|
name: Required[str]
|
63
64
|
target_method: Required[str]
|
64
65
|
target_method_function_type: Required[str]
|
65
66
|
signature: Required[model_signature.ModelSignature]
|
67
|
+
is_partitioned: Required[bool]
|
66
68
|
|
67
69
|
|
68
70
|
class ModelFunctionInfoDict(TypedDict):
|
@@ -78,6 +80,7 @@ class SnowparkMLDataDict(TypedDict):
|
|
78
80
|
|
79
81
|
class LineageSourceTypes(enum.Enum):
|
80
82
|
DATASET = "DATASET"
|
83
|
+
QUERY = "QUERY"
|
81
84
|
|
82
85
|
|
83
86
|
class LineageSourceDict(TypedDict):
|
@@ -363,9 +363,14 @@ class ModelEnv:
|
|
363
363
|
self.cuda_version = env_dict.get("cuda_version", None)
|
364
364
|
self.snowpark_ml_version = env_dict["snowpark_ml_version"]
|
365
365
|
|
366
|
-
def save_as_dict(
|
366
|
+
def save_as_dict(
|
367
|
+
self, base_dir: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
368
|
+
) -> model_meta_schema.ModelEnvDict:
|
367
369
|
env_utils.save_conda_env_file(
|
368
|
-
pathlib.Path(base_dir / self.conda_env_rel_path),
|
370
|
+
pathlib.Path(base_dir / self.conda_env_rel_path),
|
371
|
+
self._conda_dependencies,
|
372
|
+
self.python_version,
|
373
|
+
default_channel_override=default_channel_override,
|
369
374
|
)
|
370
375
|
env_utils.save_requirements_file(
|
371
376
|
pathlib.Path(base_dir / self.pip_requirements_rel_path), self._pip_requirements
|
@@ -1,7 +1,8 @@
|
|
1
|
+
import os
|
1
2
|
from abc import abstractmethod
|
2
|
-
from enum import Enum
|
3
3
|
from typing import Dict, Generic, Optional, Protocol, Type, final
|
4
4
|
|
5
|
+
import pandas as pd
|
5
6
|
from typing_extensions import TypeGuard, Unpack
|
6
7
|
|
7
8
|
from snowflake.ml.model import custom_model, type_hints as model_types
|
@@ -9,15 +10,6 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
9
10
|
from snowflake.ml.model._packager.model_meta import model_meta
|
10
11
|
|
11
12
|
|
12
|
-
class ModelObjective(Enum):
|
13
|
-
# This is not getting stored anywhere as metadata yet so it should be fine to slowly extend it for better coverage
|
14
|
-
UNKNOWN = "unknown"
|
15
|
-
BINARY_CLASSIFICATION = "binary_classification"
|
16
|
-
MULTI_CLASSIFICATION = "multi_classification"
|
17
|
-
REGRESSION = "regression"
|
18
|
-
RANKING = "ranking"
|
19
|
-
|
20
|
-
|
21
13
|
class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
22
14
|
HANDLER_TYPE: model_types.SupportedModelHandlerType
|
23
15
|
HANDLER_VERSION: str
|
@@ -106,6 +98,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
106
98
|
cls,
|
107
99
|
raw_model: model_types._ModelType,
|
108
100
|
model_meta: model_meta.ModelMetadata,
|
101
|
+
background_data: Optional[pd.DataFrame] = None,
|
109
102
|
**kwargs: Unpack[model_types.BaseModelLoadOption],
|
110
103
|
) -> custom_model.CustomModel:
|
111
104
|
"""Create a custom model class wrap for unified interface when being deployed. The predict method will be
|
@@ -114,6 +107,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
114
107
|
Args:
|
115
108
|
raw_model: original model object,
|
116
109
|
model_meta: The model metadata.
|
110
|
+
background_data: The background data used for the model explanations.
|
117
111
|
kwargs: Options when converting the model.
|
118
112
|
|
119
113
|
Raises:
|
@@ -131,7 +125,8 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
|
|
131
125
|
_MIN_SNOWPARK_ML_VERSION: The minimal version of Snowpark ML library to use the current handler.
|
132
126
|
_HANDLER_MIGRATOR_PLANS: Dict holding handler migrator plans.
|
133
127
|
|
134
|
-
|
128
|
+
MODEL_BLOB_FILE_OR_DIR: Relative path of the model blob file in the model subdir. Default to "model.pkl".
|
129
|
+
BG_DATA_FILE_SUFFIX: Suffix of the background data file. Default to "_background_data.pqt".
|
135
130
|
MODEL_ARTIFACTS_DIR: Relative path of the model artifacts dir in the model subdir. Default to "artifacts"
|
136
131
|
DEFAULT_TARGET_METHODS: Default target methods to be logged if not specified in this kind of model. Default to
|
137
132
|
["predict"]
|
@@ -139,8 +134,10 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
|
|
139
134
|
inputting sample data or model signature. Default to False.
|
140
135
|
"""
|
141
136
|
|
142
|
-
|
137
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
138
|
+
BG_DATA_FILE_SUFFIX = "_background_data.pqt"
|
143
139
|
MODEL_ARTIFACTS_DIR = "artifacts"
|
140
|
+
EXPLAIN_ARTIFACTS_DIR = "explain_artifacts"
|
144
141
|
DEFAULT_TARGET_METHODS = ["predict"]
|
145
142
|
IS_AUTO_SIGNATURE = False
|
146
143
|
|
@@ -169,3 +166,23 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
|
|
169
166
|
model_meta=model_meta,
|
170
167
|
model_blobs_dir_path=model_blobs_dir_path,
|
171
168
|
)
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
@final
|
172
|
+
def load_background_data(cls, name: str, model_blobs_dir_path: str) -> Optional[pd.DataFrame]:
|
173
|
+
"""Load the model into memory.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
name: Name of the model.
|
177
|
+
model_blobs_dir_path: Directory path to the whole model.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
Optional[pd.DataFrame], background data as pandas DataFrame, if exists.
|
181
|
+
"""
|
182
|
+
data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, name + cls.BG_DATA_FILE_SUFFIX)
|
183
|
+
if not os.path.exists(model_blobs_dir_path) or not os.path.isfile(data_blob_path):
|
184
|
+
return None
|
185
|
+
with open(data_blob_path, "rb") as f:
|
186
|
+
background_data = pd.read_parquet(f)
|
187
|
+
|
188
|
+
return background_data
|
@@ -1,9 +1,11 @@
|
|
1
1
|
import json
|
2
|
+
import warnings
|
2
3
|
from typing import Any, Callable, Iterable, Optional, Sequence, cast
|
3
4
|
|
4
5
|
import numpy as np
|
5
6
|
import numpy.typing as npt
|
6
7
|
import pandas as pd
|
8
|
+
from absl import logging
|
7
9
|
|
8
10
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
9
11
|
from snowflake.ml.model._packager.model_meta import model_meta
|
@@ -11,6 +13,17 @@ from snowflake.ml.model._signatures import snowpark_handler
|
|
11
13
|
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
12
14
|
|
13
15
|
|
16
|
+
class NumpyEncoder(json.JSONEncoder):
|
17
|
+
def default(self, obj: Any) -> Any:
|
18
|
+
if isinstance(obj, np.integer):
|
19
|
+
return int(obj)
|
20
|
+
if isinstance(obj, np.floating):
|
21
|
+
return float(obj)
|
22
|
+
if isinstance(obj, np.ndarray):
|
23
|
+
return obj.tolist()
|
24
|
+
return super().default(obj)
|
25
|
+
|
26
|
+
|
14
27
|
def _is_callable(model: model_types.SupportedModelType, method_name: str) -> bool:
|
15
28
|
return callable(getattr(model, method_name, None))
|
16
29
|
|
@@ -93,23 +106,42 @@ def convert_explanations_to_2D_df(
|
|
93
106
|
return pd.DataFrame(explanations)
|
94
107
|
|
95
108
|
if hasattr(model, "classes_"):
|
96
|
-
classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
|
109
|
+
classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr]
|
97
110
|
len_classes = len(classes_list)
|
98
111
|
if explanations.shape[2] != len_classes:
|
99
112
|
raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
|
100
113
|
else:
|
101
|
-
classes_list = [i for i in range(explanations.shape[2])]
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
if isinstance(cl, (int, np.integer)):
|
110
|
-
cl = int(cl)
|
111
|
-
class_explanations[cl] = cl_exp
|
112
|
-
col_list.append(json.dumps(class_explanations))
|
113
|
-
exp_2d.append(col_list)
|
114
|
+
classes_list = [str(i) for i in range(explanations.shape[2])]
|
115
|
+
|
116
|
+
def row_to_dict(row: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
117
|
+
"""Converts a single row to a dictionary."""
|
118
|
+
# convert to object or numpy creates strings of fixed length
|
119
|
+
return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object)
|
120
|
+
|
121
|
+
exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
|
114
122
|
|
115
123
|
return pd.DataFrame(exp_2d)
|
124
|
+
|
125
|
+
|
126
|
+
def validate_model_objective(
|
127
|
+
passed_model_objective: model_types.ModelObjective, inferred_model_objective: model_types.ModelObjective
|
128
|
+
) -> model_types.ModelObjective:
|
129
|
+
if (
|
130
|
+
passed_model_objective != model_types.ModelObjective.UNKNOWN
|
131
|
+
and inferred_model_objective != model_types.ModelObjective.UNKNOWN
|
132
|
+
):
|
133
|
+
if passed_model_objective != inferred_model_objective:
|
134
|
+
warnings.warn(
|
135
|
+
f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model "
|
136
|
+
f"version and passed argument ModelObjective: {passed_model_objective.name} is ignored",
|
137
|
+
category=UserWarning,
|
138
|
+
stacklevel=1,
|
139
|
+
)
|
140
|
+
return inferred_model_objective
|
141
|
+
elif inferred_model_objective != model_types.ModelObjective.UNKNOWN:
|
142
|
+
logging.info(
|
143
|
+
f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model "
|
144
|
+
f"version"
|
145
|
+
)
|
146
|
+
return inferred_model_objective
|
147
|
+
return passed_model_objective
|
@@ -30,24 +30,24 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
30
30
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
31
31
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
32
32
|
|
33
|
-
|
33
|
+
MODEL_BLOB_FILE_OR_DIR = "model.bin"
|
34
34
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
35
|
|
36
36
|
@classmethod
|
37
|
-
def
|
37
|
+
def get_model_objective_and_output_type(cls, model: "catboost.CatBoost") -> model_types.ModelObjective:
|
38
38
|
import catboost
|
39
39
|
|
40
40
|
if isinstance(model, catboost.CatBoostClassifier):
|
41
41
|
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
42
42
|
if num_classes == 2:
|
43
|
-
return
|
44
|
-
return
|
43
|
+
return model_types.ModelObjective.BINARY_CLASSIFICATION
|
44
|
+
return model_types.ModelObjective.MULTI_CLASSIFICATION
|
45
45
|
if isinstance(model, catboost.CatBoostRanker):
|
46
|
-
return
|
46
|
+
return model_types.ModelObjective.RANKING
|
47
47
|
if isinstance(model, catboost.CatBoostRegressor):
|
48
|
-
return
|
48
|
+
return model_types.ModelObjective.REGRESSION
|
49
49
|
# TODO: Find out model type from the generic Catboost Model
|
50
|
-
return
|
50
|
+
return model_types.ModelObjective.UNKNOWN
|
51
51
|
|
52
52
|
@classmethod
|
53
53
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
@@ -77,6 +77,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
77
77
|
is_sub_model: Optional[bool] = False,
|
78
78
|
**kwargs: Unpack[model_types.CatBoostModelSaveOptions],
|
79
79
|
) -> None:
|
80
|
+
enable_explainability = kwargs.get("enable_explainability", True)
|
81
|
+
|
80
82
|
import catboost
|
81
83
|
|
82
84
|
assert isinstance(model, catboost.CatBoost)
|
@@ -105,9 +107,14 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
105
107
|
sample_input_data=sample_input_data,
|
106
108
|
get_prediction_fn=get_prediction,
|
107
109
|
)
|
108
|
-
|
110
|
+
inferred_model_objective = cls.get_model_objective_and_output_type(model)
|
111
|
+
model_meta.model_objective = handlers_utils.validate_model_objective(
|
112
|
+
model_meta.model_objective, inferred_model_objective
|
113
|
+
)
|
114
|
+
model_objective = model_meta.model_objective
|
115
|
+
if enable_explainability:
|
109
116
|
output_type = model_signature.DataType.DOUBLE
|
110
|
-
if
|
117
|
+
if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
|
111
118
|
output_type = model_signature.DataType.STRING
|
112
119
|
model_meta = handlers_utils.add_explain_method_signature(
|
113
120
|
model_meta=model_meta,
|
@@ -115,10 +122,13 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
115
122
|
target_method="predict",
|
116
123
|
output_return_type=output_type,
|
117
124
|
)
|
125
|
+
model_meta.function_properties = {
|
126
|
+
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
127
|
+
}
|
118
128
|
|
119
129
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
120
130
|
os.makedirs(model_blob_path, exist_ok=True)
|
121
|
-
model_save_path = os.path.join(model_blob_path, cls.
|
131
|
+
model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
122
132
|
|
123
133
|
model.save_model(model_save_path)
|
124
134
|
|
@@ -126,7 +136,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
126
136
|
name=name,
|
127
137
|
model_type=cls.HANDLER_TYPE,
|
128
138
|
handler_version=cls.HANDLER_VERSION,
|
129
|
-
path=cls.
|
139
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
130
140
|
options=model_meta_schema.CatBoostModelBlobOptions({"catboost_estimator_type": model.__class__.__name__}),
|
131
141
|
)
|
132
142
|
model_meta.models[name] = base_meta
|
@@ -138,11 +148,9 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
138
148
|
],
|
139
149
|
check_local_version=True,
|
140
150
|
)
|
141
|
-
if
|
142
|
-
model_meta.env.include_if_absent(
|
143
|
-
|
144
|
-
check_local_version=True,
|
145
|
-
)
|
151
|
+
if enable_explainability:
|
152
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
153
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
146
154
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
147
155
|
|
148
156
|
return None
|
@@ -188,6 +196,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
188
196
|
cls,
|
189
197
|
raw_model: "catboost.CatBoost",
|
190
198
|
model_meta: model_meta_api.ModelMetadata,
|
199
|
+
background_data: Optional[pd.DataFrame] = None,
|
191
200
|
**kwargs: Unpack[model_types.CatBoostModelLoadOptions],
|
192
201
|
) -> custom_model.CustomModel:
|
193
202
|
import catboost
|
@@ -51,6 +51,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
51
51
|
**kwargs: Unpack[model_types.CustomModelSaveOption],
|
52
52
|
) -> None:
|
53
53
|
assert isinstance(model, custom_model.CustomModel)
|
54
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
55
|
+
if enable_explainability:
|
56
|
+
raise NotImplementedError("Explainability is not supported for custom model.")
|
54
57
|
|
55
58
|
def get_prediction(
|
56
59
|
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
@@ -108,13 +111,13 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
108
111
|
# Make sure that the module where the model is defined get pickled by value as well.
|
109
112
|
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
110
113
|
pickled_obj = (model.__class__, model.context)
|
111
|
-
with open(os.path.join(model_blob_path, cls.
|
114
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
112
115
|
cloudpickle.dump(pickled_obj, f)
|
113
116
|
# model meta will be saved by the context manager
|
114
117
|
model_meta.models[name] = model_blob_meta.ModelBlobMeta(
|
115
118
|
name=name,
|
116
119
|
model_type=cls.HANDLER_TYPE,
|
117
|
-
path=cls.
|
120
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
118
121
|
handler_version=cls.HANDLER_VERSION,
|
119
122
|
function_properties=model_meta.function_properties,
|
120
123
|
artifacts={
|
@@ -183,6 +186,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
183
186
|
cls,
|
184
187
|
raw_model: custom_model.CustomModel,
|
185
188
|
model_meta: model_meta_api.ModelMetadata,
|
189
|
+
background_data: Optional[pd.DataFrame] = None,
|
186
190
|
**kwargs: Unpack[model_types.CustomModelLoadOption],
|
187
191
|
) -> custom_model.CustomModel:
|
188
192
|
return raw_model
|
@@ -89,7 +89,7 @@ class HuggingFacePipelineHandler(
|
|
89
89
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
90
90
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
91
91
|
|
92
|
-
|
92
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
93
93
|
ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
|
94
94
|
DEFAULT_TARGET_METHODS = ["__call__"]
|
95
95
|
IS_AUTO_SIGNATURE = True
|
@@ -133,6 +133,9 @@ class HuggingFacePipelineHandler(
|
|
133
133
|
is_sub_model: Optional[bool] = False,
|
134
134
|
**kwargs: Unpack[model_types.HuggingFaceSaveOptions],
|
135
135
|
) -> None:
|
136
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
137
|
+
if enable_explainability:
|
138
|
+
raise NotImplementedError("Explainability is not supported for huggingface model.")
|
136
139
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
137
140
|
task = model.task # type:ignore[attr-defined]
|
138
141
|
framework = model.framework # type:ignore[attr-defined]
|
@@ -193,7 +196,7 @@ class HuggingFacePipelineHandler(
|
|
193
196
|
|
194
197
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
195
198
|
model.save_pretrained( # type:ignore[attr-defined]
|
196
|
-
os.path.join(model_blob_path, cls.
|
199
|
+
os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
197
200
|
)
|
198
201
|
pipeline_params = {
|
199
202
|
"_batch_size": model._batch_size, # type:ignore[attr-defined]
|
@@ -205,7 +208,7 @@ class HuggingFacePipelineHandler(
|
|
205
208
|
with open(
|
206
209
|
os.path.join(
|
207
210
|
model_blob_path,
|
208
|
-
cls.
|
211
|
+
cls.MODEL_BLOB_FILE_OR_DIR,
|
209
212
|
cls.ADDITIONAL_CONFIG_FILE,
|
210
213
|
),
|
211
214
|
"wb",
|
@@ -213,7 +216,7 @@ class HuggingFacePipelineHandler(
|
|
213
216
|
cloudpickle.dump(pipeline_params, f)
|
214
217
|
else:
|
215
218
|
with open(
|
216
|
-
os.path.join(model_blob_path, cls.
|
219
|
+
os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR),
|
217
220
|
"wb",
|
218
221
|
) as f:
|
219
222
|
cloudpickle.dump(model, f)
|
@@ -222,7 +225,7 @@ class HuggingFacePipelineHandler(
|
|
222
225
|
name=name,
|
223
226
|
model_type=cls.HANDLER_TYPE,
|
224
227
|
handler_version=cls.HANDLER_VERSION,
|
225
|
-
path=cls.
|
228
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
226
229
|
options=model_meta_schema.HuggingFacePipelineModelBlobOptions(
|
227
230
|
{
|
228
231
|
"task": task,
|
@@ -329,6 +332,7 @@ class HuggingFacePipelineHandler(
|
|
329
332
|
cls,
|
330
333
|
raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
|
331
334
|
model_meta: model_meta_api.ModelMetadata,
|
335
|
+
background_data: Optional[pd.DataFrame] = None,
|
332
336
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
333
337
|
) -> custom_model.CustomModel:
|
334
338
|
import transformers
|
@@ -365,7 +369,9 @@ class HuggingFacePipelineHandler(
|
|
365
369
|
else:
|
366
370
|
# For others, we could offer the whole dataframe as a list.
|
367
371
|
# Some of them may need some conversion
|
368
|
-
if
|
372
|
+
if hasattr(transformers, "ConversationalPipeline") and isinstance(
|
373
|
+
raw_model, transformers.ConversationalPipeline
|
374
|
+
):
|
369
375
|
input_data = [
|
370
376
|
transformers.Conversation(
|
371
377
|
text=conv_data["user_inputs"][0],
|
@@ -387,27 +393,33 @@ class HuggingFacePipelineHandler(
|
|
387
393
|
# Making it not aligned with the auto-inferred signature.
|
388
394
|
# If the output is a dict, we could blindly create a list containing that.
|
389
395
|
# Otherwise, creating pandas DataFrame won't succeed.
|
390
|
-
if
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
396
|
+
if (
|
397
|
+
(hasattr(transformers, "Conversation") and isinstance(temp_res, transformers.Conversation))
|
398
|
+
or isinstance(temp_res, dict)
|
399
|
+
or (
|
400
|
+
# For some pipeline that is expected to generate a list of dict per input
|
401
|
+
# When it omit outer list, it becomes list of dict instead of list of list of dict.
|
402
|
+
# We need to distinguish them from those pipelines that designed to output a dict per input
|
403
|
+
# So we need to check the pipeline type.
|
404
|
+
isinstance(
|
405
|
+
raw_model,
|
406
|
+
(
|
407
|
+
transformers.FillMaskPipeline,
|
408
|
+
transformers.QuestionAnsweringPipeline,
|
409
|
+
),
|
410
|
+
)
|
411
|
+
and X.shape[0] == 1
|
412
|
+
and isinstance(temp_res[0], dict)
|
401
413
|
)
|
402
|
-
and X.shape[0] == 1
|
403
|
-
and isinstance(temp_res[0], dict)
|
404
414
|
):
|
405
415
|
temp_res = [temp_res]
|
406
416
|
|
407
417
|
if len(temp_res) == 0:
|
408
418
|
return pd.DataFrame()
|
409
419
|
|
410
|
-
if
|
420
|
+
if hasattr(transformers, "ConversationalPipeline") and isinstance(
|
421
|
+
raw_model, transformers.ConversationalPipeline
|
422
|
+
):
|
411
423
|
temp_res = [[conv.generated_responses] for conv in temp_res]
|
412
424
|
|
413
425
|
# To concat those who outputs a list with one input.
|