snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -1,8 +1,10 @@
|
|
1
|
+
import json
|
1
2
|
import textwrap
|
2
3
|
import warnings
|
3
4
|
from abc import ABC, abstractmethod
|
4
5
|
from enum import Enum
|
5
6
|
from typing import (
|
7
|
+
TYPE_CHECKING,
|
6
8
|
Any,
|
7
9
|
Callable,
|
8
10
|
Dict,
|
@@ -26,8 +28,14 @@ from typing_extensions import TypeGuard
|
|
26
28
|
|
27
29
|
import snowflake.snowpark
|
28
30
|
import snowflake.snowpark.types as spt
|
31
|
+
from snowflake.ml._internal import type_utils
|
29
32
|
from snowflake.ml._internal.utils import formatting, identifier
|
30
33
|
from snowflake.ml.model import type_hints as model_types
|
34
|
+
from snowflake.ml.model._deploy_client.warehouse import infer_template
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
import tensorflow
|
38
|
+
import torch
|
31
39
|
|
32
40
|
|
33
41
|
class DataType(Enum):
|
@@ -36,22 +44,22 @@ class DataType(Enum):
|
|
36
44
|
self._snowpark_type = snowpark_type
|
37
45
|
self._numpy_type = numpy_type
|
38
46
|
|
39
|
-
INT8 = ("int8", spt.
|
40
|
-
INT16 = ("int16", spt.
|
47
|
+
INT8 = ("int8", spt.ByteType, np.int8)
|
48
|
+
INT16 = ("int16", spt.ShortType, np.int16)
|
41
49
|
INT32 = ("int32", spt.IntegerType, np.int32)
|
42
|
-
INT64 = ("int64", spt.
|
50
|
+
INT64 = ("int64", spt.LongType, np.int64)
|
43
51
|
|
44
52
|
FLOAT = ("float", spt.FloatType, np.float32)
|
45
53
|
DOUBLE = ("double", spt.DoubleType, np.float64)
|
46
54
|
|
47
|
-
UINT8 = ("uint8", spt.
|
48
|
-
UINT16 = ("uint16", spt.
|
55
|
+
UINT8 = ("uint8", spt.ByteType, np.uint8)
|
56
|
+
UINT16 = ("uint16", spt.ShortType, np.uint16)
|
49
57
|
UINT32 = ("uint32", spt.IntegerType, np.uint32)
|
50
|
-
UINT64 = ("uint64", spt.
|
58
|
+
UINT64 = ("uint64", spt.LongType, np.uint64)
|
51
59
|
|
52
|
-
BOOL = ("bool", spt.BooleanType, np.
|
53
|
-
STRING = ("string", spt.StringType, np.
|
54
|
-
BYTES = ("bytes", spt.BinaryType, np.
|
60
|
+
BOOL = ("bool", spt.BooleanType, np.bool_)
|
61
|
+
STRING = ("string", spt.StringType, np.str_)
|
62
|
+
BYTES = ("bytes", spt.BinaryType, np.bytes_)
|
55
63
|
|
56
64
|
def as_snowpark_type(self) -> spt.DataType:
|
57
65
|
"""Convert to corresponding Snowpark Type.
|
@@ -84,6 +92,30 @@ class DataType(Enum):
|
|
84
92
|
return np_to_snowml_type_mapping[potential_type]
|
85
93
|
raise NotImplementedError(f"Type {np_type} is not supported as a DataType.")
|
86
94
|
|
95
|
+
@classmethod
|
96
|
+
def from_torch_type(cls, torch_type: "torch.dtype") -> "DataType":
|
97
|
+
import torch
|
98
|
+
|
99
|
+
"""Translate torch dtype to DataType for signature definition.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
torch_type: The torch dtype.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
Corresponding DataType.
|
106
|
+
"""
|
107
|
+
torch_dtype_to_numpy_dtype_mapping = {
|
108
|
+
torch.uint8: np.uint8,
|
109
|
+
torch.int8: np.int8,
|
110
|
+
torch.int16: np.int16,
|
111
|
+
torch.int32: np.int32,
|
112
|
+
torch.int64: np.int64,
|
113
|
+
torch.float32: np.float32,
|
114
|
+
torch.float64: np.float64,
|
115
|
+
torch.bool: np.bool_,
|
116
|
+
}
|
117
|
+
return cls.from_numpy_type(torch_dtype_to_numpy_dtype_mapping[torch_type])
|
118
|
+
|
87
119
|
@classmethod
|
88
120
|
def from_snowpark_type(cls, snowpark_type: spt.DataType) -> "DataType":
|
89
121
|
"""Translate snowpark type to DataType for signature definition.
|
@@ -97,30 +129,45 @@ class DataType(Enum):
|
|
97
129
|
Returns:
|
98
130
|
Corresponding DataType.
|
99
131
|
"""
|
132
|
+
if isinstance(snowpark_type, spt.ArrayType):
|
133
|
+
actual_sp_type = snowpark_type.element_type
|
134
|
+
else:
|
135
|
+
actual_sp_type = snowpark_type
|
136
|
+
|
100
137
|
snowpark_to_snowml_type_mapping: Dict[Type[spt.DataType], DataType] = {
|
101
|
-
|
102
|
-
|
138
|
+
i._snowpark_type: i
|
139
|
+
for i in DataType
|
140
|
+
# We by default infer as signed integer.
|
141
|
+
if i not in [DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64]
|
103
142
|
}
|
104
143
|
for potential_type in snowpark_to_snowml_type_mapping.keys():
|
105
|
-
if isinstance(
|
144
|
+
if isinstance(actual_sp_type, potential_type):
|
106
145
|
return snowpark_to_snowml_type_mapping[potential_type]
|
146
|
+
# Fallback for decimal type.
|
147
|
+
if isinstance(snowpark_type, spt.DecimalType):
|
148
|
+
if snowpark_type.scale == 0:
|
149
|
+
return DataType.INT64
|
107
150
|
raise NotImplementedError(f"Type {snowpark_type} is not supported as a DataType.")
|
108
151
|
|
109
152
|
def is_same_snowpark_type(self, incoming_snowpark_type: spt.DataType) -> bool:
|
110
153
|
"""Check if provided snowpark type is the same as Data Type.
|
111
|
-
Since for Snowflake all integer types are same, thus when datatype is a integer type, the incoming snowpark
|
112
|
-
type can be any type inherit from _IntegralType.
|
113
154
|
|
114
155
|
Args:
|
115
156
|
incoming_snowpark_type: The snowpark type.
|
116
157
|
|
158
|
+
Raises:
|
159
|
+
NotImplementedError: Raised when the given numpy type is not supported.
|
160
|
+
|
117
161
|
Returns:
|
118
162
|
If the provided snowpark type is the same as the DataType.
|
119
163
|
"""
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
164
|
+
# Special handle for Decimal Type.
|
165
|
+
if isinstance(incoming_snowpark_type, spt.DecimalType):
|
166
|
+
if incoming_snowpark_type.scale == 0:
|
167
|
+
return self == DataType.INT64 or self == DataType.UINT64
|
168
|
+
raise NotImplementedError(f"Type {incoming_snowpark_type} is not supported as a DataType.")
|
169
|
+
|
170
|
+
return isinstance(incoming_snowpark_type, self._snowpark_type)
|
124
171
|
|
125
172
|
|
126
173
|
class BaseFeatureSpec(ABC):
|
@@ -174,9 +221,19 @@ class FeatureSpec(BaseFeatureSpec):
|
|
174
221
|
(2,): 1d list with fixed len of 2.
|
175
222
|
(-1,): 1d list with variable length. Used for ragged tensor representation.
|
176
223
|
(d1, d2, d3): 3d tensor.
|
224
|
+
|
225
|
+
Raises:
|
226
|
+
TypeError: Raised when the dtype input type is incorrect.
|
227
|
+
TypeError: Raised when the shape input type is incorrect.
|
177
228
|
"""
|
178
229
|
super().__init__(name=name)
|
230
|
+
|
231
|
+
if not isinstance(dtype, DataType):
|
232
|
+
raise TypeError("dtype should be a model signature datatype.")
|
179
233
|
self._dtype = dtype
|
234
|
+
|
235
|
+
if shape and not isinstance(shape, tuple):
|
236
|
+
raise TypeError("Shape should be a tuple if presented.")
|
180
237
|
self._shape = shape
|
181
238
|
|
182
239
|
def as_snowpark_type(self) -> spt.DataType:
|
@@ -191,7 +248,7 @@ class FeatureSpec(BaseFeatureSpec):
|
|
191
248
|
"""Convert to corresponding local Type."""
|
192
249
|
if not self._shape:
|
193
250
|
return self._dtype._numpy_type
|
194
|
-
return np.
|
251
|
+
return np.object_
|
195
252
|
|
196
253
|
def __eq__(self, other: object) -> bool:
|
197
254
|
if isinstance(other, FeatureSpec):
|
@@ -229,6 +286,8 @@ class FeatureSpec(BaseFeatureSpec):
|
|
229
286
|
"""
|
230
287
|
name = input_dict["name"]
|
231
288
|
shape = input_dict.get("shape", None)
|
289
|
+
if shape:
|
290
|
+
shape = tuple(shape)
|
232
291
|
type = DataType[input_dict["type"]]
|
233
292
|
return FeatureSpec(name=name, dtype=type, shape=shape)
|
234
293
|
|
@@ -421,7 +480,7 @@ class _BaseDataHandler(ABC, Generic[model_types._DataType]):
|
|
421
480
|
|
422
481
|
@staticmethod
|
423
482
|
@abstractmethod
|
424
|
-
def convert_to_df(data: model_types._DataType) ->
|
483
|
+
def convert_to_df(data: model_types._DataType, ensure_serializable: bool = True) -> pd.DataFrame:
|
425
484
|
...
|
426
485
|
|
427
486
|
|
@@ -454,7 +513,7 @@ class _PandasDataFrameHandler(_BaseDataHandler[pd.DataFrame]):
|
|
454
513
|
np.int64,
|
455
514
|
np.uint64,
|
456
515
|
np.float64,
|
457
|
-
np.
|
516
|
+
np.object_,
|
458
517
|
]: # To keep compatibility with Pandas 2.x and 1.x
|
459
518
|
raise ValueError("Data Validation Error: Unsupported column index type is found.")
|
460
519
|
|
@@ -538,7 +597,17 @@ class _PandasDataFrameHandler(_BaseDataHandler[pd.DataFrame]):
|
|
538
597
|
return specs
|
539
598
|
|
540
599
|
@staticmethod
|
541
|
-
def convert_to_df(data: pd.DataFrame) -> pd.DataFrame:
|
600
|
+
def convert_to_df(data: pd.DataFrame, ensure_serializable: bool = True) -> pd.DataFrame:
|
601
|
+
if not ensure_serializable:
|
602
|
+
return data
|
603
|
+
# This convert is necessary since numpy dataframe cannot be correctly handled when provided as an element of
|
604
|
+
# a list when creating Snowpark Dataframe.
|
605
|
+
df_cols = data.columns
|
606
|
+
df_col_dtypes = [data[col].dtype for col in data.columns]
|
607
|
+
for df_col, df_col_dtype in zip(df_cols, df_col_dtypes):
|
608
|
+
if df_col_dtype == np.dtype("O"):
|
609
|
+
if isinstance(data[df_col][0], np.ndarray):
|
610
|
+
data[df_col] = data[df_col].map(np.ndarray.tolist)
|
542
611
|
return data
|
543
612
|
|
544
613
|
|
@@ -569,7 +638,7 @@ class _NumpyArrayHandler(_BaseDataHandler[model_types._SupportedNumpyArray]):
|
|
569
638
|
def infer_signature(
|
570
639
|
data: model_types._SupportedNumpyArray, role: Literal["input", "output"]
|
571
640
|
) -> Sequence[BaseFeatureSpec]:
|
572
|
-
feature_prefix = f"{
|
641
|
+
feature_prefix = f"{_NumpyArrayHandler.FEATURE_PREFIX}_"
|
573
642
|
dtype = DataType.from_numpy_type(data.dtype)
|
574
643
|
role_prefix = (_NumpyArrayHandler.INPUT_PREFIX if role == "input" else _NumpyArrayHandler.OUTPUT_PREFIX) + "_"
|
575
644
|
if len(data.shape) == 1:
|
@@ -588,68 +657,269 @@ class _NumpyArrayHandler(_BaseDataHandler[model_types._SupportedNumpyArray]):
|
|
588
657
|
return features
|
589
658
|
|
590
659
|
@staticmethod
|
591
|
-
def convert_to_df(data: model_types._SupportedNumpyArray) -> pd.DataFrame:
|
660
|
+
def convert_to_df(data: model_types._SupportedNumpyArray, ensure_serializable: bool = True) -> pd.DataFrame:
|
592
661
|
if len(data.shape) == 1:
|
593
662
|
data = np.expand_dims(data, axis=1)
|
594
663
|
n_cols = data.shape[1]
|
595
664
|
if len(data.shape) == 2:
|
596
|
-
return pd.DataFrame(data
|
665
|
+
return pd.DataFrame(data)
|
597
666
|
else:
|
598
667
|
n_rows = data.shape[0]
|
599
|
-
|
668
|
+
if ensure_serializable:
|
669
|
+
return pd.DataFrame(data={i: [data[k, i].tolist() for k in range(n_rows)] for i in range(n_cols)})
|
670
|
+
return pd.DataFrame(data={i: [list(data[k, i]) for k in range(n_rows)] for i in range(n_cols)})
|
600
671
|
|
601
672
|
|
602
|
-
class
|
673
|
+
class _SeqOfNumpyArrayHandler(_BaseDataHandler[Sequence[model_types._SupportedNumpyArray]]):
|
603
674
|
@staticmethod
|
604
|
-
def can_handle(data: model_types.SupportedDataType) -> TypeGuard[
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
)
|
675
|
+
def can_handle(data: model_types.SupportedDataType) -> TypeGuard[Sequence[model_types._SupportedNumpyArray]]:
|
676
|
+
if not isinstance(data, list):
|
677
|
+
return False
|
678
|
+
if len(data) == 0:
|
679
|
+
return False
|
680
|
+
if isinstance(data[0], np.ndarray):
|
681
|
+
return all(isinstance(data_col, np.ndarray) for data_col in data)
|
682
|
+
return False
|
610
683
|
|
611
684
|
@staticmethod
|
612
|
-
def count(data:
|
685
|
+
def count(data: Sequence[model_types._SupportedNumpyArray]) -> int:
|
613
686
|
return min(_NumpyArrayHandler.count(data_col) for data_col in data)
|
614
687
|
|
615
688
|
@staticmethod
|
616
|
-
def truncate(data:
|
689
|
+
def truncate(data: Sequence[model_types._SupportedNumpyArray]) -> Sequence[model_types._SupportedNumpyArray]:
|
617
690
|
return [
|
618
|
-
data_col[: min(
|
691
|
+
data_col[: min(_SeqOfNumpyArrayHandler.count(data), _SeqOfNumpyArrayHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
|
619
692
|
for data_col in data
|
620
693
|
]
|
621
694
|
|
622
695
|
@staticmethod
|
623
|
-
def validate(data:
|
696
|
+
def validate(data: Sequence[model_types._SupportedNumpyArray]) -> None:
|
624
697
|
for data_col in data:
|
625
698
|
_NumpyArrayHandler.validate(data_col)
|
626
699
|
|
627
700
|
@staticmethod
|
628
701
|
def infer_signature(
|
629
|
-
data:
|
702
|
+
data: Sequence[model_types._SupportedNumpyArray], role: Literal["input", "output"]
|
630
703
|
) -> Sequence[BaseFeatureSpec]:
|
704
|
+
feature_prefix = f"{_SeqOfNumpyArrayHandler.FEATURE_PREFIX}_"
|
631
705
|
features: List[BaseFeatureSpec] = []
|
632
706
|
role_prefix = (
|
633
|
-
|
707
|
+
_SeqOfNumpyArrayHandler.INPUT_PREFIX if role == "input" else _SeqOfNumpyArrayHandler.OUTPUT_PREFIX
|
634
708
|
) + "_"
|
635
709
|
|
636
710
|
for i, data_col in enumerate(data):
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
711
|
+
dtype = DataType.from_numpy_type(data_col.dtype)
|
712
|
+
ft_name = f"{role_prefix}{feature_prefix}{i}"
|
713
|
+
if len(data_col.shape) == 1:
|
714
|
+
features.append(FeatureSpec(dtype=dtype, name=ft_name))
|
715
|
+
else:
|
716
|
+
ft_shape = tuple(data_col.shape[1:])
|
717
|
+
features.append(FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
|
641
718
|
return features
|
642
719
|
|
643
720
|
@staticmethod
|
644
|
-
def convert_to_df(
|
645
|
-
|
721
|
+
def convert_to_df(
|
722
|
+
data: Sequence[model_types._SupportedNumpyArray], ensure_serializable: bool = True
|
723
|
+
) -> pd.DataFrame:
|
724
|
+
if ensure_serializable:
|
725
|
+
return pd.DataFrame(data={i: data_col.tolist() for i, data_col in enumerate(data)})
|
726
|
+
return pd.DataFrame(data={i: list(data_col) for i, data_col in enumerate(data)})
|
727
|
+
|
728
|
+
|
729
|
+
class _SeqOfPyTorchTensorHandler(_BaseDataHandler[Sequence["torch.Tensor"]]):
|
730
|
+
@staticmethod
|
731
|
+
def can_handle(data: model_types.SupportedDataType) -> TypeGuard[Sequence["torch.Tensor"]]:
|
732
|
+
if not isinstance(data, list):
|
733
|
+
return False
|
734
|
+
if len(data) == 0:
|
735
|
+
return False
|
736
|
+
if type_utils.LazyType("torch.Tensor").isinstance(data[0]):
|
737
|
+
return all(type_utils.LazyType("torch.Tensor").isinstance(data_col) for data_col in data)
|
738
|
+
return False
|
739
|
+
|
740
|
+
@staticmethod
|
741
|
+
def count(data: Sequence["torch.Tensor"]) -> int:
|
742
|
+
return min(data_col.shape[0] for data_col in data)
|
743
|
+
|
744
|
+
@staticmethod
|
745
|
+
def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]:
|
746
|
+
return [
|
747
|
+
data_col[
|
748
|
+
: min(_SeqOfPyTorchTensorHandler.count(data), _SeqOfPyTorchTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)
|
749
|
+
]
|
750
|
+
for data_col in data
|
751
|
+
]
|
752
|
+
|
753
|
+
@staticmethod
|
754
|
+
def validate(data: Sequence["torch.Tensor"]) -> None:
|
755
|
+
import torch
|
756
|
+
|
757
|
+
for data_col in data:
|
758
|
+
if data_col.shape == torch.Size([0]):
|
759
|
+
# Empty array
|
760
|
+
raise ValueError("Data Validation Error: Empty data is found.")
|
761
|
+
|
762
|
+
if data_col.shape == torch.Size([1]):
|
763
|
+
# scalar
|
764
|
+
raise ValueError("Data Validation Error: Scalar data is found.")
|
765
|
+
|
766
|
+
@staticmethod
|
767
|
+
def infer_signature(data: Sequence["torch.Tensor"], role: Literal["input", "output"]) -> Sequence[BaseFeatureSpec]:
|
768
|
+
feature_prefix = f"{_SeqOfPyTorchTensorHandler.FEATURE_PREFIX}_"
|
769
|
+
features: List[BaseFeatureSpec] = []
|
770
|
+
role_prefix = (
|
771
|
+
_SeqOfPyTorchTensorHandler.INPUT_PREFIX if role == "input" else _SeqOfPyTorchTensorHandler.OUTPUT_PREFIX
|
772
|
+
) + "_"
|
773
|
+
|
774
|
+
for i, data_col in enumerate(data):
|
775
|
+
dtype = DataType.from_torch_type(data_col.dtype)
|
776
|
+
ft_name = f"{role_prefix}{feature_prefix}{i}"
|
777
|
+
if len(data_col.shape) == 1:
|
778
|
+
features.append(FeatureSpec(dtype=dtype, name=ft_name))
|
779
|
+
else:
|
780
|
+
ft_shape = tuple(data_col.shape[1:])
|
781
|
+
features.append(FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
|
782
|
+
return features
|
783
|
+
|
784
|
+
@staticmethod
|
785
|
+
def convert_to_df(data: Sequence["torch.Tensor"], ensure_serializable: bool = True) -> pd.DataFrame:
|
786
|
+
# Use list(...) instead of .tolist() to ensure that
|
787
|
+
# the content is still numpy array so that the type could be preserved.
|
788
|
+
# But that would not serializable and cannot use as UDF input and output.
|
789
|
+
if ensure_serializable:
|
790
|
+
return pd.DataFrame({i: data_col.detach().to("cpu").numpy().tolist() for i, data_col in enumerate(data)})
|
791
|
+
return pd.DataFrame({i: list(data_col.detach().to("cpu").numpy()) for i, data_col in enumerate(data)})
|
792
|
+
|
793
|
+
@staticmethod
|
794
|
+
def convert_from_df(
|
795
|
+
df: pd.DataFrame, features: Optional[Sequence[BaseFeatureSpec]] = None
|
796
|
+
) -> Sequence["torch.Tensor"]:
|
797
|
+
import torch
|
798
|
+
|
799
|
+
res = []
|
800
|
+
if features:
|
801
|
+
for feature in features:
|
802
|
+
if isinstance(feature, FeatureGroupSpec):
|
803
|
+
raise NotImplementedError("FeatureGroupSpec is not supported.")
|
804
|
+
assert isinstance(feature, FeatureSpec), "Invalid feature kind."
|
805
|
+
res.append(torch.from_numpy(np.stack(df[feature.name].to_numpy()).astype(feature._dtype._numpy_type)))
|
806
|
+
return res
|
807
|
+
return [torch.from_numpy(np.stack(df[col].to_numpy())) for col in df]
|
808
|
+
|
809
|
+
|
810
|
+
class _SeqOfTensorflowTensorHandler(_BaseDataHandler[Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]]):
|
811
|
+
@staticmethod
|
812
|
+
def can_handle(
|
813
|
+
data: model_types.SupportedDataType,
|
814
|
+
) -> TypeGuard[Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]]:
|
815
|
+
if not isinstance(data, list):
|
816
|
+
return False
|
817
|
+
if len(data) == 0:
|
818
|
+
return False
|
819
|
+
if type_utils.LazyType("tensorflow.Tensor").isinstance(data[0]) or type_utils.LazyType(
|
820
|
+
"tensorflow.Variable"
|
821
|
+
).isinstance(data[0]):
|
822
|
+
return all(
|
823
|
+
type_utils.LazyType("tensorflow.Tensor").isinstance(data_col)
|
824
|
+
or type_utils.LazyType("tensorflow.Variable").isinstance(data_col)
|
825
|
+
for data_col in data
|
826
|
+
)
|
827
|
+
return False
|
828
|
+
|
829
|
+
@staticmethod
|
830
|
+
def count(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> int:
|
831
|
+
import tensorflow as tf
|
832
|
+
|
833
|
+
rows = []
|
834
|
+
for data_col in data:
|
835
|
+
shapes = data_col.shape.as_list()
|
836
|
+
if data_col.shape == tf.TensorShape(None) or (not shapes) or (shapes[0] is None):
|
837
|
+
# Unknown shape array
|
838
|
+
raise ValueError("Data Validation Error: Unknown shape data is found.")
|
839
|
+
# Make mypy happy
|
840
|
+
assert isinstance(shapes[0], int)
|
841
|
+
|
842
|
+
rows.append(shapes[0])
|
843
|
+
|
844
|
+
return min(rows)
|
845
|
+
|
846
|
+
@staticmethod
|
847
|
+
def truncate(
|
848
|
+
data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]
|
849
|
+
) -> Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]:
|
850
|
+
return [
|
851
|
+
data_col[
|
852
|
+
: min(
|
853
|
+
_SeqOfTensorflowTensorHandler.count(data), _SeqOfTensorflowTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT
|
854
|
+
)
|
855
|
+
]
|
856
|
+
for data_col in data
|
857
|
+
]
|
858
|
+
|
859
|
+
@staticmethod
|
860
|
+
def validate(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> None:
|
861
|
+
import tensorflow as tf
|
862
|
+
|
646
863
|
for data_col in data:
|
864
|
+
if data_col.shape == tf.TensorShape(None) or any(dim is None for dim in data_col.shape.as_list()):
|
865
|
+
# Unknown shape array
|
866
|
+
raise ValueError("Data Validation Error: Unknown shape data is found.")
|
867
|
+
|
868
|
+
if data_col.shape == tf.TensorShape([0]):
|
869
|
+
# Empty array
|
870
|
+
raise ValueError("Data Validation Error: Empty data is found.")
|
871
|
+
|
872
|
+
if data_col.shape == tf.TensorShape([1]) or data_col.shape == tf.TensorShape([]):
|
873
|
+
# scalar
|
874
|
+
raise ValueError("Data Validation Error: Scalar data is found.")
|
875
|
+
|
876
|
+
@staticmethod
|
877
|
+
def infer_signature(
|
878
|
+
data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]], role: Literal["input", "output"]
|
879
|
+
) -> Sequence[BaseFeatureSpec]:
|
880
|
+
feature_prefix = f"{_SeqOfTensorflowTensorHandler.FEATURE_PREFIX}_"
|
881
|
+
features: List[BaseFeatureSpec] = []
|
882
|
+
role_prefix = (
|
883
|
+
_SeqOfTensorflowTensorHandler.INPUT_PREFIX
|
884
|
+
if role == "input"
|
885
|
+
else _SeqOfTensorflowTensorHandler.OUTPUT_PREFIX
|
886
|
+
) + "_"
|
887
|
+
|
888
|
+
for i, data_col in enumerate(data):
|
889
|
+
dtype = DataType.from_numpy_type(data_col.dtype.as_numpy_dtype)
|
890
|
+
ft_name = f"{role_prefix}{feature_prefix}{i}"
|
647
891
|
if len(data_col.shape) == 1:
|
648
|
-
|
892
|
+
features.append(FeatureSpec(dtype=dtype, name=ft_name))
|
649
893
|
else:
|
650
|
-
|
651
|
-
|
652
|
-
return
|
894
|
+
ft_shape = tuple(data_col.shape[1:])
|
895
|
+
features.append(FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
|
896
|
+
return features
|
897
|
+
|
898
|
+
@staticmethod
|
899
|
+
def convert_to_df(
|
900
|
+
data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]], ensure_serializable: bool = True
|
901
|
+
) -> pd.DataFrame:
|
902
|
+
if ensure_serializable:
|
903
|
+
return pd.DataFrame({i: data_col.numpy().tolist() for i, data_col in enumerate(iterable=data)})
|
904
|
+
return pd.DataFrame({i: list(data_col.numpy()) for i, data_col in enumerate(iterable=data)})
|
905
|
+
|
906
|
+
@staticmethod
|
907
|
+
def convert_from_df(
|
908
|
+
df: pd.DataFrame, features: Optional[Sequence[BaseFeatureSpec]] = None
|
909
|
+
) -> Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]:
|
910
|
+
import tensorflow as tf
|
911
|
+
|
912
|
+
res = []
|
913
|
+
if features:
|
914
|
+
for feature in features:
|
915
|
+
if isinstance(feature, FeatureGroupSpec):
|
916
|
+
raise NotImplementedError("FeatureGroupSpec is not supported.")
|
917
|
+
assert isinstance(feature, FeatureSpec), "Invalid feature kind."
|
918
|
+
res.append(
|
919
|
+
tf.convert_to_tensor(np.stack(df[feature.name].to_numpy()).astype(feature._dtype._numpy_type))
|
920
|
+
)
|
921
|
+
return res
|
922
|
+
return [tf.convert_to_tensor(np.stack(df[col].to_numpy())) for col in df]
|
653
923
|
|
654
924
|
|
655
925
|
class _ListOfBuiltinHandler(_BaseDataHandler[model_types._SupportedBuiltinsList]):
|
@@ -684,7 +954,10 @@ class _ListOfBuiltinHandler(_BaseDataHandler[model_types._SupportedBuiltinsList]
|
|
684
954
|
return _PandasDataFrameHandler.infer_signature(pd.DataFrame(data), role)
|
685
955
|
|
686
956
|
@staticmethod
|
687
|
-
def convert_to_df(
|
957
|
+
def convert_to_df(
|
958
|
+
data: model_types._SupportedBuiltinsList,
|
959
|
+
ensure_serializable: bool = True,
|
960
|
+
) -> pd.DataFrame:
|
688
961
|
return pd.DataFrame(data)
|
689
962
|
|
690
963
|
|
@@ -705,7 +978,12 @@ class _SnowparkDataFrameHandler(_BaseDataHandler[snowflake.snowpark.DataFrame]):
|
|
705
978
|
def validate(data: snowflake.snowpark.DataFrame) -> None:
|
706
979
|
schema = data.schema
|
707
980
|
for field in schema.fields:
|
708
|
-
|
981
|
+
data_type = field.datatype
|
982
|
+
if isinstance(data_type, spt.ArrayType):
|
983
|
+
actual_data_type = data_type.element_type
|
984
|
+
else:
|
985
|
+
actual_data_type = data_type
|
986
|
+
if not any(type.is_same_snowpark_type(actual_data_type) for type in DataType):
|
709
987
|
raise ValueError(
|
710
988
|
f"Data Validation Error: Unsupported data type {field.datatype} in column {field.name}."
|
711
989
|
)
|
@@ -718,19 +996,91 @@ class _SnowparkDataFrameHandler(_BaseDataHandler[snowflake.snowpark.DataFrame]):
|
|
718
996
|
schema = data.schema
|
719
997
|
for field in schema.fields:
|
720
998
|
name = identifier.get_unescaped_names(field.name)
|
721
|
-
|
999
|
+
if isinstance(field.datatype, spt.ArrayType):
|
1000
|
+
raise NotImplementedError("Cannot infer model signature from Snowpark DataFrame with Array Type.")
|
1001
|
+
else:
|
1002
|
+
features.append(FeatureSpec(name=name, dtype=DataType.from_snowpark_type(field.datatype)))
|
722
1003
|
return features
|
723
1004
|
|
724
1005
|
@staticmethod
|
725
|
-
def convert_to_df(
|
726
|
-
|
1006
|
+
def convert_to_df(
|
1007
|
+
data: snowflake.snowpark.DataFrame,
|
1008
|
+
ensure_serializable: bool = True,
|
1009
|
+
features: Optional[Sequence[BaseFeatureSpec]] = None,
|
1010
|
+
) -> pd.DataFrame:
|
1011
|
+
# This method do things on top of to_pandas, to make sure the local dataframe got is in correct shape.
|
1012
|
+
dtype_map = {}
|
1013
|
+
if features:
|
1014
|
+
for feature in features:
|
1015
|
+
if isinstance(feature, FeatureGroupSpec):
|
1016
|
+
raise NotImplementedError("FeatureGroupSpec is not supported.")
|
1017
|
+
assert isinstance(feature, FeatureSpec), "Invalid feature kind."
|
1018
|
+
dtype_map[feature.name] = feature.as_dtype()
|
1019
|
+
df_local = data.to_pandas()
|
1020
|
+
# This is because Array will become string (Even though the correct schema is set)
|
1021
|
+
# and object will become variant type and requires an additional loads
|
1022
|
+
# to get correct data otherwise it would be string.
|
1023
|
+
for field in data.schema.fields:
|
1024
|
+
if isinstance(field.datatype, spt.ArrayType):
|
1025
|
+
df_local[identifier.get_unescaped_names(field.name)] = df_local[
|
1026
|
+
identifier.get_unescaped_names(field.name)
|
1027
|
+
].map(json.loads)
|
1028
|
+
# Only when the feature is not from inference, we are confident to do the type casting.
|
1029
|
+
# Otherwise, dtype_map will be empty
|
1030
|
+
df_local = df_local.astype(dtype=dtype_map)
|
1031
|
+
return df_local
|
1032
|
+
|
1033
|
+
@staticmethod
|
1034
|
+
def convert_from_df(
|
1035
|
+
session: snowflake.snowpark.Session, df: pd.DataFrame, keep_order: bool = True
|
1036
|
+
) -> snowflake.snowpark.DataFrame:
|
1037
|
+
# This method is necessary to create the Snowpark Dataframe in correct schema.
|
1038
|
+
# Snowpark ignore the schema argument when providing a pandas DataFrame.
|
1039
|
+
# However, in this case, if a cell of the original Dataframe is some array type,
|
1040
|
+
# they will be inferred as VARIANT.
|
1041
|
+
# To make sure Snowpark get the correct schema, we have to provide in a list of records.
|
1042
|
+
# However, in this case, the order could not be preserved. Thus, a _ID column has to be added,
|
1043
|
+
# if keep_order is True.
|
1044
|
+
# Although in this case, the column with array type can get correct ARRAY type, however, the element
|
1045
|
+
# type is not preserved, and will become string type. This affect the implementation of convert_from_df.
|
1046
|
+
df = _PandasDataFrameHandler.convert_to_df(df)
|
1047
|
+
df_cols = df.columns
|
1048
|
+
if df_cols.dtype != np.object_:
|
1049
|
+
raise ValueError("Cannot convert a Pandas DataFrame whose column index is not a string")
|
1050
|
+
features = _PandasDataFrameHandler.infer_signature(df, role="input")
|
1051
|
+
# Role will be no effect on the column index. That is to say, the feature name is the actual column name.
|
1052
|
+
schema_list = []
|
1053
|
+
for feature in features:
|
1054
|
+
if isinstance(feature, FeatureGroupSpec):
|
1055
|
+
raise NotImplementedError("FeatureGroupSpec is not supported.")
|
1056
|
+
assert isinstance(feature, FeatureSpec), "Invalid feature kind."
|
1057
|
+
schema_list.append(
|
1058
|
+
spt.StructField(
|
1059
|
+
identifier.get_inferred_name(feature.name),
|
1060
|
+
feature.as_snowpark_type(),
|
1061
|
+
nullable=df[feature.name].isnull().any(),
|
1062
|
+
)
|
1063
|
+
)
|
1064
|
+
|
1065
|
+
data = df.rename(columns=identifier.get_inferred_name).to_dict("records")
|
1066
|
+
if keep_order:
|
1067
|
+
for idx, data_item in enumerate(data):
|
1068
|
+
data_item[infer_template._KEEP_ORDER_COL_NAME] = idx
|
1069
|
+
schema_list.append(spt.StructField(infer_template._KEEP_ORDER_COL_NAME, spt.LongType(), nullable=False))
|
1070
|
+
sp_df = session.create_dataframe(
|
1071
|
+
data, # To make sure the schema can be used, otherwise, array will become variant.
|
1072
|
+
spt.StructType(schema_list),
|
1073
|
+
)
|
1074
|
+
return sp_df
|
727
1075
|
|
728
1076
|
|
729
1077
|
_LOCAL_DATA_HANDLERS: List[Type[_BaseDataHandler[Any]]] = [
|
730
1078
|
_PandasDataFrameHandler,
|
731
1079
|
_NumpyArrayHandler,
|
732
|
-
_ListOfNumpyArrayHandler,
|
733
1080
|
_ListOfBuiltinHandler,
|
1081
|
+
_SeqOfNumpyArrayHandler,
|
1082
|
+
_SeqOfPyTorchTensorHandler,
|
1083
|
+
_SeqOfTensorflowTensorHandler,
|
734
1084
|
]
|
735
1085
|
_ALL_DATA_HANDLERS = _LOCAL_DATA_HANDLERS + [_SnowparkDataFrameHandler]
|
736
1086
|
|
@@ -1007,22 +1357,36 @@ def _validate_snowpark_data(data: snowflake.snowpark.DataFrame, features: Sequen
|
|
1007
1357
|
raise NotImplementedError("FeatureGroupSpec is not supported.")
|
1008
1358
|
assert isinstance(feature, FeatureSpec), "Invalid feature kind."
|
1009
1359
|
ft_type = feature._dtype
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1360
|
+
field_data_type = field.datatype
|
1361
|
+
if isinstance(field_data_type, spt.ArrayType):
|
1362
|
+
if feature._shape is None:
|
1363
|
+
raise ValueError(
|
1364
|
+
f"Data Validation Error in feature {ft_name}: "
|
1365
|
+
+ f"Feature is a array feature, while {field.name} is not."
|
1366
|
+
)
|
1367
|
+
warnings.warn(
|
1368
|
+
f"Warn in feature {ft_name}: Feature is a array feature," + " type validation cannot happen.",
|
1369
|
+
category=RuntimeWarning,
|
1014
1370
|
)
|
1371
|
+
else:
|
1372
|
+
if feature._shape:
|
1373
|
+
raise ValueError(
|
1374
|
+
f"Data Validation Error in feature {ft_name}: "
|
1375
|
+
+ f"Feature is a scalar feature, while {field.name} is not."
|
1376
|
+
)
|
1377
|
+
if not ft_type.is_same_snowpark_type(field_data_type):
|
1378
|
+
raise ValueError(
|
1379
|
+
f"Data Validation Error in feature {ft_name}: "
|
1380
|
+
+ f"Feature type {ft_type} is not met by column {field.name}."
|
1381
|
+
)
|
1015
1382
|
if not found:
|
1016
1383
|
raise ValueError(f"Data Validation Error: feature {ft_name} does not exist in data.")
|
1017
1384
|
|
1018
1385
|
|
1019
|
-
def
|
1020
|
-
data
|
1021
|
-
) -> pd.DataFrame:
|
1022
|
-
"""Validate the data with features in model signature and convert to DataFrame
|
1386
|
+
def _convert_local_data_to_df(data: model_types.SupportedLocalDataType) -> pd.DataFrame:
|
1387
|
+
"""Convert local data to pandas DataFrame or Snowpark DataFrame
|
1023
1388
|
|
1024
1389
|
Args:
|
1025
|
-
features: A list of feature specs that the data should follow.
|
1026
1390
|
data: The provided data.
|
1027
1391
|
|
1028
1392
|
Raises:
|
@@ -1035,13 +1399,29 @@ def _convert_and_validate_local_data(
|
|
1035
1399
|
for handler in _LOCAL_DATA_HANDLERS:
|
1036
1400
|
if handler.can_handle(data):
|
1037
1401
|
handler.validate(data)
|
1038
|
-
df = handler.convert_to_df(data)
|
1402
|
+
df = handler.convert_to_df(data, ensure_serializable=False)
|
1039
1403
|
break
|
1040
1404
|
if df is None:
|
1041
1405
|
raise ValueError(f"Data Validation Error: Un-supported type {type(data)} provided.")
|
1042
|
-
|
1406
|
+
return df
|
1407
|
+
|
1408
|
+
|
1409
|
+
def _convert_and_validate_local_data(
|
1410
|
+
data: model_types.SupportedLocalDataType, features: Sequence[BaseFeatureSpec]
|
1411
|
+
) -> pd.DataFrame:
|
1412
|
+
"""Validate the data with features in model signature and convert to DataFrame
|
1413
|
+
|
1414
|
+
Args:
|
1415
|
+
features: A list of feature specs that the data should follow.
|
1416
|
+
data: The provided data.
|
1417
|
+
|
1418
|
+
Returns:
|
1419
|
+
The converted dataframe with renamed column index.
|
1420
|
+
"""
|
1421
|
+
df = _convert_local_data_to_df(data)
|
1043
1422
|
df = _rename_pandas_df(df, features)
|
1044
1423
|
_validate_pandas_df(df, features)
|
1424
|
+
df = _PandasDataFrameHandler.convert_to_df(df, ensure_serializable=True)
|
1045
1425
|
|
1046
1426
|
return df
|
1047
1427
|
|