snowflake-ml-python 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/utils/identifier.py +3 -1
  3. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  4. snowflake/ml/feature_store/feature_store.py +151 -78
  5. snowflake/ml/feature_store/feature_view.py +12 -24
  6. snowflake/ml/fileset/sfcfs.py +56 -50
  7. snowflake/ml/fileset/stage_fs.py +48 -13
  8. snowflake/ml/model/_client/model/model_version_impl.py +2 -50
  9. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  10. snowflake/ml/model/_client/sql/model.py +23 -2
  11. snowflake/ml/model/_client/sql/model_version.py +22 -1
  12. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  13. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  14. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  15. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  16. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  17. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  18. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  19. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  20. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  21. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  22. snowflake/ml/model/_packager/model_packager.py +2 -2
  23. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  24. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  25. snowflake/ml/model/type_hints.py +21 -2
  26. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  27. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  28. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  29. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  30. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  31. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  32. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  33. snowflake/ml/modeling/cluster/birch.py +195 -123
  34. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  35. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  36. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  37. snowflake/ml/modeling/cluster/k_means.py +195 -123
  38. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  39. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  40. snowflake/ml/modeling/cluster/optics.py +195 -123
  41. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  42. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  43. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  44. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  45. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  46. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  47. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  48. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  49. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  50. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  51. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  52. snowflake/ml/modeling/covariance/oas.py +195 -123
  53. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  54. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  55. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  56. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  57. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  58. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  59. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  60. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  61. snowflake/ml/modeling/decomposition/pca.py +195 -123
  62. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  63. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  64. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  65. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  66. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  67. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  68. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  69. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  70. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  71. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  72. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  73. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  74. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  76. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  77. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  78. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  79. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  80. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  81. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  82. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  83. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  84. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  85. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  86. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  87. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  88. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  89. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  90. snowflake/ml/modeling/framework/_utils.py +8 -1
  91. snowflake/ml/modeling/framework/base.py +9 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  94. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  95. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  96. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  97. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  98. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  99. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  100. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  101. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  102. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  103. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +195 -123
  104. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +195 -123
  105. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  106. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  107. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  108. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  109. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  110. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  111. snowflake/ml/modeling/linear_model/lars.py +195 -123
  112. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  113. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  114. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  115. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  116. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  117. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  118. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  119. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  124. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  125. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  126. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  127. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  128. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  129. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  130. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  131. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  132. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  133. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  134. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  136. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  137. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  138. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  139. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  140. snowflake/ml/modeling/manifold/isomap.py +195 -123
  141. snowflake/ml/modeling/manifold/mds.py +195 -123
  142. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  143. snowflake/ml/modeling/manifold/tsne.py +195 -123
  144. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  145. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  146. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  147. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  148. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  149. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  150. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  151. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  152. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  153. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  154. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  155. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  156. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  157. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  158. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  159. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  160. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  161. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  162. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  163. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  164. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  165. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  166. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  167. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  168. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  169. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  170. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  171. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  172. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  173. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  174. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  175. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  176. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  177. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  178. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  179. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  180. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  181. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  182. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  183. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  184. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  185. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  186. snowflake/ml/modeling/svm/svc.py +195 -123
  187. snowflake/ml/modeling/svm/svr.py +195 -123
  188. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  189. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  190. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  191. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  192. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  193. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  194. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  195. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  196. snowflake/ml/registry/registry.py +1 -1
  197. snowflake/ml/version.py +1 -1
  198. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +68 -57
  199. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +202 -200
  200. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  201. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  202. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  203. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -256,12 +255,7 @@ class TheilSenRegressor(BaseTransformer):
256
255
  )
257
256
  return selected_cols
258
257
 
259
- @telemetry.send_api_usage_telemetry(
260
- project=_PROJECT,
261
- subproject=_SUBPROJECT,
262
- custom_tags=dict([("autogen", True)]),
263
- )
264
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "TheilSenRegressor":
258
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "TheilSenRegressor":
265
259
  """Fit linear model
266
260
  For more details on this function, see [sklearn.linear_model.TheilSenRegressor.fit]
267
261
  (https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.TheilSenRegressor.html#sklearn.linear_model.TheilSenRegressor.fit)
@@ -288,12 +282,14 @@ class TheilSenRegressor(BaseTransformer):
288
282
 
289
283
  self._snowpark_cols = dataset.select(self.input_cols).columns
290
284
 
291
- # If we are already in a stored procedure, no need to kick off another one.
285
+ # If we are already in a stored procedure, no need to kick off another one.
292
286
  if SNOWML_SPROC_ENV in os.environ:
293
287
  statement_params = telemetry.get_function_usage_statement_params(
294
288
  project=_PROJECT,
295
289
  subproject=_SUBPROJECT,
296
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), TheilSenRegressor.__class__.__name__),
290
+ function_name=telemetry.get_statement_params_full_func_name(
291
+ inspect.currentframe(), TheilSenRegressor.__class__.__name__
292
+ ),
297
293
  api_calls=[Session.call],
298
294
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
299
295
  )
@@ -314,7 +310,7 @@ class TheilSenRegressor(BaseTransformer):
314
310
  )
315
311
  self._sklearn_object = model_trainer.train()
316
312
  self._is_fitted = True
317
- self._get_model_signatures(dataset)
313
+ self._generate_model_signatures(dataset)
318
314
  return self
319
315
 
320
316
  def _batch_inference_validate_snowpark(
@@ -390,7 +386,9 @@ class TheilSenRegressor(BaseTransformer):
390
386
  # when it is classifier, infer the datatype from label columns
391
387
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
392
388
  # Batch inference takes a single expected output column type. Use the first columns type for now.
393
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
389
+ label_cols_signatures = [
390
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
391
+ ]
394
392
  if len(label_cols_signatures) == 0:
395
393
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
396
394
  raise exceptions.SnowflakeMLException(
@@ -398,25 +396,22 @@ class TheilSenRegressor(BaseTransformer):
398
396
  original_exception=ValueError(error_str),
399
397
  )
400
398
 
401
- expected_type_inferred = convert_sp_to_sf_type(
402
- label_cols_signatures[0].as_snowpark_type()
403
- )
399
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
404
400
 
405
401
  self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
406
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
402
+ assert isinstance(
403
+ dataset._session, Session
404
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
407
405
 
408
406
  transform_kwargs = dict(
409
- session = dataset._session,
410
- dependencies = self._deps,
411
- drop_input_cols = self._drop_input_cols,
412
- expected_output_cols_type = expected_type_inferred,
407
+ session=dataset._session,
408
+ dependencies=self._deps,
409
+ drop_input_cols=self._drop_input_cols,
410
+ expected_output_cols_type=expected_type_inferred,
413
411
  )
414
412
 
415
413
  elif isinstance(dataset, pd.DataFrame):
416
- transform_kwargs = dict(
417
- snowpark_input_cols = self._snowpark_cols,
418
- drop_input_cols = self._drop_input_cols
419
- )
414
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
420
415
 
421
416
  transform_handlers = ModelTransformerBuilder.build(
422
417
  dataset=dataset,
@@ -456,7 +451,7 @@ class TheilSenRegressor(BaseTransformer):
456
451
  Transformed dataset.
457
452
  """
458
453
  super()._check_dataset_type(dataset)
459
- inference_method="transform"
454
+ inference_method = "transform"
460
455
 
461
456
  # This dictionary contains optional kwargs for batch inference. These kwargs
462
457
  # are specific to the type of dataset used.
@@ -493,17 +488,14 @@ class TheilSenRegressor(BaseTransformer):
493
488
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
494
489
 
495
490
  transform_kwargs = dict(
496
- session = dataset._session,
497
- dependencies = self._deps,
498
- drop_input_cols = self._drop_input_cols,
499
- expected_output_cols_type = expected_dtype,
491
+ session=dataset._session,
492
+ dependencies=self._deps,
493
+ drop_input_cols=self._drop_input_cols,
494
+ expected_output_cols_type=expected_dtype,
500
495
  )
501
496
 
502
497
  elif isinstance(dataset, pd.DataFrame):
503
- transform_kwargs = dict(
504
- snowpark_input_cols = self._snowpark_cols,
505
- drop_input_cols = self._drop_input_cols
506
- )
498
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
507
499
 
508
500
  transform_handlers = ModelTransformerBuilder.build(
509
501
  dataset=dataset,
@@ -522,7 +514,11 @@ class TheilSenRegressor(BaseTransformer):
522
514
  return output_df
523
515
 
524
516
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
525
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
517
+ def fit_predict(
518
+ self,
519
+ dataset: Union[DataFrame, pd.DataFrame],
520
+ output_cols_prefix: str = "fit_predict_",
521
+ ) -> Union[DataFrame, pd.DataFrame]:
526
522
  """ Method not supported for this class.
527
523
 
528
524
 
@@ -547,7 +543,9 @@ class TheilSenRegressor(BaseTransformer):
547
543
  )
548
544
  output_result, fitted_estimator = model_trainer.train_fit_predict(
549
545
  drop_input_cols=self._drop_input_cols,
550
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
546
+ expected_output_cols_list=(
547
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
548
+ ),
551
549
  )
552
550
  self._sklearn_object = fitted_estimator
553
551
  self._is_fitted = True
@@ -564,6 +562,62 @@ class TheilSenRegressor(BaseTransformer):
564
562
  assert self._sklearn_object is not None
565
563
  return self._sklearn_object.embedding_
566
564
 
565
+
566
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
567
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
568
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
569
+ """
570
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
571
+ # The following condition is introduced for kneighbors methods, and not used in other methods
572
+ if output_cols:
573
+ output_cols = [
574
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
575
+ for c in output_cols
576
+ ]
577
+ elif getattr(self._sklearn_object, "classes_", None) is None:
578
+ output_cols = [output_cols_prefix]
579
+ elif self._sklearn_object is not None:
580
+ classes = self._sklearn_object.classes_
581
+ if isinstance(classes, numpy.ndarray):
582
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
583
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
584
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
585
+ output_cols = []
586
+ for i, cl in enumerate(classes):
587
+ # For binary classification, there is only one output column for each class
588
+ # ndarray as the two classes are complementary.
589
+ if len(cl) == 2:
590
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
591
+ else:
592
+ output_cols.extend([
593
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
594
+ ])
595
+ else:
596
+ output_cols = []
597
+
598
+ # Make sure column names are valid snowflake identifiers.
599
+ assert output_cols is not None # Make MyPy happy
600
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
601
+
602
+ return rv
603
+
604
+ def _align_expected_output_names(
605
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
606
+ ) -> List[str]:
607
+ # in case the inferred output column names dimension is different
608
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
610
+ output_df_columns = list(output_df_pd.columns)
611
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
612
+ if self.sample_weight_col:
613
+ output_df_columns_set -= set(self.sample_weight_col)
614
+ # if the dimension of inferred output column names is correct; use it
615
+ if len(expected_output_cols_list) == len(output_df_columns_set):
616
+ return expected_output_cols_list
617
+ # otherwise, use the sklearn estimator's output
618
+ else:
619
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
620
+
567
621
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
568
622
  @telemetry.send_api_usage_telemetry(
569
623
  project=_PROJECT,
@@ -594,24 +648,28 @@ class TheilSenRegressor(BaseTransformer):
594
648
  # are specific to the type of dataset used.
595
649
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
596
650
 
651
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
652
+
597
653
  if isinstance(dataset, DataFrame):
598
654
  self._deps = self._batch_inference_validate_snowpark(
599
655
  dataset=dataset,
600
656
  inference_method=inference_method,
601
657
  )
602
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
658
+ assert isinstance(
659
+ dataset._session, Session
660
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
603
661
  transform_kwargs = dict(
604
662
  session=dataset._session,
605
663
  dependencies=self._deps,
606
- drop_input_cols = self._drop_input_cols,
664
+ drop_input_cols=self._drop_input_cols,
607
665
  expected_output_cols_type="float",
608
666
  )
667
+ expected_output_cols = self._align_expected_output_names(
668
+ inference_method, dataset, expected_output_cols, output_cols_prefix
669
+ )
609
670
 
610
671
  elif isinstance(dataset, pd.DataFrame):
611
- transform_kwargs = dict(
612
- snowpark_input_cols = self._snowpark_cols,
613
- drop_input_cols = self._drop_input_cols
614
- )
672
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
615
673
 
616
674
  transform_handlers = ModelTransformerBuilder.build(
617
675
  dataset=dataset,
@@ -623,7 +681,7 @@ class TheilSenRegressor(BaseTransformer):
623
681
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
624
682
  inference_method=inference_method,
625
683
  input_cols=self.input_cols,
626
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
684
+ expected_output_cols=expected_output_cols,
627
685
  **transform_kwargs
628
686
  )
629
687
  return output_df
@@ -653,7 +711,8 @@ class TheilSenRegressor(BaseTransformer):
653
711
  Output dataset with log probability of the sample for each class in the model.
654
712
  """
655
713
  super()._check_dataset_type(dataset)
656
- inference_method="predict_log_proba"
714
+ inference_method = "predict_log_proba"
715
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
657
716
 
658
717
  # This dictionary contains optional kwargs for batch inference. These kwargs
659
718
  # are specific to the type of dataset used.
@@ -664,18 +723,20 @@ class TheilSenRegressor(BaseTransformer):
664
723
  dataset=dataset,
665
724
  inference_method=inference_method,
666
725
  )
667
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
726
+ assert isinstance(
727
+ dataset._session, Session
728
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
668
729
  transform_kwargs = dict(
669
730
  session=dataset._session,
670
731
  dependencies=self._deps,
671
- drop_input_cols = self._drop_input_cols,
732
+ drop_input_cols=self._drop_input_cols,
672
733
  expected_output_cols_type="float",
673
734
  )
735
+ expected_output_cols = self._align_expected_output_names(
736
+ inference_method, dataset, expected_output_cols, output_cols_prefix
737
+ )
674
738
  elif isinstance(dataset, pd.DataFrame):
675
- transform_kwargs = dict(
676
- snowpark_input_cols = self._snowpark_cols,
677
- drop_input_cols = self._drop_input_cols
678
- )
739
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
679
740
 
680
741
  transform_handlers = ModelTransformerBuilder.build(
681
742
  dataset=dataset,
@@ -688,7 +749,7 @@ class TheilSenRegressor(BaseTransformer):
688
749
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
689
750
  inference_method=inference_method,
690
751
  input_cols=self.input_cols,
691
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
752
+ expected_output_cols=expected_output_cols,
692
753
  **transform_kwargs
693
754
  )
694
755
  return output_df
@@ -714,30 +775,34 @@ class TheilSenRegressor(BaseTransformer):
714
775
  Output dataset with results of the decision function for the samples in input dataset.
715
776
  """
716
777
  super()._check_dataset_type(dataset)
717
- inference_method="decision_function"
778
+ inference_method = "decision_function"
718
779
 
719
780
  # This dictionary contains optional kwargs for batch inference. These kwargs
720
781
  # are specific to the type of dataset used.
721
782
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
722
783
 
784
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
785
+
723
786
  if isinstance(dataset, DataFrame):
724
787
  self._deps = self._batch_inference_validate_snowpark(
725
788
  dataset=dataset,
726
789
  inference_method=inference_method,
727
790
  )
728
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
791
+ assert isinstance(
792
+ dataset._session, Session
793
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
729
794
  transform_kwargs = dict(
730
795
  session=dataset._session,
731
796
  dependencies=self._deps,
732
- drop_input_cols = self._drop_input_cols,
797
+ drop_input_cols=self._drop_input_cols,
733
798
  expected_output_cols_type="float",
734
799
  )
800
+ expected_output_cols = self._align_expected_output_names(
801
+ inference_method, dataset, expected_output_cols, output_cols_prefix
802
+ )
735
803
 
736
804
  elif isinstance(dataset, pd.DataFrame):
737
- transform_kwargs = dict(
738
- snowpark_input_cols = self._snowpark_cols,
739
- drop_input_cols = self._drop_input_cols
740
- )
805
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
741
806
 
742
807
  transform_handlers = ModelTransformerBuilder.build(
743
808
  dataset=dataset,
@@ -750,7 +815,7 @@ class TheilSenRegressor(BaseTransformer):
750
815
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
751
816
  inference_method=inference_method,
752
817
  input_cols=self.input_cols,
753
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
818
+ expected_output_cols=expected_output_cols,
754
819
  **transform_kwargs
755
820
  )
756
821
  return output_df
@@ -779,12 +844,14 @@ class TheilSenRegressor(BaseTransformer):
779
844
  Output dataset with probability of the sample for each class in the model.
780
845
  """
781
846
  super()._check_dataset_type(dataset)
782
- inference_method="score_samples"
847
+ inference_method = "score_samples"
783
848
 
784
849
  # This dictionary contains optional kwargs for batch inference. These kwargs
785
850
  # are specific to the type of dataset used.
786
851
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
787
852
 
853
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
854
+
788
855
  if isinstance(dataset, DataFrame):
789
856
  self._deps = self._batch_inference_validate_snowpark(
790
857
  dataset=dataset,
@@ -797,6 +864,9 @@ class TheilSenRegressor(BaseTransformer):
797
864
  drop_input_cols = self._drop_input_cols,
798
865
  expected_output_cols_type="float",
799
866
  )
867
+ expected_output_cols = self._align_expected_output_names(
868
+ inference_method, dataset, expected_output_cols, output_cols_prefix
869
+ )
800
870
 
801
871
  elif isinstance(dataset, pd.DataFrame):
802
872
  transform_kwargs = dict(
@@ -815,7 +885,7 @@ class TheilSenRegressor(BaseTransformer):
815
885
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
816
886
  inference_method=inference_method,
817
887
  input_cols=self.input_cols,
818
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
888
+ expected_output_cols=expected_output_cols,
819
889
  **transform_kwargs
820
890
  )
821
891
  return output_df
@@ -962,50 +1032,84 @@ class TheilSenRegressor(BaseTransformer):
962
1032
  )
963
1033
  return output_df
964
1034
 
1035
+
1036
+
1037
+ def to_sklearn(self) -> Any:
1038
+ """Get sklearn.linear_model.TheilSenRegressor object.
1039
+ """
1040
+ if self._sklearn_object is None:
1041
+ self._sklearn_object = self._create_sklearn_object()
1042
+ return self._sklearn_object
1043
+
1044
+ def to_xgboost(self) -> Any:
1045
+ raise exceptions.SnowflakeMLException(
1046
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1047
+ original_exception=AttributeError(
1048
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1049
+ "to_xgboost()",
1050
+ "to_sklearn()"
1051
+ )
1052
+ ),
1053
+ )
1054
+
1055
+ def to_lightgbm(self) -> Any:
1056
+ raise exceptions.SnowflakeMLException(
1057
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1058
+ original_exception=AttributeError(
1059
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1060
+ "to_lightgbm()",
1061
+ "to_sklearn()"
1062
+ )
1063
+ ),
1064
+ )
965
1065
 
966
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1066
+ def _get_dependencies(self) -> List[str]:
1067
+ return self._deps
1068
+
1069
+
1070
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
967
1071
  self._model_signature_dict = dict()
968
1072
 
969
1073
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
970
1074
 
971
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1075
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
972
1076
  outputs: List[BaseFeatureSpec] = []
973
1077
  if hasattr(self, "predict"):
974
1078
  # keep mypy happy
975
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1079
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
976
1080
  # For classifier, the type of predict is the same as the type of label
977
- if self._sklearn_object._estimator_type == 'classifier':
978
- # label columns is the desired type for output
1081
+ if self._sklearn_object._estimator_type == "classifier":
1082
+ # label columns is the desired type for output
979
1083
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
980
1084
  # rename the output columns
981
1085
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
982
- self._model_signature_dict["predict"] = ModelSignature(inputs,
983
- ([] if self._drop_input_cols else inputs)
984
- + outputs)
1086
+ self._model_signature_dict["predict"] = ModelSignature(
1087
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1088
+ )
985
1089
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
986
1090
  # For outlier models, returns -1 for outliers and 1 for inliers.
987
- # Clusterer returns int64 cluster labels.
1091
+ # Clusterer returns int64 cluster labels.
988
1092
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
989
1093
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
990
- self._model_signature_dict["predict"] = ModelSignature(inputs,
991
- ([] if self._drop_input_cols else inputs)
992
- + outputs)
993
-
1094
+ self._model_signature_dict["predict"] = ModelSignature(
1095
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1096
+ )
1097
+
994
1098
  # For regressor, the type of predict is float64
995
- elif self._sklearn_object._estimator_type == 'regressor':
1099
+ elif self._sklearn_object._estimator_type == "regressor":
996
1100
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
997
- self._model_signature_dict["predict"] = ModelSignature(inputs,
998
- ([] if self._drop_input_cols else inputs)
999
- + outputs)
1000
-
1101
+ self._model_signature_dict["predict"] = ModelSignature(
1102
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1103
+ )
1104
+
1001
1105
  for prob_func in PROB_FUNCTIONS:
1002
1106
  if hasattr(self, prob_func):
1003
1107
  output_cols_prefix: str = f"{prob_func}_"
1004
1108
  output_column_names = self._get_output_column_names(output_cols_prefix)
1005
1109
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1006
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
1007
- ([] if self._drop_input_cols else inputs)
1008
- + outputs)
1110
+ self._model_signature_dict[prob_func] = ModelSignature(
1111
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1112
+ )
1009
1113
 
1010
1114
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
1011
1115
  items = list(self._model_signature_dict.items())
@@ -1018,10 +1122,10 @@ class TheilSenRegressor(BaseTransformer):
1018
1122
  """Returns model signature of current class.
1019
1123
 
1020
1124
  Raises:
1021
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1125
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1022
1126
 
1023
1127
  Returns:
1024
- Dict[str, ModelSignature]: each method and its input output signature
1128
+ Dict with each method and its input output signature
1025
1129
  """
1026
1130
  if self._model_signature_dict is None:
1027
1131
  raise exceptions.SnowflakeMLException(
@@ -1029,35 +1133,3 @@ class TheilSenRegressor(BaseTransformer):
1029
1133
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1030
1134
  )
1031
1135
  return self._model_signature_dict
1032
-
1033
- def to_sklearn(self) -> Any:
1034
- """Get sklearn.linear_model.TheilSenRegressor object.
1035
- """
1036
- if self._sklearn_object is None:
1037
- self._sklearn_object = self._create_sklearn_object()
1038
- return self._sklearn_object
1039
-
1040
- def to_xgboost(self) -> Any:
1041
- raise exceptions.SnowflakeMLException(
1042
- error_code=error_codes.METHOD_NOT_ALLOWED,
1043
- original_exception=AttributeError(
1044
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1045
- "to_xgboost()",
1046
- "to_sklearn()"
1047
- )
1048
- ),
1049
- )
1050
-
1051
- def to_lightgbm(self) -> Any:
1052
- raise exceptions.SnowflakeMLException(
1053
- error_code=error_codes.METHOD_NOT_ALLOWED,
1054
- original_exception=AttributeError(
1055
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1056
- "to_lightgbm()",
1057
- "to_sklearn()"
1058
- )
1059
- ),
1060
- )
1061
-
1062
- def _get_dependencies(self) -> List[str]:
1063
- return self._deps