snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. snowflake/ml/_internal/env_utils.py +2 -1
  2. snowflake/ml/_internal/file_utils.py +35 -40
  3. snowflake/ml/_internal/telemetry.py +5 -8
  4. snowflake/ml/_internal/utils/identifier.py +74 -7
  5. snowflake/ml/_internal/utils/uri.py +7 -2
  6. snowflake/ml/model/_core_requirements.py +1 -1
  7. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
  8. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
  9. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
  10. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
  11. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
  12. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
  13. snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
  14. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
  15. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
  16. snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
  17. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
  18. snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
  19. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
  20. snowflake/ml/model/_deployer.py +14 -27
  21. snowflake/ml/model/_env.py +4 -4
  22. snowflake/ml/model/_handlers/_base.py +3 -1
  23. snowflake/ml/model/_handlers/custom.py +14 -2
  24. snowflake/ml/model/_handlers/pytorch.py +186 -0
  25. snowflake/ml/model/_handlers/sklearn.py +14 -8
  26. snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
  27. snowflake/ml/model/_handlers/torchscript.py +180 -0
  28. snowflake/ml/model/_handlers/xgboost.py +19 -9
  29. snowflake/ml/model/_model.py +27 -21
  30. snowflake/ml/model/_model_meta.py +33 -19
  31. snowflake/ml/model/model_signature.py +446 -66
  32. snowflake/ml/model/type_hints.py +28 -15
  33. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
  34. snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
  35. snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
  36. snowflake/ml/modeling/cluster/birch.py +79 -43
  37. snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
  38. snowflake/ml/modeling/cluster/dbscan.py +79 -43
  39. snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
  40. snowflake/ml/modeling/cluster/k_means.py +79 -43
  41. snowflake/ml/modeling/cluster/mean_shift.py +79 -43
  42. snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
  43. snowflake/ml/modeling/cluster/optics.py +79 -43
  44. snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
  45. snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
  46. snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
  47. snowflake/ml/modeling/compose/column_transformer.py +79 -43
  48. snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
  49. snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
  50. snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
  51. snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
  52. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
  53. snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
  54. snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
  55. snowflake/ml/modeling/covariance/oas.py +79 -43
  56. snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
  57. snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
  58. snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
  59. snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
  60. snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
  61. snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
  62. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
  63. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
  64. snowflake/ml/modeling/decomposition/pca.py +79 -43
  65. snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
  66. snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
  68. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
  69. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
  70. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
  71. snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
  72. snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
  73. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
  74. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
  75. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
  76. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
  79. snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
  80. snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
  81. snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
  82. snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
  83. snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
  84. snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
  85. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
  86. snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
  87. snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
  88. snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
  89. snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
  90. snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
  91. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
  92. snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
  94. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
  95. snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
  96. snowflake/ml/modeling/impute/knn_imputer.py +79 -43
  97. snowflake/ml/modeling/impute/missing_indicator.py +79 -43
  98. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
  99. snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
  100. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
  101. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
  102. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
  103. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
  104. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
  105. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
  106. snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
  107. snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
  108. snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
  109. snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
  110. snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
  111. snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
  112. snowflake/ml/modeling/linear_model/lars.py +79 -43
  113. snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
  114. snowflake/ml/modeling/linear_model/lasso.py +79 -43
  115. snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
  116. snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
  117. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
  118. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
  119. snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
  120. snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
  121. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
  123. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
  124. snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
  125. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
  126. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
  127. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
  128. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
  129. snowflake/ml/modeling/linear_model/perceptron.py +79 -43
  130. snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
  131. snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
  132. snowflake/ml/modeling/linear_model/ridge.py +79 -43
  133. snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
  134. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
  135. snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
  136. snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
  137. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
  138. snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
  139. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
  140. snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
  141. snowflake/ml/modeling/manifold/isomap.py +79 -43
  142. snowflake/ml/modeling/manifold/mds.py +79 -43
  143. snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
  144. snowflake/ml/modeling/manifold/tsne.py +79 -43
  145. snowflake/ml/modeling/metrics/classification.py +6 -1
  146. snowflake/ml/modeling/metrics/regression.py +517 -9
  147. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
  148. snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
  149. snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
  150. snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
  151. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
  152. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
  153. snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
  154. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
  155. snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
  156. snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
  157. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
  158. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
  159. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
  160. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
  161. snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
  162. snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
  163. snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
  164. snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
  165. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
  166. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
  167. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
  168. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
  169. snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
  170. snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
  171. snowflake/ml/modeling/pipeline/pipeline.py +24 -0
  172. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
  173. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
  174. snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
  175. snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
  176. snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
  177. snowflake/ml/modeling/svm/linear_svc.py +79 -43
  178. snowflake/ml/modeling/svm/linear_svr.py +79 -43
  179. snowflake/ml/modeling/svm/nu_svc.py +79 -43
  180. snowflake/ml/modeling/svm/nu_svr.py +79 -43
  181. snowflake/ml/modeling/svm/svc.py +79 -43
  182. snowflake/ml/modeling/svm/svr.py +79 -43
  183. snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
  184. snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
  185. snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
  186. snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
  187. snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
  188. snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
  189. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
  190. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
  191. snowflake/ml/registry/model_registry.py +123 -121
  192. snowflake/ml/version.py +1 -1
  193. {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
  194. snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
  195. snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
  196. {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -1,8 +1,10 @@
1
+ import json
1
2
  import textwrap
2
3
  import warnings
3
4
  from abc import ABC, abstractmethod
4
5
  from enum import Enum
5
6
  from typing import (
7
+ TYPE_CHECKING,
6
8
  Any,
7
9
  Callable,
8
10
  Dict,
@@ -26,8 +28,14 @@ from typing_extensions import TypeGuard
26
28
 
27
29
  import snowflake.snowpark
28
30
  import snowflake.snowpark.types as spt
31
+ from snowflake.ml._internal import type_utils
29
32
  from snowflake.ml._internal.utils import formatting, identifier
30
33
  from snowflake.ml.model import type_hints as model_types
34
+ from snowflake.ml.model._deploy_client.warehouse import infer_template
35
+
36
+ if TYPE_CHECKING:
37
+ import tensorflow
38
+ import torch
31
39
 
32
40
 
33
41
  class DataType(Enum):
@@ -36,22 +44,22 @@ class DataType(Enum):
36
44
  self._snowpark_type = snowpark_type
37
45
  self._numpy_type = numpy_type
38
46
 
39
- INT8 = ("int8", spt.IntegerType, np.int8)
40
- INT16 = ("int16", spt.IntegerType, np.int16)
47
+ INT8 = ("int8", spt.ByteType, np.int8)
48
+ INT16 = ("int16", spt.ShortType, np.int16)
41
49
  INT32 = ("int32", spt.IntegerType, np.int32)
42
- INT64 = ("int64", spt.IntegerType, np.int64)
50
+ INT64 = ("int64", spt.LongType, np.int64)
43
51
 
44
52
  FLOAT = ("float", spt.FloatType, np.float32)
45
53
  DOUBLE = ("double", spt.DoubleType, np.float64)
46
54
 
47
- UINT8 = ("uint8", spt.IntegerType, np.uint8)
48
- UINT16 = ("uint16", spt.IntegerType, np.uint16)
55
+ UINT8 = ("uint8", spt.ByteType, np.uint8)
56
+ UINT16 = ("uint16", spt.ShortType, np.uint16)
49
57
  UINT32 = ("uint32", spt.IntegerType, np.uint32)
50
- UINT64 = ("uint64", spt.IntegerType, np.uint64)
58
+ UINT64 = ("uint64", spt.LongType, np.uint64)
51
59
 
52
- BOOL = ("bool", spt.BooleanType, np.bool8)
53
- STRING = ("string", spt.StringType, np.str0)
54
- BYTES = ("bytes", spt.BinaryType, np.bytes0)
60
+ BOOL = ("bool", spt.BooleanType, np.bool_)
61
+ STRING = ("string", spt.StringType, np.str_)
62
+ BYTES = ("bytes", spt.BinaryType, np.bytes_)
55
63
 
56
64
  def as_snowpark_type(self) -> spt.DataType:
57
65
  """Convert to corresponding Snowpark Type.
@@ -84,6 +92,30 @@ class DataType(Enum):
84
92
  return np_to_snowml_type_mapping[potential_type]
85
93
  raise NotImplementedError(f"Type {np_type} is not supported as a DataType.")
86
94
 
95
+ @classmethod
96
+ def from_torch_type(cls, torch_type: "torch.dtype") -> "DataType":
97
+ import torch
98
+
99
+ """Translate torch dtype to DataType for signature definition.
100
+
101
+ Args:
102
+ torch_type: The torch dtype.
103
+
104
+ Returns:
105
+ Corresponding DataType.
106
+ """
107
+ torch_dtype_to_numpy_dtype_mapping = {
108
+ torch.uint8: np.uint8,
109
+ torch.int8: np.int8,
110
+ torch.int16: np.int16,
111
+ torch.int32: np.int32,
112
+ torch.int64: np.int64,
113
+ torch.float32: np.float32,
114
+ torch.float64: np.float64,
115
+ torch.bool: np.bool_,
116
+ }
117
+ return cls.from_numpy_type(torch_dtype_to_numpy_dtype_mapping[torch_type])
118
+
87
119
  @classmethod
88
120
  def from_snowpark_type(cls, snowpark_type: spt.DataType) -> "DataType":
89
121
  """Translate snowpark type to DataType for signature definition.
@@ -97,30 +129,45 @@ class DataType(Enum):
97
129
  Returns:
98
130
  Corresponding DataType.
99
131
  """
132
+ if isinstance(snowpark_type, spt.ArrayType):
133
+ actual_sp_type = snowpark_type.element_type
134
+ else:
135
+ actual_sp_type = snowpark_type
136
+
100
137
  snowpark_to_snowml_type_mapping: Dict[Type[spt.DataType], DataType] = {
101
- spt._IntegralType: DataType.INT64,
102
- **{i._snowpark_type: i for i in DataType if i._snowpark_type != spt.IntegerType},
138
+ i._snowpark_type: i
139
+ for i in DataType
140
+ # We by default infer as signed integer.
141
+ if i not in [DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64]
103
142
  }
104
143
  for potential_type in snowpark_to_snowml_type_mapping.keys():
105
- if isinstance(snowpark_type, potential_type):
144
+ if isinstance(actual_sp_type, potential_type):
106
145
  return snowpark_to_snowml_type_mapping[potential_type]
146
+ # Fallback for decimal type.
147
+ if isinstance(snowpark_type, spt.DecimalType):
148
+ if snowpark_type.scale == 0:
149
+ return DataType.INT64
107
150
  raise NotImplementedError(f"Type {snowpark_type} is not supported as a DataType.")
108
151
 
109
152
  def is_same_snowpark_type(self, incoming_snowpark_type: spt.DataType) -> bool:
110
153
  """Check if provided snowpark type is the same as Data Type.
111
- Since for Snowflake all integer types are same, thus when datatype is a integer type, the incoming snowpark
112
- type can be any type inherit from _IntegralType.
113
154
 
114
155
  Args:
115
156
  incoming_snowpark_type: The snowpark type.
116
157
 
158
+ Raises:
159
+ NotImplementedError: Raised when the given numpy type is not supported.
160
+
117
161
  Returns:
118
162
  If the provided snowpark type is the same as the DataType.
119
163
  """
120
- if self._snowpark_type == spt.IntegerType:
121
- return isinstance(incoming_snowpark_type, spt._IntegralType)
122
- else:
123
- return isinstance(incoming_snowpark_type, self._snowpark_type)
164
+ # Special handle for Decimal Type.
165
+ if isinstance(incoming_snowpark_type, spt.DecimalType):
166
+ if incoming_snowpark_type.scale == 0:
167
+ return self == DataType.INT64 or self == DataType.UINT64
168
+ raise NotImplementedError(f"Type {incoming_snowpark_type} is not supported as a DataType.")
169
+
170
+ return isinstance(incoming_snowpark_type, self._snowpark_type)
124
171
 
125
172
 
126
173
  class BaseFeatureSpec(ABC):
@@ -174,9 +221,19 @@ class FeatureSpec(BaseFeatureSpec):
174
221
  (2,): 1d list with fixed len of 2.
175
222
  (-1,): 1d list with variable length. Used for ragged tensor representation.
176
223
  (d1, d2, d3): 3d tensor.
224
+
225
+ Raises:
226
+ TypeError: Raised when the dtype input type is incorrect.
227
+ TypeError: Raised when the shape input type is incorrect.
177
228
  """
178
229
  super().__init__(name=name)
230
+
231
+ if not isinstance(dtype, DataType):
232
+ raise TypeError("dtype should be a model signature datatype.")
179
233
  self._dtype = dtype
234
+
235
+ if shape and not isinstance(shape, tuple):
236
+ raise TypeError("Shape should be a tuple if presented.")
180
237
  self._shape = shape
181
238
 
182
239
  def as_snowpark_type(self) -> spt.DataType:
@@ -191,7 +248,7 @@ class FeatureSpec(BaseFeatureSpec):
191
248
  """Convert to corresponding local Type."""
192
249
  if not self._shape:
193
250
  return self._dtype._numpy_type
194
- return np.object0
251
+ return np.object_
195
252
 
196
253
  def __eq__(self, other: object) -> bool:
197
254
  if isinstance(other, FeatureSpec):
@@ -229,6 +286,8 @@ class FeatureSpec(BaseFeatureSpec):
229
286
  """
230
287
  name = input_dict["name"]
231
288
  shape = input_dict.get("shape", None)
289
+ if shape:
290
+ shape = tuple(shape)
232
291
  type = DataType[input_dict["type"]]
233
292
  return FeatureSpec(name=name, dtype=type, shape=shape)
234
293
 
@@ -421,7 +480,7 @@ class _BaseDataHandler(ABC, Generic[model_types._DataType]):
421
480
 
422
481
  @staticmethod
423
482
  @abstractmethod
424
- def convert_to_df(data: model_types._DataType) -> Union[pd.DataFrame, snowflake.snowpark.DataFrame]:
483
+ def convert_to_df(data: model_types._DataType, ensure_serializable: bool = True) -> pd.DataFrame:
425
484
  ...
426
485
 
427
486
 
@@ -454,7 +513,7 @@ class _PandasDataFrameHandler(_BaseDataHandler[pd.DataFrame]):
454
513
  np.int64,
455
514
  np.uint64,
456
515
  np.float64,
457
- np.object0,
516
+ np.object_,
458
517
  ]: # To keep compatibility with Pandas 2.x and 1.x
459
518
  raise ValueError("Data Validation Error: Unsupported column index type is found.")
460
519
 
@@ -538,7 +597,17 @@ class _PandasDataFrameHandler(_BaseDataHandler[pd.DataFrame]):
538
597
  return specs
539
598
 
540
599
  @staticmethod
541
- def convert_to_df(data: pd.DataFrame) -> pd.DataFrame:
600
+ def convert_to_df(data: pd.DataFrame, ensure_serializable: bool = True) -> pd.DataFrame:
601
+ if not ensure_serializable:
602
+ return data
603
+ # This convert is necessary since numpy dataframe cannot be correctly handled when provided as an element of
604
+ # a list when creating Snowpark Dataframe.
605
+ df_cols = data.columns
606
+ df_col_dtypes = [data[col].dtype for col in data.columns]
607
+ for df_col, df_col_dtype in zip(df_cols, df_col_dtypes):
608
+ if df_col_dtype == np.dtype("O"):
609
+ if isinstance(data[df_col][0], np.ndarray):
610
+ data[df_col] = data[df_col].map(np.ndarray.tolist)
542
611
  return data
543
612
 
544
613
 
@@ -569,7 +638,7 @@ class _NumpyArrayHandler(_BaseDataHandler[model_types._SupportedNumpyArray]):
569
638
  def infer_signature(
570
639
  data: model_types._SupportedNumpyArray, role: Literal["input", "output"]
571
640
  ) -> Sequence[BaseFeatureSpec]:
572
- feature_prefix = f"{_PandasDataFrameHandler.FEATURE_PREFIX}_"
641
+ feature_prefix = f"{_NumpyArrayHandler.FEATURE_PREFIX}_"
573
642
  dtype = DataType.from_numpy_type(data.dtype)
574
643
  role_prefix = (_NumpyArrayHandler.INPUT_PREFIX if role == "input" else _NumpyArrayHandler.OUTPUT_PREFIX) + "_"
575
644
  if len(data.shape) == 1:
@@ -588,68 +657,269 @@ class _NumpyArrayHandler(_BaseDataHandler[model_types._SupportedNumpyArray]):
588
657
  return features
589
658
 
590
659
  @staticmethod
591
- def convert_to_df(data: model_types._SupportedNumpyArray) -> pd.DataFrame:
660
+ def convert_to_df(data: model_types._SupportedNumpyArray, ensure_serializable: bool = True) -> pd.DataFrame:
592
661
  if len(data.shape) == 1:
593
662
  data = np.expand_dims(data, axis=1)
594
663
  n_cols = data.shape[1]
595
664
  if len(data.shape) == 2:
596
- return pd.DataFrame(data={i: data[:, i] for i in range(n_cols)})
665
+ return pd.DataFrame(data)
597
666
  else:
598
667
  n_rows = data.shape[0]
599
- return pd.DataFrame(data={i: [np.array(data[k, i]) for k in range(n_rows)] for i in range(n_cols)})
668
+ if ensure_serializable:
669
+ return pd.DataFrame(data={i: [data[k, i].tolist() for k in range(n_rows)] for i in range(n_cols)})
670
+ return pd.DataFrame(data={i: [list(data[k, i]) for k in range(n_rows)] for i in range(n_cols)})
600
671
 
601
672
 
602
- class _ListOfNumpyArrayHandler(_BaseDataHandler[List[model_types._SupportedNumpyArray]]):
673
+ class _SeqOfNumpyArrayHandler(_BaseDataHandler[Sequence[model_types._SupportedNumpyArray]]):
603
674
  @staticmethod
604
- def can_handle(data: model_types.SupportedDataType) -> TypeGuard[List[model_types._SupportedNumpyArray]]:
605
- return (
606
- isinstance(data, list)
607
- and len(data) > 0
608
- and all(_NumpyArrayHandler.can_handle(data_col) for data_col in data)
609
- )
675
+ def can_handle(data: model_types.SupportedDataType) -> TypeGuard[Sequence[model_types._SupportedNumpyArray]]:
676
+ if not isinstance(data, list):
677
+ return False
678
+ if len(data) == 0:
679
+ return False
680
+ if isinstance(data[0], np.ndarray):
681
+ return all(isinstance(data_col, np.ndarray) for data_col in data)
682
+ return False
610
683
 
611
684
  @staticmethod
612
- def count(data: List[model_types._SupportedNumpyArray]) -> int:
685
+ def count(data: Sequence[model_types._SupportedNumpyArray]) -> int:
613
686
  return min(_NumpyArrayHandler.count(data_col) for data_col in data)
614
687
 
615
688
  @staticmethod
616
- def truncate(data: List[model_types._SupportedNumpyArray]) -> List[model_types._SupportedNumpyArray]:
689
+ def truncate(data: Sequence[model_types._SupportedNumpyArray]) -> Sequence[model_types._SupportedNumpyArray]:
617
690
  return [
618
- data_col[: min(_ListOfNumpyArrayHandler.count(data), _ListOfNumpyArrayHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
691
+ data_col[: min(_SeqOfNumpyArrayHandler.count(data), _SeqOfNumpyArrayHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
619
692
  for data_col in data
620
693
  ]
621
694
 
622
695
  @staticmethod
623
- def validate(data: List[model_types._SupportedNumpyArray]) -> None:
696
+ def validate(data: Sequence[model_types._SupportedNumpyArray]) -> None:
624
697
  for data_col in data:
625
698
  _NumpyArrayHandler.validate(data_col)
626
699
 
627
700
  @staticmethod
628
701
  def infer_signature(
629
- data: List[model_types._SupportedNumpyArray], role: Literal["input", "output"]
702
+ data: Sequence[model_types._SupportedNumpyArray], role: Literal["input", "output"]
630
703
  ) -> Sequence[BaseFeatureSpec]:
704
+ feature_prefix = f"{_SeqOfNumpyArrayHandler.FEATURE_PREFIX}_"
631
705
  features: List[BaseFeatureSpec] = []
632
706
  role_prefix = (
633
- _ListOfNumpyArrayHandler.INPUT_PREFIX if role == "input" else _ListOfNumpyArrayHandler.OUTPUT_PREFIX
707
+ _SeqOfNumpyArrayHandler.INPUT_PREFIX if role == "input" else _SeqOfNumpyArrayHandler.OUTPUT_PREFIX
634
708
  ) + "_"
635
709
 
636
710
  for i, data_col in enumerate(data):
637
- inferred_res = _NumpyArrayHandler.infer_signature(data_col, role)
638
- for ft in inferred_res:
639
- ft._name = f"{role_prefix}{i}_{ft._name[len(role_prefix):]}"
640
- features.extend(inferred_res)
711
+ dtype = DataType.from_numpy_type(data_col.dtype)
712
+ ft_name = f"{role_prefix}{feature_prefix}{i}"
713
+ if len(data_col.shape) == 1:
714
+ features.append(FeatureSpec(dtype=dtype, name=ft_name))
715
+ else:
716
+ ft_shape = tuple(data_col.shape[1:])
717
+ features.append(FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
641
718
  return features
642
719
 
643
720
  @staticmethod
644
- def convert_to_df(data: List[model_types._SupportedNumpyArray]) -> pd.DataFrame:
645
- l_data = []
721
+ def convert_to_df(
722
+ data: Sequence[model_types._SupportedNumpyArray], ensure_serializable: bool = True
723
+ ) -> pd.DataFrame:
724
+ if ensure_serializable:
725
+ return pd.DataFrame(data={i: data_col.tolist() for i, data_col in enumerate(data)})
726
+ return pd.DataFrame(data={i: list(data_col) for i, data_col in enumerate(data)})
727
+
728
+
729
+ class _SeqOfPyTorchTensorHandler(_BaseDataHandler[Sequence["torch.Tensor"]]):
730
+ @staticmethod
731
+ def can_handle(data: model_types.SupportedDataType) -> TypeGuard[Sequence["torch.Tensor"]]:
732
+ if not isinstance(data, list):
733
+ return False
734
+ if len(data) == 0:
735
+ return False
736
+ if type_utils.LazyType("torch.Tensor").isinstance(data[0]):
737
+ return all(type_utils.LazyType("torch.Tensor").isinstance(data_col) for data_col in data)
738
+ return False
739
+
740
+ @staticmethod
741
+ def count(data: Sequence["torch.Tensor"]) -> int:
742
+ return min(data_col.shape[0] for data_col in data)
743
+
744
+ @staticmethod
745
+ def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]:
746
+ return [
747
+ data_col[
748
+ : min(_SeqOfPyTorchTensorHandler.count(data), _SeqOfPyTorchTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)
749
+ ]
750
+ for data_col in data
751
+ ]
752
+
753
+ @staticmethod
754
+ def validate(data: Sequence["torch.Tensor"]) -> None:
755
+ import torch
756
+
757
+ for data_col in data:
758
+ if data_col.shape == torch.Size([0]):
759
+ # Empty array
760
+ raise ValueError("Data Validation Error: Empty data is found.")
761
+
762
+ if data_col.shape == torch.Size([1]):
763
+ # scalar
764
+ raise ValueError("Data Validation Error: Scalar data is found.")
765
+
766
+ @staticmethod
767
+ def infer_signature(data: Sequence["torch.Tensor"], role: Literal["input", "output"]) -> Sequence[BaseFeatureSpec]:
768
+ feature_prefix = f"{_SeqOfPyTorchTensorHandler.FEATURE_PREFIX}_"
769
+ features: List[BaseFeatureSpec] = []
770
+ role_prefix = (
771
+ _SeqOfPyTorchTensorHandler.INPUT_PREFIX if role == "input" else _SeqOfPyTorchTensorHandler.OUTPUT_PREFIX
772
+ ) + "_"
773
+
774
+ for i, data_col in enumerate(data):
775
+ dtype = DataType.from_torch_type(data_col.dtype)
776
+ ft_name = f"{role_prefix}{feature_prefix}{i}"
777
+ if len(data_col.shape) == 1:
778
+ features.append(FeatureSpec(dtype=dtype, name=ft_name))
779
+ else:
780
+ ft_shape = tuple(data_col.shape[1:])
781
+ features.append(FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
782
+ return features
783
+
784
+ @staticmethod
785
+ def convert_to_df(data: Sequence["torch.Tensor"], ensure_serializable: bool = True) -> pd.DataFrame:
786
+ # Use list(...) instead of .tolist() to ensure that
787
+ # the content is still numpy array so that the type could be preserved.
788
+ # But that would not serializable and cannot use as UDF input and output.
789
+ if ensure_serializable:
790
+ return pd.DataFrame({i: data_col.detach().to("cpu").numpy().tolist() for i, data_col in enumerate(data)})
791
+ return pd.DataFrame({i: list(data_col.detach().to("cpu").numpy()) for i, data_col in enumerate(data)})
792
+
793
+ @staticmethod
794
+ def convert_from_df(
795
+ df: pd.DataFrame, features: Optional[Sequence[BaseFeatureSpec]] = None
796
+ ) -> Sequence["torch.Tensor"]:
797
+ import torch
798
+
799
+ res = []
800
+ if features:
801
+ for feature in features:
802
+ if isinstance(feature, FeatureGroupSpec):
803
+ raise NotImplementedError("FeatureGroupSpec is not supported.")
804
+ assert isinstance(feature, FeatureSpec), "Invalid feature kind."
805
+ res.append(torch.from_numpy(np.stack(df[feature.name].to_numpy()).astype(feature._dtype._numpy_type)))
806
+ return res
807
+ return [torch.from_numpy(np.stack(df[col].to_numpy())) for col in df]
808
+
809
+
810
+ class _SeqOfTensorflowTensorHandler(_BaseDataHandler[Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]]):
811
+ @staticmethod
812
+ def can_handle(
813
+ data: model_types.SupportedDataType,
814
+ ) -> TypeGuard[Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]]:
815
+ if not isinstance(data, list):
816
+ return False
817
+ if len(data) == 0:
818
+ return False
819
+ if type_utils.LazyType("tensorflow.Tensor").isinstance(data[0]) or type_utils.LazyType(
820
+ "tensorflow.Variable"
821
+ ).isinstance(data[0]):
822
+ return all(
823
+ type_utils.LazyType("tensorflow.Tensor").isinstance(data_col)
824
+ or type_utils.LazyType("tensorflow.Variable").isinstance(data_col)
825
+ for data_col in data
826
+ )
827
+ return False
828
+
829
+ @staticmethod
830
+ def count(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> int:
831
+ import tensorflow as tf
832
+
833
+ rows = []
834
+ for data_col in data:
835
+ shapes = data_col.shape.as_list()
836
+ if data_col.shape == tf.TensorShape(None) or (not shapes) or (shapes[0] is None):
837
+ # Unknown shape array
838
+ raise ValueError("Data Validation Error: Unknown shape data is found.")
839
+ # Make mypy happy
840
+ assert isinstance(shapes[0], int)
841
+
842
+ rows.append(shapes[0])
843
+
844
+ return min(rows)
845
+
846
+ @staticmethod
847
+ def truncate(
848
+ data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]
849
+ ) -> Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]:
850
+ return [
851
+ data_col[
852
+ : min(
853
+ _SeqOfTensorflowTensorHandler.count(data), _SeqOfTensorflowTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT
854
+ )
855
+ ]
856
+ for data_col in data
857
+ ]
858
+
859
+ @staticmethod
860
+ def validate(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> None:
861
+ import tensorflow as tf
862
+
646
863
  for data_col in data:
864
+ if data_col.shape == tf.TensorShape(None) or any(dim is None for dim in data_col.shape.as_list()):
865
+ # Unknown shape array
866
+ raise ValueError("Data Validation Error: Unknown shape data is found.")
867
+
868
+ if data_col.shape == tf.TensorShape([0]):
869
+ # Empty array
870
+ raise ValueError("Data Validation Error: Empty data is found.")
871
+
872
+ if data_col.shape == tf.TensorShape([1]) or data_col.shape == tf.TensorShape([]):
873
+ # scalar
874
+ raise ValueError("Data Validation Error: Scalar data is found.")
875
+
876
+ @staticmethod
877
+ def infer_signature(
878
+ data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]], role: Literal["input", "output"]
879
+ ) -> Sequence[BaseFeatureSpec]:
880
+ feature_prefix = f"{_SeqOfTensorflowTensorHandler.FEATURE_PREFIX}_"
881
+ features: List[BaseFeatureSpec] = []
882
+ role_prefix = (
883
+ _SeqOfTensorflowTensorHandler.INPUT_PREFIX
884
+ if role == "input"
885
+ else _SeqOfTensorflowTensorHandler.OUTPUT_PREFIX
886
+ ) + "_"
887
+
888
+ for i, data_col in enumerate(data):
889
+ dtype = DataType.from_numpy_type(data_col.dtype.as_numpy_dtype)
890
+ ft_name = f"{role_prefix}{feature_prefix}{i}"
647
891
  if len(data_col.shape) == 1:
648
- l_data.append(np.expand_dims(data_col, axis=1))
892
+ features.append(FeatureSpec(dtype=dtype, name=ft_name))
649
893
  else:
650
- l_data.append(data_col)
651
- arr = np.concatenate(l_data, axis=1)
652
- return _NumpyArrayHandler.convert_to_df(arr)
894
+ ft_shape = tuple(data_col.shape[1:])
895
+ features.append(FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
896
+ return features
897
+
898
+ @staticmethod
899
+ def convert_to_df(
900
+ data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]], ensure_serializable: bool = True
901
+ ) -> pd.DataFrame:
902
+ if ensure_serializable:
903
+ return pd.DataFrame({i: data_col.numpy().tolist() for i, data_col in enumerate(iterable=data)})
904
+ return pd.DataFrame({i: list(data_col.numpy()) for i, data_col in enumerate(iterable=data)})
905
+
906
+ @staticmethod
907
+ def convert_from_df(
908
+ df: pd.DataFrame, features: Optional[Sequence[BaseFeatureSpec]] = None
909
+ ) -> Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]:
910
+ import tensorflow as tf
911
+
912
+ res = []
913
+ if features:
914
+ for feature in features:
915
+ if isinstance(feature, FeatureGroupSpec):
916
+ raise NotImplementedError("FeatureGroupSpec is not supported.")
917
+ assert isinstance(feature, FeatureSpec), "Invalid feature kind."
918
+ res.append(
919
+ tf.convert_to_tensor(np.stack(df[feature.name].to_numpy()).astype(feature._dtype._numpy_type))
920
+ )
921
+ return res
922
+ return [tf.convert_to_tensor(np.stack(df[col].to_numpy())) for col in df]
653
923
 
654
924
 
655
925
  class _ListOfBuiltinHandler(_BaseDataHandler[model_types._SupportedBuiltinsList]):
@@ -684,7 +954,10 @@ class _ListOfBuiltinHandler(_BaseDataHandler[model_types._SupportedBuiltinsList]
684
954
  return _PandasDataFrameHandler.infer_signature(pd.DataFrame(data), role)
685
955
 
686
956
  @staticmethod
687
- def convert_to_df(data: model_types._SupportedBuiltinsList) -> pd.DataFrame:
957
+ def convert_to_df(
958
+ data: model_types._SupportedBuiltinsList,
959
+ ensure_serializable: bool = True,
960
+ ) -> pd.DataFrame:
688
961
  return pd.DataFrame(data)
689
962
 
690
963
 
@@ -705,7 +978,12 @@ class _SnowparkDataFrameHandler(_BaseDataHandler[snowflake.snowpark.DataFrame]):
705
978
  def validate(data: snowflake.snowpark.DataFrame) -> None:
706
979
  schema = data.schema
707
980
  for field in schema.fields:
708
- if not any(type.is_same_snowpark_type(field.datatype) for type in DataType):
981
+ data_type = field.datatype
982
+ if isinstance(data_type, spt.ArrayType):
983
+ actual_data_type = data_type.element_type
984
+ else:
985
+ actual_data_type = data_type
986
+ if not any(type.is_same_snowpark_type(actual_data_type) for type in DataType):
709
987
  raise ValueError(
710
988
  f"Data Validation Error: Unsupported data type {field.datatype} in column {field.name}."
711
989
  )
@@ -718,19 +996,91 @@ class _SnowparkDataFrameHandler(_BaseDataHandler[snowflake.snowpark.DataFrame]):
718
996
  schema = data.schema
719
997
  for field in schema.fields:
720
998
  name = identifier.get_unescaped_names(field.name)
721
- features.append(FeatureSpec(name=name, dtype=DataType.from_snowpark_type(field.datatype)))
999
+ if isinstance(field.datatype, spt.ArrayType):
1000
+ raise NotImplementedError("Cannot infer model signature from Snowpark DataFrame with Array Type.")
1001
+ else:
1002
+ features.append(FeatureSpec(name=name, dtype=DataType.from_snowpark_type(field.datatype)))
722
1003
  return features
723
1004
 
724
1005
  @staticmethod
725
- def convert_to_df(data: snowflake.snowpark.DataFrame) -> snowflake.snowpark.DataFrame:
726
- return data
1006
+ def convert_to_df(
1007
+ data: snowflake.snowpark.DataFrame,
1008
+ ensure_serializable: bool = True,
1009
+ features: Optional[Sequence[BaseFeatureSpec]] = None,
1010
+ ) -> pd.DataFrame:
1011
+ # This method do things on top of to_pandas, to make sure the local dataframe got is in correct shape.
1012
+ dtype_map = {}
1013
+ if features:
1014
+ for feature in features:
1015
+ if isinstance(feature, FeatureGroupSpec):
1016
+ raise NotImplementedError("FeatureGroupSpec is not supported.")
1017
+ assert isinstance(feature, FeatureSpec), "Invalid feature kind."
1018
+ dtype_map[feature.name] = feature.as_dtype()
1019
+ df_local = data.to_pandas()
1020
+ # This is because Array will become string (Even though the correct schema is set)
1021
+ # and object will become variant type and requires an additional loads
1022
+ # to get correct data otherwise it would be string.
1023
+ for field in data.schema.fields:
1024
+ if isinstance(field.datatype, spt.ArrayType):
1025
+ df_local[identifier.get_unescaped_names(field.name)] = df_local[
1026
+ identifier.get_unescaped_names(field.name)
1027
+ ].map(json.loads)
1028
+ # Only when the feature is not from inference, we are confident to do the type casting.
1029
+ # Otherwise, dtype_map will be empty
1030
+ df_local = df_local.astype(dtype=dtype_map)
1031
+ return df_local
1032
+
1033
+ @staticmethod
1034
+ def convert_from_df(
1035
+ session: snowflake.snowpark.Session, df: pd.DataFrame, keep_order: bool = True
1036
+ ) -> snowflake.snowpark.DataFrame:
1037
+ # This method is necessary to create the Snowpark Dataframe in correct schema.
1038
+ # Snowpark ignore the schema argument when providing a pandas DataFrame.
1039
+ # However, in this case, if a cell of the original Dataframe is some array type,
1040
+ # they will be inferred as VARIANT.
1041
+ # To make sure Snowpark get the correct schema, we have to provide in a list of records.
1042
+ # However, in this case, the order could not be preserved. Thus, a _ID column has to be added,
1043
+ # if keep_order is True.
1044
+ # Although in this case, the column with array type can get correct ARRAY type, however, the element
1045
+ # type is not preserved, and will become string type. This affect the implementation of convert_from_df.
1046
+ df = _PandasDataFrameHandler.convert_to_df(df)
1047
+ df_cols = df.columns
1048
+ if df_cols.dtype != np.object_:
1049
+ raise ValueError("Cannot convert a Pandas DataFrame whose column index is not a string")
1050
+ features = _PandasDataFrameHandler.infer_signature(df, role="input")
1051
+ # Role will be no effect on the column index. That is to say, the feature name is the actual column name.
1052
+ schema_list = []
1053
+ for feature in features:
1054
+ if isinstance(feature, FeatureGroupSpec):
1055
+ raise NotImplementedError("FeatureGroupSpec is not supported.")
1056
+ assert isinstance(feature, FeatureSpec), "Invalid feature kind."
1057
+ schema_list.append(
1058
+ spt.StructField(
1059
+ identifier.get_inferred_name(feature.name),
1060
+ feature.as_snowpark_type(),
1061
+ nullable=df[feature.name].isnull().any(),
1062
+ )
1063
+ )
1064
+
1065
+ data = df.rename(columns=identifier.get_inferred_name).to_dict("records")
1066
+ if keep_order:
1067
+ for idx, data_item in enumerate(data):
1068
+ data_item[infer_template._KEEP_ORDER_COL_NAME] = idx
1069
+ schema_list.append(spt.StructField(infer_template._KEEP_ORDER_COL_NAME, spt.LongType(), nullable=False))
1070
+ sp_df = session.create_dataframe(
1071
+ data, # To make sure the schema can be used, otherwise, array will become variant.
1072
+ spt.StructType(schema_list),
1073
+ )
1074
+ return sp_df
727
1075
 
728
1076
 
729
1077
  _LOCAL_DATA_HANDLERS: List[Type[_BaseDataHandler[Any]]] = [
730
1078
  _PandasDataFrameHandler,
731
1079
  _NumpyArrayHandler,
732
- _ListOfNumpyArrayHandler,
733
1080
  _ListOfBuiltinHandler,
1081
+ _SeqOfNumpyArrayHandler,
1082
+ _SeqOfPyTorchTensorHandler,
1083
+ _SeqOfTensorflowTensorHandler,
734
1084
  ]
735
1085
  _ALL_DATA_HANDLERS = _LOCAL_DATA_HANDLERS + [_SnowparkDataFrameHandler]
736
1086
 
@@ -1007,22 +1357,36 @@ def _validate_snowpark_data(data: snowflake.snowpark.DataFrame, features: Sequen
1007
1357
  raise NotImplementedError("FeatureGroupSpec is not supported.")
1008
1358
  assert isinstance(feature, FeatureSpec), "Invalid feature kind."
1009
1359
  ft_type = feature._dtype
1010
- if not ft_type.is_same_snowpark_type(field.datatype):
1011
- raise ValueError(
1012
- f"Data Validation Error in feature {ft_name}: "
1013
- + f"Feature type {ft_type} is not met by column {field.name}."
1360
+ field_data_type = field.datatype
1361
+ if isinstance(field_data_type, spt.ArrayType):
1362
+ if feature._shape is None:
1363
+ raise ValueError(
1364
+ f"Data Validation Error in feature {ft_name}: "
1365
+ + f"Feature is a array feature, while {field.name} is not."
1366
+ )
1367
+ warnings.warn(
1368
+ f"Warn in feature {ft_name}: Feature is a array feature," + " type validation cannot happen.",
1369
+ category=RuntimeWarning,
1014
1370
  )
1371
+ else:
1372
+ if feature._shape:
1373
+ raise ValueError(
1374
+ f"Data Validation Error in feature {ft_name}: "
1375
+ + f"Feature is a scalar feature, while {field.name} is not."
1376
+ )
1377
+ if not ft_type.is_same_snowpark_type(field_data_type):
1378
+ raise ValueError(
1379
+ f"Data Validation Error in feature {ft_name}: "
1380
+ + f"Feature type {ft_type} is not met by column {field.name}."
1381
+ )
1015
1382
  if not found:
1016
1383
  raise ValueError(f"Data Validation Error: feature {ft_name} does not exist in data.")
1017
1384
 
1018
1385
 
1019
- def _convert_and_validate_local_data(
1020
- data: model_types.SupportedDataType, features: Sequence[BaseFeatureSpec]
1021
- ) -> pd.DataFrame:
1022
- """Validate the data with features in model signature and convert to DataFrame
1386
+ def _convert_local_data_to_df(data: model_types.SupportedLocalDataType) -> pd.DataFrame:
1387
+ """Convert local data to pandas DataFrame or Snowpark DataFrame
1023
1388
 
1024
1389
  Args:
1025
- features: A list of feature specs that the data should follow.
1026
1390
  data: The provided data.
1027
1391
 
1028
1392
  Raises:
@@ -1035,13 +1399,29 @@ def _convert_and_validate_local_data(
1035
1399
  for handler in _LOCAL_DATA_HANDLERS:
1036
1400
  if handler.can_handle(data):
1037
1401
  handler.validate(data)
1038
- df = handler.convert_to_df(data)
1402
+ df = handler.convert_to_df(data, ensure_serializable=False)
1039
1403
  break
1040
1404
  if df is None:
1041
1405
  raise ValueError(f"Data Validation Error: Un-supported type {type(data)} provided.")
1042
- assert isinstance(df, pd.DataFrame)
1406
+ return df
1407
+
1408
+
1409
+ def _convert_and_validate_local_data(
1410
+ data: model_types.SupportedLocalDataType, features: Sequence[BaseFeatureSpec]
1411
+ ) -> pd.DataFrame:
1412
+ """Validate the data with features in model signature and convert to DataFrame
1413
+
1414
+ Args:
1415
+ features: A list of feature specs that the data should follow.
1416
+ data: The provided data.
1417
+
1418
+ Returns:
1419
+ The converted dataframe with renamed column index.
1420
+ """
1421
+ df = _convert_local_data_to_df(data)
1043
1422
  df = _rename_pandas_df(df, features)
1044
1423
  _validate_pandas_df(df, features)
1424
+ df = _PandasDataFrameHandler.convert_to_df(df, ensure_serializable=True)
1045
1425
 
1046
1426
  return df
1047
1427