snowflake-ml-python 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/utils/formatting.py +1 -1
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/feature_store/feature_store.py +166 -184
  10. snowflake/ml/feature_store/feature_view.py +12 -24
  11. snowflake/ml/fileset/sfcfs.py +56 -50
  12. snowflake/ml/fileset/stage_fs.py +48 -13
  13. snowflake/ml/model/_client/model/model_version_impl.py +6 -49
  14. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  15. snowflake/ml/model/_client/sql/model.py +23 -2
  16. snowflake/ml/model/_client/sql/model_version.py +22 -1
  17. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -3
  18. snowflake/ml/model/_deploy_client/snowservice/deploy.py +5 -2
  19. snowflake/ml/model/_model_composer/model_composer.py +7 -5
  20. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  22. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  23. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  24. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  25. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +13 -1
  28. snowflake/ml/model/_packager/model_handlers/xgboost.py +1 -1
  29. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  31. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  32. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  33. snowflake/ml/model/_packager/model_packager.py +2 -2
  34. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +21 -2
  38. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  40. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  41. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +545 -0
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -5
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  46. snowflake/ml/modeling/cluster/birch.py +195 -123
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  48. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  50. snowflake/ml/modeling/cluster/k_means.py +195 -123
  51. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  53. snowflake/ml/modeling/cluster/optics.py +195 -123
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  57. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  64. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  65. snowflake/ml/modeling/covariance/oas.py +195 -123
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  69. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  74. snowflake/ml/modeling/decomposition/pca.py +195 -123
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  103. snowflake/ml/modeling/framework/_utils.py +8 -1
  104. snowflake/ml/modeling/framework/base.py +24 -6
  105. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  107. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  108. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  109. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  110. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +198 -125
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +198 -125
  119. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  121. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  125. snowflake/ml/modeling/linear_model/lars.py +195 -123
  126. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  127. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  132. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  142. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  145. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  154. snowflake/ml/modeling/manifold/isomap.py +195 -123
  155. snowflake/ml/modeling/manifold/mds.py +195 -123
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  157. snowflake/ml/modeling/manifold/tsne.py +195 -123
  158. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  159. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  160. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  161. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  162. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  163. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  164. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  165. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  166. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  167. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  168. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  169. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  170. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  171. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  172. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  173. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  174. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  175. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  176. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  177. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  178. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  179. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  180. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  181. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  182. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  183. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  184. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  185. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  186. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  187. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  188. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  189. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  190. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  191. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  192. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  193. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  194. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  195. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  196. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  197. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  198. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  199. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  200. snowflake/ml/modeling/svm/svc.py +195 -123
  201. snowflake/ml/modeling/svm/svr.py +195 -123
  202. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  203. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  204. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  205. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  206. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  207. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  208. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  209. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  210. snowflake/ml/registry/_manager/model_manager.py +5 -1
  211. snowflake/ml/registry/model_registry.py +99 -26
  212. snowflake/ml/registry/registry.py +3 -2
  213. snowflake/ml/version.py +1 -1
  214. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +94 -55
  215. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +218 -212
  216. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  217. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  218. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  219. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -333,12 +332,7 @@ class LogisticRegression(BaseTransformer):
333
332
  )
334
333
  return selected_cols
335
334
 
336
- @telemetry.send_api_usage_telemetry(
337
- project=_PROJECT,
338
- subproject=_SUBPROJECT,
339
- custom_tags=dict([("autogen", True)]),
340
- )
341
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LogisticRegression":
335
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LogisticRegression":
342
336
  """Fit the model according to the given training data
343
337
  For more details on this function, see [sklearn.linear_model.LogisticRegression.fit]
344
338
  (https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression.fit)
@@ -365,12 +359,14 @@ class LogisticRegression(BaseTransformer):
365
359
 
366
360
  self._snowpark_cols = dataset.select(self.input_cols).columns
367
361
 
368
- # If we are already in a stored procedure, no need to kick off another one.
362
+ # If we are already in a stored procedure, no need to kick off another one.
369
363
  if SNOWML_SPROC_ENV in os.environ:
370
364
  statement_params = telemetry.get_function_usage_statement_params(
371
365
  project=_PROJECT,
372
366
  subproject=_SUBPROJECT,
373
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), LogisticRegression.__class__.__name__),
367
+ function_name=telemetry.get_statement_params_full_func_name(
368
+ inspect.currentframe(), LogisticRegression.__class__.__name__
369
+ ),
374
370
  api_calls=[Session.call],
375
371
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
376
372
  )
@@ -391,7 +387,7 @@ class LogisticRegression(BaseTransformer):
391
387
  )
392
388
  self._sklearn_object = model_trainer.train()
393
389
  self._is_fitted = True
394
- self._get_model_signatures(dataset)
390
+ self._generate_model_signatures(dataset)
395
391
  return self
396
392
 
397
393
  def _batch_inference_validate_snowpark(
@@ -467,7 +463,9 @@ class LogisticRegression(BaseTransformer):
467
463
  # when it is classifier, infer the datatype from label columns
468
464
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
469
465
  # Batch inference takes a single expected output column type. Use the first columns type for now.
470
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
466
+ label_cols_signatures = [
467
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
468
+ ]
471
469
  if len(label_cols_signatures) == 0:
472
470
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
473
471
  raise exceptions.SnowflakeMLException(
@@ -475,25 +473,22 @@ class LogisticRegression(BaseTransformer):
475
473
  original_exception=ValueError(error_str),
476
474
  )
477
475
 
478
- expected_type_inferred = convert_sp_to_sf_type(
479
- label_cols_signatures[0].as_snowpark_type()
480
- )
476
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
481
477
 
482
478
  self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
483
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
479
+ assert isinstance(
480
+ dataset._session, Session
481
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
484
482
 
485
483
  transform_kwargs = dict(
486
- session = dataset._session,
487
- dependencies = self._deps,
488
- drop_input_cols = self._drop_input_cols,
489
- expected_output_cols_type = expected_type_inferred,
484
+ session=dataset._session,
485
+ dependencies=self._deps,
486
+ drop_input_cols=self._drop_input_cols,
487
+ expected_output_cols_type=expected_type_inferred,
490
488
  )
491
489
 
492
490
  elif isinstance(dataset, pd.DataFrame):
493
- transform_kwargs = dict(
494
- snowpark_input_cols = self._snowpark_cols,
495
- drop_input_cols = self._drop_input_cols
496
- )
491
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
497
492
 
498
493
  transform_handlers = ModelTransformerBuilder.build(
499
494
  dataset=dataset,
@@ -533,7 +528,7 @@ class LogisticRegression(BaseTransformer):
533
528
  Transformed dataset.
534
529
  """
535
530
  super()._check_dataset_type(dataset)
536
- inference_method="transform"
531
+ inference_method = "transform"
537
532
 
538
533
  # This dictionary contains optional kwargs for batch inference. These kwargs
539
534
  # are specific to the type of dataset used.
@@ -570,17 +565,14 @@ class LogisticRegression(BaseTransformer):
570
565
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
571
566
 
572
567
  transform_kwargs = dict(
573
- session = dataset._session,
574
- dependencies = self._deps,
575
- drop_input_cols = self._drop_input_cols,
576
- expected_output_cols_type = expected_dtype,
568
+ session=dataset._session,
569
+ dependencies=self._deps,
570
+ drop_input_cols=self._drop_input_cols,
571
+ expected_output_cols_type=expected_dtype,
577
572
  )
578
573
 
579
574
  elif isinstance(dataset, pd.DataFrame):
580
- transform_kwargs = dict(
581
- snowpark_input_cols = self._snowpark_cols,
582
- drop_input_cols = self._drop_input_cols
583
- )
575
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
584
576
 
585
577
  transform_handlers = ModelTransformerBuilder.build(
586
578
  dataset=dataset,
@@ -599,7 +591,11 @@ class LogisticRegression(BaseTransformer):
599
591
  return output_df
600
592
 
601
593
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
602
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
594
+ def fit_predict(
595
+ self,
596
+ dataset: Union[DataFrame, pd.DataFrame],
597
+ output_cols_prefix: str = "fit_predict_",
598
+ ) -> Union[DataFrame, pd.DataFrame]:
603
599
  """ Method not supported for this class.
604
600
 
605
601
 
@@ -624,7 +620,9 @@ class LogisticRegression(BaseTransformer):
624
620
  )
625
621
  output_result, fitted_estimator = model_trainer.train_fit_predict(
626
622
  drop_input_cols=self._drop_input_cols,
627
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
623
+ expected_output_cols_list=(
624
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
625
+ ),
628
626
  )
629
627
  self._sklearn_object = fitted_estimator
630
628
  self._is_fitted = True
@@ -641,6 +639,62 @@ class LogisticRegression(BaseTransformer):
641
639
  assert self._sklearn_object is not None
642
640
  return self._sklearn_object.embedding_
643
641
 
642
+
643
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
644
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
645
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
646
+ """
647
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
648
+ # The following condition is introduced for kneighbors methods, and not used in other methods
649
+ if output_cols:
650
+ output_cols = [
651
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
652
+ for c in output_cols
653
+ ]
654
+ elif getattr(self._sklearn_object, "classes_", None) is None:
655
+ output_cols = [output_cols_prefix]
656
+ elif self._sklearn_object is not None:
657
+ classes = self._sklearn_object.classes_
658
+ if isinstance(classes, numpy.ndarray):
659
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
660
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
661
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
662
+ output_cols = []
663
+ for i, cl in enumerate(classes):
664
+ # For binary classification, there is only one output column for each class
665
+ # ndarray as the two classes are complementary.
666
+ if len(cl) == 2:
667
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
668
+ else:
669
+ output_cols.extend([
670
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
671
+ ])
672
+ else:
673
+ output_cols = []
674
+
675
+ # Make sure column names are valid snowflake identifiers.
676
+ assert output_cols is not None # Make MyPy happy
677
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
678
+
679
+ return rv
680
+
681
+ def _align_expected_output_names(
682
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
683
+ ) -> List[str]:
684
+ # in case the inferred output column names dimension is different
685
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
686
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
687
+ output_df_columns = list(output_df_pd.columns)
688
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
689
+ if self.sample_weight_col:
690
+ output_df_columns_set -= set(self.sample_weight_col)
691
+ # if the dimension of inferred output column names is correct; use it
692
+ if len(expected_output_cols_list) == len(output_df_columns_set):
693
+ return expected_output_cols_list
694
+ # otherwise, use the sklearn estimator's output
695
+ else:
696
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
697
+
644
698
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
645
699
  @telemetry.send_api_usage_telemetry(
646
700
  project=_PROJECT,
@@ -673,24 +727,28 @@ class LogisticRegression(BaseTransformer):
673
727
  # are specific to the type of dataset used.
674
728
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
675
729
 
730
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
731
+
676
732
  if isinstance(dataset, DataFrame):
677
733
  self._deps = self._batch_inference_validate_snowpark(
678
734
  dataset=dataset,
679
735
  inference_method=inference_method,
680
736
  )
681
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
737
+ assert isinstance(
738
+ dataset._session, Session
739
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
682
740
  transform_kwargs = dict(
683
741
  session=dataset._session,
684
742
  dependencies=self._deps,
685
- drop_input_cols = self._drop_input_cols,
743
+ drop_input_cols=self._drop_input_cols,
686
744
  expected_output_cols_type="float",
687
745
  )
746
+ expected_output_cols = self._align_expected_output_names(
747
+ inference_method, dataset, expected_output_cols, output_cols_prefix
748
+ )
688
749
 
689
750
  elif isinstance(dataset, pd.DataFrame):
690
- transform_kwargs = dict(
691
- snowpark_input_cols = self._snowpark_cols,
692
- drop_input_cols = self._drop_input_cols
693
- )
751
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
694
752
 
695
753
  transform_handlers = ModelTransformerBuilder.build(
696
754
  dataset=dataset,
@@ -702,7 +760,7 @@ class LogisticRegression(BaseTransformer):
702
760
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
703
761
  inference_method=inference_method,
704
762
  input_cols=self.input_cols,
705
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
763
+ expected_output_cols=expected_output_cols,
706
764
  **transform_kwargs
707
765
  )
708
766
  return output_df
@@ -734,7 +792,8 @@ class LogisticRegression(BaseTransformer):
734
792
  Output dataset with log probability of the sample for each class in the model.
735
793
  """
736
794
  super()._check_dataset_type(dataset)
737
- inference_method="predict_log_proba"
795
+ inference_method = "predict_log_proba"
796
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
738
797
 
739
798
  # This dictionary contains optional kwargs for batch inference. These kwargs
740
799
  # are specific to the type of dataset used.
@@ -745,18 +804,20 @@ class LogisticRegression(BaseTransformer):
745
804
  dataset=dataset,
746
805
  inference_method=inference_method,
747
806
  )
748
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
807
+ assert isinstance(
808
+ dataset._session, Session
809
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
749
810
  transform_kwargs = dict(
750
811
  session=dataset._session,
751
812
  dependencies=self._deps,
752
- drop_input_cols = self._drop_input_cols,
813
+ drop_input_cols=self._drop_input_cols,
753
814
  expected_output_cols_type="float",
754
815
  )
816
+ expected_output_cols = self._align_expected_output_names(
817
+ inference_method, dataset, expected_output_cols, output_cols_prefix
818
+ )
755
819
  elif isinstance(dataset, pd.DataFrame):
756
- transform_kwargs = dict(
757
- snowpark_input_cols = self._snowpark_cols,
758
- drop_input_cols = self._drop_input_cols
759
- )
820
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
760
821
 
761
822
  transform_handlers = ModelTransformerBuilder.build(
762
823
  dataset=dataset,
@@ -769,7 +830,7 @@ class LogisticRegression(BaseTransformer):
769
830
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
770
831
  inference_method=inference_method,
771
832
  input_cols=self.input_cols,
772
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
833
+ expected_output_cols=expected_output_cols,
773
834
  **transform_kwargs
774
835
  )
775
836
  return output_df
@@ -797,30 +858,34 @@ class LogisticRegression(BaseTransformer):
797
858
  Output dataset with results of the decision function for the samples in input dataset.
798
859
  """
799
860
  super()._check_dataset_type(dataset)
800
- inference_method="decision_function"
861
+ inference_method = "decision_function"
801
862
 
802
863
  # This dictionary contains optional kwargs for batch inference. These kwargs
803
864
  # are specific to the type of dataset used.
804
865
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
805
866
 
867
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
868
+
806
869
  if isinstance(dataset, DataFrame):
807
870
  self._deps = self._batch_inference_validate_snowpark(
808
871
  dataset=dataset,
809
872
  inference_method=inference_method,
810
873
  )
811
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
874
+ assert isinstance(
875
+ dataset._session, Session
876
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
812
877
  transform_kwargs = dict(
813
878
  session=dataset._session,
814
879
  dependencies=self._deps,
815
- drop_input_cols = self._drop_input_cols,
880
+ drop_input_cols=self._drop_input_cols,
816
881
  expected_output_cols_type="float",
817
882
  )
883
+ expected_output_cols = self._align_expected_output_names(
884
+ inference_method, dataset, expected_output_cols, output_cols_prefix
885
+ )
818
886
 
819
887
  elif isinstance(dataset, pd.DataFrame):
820
- transform_kwargs = dict(
821
- snowpark_input_cols = self._snowpark_cols,
822
- drop_input_cols = self._drop_input_cols
823
- )
888
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
824
889
 
825
890
  transform_handlers = ModelTransformerBuilder.build(
826
891
  dataset=dataset,
@@ -833,7 +898,7 @@ class LogisticRegression(BaseTransformer):
833
898
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
834
899
  inference_method=inference_method,
835
900
  input_cols=self.input_cols,
836
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
901
+ expected_output_cols=expected_output_cols,
837
902
  **transform_kwargs
838
903
  )
839
904
  return output_df
@@ -862,12 +927,14 @@ class LogisticRegression(BaseTransformer):
862
927
  Output dataset with probability of the sample for each class in the model.
863
928
  """
864
929
  super()._check_dataset_type(dataset)
865
- inference_method="score_samples"
930
+ inference_method = "score_samples"
866
931
 
867
932
  # This dictionary contains optional kwargs for batch inference. These kwargs
868
933
  # are specific to the type of dataset used.
869
934
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
870
935
 
936
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
937
+
871
938
  if isinstance(dataset, DataFrame):
872
939
  self._deps = self._batch_inference_validate_snowpark(
873
940
  dataset=dataset,
@@ -880,6 +947,9 @@ class LogisticRegression(BaseTransformer):
880
947
  drop_input_cols = self._drop_input_cols,
881
948
  expected_output_cols_type="float",
882
949
  )
950
+ expected_output_cols = self._align_expected_output_names(
951
+ inference_method, dataset, expected_output_cols, output_cols_prefix
952
+ )
883
953
 
884
954
  elif isinstance(dataset, pd.DataFrame):
885
955
  transform_kwargs = dict(
@@ -898,7 +968,7 @@ class LogisticRegression(BaseTransformer):
898
968
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
899
969
  inference_method=inference_method,
900
970
  input_cols=self.input_cols,
901
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
971
+ expected_output_cols=expected_output_cols,
902
972
  **transform_kwargs
903
973
  )
904
974
  return output_df
@@ -1045,50 +1115,84 @@ class LogisticRegression(BaseTransformer):
1045
1115
  )
1046
1116
  return output_df
1047
1117
 
1118
+
1119
+
1120
+ def to_sklearn(self) -> Any:
1121
+ """Get sklearn.linear_model.LogisticRegression object.
1122
+ """
1123
+ if self._sklearn_object is None:
1124
+ self._sklearn_object = self._create_sklearn_object()
1125
+ return self._sklearn_object
1126
+
1127
+ def to_xgboost(self) -> Any:
1128
+ raise exceptions.SnowflakeMLException(
1129
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1130
+ original_exception=AttributeError(
1131
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1132
+ "to_xgboost()",
1133
+ "to_sklearn()"
1134
+ )
1135
+ ),
1136
+ )
1137
+
1138
+ def to_lightgbm(self) -> Any:
1139
+ raise exceptions.SnowflakeMLException(
1140
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1141
+ original_exception=AttributeError(
1142
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1143
+ "to_lightgbm()",
1144
+ "to_sklearn()"
1145
+ )
1146
+ ),
1147
+ )
1048
1148
 
1049
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1149
+ def _get_dependencies(self) -> List[str]:
1150
+ return self._deps
1151
+
1152
+
1153
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1050
1154
  self._model_signature_dict = dict()
1051
1155
 
1052
1156
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1053
1157
 
1054
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1158
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1055
1159
  outputs: List[BaseFeatureSpec] = []
1056
1160
  if hasattr(self, "predict"):
1057
1161
  # keep mypy happy
1058
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1162
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1059
1163
  # For classifier, the type of predict is the same as the type of label
1060
- if self._sklearn_object._estimator_type == 'classifier':
1061
- # label columns is the desired type for output
1164
+ if self._sklearn_object._estimator_type == "classifier":
1165
+ # label columns is the desired type for output
1062
1166
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1063
1167
  # rename the output columns
1064
1168
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1065
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1066
- ([] if self._drop_input_cols else inputs)
1067
- + outputs)
1169
+ self._model_signature_dict["predict"] = ModelSignature(
1170
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1171
+ )
1068
1172
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
1069
1173
  # For outlier models, returns -1 for outliers and 1 for inliers.
1070
- # Clusterer returns int64 cluster labels.
1174
+ # Clusterer returns int64 cluster labels.
1071
1175
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
1072
1176
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
1073
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1074
- ([] if self._drop_input_cols else inputs)
1075
- + outputs)
1076
-
1177
+ self._model_signature_dict["predict"] = ModelSignature(
1178
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1179
+ )
1180
+
1077
1181
  # For regressor, the type of predict is float64
1078
- elif self._sklearn_object._estimator_type == 'regressor':
1182
+ elif self._sklearn_object._estimator_type == "regressor":
1079
1183
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1080
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1081
- ([] if self._drop_input_cols else inputs)
1082
- + outputs)
1083
-
1184
+ self._model_signature_dict["predict"] = ModelSignature(
1185
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1186
+ )
1187
+
1084
1188
  for prob_func in PROB_FUNCTIONS:
1085
1189
  if hasattr(self, prob_func):
1086
1190
  output_cols_prefix: str = f"{prob_func}_"
1087
1191
  output_column_names = self._get_output_column_names(output_cols_prefix)
1088
1192
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1089
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
1090
- ([] if self._drop_input_cols else inputs)
1091
- + outputs)
1193
+ self._model_signature_dict[prob_func] = ModelSignature(
1194
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1195
+ )
1092
1196
 
1093
1197
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
1094
1198
  items = list(self._model_signature_dict.items())
@@ -1101,10 +1205,10 @@ class LogisticRegression(BaseTransformer):
1101
1205
  """Returns model signature of current class.
1102
1206
 
1103
1207
  Raises:
1104
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1208
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1105
1209
 
1106
1210
  Returns:
1107
- Dict[str, ModelSignature]: each method and its input output signature
1211
+ Dict with each method and its input output signature
1108
1212
  """
1109
1213
  if self._model_signature_dict is None:
1110
1214
  raise exceptions.SnowflakeMLException(
@@ -1112,35 +1216,3 @@ class LogisticRegression(BaseTransformer):
1112
1216
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1113
1217
  )
1114
1218
  return self._model_signature_dict
1115
-
1116
- def to_sklearn(self) -> Any:
1117
- """Get sklearn.linear_model.LogisticRegression object.
1118
- """
1119
- if self._sklearn_object is None:
1120
- self._sklearn_object = self._create_sklearn_object()
1121
- return self._sklearn_object
1122
-
1123
- def to_xgboost(self) -> Any:
1124
- raise exceptions.SnowflakeMLException(
1125
- error_code=error_codes.METHOD_NOT_ALLOWED,
1126
- original_exception=AttributeError(
1127
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1128
- "to_xgboost()",
1129
- "to_sklearn()"
1130
- )
1131
- ),
1132
- )
1133
-
1134
- def to_lightgbm(self) -> Any:
1135
- raise exceptions.SnowflakeMLException(
1136
- error_code=error_codes.METHOD_NOT_ALLOWED,
1137
- original_exception=AttributeError(
1138
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1139
- "to_lightgbm()",
1140
- "to_sklearn()"
1141
- )
1142
- ),
1143
- )
1144
-
1145
- def _get_dependencies(self) -> List[str]:
1146
- return self._deps