snowflake-ml-python 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -13
  2. snowflake/ml/_internal/exceptions/modeling_error_messages.py +5 -1
  3. snowflake/ml/_internal/telemetry.py +19 -0
  4. snowflake/ml/feature_store/__init__.py +9 -0
  5. snowflake/ml/feature_store/entity.py +73 -0
  6. snowflake/ml/feature_store/feature_store.py +1657 -0
  7. snowflake/ml/feature_store/feature_view.py +459 -0
  8. snowflake/ml/model/_client/ops/model_ops.py +16 -38
  9. snowflake/ml/model/_client/sql/model.py +1 -7
  10. snowflake/ml/model/_client/sql/model_version.py +20 -15
  11. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +9 -1
  12. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  13. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +12 -2
  14. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +7 -3
  15. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -6
  16. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +0 -2
  17. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  18. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -2
  19. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  20. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  21. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  22. snowflake/ml/model/model_signature.py +72 -16
  23. snowflake/ml/model/type_hints.py +12 -0
  24. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -41
  25. snowflake/ml/modeling/_internal/model_trainer_builder.py +13 -9
  26. snowflake/ml/modeling/_internal/{distributed_hpo_trainer.py → snowpark_implementations/distributed_hpo_trainer.py} +66 -96
  27. snowflake/ml/modeling/_internal/{snowpark_handlers.py → snowpark_implementations/snowpark_handlers.py} +9 -6
  28. snowflake/ml/modeling/_internal/{xgboost_external_memory_trainer.py → snowpark_implementations/xgboost_external_memory_trainer.py} +3 -1
  29. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +19 -3
  30. snowflake/ml/modeling/cluster/affinity_propagation.py +19 -3
  31. snowflake/ml/modeling/cluster/agglomerative_clustering.py +19 -3
  32. snowflake/ml/modeling/cluster/birch.py +19 -3
  33. snowflake/ml/modeling/cluster/bisecting_k_means.py +19 -3
  34. snowflake/ml/modeling/cluster/dbscan.py +19 -3
  35. snowflake/ml/modeling/cluster/feature_agglomeration.py +19 -3
  36. snowflake/ml/modeling/cluster/k_means.py +19 -3
  37. snowflake/ml/modeling/cluster/mean_shift.py +19 -3
  38. snowflake/ml/modeling/cluster/mini_batch_k_means.py +19 -3
  39. snowflake/ml/modeling/cluster/optics.py +19 -3
  40. snowflake/ml/modeling/cluster/spectral_biclustering.py +19 -3
  41. snowflake/ml/modeling/cluster/spectral_clustering.py +19 -3
  42. snowflake/ml/modeling/cluster/spectral_coclustering.py +19 -3
  43. snowflake/ml/modeling/compose/column_transformer.py +19 -3
  44. snowflake/ml/modeling/compose/transformed_target_regressor.py +19 -3
  45. snowflake/ml/modeling/covariance/elliptic_envelope.py +19 -3
  46. snowflake/ml/modeling/covariance/empirical_covariance.py +19 -3
  47. snowflake/ml/modeling/covariance/graphical_lasso.py +19 -3
  48. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +19 -3
  49. snowflake/ml/modeling/covariance/ledoit_wolf.py +19 -3
  50. snowflake/ml/modeling/covariance/min_cov_det.py +19 -3
  51. snowflake/ml/modeling/covariance/oas.py +19 -3
  52. snowflake/ml/modeling/covariance/shrunk_covariance.py +19 -3
  53. snowflake/ml/modeling/decomposition/dictionary_learning.py +19 -3
  54. snowflake/ml/modeling/decomposition/factor_analysis.py +19 -3
  55. snowflake/ml/modeling/decomposition/fast_ica.py +19 -3
  56. snowflake/ml/modeling/decomposition/incremental_pca.py +19 -3
  57. snowflake/ml/modeling/decomposition/kernel_pca.py +19 -3
  58. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +19 -3
  59. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +19 -3
  60. snowflake/ml/modeling/decomposition/pca.py +19 -3
  61. snowflake/ml/modeling/decomposition/sparse_pca.py +19 -3
  62. snowflake/ml/modeling/decomposition/truncated_svd.py +19 -3
  63. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +19 -3
  64. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +19 -3
  65. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +19 -3
  66. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +19 -3
  67. snowflake/ml/modeling/ensemble/bagging_classifier.py +19 -3
  68. snowflake/ml/modeling/ensemble/bagging_regressor.py +19 -3
  69. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +19 -3
  70. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +19 -3
  71. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +19 -3
  72. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +19 -3
  73. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +19 -3
  74. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +19 -3
  75. snowflake/ml/modeling/ensemble/isolation_forest.py +19 -3
  76. snowflake/ml/modeling/ensemble/random_forest_classifier.py +19 -3
  77. snowflake/ml/modeling/ensemble/random_forest_regressor.py +19 -3
  78. snowflake/ml/modeling/ensemble/stacking_regressor.py +19 -3
  79. snowflake/ml/modeling/ensemble/voting_classifier.py +19 -3
  80. snowflake/ml/modeling/ensemble/voting_regressor.py +19 -3
  81. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +19 -3
  82. snowflake/ml/modeling/feature_selection/select_fdr.py +19 -3
  83. snowflake/ml/modeling/feature_selection/select_fpr.py +19 -3
  84. snowflake/ml/modeling/feature_selection/select_fwe.py +19 -3
  85. snowflake/ml/modeling/feature_selection/select_k_best.py +19 -3
  86. snowflake/ml/modeling/feature_selection/select_percentile.py +19 -3
  87. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +19 -3
  88. snowflake/ml/modeling/feature_selection/variance_threshold.py +19 -3
  89. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +19 -3
  90. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +19 -3
  91. snowflake/ml/modeling/impute/iterative_imputer.py +19 -3
  92. snowflake/ml/modeling/impute/knn_imputer.py +19 -3
  93. snowflake/ml/modeling/impute/missing_indicator.py +19 -3
  94. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +19 -3
  95. snowflake/ml/modeling/kernel_approximation/nystroem.py +19 -3
  96. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +19 -3
  97. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +19 -3
  98. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +19 -3
  99. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +19 -3
  100. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +19 -3
  101. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +19 -3
  102. snowflake/ml/modeling/linear_model/ard_regression.py +19 -3
  103. snowflake/ml/modeling/linear_model/bayesian_ridge.py +19 -3
  104. snowflake/ml/modeling/linear_model/elastic_net.py +19 -3
  105. snowflake/ml/modeling/linear_model/elastic_net_cv.py +19 -3
  106. snowflake/ml/modeling/linear_model/gamma_regressor.py +19 -3
  107. snowflake/ml/modeling/linear_model/huber_regressor.py +19 -3
  108. snowflake/ml/modeling/linear_model/lars.py +19 -3
  109. snowflake/ml/modeling/linear_model/lars_cv.py +19 -3
  110. snowflake/ml/modeling/linear_model/lasso.py +19 -3
  111. snowflake/ml/modeling/linear_model/lasso_cv.py +19 -3
  112. snowflake/ml/modeling/linear_model/lasso_lars.py +19 -3
  113. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +19 -3
  114. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +19 -3
  115. snowflake/ml/modeling/linear_model/linear_regression.py +19 -3
  116. snowflake/ml/modeling/linear_model/logistic_regression.py +19 -3
  117. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +19 -3
  118. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +19 -3
  119. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +19 -3
  120. snowflake/ml/modeling/linear_model/multi_task_lasso.py +19 -3
  121. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +19 -3
  122. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +19 -3
  123. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +19 -3
  124. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +19 -3
  125. snowflake/ml/modeling/linear_model/perceptron.py +19 -3
  126. snowflake/ml/modeling/linear_model/poisson_regressor.py +19 -3
  127. snowflake/ml/modeling/linear_model/ransac_regressor.py +19 -3
  128. snowflake/ml/modeling/linear_model/ridge.py +19 -3
  129. snowflake/ml/modeling/linear_model/ridge_classifier.py +19 -3
  130. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +19 -3
  131. snowflake/ml/modeling/linear_model/ridge_cv.py +19 -3
  132. snowflake/ml/modeling/linear_model/sgd_classifier.py +19 -3
  133. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +19 -3
  134. snowflake/ml/modeling/linear_model/sgd_regressor.py +19 -3
  135. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +19 -3
  136. snowflake/ml/modeling/linear_model/tweedie_regressor.py +19 -3
  137. snowflake/ml/modeling/manifold/isomap.py +19 -3
  138. snowflake/ml/modeling/manifold/mds.py +19 -3
  139. snowflake/ml/modeling/manifold/spectral_embedding.py +19 -3
  140. snowflake/ml/modeling/manifold/tsne.py +19 -3
  141. snowflake/ml/modeling/metrics/classification.py +5 -6
  142. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  143. snowflake/ml/modeling/metrics/ranking.py +7 -3
  144. snowflake/ml/modeling/metrics/regression.py +6 -3
  145. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +19 -3
  146. snowflake/ml/modeling/mixture/gaussian_mixture.py +19 -3
  147. snowflake/ml/modeling/model_selection/grid_search_cv.py +3 -13
  148. snowflake/ml/modeling/model_selection/randomized_search_cv.py +3 -13
  149. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +19 -3
  150. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +19 -3
  151. snowflake/ml/modeling/multiclass/output_code_classifier.py +19 -3
  152. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +19 -3
  153. snowflake/ml/modeling/naive_bayes/categorical_nb.py +19 -3
  154. snowflake/ml/modeling/naive_bayes/complement_nb.py +19 -3
  155. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +19 -3
  156. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +19 -3
  157. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +19 -3
  158. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +19 -3
  159. snowflake/ml/modeling/neighbors/kernel_density.py +19 -3
  160. snowflake/ml/modeling/neighbors/local_outlier_factor.py +19 -3
  161. snowflake/ml/modeling/neighbors/nearest_centroid.py +19 -3
  162. snowflake/ml/modeling/neighbors/nearest_neighbors.py +19 -3
  163. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +19 -3
  164. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +19 -3
  165. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +19 -3
  166. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +19 -3
  167. snowflake/ml/modeling/neural_network/mlp_classifier.py +19 -3
  168. snowflake/ml/modeling/neural_network/mlp_regressor.py +19 -3
  169. snowflake/ml/modeling/preprocessing/polynomial_features.py +19 -3
  170. snowflake/ml/modeling/semi_supervised/label_propagation.py +19 -3
  171. snowflake/ml/modeling/semi_supervised/label_spreading.py +19 -3
  172. snowflake/ml/modeling/svm/linear_svc.py +19 -3
  173. snowflake/ml/modeling/svm/linear_svr.py +19 -3
  174. snowflake/ml/modeling/svm/nu_svc.py +19 -3
  175. snowflake/ml/modeling/svm/nu_svr.py +19 -3
  176. snowflake/ml/modeling/svm/svc.py +19 -3
  177. snowflake/ml/modeling/svm/svr.py +19 -3
  178. snowflake/ml/modeling/tree/decision_tree_classifier.py +19 -3
  179. snowflake/ml/modeling/tree/decision_tree_regressor.py +19 -3
  180. snowflake/ml/modeling/tree/extra_tree_classifier.py +19 -3
  181. snowflake/ml/modeling/tree/extra_tree_regressor.py +19 -3
  182. snowflake/ml/modeling/xgboost/xgb_classifier.py +19 -3
  183. snowflake/ml/modeling/xgboost/xgb_regressor.py +19 -3
  184. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +19 -3
  185. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +19 -3
  186. snowflake/ml/registry/registry.py +2 -0
  187. snowflake/ml/version.py +1 -1
  188. snowflake_ml_python-1.2.2.dist-info/LICENSE.txt +202 -0
  189. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/METADATA +276 -50
  190. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/RECORD +204 -197
  191. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/WHEEL +2 -1
  192. snowflake_ml_python-1.2.2.dist-info/top_level.txt +1 -0
  193. /snowflake/ml/modeling/_internal/{pandas_trainer.py → local_implementations/pandas_trainer.py} +0 -0
  194. /snowflake/ml/modeling/_internal/{snowpark_trainer.py → snowpark_implementations/snowpark_trainer.py} +0 -0
@@ -108,7 +108,8 @@ def d2_absolute_error_score(
108
108
  result_module = cloudpickle.loads(pickled_snowflake_result)
109
109
  return result_module.serialize(session, score) # type: ignore[no-any-return]
110
110
 
111
- result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session))
111
+ kwargs = telemetry.get_sproc_statement_params_kwargs(d2_absolute_error_score_anon_sproc, statement_params)
112
+ result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session, **kwargs))
112
113
  score: Union[float, npt.NDArray[np.float_]] = result_object
113
114
  return score
114
115
 
@@ -205,7 +206,8 @@ def d2_pinball_score(
205
206
  result_module = cloudpickle.loads(pickled_result_module)
206
207
  return result_module.serialize(session, score) # type: ignore[no-any-return]
207
208
 
208
- result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session))
209
+ kwargs = telemetry.get_sproc_statement_params_kwargs(d2_pinball_score_anon_sproc, statement_params)
210
+ result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session, **kwargs))
209
211
 
210
212
  score: Union[float, npt.NDArray[np.float_]] = result_object
211
213
  return score
@@ -319,7 +321,8 @@ def explained_variance_score(
319
321
  result_module = cloudpickle.loads(pickled_result_module)
320
322
  return result_module.serialize(session, score) # type: ignore[no-any-return]
321
323
 
322
- result_object = result.deserialize(session, explained_variance_score_anon_sproc(session))
324
+ kwargs = telemetry.get_sproc_statement_params_kwargs(explained_variance_score_anon_sproc, statement_params)
325
+ result_object = result.deserialize(session, explained_variance_score_anon_sproc(session, **kwargs))
323
326
  score: Union[float, npt.NDArray[np.float_]] = result_object
324
327
  return score
325
328
 
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -287,7 +287,7 @@ class BayesianGaussianMixture(BaseTransformer):
287
287
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
288
288
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
289
289
  self._snowpark_cols: Optional[List[str]] = self.input_cols
290
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=BayesianGaussianMixture.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
290
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=BayesianGaussianMixture.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
291
291
  self._autogenerated = True
292
292
 
293
293
  def _get_rand_id(self) -> str:
@@ -647,6 +647,22 @@ class BayesianGaussianMixture(BaseTransformer):
647
647
  # each row containing a list of values.
648
648
  expected_dtype = "ARRAY"
649
649
 
650
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
651
+ if expected_dtype == "":
652
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
653
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
654
+ expected_dtype = "ARRAY"
655
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
656
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
657
+ expected_dtype = "ARRAY"
658
+ else:
659
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
660
+ # We can only infer the output types from the input types if the following two statemetns are true:
661
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
662
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
663
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
664
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
665
+
650
666
  output_df = self._batch_inference(
651
667
  dataset=dataset,
652
668
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -260,7 +260,7 @@ class GaussianMixture(BaseTransformer):
260
260
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
261
261
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
262
262
  self._snowpark_cols: Optional[List[str]] = self.input_cols
263
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=GaussianMixture.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
263
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=GaussianMixture.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
264
264
  self._autogenerated = True
265
265
 
266
266
  def _get_rand_id(self) -> str:
@@ -620,6 +620,22 @@ class GaussianMixture(BaseTransformer):
620
620
  # each row containing a list of values.
621
621
  expected_dtype = "ARRAY"
622
622
 
623
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
624
+ if expected_dtype == "":
625
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
626
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
627
+ expected_dtype = "ARRAY"
628
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
629
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
630
+ expected_dtype = "ARRAY"
631
+ else:
632
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
633
+ # We can only infer the output types from the input types if the following two statemetns are true:
634
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
635
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
636
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
637
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
638
+
623
639
  output_df = self._batch_inference(
624
640
  dataset=dataset,
625
641
  inference_method="transform",
@@ -3,7 +3,6 @@
3
3
  # Do not modify the auto-generated code(except automatic reformatting by precommit hooks).
4
4
  #
5
5
  from typing import Any, Dict, Iterable, List, Optional, Set, Union
6
- from uuid import uuid4
7
6
 
8
7
  import cloudpickle as cp
9
8
  import numpy as np
@@ -22,7 +21,7 @@ from snowflake.ml.model.model_signature import (
22
21
  ModelSignature,
23
22
  _infer_signature,
24
23
  )
25
- from snowflake.ml.modeling._internal.estimator_protocols import CVHandlers
24
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
26
25
  from snowflake.ml.modeling._internal.estimator_utils import (
27
26
  gather_dependencies,
28
27
  original_estimator_has_callable,
@@ -30,7 +29,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
30
29
  validate_sklearn_args,
31
30
  )
32
31
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
33
- from snowflake.ml.modeling._internal.snowpark_handlers import (
32
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import (
34
33
  SnowparkHandlers as HandlersImpl,
35
34
  )
36
35
  from snowflake.ml.modeling.framework.base import BaseTransformer
@@ -266,20 +265,11 @@ class GridSearchCV(BaseTransformer):
266
265
  self.set_drop_input_cols(drop_input_cols)
267
266
  self.set_sample_weight_col(sample_weight_col)
268
267
  self.set_passthrough_cols(passthrough_cols)
269
- self._handlers: CVHandlers = HandlersImpl(
268
+ self._handlers: TransformerHandlers = HandlersImpl(
270
269
  class_name=self.__class__.__name__,
271
270
  subproject=_SUBPROJECT,
272
271
  )
273
272
 
274
- def _get_rand_id(self) -> str:
275
- """
276
- Generate random id to be used in sproc and stage names.
277
-
278
- Returns:
279
- Random id string usable in sproc, table, and stage names.
280
- """
281
- return str(uuid4()).replace("-", "_").upper()
282
-
283
273
  def _get_active_columns(self) -> List[str]:
284
274
  """ "Get the list of columns that are relevant to the transformer."""
285
275
  selected_cols = (
@@ -1,5 +1,4 @@
1
1
  from typing import Any, Dict, Iterable, List, Optional, Set, Union
2
- from uuid import uuid4
3
2
 
4
3
  import cloudpickle as cp
5
4
  import numpy as np
@@ -19,7 +18,7 @@ from snowflake.ml.model.model_signature import (
19
18
  ModelSignature,
20
19
  _infer_signature,
21
20
  )
22
- from snowflake.ml.modeling._internal.estimator_protocols import CVHandlers
21
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
23
22
  from snowflake.ml.modeling._internal.estimator_utils import (
24
23
  gather_dependencies,
25
24
  original_estimator_has_callable,
@@ -27,7 +26,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
27
26
  validate_sklearn_args,
28
27
  )
29
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
30
- from snowflake.ml.modeling._internal.snowpark_handlers import (
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import (
31
30
  SnowparkHandlers as HandlersImpl,
32
31
  )
33
32
  from snowflake.ml.modeling.framework.base import BaseTransformer
@@ -278,20 +277,11 @@ class RandomizedSearchCV(BaseTransformer):
278
277
  self.set_drop_input_cols(drop_input_cols)
279
278
  self.set_sample_weight_col(sample_weight_col)
280
279
  self.set_passthrough_cols(passthrough_cols)
281
- self._handlers: CVHandlers = HandlersImpl(
280
+ self._handlers: TransformerHandlers = HandlersImpl(
282
281
  class_name=self.__class__.__name__,
283
282
  subproject=_SUBPROJECT,
284
283
  )
285
284
 
286
- def _get_rand_id(self) -> str:
287
- """
288
- Generate random id to be used in sproc and stage names.
289
-
290
- Returns:
291
- Random id string usable in sproc, table, and stage names.
292
- """
293
- return str(uuid4()).replace("-", "_").upper()
294
-
295
285
  def _get_active_columns(self) -> List[str]:
296
286
  """ "Get the list of columns that are relevant to the transformer."""
297
287
  selected_cols = (
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -172,7 +172,7 @@ class OneVsOneClassifier(BaseTransformer):
172
172
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
173
173
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
174
174
  self._snowpark_cols: Optional[List[str]] = self.input_cols
175
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=OneVsOneClassifier.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
175
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=OneVsOneClassifier.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
176
176
  self._autogenerated = True
177
177
 
178
178
  def _get_rand_id(self) -> str:
@@ -532,6 +532,22 @@ class OneVsOneClassifier(BaseTransformer):
532
532
  # each row containing a list of values.
533
533
  expected_dtype = "ARRAY"
534
534
 
535
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
536
+ if expected_dtype == "":
537
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
538
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
539
+ expected_dtype = "ARRAY"
540
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
541
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
542
+ expected_dtype = "ARRAY"
543
+ else:
544
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
545
+ # We can only infer the output types from the input types if the following two statemetns are true:
546
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
547
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
548
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
549
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
550
+
535
551
  output_df = self._batch_inference(
536
552
  dataset=dataset,
537
553
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -181,7 +181,7 @@ class OneVsRestClassifier(BaseTransformer):
181
181
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
182
182
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
183
183
  self._snowpark_cols: Optional[List[str]] = self.input_cols
184
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=OneVsRestClassifier.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
184
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=OneVsRestClassifier.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
185
185
  self._autogenerated = True
186
186
 
187
187
  def _get_rand_id(self) -> str:
@@ -541,6 +541,22 @@ class OneVsRestClassifier(BaseTransformer):
541
541
  # each row containing a list of values.
542
542
  expected_dtype = "ARRAY"
543
543
 
544
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
545
+ if expected_dtype == "":
546
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
547
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
548
+ expected_dtype = "ARRAY"
549
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
550
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ else:
553
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
554
+ # We can only infer the output types from the input types if the following two statemetns are true:
555
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
556
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
557
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
558
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
559
+
544
560
  output_df = self._batch_inference(
545
561
  dataset=dataset,
546
562
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -184,7 +184,7 @@ class OutputCodeClassifier(BaseTransformer):
184
184
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
185
185
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
186
186
  self._snowpark_cols: Optional[List[str]] = self.input_cols
187
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=OutputCodeClassifier.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
187
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=OutputCodeClassifier.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
188
188
  self._autogenerated = True
189
189
 
190
190
  def _get_rand_id(self) -> str:
@@ -544,6 +544,22 @@ class OutputCodeClassifier(BaseTransformer):
544
544
  # each row containing a list of values.
545
545
  expected_dtype = "ARRAY"
546
546
 
547
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
548
+ if expected_dtype == "":
549
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
550
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
553
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
554
+ expected_dtype = "ARRAY"
555
+ else:
556
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
557
+ # We can only infer the output types from the input types if the following two statemetns are true:
558
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
559
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
560
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
561
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
562
+
547
563
  output_df = self._batch_inference(
548
564
  dataset=dataset,
549
565
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -184,7 +184,7 @@ class BernoulliNB(BaseTransformer):
184
184
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
185
185
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
186
186
  self._snowpark_cols: Optional[List[str]] = self.input_cols
187
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=BernoulliNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
187
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=BernoulliNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
188
188
  self._autogenerated = True
189
189
 
190
190
  def _get_rand_id(self) -> str:
@@ -544,6 +544,22 @@ class BernoulliNB(BaseTransformer):
544
544
  # each row containing a list of values.
545
545
  expected_dtype = "ARRAY"
546
546
 
547
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
548
+ if expected_dtype == "":
549
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
550
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
553
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
554
+ expected_dtype = "ARRAY"
555
+ else:
556
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
557
+ # We can only infer the output types from the input types if the following two statemetns are true:
558
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
559
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
560
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
561
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
562
+
547
563
  output_df = self._batch_inference(
548
564
  dataset=dataset,
549
565
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -190,7 +190,7 @@ class CategoricalNB(BaseTransformer):
190
190
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
191
191
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
192
192
  self._snowpark_cols: Optional[List[str]] = self.input_cols
193
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=CategoricalNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
193
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=CategoricalNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
194
194
  self._autogenerated = True
195
195
 
196
196
  def _get_rand_id(self) -> str:
@@ -550,6 +550,22 @@ class CategoricalNB(BaseTransformer):
550
550
  # each row containing a list of values.
551
551
  expected_dtype = "ARRAY"
552
552
 
553
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
554
+ if expected_dtype == "":
555
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
556
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
557
+ expected_dtype = "ARRAY"
558
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
559
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
560
+ expected_dtype = "ARRAY"
561
+ else:
562
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
563
+ # We can only infer the output types from the input types if the following two statemetns are true:
564
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
565
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
566
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
567
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
568
+
553
569
  output_df = self._batch_inference(
554
570
  dataset=dataset,
555
571
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -184,7 +184,7 @@ class ComplementNB(BaseTransformer):
184
184
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
185
185
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
186
186
  self._snowpark_cols: Optional[List[str]] = self.input_cols
187
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=ComplementNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
187
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=ComplementNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
188
188
  self._autogenerated = True
189
189
 
190
190
  def _get_rand_id(self) -> str:
@@ -544,6 +544,22 @@ class ComplementNB(BaseTransformer):
544
544
  # each row containing a list of values.
545
545
  expected_dtype = "ARRAY"
546
546
 
547
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
548
+ if expected_dtype == "":
549
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
550
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
553
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
554
+ expected_dtype = "ARRAY"
555
+ else:
556
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
557
+ # We can only infer the output types from the input types if the following two statemetns are true:
558
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
559
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
560
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
561
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
562
+
547
563
  output_df = self._batch_inference(
548
564
  dataset=dataset,
549
565
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -165,7 +165,7 @@ class GaussianNB(BaseTransformer):
165
165
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
166
166
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
167
167
  self._snowpark_cols: Optional[List[str]] = self.input_cols
168
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=GaussianNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
168
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=GaussianNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
169
169
  self._autogenerated = True
170
170
 
171
171
  def _get_rand_id(self) -> str:
@@ -525,6 +525,22 @@ class GaussianNB(BaseTransformer):
525
525
  # each row containing a list of values.
526
526
  expected_dtype = "ARRAY"
527
527
 
528
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
529
+ if expected_dtype == "":
530
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
531
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
532
+ expected_dtype = "ARRAY"
533
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
534
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
535
+ expected_dtype = "ARRAY"
536
+ else:
537
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
538
+ # We can only infer the output types from the input types if the following two statemetns are true:
539
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
540
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
541
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
542
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
543
+
528
544
  output_df = self._batch_inference(
529
545
  dataset=dataset,
530
546
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -178,7 +178,7 @@ class MultinomialNB(BaseTransformer):
178
178
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
179
179
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
180
180
  self._snowpark_cols: Optional[List[str]] = self.input_cols
181
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=MultinomialNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
181
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=MultinomialNB.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
182
182
  self._autogenerated = True
183
183
 
184
184
  def _get_rand_id(self) -> str:
@@ -538,6 +538,22 @@ class MultinomialNB(BaseTransformer):
538
538
  # each row containing a list of values.
539
539
  expected_dtype = "ARRAY"
540
540
 
541
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
542
+ if expected_dtype == "":
543
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
544
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
545
+ expected_dtype = "ARRAY"
546
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
547
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
548
+ expected_dtype = "ARRAY"
549
+ else:
550
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
551
+ # We can only infer the output types from the input types if the following two statemetns are true:
552
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
553
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
554
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
555
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
556
+
541
557
  output_df = self._batch_inference(
542
558
  dataset=dataset,
543
559
  inference_method="transform",