snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. snowflake/cortex/_complete.py +19 -0
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/platform_capabilities.py +87 -0
  4. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  5. snowflake/ml/_internal/telemetry.py +21 -0
  6. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  7. snowflake/ml/dataset/dataset.py +0 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +6 -0
  11. snowflake/ml/jobs/__init__.py +21 -0
  12. snowflake/ml/jobs/_utils/constants.py +57 -0
  13. snowflake/ml/jobs/_utils/payload_utils.py +438 -0
  14. snowflake/ml/jobs/_utils/spec_utils.py +296 -0
  15. snowflake/ml/jobs/_utils/types.py +39 -0
  16. snowflake/ml/jobs/decorators.py +71 -0
  17. snowflake/ml/jobs/job.py +113 -0
  18. snowflake/ml/jobs/manager.py +298 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +11 -2
  20. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  21. snowflake/ml/model/_client/sql/service.py +13 -6
  22. snowflake/ml/model/_packager/model_env/model_env.py +45 -28
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +19 -6
  24. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  25. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +17 -0
  26. snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
  28. snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
  30. snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  32. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
  33. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
  36. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  37. snowflake/ml/model/_signatures/base_handler.py +1 -2
  38. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  39. snowflake/ml/model/_signatures/core.py +2 -2
  40. snowflake/ml/model/_signatures/numpy_handler.py +11 -12
  41. snowflake/ml/model/_signatures/pandas_handler.py +11 -9
  42. snowflake/ml/model/_signatures/pytorch_handler.py +3 -6
  43. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  44. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  45. snowflake/ml/model/model_signature.py +25 -4
  46. snowflake/ml/model/type_hints.py +15 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  51. snowflake/ml/modeling/cluster/birch.py +6 -3
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  53. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  55. snowflake/ml/modeling/cluster/k_means.py +6 -3
  56. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  58. snowflake/ml/modeling/cluster/optics.py +6 -3
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  62. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  69. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  70. snowflake/ml/modeling/covariance/oas.py +6 -3
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  74. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  79. snowflake/ml/modeling/decomposition/pca.py +6 -3
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  110. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  111. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  112. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  113. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  114. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  115. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  116. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  117. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  118. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  119. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  120. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  121. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  122. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  123. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  124. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  125. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  126. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  127. snowflake/ml/modeling/linear_model/lars.py +6 -3
  128. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  129. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  130. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  131. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  132. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  133. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  134. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  135. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  136. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  137. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  139. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  140. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  141. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  142. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  143. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  144. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  145. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  146. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  147. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  148. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  149. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  150. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  151. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  152. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  153. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  154. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  155. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  156. snowflake/ml/modeling/manifold/isomap.py +6 -3
  157. snowflake/ml/modeling/manifold/mds.py +6 -3
  158. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  159. snowflake/ml/modeling/manifold/tsne.py +6 -3
  160. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  161. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  162. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  163. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  174. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  184. snowflake/ml/modeling/pipeline/pipeline.py +28 -3
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -5
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  188. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  189. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  190. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  191. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  192. snowflake/ml/modeling/svm/svc.py +6 -3
  193. snowflake/ml/modeling/svm/svr.py +6 -3
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
  202. snowflake/ml/registry/registry.py +34 -4
  203. snowflake/ml/version.py +1 -1
  204. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +81 -33
  205. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +208 -196
  206. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
  207. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
  208. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,11 @@ def get_task_skl(model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pi
24
24
  from sklearn.base import is_classifier, is_regressor
25
25
 
26
26
  if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
27
- return type_hints.Task.UNKNOWN
27
+ if hasattr(model, "predict_proba") or hasattr(model, "predict"):
28
+ model = model.steps[-1][1] # type: ignore[attr-defined]
29
+ return _get_model_task(model)
30
+ else:
31
+ return type_hints.Task.UNKNOWN
28
32
  if is_regressor(model):
29
33
  return type_hints.Task.TABULAR_REGRESSION
30
34
  if is_classifier(model):
@@ -12,7 +12,6 @@ class BaseDataHandler(ABC, Generic[model_types._DataType]):
12
12
  FEATURE_PREFIX: Final[str] = "feature"
13
13
  INPUT_PREFIX: Final[str] = "input"
14
14
  OUTPUT_PREFIX: Final[str] = "output"
15
- SIG_INFER_ROWS_COUNT_LIMIT: Final[int] = 10
16
15
 
17
16
  @staticmethod
18
17
  @abstractmethod
@@ -26,7 +25,7 @@ class BaseDataHandler(ABC, Generic[model_types._DataType]):
26
25
 
27
26
  @staticmethod
28
27
  @abstractmethod
29
- def truncate(data: model_types._DataType) -> model_types._DataType:
28
+ def truncate(data: model_types._DataType, length: int) -> model_types._DataType:
30
29
  ...
31
30
 
32
31
  @staticmethod
@@ -35,8 +35,8 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
35
35
  return len(data)
36
36
 
37
37
  @staticmethod
38
- def truncate(data: model_types._SupportedBuiltinsList) -> model_types._SupportedBuiltinsList:
39
- return data[: min(ListOfBuiltinHandler.count(data), ListOfBuiltinHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
38
+ def truncate(data: model_types._SupportedBuiltinsList, length: int) -> model_types._SupportedBuiltinsList:
39
+ return data[: min(ListOfBuiltinHandler.count(data), length)]
40
40
 
41
41
  @staticmethod
42
42
  def validate(data: model_types._SupportedBuiltinsList) -> None:
@@ -282,7 +282,7 @@ class FeatureSpec(BaseFeatureSpec):
282
282
  result_type = spt.ArrayType(result_type)
283
283
  return result_type
284
284
 
285
- def as_dtype(self) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
285
+ def as_dtype(self, force_numpy_dtype: bool = False) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
286
286
  """Convert to corresponding local Type."""
287
287
 
288
288
  if not self._shape:
@@ -291,7 +291,7 @@ class FeatureSpec(BaseFeatureSpec):
291
291
  return self._dtype._value
292
292
 
293
293
  np_type = self._dtype._numpy_type
294
- if self._nullable:
294
+ if self._nullable and not force_numpy_dtype:
295
295
  np_to_pd_dtype_mapping = {
296
296
  np.int8: pd.Int8Dtype(),
297
297
  np.int16: pd.Int16Dtype(),
@@ -23,8 +23,8 @@ class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpy
23
23
  return data.shape[0]
24
24
 
25
25
  @staticmethod
26
- def truncate(data: model_types._SupportedNumpyArray) -> model_types._SupportedNumpyArray:
27
- return data[: min(NumpyArrayHandler.count(data), NumpyArrayHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
26
+ def truncate(data: model_types._SupportedNumpyArray, length: int) -> model_types._SupportedNumpyArray:
27
+ return data[: min(NumpyArrayHandler.count(data), length)]
28
28
 
29
29
  @staticmethod
30
30
  def validate(data: model_types._SupportedNumpyArray) -> None:
@@ -50,7 +50,7 @@ class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpy
50
50
  dtype = core.DataType.from_numpy_type(data.dtype)
51
51
  role_prefix = (NumpyArrayHandler.INPUT_PREFIX if role == "input" else NumpyArrayHandler.OUTPUT_PREFIX) + "_"
52
52
  if len(data.shape) == 1:
53
- return [core.FeatureSpec(dtype=dtype, name=f"{role_prefix}{feature_prefix}0")]
53
+ return [core.FeatureSpec(dtype=dtype, name=f"{role_prefix}{feature_prefix}0", nullable=False)]
54
54
  else:
55
55
  # For high-dimension array, 0-axis is for batch, 1-axis is for column, further more is details of columns.
56
56
  features = []
@@ -59,9 +59,9 @@ class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpy
59
59
  for col_data, ft_name in zip(data[0], ft_names):
60
60
  if isinstance(col_data, np.ndarray):
61
61
  ft_shape = np.shape(col_data)
62
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
62
+ features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False))
63
63
  else:
64
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name))
64
+ features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False))
65
65
  return features
66
66
 
67
67
  @staticmethod
@@ -94,11 +94,10 @@ class SeqOfNumpyArrayHandler(base_handler.BaseDataHandler[Sequence[model_types._
94
94
  return min(NumpyArrayHandler.count(data_col) for data_col in data)
95
95
 
96
96
  @staticmethod
97
- def truncate(data: Sequence[model_types._SupportedNumpyArray]) -> Sequence[model_types._SupportedNumpyArray]:
98
- return [
99
- data_col[: min(SeqOfNumpyArrayHandler.count(data), SeqOfNumpyArrayHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
100
- for data_col in data
101
- ]
97
+ def truncate(
98
+ data: Sequence[model_types._SupportedNumpyArray], length: int
99
+ ) -> Sequence[model_types._SupportedNumpyArray]:
100
+ return [data_col[: min(SeqOfNumpyArrayHandler.count(data), length)] for data_col in data]
102
101
 
103
102
  @staticmethod
104
103
  def validate(data: Sequence[model_types._SupportedNumpyArray]) -> None:
@@ -119,10 +118,10 @@ class SeqOfNumpyArrayHandler(base_handler.BaseDataHandler[Sequence[model_types._
119
118
  dtype = core.DataType.from_numpy_type(data_col.dtype)
120
119
  ft_name = f"{role_prefix}{feature_prefix}{i}"
121
120
  if len(data_col.shape) == 1:
122
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name))
121
+ features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False))
123
122
  else:
124
123
  ft_shape = tuple(data_col.shape[1:])
125
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape))
124
+ features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False))
126
125
  return features
127
126
 
128
127
  @staticmethod
@@ -23,8 +23,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
23
23
  return len(data.index)
24
24
 
25
25
  @staticmethod
26
- def truncate(data: pd.DataFrame) -> pd.DataFrame:
27
- return data.head(min(PandasDataFrameHandler.count(data), PandasDataFrameHandler.SIG_INFER_ROWS_COUNT_LIMIT))
26
+ def truncate(data: pd.DataFrame, length: int) -> pd.DataFrame:
27
+ return data.head(min(PandasDataFrameHandler.count(data), length))
28
28
 
29
29
  @staticmethod
30
30
  def validate(data: Union[pd.DataFrame, pd.Series]) -> None:
@@ -72,13 +72,6 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
72
72
  df_col_dtypes = [data[col].dtype for col in data.columns]
73
73
  for df_col, df_col_dtype in zip(df_cols, df_col_dtypes):
74
74
  df_col_data = data[df_col]
75
- if df_col_data.isnull().all():
76
- raise snowml_exceptions.SnowflakeMLException(
77
- error_code=error_codes.INVALID_DATA,
78
- original_exception=ValueError(
79
- f"Data Validation Error: There is no non-null data in column {df_col}."
80
- ),
81
- )
82
75
  if df_col_data.isnull().any():
83
76
  warnings.warn(
84
77
  (
@@ -163,6 +156,15 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
163
156
  specs = []
164
157
  for df_col, df_col_dtype, ft_name in zip(df_cols, df_col_dtypes, ft_names):
165
158
  df_col_data = data[df_col]
159
+
160
+ if df_col_data.isnull().all():
161
+ raise snowml_exceptions.SnowflakeMLException(
162
+ error_code=error_codes.INVALID_DATA,
163
+ original_exception=ValueError(
164
+ "Data Validation Error: "
165
+ f"There is no non-null data in column {df_col} so the signature cannot be inferred."
166
+ ),
167
+ )
166
168
  if df_col_data.isnull().any():
167
169
  df_col_data = utils.series_dropna(df_col_data)
168
170
  df_col_dtype = df_col_data.dtype
@@ -30,14 +30,11 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
30
30
 
31
31
  @staticmethod
32
32
  def count(data: Sequence["torch.Tensor"]) -> int:
33
- return min(data_col.shape[0] for data_col in data) # type: ignore[no-any-return]
33
+ return min(data_col.shape[0] for data_col in data)
34
34
 
35
35
  @staticmethod
36
- def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]:
37
- return [
38
- data_col[: min(SeqOfPyTorchTensorHandler.count(data), SeqOfPyTorchTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
39
- for data_col in data
40
- ]
36
+ def truncate(data: Sequence["torch.Tensor"], length: int) -> Sequence["torch.Tensor"]:
37
+ return [data_col[: min(SeqOfPyTorchTensorHandler.count(data), 10)] for data_col in data]
41
38
 
42
39
  @staticmethod
43
40
  def validate(data: Sequence["torch.Tensor"]) -> None:
@@ -29,8 +29,8 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
29
29
  return data.count()
30
30
 
31
31
  @staticmethod
32
- def truncate(data: snowflake.snowpark.DataFrame) -> snowflake.snowpark.DataFrame:
33
- return cast(snowflake.snowpark.DataFrame, data.limit(SnowparkDataFrameHandler.SIG_INFER_ROWS_COUNT_LIMIT))
32
+ def truncate(data: snowflake.snowpark.DataFrame, length: int) -> snowflake.snowpark.DataFrame:
33
+ return cast(snowflake.snowpark.DataFrame, data.limit(length))
34
34
 
35
35
  @staticmethod
36
36
  def validate(data: snowflake.snowpark.DataFrame) -> None:
@@ -52,7 +52,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
52
52
  data: snowflake.snowpark.DataFrame, role: Literal["input", "output"]
53
53
  ) -> Sequence[core.BaseFeatureSpec]:
54
54
  return pandas_handler.PandasDataFrameHandler.infer_signature(
55
- SnowparkDataFrameHandler.convert_to_df(data.limit(n=1)), role=role
55
+ SnowparkDataFrameHandler.convert_to_df(data), role=role
56
56
  )
57
57
 
58
58
  @staticmethod
@@ -60,14 +60,9 @@ class SeqOfTensorflowTensorHandler(
60
60
 
61
61
  @staticmethod
62
62
  def truncate(
63
- data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]
63
+ data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]], length: int
64
64
  ) -> Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]:
65
- return [
66
- data_col[
67
- : min(SeqOfTensorflowTensorHandler.count(data), SeqOfTensorflowTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)
68
- ]
69
- for data_col in data
70
- ]
65
+ return [data_col[: min(SeqOfTensorflowTensorHandler.count(data), length)] for data_col in data]
71
66
 
72
67
  @staticmethod
73
68
  def validate(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> None:
@@ -21,6 +21,7 @@ from typing_extensions import Never
21
21
  import snowflake.snowpark
22
22
  import snowflake.snowpark.functions as F
23
23
  import snowflake.snowpark.types as spt
24
+ from snowflake.ml._internal import telemetry
24
25
  from snowflake.ml._internal.exceptions import (
25
26
  error_codes,
26
27
  exceptions as snowml_exceptions,
@@ -56,14 +57,22 @@ _LOCAL_DATA_HANDLERS: List[Type[base_handler.BaseDataHandler[Any]]] = [
56
57
  ]
57
58
  _ALL_DATA_HANDLERS = _LOCAL_DATA_HANDLERS + [snowpark_handler.SnowparkDataFrameHandler]
58
59
 
60
+ _TELEMETRY_PROJECT = "MLOps"
61
+ _MODEL_TELEMETRY_SUBPROJECT = "ModelSignature"
62
+
59
63
 
60
64
  def _truncate_data(
61
65
  data: model_types.SupportedDataType,
66
+ length: Optional[int] = 100,
62
67
  ) -> model_types.SupportedDataType:
63
68
  for handler in _ALL_DATA_HANDLERS:
64
69
  if handler.can_handle(data):
70
+ # If length is None, return the original data
71
+ if length is None:
72
+ return data
73
+
65
74
  row_count = handler.count(data)
66
- if row_count <= handler.SIG_INFER_ROWS_COUNT_LIMIT:
75
+ if row_count <= length:
67
76
  return data
68
77
 
69
78
  warnings.warn(
@@ -77,7 +86,7 @@ def _truncate_data(
77
86
  category=UserWarning,
78
87
  stacklevel=1,
79
88
  )
80
- return handler.truncate(data)
89
+ return handler.truncate(data, length)
81
90
  raise snowml_exceptions.SnowflakeMLException(
82
91
  error_code=error_codes.NOT_IMPLEMENTED,
83
92
  original_exception=NotImplementedError(
@@ -682,11 +691,17 @@ def _convert_and_validate_local_data(
682
691
  return df
683
692
 
684
693
 
694
+ @telemetry.send_api_usage_telemetry(
695
+ project=_TELEMETRY_PROJECT,
696
+ subproject=_MODEL_TELEMETRY_SUBPROJECT,
697
+ )
685
698
  def infer_signature(
686
699
  input_data: model_types.SupportedLocalDataType,
687
700
  output_data: model_types.SupportedLocalDataType,
688
701
  input_feature_names: Optional[List[str]] = None,
689
702
  output_feature_names: Optional[List[str]] = None,
703
+ input_data_limit: Optional[int] = 100,
704
+ output_data_limit: Optional[int] = 100,
690
705
  ) -> core.ModelSignature:
691
706
  """
692
707
  Infer model signature from given input and output sample data.
@@ -710,12 +725,18 @@ def infer_signature(
710
725
  output_data: Sample output data for the model.
711
726
  input_feature_names: Names for input features. Defaults to None.
712
727
  output_feature_names: Names for output features. Defaults to None.
728
+ input_data_limit: Limit the number of rows to be used in signature inference in the input data. Defaults to 100.
729
+ If None, all rows are used. If the number of rows in the input data is less than the limit, all rows are
730
+ used.
731
+ output_data_limit: Limit the number of rows to be used in signature inference in the output data. Defaults to
732
+ 100. If None, all rows are used. If the number of rows in the output data is less than the limit, all rows
733
+ are used.
713
734
 
714
735
  Returns:
715
736
  A model signature inferred from the given input and output sample data.
716
737
  """
717
- inputs = _infer_signature(input_data, role="input")
738
+ inputs = _infer_signature(_truncate_data(input_data, input_data_limit), role="input")
718
739
  inputs = utils.rename_features(inputs, input_feature_names)
719
- outputs = _infer_signature(output_data, role="output")
740
+ outputs = _infer_signature(_truncate_data(output_data, output_data_limit), role="output")
720
741
  outputs = utils.rename_features(outputs, output_feature_names)
721
742
  return core.ModelSignature(inputs, outputs)
@@ -7,6 +7,7 @@ from typing_extensions import NotRequired
7
7
 
8
8
  if TYPE_CHECKING:
9
9
  import catboost
10
+ import keras
10
11
  import lightgbm
11
12
  import mlflow
12
13
  import numpy as np
@@ -68,6 +69,7 @@ SupportedRequireSignatureModelType = Union[
68
69
  "torch.nn.Module",
69
70
  "torch.jit.ScriptModule",
70
71
  "tensorflow.Module",
72
+ "keras.Model",
71
73
  ]
72
74
 
73
75
  SupportedNoSignatureRequirementsModelType = Union[
@@ -103,6 +105,7 @@ Here is all acceptable types of Snowflake native model packaging and its handler
103
105
  | transformers.Pipeline | huggingface_pipeline.py | _HuggingFacePipelineHandler |
104
106
  | huggingface_pipeline.HuggingFacePipelineModel | huggingface_pipeline.py | _HuggingFacePipelineHandler |
105
107
  | sentence_transformers.SentenceTransformer | sentence_transformers.py | _SentenceTransformerHandler |
108
+ | keras.Model | keras.py | _KerasHandler |
106
109
  """
107
110
 
108
111
  SupportedModelHandlerType = Literal[
@@ -118,6 +121,7 @@ SupportedModelHandlerType = Literal[
118
121
  "tensorflow",
119
122
  "torchscript",
120
123
  "xgboost",
124
+ "keras",
121
125
  ]
122
126
 
123
127
  _ModelType = TypeVar("_ModelType", bound=SupportedModelType)
@@ -202,6 +206,11 @@ class SentenceTransformersSaveOptions(BaseModelSaveOption):
202
206
  batch_size: NotRequired[int]
203
207
 
204
208
 
209
+ class KerasSaveOptions(BaseModelSaveOption):
210
+ target_methods: NotRequired[Sequence[str]]
211
+ cuda_version: NotRequired[str]
212
+
213
+
205
214
  ModelSaveOption = Union[
206
215
  BaseModelSaveOption,
207
216
  CatBoostModelSaveOptions,
@@ -216,6 +225,7 @@ ModelSaveOption = Union[
216
225
  MLFlowSaveOptions,
217
226
  HuggingFaceSaveOptions,
218
227
  SentenceTransformersSaveOptions,
228
+ KerasSaveOptions,
219
229
  ]
220
230
 
221
231
 
@@ -276,6 +286,10 @@ class SentenceTransformersLoadOptions(BaseModelLoadOption):
276
286
  device: NotRequired[str]
277
287
 
278
288
 
289
+ class KerasLoadOptions(BaseModelLoadOption):
290
+ use_gpu: NotRequired[bool]
291
+
292
+
279
293
  ModelLoadOption = Union[
280
294
  BaseModelLoadOption,
281
295
  CatBoostModelLoadOptions,
@@ -290,6 +304,7 @@ ModelLoadOption = Union[
290
304
  MLFlowLoadOptions,
291
305
  HuggingFaceLoadOptions,
292
306
  SentenceTransformersLoadOptions,
307
+ KerasLoadOptions,
293
308
  ]
294
309
 
295
310
 
@@ -199,8 +199,21 @@ class SnowparkTransformHandlers:
199
199
  if expected_output_cols_type == "":
200
200
  expected_output_cols_type = "string"
201
201
  assert expected_output_cols_type is not None
202
+
203
+ # If there is only one output column, the UDF might have generate complex objects (lists, dicts).
204
+ # In such cases, we attempt to not do explicit cast. (Example: PolynomialFeatures.transform)
205
+ try_parse_object = len(expected_output_cols) == 1 and expected_output_cols_type != "string"
202
206
  for output_feature in expected_output_cols:
203
- output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_feature].astype(expected_output_cols_type))
207
+ column_expr = F.col(INTERMEDIATE_OBJ_NAME)[output_feature]
208
+
209
+ if try_parse_object and df_res.count() > 0:
210
+ # Only do type casting if it's not an array
211
+ if not df_res.select(F.is_array(column_expr)).first()[0]:
212
+ column_expr = column_expr.astype(expected_output_cols_type)
213
+ else:
214
+ column_expr = column_expr.astype(expected_output_cols_type)
215
+
216
+ output_cols.append(column_expr)
204
217
  output_col_names.append(identifier.get_inferred_name(output_feature))
205
218
 
206
219
  # Extract output from INTERMEDIATE_OBJ_NAME and drop that column
@@ -37,6 +37,7 @@ from snowflake.ml.model.model_signature import (
37
37
  FeatureSpec,
38
38
  ModelSignature,
39
39
  _infer_signature,
40
+ _truncate_data,
40
41
  _rename_signature_with_snowflake_identifiers,
41
42
  )
42
43
 
@@ -57,6 +58,8 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.calibration".replace("sk
57
58
 
58
59
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
59
60
 
61
+ INFER_SIGNATURE_MAX_ROWS = 100
62
+
60
63
  class CalibratedClassifierCV(BaseTransformer):
61
64
  r"""Probability calibration with isotonic regression or logistic regression
62
65
  For more details on this class, see [sklearn.calibration.CalibratedClassifierCV]
@@ -465,7 +468,7 @@ class CalibratedClassifierCV(BaseTransformer):
465
468
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
466
469
  expected_dtype = "array"
467
470
  else:
468
- output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
471
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True)]
469
472
  # We can only infer the output types from the input types if the following two statemetns are true:
470
473
  # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
471
474
  # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
@@ -1126,7 +1129,7 @@ class CalibratedClassifierCV(BaseTransformer):
1126
1129
 
1127
1130
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1128
1131
 
1129
- inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1132
+ inputs = list(_infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "input", use_snowflake_identifiers=True))
1130
1133
  outputs: List[BaseFeatureSpec] = []
1131
1134
  if hasattr(self, "predict"):
1132
1135
  # keep mypy happy
@@ -1134,7 +1137,7 @@ class CalibratedClassifierCV(BaseTransformer):
1134
1137
  # For classifier, the type of predict is the same as the type of label
1135
1138
  if self._sklearn_object._estimator_type == "classifier":
1136
1139
  # label columns is the desired type for output
1137
- outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1140
+ outputs = list(_infer_signature(_truncate_data(dataset[self.label_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True))
1138
1141
  # rename the output columns
1139
1142
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1140
1143
  self._model_signature_dict["predict"] = ModelSignature(
@@ -37,6 +37,7 @@ from snowflake.ml.model.model_signature import (
37
37
  FeatureSpec,
38
38
  ModelSignature,
39
39
  _infer_signature,
40
+ _truncate_data,
40
41
  _rename_signature_with_snowflake_identifiers,
41
42
  )
42
43
 
@@ -57,6 +58,8 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
57
58
 
58
59
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
59
60
 
61
+ INFER_SIGNATURE_MAX_ROWS = 100
62
+
60
63
  class AffinityPropagation(BaseTransformer):
61
64
  r"""Perform Affinity Propagation Clustering of data
62
65
  For more details on this class, see [sklearn.cluster.AffinityPropagation]
@@ -449,7 +452,7 @@ class AffinityPropagation(BaseTransformer):
449
452
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
450
453
  expected_dtype = "array"
451
454
  else:
452
- output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
455
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True)]
453
456
  # We can only infer the output types from the input types if the following two statemetns are true:
454
457
  # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
455
458
  # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
@@ -1106,7 +1109,7 @@ class AffinityPropagation(BaseTransformer):
1106
1109
 
1107
1110
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1108
1111
 
1109
- inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1112
+ inputs = list(_infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "input", use_snowflake_identifiers=True))
1110
1113
  outputs: List[BaseFeatureSpec] = []
1111
1114
  if hasattr(self, "predict"):
1112
1115
  # keep mypy happy
@@ -1114,7 +1117,7 @@ class AffinityPropagation(BaseTransformer):
1114
1117
  # For classifier, the type of predict is the same as the type of label
1115
1118
  if self._sklearn_object._estimator_type == "classifier":
1116
1119
  # label columns is the desired type for output
1117
- outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1120
+ outputs = list(_infer_signature(_truncate_data(dataset[self.label_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True))
1118
1121
  # rename the output columns
1119
1122
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1120
1123
  self._model_signature_dict["predict"] = ModelSignature(
@@ -37,6 +37,7 @@ from snowflake.ml.model.model_signature import (
37
37
  FeatureSpec,
38
38
  ModelSignature,
39
39
  _infer_signature,
40
+ _truncate_data,
40
41
  _rename_signature_with_snowflake_identifiers,
41
42
  )
42
43
 
@@ -57,6 +58,8 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
57
58
 
58
59
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
59
60
 
61
+ INFER_SIGNATURE_MAX_ROWS = 100
62
+
60
63
  class AgglomerativeClustering(BaseTransformer):
61
64
  r"""Agglomerative Clustering
62
65
  For more details on this class, see [sklearn.cluster.AgglomerativeClustering]
@@ -478,7 +481,7 @@ class AgglomerativeClustering(BaseTransformer):
478
481
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
479
482
  expected_dtype = "array"
480
483
  else:
481
- output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
484
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True)]
482
485
  # We can only infer the output types from the input types if the following two statemetns are true:
483
486
  # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
484
487
  # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
@@ -1135,7 +1138,7 @@ class AgglomerativeClustering(BaseTransformer):
1135
1138
 
1136
1139
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1137
1140
 
1138
- inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1141
+ inputs = list(_infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "input", use_snowflake_identifiers=True))
1139
1142
  outputs: List[BaseFeatureSpec] = []
1140
1143
  if hasattr(self, "predict"):
1141
1144
  # keep mypy happy
@@ -1143,7 +1146,7 @@ class AgglomerativeClustering(BaseTransformer):
1143
1146
  # For classifier, the type of predict is the same as the type of label
1144
1147
  if self._sklearn_object._estimator_type == "classifier":
1145
1148
  # label columns is the desired type for output
1146
- outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1149
+ outputs = list(_infer_signature(_truncate_data(dataset[self.label_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True))
1147
1150
  # rename the output columns
1148
1151
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1149
1152
  self._model_signature_dict["predict"] = ModelSignature(
@@ -37,6 +37,7 @@ from snowflake.ml.model.model_signature import (
37
37
  FeatureSpec,
38
38
  ModelSignature,
39
39
  _infer_signature,
40
+ _truncate_data,
40
41
  _rename_signature_with_snowflake_identifiers,
41
42
  )
42
43
 
@@ -57,6 +58,8 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
57
58
 
58
59
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
59
60
 
61
+ INFER_SIGNATURE_MAX_ROWS = 100
62
+
60
63
  class Birch(BaseTransformer):
61
64
  r"""Implements the BIRCH clustering algorithm
62
65
  For more details on this class, see [sklearn.cluster.Birch]
@@ -442,7 +445,7 @@ class Birch(BaseTransformer):
442
445
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
443
446
  expected_dtype = "array"
444
447
  else:
445
- output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
448
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True)]
446
449
  # We can only infer the output types from the input types if the following two statemetns are true:
447
450
  # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
448
451
  # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
@@ -1101,7 +1104,7 @@ class Birch(BaseTransformer):
1101
1104
 
1102
1105
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1103
1106
 
1104
- inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1107
+ inputs = list(_infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "input", use_snowflake_identifiers=True))
1105
1108
  outputs: List[BaseFeatureSpec] = []
1106
1109
  if hasattr(self, "predict"):
1107
1110
  # keep mypy happy
@@ -1109,7 +1112,7 @@ class Birch(BaseTransformer):
1109
1112
  # For classifier, the type of predict is the same as the type of label
1110
1113
  if self._sklearn_object._estimator_type == "classifier":
1111
1114
  # label columns is the desired type for output
1112
- outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1115
+ outputs = list(_infer_signature(_truncate_data(dataset[self.label_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True))
1113
1116
  # rename the output columns
1114
1117
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1115
1118
  self._model_signature_dict["predict"] = ModelSignature(
@@ -37,6 +37,7 @@ from snowflake.ml.model.model_signature import (
37
37
  FeatureSpec,
38
38
  ModelSignature,
39
39
  _infer_signature,
40
+ _truncate_data,
40
41
  _rename_signature_with_snowflake_identifiers,
41
42
  )
42
43
 
@@ -57,6 +58,8 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
57
58
 
58
59
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
59
60
 
61
+ INFER_SIGNATURE_MAX_ROWS = 100
62
+
60
63
  class BisectingKMeans(BaseTransformer):
61
64
  r"""Bisecting K-Means clustering
62
65
  For more details on this class, see [sklearn.cluster.BisectingKMeans]
@@ -491,7 +494,7 @@ class BisectingKMeans(BaseTransformer):
491
494
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
492
495
  expected_dtype = "array"
493
496
  else:
494
- output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
497
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True)]
495
498
  # We can only infer the output types from the input types if the following two statemetns are true:
496
499
  # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
497
500
  # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
@@ -1152,7 +1155,7 @@ class BisectingKMeans(BaseTransformer):
1152
1155
 
1153
1156
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1154
1157
 
1155
- inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1158
+ inputs = list(_infer_signature(_truncate_data(dataset[self.input_cols], INFER_SIGNATURE_MAX_ROWS), "input", use_snowflake_identifiers=True))
1156
1159
  outputs: List[BaseFeatureSpec] = []
1157
1160
  if hasattr(self, "predict"):
1158
1161
  # keep mypy happy
@@ -1160,7 +1163,7 @@ class BisectingKMeans(BaseTransformer):
1160
1163
  # For classifier, the type of predict is the same as the type of label
1161
1164
  if self._sklearn_object._estimator_type == "classifier":
1162
1165
  # label columns is the desired type for output
1163
- outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1166
+ outputs = list(_infer_signature(_truncate_data(dataset[self.label_cols], INFER_SIGNATURE_MAX_ROWS), "output", use_snowflake_identifiers=True))
1164
1167
  # rename the output columns
1165
1168
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1166
1169
  self._model_signature_dict["predict"] = ModelSignature(