snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +101 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +24 -18
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +51 -0
- snowflake/ml/jobs/_utils/payload_utils.py +352 -0
- snowflake/ml/jobs/_utils/spec_utils.py +298 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +91 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/model/model_version_impl.py +5 -3
- snowflake/ml/model/_client/ops/model_ops.py +13 -8
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +6 -7
- snowflake/ml/model/_signatures/pandas_handler.py +3 -3
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
- snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +17 -4
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +16 -178
- snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
- snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +41 -22
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -68,21 +68,45 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
68
68
|
return cast("BaseEstimator", model)
|
69
69
|
|
70
70
|
@classmethod
|
71
|
-
def _get_supported_object_for_explainability(
|
71
|
+
def _get_supported_object_for_explainability(
|
72
|
+
cls,
|
73
|
+
estimator: "BaseEstimator",
|
74
|
+
background_data: Optional[model_types.SupportedDataType],
|
75
|
+
enable_explainability: Optional[bool],
|
76
|
+
) -> Any:
|
72
77
|
from snowflake.ml.modeling import pipeline as snowml_pipeline
|
73
78
|
|
74
79
|
# handle pipeline objects separately
|
75
80
|
if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
|
76
81
|
return None
|
77
82
|
|
78
|
-
|
79
|
-
|
83
|
+
tree_methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
|
84
|
+
non_tree_methods = ["to_sklearn"]
|
85
|
+
for method_name in tree_methods:
|
86
|
+
if hasattr(estimator, method_name):
|
87
|
+
try:
|
88
|
+
result = getattr(estimator, method_name)()
|
89
|
+
return result
|
90
|
+
except exceptions.SnowflakeMLException:
|
91
|
+
pass # Do nothing and continue to the next method
|
92
|
+
for method_name in non_tree_methods:
|
80
93
|
if hasattr(estimator, method_name):
|
81
94
|
try:
|
82
95
|
result = getattr(estimator, method_name)()
|
96
|
+
if enable_explainability is None and background_data is None:
|
97
|
+
return None # cannot get explain without background data
|
98
|
+
elif enable_explainability and background_data is None:
|
99
|
+
raise ValueError(
|
100
|
+
"Provide `sample_input_data` to generate explanations for sklearn Snowpark ML models."
|
101
|
+
)
|
83
102
|
return result
|
84
103
|
except exceptions.SnowflakeMLException:
|
85
104
|
pass # Do nothing and continue to the next method
|
105
|
+
|
106
|
+
if enable_explainability:
|
107
|
+
raise ValueError(
|
108
|
+
"Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
|
109
|
+
)
|
86
110
|
return None
|
87
111
|
|
88
112
|
@classmethod
|
@@ -127,34 +151,39 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
127
151
|
raise ValueError(f"Target method {method_name} does not exist in the model.")
|
128
152
|
model_meta.signatures = temp_model_signature_dict
|
129
153
|
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
154
|
+
python_base_obj = cls._get_supported_object_for_explainability(model, sample_input_data, enable_explainability)
|
155
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
156
|
+
|
157
|
+
if enable_explainability:
|
158
|
+
if explain_target_method is None:
|
159
|
+
raise ValueError(
|
160
|
+
"The model must have one of the following methods to enable explainability: "
|
161
|
+
+ ", ".join(cls.EXPLAIN_TARGET_METHODS)
|
162
|
+
)
|
163
|
+
if enable_explainability is None:
|
164
|
+
if python_base_obj is None or explain_target_method is None:
|
137
165
|
# set None to False so we don't include shap in the environment
|
138
166
|
enable_explainability = False
|
139
167
|
else:
|
140
|
-
model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
|
141
|
-
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
|
142
|
-
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
143
|
-
model_meta = handlers_utils.add_explain_method_signature(
|
144
|
-
model_meta=model_meta,
|
145
|
-
explain_method="explain",
|
146
|
-
target_method=explain_target_method,
|
147
|
-
output_return_type=model_task_and_output_type.output_type,
|
148
|
-
)
|
149
168
|
enable_explainability = True
|
150
|
-
|
151
|
-
|
152
|
-
|
169
|
+
if enable_explainability:
|
170
|
+
model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
|
171
|
+
python_base_obj, model_meta.task
|
172
|
+
)
|
173
|
+
model_meta.task = model_task_and_output_type.task
|
174
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
175
|
+
model_meta=model_meta,
|
176
|
+
explain_method="explain",
|
177
|
+
target_method=explain_target_method,
|
178
|
+
output_return_type=model_task_and_output_type.output_type,
|
179
|
+
)
|
180
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
181
|
+
sample_input_data, model_meta, explain_target_method
|
182
|
+
)
|
183
|
+
if background_data is not None:
|
184
|
+
handlers_utils.save_background_data(
|
185
|
+
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
153
186
|
)
|
154
|
-
if background_data is not None:
|
155
|
-
handlers_utils.save_background_data(
|
156
|
-
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
157
|
-
)
|
158
187
|
|
159
188
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
160
189
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -237,8 +266,17 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
237
266
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
238
267
|
import shap
|
239
268
|
|
240
|
-
|
241
|
-
|
269
|
+
tree_methods = ["to_xgboost", "to_lightgbm"]
|
270
|
+
non_tree_methods = ["to_sklearn"]
|
271
|
+
for method_name in tree_methods:
|
272
|
+
try:
|
273
|
+
base_model = getattr(raw_model, method_name)()
|
274
|
+
explainer = shap.TreeExplainer(base_model)
|
275
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
|
276
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
277
|
+
except exceptions.SnowflakeMLException:
|
278
|
+
pass # Do nothing and continue to the next method
|
279
|
+
for method_name in non_tree_methods:
|
242
280
|
try:
|
243
281
|
base_model = getattr(raw_model, method_name)()
|
244
282
|
explainer = shap.Explainer(base_model, masker=background_data)
|
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
6
|
+
from packaging import version
|
6
7
|
from typing_extensions import TypeGuard, Unpack
|
7
8
|
|
8
9
|
from snowflake.ml._internal import type_utils
|
@@ -73,13 +74,42 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
73
74
|
if enable_explainability:
|
74
75
|
raise NotImplementedError("Explainability is not supported for Tensorflow model.")
|
75
76
|
|
77
|
+
# When tensorflow is installed, keras is also installed.
|
78
|
+
import keras
|
76
79
|
import tensorflow
|
77
80
|
|
78
81
|
assert isinstance(model, tensorflow.Module)
|
79
82
|
|
80
83
|
is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
|
81
|
-
"
|
84
|
+
"keras.Model"
|
82
85
|
).isinstance(model)
|
86
|
+
is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
|
87
|
+
is_keras_functional_or_sequential_model = (
|
88
|
+
getattr(model, "_is_graph_network", False)
|
89
|
+
or type_utils.LazyType("tensorflow.keras.engine.sequential.Sequential").isinstance(model)
|
90
|
+
or type_utils.LazyType("keras.engine.sequential.Sequential").isinstance(model)
|
91
|
+
or type_utils.LazyType("tf_keras.engine.sequential.Sequential").isinstance(model)
|
92
|
+
)
|
93
|
+
|
94
|
+
assert isinstance(model, tensorflow.Module)
|
95
|
+
|
96
|
+
keras_version = version.parse(keras.__version__)
|
97
|
+
|
98
|
+
# Tensorflow and keras model save format is different.
|
99
|
+
# Keras functional or sequential models are saved as keras format
|
100
|
+
# Keras v3 other models are saved using cloudpickle
|
101
|
+
# Keras v2 other models are saved using tensorflow saved model format
|
102
|
+
# Tensorflow models are saved using tensorflow saved model format
|
103
|
+
|
104
|
+
if is_keras_model or is_tf_keras_model:
|
105
|
+
if is_keras_functional_or_sequential_model:
|
106
|
+
save_format = "keras"
|
107
|
+
elif keras_version.major == 2 or is_tf_keras_model:
|
108
|
+
save_format = "keras_tf"
|
109
|
+
else:
|
110
|
+
save_format = "cloudpickle"
|
111
|
+
else:
|
112
|
+
save_format = "tf"
|
83
113
|
|
84
114
|
if is_keras_model:
|
85
115
|
default_target_methods = ["predict"]
|
@@ -93,6 +123,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
93
123
|
default_target_methods=default_target_methods,
|
94
124
|
)
|
95
125
|
|
126
|
+
if is_keras_model and len(target_methods) > 1:
|
127
|
+
raise ValueError("Keras model can only have one target method.")
|
128
|
+
|
96
129
|
def get_prediction(
|
97
130
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
98
131
|
) -> model_types.SupportedLocalDataType:
|
@@ -122,31 +155,43 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
122
155
|
|
123
156
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
124
157
|
os.makedirs(model_blob_path, exist_ok=True)
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
158
|
+
save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
159
|
+
if save_format == "keras":
|
160
|
+
model.save(save_path, save_format="keras")
|
161
|
+
elif save_format == "keras_tf":
|
162
|
+
model.save(save_path, save_format="tf")
|
163
|
+
elif save_format == "cloudpickle":
|
164
|
+
import cloudpickle
|
165
|
+
|
166
|
+
with open(save_path, "wb") as f:
|
167
|
+
cloudpickle.dump(model, f)
|
133
168
|
else:
|
134
|
-
tensorflow.saved_model.save(
|
169
|
+
tensorflow.saved_model.save(
|
170
|
+
model,
|
171
|
+
save_path,
|
172
|
+
options=tensorflow.saved_model.SaveOptions(experimental_custom_gradients=False),
|
173
|
+
)
|
135
174
|
|
136
175
|
base_meta = model_blob_meta.ModelBlobMeta(
|
137
176
|
name=name,
|
138
177
|
model_type=cls.HANDLER_TYPE,
|
139
178
|
handler_version=cls.HANDLER_VERSION,
|
140
179
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
141
|
-
options=model_meta_schema.TensorflowModelBlobOptions(
|
180
|
+
options=model_meta_schema.TensorflowModelBlobOptions(save_format=save_format),
|
142
181
|
)
|
143
182
|
model_meta.models[name] = base_meta
|
144
183
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
145
184
|
|
185
|
+
dependencies = [
|
186
|
+
model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
|
187
|
+
]
|
188
|
+
if is_keras_model:
|
189
|
+
dependencies.append(model_env.ModelDependency(requirement="keras", pip_name="keras"))
|
190
|
+
elif is_tf_keras_model:
|
191
|
+
dependencies.append(model_env.ModelDependency(requirement="tf-keras", pip_name="tf-keras"))
|
192
|
+
|
146
193
|
model_meta.env.include_if_absent(
|
147
|
-
|
148
|
-
model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
|
149
|
-
],
|
194
|
+
dependencies,
|
150
195
|
check_local_version=True,
|
151
196
|
)
|
152
197
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
@@ -166,10 +211,18 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
166
211
|
model_blob_metadata = model_blobs_metadata[name]
|
167
212
|
model_blob_filename = model_blob_metadata.path
|
168
213
|
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
169
|
-
|
170
|
-
|
214
|
+
load_path = os.path.join(model_blob_path, model_blob_filename)
|
215
|
+
save_format = model_blob_options.get("save_format", "tf")
|
216
|
+
if save_format == "keras" or save_format == "keras_tf":
|
217
|
+
m = tensorflow.keras.models.load_model(load_path)
|
218
|
+
elif save_format == "cloudpickle":
|
219
|
+
import cloudpickle
|
220
|
+
|
221
|
+
with open(load_path, "rb") as f:
|
222
|
+
m = cloudpickle.load(f)
|
171
223
|
else:
|
172
|
-
m = tensorflow.saved_model.load(
|
224
|
+
m = tensorflow.saved_model.load(load_path)
|
225
|
+
|
173
226
|
return cast(tensorflow.Module, m)
|
174
227
|
|
175
228
|
@classmethod
|
@@ -117,8 +117,8 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
117
117
|
sample_input_data=sample_input_data,
|
118
118
|
get_prediction_fn=get_prediction,
|
119
119
|
)
|
120
|
-
model_task_and_output = model_task_utils.
|
121
|
-
model_meta.task =
|
120
|
+
model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
|
121
|
+
model_meta.task = model_task_and_output.task
|
122
122
|
if enable_explainability:
|
123
123
|
model_meta = handlers_utils.add_explain_method_signature(
|
124
124
|
model_meta=model_meta,
|
@@ -254,7 +254,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
254
254
|
import shap
|
255
255
|
|
256
256
|
explainer = shap.TreeExplainer(raw_model)
|
257
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X)
|
257
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
|
258
258
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
259
259
|
|
260
260
|
if target_method == "explain":
|
@@ -215,6 +215,7 @@ class ModelMetadata:
|
|
215
215
|
function_properties: A dict mapping function names to dict mapping function property key to value.
|
216
216
|
metadata: User provided key-value metadata of the model. Defaults to None.
|
217
217
|
creation_timestamp: Unix timestamp when the model metadata is created.
|
218
|
+
user_files: Dict mapping subdirectories to extra artifact file paths for files to include in the model.
|
218
219
|
task: Model task like TABULAR_REGRESSION, tabular_classification, timeseries_forecasting etc.
|
219
220
|
"""
|
220
221
|
|
@@ -234,6 +235,7 @@ class ModelMetadata:
|
|
234
235
|
runtimes: Optional[Dict[str, model_runtime.ModelRuntime]] = None,
|
235
236
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
236
237
|
function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
|
238
|
+
user_files: Optional[Dict[str, List[str]]] = None,
|
237
239
|
metadata: Optional[Dict[str, str]] = None,
|
238
240
|
creation_timestamp: Optional[str] = None,
|
239
241
|
min_snowpark_ml_version: Optional[str] = None,
|
@@ -247,6 +249,7 @@ class ModelMetadata:
|
|
247
249
|
if signatures:
|
248
250
|
self.signatures = signatures
|
249
251
|
self.function_properties = function_properties or {}
|
252
|
+
self.user_files = user_files
|
250
253
|
self.metadata = metadata
|
251
254
|
self.model_type = model_type
|
252
255
|
self.env = env
|
@@ -59,7 +59,11 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
|
|
59
59
|
|
60
60
|
|
61
61
|
class TensorflowModelBlobOptions(BaseModelBlobOptions):
|
62
|
-
|
62
|
+
save_format: Required[str]
|
63
|
+
|
64
|
+
|
65
|
+
class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
|
66
|
+
batch_size: Required[int]
|
63
67
|
|
64
68
|
|
65
69
|
ModelBlobOptions = Union[
|
@@ -68,6 +72,7 @@ ModelBlobOptions = Union[
|
|
68
72
|
MLFlowModelBlobOptions,
|
69
73
|
XgboostModelBlobOptions,
|
70
74
|
TensorflowModelBlobOptions,
|
75
|
+
SentenceTransformersModelBlobOptions,
|
71
76
|
]
|
72
77
|
|
73
78
|
|
@@ -1,2 +1,2 @@
|
|
1
|
-
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<
|
2
|
-
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<
|
1
|
+
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
2
|
+
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'huggingface_hub<0.26', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.12.0,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
@@ -149,8 +149,9 @@ def _get_model_task(model: Any) -> type_hints.Task:
|
|
149
149
|
raise ValueError(f"Model type {type(model)} is not supported")
|
150
150
|
|
151
151
|
|
152
|
-
def
|
153
|
-
|
152
|
+
def resolve_model_task_and_output_type(model: Any, passed_model_task: type_hints.Task) -> ModelTaskAndOutputType:
|
153
|
+
inferred_task = _get_model_task(model)
|
154
|
+
task = handlers_utils.validate_model_task(passed_model_task, inferred_task)
|
154
155
|
output_type = model_signature.DataType.DOUBLE
|
155
156
|
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
156
157
|
output_type = model_signature.DataType.STRING
|
@@ -12,7 +12,6 @@ class BaseDataHandler(ABC, Generic[model_types._DataType]):
|
|
12
12
|
FEATURE_PREFIX: Final[str] = "feature"
|
13
13
|
INPUT_PREFIX: Final[str] = "input"
|
14
14
|
OUTPUT_PREFIX: Final[str] = "output"
|
15
|
-
SIG_INFER_ROWS_COUNT_LIMIT: Final[int] = 10
|
16
15
|
|
17
16
|
@staticmethod
|
18
17
|
@abstractmethod
|
@@ -26,7 +25,7 @@ class BaseDataHandler(ABC, Generic[model_types._DataType]):
|
|
26
25
|
|
27
26
|
@staticmethod
|
28
27
|
@abstractmethod
|
29
|
-
def truncate(data: model_types._DataType) -> model_types._DataType:
|
28
|
+
def truncate(data: model_types._DataType, length: int) -> model_types._DataType:
|
30
29
|
...
|
31
30
|
|
32
31
|
@staticmethod
|
@@ -35,8 +35,8 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
|
|
35
35
|
return len(data)
|
36
36
|
|
37
37
|
@staticmethod
|
38
|
-
def truncate(data: model_types._SupportedBuiltinsList) -> model_types._SupportedBuiltinsList:
|
39
|
-
return data[: min(ListOfBuiltinHandler.count(data),
|
38
|
+
def truncate(data: model_types._SupportedBuiltinsList, length: int) -> model_types._SupportedBuiltinsList:
|
39
|
+
return data[: min(ListOfBuiltinHandler.count(data), length)]
|
40
40
|
|
41
41
|
@staticmethod
|
42
42
|
def validate(data: model_types._SupportedBuiltinsList) -> None:
|
@@ -23,8 +23,8 @@ class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpy
|
|
23
23
|
return data.shape[0]
|
24
24
|
|
25
25
|
@staticmethod
|
26
|
-
def truncate(data: model_types._SupportedNumpyArray) -> model_types._SupportedNumpyArray:
|
27
|
-
return data[: min(NumpyArrayHandler.count(data),
|
26
|
+
def truncate(data: model_types._SupportedNumpyArray, length: int) -> model_types._SupportedNumpyArray:
|
27
|
+
return data[: min(NumpyArrayHandler.count(data), length)]
|
28
28
|
|
29
29
|
@staticmethod
|
30
30
|
def validate(data: model_types._SupportedNumpyArray) -> None:
|
@@ -94,11 +94,10 @@ class SeqOfNumpyArrayHandler(base_handler.BaseDataHandler[Sequence[model_types._
|
|
94
94
|
return min(NumpyArrayHandler.count(data_col) for data_col in data)
|
95
95
|
|
96
96
|
@staticmethod
|
97
|
-
def truncate(
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
]
|
97
|
+
def truncate(
|
98
|
+
data: Sequence[model_types._SupportedNumpyArray], length: int
|
99
|
+
) -> Sequence[model_types._SupportedNumpyArray]:
|
100
|
+
return [data_col[: min(SeqOfNumpyArrayHandler.count(data), length)] for data_col in data]
|
102
101
|
|
103
102
|
@staticmethod
|
104
103
|
def validate(data: Sequence[model_types._SupportedNumpyArray]) -> None:
|
@@ -23,8 +23,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
23
23
|
return len(data.index)
|
24
24
|
|
25
25
|
@staticmethod
|
26
|
-
def truncate(data: pd.DataFrame) -> pd.DataFrame:
|
27
|
-
return data.head(min(PandasDataFrameHandler.count(data),
|
26
|
+
def truncate(data: pd.DataFrame, length: int) -> pd.DataFrame:
|
27
|
+
return data.head(min(PandasDataFrameHandler.count(data), length))
|
28
28
|
|
29
29
|
@staticmethod
|
30
30
|
def validate(data: Union[pd.DataFrame, pd.Series]) -> None:
|
@@ -224,6 +224,6 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
224
224
|
df_col_dtypes = [df[col].dtype for col in df.columns]
|
225
225
|
for df_col, df_col_dtype in zip(df_cols, df_col_dtypes):
|
226
226
|
if df_col_dtype == np.dtype("O"):
|
227
|
-
if isinstance(df[df_col][0], np.ndarray):
|
227
|
+
if isinstance(df[df_col].iloc[0], np.ndarray):
|
228
228
|
df[df_col] = df[df_col].map(np.ndarray.tolist)
|
229
229
|
return df
|
@@ -33,11 +33,8 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
|
|
33
33
|
return min(data_col.shape[0] for data_col in data) # type: ignore[no-any-return]
|
34
34
|
|
35
35
|
@staticmethod
|
36
|
-
def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]:
|
37
|
-
return [
|
38
|
-
data_col[: min(SeqOfPyTorchTensorHandler.count(data), SeqOfPyTorchTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
|
39
|
-
for data_col in data
|
40
|
-
]
|
36
|
+
def truncate(data: Sequence["torch.Tensor"], length: int) -> Sequence["torch.Tensor"]:
|
37
|
+
return [data_col[: min(SeqOfPyTorchTensorHandler.count(data), 10)] for data_col in data]
|
41
38
|
|
42
39
|
@staticmethod
|
43
40
|
def validate(data: Sequence["torch.Tensor"]) -> None:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Literal, Optional, Sequence, cast
|
2
|
+
from typing import Any, Literal, Optional, Sequence, cast
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
@@ -29,8 +29,8 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
29
29
|
return data.count()
|
30
30
|
|
31
31
|
@staticmethod
|
32
|
-
def truncate(data: snowflake.snowpark.DataFrame) -> snowflake.snowpark.DataFrame:
|
33
|
-
return cast(snowflake.snowpark.DataFrame, data.limit(
|
32
|
+
def truncate(data: snowflake.snowpark.DataFrame, length: int) -> snowflake.snowpark.DataFrame:
|
33
|
+
return cast(snowflake.snowpark.DataFrame, data.limit(length))
|
34
34
|
|
35
35
|
@staticmethod
|
36
36
|
def validate(data: snowflake.snowpark.DataFrame) -> None:
|
@@ -52,7 +52,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
52
52
|
data: snowflake.snowpark.DataFrame, role: Literal["input", "output"]
|
53
53
|
) -> Sequence[core.BaseFeatureSpec]:
|
54
54
|
return pandas_handler.PandasDataFrameHandler.infer_signature(
|
55
|
-
SnowparkDataFrameHandler.convert_to_df(data
|
55
|
+
SnowparkDataFrameHandler.convert_to_df(data), role=role
|
56
56
|
)
|
57
57
|
|
58
58
|
@staticmethod
|
@@ -73,14 +73,20 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
73
73
|
assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
|
74
74
|
dtype_map[feature.name] = feature.as_dtype()
|
75
75
|
df_local = data.to_pandas()
|
76
|
+
|
76
77
|
# This is because Array will become string (Even though the correct schema is set)
|
77
78
|
# and object will become variant type and requires an additional loads
|
78
79
|
# to get correct data otherwise it would be string.
|
80
|
+
def load_if_not_null(x: str) -> Optional[Any]:
|
81
|
+
if x is None:
|
82
|
+
return None
|
83
|
+
return json.loads(x)
|
84
|
+
|
79
85
|
for field in data.schema.fields:
|
80
86
|
if isinstance(field.datatype, spt.ArrayType):
|
81
87
|
df_local[identifier.get_unescaped_names(field.name)] = df_local[
|
82
88
|
identifier.get_unescaped_names(field.name)
|
83
|
-
].map(
|
89
|
+
].map(load_if_not_null)
|
84
90
|
# Only when the feature is not from inference, we are confident to do the type casting.
|
85
91
|
# Otherwise, dtype_map will be empty.
|
86
92
|
# Errors are ignored to make sure None won't be converted and won't raise Error
|
@@ -60,14 +60,9 @@ class SeqOfTensorflowTensorHandler(
|
|
60
60
|
|
61
61
|
@staticmethod
|
62
62
|
def truncate(
|
63
|
-
data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]
|
63
|
+
data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]], length: int
|
64
64
|
) -> Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]:
|
65
|
-
return [
|
66
|
-
data_col[
|
67
|
-
: min(SeqOfTensorflowTensorHandler.count(data), SeqOfTensorflowTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)
|
68
|
-
]
|
69
|
-
for data_col in data
|
70
|
-
]
|
65
|
+
return [data_col[: min(SeqOfTensorflowTensorHandler.count(data), length)] for data_col in data]
|
71
66
|
|
72
67
|
@staticmethod
|
73
68
|
def validate(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> None:
|
@@ -59,11 +59,16 @@ _ALL_DATA_HANDLERS = _LOCAL_DATA_HANDLERS + [snowpark_handler.SnowparkDataFrameH
|
|
59
59
|
|
60
60
|
def _truncate_data(
|
61
61
|
data: model_types.SupportedDataType,
|
62
|
+
length: Optional[int] = 100,
|
62
63
|
) -> model_types.SupportedDataType:
|
63
64
|
for handler in _ALL_DATA_HANDLERS:
|
64
65
|
if handler.can_handle(data):
|
66
|
+
# If length is None, return the original data
|
67
|
+
if length is None:
|
68
|
+
return data
|
69
|
+
|
65
70
|
row_count = handler.count(data)
|
66
|
-
if row_count <=
|
71
|
+
if row_count <= length:
|
67
72
|
return data
|
68
73
|
|
69
74
|
warnings.warn(
|
@@ -77,7 +82,7 @@ def _truncate_data(
|
|
77
82
|
category=UserWarning,
|
78
83
|
stacklevel=1,
|
79
84
|
)
|
80
|
-
return handler.truncate(data)
|
85
|
+
return handler.truncate(data, length)
|
81
86
|
raise snowml_exceptions.SnowflakeMLException(
|
82
87
|
error_code=error_codes.NOT_IMPLEMENTED,
|
83
88
|
original_exception=NotImplementedError(
|
@@ -687,6 +692,8 @@ def infer_signature(
|
|
687
692
|
output_data: model_types.SupportedLocalDataType,
|
688
693
|
input_feature_names: Optional[List[str]] = None,
|
689
694
|
output_feature_names: Optional[List[str]] = None,
|
695
|
+
input_data_limit: Optional[int] = 100,
|
696
|
+
output_data_limit: Optional[int] = 100,
|
690
697
|
) -> core.ModelSignature:
|
691
698
|
"""
|
692
699
|
Infer model signature from given input and output sample data.
|
@@ -710,12 +717,18 @@ def infer_signature(
|
|
710
717
|
output_data: Sample output data for the model.
|
711
718
|
input_feature_names: Names for input features. Defaults to None.
|
712
719
|
output_feature_names: Names for output features. Defaults to None.
|
720
|
+
input_data_limit: Limit the number of rows to be used in signature inference in the input data. Defaults to 100.
|
721
|
+
If None, all rows are used. If the number of rows in the input data is less than the limit, all rows are
|
722
|
+
used.
|
723
|
+
output_data_limit: Limit the number of rows to be used in signature inference in the output data. Defaults to
|
724
|
+
100. If None, all rows are used. If the number of rows in the output data is less than the limit, all rows
|
725
|
+
are used.
|
713
726
|
|
714
727
|
Returns:
|
715
728
|
A model signature inferred from the given input and output sample data.
|
716
729
|
"""
|
717
|
-
inputs = _infer_signature(input_data, role="input")
|
730
|
+
inputs = _infer_signature(_truncate_data(input_data, input_data_limit), role="input")
|
718
731
|
inputs = utils.rename_features(inputs, input_feature_names)
|
719
|
-
outputs = _infer_signature(output_data, role="output")
|
732
|
+
outputs = _infer_signature(_truncate_data(output_data, output_data_limit), role="output")
|
720
733
|
outputs = utils.rename_features(outputs, output_feature_names)
|
721
734
|
return core.ModelSignature(inputs, outputs)
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -199,6 +199,7 @@ class HuggingFaceSaveOptions(BaseModelSaveOption):
|
|
199
199
|
class SentenceTransformersSaveOptions(BaseModelSaveOption):
|
200
200
|
target_methods: NotRequired[Sequence[str]]
|
201
201
|
cuda_version: NotRequired[str]
|
202
|
+
batch_size: NotRequired[int]
|
202
203
|
|
203
204
|
|
204
205
|
ModelSaveOption = Union[
|
@@ -1,11 +1,9 @@
|
|
1
|
-
import os
|
2
1
|
from typing import List, Optional, Union
|
3
2
|
|
4
3
|
import pandas as pd
|
5
4
|
from sklearn import model_selection
|
6
5
|
|
7
6
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
8
|
-
from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
|
9
7
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
10
8
|
get_module_name,
|
11
9
|
is_single_node,
|
@@ -13,9 +11,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
13
11
|
from snowflake.ml.modeling._internal.local_implementations.pandas_trainer import (
|
14
12
|
PandasModelTrainer,
|
15
13
|
)
|
16
|
-
from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_trainer import (
|
17
|
-
MLRuntimeModelTrainer,
|
18
|
-
)
|
19
14
|
from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
|
20
15
|
from snowflake.ml.modeling._internal.snowpark_implementations.distributed_hpo_trainer import (
|
21
16
|
DistributedHPOTrainer,
|
@@ -107,9 +102,6 @@ class ModelTrainerBuilder:
|
|
107
102
|
"autogenerated": autogenerated,
|
108
103
|
"subproject": subproject,
|
109
104
|
}
|
110
|
-
if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
|
111
|
-
return MLRuntimeModelTrainer(**init_args) # type: ignore[arg-type, return-value]
|
112
|
-
|
113
105
|
trainer_klass = SnowparkModelTrainer
|
114
106
|
|
115
107
|
assert dataset._session is not None # Make MyPy happy
|