snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/telemetry.py +4 -2
  7. snowflake/ml/_internal/type_utils.py +3 -3
  8. snowflake/ml/_internal/utils/import_utils.py +31 -0
  9. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
  10. snowflake/ml/data/__init__.py +5 -0
  11. snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
  12. snowflake/ml/data/data_connector.py +1 -1
  13. snowflake/ml/data/torch_utils.py +33 -14
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
  16. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
  18. snowflake/ml/feature_store/examples/example_helper.py +6 -3
  19. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
  20. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
  21. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
  22. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
  23. snowflake/ml/feature_store/feature_store.py +1 -2
  24. snowflake/ml/feature_store/feature_view.py +5 -1
  25. snowflake/ml/model/_client/model/model_version_impl.py +145 -11
  26. snowflake/ml/model/_client/ops/model_ops.py +56 -16
  27. snowflake/ml/model/_client/ops/service_ops.py +46 -30
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  30. snowflake/ml/model/_client/sql/service.py +25 -1
  31. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  34. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
  36. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  37. snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
  38. snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
  39. snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
  41. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
  42. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
  43. snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
  44. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
  45. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  46. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  47. snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
  48. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  49. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  50. snowflake/ml/model/_packager/model_packager.py +0 -11
  51. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  52. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  53. snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
  54. snowflake/ml/model/_signatures/core.py +63 -16
  55. snowflake/ml/model/_signatures/pandas_handler.py +87 -27
  56. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  57. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  58. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  59. snowflake/ml/model/_signatures/utils.py +4 -0
  60. snowflake/ml/model/custom_model.py +47 -7
  61. snowflake/ml/model/model_signature.py +40 -9
  62. snowflake/ml/model/type_hints.py +9 -1
  63. snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
  64. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
  65. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
  66. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
  67. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
  68. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
  69. snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
  70. snowflake/ml/modeling/cluster/dbscan.py +5 -2
  71. snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
  72. snowflake/ml/modeling/cluster/k_means.py +14 -19
  73. snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
  74. snowflake/ml/modeling/cluster/optics.py +6 -6
  75. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
  76. snowflake/ml/modeling/compose/column_transformer.py +15 -5
  77. snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
  78. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  79. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  80. snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
  81. snowflake/ml/modeling/covariance/oas.py +1 -1
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
  85. snowflake/ml/modeling/decomposition/pca.py +28 -15
  86. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
  87. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
  88. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
  89. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
  90. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
  91. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
  92. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
  93. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
  94. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
  96. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  104. snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
  106. snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  108. snowflake/ml/modeling/linear_model/lars.py +0 -10
  109. snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
  110. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  111. snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
  112. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
  113. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
  114. snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
  115. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
  116. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
  119. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
  120. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/perceptron.py +3 -3
  122. snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
  123. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
  124. snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
  125. snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
  126. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
  127. snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
  128. snowflake/ml/modeling/manifold/isomap.py +1 -1
  129. snowflake/ml/modeling/manifold/mds.py +3 -3
  130. snowflake/ml/modeling/manifold/tsne.py +10 -4
  131. snowflake/ml/modeling/metrics/classification.py +12 -16
  132. snowflake/ml/modeling/metrics/ranking.py +3 -3
  133. snowflake/ml/modeling/metrics/regression.py +3 -3
  134. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
  135. snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
  136. snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
  137. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
  138. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
  139. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
  140. snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
  141. snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
  142. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  143. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
  144. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  145. snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
  146. snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
  147. snowflake/ml/modeling/pipeline/pipeline.py +16 -14
  148. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
  149. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
  150. snowflake/ml/modeling/svm/linear_svc.py +25 -16
  151. snowflake/ml/modeling/svm/linear_svr.py +23 -17
  152. snowflake/ml/modeling/svm/nu_svc.py +5 -3
  153. snowflake/ml/modeling/svm/nu_svr.py +3 -1
  154. snowflake/ml/modeling/svm/svc.py +9 -5
  155. snowflake/ml/modeling/svm/svr.py +3 -1
  156. snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
  157. snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
  158. snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
  159. snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
  160. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
  161. snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
  162. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  163. snowflake/ml/monitoring/model_monitor.py +37 -0
  164. snowflake/ml/registry/_manager/model_manager.py +15 -1
  165. snowflake/ml/registry/registry.py +32 -37
  166. snowflake/ml/version.py +1 -1
  167. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
  168. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
  169. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  170. snowflake/ml/monitoring/_client/model_monitor.py +0 -126
  171. snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
  172. snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
  173. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  174. /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
  175. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  176. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  # mypy: disable-error-code="import"
2
2
  import os
3
3
  import warnings
4
- from importlib import metadata as importlib_metadata
5
4
  from typing import (
6
5
  TYPE_CHECKING,
7
6
  Any,
@@ -16,23 +15,19 @@ from typing import (
16
15
 
17
16
  import numpy as np
18
17
  import pandas as pd
19
- from packaging import version
20
18
  from typing_extensions import TypeGuard, Unpack
21
19
 
22
20
  from snowflake.ml._internal import type_utils
23
21
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
24
22
  from snowflake.ml.model._packager.model_env import model_env
25
- from snowflake.ml.model._packager.model_handlers import (
26
- _base,
27
- _utils as handlers_utils,
28
- model_objective_utils,
29
- )
23
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
30
24
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
31
25
  from snowflake.ml.model._packager.model_meta import (
32
26
  model_blob_meta,
33
27
  model_meta as model_meta_api,
34
28
  model_meta_schema,
35
29
  )
30
+ from snowflake.ml.model._packager.model_task import model_task_utils
36
31
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
37
32
 
38
33
  if TYPE_CHECKING:
@@ -94,23 +89,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
94
89
 
95
90
  assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
96
91
 
97
- local_xgb_version = None
98
-
99
- try:
100
- local_dist = importlib_metadata.distribution("xgboost")
101
- local_xgb_version = version.parse(local_dist.version)
102
- except importlib_metadata.PackageNotFoundError:
103
- pass
104
-
105
- if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
106
- warnings.warn(
107
- f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
108
- + "If you want model explanations, lower the xgboost version to <2.1.0.",
109
- category=UserWarning,
110
- stacklevel=1,
111
- )
112
- enable_explainability = False
113
-
114
92
  if not is_sub_model:
115
93
  target_methods = handlers_utils.get_target_methods(
116
94
  model=model,
@@ -139,7 +117,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
139
117
  sample_input_data=sample_input_data,
140
118
  get_prediction_fn=get_prediction,
141
119
  )
142
- model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
120
+ model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
143
121
  model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
144
122
  if enable_explainability:
145
123
  model_meta = handlers_utils.add_explain_method_signature(
@@ -187,23 +165,15 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
187
165
  ],
188
166
  check_local_version=True,
189
167
  )
190
- if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
191
- model_meta.env.include_if_absent(
192
- [
193
- model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
194
- ],
195
- check_local_version=False,
196
- )
197
- else:
198
- model_meta.env.include_if_absent(
199
- [
200
- model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
201
- ],
202
- check_local_version=True,
203
- )
168
+ model_meta.env.include_if_absent(
169
+ [
170
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
171
+ ],
172
+ check_local_version=True,
173
+ )
204
174
 
205
175
  if enable_explainability:
206
- model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
176
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
207
177
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
208
178
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
209
179
 
@@ -1,3 +1,2 @@
1
- REQUIREMENTS = [
2
- "cloudpickle>=2.0.0"
3
- ]
1
+ REQUIREMENTS = ['cloudpickle>=2.0.0']
2
+ ALL_REQUIREMENTS=['cloudpickle>=2.0.0']
@@ -58,11 +58,16 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
58
58
  xgb_estimator_type: Required[str]
59
59
 
60
60
 
61
+ class TensorflowModelBlobOptions(BaseModelBlobOptions):
62
+ is_keras_model: Required[bool]
63
+
64
+
61
65
  ModelBlobOptions = Union[
62
66
  BaseModelBlobOptions,
63
67
  HuggingFacePipelineModelBlobOptions,
64
68
  MLFlowModelBlobOptions,
65
69
  XgboostModelBlobOptions,
70
+ TensorflowModelBlobOptions,
66
71
  ]
67
72
 
68
73
 
@@ -61,17 +61,6 @@ class ModelPackager:
61
61
  if not options:
62
62
  options = model_types.BaseModelSaveOption()
63
63
 
64
- # here handling the case of enable_explainability is False/None
65
- enable_explainability = options.get("enable_explainability", None)
66
- if enable_explainability is False or enable_explainability is None:
67
- if (signatures is not None) and (sample_input_data is not None):
68
- raise snowml_exceptions.SnowflakeMLException(
69
- error_code=error_codes.INVALID_ARGUMENT,
70
- original_exception=ValueError(
71
- "Signatures and sample_input_data both cannot be specified at the same time."
72
- ),
73
- )
74
-
75
64
  handler = model_handler.find_handler(model)
76
65
  if handler is None:
77
66
  raise snowml_exceptions.SnowflakeMLException(
@@ -1,10 +1,2 @@
1
- REQUIREMENTS = [
2
- "absl-py>=0.15,<2",
3
- "anyio>=3.5.0,<4",
4
- "numpy>=1.23,<2",
5
- "packaging>=20.9,<24",
6
- "pandas>=1.0.0,<3",
7
- "pyyaml>=6.0,<7",
8
- "snowflake-snowpark-python>=1.17.0,<2",
9
- "typing-extensions>=4.1.0,<5"
10
- ]
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
+ ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -17,6 +17,8 @@ _SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES = [
17
17
  for r in _snowml_inference_alternative_requirements.REQUIREMENTS
18
18
  ]
19
19
 
20
+ PACKAGES_NOT_ALLOWED_IN_WAREHOUSE = ["snowflake-connector-python", "pyarrow"]
21
+
20
22
 
21
23
  class ModelRuntime:
22
24
  """Class to represent runtime in a model, which controls the runtime and version, imports and dependencies.
@@ -61,15 +63,8 @@ class ModelRuntime:
61
63
  ],
62
64
  )
63
65
 
64
- if not is_warehouse and self.embed_local_ml_library:
65
- self.runtime_env.include_if_absent(
66
- [
67
- model_env.ModelDependency(
68
- requirement="pyarrow",
69
- pip_name="pyarrow",
70
- )
71
- ],
72
- )
66
+ if is_warehouse and self.embed_local_ml_library:
67
+ self.runtime_env.remove_if_present_conda(PACKAGES_NOT_ALLOWED_IN_WAREHOUSE)
73
68
 
74
69
  if is_gpu:
75
70
  self.runtime_env.generate_env_for_cuda()
@@ -84,7 +84,7 @@ def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel
84
84
  if type_utils.LazyType("lightgbm.Booster").isinstance(model):
85
85
  model_task = model.params["objective"] # type: ignore[attr-defined]
86
86
  elif hasattr(model, "objective_"):
87
- model_task = model.objective_
87
+ model_task = model.objective_ # type: ignore[assignment]
88
88
  if model_task in _BINARY_CLASSIFICATION_OBJECTIVES:
89
89
  return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
90
90
  if model_task in _MULTI_CLASSIFICATION_OBJECTIVES:
@@ -128,42 +128,30 @@ def get_model_task_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> t
128
128
  return type_hints.Task.UNKNOWN
129
129
 
130
130
 
131
- def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
131
+ def _get_model_task(model: Any) -> type_hints.Task:
132
132
  if type_utils.LazyType("xgboost.Booster").isinstance(model) or type_utils.LazyType("xgboost.XGBModel").isinstance(
133
133
  model
134
134
  ):
135
- task = get_model_task_xgb(model)
136
- output_type = model_signature.DataType.DOUBLE
137
- if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
138
- output_type = model_signature.DataType.STRING
139
- return ModelTaskAndOutputType(task=task, output_type=output_type)
135
+ return get_model_task_xgb(model)
140
136
 
141
137
  if type_utils.LazyType("lightgbm.Booster").isinstance(model) or type_utils.LazyType(
142
138
  "lightgbm.LGBMModel"
143
139
  ).isinstance(model):
144
- task = get_model_task_lightgbm(model)
145
- output_type = model_signature.DataType.DOUBLE
146
- if task in [
147
- type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
148
- type_hints.Task.TABULAR_MULTI_CLASSIFICATION,
149
- ]:
150
- output_type = model_signature.DataType.STRING
151
- return ModelTaskAndOutputType(task=task, output_type=output_type)
140
+ return get_model_task_lightgbm(model)
152
141
 
153
142
  if type_utils.LazyType("catboost.CatBoost").isinstance(model):
154
- task = get_model_task_catboost(model)
155
- output_type = model_signature.DataType.DOUBLE
156
- if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
157
- output_type = model_signature.DataType.STRING
158
- return ModelTaskAndOutputType(task=task, output_type=output_type)
143
+ return get_model_task_catboost(model)
159
144
 
160
145
  if type_utils.LazyType("sklearn.base.BaseEstimator").isinstance(model) or type_utils.LazyType(
161
146
  "sklearn.pipeline.Pipeline"
162
147
  ).isinstance(model):
163
- task = get_task_skl(model)
164
- output_type = model_signature.DataType.DOUBLE
165
- if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
166
- output_type = model_signature.DataType.STRING
167
- return ModelTaskAndOutputType(task=task, output_type=output_type)
168
-
148
+ return get_task_skl(model)
169
149
  raise ValueError(f"Model type {type(model)} is not supported")
150
+
151
+
152
+ def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
153
+ task = _get_model_task(model)
154
+ output_type = model_signature.DataType.DOUBLE
155
+ if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
156
+ output_type = model_signature.DataType.STRING
157
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
@@ -14,10 +14,12 @@ from typing import (
14
14
  Type,
15
15
  Union,
16
16
  final,
17
+ get_args,
17
18
  )
18
19
 
19
20
  import numpy as np
20
21
  import numpy.typing as npt
22
+ import pandas as pd
21
23
 
22
24
  import snowflake.snowpark.types as spt
23
25
  from snowflake.ml._internal.exceptions import (
@@ -29,6 +31,21 @@ if TYPE_CHECKING:
29
31
  import mlflow
30
32
  import torch
31
33
 
34
+ PandasExtensionTypes = Union[
35
+ pd.Int8Dtype,
36
+ pd.Int16Dtype,
37
+ pd.Int32Dtype,
38
+ pd.Int64Dtype,
39
+ pd.UInt8Dtype,
40
+ pd.UInt16Dtype,
41
+ pd.UInt32Dtype,
42
+ pd.UInt64Dtype,
43
+ pd.Float32Dtype,
44
+ pd.Float64Dtype,
45
+ pd.BooleanDtype,
46
+ pd.StringDtype,
47
+ ]
48
+
32
49
 
33
50
  class DataType(Enum):
34
51
  def __init__(self, value: str, snowpark_type: Type[spt.DataType], numpy_type: npt.DTypeLike) -> None:
@@ -67,11 +84,11 @@ class DataType(Enum):
67
84
  return f"DataType.{self.name}"
68
85
 
69
86
  @classmethod
70
- def from_numpy_type(cls, np_type: npt.DTypeLike) -> "DataType":
87
+ def from_numpy_type(cls, input_type: Union[npt.DTypeLike, PandasExtensionTypes]) -> "DataType":
71
88
  """Translate numpy dtype to DataType for signature definition.
72
89
 
73
90
  Args:
74
- np_type: The numpy dtype.
91
+ input_type: The numpy dtype or Pandas Extension Dtype
75
92
 
76
93
  Raises:
77
94
  SnowflakeMLException: NotImplementedError: Raised when the given numpy type is not supported.
@@ -79,6 +96,10 @@ class DataType(Enum):
79
96
  Returns:
80
97
  Corresponding DataType.
81
98
  """
99
+ # To support pandas extension dtype
100
+ if isinstance(input_type, get_args(PandasExtensionTypes)):
101
+ input_type = input_type.type
102
+
82
103
  np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType}
83
104
 
84
105
  # Add datetime types:
@@ -88,12 +109,12 @@ class DataType(Enum):
88
109
  np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ
89
110
 
90
111
  for potential_type in np_to_snowml_type_mapping.keys():
91
- if np.can_cast(np_type, potential_type, casting="no"):
112
+ if np.can_cast(input_type, potential_type, casting="no"):
92
113
  # This is used since the same dtype might represented in different ways.
93
114
  return np_to_snowml_type_mapping[potential_type]
94
115
  raise snowml_exceptions.SnowflakeMLException(
95
116
  error_code=error_codes.NOT_IMPLEMENTED,
96
- original_exception=NotImplementedError(f"Type {np_type} is not supported as a DataType."),
117
+ original_exception=NotImplementedError(f"Type {input_type} is not supported as a DataType."),
97
118
  )
98
119
 
99
120
  @classmethod
@@ -212,6 +233,7 @@ class FeatureSpec(BaseFeatureSpec):
212
233
  name: str,
213
234
  dtype: DataType,
214
235
  shape: Optional[Tuple[int, ...]] = None,
236
+ nullable: bool = True,
215
237
  ) -> None:
216
238
  """
217
239
  Initialize a feature.
@@ -219,6 +241,7 @@ class FeatureSpec(BaseFeatureSpec):
219
241
  Args:
220
242
  name: Name of the feature.
221
243
  dtype: Type of the elements in the feature.
244
+ nullable: Whether the feature is nullable. Defaults to True.
222
245
  shape: Used to represent scalar feature, 1-d feature list,
223
246
  or n-d tensor. Use -1 to represent variable length. Defaults to None.
224
247
 
@@ -227,6 +250,7 @@ class FeatureSpec(BaseFeatureSpec):
227
250
  - (2,): 1d list with a fixed length of 2.
228
251
  - (-1,): 1d list with variable length, used for ragged tensor representation.
229
252
  - (d1, d2, d3): 3d tensor.
253
+ nullable: Whether the feature is nullable. Defaults to True.
230
254
 
231
255
  Raises:
232
256
  SnowflakeMLException: TypeError: When the dtype input type is incorrect.
@@ -248,6 +272,8 @@ class FeatureSpec(BaseFeatureSpec):
248
272
  )
249
273
  self._shape = shape
250
274
 
275
+ self._nullable = nullable
276
+
251
277
  def as_snowpark_type(self) -> spt.DataType:
252
278
  result_type = self._dtype.as_snowpark_type()
253
279
  if not self._shape:
@@ -256,13 +282,34 @@ class FeatureSpec(BaseFeatureSpec):
256
282
  result_type = spt.ArrayType(result_type)
257
283
  return result_type
258
284
 
259
- def as_dtype(self) -> Union[npt.DTypeLike, str]:
285
+ def as_dtype(self) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
260
286
  """Convert to corresponding local Type."""
287
+
261
288
  if not self._shape:
262
289
  # scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64'
263
290
  if "datetime64" in self._dtype._value:
264
291
  return self._dtype._value
265
- return self._dtype._numpy_type
292
+
293
+ np_type = self._dtype._numpy_type
294
+ if self._nullable:
295
+ np_to_pd_dtype_mapping = {
296
+ np.int8: pd.Int8Dtype(),
297
+ np.int16: pd.Int16Dtype(),
298
+ np.int32: pd.Int32Dtype(),
299
+ np.int64: pd.Int64Dtype(),
300
+ np.uint8: pd.UInt8Dtype(),
301
+ np.uint16: pd.UInt16Dtype(),
302
+ np.uint32: pd.UInt32Dtype(),
303
+ np.uint64: pd.UInt64Dtype(),
304
+ np.float32: pd.Float32Dtype(),
305
+ np.float64: pd.Float64Dtype(),
306
+ np.bool_: pd.BooleanDtype(),
307
+ np.str_: pd.StringDtype(),
308
+ }
309
+
310
+ return np_to_pd_dtype_mapping.get(np_type, np_type) # type: ignore[arg-type]
311
+
312
+ return np_type
266
313
  return np.object_
267
314
 
268
315
  def __eq__(self, other: object) -> bool:
@@ -273,7 +320,10 @@ class FeatureSpec(BaseFeatureSpec):
273
320
 
274
321
  def __repr__(self) -> str:
275
322
  shape_str = f", shape={repr(self._shape)}" if self._shape else ""
276
- return f"FeatureSpec(dtype={repr(self._dtype)}, name={repr(self._name)}{shape_str})"
323
+ return (
324
+ f"FeatureSpec(dtype={repr(self._dtype)}, "
325
+ f"name={repr(self._name)}{shape_str}, nullable={repr(self._nullable)})"
326
+ )
277
327
 
278
328
  def to_dict(self) -> Dict[str, Any]:
279
329
  """Serialize the feature group into a dict.
@@ -281,10 +331,7 @@ class FeatureSpec(BaseFeatureSpec):
281
331
  Returns:
282
332
  A dict that serializes the feature group.
283
333
  """
284
- base_dict: Dict[str, Any] = {
285
- "type": self._dtype.name,
286
- "name": self._name,
287
- }
334
+ base_dict: Dict[str, Any] = {"type": self._dtype.name, "name": self._name, "nullable": self._nullable}
288
335
  if self._shape is not None:
289
336
  base_dict["shape"] = self._shape
290
337
  return base_dict
@@ -304,7 +351,9 @@ class FeatureSpec(BaseFeatureSpec):
304
351
  if shape:
305
352
  shape = tuple(shape)
306
353
  type = DataType[input_dict["type"]]
307
- return FeatureSpec(name=name, dtype=type, shape=shape)
354
+ # If nullable is not provided, default to False for backward compatibility.
355
+ nullable = input_dict.get("nullable", False)
356
+ return FeatureSpec(name=name, dtype=type, shape=shape, nullable=nullable)
308
357
 
309
358
  @classmethod
310
359
  def from_mlflow_spec(
@@ -475,10 +524,8 @@ class ModelSignature:
475
524
  sig_outs = loaded["outputs"]
476
525
  sig_inputs = loaded["inputs"]
477
526
 
478
- deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = (
479
- lambda sig_spec: FeatureGroupSpec.from_dict(sig_spec)
480
- if "feature_group" in sig_spec
481
- else FeatureSpec.from_dict(sig_spec)
527
+ deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = lambda sig_spec: (
528
+ FeatureGroupSpec.from_dict(sig_spec) if "feature_group" in sig_spec else FeatureSpec.from_dict(sig_spec)
482
529
  )
483
530
 
484
531
  return ModelSignature(
@@ -1,4 +1,5 @@
1
- from typing import Literal, Sequence
1
+ import warnings
2
+ from typing import Literal, Sequence, Union
2
3
 
3
4
  import numpy as np
4
5
  import pandas as pd
@@ -14,8 +15,8 @@ from snowflake.ml.model._signatures import base_handler, core, utils
14
15
 
15
16
  class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
16
17
  @staticmethod
17
- def can_handle(data: model_types.SupportedDataType) -> TypeGuard[pd.DataFrame]:
18
- return isinstance(data, pd.DataFrame)
18
+ def can_handle(data: model_types.SupportedDataType) -> TypeGuard[Union[pd.DataFrame, pd.Series]]:
19
+ return isinstance(data, pd.DataFrame) or isinstance(data, pd.Series)
19
20
 
20
21
  @staticmethod
21
22
  def count(data: pd.DataFrame) -> int:
@@ -26,7 +27,17 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
26
27
  return data.head(min(PandasDataFrameHandler.count(data), PandasDataFrameHandler.SIG_INFER_ROWS_COUNT_LIMIT))
27
28
 
28
29
  @staticmethod
29
- def validate(data: pd.DataFrame) -> None:
30
+ def validate(data: Union[pd.DataFrame, pd.Series]) -> None:
31
+ if isinstance(data, pd.Series):
32
+ # check if the series is empty and throw error
33
+ if data.empty:
34
+ raise snowml_exceptions.SnowflakeMLException(
35
+ error_code=error_codes.INVALID_DATA,
36
+ original_exception=ValueError("Data Validation Error: Empty data is found."),
37
+ )
38
+ # convert the series to a dataframe
39
+ data = data.to_frame()
40
+
30
41
  df_cols = data.columns
31
42
 
32
43
  if df_cols.has_duplicates: # Rule out categorical index with duplicates
@@ -60,21 +71,44 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
60
71
 
61
72
  df_col_dtypes = [data[col].dtype for col in data.columns]
62
73
  for df_col, df_col_dtype in zip(df_cols, df_col_dtypes):
74
+ df_col_data = data[df_col]
75
+ if df_col_data.isnull().all():
76
+ raise snowml_exceptions.SnowflakeMLException(
77
+ error_code=error_codes.INVALID_DATA,
78
+ original_exception=ValueError(
79
+ f"Data Validation Error: There is no non-null data in column {df_col}."
80
+ ),
81
+ )
82
+ if df_col_data.isnull().any():
83
+ warnings.warn(
84
+ (
85
+ f"Null value detected in column {df_col}, model signature inference might not accurate, "
86
+ "or your prediction might fail if your model does not support null input. If this is not "
87
+ "expected, please check your input dataframe."
88
+ ),
89
+ category=UserWarning,
90
+ stacklevel=2,
91
+ )
92
+
93
+ df_col_data = utils.series_dropna(df_col_data)
94
+ df_col_dtype = df_col_data.dtype
95
+
63
96
  if df_col_dtype == np.dtype("O"):
64
97
  # Check if all objects have the same type
65
- if not all(isinstance(data_row, type(data[df_col].iloc[0])) for data_row in data[df_col]):
98
+ if not all(isinstance(data_row, type(df_col_data.iloc[0])) for data_row in df_col_data):
66
99
  raise snowml_exceptions.SnowflakeMLException(
67
100
  error_code=error_codes.INVALID_DATA,
68
101
  original_exception=ValueError(
69
- f"Data Validation Error: Inconsistent type of object found in column data {data[df_col]}."
102
+ "Data Validation Error: "
103
+ + f"Inconsistent type of element in object found in column data {df_col_data}."
70
104
  ),
71
105
  )
72
106
 
73
- if isinstance(data[df_col].iloc[0], list):
74
- arr = utils.convert_list_to_ndarray(data[df_col].iloc[0])
107
+ if isinstance(df_col_data.iloc[0], list):
108
+ arr = utils.convert_list_to_ndarray(df_col_data.iloc[0])
75
109
  arr_dtype = core.DataType.from_numpy_type(arr.dtype)
76
110
 
77
- converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in data[df_col]]
111
+ converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data]
78
112
 
79
113
  if not all(
80
114
  core.DataType.from_numpy_type(converted_data.dtype) == arr_dtype
@@ -84,32 +118,37 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
84
118
  error_code=error_codes.INVALID_DATA,
85
119
  original_exception=ValueError(
86
120
  "Data Validation Error: "
87
- + f"Inconsistent type of element in object found in column data {data[df_col]}."
121
+ + f"Inconsistent type of element in object found in column data {df_col_data}."
88
122
  ),
89
123
  )
90
124
 
91
- elif isinstance(data[df_col].iloc[0], np.ndarray):
92
- arr_dtype = core.DataType.from_numpy_type(data[df_col].iloc[0].dtype)
125
+ elif isinstance(df_col_data.iloc[0], np.ndarray):
126
+ arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype)
93
127
 
94
- if not all(core.DataType.from_numpy_type(data_row.dtype) == arr_dtype for data_row in data[df_col]):
128
+ if not all(core.DataType.from_numpy_type(data_row.dtype) == arr_dtype for data_row in df_col_data):
95
129
  raise snowml_exceptions.SnowflakeMLException(
96
130
  error_code=error_codes.INVALID_DATA,
97
131
  original_exception=ValueError(
98
132
  "Data Validation Error: "
99
- + f"Inconsistent type of element in object found in column data {data[df_col]}."
133
+ + f"Inconsistent type of element in object found in column data {df_col_data}."
100
134
  ),
101
135
  )
102
- elif not isinstance(data[df_col].iloc[0], (str, bytes)):
136
+ elif not isinstance(df_col_data.iloc[0], (str, bytes)):
103
137
  raise snowml_exceptions.SnowflakeMLException(
104
138
  error_code=error_codes.INVALID_DATA,
105
139
  original_exception=ValueError(
106
- f"Data Validation Error: Unsupported type confronted in {data[df_col]}"
140
+ f"Data Validation Error: Unsupported type confronted in {df_col_data}"
107
141
  ),
108
142
  )
109
143
 
110
144
  @staticmethod
111
- def infer_signature(data: pd.DataFrame, role: Literal["input", "output"]) -> Sequence[core.BaseFeatureSpec]:
145
+ def infer_signature(
146
+ data: Union[pd.DataFrame, pd.Series],
147
+ role: Literal["input", "output"],
148
+ ) -> Sequence[core.BaseFeatureSpec]:
112
149
  feature_prefix = f"{PandasDataFrameHandler.FEATURE_PREFIX}_"
150
+ if isinstance(data, pd.Series):
151
+ data = data.to_frame()
113
152
  df_cols = data.columns
114
153
  role_prefix = (
115
154
  PandasDataFrameHandler.INPUT_PREFIX if role == "input" else PandasDataFrameHandler.OUTPUT_PREFIX
@@ -123,30 +162,51 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
123
162
 
124
163
  specs = []
125
164
  for df_col, df_col_dtype, ft_name in zip(df_cols, df_col_dtypes, ft_names):
165
+ df_col_data = data[df_col]
166
+ if df_col_data.isnull().any():
167
+ df_col_data = utils.series_dropna(df_col_data)
168
+ df_col_dtype = df_col_data.dtype
169
+
126
170
  if df_col_dtype == np.dtype("O"):
127
- if isinstance(data[df_col].iloc[0], list):
128
- arr = utils.convert_list_to_ndarray(data[df_col].iloc[0])
171
+ if isinstance(df_col_data.iloc[0], list):
172
+ arr = utils.convert_list_to_ndarray(df_col_data.iloc[0])
129
173
  arr_dtype = core.DataType.from_numpy_type(arr.dtype)
130
- ft_shape = np.shape(data[df_col].iloc[0])
174
+ ft_shape = np.shape(df_col_data.iloc[0])
131
175
 
132
- converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in data[df_col]]
176
+ converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data]
133
177
 
134
178
  if not all(np.shape(converted_data) == ft_shape for converted_data in converted_data_list):
135
179
  ft_shape = (-1,)
136
180
 
137
181
  specs.append(core.FeatureSpec(dtype=arr_dtype, name=ft_name, shape=ft_shape))
138
- elif isinstance(data[df_col].iloc[0], np.ndarray):
139
- arr_dtype = core.DataType.from_numpy_type(data[df_col].iloc[0].dtype)
140
- ft_shape = np.shape(data[df_col].iloc[0])
182
+ elif isinstance(df_col_data.iloc[0], np.ndarray):
183
+ arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype)
184
+ ft_shape = np.shape(df_col_data.iloc[0])
141
185
 
142
- if not all(np.shape(data_row) == ft_shape for data_row in data[df_col]):
186
+ if not all(np.shape(data_row) == ft_shape for data_row in df_col_data):
143
187
  ft_shape = (-1,)
144
188
 
145
189
  specs.append(core.FeatureSpec(dtype=arr_dtype, name=ft_name, shape=ft_shape))
146
- elif isinstance(data[df_col].iloc[0], str):
190
+ elif isinstance(df_col_data.iloc[0], str):
147
191
  specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
148
- elif isinstance(data[df_col].iloc[0], bytes):
192
+ elif isinstance(df_col_data.iloc[0], bytes):
149
193
  specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
194
+ elif isinstance(df_col_dtype, pd.CategoricalDtype):
195
+ category_dtype = df_col_dtype.categories.dtype
196
+ if category_dtype == np.dtype("O"):
197
+ if isinstance(df_col_dtype.categories[0], str):
198
+ specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
199
+ elif isinstance(df_col_dtype.categories[0], bytes):
200
+ specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
201
+ else:
202
+ raise snowml_exceptions.SnowflakeMLException(
203
+ error_code=error_codes.INVALID_DATA,
204
+ original_exception=ValueError(
205
+ f"Data Validation Error: Unsupported type confronted in {df_col_dtype.categories[0]}"
206
+ ),
207
+ )
208
+ else:
209
+ specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(category_dtype), name=ft_name))
150
210
  elif isinstance(data[df_col].iloc[0], np.datetime64):
151
211
  specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
152
212
  else:
@@ -72,10 +72,10 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
72
72
  dtype = core.DataType.from_torch_type(data_col.dtype)
73
73
  ft_name = f"{role_prefix}{feature_prefix}{i}"
74
74
  if len(data_col.shape) == 1:
75
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name))
75
+ features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False))
76
76
  else:
77
77
  ft_shape = tuple(data_col.shape[1:])
78
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
78
+ features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False))
79
79
  return features
80
80
 
81
81
  @staticmethod
@@ -82,7 +82,8 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
82
82
  identifier.get_unescaped_names(field.name)
83
83
  ].map(json.loads)
84
84
  # Only when the feature is not from inference, we are confident to do the type casting.
85
- # Otherwise, dtype_map will be empty
85
+ # Otherwise, dtype_map will be empty.
86
+ # Errors are ignored to make sure None won't be converted and won't raise Error
86
87
  df_local = df_local.astype(dtype=dtype_map)
87
88
  return df_local
88
89