snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -1
- snowflake/cortex/_complete.py +240 -16
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_sse_client.py +81 -0
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +34 -10
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
- snowflake/ml/_internal/telemetry.py +26 -0
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/dataset/dataset.py +54 -32
- snowflake/ml/dataset/dataset_factory.py +3 -4
- snowflake/ml/feature_store/feature_store.py +440 -243
- snowflake/ml/feature_store/feature_view.py +61 -9
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +6 -8
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +47 -4
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +7 -6
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_signatures/builtins_handler.py +2 -1
- snowflake/ml/model/_signatures/core.py +13 -1
- snowflake/ml/model/_signatures/pandas_handler.py +2 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +74 -4
- snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +158 -121
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +39 -18
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +88 -134
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +22 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +5 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +5 -3
- snowflake/ml/modeling/cluster/birch.py +5 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +5 -3
- snowflake/ml/modeling/cluster/dbscan.py +5 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +5 -3
- snowflake/ml/modeling/cluster/k_means.py +5 -3
- snowflake/ml/modeling/cluster/mean_shift.py +5 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +5 -3
- snowflake/ml/modeling/cluster/optics.py +5 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +5 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +5 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +5 -3
- snowflake/ml/modeling/compose/column_transformer.py +5 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +5 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +5 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +5 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +5 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +5 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +5 -3
- snowflake/ml/modeling/covariance/oas.py +5 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +5 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +5 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +5 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +5 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +5 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +5 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -3
- snowflake/ml/modeling/decomposition/pca.py +5 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +5 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +5 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +5 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +5 -3
- snowflake/ml/modeling/framework/base.py +3 -8
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +5 -3
- snowflake/ml/modeling/impute/knn_imputer.py +5 -3
- snowflake/ml/modeling/impute/missing_indicator.py +5 -3
- snowflake/ml/modeling/impute/simple_imputer.py +8 -4
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +5 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +5 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +5 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +5 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +5 -3
- snowflake/ml/modeling/manifold/mds.py +5 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +5 -3
- snowflake/ml/modeling/manifold/tsne.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +5 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +5 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +5 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +5 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +5 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +5 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +6 -0
- snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
- snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +53 -11
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +44 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +5 -3
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +51 -7
- snowflake_ml_python-1.5.4.dist-info/RECORD +389 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
- snowflake_ml_python-1.5.2.dist-info/RECORD +0 -384
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
17
17
|
from snowflake.ml.model._packager.model_meta import (
|
18
18
|
model_blob_meta,
|
19
19
|
model_meta as model_meta_api,
|
20
|
+
model_meta_schema,
|
20
21
|
)
|
21
22
|
|
22
23
|
|
@@ -68,6 +69,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
68
69
|
predictions_df = target_method(model, sample_input_data)
|
69
70
|
return predictions_df
|
70
71
|
|
72
|
+
for func_name in model._get_partitioned_infer_methods():
|
73
|
+
function_properties = model_meta.function_properties.get(func_name, {})
|
74
|
+
function_properties[model_meta_schema.FunctionProperties.PARTITIONED.value] = True
|
75
|
+
model_meta.function_properties[func_name] = function_properties
|
76
|
+
|
71
77
|
if not is_sub_model:
|
72
78
|
model_meta = handlers_utils.validate_signature(
|
73
79
|
model=model,
|
@@ -101,14 +107,16 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
101
107
|
|
102
108
|
# Make sure that the module where the model is defined get pickled by value as well.
|
103
109
|
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
104
|
-
|
110
|
+
pickled_obj = (model.__class__, model.context)
|
105
111
|
with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
|
106
|
-
cloudpickle.dump(
|
112
|
+
cloudpickle.dump(pickled_obj, f)
|
113
|
+
# model meta will be saved by the context manager
|
107
114
|
model_meta.models[name] = model_blob_meta.ModelBlobMeta(
|
108
115
|
name=name,
|
109
116
|
model_type=cls.HANDLER_TYPE,
|
110
117
|
path=cls.MODELE_BLOB_FILE_OR_DIR,
|
111
118
|
handler_version=cls.HANDLER_VERSION,
|
119
|
+
function_properties=model_meta.function_properties,
|
112
120
|
artifacts={
|
113
121
|
name: pathlib.Path(
|
114
122
|
os.path.join(cls.MODEL_ARTIFACTS_DIR, os.path.basename(os.path.normpath(path=uri)))
|
@@ -128,7 +136,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
128
136
|
name: str,
|
129
137
|
model_meta: model_meta_api.ModelMetadata,
|
130
138
|
model_blobs_dir_path: str,
|
131
|
-
**kwargs: Unpack[model_types.
|
139
|
+
**kwargs: Unpack[model_types.CustomModelLoadOption],
|
132
140
|
) -> "custom_model.CustomModel":
|
133
141
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
134
142
|
|
@@ -175,6 +183,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
175
183
|
cls,
|
176
184
|
raw_model: custom_model.CustomModel,
|
177
185
|
model_meta: model_meta_api.ModelMetadata,
|
178
|
-
**kwargs: Unpack[model_types.
|
186
|
+
**kwargs: Unpack[model_types.CustomModelLoadOption],
|
179
187
|
) -> custom_model.CustomModel:
|
180
188
|
return raw_model
|
@@ -3,6 +3,7 @@ import os
|
|
3
3
|
import warnings
|
4
4
|
from typing import (
|
5
5
|
TYPE_CHECKING,
|
6
|
+
Any,
|
6
7
|
Callable,
|
7
8
|
Dict,
|
8
9
|
List,
|
@@ -250,9 +251,18 @@ class HuggingFacePipelineHandler(
|
|
250
251
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
251
252
|
|
252
253
|
@staticmethod
|
253
|
-
def _get_device_config() -> Dict[str, str]:
|
254
|
-
device_config = {}
|
255
|
-
|
254
|
+
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
|
255
|
+
device_config: Dict[str, Any] = {}
|
256
|
+
if (
|
257
|
+
kwargs.get("use_gpu", False)
|
258
|
+
and kwargs.get("device_map", None) is None
|
259
|
+
and kwargs.get("device", None) is None
|
260
|
+
):
|
261
|
+
device_config["device_map"] = "auto"
|
262
|
+
elif kwargs.get("device_map", None) is not None:
|
263
|
+
device_config["device_map"] = kwargs["device_map"]
|
264
|
+
elif kwargs.get("device", None) is not None:
|
265
|
+
device_config["device"] = kwargs["device"]
|
256
266
|
|
257
267
|
return device_config
|
258
268
|
|
@@ -262,7 +272,7 @@ class HuggingFacePipelineHandler(
|
|
262
272
|
name: str,
|
263
273
|
model_meta: model_meta_api.ModelMetadata,
|
264
274
|
model_blobs_dir_path: str,
|
265
|
-
**kwargs: Unpack[model_types.
|
275
|
+
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
266
276
|
) -> Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"]:
|
267
277
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
268
278
|
# We need to redirect the some folders to a writable location in the sandbox.
|
@@ -292,10 +302,7 @@ class HuggingFacePipelineHandler(
|
|
292
302
|
) as f:
|
293
303
|
pipeline_params = cloudpickle.load(f)
|
294
304
|
|
295
|
-
|
296
|
-
device_config = cls._get_device_config()
|
297
|
-
else:
|
298
|
-
device_config = {}
|
305
|
+
device_config = cls._get_device_config(**kwargs)
|
299
306
|
|
300
307
|
m = transformers.pipeline(
|
301
308
|
model_blob_options["task"],
|
@@ -310,12 +317,8 @@ class HuggingFacePipelineHandler(
|
|
310
317
|
with open(model_blob_file_or_dir_path, "rb") as f:
|
311
318
|
m = cloudpickle.load(f)
|
312
319
|
assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
|
313
|
-
if (
|
314
|
-
|
315
|
-
and getattr(m, "device_map", None) is None
|
316
|
-
and kwargs.get("use_gpu", False)
|
317
|
-
):
|
318
|
-
m.__dict__.update(cls._get_device_config())
|
320
|
+
if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
|
321
|
+
m.__dict__.update(cls._get_device_config(**kwargs))
|
319
322
|
|
320
323
|
if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
|
321
324
|
m.__dict__.update(torch_dtype="auto")
|
@@ -326,7 +329,7 @@ class HuggingFacePipelineHandler(
|
|
326
329
|
cls,
|
327
330
|
raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
|
328
331
|
model_meta: model_meta_api.ModelMetadata,
|
329
|
-
**kwargs: Unpack[model_types.
|
332
|
+
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
330
333
|
) -> custom_model.CustomModel:
|
331
334
|
import transformers
|
332
335
|
|
@@ -139,7 +139,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
139
139
|
name: str,
|
140
140
|
model_meta: model_meta_api.ModelMetadata,
|
141
141
|
model_blobs_dir_path: str,
|
142
|
-
**kwargs: Unpack[model_types.
|
142
|
+
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|
143
143
|
) -> Union["lightgbm.Booster", "lightgbm.LGBMModel"]:
|
144
144
|
import lightgbm
|
145
145
|
|
@@ -169,7 +169,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
169
169
|
cls,
|
170
170
|
raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
|
171
171
|
model_meta: model_meta_api.ModelMetadata,
|
172
|
-
**kwargs: Unpack[model_types.
|
172
|
+
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|
173
173
|
) -> custom_model.CustomModel:
|
174
174
|
import lightgbm
|
175
175
|
|
@@ -118,7 +118,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
118
118
|
name: str,
|
119
119
|
model_meta: model_meta_api.ModelMetadata,
|
120
120
|
model_blobs_dir_path: str,
|
121
|
-
**kwargs: Unpack[model_types.
|
121
|
+
**kwargs: Unpack[model_types.LLMLoadOptions],
|
122
122
|
) -> llm.LLM:
|
123
123
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
124
124
|
if not hasattr(model_meta, "models"):
|
@@ -143,7 +143,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
143
143
|
cls,
|
144
144
|
raw_model: llm.LLM,
|
145
145
|
model_meta: model_meta_api.ModelMetadata,
|
146
|
-
**kwargs: Unpack[model_types.
|
146
|
+
**kwargs: Unpack[model_types.LLMLoadOptions],
|
147
147
|
) -> custom_model.CustomModel:
|
148
148
|
import gc
|
149
149
|
import tempfile
|
@@ -160,7 +160,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
160
160
|
name: str,
|
161
161
|
model_meta: model_meta_api.ModelMetadata,
|
162
162
|
model_blobs_dir_path: str,
|
163
|
-
**kwargs: Unpack[model_types.
|
163
|
+
**kwargs: Unpack[model_types.MLFlowLoadOptions],
|
164
164
|
) -> "mlflow.pyfunc.PyFuncModel":
|
165
165
|
import mlflow
|
166
166
|
|
@@ -194,7 +194,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
194
194
|
cls,
|
195
195
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
196
196
|
model_meta: model_meta_api.ModelMetadata,
|
197
|
-
**kwargs: Unpack[model_types.
|
197
|
+
**kwargs: Unpack[model_types.MLFlowLoadOptions],
|
198
198
|
) -> custom_model.CustomModel:
|
199
199
|
from snowflake.ml.model import custom_model
|
200
200
|
|
@@ -137,7 +137,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
137
137
|
name: str,
|
138
138
|
model_meta: model_meta_api.ModelMetadata,
|
139
139
|
model_blobs_dir_path: str,
|
140
|
-
**kwargs: Unpack[model_types.
|
140
|
+
**kwargs: Unpack[model_types.PyTorchLoadOptions],
|
141
141
|
) -> "torch.nn.Module":
|
142
142
|
import torch
|
143
143
|
|
@@ -156,7 +156,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
156
156
|
cls,
|
157
157
|
raw_model: "torch.nn.Module",
|
158
158
|
model_meta: model_meta_api.ModelMetadata,
|
159
|
-
**kwargs: Unpack[model_types.
|
159
|
+
**kwargs: Unpack[model_types.PyTorchLoadOptions],
|
160
160
|
) -> custom_model.CustomModel:
|
161
161
|
import torch
|
162
162
|
|
@@ -126,7 +126,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
126
126
|
name: str,
|
127
127
|
model_meta: model_meta_api.ModelMetadata,
|
128
128
|
model_blobs_dir_path: str,
|
129
|
-
**kwargs: Unpack[model_types.
|
129
|
+
**kwargs: Unpack[model_types.SentenceTransformersLoadOptions], # use_gpu
|
130
130
|
) -> "sentence_transformers.SentenceTransformer":
|
131
131
|
import sentence_transformers
|
132
132
|
|
@@ -154,7 +154,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
154
154
|
cls,
|
155
155
|
raw_model: "sentence_transformers.SentenceTransformer",
|
156
156
|
model_meta: model_meta_api.ModelMetadata,
|
157
|
-
**kwargs: Unpack[model_types.
|
157
|
+
**kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
|
158
158
|
) -> custom_model.CustomModel:
|
159
159
|
import sentence_transformers
|
160
160
|
|
@@ -133,7 +133,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
133
133
|
name: str,
|
134
134
|
model_meta: model_meta_api.ModelMetadata,
|
135
135
|
model_blobs_dir_path: str,
|
136
|
-
**kwargs: Unpack[model_types.
|
136
|
+
**kwargs: Unpack[model_types.SKLModelLoadOptions],
|
137
137
|
) -> Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]:
|
138
138
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
139
139
|
model_blobs_metadata = model_meta.models
|
@@ -153,7 +153,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
153
153
|
cls,
|
154
154
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
155
155
|
model_meta: model_meta_api.ModelMetadata,
|
156
|
-
**kwargs: Unpack[model_types.
|
156
|
+
**kwargs: Unpack[model_types.SKLModelLoadOptions],
|
157
157
|
) -> custom_model.CustomModel:
|
158
158
|
from snowflake.ml.model import custom_model
|
159
159
|
|
@@ -127,7 +127,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
127
127
|
name: str,
|
128
128
|
model_meta: model_meta_api.ModelMetadata,
|
129
129
|
model_blobs_dir_path: str,
|
130
|
-
**kwargs: Unpack[model_types.
|
130
|
+
**kwargs: Unpack[model_types.SNOWModelLoadOptions],
|
131
131
|
) -> "BaseEstimator":
|
132
132
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
133
133
|
model_blobs_metadata = model_meta.models
|
@@ -146,7 +146,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
146
146
|
cls,
|
147
147
|
raw_model: "BaseEstimator",
|
148
148
|
model_meta: model_meta_api.ModelMetadata,
|
149
|
-
**kwargs: Unpack[model_types.
|
149
|
+
**kwargs: Unpack[model_types.SNOWModelLoadOptions],
|
150
150
|
) -> custom_model.CustomModel:
|
151
151
|
from snowflake.ml.model import custom_model
|
152
152
|
|
@@ -138,7 +138,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
138
138
|
name: str,
|
139
139
|
model_meta: model_meta_api.ModelMetadata,
|
140
140
|
model_blobs_dir_path: str,
|
141
|
-
**kwargs: Unpack[model_types.
|
141
|
+
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
142
142
|
) -> "tensorflow.Module":
|
143
143
|
import tensorflow
|
144
144
|
|
@@ -156,7 +156,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
156
156
|
cls,
|
157
157
|
raw_model: "tensorflow.Module",
|
158
158
|
model_meta: model_meta_api.ModelMetadata,
|
159
|
-
**kwargs: Unpack[model_types.
|
159
|
+
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
160
160
|
) -> custom_model.CustomModel:
|
161
161
|
import tensorflow
|
162
162
|
|
@@ -128,7 +128,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
128
128
|
name: str,
|
129
129
|
model_meta: model_meta_api.ModelMetadata,
|
130
130
|
model_blobs_dir_path: str,
|
131
|
-
**kwargs: Unpack[model_types.
|
131
|
+
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
132
132
|
) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
|
133
133
|
import torch
|
134
134
|
|
@@ -152,7 +152,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
152
152
|
cls,
|
153
153
|
raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
154
154
|
model_meta: model_meta_api.ModelMetadata,
|
155
|
-
**kwargs: Unpack[model_types.
|
155
|
+
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
156
156
|
) -> custom_model.CustomModel:
|
157
157
|
from snowflake.ml.model import custom_model
|
158
158
|
|
@@ -141,7 +141,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
141
141
|
name: str,
|
142
142
|
model_meta: model_meta_api.ModelMetadata,
|
143
143
|
model_blobs_dir_path: str,
|
144
|
-
**kwargs: Unpack[model_types.
|
144
|
+
**kwargs: Unpack[model_types.XGBModelLoadOptions],
|
145
145
|
) -> Union["xgboost.Booster", "xgboost.XGBModel"]:
|
146
146
|
import xgboost
|
147
147
|
|
@@ -175,7 +175,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
175
175
|
cls,
|
176
176
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
177
177
|
model_meta: model_meta_api.ModelMetadata,
|
178
|
-
**kwargs: Unpack[model_types.
|
178
|
+
**kwargs: Unpack[model_types.XGBModelLoadOptions],
|
179
179
|
) -> custom_model.CustomModel:
|
180
180
|
import xgboost
|
181
181
|
|
@@ -23,6 +23,7 @@ class ModelBlobMeta:
|
|
23
23
|
self.model_type = kwargs["model_type"]
|
24
24
|
self.path = kwargs["path"]
|
25
25
|
self.handler_version = kwargs["handler_version"]
|
26
|
+
self.function_properties = kwargs.get("function_properties", {})
|
26
27
|
|
27
28
|
self.artifacts: Dict[str, str] = {}
|
28
29
|
artifacts = kwargs.get("artifacts", None)
|
@@ -39,6 +40,7 @@ class ModelBlobMeta:
|
|
39
40
|
model_type=self.model_type,
|
40
41
|
path=self.path,
|
41
42
|
handler_version=self.handler_version,
|
43
|
+
function_properties=self.function_properties,
|
42
44
|
artifacts=self.artifacts,
|
43
45
|
options=self.options,
|
44
46
|
)
|
@@ -7,11 +7,12 @@ import zipfile
|
|
7
7
|
from contextlib import contextmanager
|
8
8
|
from datetime import datetime
|
9
9
|
from types import ModuleType
|
10
|
-
from typing import Any, Dict, Generator, List, Optional
|
10
|
+
from typing import Any, Dict, Generator, List, Optional, TypedDict
|
11
11
|
|
12
12
|
import cloudpickle
|
13
13
|
import yaml
|
14
14
|
from packaging import requirements, version
|
15
|
+
from typing_extensions import Required
|
15
16
|
|
16
17
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
17
18
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
@@ -47,6 +48,7 @@ def create_model_metadata(
|
|
47
48
|
name: str,
|
48
49
|
model_type: model_types.SupportedModelHandlerType,
|
49
50
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
51
|
+
function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
|
50
52
|
metadata: Optional[Dict[str, str]] = None,
|
51
53
|
code_paths: Optional[List[str]] = None,
|
52
54
|
ext_modules: Optional[List[ModuleType]] = None,
|
@@ -64,6 +66,7 @@ def create_model_metadata(
|
|
64
66
|
model_type: Type of the model.
|
65
67
|
signatures: Signatures of the model. If None, it will be inferred after the model meta is created.
|
66
68
|
Defaults to None.
|
69
|
+
function_properties: Dict mapping function names to a dict of properties, mapping property key to value.
|
67
70
|
metadata: User provided key-value metadata of the model. Defaults to None.
|
68
71
|
code_paths: List of paths to additional codes that needs to be packed with. Defaults to None.
|
69
72
|
ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
|
@@ -127,6 +130,7 @@ def create_model_metadata(
|
|
127
130
|
metadata=metadata,
|
128
131
|
model_type=model_type,
|
129
132
|
signatures=signatures,
|
133
|
+
function_properties=function_properties,
|
130
134
|
)
|
131
135
|
|
132
136
|
code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
|
@@ -215,6 +219,12 @@ def load_code_path(model_dir_path: str) -> None:
|
|
215
219
|
sys.path.insert(0, code_path)
|
216
220
|
|
217
221
|
|
222
|
+
class ModelMetadataTelemetryDict(TypedDict):
|
223
|
+
model_name: Required[str]
|
224
|
+
framework_type: Required[model_types.SupportedModelHandlerType]
|
225
|
+
number_of_functions: Required[int]
|
226
|
+
|
227
|
+
|
218
228
|
class ModelMetadata:
|
219
229
|
"""Model metadata for Snowflake native model packaged model.
|
220
230
|
|
@@ -224,10 +234,18 @@ class ModelMetadata:
|
|
224
234
|
env: ModelEnv object containing all environment related object
|
225
235
|
models: Dict of model blob metadata
|
226
236
|
signatures: A dict mapping from target function name to input and output signatures.
|
237
|
+
function_properties: A dict mapping function names to dict mapping function property key to value.
|
227
238
|
metadata: User provided key-value metadata of the model. Defaults to None.
|
228
239
|
creation_timestamp: Unix timestamp when the model metadata is created.
|
229
240
|
"""
|
230
241
|
|
242
|
+
def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
|
243
|
+
return ModelMetadataTelemetryDict(
|
244
|
+
model_name=self.name,
|
245
|
+
framework_type=self.model_type,
|
246
|
+
number_of_functions=len(self.signatures.keys()),
|
247
|
+
)
|
248
|
+
|
231
249
|
def __init__(
|
232
250
|
self,
|
233
251
|
*,
|
@@ -236,6 +254,7 @@ class ModelMetadata:
|
|
236
254
|
model_type: model_types.SupportedModelHandlerType,
|
237
255
|
runtimes: Optional[Dict[str, model_runtime.ModelRuntime]] = None,
|
238
256
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
257
|
+
function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
|
239
258
|
metadata: Optional[Dict[str, str]] = None,
|
240
259
|
creation_timestamp: Optional[str] = None,
|
241
260
|
min_snowpark_ml_version: Optional[str] = None,
|
@@ -246,6 +265,7 @@ class ModelMetadata:
|
|
246
265
|
self.signatures: Dict[str, model_signature.ModelSignature] = dict()
|
247
266
|
if signatures:
|
248
267
|
self.signatures = signatures
|
268
|
+
self.function_properties = function_properties or {}
|
249
269
|
self.metadata = metadata
|
250
270
|
self.model_type = model_type
|
251
271
|
self.env = env
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# This files contains schema definition of what will be written into model.yml
|
2
2
|
# Changing this file should lead to a change of the schema version.
|
3
|
-
|
3
|
+
from enum import Enum
|
4
4
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
5
5
|
|
6
6
|
from typing_extensions import NotRequired, Required
|
@@ -11,6 +11,10 @@ MODEL_METADATA_VERSION = "2023-12-01"
|
|
11
11
|
MODEL_METADATA_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
12
12
|
|
13
13
|
|
14
|
+
class FunctionProperties(Enum):
|
15
|
+
PARTITIONED = "PARTITIONED"
|
16
|
+
|
17
|
+
|
14
18
|
class ModelRuntimeDependenciesDict(TypedDict):
|
15
19
|
conda: Required[str]
|
16
20
|
pip: Required[str]
|
@@ -72,6 +76,7 @@ class ModelBlobMetadataDict(TypedDict):
|
|
72
76
|
model_type: Required[type_hints.SupportedModelHandlerType]
|
73
77
|
path: Required[str]
|
74
78
|
handler_version: Required[str]
|
79
|
+
function_properties: NotRequired[Dict[str, Dict[str, Any]]]
|
75
80
|
artifacts: NotRequired[Dict[str, str]]
|
76
81
|
options: NotRequired[ModelBlobOptions]
|
77
82
|
|
@@ -47,12 +47,12 @@ class ModelPackager:
|
|
47
47
|
ext_modules: Optional[List[ModuleType]] = None,
|
48
48
|
code_paths: Optional[List[str]] = None,
|
49
49
|
options: Optional[model_types.ModelSaveOption] = None,
|
50
|
-
) ->
|
50
|
+
) -> model_meta.ModelMetadata:
|
51
51
|
if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
|
52
52
|
raise snowml_exceptions.SnowflakeMLException(
|
53
53
|
error_code=error_codes.INVALID_ARGUMENT,
|
54
54
|
original_exception=ValueError(
|
55
|
-
"
|
55
|
+
"Either of `signatures` or `sample_input_data` must be provided for this kind of model."
|
56
56
|
),
|
57
57
|
)
|
58
58
|
|
@@ -103,6 +103,7 @@ class ModelPackager:
|
|
103
103
|
|
104
104
|
self.model = model
|
105
105
|
self.meta = meta
|
106
|
+
return meta
|
106
107
|
|
107
108
|
def load(
|
108
109
|
self,
|
@@ -110,7 +111,7 @@ class ModelPackager:
|
|
110
111
|
meta_only: bool = False,
|
111
112
|
as_custom_model: bool = False,
|
112
113
|
options: Optional[model_types.ModelLoadOption] = None,
|
113
|
-
) ->
|
114
|
+
) -> model_meta.ModelMetadata:
|
114
115
|
"""Load the model into memory from directory. Used internal only.
|
115
116
|
|
116
117
|
Args:
|
@@ -120,11 +121,14 @@ class ModelPackager:
|
|
120
121
|
|
121
122
|
Raises:
|
122
123
|
SnowflakeMLException: Raised if model is not native format.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
Metadata of loaded model.
|
123
127
|
"""
|
124
128
|
|
125
129
|
self.meta = model_meta.ModelMetadata.load(self.local_dir_path)
|
126
130
|
if meta_only:
|
127
|
-
return
|
131
|
+
return self.meta
|
128
132
|
|
129
133
|
model_meta.load_code_path(self.local_dir_path)
|
130
134
|
|
@@ -146,3 +150,4 @@ class ModelPackager:
|
|
146
150
|
assert isinstance(m, custom_model.CustomModel)
|
147
151
|
|
148
152
|
self.model = m
|
153
|
+
return self.meta
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import datetime
|
1
2
|
from collections import abc
|
2
3
|
from typing import Literal, Sequence
|
3
4
|
|
@@ -24,7 +25,7 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
|
|
24
25
|
# String is a Sequence but we take them as an whole
|
25
26
|
if isinstance(element, abc.Sequence) and not isinstance(element, str):
|
26
27
|
can_handle = ListOfBuiltinHandler.can_handle(element)
|
27
|
-
elif not isinstance(element, (int, float, bool, str)):
|
28
|
+
elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
|
28
29
|
can_handle = False
|
29
30
|
break
|
30
31
|
return can_handle
|
@@ -53,6 +53,8 @@ class DataType(Enum):
|
|
53
53
|
STRING = ("string", spt.StringType, np.str_)
|
54
54
|
BYTES = ("bytes", spt.BinaryType, np.bytes_)
|
55
55
|
|
56
|
+
TIMESTAMP_NTZ = ("datetime64[ns]", spt.TimestampType, "datetime64[ns]")
|
57
|
+
|
56
58
|
def as_snowpark_type(self) -> spt.DataType:
|
57
59
|
"""Convert to corresponding Snowpark Type.
|
58
60
|
|
@@ -78,6 +80,13 @@ class DataType(Enum):
|
|
78
80
|
Corresponding DataType.
|
79
81
|
"""
|
80
82
|
np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType}
|
83
|
+
|
84
|
+
# Add datetime types:
|
85
|
+
datetime_res = ["Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"]
|
86
|
+
|
87
|
+
for res in datetime_res:
|
88
|
+
np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ
|
89
|
+
|
81
90
|
for potential_type in np_to_snowml_type_mapping.keys():
|
82
91
|
if np.can_cast(np_type, potential_type, casting="no"):
|
83
92
|
# This is used since the same dtype might represented in different ways.
|
@@ -247,9 +256,12 @@ class FeatureSpec(BaseFeatureSpec):
|
|
247
256
|
result_type = spt.ArrayType(result_type)
|
248
257
|
return result_type
|
249
258
|
|
250
|
-
def as_dtype(self) -> npt.DTypeLike:
|
259
|
+
def as_dtype(self) -> Union[npt.DTypeLike, str]:
|
251
260
|
"""Convert to corresponding local Type."""
|
252
261
|
if not self._shape:
|
262
|
+
# scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64'
|
263
|
+
if "datetime64" in self._dtype._value:
|
264
|
+
return self._dtype._value
|
253
265
|
return self._dtype._numpy_type
|
254
266
|
return np.object_
|
255
267
|
|
@@ -147,6 +147,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
147
147
|
specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
|
148
148
|
elif isinstance(data[df_col].iloc[0], bytes):
|
149
149
|
specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
|
150
|
+
elif isinstance(data[df_col].iloc[0], np.datetime64):
|
151
|
+
specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
|
150
152
|
else:
|
151
153
|
specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(df_col_dtype), name=ft_name))
|
152
154
|
return specs
|
@@ -107,6 +107,9 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
107
107
|
if not features:
|
108
108
|
features = pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input")
|
109
109
|
# Role will be no effect on the column index. That is to say, the feature name is the actual column name.
|
110
|
+
if keep_order:
|
111
|
+
df = df.reset_index(drop=True)
|
112
|
+
df[infer_template._KEEP_ORDER_COL_NAME] = df.index
|
110
113
|
sp_df = session.create_dataframe(df)
|
111
114
|
column_names = []
|
112
115
|
columns = []
|
@@ -122,7 +125,4 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
122
125
|
|
123
126
|
sp_df = sp_df.with_columns(column_names, columns)
|
124
127
|
|
125
|
-
if keep_order:
|
126
|
-
sp_df = sp_df.with_column(infer_template._KEEP_ORDER_COL_NAME, F.monotonically_increasing_id())
|
127
|
-
|
128
128
|
return sp_df
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import inspect
|
3
|
-
from typing import Any, Callable, Coroutine, Dict, Generator, Optional
|
3
|
+
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional
|
4
4
|
|
5
5
|
import anyio
|
6
6
|
import pandas as pd
|
@@ -168,7 +168,7 @@ class CustomModel:
|
|
168
168
|
def _get_infer_methods(
|
169
169
|
self,
|
170
170
|
) -> Generator[Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame], None, None]:
|
171
|
-
"""Returns all methods in CLS with
|
171
|
+
"""Returns all methods in CLS with `inference_api` decorator as the outermost decorator."""
|
172
172
|
for cls_method_str in dir(self):
|
173
173
|
cls_method = getattr(self, cls_method_str)
|
174
174
|
if getattr(cls_method, "_is_inference_api", False):
|
@@ -177,6 +177,18 @@ class CustomModel:
|
|
177
177
|
else:
|
178
178
|
raise TypeError("A non-method inference API function is not supported.")
|
179
179
|
|
180
|
+
def _get_partitioned_infer_methods(self) -> List[str]:
|
181
|
+
"""Returns all methods in CLS with `partitioned_inference_api` as the outermost decorator."""
|
182
|
+
rv = []
|
183
|
+
for cls_method_str in dir(self):
|
184
|
+
cls_method = getattr(self, cls_method_str)
|
185
|
+
if getattr(cls_method, "_is_partitioned_inference_api", False):
|
186
|
+
if inspect.ismethod(cls_method):
|
187
|
+
rv.append(cls_method_str)
|
188
|
+
else:
|
189
|
+
raise TypeError("A non-method inference API function is not supported.")
|
190
|
+
return rv
|
191
|
+
|
180
192
|
|
181
193
|
def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]) -> None:
|
182
194
|
"""Validate the user provided predict method.
|
@@ -219,3 +231,11 @@ def inference_api(
|
|
219
231
|
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
220
232
|
func.__dict__["_is_inference_api"] = True
|
221
233
|
return func
|
234
|
+
|
235
|
+
|
236
|
+
def partitioned_inference_api(
|
237
|
+
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]
|
238
|
+
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
239
|
+
func.__dict__["_is_inference_api"] = True
|
240
|
+
func.__dict__["_is_partitioned_inference_api"] = True
|
241
|
+
return func
|
@@ -168,6 +168,8 @@ def _validate_numpy_array(
|
|
168
168
|
max_v <= np.finfo(feature_type._numpy_type).max # type: ignore[arg-type]
|
169
169
|
and min_v >= np.finfo(feature_type._numpy_type).min # type: ignore[arg-type]
|
170
170
|
)
|
171
|
+
elif feature_type in [core.DataType.TIMESTAMP_NTZ]:
|
172
|
+
return np.issubdtype(arr.dtype, np.datetime64)
|
171
173
|
else:
|
172
174
|
return np.can_cast(arr.dtype, feature_type._numpy_type, casting="no")
|
173
175
|
|