snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. snowflake/ml/_internal/env_utils.py +2 -1
  2. snowflake/ml/_internal/file_utils.py +35 -40
  3. snowflake/ml/_internal/telemetry.py +5 -8
  4. snowflake/ml/_internal/utils/identifier.py +74 -7
  5. snowflake/ml/_internal/utils/uri.py +7 -2
  6. snowflake/ml/model/_core_requirements.py +1 -1
  7. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
  8. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
  9. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
  10. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
  11. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
  12. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
  13. snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
  14. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
  15. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
  16. snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
  17. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
  18. snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
  19. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
  20. snowflake/ml/model/_deployer.py +14 -27
  21. snowflake/ml/model/_env.py +4 -4
  22. snowflake/ml/model/_handlers/_base.py +3 -1
  23. snowflake/ml/model/_handlers/custom.py +14 -2
  24. snowflake/ml/model/_handlers/pytorch.py +186 -0
  25. snowflake/ml/model/_handlers/sklearn.py +14 -8
  26. snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
  27. snowflake/ml/model/_handlers/torchscript.py +180 -0
  28. snowflake/ml/model/_handlers/xgboost.py +19 -9
  29. snowflake/ml/model/_model.py +27 -21
  30. snowflake/ml/model/_model_meta.py +33 -19
  31. snowflake/ml/model/model_signature.py +446 -66
  32. snowflake/ml/model/type_hints.py +28 -15
  33. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
  34. snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
  35. snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
  36. snowflake/ml/modeling/cluster/birch.py +79 -43
  37. snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
  38. snowflake/ml/modeling/cluster/dbscan.py +79 -43
  39. snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
  40. snowflake/ml/modeling/cluster/k_means.py +79 -43
  41. snowflake/ml/modeling/cluster/mean_shift.py +79 -43
  42. snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
  43. snowflake/ml/modeling/cluster/optics.py +79 -43
  44. snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
  45. snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
  46. snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
  47. snowflake/ml/modeling/compose/column_transformer.py +79 -43
  48. snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
  49. snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
  50. snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
  51. snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
  52. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
  53. snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
  54. snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
  55. snowflake/ml/modeling/covariance/oas.py +79 -43
  56. snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
  57. snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
  58. snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
  59. snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
  60. snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
  61. snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
  62. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
  63. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
  64. snowflake/ml/modeling/decomposition/pca.py +79 -43
  65. snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
  66. snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
  68. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
  69. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
  70. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
  71. snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
  72. snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
  73. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
  74. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
  75. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
  76. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
  79. snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
  80. snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
  81. snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
  82. snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
  83. snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
  84. snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
  85. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
  86. snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
  87. snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
  88. snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
  89. snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
  90. snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
  91. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
  92. snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
  94. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
  95. snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
  96. snowflake/ml/modeling/impute/knn_imputer.py +79 -43
  97. snowflake/ml/modeling/impute/missing_indicator.py +79 -43
  98. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
  99. snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
  100. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
  101. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
  102. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
  103. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
  104. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
  105. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
  106. snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
  107. snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
  108. snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
  109. snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
  110. snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
  111. snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
  112. snowflake/ml/modeling/linear_model/lars.py +79 -43
  113. snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
  114. snowflake/ml/modeling/linear_model/lasso.py +79 -43
  115. snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
  116. snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
  117. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
  118. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
  119. snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
  120. snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
  121. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
  123. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
  124. snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
  125. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
  126. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
  127. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
  128. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
  129. snowflake/ml/modeling/linear_model/perceptron.py +79 -43
  130. snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
  131. snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
  132. snowflake/ml/modeling/linear_model/ridge.py +79 -43
  133. snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
  134. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
  135. snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
  136. snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
  137. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
  138. snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
  139. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
  140. snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
  141. snowflake/ml/modeling/manifold/isomap.py +79 -43
  142. snowflake/ml/modeling/manifold/mds.py +79 -43
  143. snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
  144. snowflake/ml/modeling/manifold/tsne.py +79 -43
  145. snowflake/ml/modeling/metrics/classification.py +6 -1
  146. snowflake/ml/modeling/metrics/regression.py +517 -9
  147. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
  148. snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
  149. snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
  150. snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
  151. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
  152. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
  153. snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
  154. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
  155. snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
  156. snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
  157. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
  158. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
  159. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
  160. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
  161. snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
  162. snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
  163. snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
  164. snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
  165. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
  166. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
  167. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
  168. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
  169. snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
  170. snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
  171. snowflake/ml/modeling/pipeline/pipeline.py +24 -0
  172. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
  173. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
  174. snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
  175. snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
  176. snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
  177. snowflake/ml/modeling/svm/linear_svc.py +79 -43
  178. snowflake/ml/modeling/svm/linear_svr.py +79 -43
  179. snowflake/ml/modeling/svm/nu_svc.py +79 -43
  180. snowflake/ml/modeling/svm/nu_svr.py +79 -43
  181. snowflake/ml/modeling/svm/svc.py +79 -43
  182. snowflake/ml/modeling/svm/svr.py +79 -43
  183. snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
  184. snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
  185. snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
  186. snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
  187. snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
  188. snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
  189. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
  190. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
  191. snowflake/ml/registry/model_registry.py +123 -121
  192. snowflake/ml/version.py +1 -1
  193. {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
  194. snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
  195. snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
  196. {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@
7
7
  #
8
8
  import inspect
9
9
  import os
10
+ import posixpath
10
11
  from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
11
12
  from uuid import uuid4
12
13
 
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
27
28
  from snowflake.snowpark import DataFrame, Session
28
29
  from snowflake.snowpark.functions import pandas_udf, sproc
29
30
  from snowflake.snowpark.types import PandasSeries
31
+ from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
32
 
31
33
  from snowflake.ml.model.model_signature import (
32
34
  DataType,
@@ -352,7 +354,6 @@ class MLPClassifier(BaseTransformer):
352
354
  sample_weight_col: Optional[str] = None,
353
355
  ) -> None:
354
356
  super().__init__()
355
- self.id = str(uuid4()).replace("-", "_").upper()
356
357
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
357
358
 
358
359
  self._deps = list(deps)
@@ -394,6 +395,15 @@ class MLPClassifier(BaseTransformer):
394
395
  self.set_drop_input_cols(drop_input_cols)
395
396
  self.set_sample_weight_col(sample_weight_col)
396
397
 
398
+ def _get_rand_id(self) -> str:
399
+ """
400
+ Generate random id to be used in sproc and stage names.
401
+
402
+ Returns:
403
+ Random id string usable in sproc, table, and stage names.
404
+ """
405
+ return str(uuid4()).replace("-", "_").upper()
406
+
397
407
  def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
398
408
  """
399
409
  Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
@@ -472,7 +482,7 @@ class MLPClassifier(BaseTransformer):
472
482
  cp.dump(self._sklearn_object, local_transform_file)
473
483
 
474
484
  # Create temp stage to run fit.
475
- transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.id)
485
+ transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
476
486
  stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
477
487
  SqlResultValidator(
478
488
  session=session,
@@ -485,11 +495,12 @@ class MLPClassifier(BaseTransformer):
485
495
  expected_value=f"Stage area {transform_stage_name} successfully created."
486
496
  ).validate()
487
497
 
488
- stage_transform_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
498
+ # Use posixpath to construct stage paths
499
+ stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
500
+ stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
489
501
  local_result_file_name = get_temp_file_path()
490
- stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
491
502
 
492
- fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.id)
503
+ fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
493
504
  statement_params = telemetry.get_function_usage_statement_params(
494
505
  project=_PROJECT,
495
506
  subproject=_SUBPROJECT,
@@ -515,6 +526,7 @@ class MLPClassifier(BaseTransformer):
515
526
  replace=True,
516
527
  session=session,
517
528
  statement_params=statement_params,
529
+ anonymous=True
518
530
  )
519
531
  def fit_wrapper_sproc(
520
532
  session: Session,
@@ -523,7 +535,8 @@ class MLPClassifier(BaseTransformer):
523
535
  stage_result_file_name: str,
524
536
  input_cols: List[str],
525
537
  label_cols: List[str],
526
- sample_weight_col: Optional[str]
538
+ sample_weight_col: Optional[str],
539
+ statement_params: Dict[str, str]
527
540
  ) -> str:
528
541
  import cloudpickle as cp
529
542
  import numpy as np
@@ -590,15 +603,15 @@ class MLPClassifier(BaseTransformer):
590
603
  api_calls=[Session.call],
591
604
  custom_tags=dict([("autogen", True)]),
592
605
  )
593
- sproc_export_file_name = session.call(
594
- fit_sproc_name,
606
+ sproc_export_file_name = fit_wrapper_sproc(
607
+ session,
595
608
  query,
596
609
  stage_transform_file_name,
597
610
  stage_result_file_name,
598
611
  identifier.get_unescaped_names(self.input_cols),
599
612
  identifier.get_unescaped_names(self.label_cols),
600
613
  identifier.get_unescaped_names(self.sample_weight_col),
601
- statement_params=statement_params,
614
+ statement_params,
602
615
  )
603
616
 
604
617
  if "|" in sproc_export_file_name:
@@ -608,7 +621,7 @@ class MLPClassifier(BaseTransformer):
608
621
  print("\n".join(fields[1:]))
609
622
 
610
623
  session.file.get(
611
- os.path.join(stage_result_file_name, sproc_export_file_name),
624
+ posixpath.join(stage_result_file_name, sproc_export_file_name),
612
625
  local_result_file_name,
613
626
  statement_params=statement_params
614
627
  )
@@ -654,7 +667,7 @@ class MLPClassifier(BaseTransformer):
654
667
 
655
668
  # Register vectorized UDF for batch inference
656
669
  batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
657
- safe_id=self.id, method=inference_method)
670
+ safe_id=self._get_rand_id(), method=inference_method)
658
671
 
659
672
  # Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
660
673
  # will try to pickle all of self which fails.
@@ -746,7 +759,7 @@ class MLPClassifier(BaseTransformer):
746
759
  return transformed_pandas_df.to_dict("records")
747
760
 
748
761
  batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
749
- safe_id=self.id
762
+ safe_id=self._get_rand_id()
750
763
  )
751
764
 
752
765
  pass_through_columns = self._get_pass_through_columns(dataset)
@@ -802,26 +815,37 @@ class MLPClassifier(BaseTransformer):
802
815
  # input cols need to match unquoted / quoted
803
816
  input_cols = self.input_cols
804
817
  unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
818
+ quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
805
819
 
806
820
  estimator = self._sklearn_object
807
821
 
808
- input_df = dataset[input_cols] # Select input columns with quoted column names.
809
- if hasattr(estimator, "feature_names_in_"):
810
- missing_features = []
811
- for i, f in enumerate(getattr(estimator, "feature_names_in_")):
812
- if i >= len(input_cols) or (input_cols[i] != f and unquoted_input_cols[i] != f):
813
- missing_features.append(f)
814
-
815
- if len(missing_features) > 0:
816
- raise ValueError(
817
- "The feature names should match with those that were passed during fit.\n"
818
- f"Features seen during fit call but not present in the input: {missing_features}\n"
819
- f"Features in the input dataframe : {input_cols}\n"
820
- )
821
- input_df.columns = getattr(estimator, "feature_names_in_")
822
- else:
823
- # Just rename the column names to unquoted identifiers.
824
- input_df.columns = unquoted_input_cols # Replace the quoted columns identifier with unquoted column ids.
822
+ features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
823
+ missing_features = []
824
+ features_in_dataset = set(dataset.columns)
825
+ columns_to_select = []
826
+ for i, f in enumerate(features_required_by_estimator):
827
+ if (
828
+ i >= len(input_cols)
829
+ or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
830
+ or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
831
+ and quoted_input_cols[i] not in features_in_dataset)
832
+ ):
833
+ missing_features.append(f)
834
+ elif input_cols[i] in features_in_dataset:
835
+ columns_to_select.append(input_cols[i])
836
+ elif unquoted_input_cols[i] in features_in_dataset:
837
+ columns_to_select.append(unquoted_input_cols[i])
838
+ else:
839
+ columns_to_select.append(quoted_input_cols[i])
840
+
841
+ if len(missing_features) > 0:
842
+ raise ValueError(
843
+ "The feature names should match with those that were passed during fit.\n"
844
+ f"Features seen during fit call but not present in the input: {missing_features}\n"
845
+ f"Features in the input dataframe : {input_cols}\n"
846
+ )
847
+ input_df = dataset[columns_to_select]
848
+ input_df.columns = features_required_by_estimator
825
849
 
826
850
  transformed_numpy_array = getattr(estimator, inference_method)(
827
851
  input_df
@@ -902,11 +926,18 @@ class MLPClassifier(BaseTransformer):
902
926
  Transformed dataset.
903
927
  """
904
928
  if isinstance(dataset, DataFrame):
929
+ expected_type_inferred = ""
930
+ # when it is classifier, infer the datatype from label columns
931
+ if expected_type_inferred == "" and 'predict' in self.model_signatures:
932
+ expected_type_inferred = convert_sp_to_sf_type(
933
+ self.model_signatures['predict'].outputs[0].as_snowpark_type()
934
+ )
935
+
905
936
  output_df = self._batch_inference(
906
937
  dataset=dataset,
907
938
  inference_method="predict",
908
939
  expected_output_cols_list=self.output_cols,
909
- expected_output_cols_type="",
940
+ expected_output_cols_type=expected_type_inferred,
910
941
  )
911
942
  elif isinstance(dataset, pd.DataFrame):
912
943
  output_df = self._sklearn_inference(
@@ -977,10 +1008,10 @@ class MLPClassifier(BaseTransformer):
977
1008
 
978
1009
  def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
979
1010
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
980
- Returns an empty list if current object is not a classifier or not yet fitted.
1011
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
981
1012
  """
982
1013
  if getattr(self._sklearn_object, "classes_", None) is None:
983
- return []
1014
+ return [output_cols_prefix]
984
1015
 
985
1016
  classes = self._sklearn_object.classes_
986
1017
  if isinstance(classes, numpy.ndarray):
@@ -1209,7 +1240,7 @@ class MLPClassifier(BaseTransformer):
1209
1240
  cp.dump(self._sklearn_object, local_score_file)
1210
1241
 
1211
1242
  # Create temp stage to run score.
1212
- score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.id)
1243
+ score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
1213
1244
  session = dataset._session
1214
1245
  stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
1215
1246
  SqlResultValidator(
@@ -1223,8 +1254,9 @@ class MLPClassifier(BaseTransformer):
1223
1254
  expected_value=f"Stage area {score_stage_name} successfully created."
1224
1255
  ).validate()
1225
1256
 
1226
- stage_score_file_name = os.path.join(score_stage_name, os.path.basename(local_score_file_name))
1227
- score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.id)
1257
+ # Use posixpath to construct stage paths
1258
+ stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
1259
+ score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
1228
1260
  statement_params = telemetry.get_function_usage_statement_params(
1229
1261
  project=_PROJECT,
1230
1262
  subproject=_SUBPROJECT,
@@ -1250,6 +1282,7 @@ class MLPClassifier(BaseTransformer):
1250
1282
  replace=True,
1251
1283
  session=session,
1252
1284
  statement_params=statement_params,
1285
+ anonymous=True
1253
1286
  )
1254
1287
  def score_wrapper_sproc(
1255
1288
  session: Session,
@@ -1257,7 +1290,8 @@ class MLPClassifier(BaseTransformer):
1257
1290
  stage_score_file_name: str,
1258
1291
  input_cols: List[str],
1259
1292
  label_cols: List[str],
1260
- sample_weight_col: Optional[str]
1293
+ sample_weight_col: Optional[str],
1294
+ statement_params: Dict[str, str]
1261
1295
  ) -> float:
1262
1296
  import cloudpickle as cp
1263
1297
  import numpy as np
@@ -1307,14 +1341,14 @@ class MLPClassifier(BaseTransformer):
1307
1341
  api_calls=[Session.call],
1308
1342
  custom_tags=dict([("autogen", True)]),
1309
1343
  )
1310
- score = session.call(
1311
- score_sproc_name,
1344
+ score = score_wrapper_sproc(
1345
+ session,
1312
1346
  query,
1313
1347
  stage_score_file_name,
1314
1348
  identifier.get_unescaped_names(self.input_cols),
1315
1349
  identifier.get_unescaped_names(self.label_cols),
1316
1350
  identifier.get_unescaped_names(self.sample_weight_col),
1317
- statement_params=statement_params,
1351
+ statement_params,
1318
1352
  )
1319
1353
 
1320
1354
  cleanup_temp_files([local_score_file_name])
@@ -1332,18 +1366,20 @@ class MLPClassifier(BaseTransformer):
1332
1366
  if self._sklearn_object._estimator_type == 'classifier':
1333
1367
  outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
1334
1368
  outputs = _rename_features(outputs, self.output_cols) # rename the output columns
1335
- self._model_signature_dict["predict"] = ModelSignature(inputs, outputs)
1369
+ self._model_signature_dict["predict"] = ModelSignature(inputs,
1370
+ ([] if self._drop_input_cols else inputs) + outputs)
1336
1371
  # For regressor, the type of predict is float64
1337
1372
  elif self._sklearn_object._estimator_type == 'regressor':
1338
1373
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1339
- self._model_signature_dict["predict"] = ModelSignature(inputs, outputs)
1340
-
1374
+ self._model_signature_dict["predict"] = ModelSignature(inputs,
1375
+ ([] if self._drop_input_cols else inputs) + outputs)
1341
1376
  for prob_func in PROB_FUNCTIONS:
1342
1377
  if hasattr(self, prob_func):
1343
1378
  output_cols_prefix: str = f"{prob_func}_"
1344
1379
  output_column_names = self._get_output_column_names(output_cols_prefix)
1345
1380
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1346
- self._model_signature_dict[prob_func] = ModelSignature(inputs, outputs)
1381
+ self._model_signature_dict[prob_func] = ModelSignature(inputs,
1382
+ ([] if self._drop_input_cols else inputs) + outputs)
1347
1383
 
1348
1384
  @property
1349
1385
  def model_signatures(self) -> Dict[str, ModelSignature]:
@@ -7,6 +7,7 @@
7
7
  #
8
8
  import inspect
9
9
  import os
10
+ import posixpath
10
11
  from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
11
12
  from uuid import uuid4
12
13
 
@@ -27,6 +28,7 @@ from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get
27
28
  from snowflake.snowpark import DataFrame, Session
28
29
  from snowflake.snowpark.functions import pandas_udf, sproc
29
30
  from snowflake.snowpark.types import PandasSeries
31
+ from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
32
 
31
33
  from snowflake.ml.model.model_signature import (
32
34
  DataType,
@@ -348,7 +350,6 @@ class MLPRegressor(BaseTransformer):
348
350
  sample_weight_col: Optional[str] = None,
349
351
  ) -> None:
350
352
  super().__init__()
351
- self.id = str(uuid4()).replace("-", "_").upper()
352
353
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
353
354
 
354
355
  self._deps = list(deps)
@@ -390,6 +391,15 @@ class MLPRegressor(BaseTransformer):
390
391
  self.set_drop_input_cols(drop_input_cols)
391
392
  self.set_sample_weight_col(sample_weight_col)
392
393
 
394
+ def _get_rand_id(self) -> str:
395
+ """
396
+ Generate random id to be used in sproc and stage names.
397
+
398
+ Returns:
399
+ Random id string usable in sproc, table, and stage names.
400
+ """
401
+ return str(uuid4()).replace("-", "_").upper()
402
+
393
403
  def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
394
404
  """
395
405
  Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
@@ -468,7 +478,7 @@ class MLPRegressor(BaseTransformer):
468
478
  cp.dump(self._sklearn_object, local_transform_file)
469
479
 
470
480
  # Create temp stage to run fit.
471
- transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self.id)
481
+ transform_stage_name = "SNOWML_TRANSFORM_{safe_id}".format(safe_id=self._get_rand_id())
472
482
  stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
473
483
  SqlResultValidator(
474
484
  session=session,
@@ -481,11 +491,12 @@ class MLPRegressor(BaseTransformer):
481
491
  expected_value=f"Stage area {transform_stage_name} successfully created."
482
492
  ).validate()
483
493
 
484
- stage_transform_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
494
+ # Use posixpath to construct stage paths
495
+ stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
496
+ stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
485
497
  local_result_file_name = get_temp_file_path()
486
- stage_result_file_name = os.path.join(transform_stage_name, os.path.basename(local_transform_file_name))
487
498
 
488
- fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self.id)
499
+ fit_sproc_name = "SNOWML_FIT_{safe_id}".format(safe_id=self._get_rand_id())
489
500
  statement_params = telemetry.get_function_usage_statement_params(
490
501
  project=_PROJECT,
491
502
  subproject=_SUBPROJECT,
@@ -511,6 +522,7 @@ class MLPRegressor(BaseTransformer):
511
522
  replace=True,
512
523
  session=session,
513
524
  statement_params=statement_params,
525
+ anonymous=True
514
526
  )
515
527
  def fit_wrapper_sproc(
516
528
  session: Session,
@@ -519,7 +531,8 @@ class MLPRegressor(BaseTransformer):
519
531
  stage_result_file_name: str,
520
532
  input_cols: List[str],
521
533
  label_cols: List[str],
522
- sample_weight_col: Optional[str]
534
+ sample_weight_col: Optional[str],
535
+ statement_params: Dict[str, str]
523
536
  ) -> str:
524
537
  import cloudpickle as cp
525
538
  import numpy as np
@@ -586,15 +599,15 @@ class MLPRegressor(BaseTransformer):
586
599
  api_calls=[Session.call],
587
600
  custom_tags=dict([("autogen", True)]),
588
601
  )
589
- sproc_export_file_name = session.call(
590
- fit_sproc_name,
602
+ sproc_export_file_name = fit_wrapper_sproc(
603
+ session,
591
604
  query,
592
605
  stage_transform_file_name,
593
606
  stage_result_file_name,
594
607
  identifier.get_unescaped_names(self.input_cols),
595
608
  identifier.get_unescaped_names(self.label_cols),
596
609
  identifier.get_unescaped_names(self.sample_weight_col),
597
- statement_params=statement_params,
610
+ statement_params,
598
611
  )
599
612
 
600
613
  if "|" in sproc_export_file_name:
@@ -604,7 +617,7 @@ class MLPRegressor(BaseTransformer):
604
617
  print("\n".join(fields[1:]))
605
618
 
606
619
  session.file.get(
607
- os.path.join(stage_result_file_name, sproc_export_file_name),
620
+ posixpath.join(stage_result_file_name, sproc_export_file_name),
608
621
  local_result_file_name,
609
622
  statement_params=statement_params
610
623
  )
@@ -650,7 +663,7 @@ class MLPRegressor(BaseTransformer):
650
663
 
651
664
  # Register vectorized UDF for batch inference
652
665
  batch_inference_udf_name = "SNOWML_BATCH_INFERENCE_{safe_id}_{method}".format(
653
- safe_id=self.id, method=inference_method)
666
+ safe_id=self._get_rand_id(), method=inference_method)
654
667
 
655
668
  # Need to do this since if we use self._sklearn_object directly in the UDF, Snowpark
656
669
  # will try to pickle all of self which fails.
@@ -742,7 +755,7 @@ class MLPRegressor(BaseTransformer):
742
755
  return transformed_pandas_df.to_dict("records")
743
756
 
744
757
  batch_inference_table_name = "SNOWML_BATCH_INFERENCE_INPUT_TABLE_{safe_id}".format(
745
- safe_id=self.id
758
+ safe_id=self._get_rand_id()
746
759
  )
747
760
 
748
761
  pass_through_columns = self._get_pass_through_columns(dataset)
@@ -798,26 +811,37 @@ class MLPRegressor(BaseTransformer):
798
811
  # input cols need to match unquoted / quoted
799
812
  input_cols = self.input_cols
800
813
  unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
814
+ quoted_input_cols = identifier.get_escaped_names(unquoted_input_cols)
801
815
 
802
816
  estimator = self._sklearn_object
803
817
 
804
- input_df = dataset[input_cols] # Select input columns with quoted column names.
805
- if hasattr(estimator, "feature_names_in_"):
806
- missing_features = []
807
- for i, f in enumerate(getattr(estimator, "feature_names_in_")):
808
- if i >= len(input_cols) or (input_cols[i] != f and unquoted_input_cols[i] != f):
809
- missing_features.append(f)
810
-
811
- if len(missing_features) > 0:
812
- raise ValueError(
813
- "The feature names should match with those that were passed during fit.\n"
814
- f"Features seen during fit call but not present in the input: {missing_features}\n"
815
- f"Features in the input dataframe : {input_cols}\n"
816
- )
817
- input_df.columns = getattr(estimator, "feature_names_in_")
818
- else:
819
- # Just rename the column names to unquoted identifiers.
820
- input_df.columns = unquoted_input_cols # Replace the quoted columns identifier with unquoted column ids.
818
+ features_required_by_estimator = getattr(estimator, "feature_names_in_") if hasattr(estimator, "feature_names_in_") else unquoted_input_cols
819
+ missing_features = []
820
+ features_in_dataset = set(dataset.columns)
821
+ columns_to_select = []
822
+ for i, f in enumerate(features_required_by_estimator):
823
+ if (
824
+ i >= len(input_cols)
825
+ or (input_cols[i] != f and unquoted_input_cols[i] != f and quoted_input_cols[i] != f)
826
+ or (input_cols[i] not in features_in_dataset and unquoted_input_cols[i] not in features_in_dataset
827
+ and quoted_input_cols[i] not in features_in_dataset)
828
+ ):
829
+ missing_features.append(f)
830
+ elif input_cols[i] in features_in_dataset:
831
+ columns_to_select.append(input_cols[i])
832
+ elif unquoted_input_cols[i] in features_in_dataset:
833
+ columns_to_select.append(unquoted_input_cols[i])
834
+ else:
835
+ columns_to_select.append(quoted_input_cols[i])
836
+
837
+ if len(missing_features) > 0:
838
+ raise ValueError(
839
+ "The feature names should match with those that were passed during fit.\n"
840
+ f"Features seen during fit call but not present in the input: {missing_features}\n"
841
+ f"Features in the input dataframe : {input_cols}\n"
842
+ )
843
+ input_df = dataset[columns_to_select]
844
+ input_df.columns = features_required_by_estimator
821
845
 
822
846
  transformed_numpy_array = getattr(estimator, inference_method)(
823
847
  input_df
@@ -898,11 +922,18 @@ class MLPRegressor(BaseTransformer):
898
922
  Transformed dataset.
899
923
  """
900
924
  if isinstance(dataset, DataFrame):
925
+ expected_type_inferred = "float"
926
+ # when it is classifier, infer the datatype from label columns
927
+ if expected_type_inferred == "" and 'predict' in self.model_signatures:
928
+ expected_type_inferred = convert_sp_to_sf_type(
929
+ self.model_signatures['predict'].outputs[0].as_snowpark_type()
930
+ )
931
+
901
932
  output_df = self._batch_inference(
902
933
  dataset=dataset,
903
934
  inference_method="predict",
904
935
  expected_output_cols_list=self.output_cols,
905
- expected_output_cols_type="float",
936
+ expected_output_cols_type=expected_type_inferred,
906
937
  )
907
938
  elif isinstance(dataset, pd.DataFrame):
908
939
  output_df = self._sklearn_inference(
@@ -973,10 +1004,10 @@ class MLPRegressor(BaseTransformer):
973
1004
 
974
1005
  def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
975
1006
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
976
- Returns an empty list if current object is not a classifier or not yet fitted.
1007
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
977
1008
  """
978
1009
  if getattr(self._sklearn_object, "classes_", None) is None:
979
- return []
1010
+ return [output_cols_prefix]
980
1011
 
981
1012
  classes = self._sklearn_object.classes_
982
1013
  if isinstance(classes, numpy.ndarray):
@@ -1201,7 +1232,7 @@ class MLPRegressor(BaseTransformer):
1201
1232
  cp.dump(self._sklearn_object, local_score_file)
1202
1233
 
1203
1234
  # Create temp stage to run score.
1204
- score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.id)
1235
+ score_stage_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
1205
1236
  session = dataset._session
1206
1237
  stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {score_stage_name};"
1207
1238
  SqlResultValidator(
@@ -1215,8 +1246,9 @@ class MLPRegressor(BaseTransformer):
1215
1246
  expected_value=f"Stage area {score_stage_name} successfully created."
1216
1247
  ).validate()
1217
1248
 
1218
- stage_score_file_name = os.path.join(score_stage_name, os.path.basename(local_score_file_name))
1219
- score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self.id)
1249
+ # Use posixpath to construct stage paths
1250
+ stage_score_file_name = posixpath.join(score_stage_name, os.path.basename(local_score_file_name))
1251
+ score_sproc_name = "SNOWML_SCORE_{safe_id}".format(safe_id=self._get_rand_id())
1220
1252
  statement_params = telemetry.get_function_usage_statement_params(
1221
1253
  project=_PROJECT,
1222
1254
  subproject=_SUBPROJECT,
@@ -1242,6 +1274,7 @@ class MLPRegressor(BaseTransformer):
1242
1274
  replace=True,
1243
1275
  session=session,
1244
1276
  statement_params=statement_params,
1277
+ anonymous=True
1245
1278
  )
1246
1279
  def score_wrapper_sproc(
1247
1280
  session: Session,
@@ -1249,7 +1282,8 @@ class MLPRegressor(BaseTransformer):
1249
1282
  stage_score_file_name: str,
1250
1283
  input_cols: List[str],
1251
1284
  label_cols: List[str],
1252
- sample_weight_col: Optional[str]
1285
+ sample_weight_col: Optional[str],
1286
+ statement_params: Dict[str, str]
1253
1287
  ) -> float:
1254
1288
  import cloudpickle as cp
1255
1289
  import numpy as np
@@ -1299,14 +1333,14 @@ class MLPRegressor(BaseTransformer):
1299
1333
  api_calls=[Session.call],
1300
1334
  custom_tags=dict([("autogen", True)]),
1301
1335
  )
1302
- score = session.call(
1303
- score_sproc_name,
1336
+ score = score_wrapper_sproc(
1337
+ session,
1304
1338
  query,
1305
1339
  stage_score_file_name,
1306
1340
  identifier.get_unescaped_names(self.input_cols),
1307
1341
  identifier.get_unescaped_names(self.label_cols),
1308
1342
  identifier.get_unescaped_names(self.sample_weight_col),
1309
- statement_params=statement_params,
1343
+ statement_params,
1310
1344
  )
1311
1345
 
1312
1346
  cleanup_temp_files([local_score_file_name])
@@ -1324,18 +1358,20 @@ class MLPRegressor(BaseTransformer):
1324
1358
  if self._sklearn_object._estimator_type == 'classifier':
1325
1359
  outputs = _infer_signature(dataset[self.label_cols], "output") # label columns is the desired type for output
1326
1360
  outputs = _rename_features(outputs, self.output_cols) # rename the output columns
1327
- self._model_signature_dict["predict"] = ModelSignature(inputs, outputs)
1361
+ self._model_signature_dict["predict"] = ModelSignature(inputs,
1362
+ ([] if self._drop_input_cols else inputs) + outputs)
1328
1363
  # For regressor, the type of predict is float64
1329
1364
  elif self._sklearn_object._estimator_type == 'regressor':
1330
1365
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1331
- self._model_signature_dict["predict"] = ModelSignature(inputs, outputs)
1332
-
1366
+ self._model_signature_dict["predict"] = ModelSignature(inputs,
1367
+ ([] if self._drop_input_cols else inputs) + outputs)
1333
1368
  for prob_func in PROB_FUNCTIONS:
1334
1369
  if hasattr(self, prob_func):
1335
1370
  output_cols_prefix: str = f"{prob_func}_"
1336
1371
  output_column_names = self._get_output_column_names(output_cols_prefix)
1337
1372
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1338
- self._model_signature_dict[prob_func] = ModelSignature(inputs, outputs)
1373
+ self._model_signature_dict[prob_func] = ModelSignature(inputs,
1374
+ ([] if self._drop_input_cols else inputs) + outputs)
1339
1375
 
1340
1376
  @property
1341
1377
  def model_signatures(self) -> Dict[str, ModelSignature]:
@@ -14,6 +14,7 @@ from sklearn.utils import metaestimators
14
14
 
15
15
  from snowflake import snowpark
16
16
  from snowflake.ml._internal import telemetry
17
+ from snowflake.ml.model.model_signature import ModelSignature, _infer_signature
17
18
  from snowflake.ml.modeling.framework import _utils, base
18
19
 
19
20
  _PROJECT = "ModelDevelopment"
@@ -103,6 +104,8 @@ class Pipeline(base.BaseTransformer):
103
104
  self._transformers_to_input_indices: Dict[str, List[int]] = {}
104
105
  self._is_convertable_to_sklearn = True
105
106
 
107
+ self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
108
+
106
109
  deps: Set[str] = {f"pandas=={pd.__version__}", f"scikit-learn=={skversion}"}
107
110
  for _, obj in steps:
108
111
  if isinstance(obj, base.BaseTransformer):
@@ -241,6 +244,7 @@ class Pipeline(base.BaseTransformer):
241
244
  step_name=estimator[0], all_cols=all_cols, input_cols=estimator[1].get_input_cols()
242
245
  )
243
246
 
247
+ self._get_model_signatures(dataset=dataset)
244
248
  self._is_fitted = True
245
249
  return self
246
250
 
@@ -309,6 +313,7 @@ class Pipeline(base.BaseTransformer):
309
313
  res = estimator[1].fit(transformed_dataset).transform(transformed_dataset)
310
314
  return res
311
315
 
316
+ self._get_model_signatures(dataset=dataset)
312
317
  self._is_fitted = True
313
318
  return transformed_dataset
314
319
 
@@ -346,6 +351,7 @@ class Pipeline(base.BaseTransformer):
346
351
  else:
347
352
  transformed_dataset = estimator[1].fit(transformed_dataset).predict(transformed_dataset)
348
353
 
354
+ self._get_model_signatures(dataset=dataset)
349
355
  self._is_fitted = True
350
356
  return transformed_dataset
351
357
 
@@ -559,3 +565,21 @@ class Pipeline(base.BaseTransformer):
559
565
 
560
566
  def _get_dependencies(self) -> List[str]:
561
567
  return self._deps
568
+
569
+ def _get_model_signatures(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> None:
570
+ self._model_signature_dict = dict()
571
+
572
+ input_columns = self._get_sanitized_list_of_columns(dataset.columns)
573
+ inputs_signature = _infer_signature(dataset[input_columns], "input")
574
+
575
+ estimator_step = self._get_estimator()
576
+ if estimator_step:
577
+ estimator_signatures = estimator_step[1].model_signatures
578
+ for method, signature in estimator_signatures.items():
579
+ self._model_signature_dict[method] = ModelSignature(inputs=inputs_signature, outputs=signature.outputs)
580
+
581
+ @property
582
+ def model_signatures(self) -> Dict[str, ModelSignature]:
583
+ if self._model_signature_dict is None:
584
+ raise RuntimeError("Estimator not fitted before accessing property model_signatures! ")
585
+ return self._model_signature_dict