snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/telemetry.py +4 -2
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/import_utils.py +31 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/data/torch_utils.py +33 -14
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
- snowflake/ml/feature_store/examples/example_helper.py +6 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
- snowflake/ml/feature_store/feature_store.py +1 -2
- snowflake/ml/feature_store/feature_view.py +5 -1
- snowflake/ml/model/_client/model/model_version_impl.py +145 -11
- snowflake/ml/model/_client/ops/model_ops.py +56 -16
- snowflake/ml/model/_client/ops/service_ops.py +46 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
- snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_packager.py +0 -11
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +87 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/custom_model.py +47 -7
- snowflake/ml/model/model_signature.py +40 -9
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
- snowflake/ml/modeling/cluster/dbscan.py +5 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
- snowflake/ml/modeling/cluster/k_means.py +14 -19
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
- snowflake/ml/modeling/cluster/optics.py +6 -6
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
- snowflake/ml/modeling/compose/column_transformer.py +15 -5
- snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
- snowflake/ml/modeling/decomposition/pca.py +28 -15
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
- snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +0 -10
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
- snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +3 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
- snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
- snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +3 -3
- snowflake/ml/modeling/manifold/tsne.py +10 -4
- snowflake/ml/modeling/metrics/classification.py +12 -16
- snowflake/ml/modeling/metrics/ranking.py +3 -3
- snowflake/ml/modeling/metrics/regression.py +3 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +16 -14
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
- snowflake/ml/modeling/svm/linear_svc.py +25 -16
- snowflake/ml/modeling/svm/linear_svr.py +23 -17
- snowflake/ml/modeling/svm/nu_svc.py +5 -3
- snowflake/ml/modeling/svm/nu_svr.py +3 -1
- snowflake/ml/modeling/svm/svc.py +9 -5
- snowflake/ml/modeling/svm/svr.py +3 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
- snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
- snowflake/ml/monitoring/model_monitor.py +37 -0
- snowflake/ml/registry/_manager/model_manager.py +15 -1
- snowflake/ml/registry/registry.py +32 -37
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/_client/model_monitor.py +0 -126
- snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
- snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
import os
|
3
3
|
import warnings
|
4
|
-
from importlib import metadata as importlib_metadata
|
5
4
|
from typing import (
|
6
5
|
TYPE_CHECKING,
|
7
6
|
Any,
|
@@ -16,23 +15,19 @@ from typing import (
|
|
16
15
|
|
17
16
|
import numpy as np
|
18
17
|
import pandas as pd
|
19
|
-
from packaging import version
|
20
18
|
from typing_extensions import TypeGuard, Unpack
|
21
19
|
|
22
20
|
from snowflake.ml._internal import type_utils
|
23
21
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
24
22
|
from snowflake.ml.model._packager.model_env import model_env
|
25
|
-
from snowflake.ml.model._packager.model_handlers import
|
26
|
-
_base,
|
27
|
-
_utils as handlers_utils,
|
28
|
-
model_objective_utils,
|
29
|
-
)
|
23
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
30
24
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
31
25
|
from snowflake.ml.model._packager.model_meta import (
|
32
26
|
model_blob_meta,
|
33
27
|
model_meta as model_meta_api,
|
34
28
|
model_meta_schema,
|
35
29
|
)
|
30
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
36
31
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
37
32
|
|
38
33
|
if TYPE_CHECKING:
|
@@ -94,23 +89,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
94
89
|
|
95
90
|
assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
|
96
91
|
|
97
|
-
local_xgb_version = None
|
98
|
-
|
99
|
-
try:
|
100
|
-
local_dist = importlib_metadata.distribution("xgboost")
|
101
|
-
local_xgb_version = version.parse(local_dist.version)
|
102
|
-
except importlib_metadata.PackageNotFoundError:
|
103
|
-
pass
|
104
|
-
|
105
|
-
if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
|
106
|
-
warnings.warn(
|
107
|
-
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
108
|
-
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
109
|
-
category=UserWarning,
|
110
|
-
stacklevel=1,
|
111
|
-
)
|
112
|
-
enable_explainability = False
|
113
|
-
|
114
92
|
if not is_sub_model:
|
115
93
|
target_methods = handlers_utils.get_target_methods(
|
116
94
|
model=model,
|
@@ -139,7 +117,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
139
117
|
sample_input_data=sample_input_data,
|
140
118
|
get_prediction_fn=get_prediction,
|
141
119
|
)
|
142
|
-
model_task_and_output =
|
120
|
+
model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
|
143
121
|
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
144
122
|
if enable_explainability:
|
145
123
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -187,23 +165,15 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
187
165
|
],
|
188
166
|
check_local_version=True,
|
189
167
|
)
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
)
|
197
|
-
else:
|
198
|
-
model_meta.env.include_if_absent(
|
199
|
-
[
|
200
|
-
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
201
|
-
],
|
202
|
-
check_local_version=True,
|
203
|
-
)
|
168
|
+
model_meta.env.include_if_absent(
|
169
|
+
[
|
170
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
171
|
+
],
|
172
|
+
check_local_version=True,
|
173
|
+
)
|
204
174
|
|
205
175
|
if enable_explainability:
|
206
|
-
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
176
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
|
207
177
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
208
178
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
209
179
|
|
@@ -1,3 +1,2 @@
|
|
1
|
-
REQUIREMENTS = [
|
2
|
-
|
3
|
-
]
|
1
|
+
REQUIREMENTS = ['cloudpickle>=2.0.0']
|
2
|
+
ALL_REQUIREMENTS=['cloudpickle>=2.0.0']
|
@@ -58,11 +58,16 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
|
|
58
58
|
xgb_estimator_type: Required[str]
|
59
59
|
|
60
60
|
|
61
|
+
class TensorflowModelBlobOptions(BaseModelBlobOptions):
|
62
|
+
is_keras_model: Required[bool]
|
63
|
+
|
64
|
+
|
61
65
|
ModelBlobOptions = Union[
|
62
66
|
BaseModelBlobOptions,
|
63
67
|
HuggingFacePipelineModelBlobOptions,
|
64
68
|
MLFlowModelBlobOptions,
|
65
69
|
XgboostModelBlobOptions,
|
70
|
+
TensorflowModelBlobOptions,
|
66
71
|
]
|
67
72
|
|
68
73
|
|
@@ -61,17 +61,6 @@ class ModelPackager:
|
|
61
61
|
if not options:
|
62
62
|
options = model_types.BaseModelSaveOption()
|
63
63
|
|
64
|
-
# here handling the case of enable_explainability is False/None
|
65
|
-
enable_explainability = options.get("enable_explainability", None)
|
66
|
-
if enable_explainability is False or enable_explainability is None:
|
67
|
-
if (signatures is not None) and (sample_input_data is not None):
|
68
|
-
raise snowml_exceptions.SnowflakeMLException(
|
69
|
-
error_code=error_codes.INVALID_ARGUMENT,
|
70
|
-
original_exception=ValueError(
|
71
|
-
"Signatures and sample_input_data both cannot be specified at the same time."
|
72
|
-
),
|
73
|
-
)
|
74
|
-
|
75
64
|
handler = model_handler.find_handler(model)
|
76
65
|
if handler is None:
|
77
66
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -1,10 +1,2 @@
|
|
1
|
-
REQUIREMENTS = [
|
2
|
-
|
3
|
-
"anyio>=3.5.0,<4",
|
4
|
-
"numpy>=1.23,<2",
|
5
|
-
"packaging>=20.9,<24",
|
6
|
-
"pandas>=1.0.0,<3",
|
7
|
-
"pyyaml>=6.0,<7",
|
8
|
-
"snowflake-snowpark-python>=1.17.0,<2",
|
9
|
-
"typing-extensions>=4.1.0,<5"
|
10
|
-
]
|
1
|
+
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
2
|
+
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
@@ -17,6 +17,8 @@ _SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES = [
|
|
17
17
|
for r in _snowml_inference_alternative_requirements.REQUIREMENTS
|
18
18
|
]
|
19
19
|
|
20
|
+
PACKAGES_NOT_ALLOWED_IN_WAREHOUSE = ["snowflake-connector-python", "pyarrow"]
|
21
|
+
|
20
22
|
|
21
23
|
class ModelRuntime:
|
22
24
|
"""Class to represent runtime in a model, which controls the runtime and version, imports and dependencies.
|
@@ -61,15 +63,8 @@ class ModelRuntime:
|
|
61
63
|
],
|
62
64
|
)
|
63
65
|
|
64
|
-
if
|
65
|
-
self.runtime_env.
|
66
|
-
[
|
67
|
-
model_env.ModelDependency(
|
68
|
-
requirement="pyarrow",
|
69
|
-
pip_name="pyarrow",
|
70
|
-
)
|
71
|
-
],
|
72
|
-
)
|
66
|
+
if is_warehouse and self.embed_local_ml_library:
|
67
|
+
self.runtime_env.remove_if_present_conda(PACKAGES_NOT_ALLOWED_IN_WAREHOUSE)
|
73
68
|
|
74
69
|
if is_gpu:
|
75
70
|
self.runtime_env.generate_env_for_cuda()
|
@@ -84,7 +84,7 @@ def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel
|
|
84
84
|
if type_utils.LazyType("lightgbm.Booster").isinstance(model):
|
85
85
|
model_task = model.params["objective"] # type: ignore[attr-defined]
|
86
86
|
elif hasattr(model, "objective_"):
|
87
|
-
model_task = model.objective_
|
87
|
+
model_task = model.objective_ # type: ignore[assignment]
|
88
88
|
if model_task in _BINARY_CLASSIFICATION_OBJECTIVES:
|
89
89
|
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
90
90
|
if model_task in _MULTI_CLASSIFICATION_OBJECTIVES:
|
@@ -128,42 +128,30 @@ def get_model_task_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> t
|
|
128
128
|
return type_hints.Task.UNKNOWN
|
129
129
|
|
130
130
|
|
131
|
-
def
|
131
|
+
def _get_model_task(model: Any) -> type_hints.Task:
|
132
132
|
if type_utils.LazyType("xgboost.Booster").isinstance(model) or type_utils.LazyType("xgboost.XGBModel").isinstance(
|
133
133
|
model
|
134
134
|
):
|
135
|
-
|
136
|
-
output_type = model_signature.DataType.DOUBLE
|
137
|
-
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
138
|
-
output_type = model_signature.DataType.STRING
|
139
|
-
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
135
|
+
return get_model_task_xgb(model)
|
140
136
|
|
141
137
|
if type_utils.LazyType("lightgbm.Booster").isinstance(model) or type_utils.LazyType(
|
142
138
|
"lightgbm.LGBMModel"
|
143
139
|
).isinstance(model):
|
144
|
-
|
145
|
-
output_type = model_signature.DataType.DOUBLE
|
146
|
-
if task in [
|
147
|
-
type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
|
148
|
-
type_hints.Task.TABULAR_MULTI_CLASSIFICATION,
|
149
|
-
]:
|
150
|
-
output_type = model_signature.DataType.STRING
|
151
|
-
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
140
|
+
return get_model_task_lightgbm(model)
|
152
141
|
|
153
142
|
if type_utils.LazyType("catboost.CatBoost").isinstance(model):
|
154
|
-
|
155
|
-
output_type = model_signature.DataType.DOUBLE
|
156
|
-
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
157
|
-
output_type = model_signature.DataType.STRING
|
158
|
-
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
143
|
+
return get_model_task_catboost(model)
|
159
144
|
|
160
145
|
if type_utils.LazyType("sklearn.base.BaseEstimator").isinstance(model) or type_utils.LazyType(
|
161
146
|
"sklearn.pipeline.Pipeline"
|
162
147
|
).isinstance(model):
|
163
|
-
|
164
|
-
output_type = model_signature.DataType.DOUBLE
|
165
|
-
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
166
|
-
output_type = model_signature.DataType.STRING
|
167
|
-
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
168
|
-
|
148
|
+
return get_task_skl(model)
|
169
149
|
raise ValueError(f"Model type {type(model)} is not supported")
|
150
|
+
|
151
|
+
|
152
|
+
def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
|
153
|
+
task = _get_model_task(model)
|
154
|
+
output_type = model_signature.DataType.DOUBLE
|
155
|
+
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
156
|
+
output_type = model_signature.DataType.STRING
|
157
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
@@ -14,10 +14,12 @@ from typing import (
|
|
14
14
|
Type,
|
15
15
|
Union,
|
16
16
|
final,
|
17
|
+
get_args,
|
17
18
|
)
|
18
19
|
|
19
20
|
import numpy as np
|
20
21
|
import numpy.typing as npt
|
22
|
+
import pandas as pd
|
21
23
|
|
22
24
|
import snowflake.snowpark.types as spt
|
23
25
|
from snowflake.ml._internal.exceptions import (
|
@@ -29,6 +31,21 @@ if TYPE_CHECKING:
|
|
29
31
|
import mlflow
|
30
32
|
import torch
|
31
33
|
|
34
|
+
PandasExtensionTypes = Union[
|
35
|
+
pd.Int8Dtype,
|
36
|
+
pd.Int16Dtype,
|
37
|
+
pd.Int32Dtype,
|
38
|
+
pd.Int64Dtype,
|
39
|
+
pd.UInt8Dtype,
|
40
|
+
pd.UInt16Dtype,
|
41
|
+
pd.UInt32Dtype,
|
42
|
+
pd.UInt64Dtype,
|
43
|
+
pd.Float32Dtype,
|
44
|
+
pd.Float64Dtype,
|
45
|
+
pd.BooleanDtype,
|
46
|
+
pd.StringDtype,
|
47
|
+
]
|
48
|
+
|
32
49
|
|
33
50
|
class DataType(Enum):
|
34
51
|
def __init__(self, value: str, snowpark_type: Type[spt.DataType], numpy_type: npt.DTypeLike) -> None:
|
@@ -67,11 +84,11 @@ class DataType(Enum):
|
|
67
84
|
return f"DataType.{self.name}"
|
68
85
|
|
69
86
|
@classmethod
|
70
|
-
def from_numpy_type(cls,
|
87
|
+
def from_numpy_type(cls, input_type: Union[npt.DTypeLike, PandasExtensionTypes]) -> "DataType":
|
71
88
|
"""Translate numpy dtype to DataType for signature definition.
|
72
89
|
|
73
90
|
Args:
|
74
|
-
|
91
|
+
input_type: The numpy dtype or Pandas Extension Dtype
|
75
92
|
|
76
93
|
Raises:
|
77
94
|
SnowflakeMLException: NotImplementedError: Raised when the given numpy type is not supported.
|
@@ -79,6 +96,10 @@ class DataType(Enum):
|
|
79
96
|
Returns:
|
80
97
|
Corresponding DataType.
|
81
98
|
"""
|
99
|
+
# To support pandas extension dtype
|
100
|
+
if isinstance(input_type, get_args(PandasExtensionTypes)):
|
101
|
+
input_type = input_type.type
|
102
|
+
|
82
103
|
np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType}
|
83
104
|
|
84
105
|
# Add datetime types:
|
@@ -88,12 +109,12 @@ class DataType(Enum):
|
|
88
109
|
np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ
|
89
110
|
|
90
111
|
for potential_type in np_to_snowml_type_mapping.keys():
|
91
|
-
if np.can_cast(
|
112
|
+
if np.can_cast(input_type, potential_type, casting="no"):
|
92
113
|
# This is used since the same dtype might represented in different ways.
|
93
114
|
return np_to_snowml_type_mapping[potential_type]
|
94
115
|
raise snowml_exceptions.SnowflakeMLException(
|
95
116
|
error_code=error_codes.NOT_IMPLEMENTED,
|
96
|
-
original_exception=NotImplementedError(f"Type {
|
117
|
+
original_exception=NotImplementedError(f"Type {input_type} is not supported as a DataType."),
|
97
118
|
)
|
98
119
|
|
99
120
|
@classmethod
|
@@ -212,6 +233,7 @@ class FeatureSpec(BaseFeatureSpec):
|
|
212
233
|
name: str,
|
213
234
|
dtype: DataType,
|
214
235
|
shape: Optional[Tuple[int, ...]] = None,
|
236
|
+
nullable: bool = True,
|
215
237
|
) -> None:
|
216
238
|
"""
|
217
239
|
Initialize a feature.
|
@@ -219,6 +241,7 @@ class FeatureSpec(BaseFeatureSpec):
|
|
219
241
|
Args:
|
220
242
|
name: Name of the feature.
|
221
243
|
dtype: Type of the elements in the feature.
|
244
|
+
nullable: Whether the feature is nullable. Defaults to True.
|
222
245
|
shape: Used to represent scalar feature, 1-d feature list,
|
223
246
|
or n-d tensor. Use -1 to represent variable length. Defaults to None.
|
224
247
|
|
@@ -227,6 +250,7 @@ class FeatureSpec(BaseFeatureSpec):
|
|
227
250
|
- (2,): 1d list with a fixed length of 2.
|
228
251
|
- (-1,): 1d list with variable length, used for ragged tensor representation.
|
229
252
|
- (d1, d2, d3): 3d tensor.
|
253
|
+
nullable: Whether the feature is nullable. Defaults to True.
|
230
254
|
|
231
255
|
Raises:
|
232
256
|
SnowflakeMLException: TypeError: When the dtype input type is incorrect.
|
@@ -248,6 +272,8 @@ class FeatureSpec(BaseFeatureSpec):
|
|
248
272
|
)
|
249
273
|
self._shape = shape
|
250
274
|
|
275
|
+
self._nullable = nullable
|
276
|
+
|
251
277
|
def as_snowpark_type(self) -> spt.DataType:
|
252
278
|
result_type = self._dtype.as_snowpark_type()
|
253
279
|
if not self._shape:
|
@@ -256,13 +282,34 @@ class FeatureSpec(BaseFeatureSpec):
|
|
256
282
|
result_type = spt.ArrayType(result_type)
|
257
283
|
return result_type
|
258
284
|
|
259
|
-
def as_dtype(self) -> Union[npt.DTypeLike, str]:
|
285
|
+
def as_dtype(self) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
|
260
286
|
"""Convert to corresponding local Type."""
|
287
|
+
|
261
288
|
if not self._shape:
|
262
289
|
# scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64'
|
263
290
|
if "datetime64" in self._dtype._value:
|
264
291
|
return self._dtype._value
|
265
|
-
|
292
|
+
|
293
|
+
np_type = self._dtype._numpy_type
|
294
|
+
if self._nullable:
|
295
|
+
np_to_pd_dtype_mapping = {
|
296
|
+
np.int8: pd.Int8Dtype(),
|
297
|
+
np.int16: pd.Int16Dtype(),
|
298
|
+
np.int32: pd.Int32Dtype(),
|
299
|
+
np.int64: pd.Int64Dtype(),
|
300
|
+
np.uint8: pd.UInt8Dtype(),
|
301
|
+
np.uint16: pd.UInt16Dtype(),
|
302
|
+
np.uint32: pd.UInt32Dtype(),
|
303
|
+
np.uint64: pd.UInt64Dtype(),
|
304
|
+
np.float32: pd.Float32Dtype(),
|
305
|
+
np.float64: pd.Float64Dtype(),
|
306
|
+
np.bool_: pd.BooleanDtype(),
|
307
|
+
np.str_: pd.StringDtype(),
|
308
|
+
}
|
309
|
+
|
310
|
+
return np_to_pd_dtype_mapping.get(np_type, np_type) # type: ignore[arg-type]
|
311
|
+
|
312
|
+
return np_type
|
266
313
|
return np.object_
|
267
314
|
|
268
315
|
def __eq__(self, other: object) -> bool:
|
@@ -273,7 +320,10 @@ class FeatureSpec(BaseFeatureSpec):
|
|
273
320
|
|
274
321
|
def __repr__(self) -> str:
|
275
322
|
shape_str = f", shape={repr(self._shape)}" if self._shape else ""
|
276
|
-
return
|
323
|
+
return (
|
324
|
+
f"FeatureSpec(dtype={repr(self._dtype)}, "
|
325
|
+
f"name={repr(self._name)}{shape_str}, nullable={repr(self._nullable)})"
|
326
|
+
)
|
277
327
|
|
278
328
|
def to_dict(self) -> Dict[str, Any]:
|
279
329
|
"""Serialize the feature group into a dict.
|
@@ -281,10 +331,7 @@ class FeatureSpec(BaseFeatureSpec):
|
|
281
331
|
Returns:
|
282
332
|
A dict that serializes the feature group.
|
283
333
|
"""
|
284
|
-
base_dict: Dict[str, Any] = {
|
285
|
-
"type": self._dtype.name,
|
286
|
-
"name": self._name,
|
287
|
-
}
|
334
|
+
base_dict: Dict[str, Any] = {"type": self._dtype.name, "name": self._name, "nullable": self._nullable}
|
288
335
|
if self._shape is not None:
|
289
336
|
base_dict["shape"] = self._shape
|
290
337
|
return base_dict
|
@@ -304,7 +351,9 @@ class FeatureSpec(BaseFeatureSpec):
|
|
304
351
|
if shape:
|
305
352
|
shape = tuple(shape)
|
306
353
|
type = DataType[input_dict["type"]]
|
307
|
-
|
354
|
+
# If nullable is not provided, default to False for backward compatibility.
|
355
|
+
nullable = input_dict.get("nullable", False)
|
356
|
+
return FeatureSpec(name=name, dtype=type, shape=shape, nullable=nullable)
|
308
357
|
|
309
358
|
@classmethod
|
310
359
|
def from_mlflow_spec(
|
@@ -475,10 +524,8 @@ class ModelSignature:
|
|
475
524
|
sig_outs = loaded["outputs"]
|
476
525
|
sig_inputs = loaded["inputs"]
|
477
526
|
|
478
|
-
deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = (
|
479
|
-
|
480
|
-
if "feature_group" in sig_spec
|
481
|
-
else FeatureSpec.from_dict(sig_spec)
|
527
|
+
deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = lambda sig_spec: (
|
528
|
+
FeatureGroupSpec.from_dict(sig_spec) if "feature_group" in sig_spec else FeatureSpec.from_dict(sig_spec)
|
482
529
|
)
|
483
530
|
|
484
531
|
return ModelSignature(
|
@@ -1,4 +1,5 @@
|
|
1
|
-
|
1
|
+
import warnings
|
2
|
+
from typing import Literal, Sequence, Union
|
2
3
|
|
3
4
|
import numpy as np
|
4
5
|
import pandas as pd
|
@@ -14,8 +15,8 @@ from snowflake.ml.model._signatures import base_handler, core, utils
|
|
14
15
|
|
15
16
|
class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
16
17
|
@staticmethod
|
17
|
-
def can_handle(data: model_types.SupportedDataType) -> TypeGuard[pd.DataFrame]:
|
18
|
-
return isinstance(data, pd.DataFrame)
|
18
|
+
def can_handle(data: model_types.SupportedDataType) -> TypeGuard[Union[pd.DataFrame, pd.Series]]:
|
19
|
+
return isinstance(data, pd.DataFrame) or isinstance(data, pd.Series)
|
19
20
|
|
20
21
|
@staticmethod
|
21
22
|
def count(data: pd.DataFrame) -> int:
|
@@ -26,7 +27,17 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
26
27
|
return data.head(min(PandasDataFrameHandler.count(data), PandasDataFrameHandler.SIG_INFER_ROWS_COUNT_LIMIT))
|
27
28
|
|
28
29
|
@staticmethod
|
29
|
-
def validate(data: pd.DataFrame) -> None:
|
30
|
+
def validate(data: Union[pd.DataFrame, pd.Series]) -> None:
|
31
|
+
if isinstance(data, pd.Series):
|
32
|
+
# check if the series is empty and throw error
|
33
|
+
if data.empty:
|
34
|
+
raise snowml_exceptions.SnowflakeMLException(
|
35
|
+
error_code=error_codes.INVALID_DATA,
|
36
|
+
original_exception=ValueError("Data Validation Error: Empty data is found."),
|
37
|
+
)
|
38
|
+
# convert the series to a dataframe
|
39
|
+
data = data.to_frame()
|
40
|
+
|
30
41
|
df_cols = data.columns
|
31
42
|
|
32
43
|
if df_cols.has_duplicates: # Rule out categorical index with duplicates
|
@@ -60,21 +71,44 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
60
71
|
|
61
72
|
df_col_dtypes = [data[col].dtype for col in data.columns]
|
62
73
|
for df_col, df_col_dtype in zip(df_cols, df_col_dtypes):
|
74
|
+
df_col_data = data[df_col]
|
75
|
+
if df_col_data.isnull().all():
|
76
|
+
raise snowml_exceptions.SnowflakeMLException(
|
77
|
+
error_code=error_codes.INVALID_DATA,
|
78
|
+
original_exception=ValueError(
|
79
|
+
f"Data Validation Error: There is no non-null data in column {df_col}."
|
80
|
+
),
|
81
|
+
)
|
82
|
+
if df_col_data.isnull().any():
|
83
|
+
warnings.warn(
|
84
|
+
(
|
85
|
+
f"Null value detected in column {df_col}, model signature inference might not accurate, "
|
86
|
+
"or your prediction might fail if your model does not support null input. If this is not "
|
87
|
+
"expected, please check your input dataframe."
|
88
|
+
),
|
89
|
+
category=UserWarning,
|
90
|
+
stacklevel=2,
|
91
|
+
)
|
92
|
+
|
93
|
+
df_col_data = utils.series_dropna(df_col_data)
|
94
|
+
df_col_dtype = df_col_data.dtype
|
95
|
+
|
63
96
|
if df_col_dtype == np.dtype("O"):
|
64
97
|
# Check if all objects have the same type
|
65
|
-
if not all(isinstance(data_row, type(
|
98
|
+
if not all(isinstance(data_row, type(df_col_data.iloc[0])) for data_row in df_col_data):
|
66
99
|
raise snowml_exceptions.SnowflakeMLException(
|
67
100
|
error_code=error_codes.INVALID_DATA,
|
68
101
|
original_exception=ValueError(
|
69
|
-
|
102
|
+
"Data Validation Error: "
|
103
|
+
+ f"Inconsistent type of element in object found in column data {df_col_data}."
|
70
104
|
),
|
71
105
|
)
|
72
106
|
|
73
|
-
if isinstance(
|
74
|
-
arr = utils.convert_list_to_ndarray(
|
107
|
+
if isinstance(df_col_data.iloc[0], list):
|
108
|
+
arr = utils.convert_list_to_ndarray(df_col_data.iloc[0])
|
75
109
|
arr_dtype = core.DataType.from_numpy_type(arr.dtype)
|
76
110
|
|
77
|
-
converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in
|
111
|
+
converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data]
|
78
112
|
|
79
113
|
if not all(
|
80
114
|
core.DataType.from_numpy_type(converted_data.dtype) == arr_dtype
|
@@ -84,32 +118,37 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
84
118
|
error_code=error_codes.INVALID_DATA,
|
85
119
|
original_exception=ValueError(
|
86
120
|
"Data Validation Error: "
|
87
|
-
+ f"Inconsistent type of element in object found in column data {
|
121
|
+
+ f"Inconsistent type of element in object found in column data {df_col_data}."
|
88
122
|
),
|
89
123
|
)
|
90
124
|
|
91
|
-
elif isinstance(
|
92
|
-
arr_dtype = core.DataType.from_numpy_type(
|
125
|
+
elif isinstance(df_col_data.iloc[0], np.ndarray):
|
126
|
+
arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype)
|
93
127
|
|
94
|
-
if not all(core.DataType.from_numpy_type(data_row.dtype) == arr_dtype for data_row in
|
128
|
+
if not all(core.DataType.from_numpy_type(data_row.dtype) == arr_dtype for data_row in df_col_data):
|
95
129
|
raise snowml_exceptions.SnowflakeMLException(
|
96
130
|
error_code=error_codes.INVALID_DATA,
|
97
131
|
original_exception=ValueError(
|
98
132
|
"Data Validation Error: "
|
99
|
-
+ f"Inconsistent type of element in object found in column data {
|
133
|
+
+ f"Inconsistent type of element in object found in column data {df_col_data}."
|
100
134
|
),
|
101
135
|
)
|
102
|
-
elif not isinstance(
|
136
|
+
elif not isinstance(df_col_data.iloc[0], (str, bytes)):
|
103
137
|
raise snowml_exceptions.SnowflakeMLException(
|
104
138
|
error_code=error_codes.INVALID_DATA,
|
105
139
|
original_exception=ValueError(
|
106
|
-
f"Data Validation Error: Unsupported type confronted in {
|
140
|
+
f"Data Validation Error: Unsupported type confronted in {df_col_data}"
|
107
141
|
),
|
108
142
|
)
|
109
143
|
|
110
144
|
@staticmethod
|
111
|
-
def infer_signature(
|
145
|
+
def infer_signature(
|
146
|
+
data: Union[pd.DataFrame, pd.Series],
|
147
|
+
role: Literal["input", "output"],
|
148
|
+
) -> Sequence[core.BaseFeatureSpec]:
|
112
149
|
feature_prefix = f"{PandasDataFrameHandler.FEATURE_PREFIX}_"
|
150
|
+
if isinstance(data, pd.Series):
|
151
|
+
data = data.to_frame()
|
113
152
|
df_cols = data.columns
|
114
153
|
role_prefix = (
|
115
154
|
PandasDataFrameHandler.INPUT_PREFIX if role == "input" else PandasDataFrameHandler.OUTPUT_PREFIX
|
@@ -123,30 +162,51 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
123
162
|
|
124
163
|
specs = []
|
125
164
|
for df_col, df_col_dtype, ft_name in zip(df_cols, df_col_dtypes, ft_names):
|
165
|
+
df_col_data = data[df_col]
|
166
|
+
if df_col_data.isnull().any():
|
167
|
+
df_col_data = utils.series_dropna(df_col_data)
|
168
|
+
df_col_dtype = df_col_data.dtype
|
169
|
+
|
126
170
|
if df_col_dtype == np.dtype("O"):
|
127
|
-
if isinstance(
|
128
|
-
arr = utils.convert_list_to_ndarray(
|
171
|
+
if isinstance(df_col_data.iloc[0], list):
|
172
|
+
arr = utils.convert_list_to_ndarray(df_col_data.iloc[0])
|
129
173
|
arr_dtype = core.DataType.from_numpy_type(arr.dtype)
|
130
|
-
ft_shape = np.shape(
|
174
|
+
ft_shape = np.shape(df_col_data.iloc[0])
|
131
175
|
|
132
|
-
converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in
|
176
|
+
converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data]
|
133
177
|
|
134
178
|
if not all(np.shape(converted_data) == ft_shape for converted_data in converted_data_list):
|
135
179
|
ft_shape = (-1,)
|
136
180
|
|
137
181
|
specs.append(core.FeatureSpec(dtype=arr_dtype, name=ft_name, shape=ft_shape))
|
138
|
-
elif isinstance(
|
139
|
-
arr_dtype = core.DataType.from_numpy_type(
|
140
|
-
ft_shape = np.shape(
|
182
|
+
elif isinstance(df_col_data.iloc[0], np.ndarray):
|
183
|
+
arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype)
|
184
|
+
ft_shape = np.shape(df_col_data.iloc[0])
|
141
185
|
|
142
|
-
if not all(np.shape(data_row) == ft_shape for data_row in
|
186
|
+
if not all(np.shape(data_row) == ft_shape for data_row in df_col_data):
|
143
187
|
ft_shape = (-1,)
|
144
188
|
|
145
189
|
specs.append(core.FeatureSpec(dtype=arr_dtype, name=ft_name, shape=ft_shape))
|
146
|
-
elif isinstance(
|
190
|
+
elif isinstance(df_col_data.iloc[0], str):
|
147
191
|
specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
|
148
|
-
elif isinstance(
|
192
|
+
elif isinstance(df_col_data.iloc[0], bytes):
|
149
193
|
specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
|
194
|
+
elif isinstance(df_col_dtype, pd.CategoricalDtype):
|
195
|
+
category_dtype = df_col_dtype.categories.dtype
|
196
|
+
if category_dtype == np.dtype("O"):
|
197
|
+
if isinstance(df_col_dtype.categories[0], str):
|
198
|
+
specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
|
199
|
+
elif isinstance(df_col_dtype.categories[0], bytes):
|
200
|
+
specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
|
201
|
+
else:
|
202
|
+
raise snowml_exceptions.SnowflakeMLException(
|
203
|
+
error_code=error_codes.INVALID_DATA,
|
204
|
+
original_exception=ValueError(
|
205
|
+
f"Data Validation Error: Unsupported type confronted in {df_col_dtype.categories[0]}"
|
206
|
+
),
|
207
|
+
)
|
208
|
+
else:
|
209
|
+
specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(category_dtype), name=ft_name))
|
150
210
|
elif isinstance(data[df_col].iloc[0], np.datetime64):
|
151
211
|
specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
|
152
212
|
else:
|
@@ -72,10 +72,10 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
|
|
72
72
|
dtype = core.DataType.from_torch_type(data_col.dtype)
|
73
73
|
ft_name = f"{role_prefix}{feature_prefix}{i}"
|
74
74
|
if len(data_col.shape) == 1:
|
75
|
-
features.append(core.FeatureSpec(dtype=dtype, name=ft_name))
|
75
|
+
features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False))
|
76
76
|
else:
|
77
77
|
ft_shape = tuple(data_col.shape[1:])
|
78
|
-
features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
|
78
|
+
features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False))
|
79
79
|
return features
|
80
80
|
|
81
81
|
@staticmethod
|
@@ -82,7 +82,8 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
82
82
|
identifier.get_unescaped_names(field.name)
|
83
83
|
].map(json.loads)
|
84
84
|
# Only when the feature is not from inference, we are confident to do the type casting.
|
85
|
-
# Otherwise, dtype_map will be empty
|
85
|
+
# Otherwise, dtype_map will be empty.
|
86
|
+
# Errors are ignored to make sure None won't be converted and won't raise Error
|
86
87
|
df_local = df_local.astype(dtype=dtype_map)
|
87
88
|
return df_local
|
88
89
|
|