snowflake-ml-python 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. snowflake/ml/_internal/telemetry.py +19 -0
  2. snowflake/ml/model/_client/ops/model_ops.py +16 -38
  3. snowflake/ml/model/_client/sql/model.py +1 -7
  4. snowflake/ml/model/_client/sql/model_version.py +20 -15
  5. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -6
  6. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +0 -2
  7. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  8. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -2
  9. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  10. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  11. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  12. snowflake/ml/model/type_hints.py +3 -0
  13. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +63 -95
  14. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  15. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +16 -0
  16. snowflake/ml/modeling/cluster/affinity_propagation.py +16 -0
  17. snowflake/ml/modeling/cluster/agglomerative_clustering.py +16 -0
  18. snowflake/ml/modeling/cluster/birch.py +16 -0
  19. snowflake/ml/modeling/cluster/bisecting_k_means.py +16 -0
  20. snowflake/ml/modeling/cluster/dbscan.py +16 -0
  21. snowflake/ml/modeling/cluster/feature_agglomeration.py +16 -0
  22. snowflake/ml/modeling/cluster/k_means.py +16 -0
  23. snowflake/ml/modeling/cluster/mean_shift.py +16 -0
  24. snowflake/ml/modeling/cluster/mini_batch_k_means.py +16 -0
  25. snowflake/ml/modeling/cluster/optics.py +16 -0
  26. snowflake/ml/modeling/cluster/spectral_biclustering.py +16 -0
  27. snowflake/ml/modeling/cluster/spectral_clustering.py +16 -0
  28. snowflake/ml/modeling/cluster/spectral_coclustering.py +16 -0
  29. snowflake/ml/modeling/compose/column_transformer.py +16 -0
  30. snowflake/ml/modeling/compose/transformed_target_regressor.py +16 -0
  31. snowflake/ml/modeling/covariance/elliptic_envelope.py +16 -0
  32. snowflake/ml/modeling/covariance/empirical_covariance.py +16 -0
  33. snowflake/ml/modeling/covariance/graphical_lasso.py +16 -0
  34. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +16 -0
  35. snowflake/ml/modeling/covariance/ledoit_wolf.py +16 -0
  36. snowflake/ml/modeling/covariance/min_cov_det.py +16 -0
  37. snowflake/ml/modeling/covariance/oas.py +16 -0
  38. snowflake/ml/modeling/covariance/shrunk_covariance.py +16 -0
  39. snowflake/ml/modeling/decomposition/dictionary_learning.py +16 -0
  40. snowflake/ml/modeling/decomposition/factor_analysis.py +16 -0
  41. snowflake/ml/modeling/decomposition/fast_ica.py +16 -0
  42. snowflake/ml/modeling/decomposition/incremental_pca.py +16 -0
  43. snowflake/ml/modeling/decomposition/kernel_pca.py +16 -0
  44. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +16 -0
  45. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +16 -0
  46. snowflake/ml/modeling/decomposition/pca.py +16 -0
  47. snowflake/ml/modeling/decomposition/sparse_pca.py +16 -0
  48. snowflake/ml/modeling/decomposition/truncated_svd.py +16 -0
  49. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +16 -0
  50. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +16 -0
  51. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +16 -0
  52. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +16 -0
  53. snowflake/ml/modeling/ensemble/bagging_classifier.py +16 -0
  54. snowflake/ml/modeling/ensemble/bagging_regressor.py +16 -0
  55. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +16 -0
  56. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +16 -0
  57. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +16 -0
  58. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +16 -0
  59. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +16 -0
  60. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +16 -0
  61. snowflake/ml/modeling/ensemble/isolation_forest.py +16 -0
  62. snowflake/ml/modeling/ensemble/random_forest_classifier.py +16 -0
  63. snowflake/ml/modeling/ensemble/random_forest_regressor.py +16 -0
  64. snowflake/ml/modeling/ensemble/stacking_regressor.py +16 -0
  65. snowflake/ml/modeling/ensemble/voting_classifier.py +16 -0
  66. snowflake/ml/modeling/ensemble/voting_regressor.py +16 -0
  67. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +16 -0
  68. snowflake/ml/modeling/feature_selection/select_fdr.py +16 -0
  69. snowflake/ml/modeling/feature_selection/select_fpr.py +16 -0
  70. snowflake/ml/modeling/feature_selection/select_fwe.py +16 -0
  71. snowflake/ml/modeling/feature_selection/select_k_best.py +16 -0
  72. snowflake/ml/modeling/feature_selection/select_percentile.py +16 -0
  73. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +16 -0
  74. snowflake/ml/modeling/feature_selection/variance_threshold.py +16 -0
  75. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +16 -0
  76. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +16 -0
  77. snowflake/ml/modeling/impute/iterative_imputer.py +16 -0
  78. snowflake/ml/modeling/impute/knn_imputer.py +16 -0
  79. snowflake/ml/modeling/impute/missing_indicator.py +16 -0
  80. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +16 -0
  81. snowflake/ml/modeling/kernel_approximation/nystroem.py +16 -0
  82. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +16 -0
  83. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +16 -0
  84. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +16 -0
  85. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +16 -0
  86. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +16 -0
  87. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +16 -0
  88. snowflake/ml/modeling/linear_model/ard_regression.py +16 -0
  89. snowflake/ml/modeling/linear_model/bayesian_ridge.py +16 -0
  90. snowflake/ml/modeling/linear_model/elastic_net.py +16 -0
  91. snowflake/ml/modeling/linear_model/elastic_net_cv.py +16 -0
  92. snowflake/ml/modeling/linear_model/gamma_regressor.py +16 -0
  93. snowflake/ml/modeling/linear_model/huber_regressor.py +16 -0
  94. snowflake/ml/modeling/linear_model/lars.py +16 -0
  95. snowflake/ml/modeling/linear_model/lars_cv.py +16 -0
  96. snowflake/ml/modeling/linear_model/lasso.py +16 -0
  97. snowflake/ml/modeling/linear_model/lasso_cv.py +16 -0
  98. snowflake/ml/modeling/linear_model/lasso_lars.py +16 -0
  99. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +16 -0
  100. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +16 -0
  101. snowflake/ml/modeling/linear_model/linear_regression.py +16 -0
  102. snowflake/ml/modeling/linear_model/logistic_regression.py +16 -0
  103. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +16 -0
  104. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +16 -0
  105. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +16 -0
  106. snowflake/ml/modeling/linear_model/multi_task_lasso.py +16 -0
  107. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +16 -0
  108. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +16 -0
  109. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +16 -0
  110. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +16 -0
  111. snowflake/ml/modeling/linear_model/perceptron.py +16 -0
  112. snowflake/ml/modeling/linear_model/poisson_regressor.py +16 -0
  113. snowflake/ml/modeling/linear_model/ransac_regressor.py +16 -0
  114. snowflake/ml/modeling/linear_model/ridge.py +16 -0
  115. snowflake/ml/modeling/linear_model/ridge_classifier.py +16 -0
  116. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +16 -0
  117. snowflake/ml/modeling/linear_model/ridge_cv.py +16 -0
  118. snowflake/ml/modeling/linear_model/sgd_classifier.py +16 -0
  119. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +16 -0
  120. snowflake/ml/modeling/linear_model/sgd_regressor.py +16 -0
  121. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +16 -0
  122. snowflake/ml/modeling/linear_model/tweedie_regressor.py +16 -0
  123. snowflake/ml/modeling/manifold/isomap.py +16 -0
  124. snowflake/ml/modeling/manifold/mds.py +16 -0
  125. snowflake/ml/modeling/manifold/spectral_embedding.py +16 -0
  126. snowflake/ml/modeling/manifold/tsne.py +16 -0
  127. snowflake/ml/modeling/metrics/classification.py +5 -6
  128. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  129. snowflake/ml/modeling/metrics/ranking.py +7 -3
  130. snowflake/ml/modeling/metrics/regression.py +6 -3
  131. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +16 -0
  132. snowflake/ml/modeling/mixture/gaussian_mixture.py +16 -0
  133. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +16 -0
  134. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +16 -0
  135. snowflake/ml/modeling/multiclass/output_code_classifier.py +16 -0
  136. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +16 -0
  137. snowflake/ml/modeling/naive_bayes/categorical_nb.py +16 -0
  138. snowflake/ml/modeling/naive_bayes/complement_nb.py +16 -0
  139. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +16 -0
  140. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +16 -0
  141. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +16 -0
  142. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +16 -0
  143. snowflake/ml/modeling/neighbors/kernel_density.py +16 -0
  144. snowflake/ml/modeling/neighbors/local_outlier_factor.py +16 -0
  145. snowflake/ml/modeling/neighbors/nearest_centroid.py +16 -0
  146. snowflake/ml/modeling/neighbors/nearest_neighbors.py +16 -0
  147. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +16 -0
  148. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +16 -0
  149. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +16 -0
  150. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +16 -0
  151. snowflake/ml/modeling/neural_network/mlp_classifier.py +16 -0
  152. snowflake/ml/modeling/neural_network/mlp_regressor.py +16 -0
  153. snowflake/ml/modeling/preprocessing/polynomial_features.py +16 -0
  154. snowflake/ml/modeling/semi_supervised/label_propagation.py +16 -0
  155. snowflake/ml/modeling/semi_supervised/label_spreading.py +16 -0
  156. snowflake/ml/modeling/svm/linear_svc.py +16 -0
  157. snowflake/ml/modeling/svm/linear_svr.py +16 -0
  158. snowflake/ml/modeling/svm/nu_svc.py +16 -0
  159. snowflake/ml/modeling/svm/nu_svr.py +16 -0
  160. snowflake/ml/modeling/svm/svc.py +16 -0
  161. snowflake/ml/modeling/svm/svr.py +16 -0
  162. snowflake/ml/modeling/tree/decision_tree_classifier.py +16 -0
  163. snowflake/ml/modeling/tree/decision_tree_regressor.py +16 -0
  164. snowflake/ml/modeling/tree/extra_tree_classifier.py +16 -0
  165. snowflake/ml/modeling/tree/extra_tree_regressor.py +16 -0
  166. snowflake/ml/modeling/xgboost/xgb_classifier.py +16 -0
  167. snowflake/ml/modeling/xgboost/xgb_regressor.py +16 -0
  168. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +16 -0
  169. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +16 -0
  170. snowflake/ml/registry/registry.py +2 -0
  171. snowflake/ml/version.py +1 -1
  172. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  173. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +261 -50
  174. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.1.dist-info}/RECORD +189 -186
  175. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  176. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
@@ -566,6 +566,22 @@ class KernelRidge(BaseTransformer):
566
566
  # each row containing a list of values.
567
567
  expected_dtype = "ARRAY"
568
568
 
569
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
570
+ if expected_dtype == "":
571
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
572
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
573
+ expected_dtype = "ARRAY"
574
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
575
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
576
+ expected_dtype = "ARRAY"
577
+ else:
578
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
579
+ # We can only infer the output types from the input types if the following two statemetns are true:
580
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
581
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
582
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
583
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
584
+
569
585
  output_df = self._batch_inference(
570
586
  dataset=dataset,
571
587
  inference_method="transform",
@@ -554,6 +554,22 @@ class LGBMClassifier(BaseTransformer):
554
554
  # each row containing a list of values.
555
555
  expected_dtype = "ARRAY"
556
556
 
557
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
558
+ if expected_dtype == "":
559
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
560
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
561
+ expected_dtype = "ARRAY"
562
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
563
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
564
+ expected_dtype = "ARRAY"
565
+ else:
566
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
567
+ # We can only infer the output types from the input types if the following two statemetns are true:
568
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
569
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
570
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
571
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
572
+
557
573
  output_df = self._batch_inference(
558
574
  dataset=dataset,
559
575
  inference_method="transform",
@@ -554,6 +554,22 @@ class LGBMRegressor(BaseTransformer):
554
554
  # each row containing a list of values.
555
555
  expected_dtype = "ARRAY"
556
556
 
557
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
558
+ if expected_dtype == "":
559
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
560
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
561
+ expected_dtype = "ARRAY"
562
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
563
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
564
+ expected_dtype = "ARRAY"
565
+ else:
566
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
567
+ # We can only infer the output types from the input types if the following two statemetns are true:
568
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
569
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
570
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
571
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
572
+
557
573
  output_df = self._batch_inference(
558
574
  dataset=dataset,
559
575
  inference_method="transform",
@@ -580,6 +580,22 @@ class ARDRegression(BaseTransformer):
580
580
  # each row containing a list of values.
581
581
  expected_dtype = "ARRAY"
582
582
 
583
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
584
+ if expected_dtype == "":
585
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
586
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
587
+ expected_dtype = "ARRAY"
588
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
589
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
590
+ expected_dtype = "ARRAY"
591
+ else:
592
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
593
+ # We can only infer the output types from the input types if the following two statemetns are true:
594
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
595
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
596
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
597
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
598
+
583
599
  output_df = self._batch_inference(
584
600
  dataset=dataset,
585
601
  inference_method="transform",
@@ -591,6 +591,22 @@ class BayesianRidge(BaseTransformer):
591
591
  # each row containing a list of values.
592
592
  expected_dtype = "ARRAY"
593
593
 
594
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
595
+ if expected_dtype == "":
596
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
597
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
598
+ expected_dtype = "ARRAY"
599
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
600
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
601
+ expected_dtype = "ARRAY"
602
+ else:
603
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
604
+ # We can only infer the output types from the input types if the following two statemetns are true:
605
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
606
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
607
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
608
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
609
+
594
610
  output_df = self._batch_inference(
595
611
  dataset=dataset,
596
612
  inference_method="transform",
@@ -590,6 +590,22 @@ class ElasticNet(BaseTransformer):
590
590
  # each row containing a list of values.
591
591
  expected_dtype = "ARRAY"
592
592
 
593
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
594
+ if expected_dtype == "":
595
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
596
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
597
+ expected_dtype = "ARRAY"
598
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
599
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
600
+ expected_dtype = "ARRAY"
601
+ else:
602
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
603
+ # We can only infer the output types from the input types if the following two statemetns are true:
604
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
605
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
606
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
607
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
608
+
593
609
  output_df = self._batch_inference(
594
610
  dataset=dataset,
595
611
  inference_method="transform",
@@ -626,6 +626,22 @@ class ElasticNetCV(BaseTransformer):
626
626
  # each row containing a list of values.
627
627
  expected_dtype = "ARRAY"
628
628
 
629
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
630
+ if expected_dtype == "":
631
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
632
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
633
+ expected_dtype = "ARRAY"
634
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
635
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
636
+ expected_dtype = "ARRAY"
637
+ else:
638
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
639
+ # We can only infer the output types from the input types if the following two statemetns are true:
640
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
641
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
642
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
643
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
644
+
629
645
  output_df = self._batch_inference(
630
646
  dataset=dataset,
631
647
  inference_method="transform",
@@ -571,6 +571,22 @@ class GammaRegressor(BaseTransformer):
571
571
  # each row containing a list of values.
572
572
  expected_dtype = "ARRAY"
573
573
 
574
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
575
+ if expected_dtype == "":
576
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
577
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
578
+ expected_dtype = "ARRAY"
579
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
580
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
581
+ expected_dtype = "ARRAY"
582
+ else:
583
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
584
+ # We can only infer the output types from the input types if the following two statemetns are true:
585
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
586
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
587
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
588
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
589
+
574
590
  output_df = self._batch_inference(
575
591
  dataset=dataset,
576
592
  inference_method="transform",
@@ -554,6 +554,22 @@ class HuberRegressor(BaseTransformer):
554
554
  # each row containing a list of values.
555
555
  expected_dtype = "ARRAY"
556
556
 
557
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
558
+ if expected_dtype == "":
559
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
560
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
561
+ expected_dtype = "ARRAY"
562
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
563
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
564
+ expected_dtype = "ARRAY"
565
+ else:
566
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
567
+ # We can only infer the output types from the input types if the following two statemetns are true:
568
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
569
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
570
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
571
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
572
+
557
573
  output_df = self._batch_inference(
558
574
  dataset=dataset,
559
575
  inference_method="transform",
@@ -583,6 +583,22 @@ class Lars(BaseTransformer):
583
583
  # each row containing a list of values.
584
584
  expected_dtype = "ARRAY"
585
585
 
586
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
587
+ if expected_dtype == "":
588
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
589
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
590
+ expected_dtype = "ARRAY"
591
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
592
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
593
+ expected_dtype = "ARRAY"
594
+ else:
595
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
596
+ # We can only infer the output types from the input types if the following two statemetns are true:
597
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
598
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
599
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
600
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
601
+
586
602
  output_df = self._batch_inference(
587
603
  dataset=dataset,
588
604
  inference_method="transform",
@@ -591,6 +591,22 @@ class LarsCV(BaseTransformer):
591
591
  # each row containing a list of values.
592
592
  expected_dtype = "ARRAY"
593
593
 
594
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
595
+ if expected_dtype == "":
596
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
597
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
598
+ expected_dtype = "ARRAY"
599
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
600
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
601
+ expected_dtype = "ARRAY"
602
+ else:
603
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
604
+ # We can only infer the output types from the input types if the following two statemetns are true:
605
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
606
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
607
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
608
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
609
+
594
610
  output_df = self._batch_inference(
595
611
  dataset=dataset,
596
612
  inference_method="transform",
@@ -584,6 +584,22 @@ class Lasso(BaseTransformer):
584
584
  # each row containing a list of values.
585
585
  expected_dtype = "ARRAY"
586
586
 
587
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
588
+ if expected_dtype == "":
589
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
590
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
591
+ expected_dtype = "ARRAY"
592
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
593
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
594
+ expected_dtype = "ARRAY"
595
+ else:
596
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
597
+ # We can only infer the output types from the input types if the following two statemetns are true:
598
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
599
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
600
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
601
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
602
+
587
603
  output_df = self._batch_inference(
588
604
  dataset=dataset,
589
605
  inference_method="transform",
@@ -612,6 +612,22 @@ class LassoCV(BaseTransformer):
612
612
  # each row containing a list of values.
613
613
  expected_dtype = "ARRAY"
614
614
 
615
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
616
+ if expected_dtype == "":
617
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
618
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
619
+ expected_dtype = "ARRAY"
620
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
621
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
622
+ expected_dtype = "ARRAY"
623
+ else:
624
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
625
+ # We can only infer the output types from the input types if the following two statemetns are true:
626
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
627
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
628
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
629
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
630
+
615
631
  output_df = self._batch_inference(
616
632
  dataset=dataset,
617
633
  inference_method="transform",
@@ -604,6 +604,22 @@ class LassoLars(BaseTransformer):
604
604
  # each row containing a list of values.
605
605
  expected_dtype = "ARRAY"
606
606
 
607
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
608
+ if expected_dtype == "":
609
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
610
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
611
+ expected_dtype = "ARRAY"
612
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
613
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
614
+ expected_dtype = "ARRAY"
615
+ else:
616
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
617
+ # We can only infer the output types from the input types if the following two statemetns are true:
618
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
619
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
620
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
621
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
622
+
607
623
  output_df = self._batch_inference(
608
624
  dataset=dataset,
609
625
  inference_method="transform",
@@ -605,6 +605,22 @@ class LassoLarsCV(BaseTransformer):
605
605
  # each row containing a list of values.
606
606
  expected_dtype = "ARRAY"
607
607
 
608
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
609
+ if expected_dtype == "":
610
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
611
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
612
+ expected_dtype = "ARRAY"
613
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
614
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
615
+ expected_dtype = "ARRAY"
616
+ else:
617
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
618
+ # We can only infer the output types from the input types if the following two statemetns are true:
619
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
620
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
621
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
622
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
623
+
608
624
  output_df = self._batch_inference(
609
625
  dataset=dataset,
610
626
  inference_method="transform",
@@ -588,6 +588,22 @@ class LassoLarsIC(BaseTransformer):
588
588
  # each row containing a list of values.
589
589
  expected_dtype = "ARRAY"
590
590
 
591
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
592
+ if expected_dtype == "":
593
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
594
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
595
+ expected_dtype = "ARRAY"
596
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
597
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
598
+ expected_dtype = "ARRAY"
599
+ else:
600
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
601
+ # We can only infer the output types from the input types if the following two statemetns are true:
602
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
603
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
604
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
605
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
606
+
591
607
  output_df = self._batch_inference(
592
608
  dataset=dataset,
593
609
  inference_method="transform",
@@ -541,6 +541,22 @@ class LinearRegression(BaseTransformer):
541
541
  # each row containing a list of values.
542
542
  expected_dtype = "ARRAY"
543
543
 
544
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
545
+ if expected_dtype == "":
546
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
547
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
548
+ expected_dtype = "ARRAY"
549
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
550
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ else:
553
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
554
+ # We can only infer the output types from the input types if the following two statemetns are true:
555
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
556
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
557
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
558
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
559
+
544
560
  output_df = self._batch_inference(
545
561
  dataset=dataset,
546
562
  inference_method="transform",
@@ -655,6 +655,22 @@ class LogisticRegression(BaseTransformer):
655
655
  # each row containing a list of values.
656
656
  expected_dtype = "ARRAY"
657
657
 
658
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
659
+ if expected_dtype == "":
660
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
661
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
662
+ expected_dtype = "ARRAY"
663
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
664
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
665
+ expected_dtype = "ARRAY"
666
+ else:
667
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
668
+ # We can only infer the output types from the input types if the following two statemetns are true:
669
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
670
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
671
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
672
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
673
+
658
674
  output_df = self._batch_inference(
659
675
  dataset=dataset,
660
676
  inference_method="transform",
@@ -676,6 +676,22 @@ class LogisticRegressionCV(BaseTransformer):
676
676
  # each row containing a list of values.
677
677
  expected_dtype = "ARRAY"
678
678
 
679
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
680
+ if expected_dtype == "":
681
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
682
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
683
+ expected_dtype = "ARRAY"
684
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
685
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
686
+ expected_dtype = "ARRAY"
687
+ else:
688
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
689
+ # We can only infer the output types from the input types if the following two statemetns are true:
690
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
691
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
692
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
693
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
694
+
679
695
  output_df = self._batch_inference(
680
696
  dataset=dataset,
681
697
  inference_method="transform",
@@ -574,6 +574,22 @@ class MultiTaskElasticNet(BaseTransformer):
574
574
  # each row containing a list of values.
575
575
  expected_dtype = "ARRAY"
576
576
 
577
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
578
+ if expected_dtype == "":
579
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
580
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
581
+ expected_dtype = "ARRAY"
582
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
583
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
584
+ expected_dtype = "ARRAY"
585
+ else:
586
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
587
+ # We can only infer the output types from the input types if the following two statemetns are true:
588
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
589
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
590
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
591
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
592
+
577
593
  output_df = self._batch_inference(
578
594
  dataset=dataset,
579
595
  inference_method="transform",
@@ -615,6 +615,22 @@ class MultiTaskElasticNetCV(BaseTransformer):
615
615
  # each row containing a list of values.
616
616
  expected_dtype = "ARRAY"
617
617
 
618
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
619
+ if expected_dtype == "":
620
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
621
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
622
+ expected_dtype = "ARRAY"
623
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
624
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
625
+ expected_dtype = "ARRAY"
626
+ else:
627
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
628
+ # We can only infer the output types from the input types if the following two statemetns are true:
629
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
630
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
631
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
632
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
633
+
618
634
  output_df = self._batch_inference(
619
635
  dataset=dataset,
620
636
  inference_method="transform",
@@ -566,6 +566,22 @@ class MultiTaskLasso(BaseTransformer):
566
566
  # each row containing a list of values.
567
567
  expected_dtype = "ARRAY"
568
568
 
569
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
570
+ if expected_dtype == "":
571
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
572
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
573
+ expected_dtype = "ARRAY"
574
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
575
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
576
+ expected_dtype = "ARRAY"
577
+ else:
578
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
579
+ # We can only infer the output types from the input types if the following two statemetns are true:
580
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
581
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
582
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
583
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
584
+
569
585
  output_df = self._batch_inference(
570
586
  dataset=dataset,
571
587
  inference_method="transform",
@@ -601,6 +601,22 @@ class MultiTaskLassoCV(BaseTransformer):
601
601
  # each row containing a list of values.
602
602
  expected_dtype = "ARRAY"
603
603
 
604
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
605
+ if expected_dtype == "":
606
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
607
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
608
+ expected_dtype = "ARRAY"
609
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
610
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
611
+ expected_dtype = "ARRAY"
612
+ else:
613
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
614
+ # We can only infer the output types from the input types if the following two statemetns are true:
615
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
616
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
617
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
618
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
619
+
604
620
  output_df = self._batch_inference(
605
621
  dataset=dataset,
606
622
  inference_method="transform",
@@ -549,6 +549,22 @@ class OrthogonalMatchingPursuit(BaseTransformer):
549
549
  # each row containing a list of values.
550
550
  expected_dtype = "ARRAY"
551
551
 
552
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
553
+ if expected_dtype == "":
554
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
555
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
556
+ expected_dtype = "ARRAY"
557
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
558
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
559
+ expected_dtype = "ARRAY"
560
+ else:
561
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
562
+ # We can only infer the output types from the input types if the following two statemetns are true:
563
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
564
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
565
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
566
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
567
+
552
568
  output_df = self._batch_inference(
553
569
  dataset=dataset,
554
570
  inference_method="transform",