snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,12 @@
1
1
  #!/usr/bin/env python3
2
+ import inspect
3
+ import os
4
+ import posixpath
5
+ import tempfile
2
6
  from itertools import chain
3
7
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
4
8
 
9
+ import cloudpickle as cp
5
10
  import numpy as np
6
11
  import pandas as pd
7
12
  from sklearn import __version__ as skversion, pipeline
@@ -10,14 +15,20 @@ from sklearn.preprocessing import FunctionTransformer
10
15
  from sklearn.utils import metaestimators
11
16
 
12
17
  from snowflake import snowpark
13
- from snowflake.ml._internal import telemetry
18
+ from snowflake.ml._internal import file_utils, telemetry
14
19
  from snowflake.ml._internal.exceptions import error_codes, exceptions
15
- from snowflake.ml._internal.utils import snowpark_dataframe_utils
20
+ from snowflake.ml._internal.utils import snowpark_dataframe_utils, temp_file_utils
16
21
  from snowflake.ml.model.model_signature import ModelSignature, _infer_signature
22
+ from snowflake.ml.modeling._internal.model_transformer_builder import (
23
+ ModelTransformerBuilder,
24
+ )
17
25
  from snowflake.ml.modeling.framework import _utils, base
26
+ from snowflake.snowpark import Session, functions as F
27
+ from snowflake.snowpark._internal import utils as snowpark_utils
18
28
 
19
29
  _PROJECT = "ModelDevelopment"
20
30
  _SUBPROJECT = "Framework"
31
+ IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME"
21
32
 
22
33
 
23
34
  def _final_step_has(attr: str) -> Callable[..., bool]:
@@ -113,6 +124,8 @@ class Pipeline(base.BaseTransformer):
113
124
  if isinstance(obj, base.BaseTransformer):
114
125
  deps = deps | set(obj._get_dependencies())
115
126
  self._deps = list(deps)
127
+ self._sklearn_object = None
128
+ self.label_cols = self._get_label_cols()
116
129
 
117
130
  @staticmethod
118
131
  def _is_estimator(obj: object) -> bool:
@@ -147,6 +160,33 @@ class Pipeline(base.BaseTransformer):
147
160
  self._n_features_in = []
148
161
  self._transformers_to_input_indices = {}
149
162
 
163
+ def _is_convertible_to_sklearn_object(self) -> bool:
164
+ """Checks if the pipeline can be converted to a native sklearn pipeline.
165
+ - We can not create an sklearn pipeline if its label or sample weight column are
166
+ modified in the pipeline.
167
+ - We can not create an sklearn pipeline if any of its steps cannot be converted to an sklearn pipeline
168
+ - We can not create an sklearn pipeline if input columns are specified in any step other than
169
+ the first step
170
+
171
+ Returns:
172
+ True if the pipeline can be converted to a native sklearn pipeline, else false.
173
+ """
174
+ if self._is_pipeline_modifying_label_or_sample_weight():
175
+ return False
176
+
177
+ # check that nested pipelines can be converted to sklearn
178
+ for _, base_estimator in self.steps:
179
+ if hasattr(base_estimator, "_is_convertible_to_sklearn_object"):
180
+ if not base_estimator._is_convertible_to_sklearn_object():
181
+ return False
182
+
183
+ # check that no column after the first column has 'input columns' set.
184
+ for _, base_estimator in self.steps[1:]:
185
+ if base_estimator.get_input_cols():
186
+ # We only want Falsy values - None and []
187
+ return False
188
+ return True
189
+
150
190
  def _is_pipeline_modifying_label_or_sample_weight(self) -> bool:
151
191
  """
152
192
  Checks if pipeline is modifying label or sample_weight columns.
@@ -214,27 +254,167 @@ class Pipeline(base.BaseTransformer):
214
254
  self._append_step_feature_consumption_info(
215
255
  step_name=name, all_cols=transformed_dataset.columns[:], input_cols=trans.get_input_cols()
216
256
  )
217
- if has_callable_attr(trans, "fit_transform"):
218
- transformed_dataset = trans.fit_transform(transformed_dataset)
219
- else:
220
- trans.fit(transformed_dataset)
221
- transformed_dataset = trans.transform(transformed_dataset)
257
+ trans.fit(transformed_dataset)
258
+ transformed_dataset = trans.transform(transformed_dataset)
222
259
 
223
260
  return transformed_dataset
224
261
 
262
+ def _upload_model_to_stage(self, stage_name: str, estimator: object, session: Session) -> Tuple[str, str]:
263
+ """
264
+ Util method to pickle and upload the model to a temp Snowflake stage.
265
+
266
+ Args:
267
+ stage_name: Stage name to save model.
268
+ estimator: the pipeline estimator itself
269
+ session: Session object
270
+
271
+ Returns:
272
+ a tuple containing stage file paths for pickled input model for training and location to store trained
273
+ models(response from training sproc).
274
+ """
275
+ # Create a temp file and dump the transform to that file.
276
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
277
+ with open(local_transform_file_name, mode="w+b") as local_transform_file:
278
+ cp.dump(estimator, local_transform_file)
279
+
280
+ # Use posixpath to construct stage paths
281
+ stage_transform_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
282
+ stage_result_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
283
+
284
+ # Put locally serialized transform on stage.
285
+ session.file.put(
286
+ local_transform_file_name,
287
+ stage_transform_file_name,
288
+ auto_compress=False,
289
+ overwrite=True,
290
+ )
291
+
292
+ temp_file_utils.cleanup_temp_files([local_transform_file_name])
293
+ return (stage_transform_file_name, stage_result_file_name)
294
+
295
+ def _fit_snowpark_dataframe_within_one_sproc(self, session: Session, dataset: snowpark.DataFrame) -> None:
296
+ # Extract queries that generated the dataframe. We will need to pass it to score procedure.
297
+ sql_queries = dataset.queries["queries"]
298
+
299
+ # Zip the current snowml package
300
+ with tempfile.TemporaryDirectory() as tmpdir:
301
+ snowml_zip_module_filename = os.path.join(tmpdir, "snowflake-ml-python.zip")
302
+ file_utils.zip_python_package(snowml_zip_module_filename, "snowflake.ml")
303
+ imports = [snowml_zip_module_filename]
304
+
305
+ sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
306
+ required_deps = self._deps
307
+ sproc_statement_params = telemetry.get_function_usage_statement_params(
308
+ project=_PROJECT,
309
+ subproject="PIPELINE",
310
+ function_name=telemetry.get_statement_params_full_func_name(
311
+ inspect.currentframe(), self.__class__.__name__
312
+ ),
313
+ api_calls=[F.sproc],
314
+ )
315
+ transform_stage_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
316
+ stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
317
+ session.sql(stage_creation_query).collect()
318
+ (stage_estimator_file_name, stage_result_file_name) = self._upload_model_to_stage(
319
+ transform_stage_name, self, session
320
+ )
321
+
322
+ def pipeline_within_one_sproc(
323
+ session: Session,
324
+ sql_queries: List[str],
325
+ stage_estimator_file_name: str,
326
+ stage_result_file_name: str,
327
+ sproc_statement_params: Dict[str, str],
328
+ ) -> str:
329
+ import os
330
+
331
+ import cloudpickle as cp
332
+ import pandas as pd
333
+
334
+ for query in sql_queries[:-1]:
335
+ _ = session.sql(query).collect(statement_params=sproc_statement_params)
336
+ sp_df = session.sql(sql_queries[-1])
337
+ df: pd.DataFrame = sp_df.to_pandas(statement_params=sproc_statement_params)
338
+ df.columns = sp_df.columns
339
+
340
+ local_estimator_file_name = temp_file_utils.get_temp_file_path()
341
+
342
+ session.file.get(stage_estimator_file_name, local_estimator_file_name)
343
+
344
+ local_estimator_file_path = os.path.join(
345
+ local_estimator_file_name, os.listdir(local_estimator_file_name)[0]
346
+ )
347
+ with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
348
+ estimator = cp.load(local_estimator_file_obj)
349
+
350
+ estimator.fit(df)
351
+
352
+ local_result_file_name = temp_file_utils.get_temp_file_path()
353
+
354
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
355
+ cp.dump(estimator, local_result_file_obj)
356
+
357
+ session.file.put(
358
+ local_result_file_name,
359
+ stage_result_file_name,
360
+ auto_compress=False,
361
+ overwrite=True,
362
+ statement_params=sproc_statement_params,
363
+ )
364
+
365
+ return str(os.path.basename(local_result_file_name))
366
+
367
+ session.sproc.register(
368
+ func=pipeline_within_one_sproc,
369
+ is_permanent=False,
370
+ name=sproc_name,
371
+ packages=required_deps, # type: ignore[arg-type]
372
+ replace=True,
373
+ session=session,
374
+ anonymous=True,
375
+ imports=imports, # type: ignore[arg-type]
376
+ statement_params=sproc_statement_params,
377
+ )
378
+
379
+ sproc_export_file_name: str = pipeline_within_one_sproc(
380
+ session,
381
+ sql_queries,
382
+ stage_estimator_file_name,
383
+ stage_result_file_name,
384
+ sproc_statement_params,
385
+ )
386
+
387
+ local_result_file_name = temp_file_utils.get_temp_file_path()
388
+ session.file.get(
389
+ posixpath.join(stage_estimator_file_name, sproc_export_file_name),
390
+ local_result_file_name,
391
+ statement_params=sproc_statement_params,
392
+ )
393
+
394
+ with open(os.path.join(local_result_file_name, sproc_export_file_name), mode="r+b") as result_file_obj:
395
+ fit_estimator = cp.load(result_file_obj)
396
+
397
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
398
+ for key, val in vars(fit_estimator).items():
399
+ setattr(self, key, val)
400
+
225
401
  @telemetry.send_api_usage_telemetry(
226
402
  project=_PROJECT,
227
403
  subproject=_SUBPROJECT,
228
404
  )
229
- def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "Pipeline":
405
+ def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame], squash: Optional[bool] = False) -> "Pipeline":
230
406
  """
231
407
  Fit the entire pipeline using the dataset.
232
408
 
233
409
  Args:
234
410
  dataset: Input dataset.
411
+ squash: Run the whole pipeline within a stored procedure
235
412
 
236
413
  Returns:
237
414
  Fitted pipeline.
415
+
416
+ Raises:
417
+ ValueError: A pipeline incompatible with sklearn is used on MLRS
238
418
  """
239
419
 
240
420
  self._validate_steps()
@@ -243,19 +423,33 @@ class Pipeline(base.BaseTransformer):
243
423
  if isinstance(dataset, snowpark.DataFrame)
244
424
  else dataset
245
425
  )
246
- transformed_dataset = self._fit_transform_dataset(dataset)
247
426
 
248
- estimator = self._get_estimator()
249
- if estimator:
250
- all_cols = transformed_dataset.columns[:]
251
- estimator[1].fit(transformed_dataset)
427
+ if self._can_be_trained_in_ml_runtime(dataset):
428
+ if not self._is_convertible_to_sklearn_object():
429
+ raise ValueError("This pipeline cannot be converted to an sklearn pipeline.")
430
+ self._fit_ml_runtime(dataset)
252
431
 
253
- self._append_step_feature_consumption_info(
254
- step_name=estimator[0], all_cols=all_cols, input_cols=estimator[1].get_input_cols()
255
- )
432
+ elif squash and isinstance(dataset, snowpark.DataFrame):
433
+ session = dataset._session
434
+ assert session is not None
435
+ self._fit_snowpark_dataframe_within_one_sproc(session=session, dataset=dataset)
436
+
437
+ else:
438
+ transformed_dataset = self._fit_transform_dataset(dataset)
439
+
440
+ estimator = self._get_estimator()
441
+ if estimator:
442
+ all_cols = transformed_dataset.columns[:]
443
+ estimator[1].fit(transformed_dataset)
444
+
445
+ self._append_step_feature_consumption_info(
446
+ step_name=estimator[0], all_cols=all_cols, input_cols=estimator[1].get_input_cols()
447
+ )
448
+
449
+ self._generate_model_signatures(dataset=dataset)
256
450
 
257
- self._generate_model_signatures(dataset=dataset)
258
451
  self._is_fitted = True
452
+
259
453
  return self
260
454
 
261
455
  @metaestimators.available_if(_final_step_has("transform")) # type: ignore[misc]
@@ -280,6 +474,22 @@ class Pipeline(base.BaseTransformer):
280
474
  else dataset
281
475
  )
282
476
 
477
+ if self._sklearn_object is not None:
478
+ handler = ModelTransformerBuilder.build(
479
+ dataset=dataset,
480
+ estimator=self._sklearn_object,
481
+ class_name="Pipeline",
482
+ subproject="",
483
+ autogenerated=False,
484
+ )
485
+ return handler.batch_inference(
486
+ inference_method="transform",
487
+ input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset),
488
+ expected_output_cols=self._infer_output_cols(),
489
+ session=dataset._session,
490
+ dependencies=self._deps,
491
+ )
492
+
283
493
  transformed_dataset = self._transform_dataset(dataset=dataset)
284
494
  estimator = self._get_estimator()
285
495
  if estimator:
@@ -389,8 +599,32 @@ class Pipeline(base.BaseTransformer):
389
599
 
390
600
  Returns:
391
601
  Output dataset.
602
+
603
+ Raises:
604
+ ValueError: An sklearn object has not been fit and stored before calling this function.
392
605
  """
393
- return self._invoke_estimator_func("predict", dataset)
606
+ if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
607
+ if self._sklearn_object is None:
608
+ raise ValueError("Model must be fit before inference.")
609
+
610
+ expected_output_cols = self._infer_output_cols()
611
+ handler = ModelTransformerBuilder.build(
612
+ dataset=dataset,
613
+ estimator=self._sklearn_object,
614
+ class_name="Pipeline",
615
+ subproject="",
616
+ autogenerated=False,
617
+ )
618
+ return handler.batch_inference(
619
+ inference_method="predict",
620
+ input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset),
621
+ expected_output_cols=expected_output_cols,
622
+ session=dataset._session,
623
+ dependencies=self._deps,
624
+ )
625
+
626
+ else:
627
+ return self._invoke_estimator_func("predict", dataset)
394
628
 
395
629
  @metaestimators.available_if(_final_step_has("score_samples")) # type: ignore[misc]
396
630
  @telemetry.send_api_usage_telemetry(
@@ -408,8 +642,32 @@ class Pipeline(base.BaseTransformer):
408
642
 
409
643
  Returns:
410
644
  Output dataset.
645
+
646
+ Raises:
647
+ ValueError: An sklearn object has not been fit before calling this function
411
648
  """
412
- return self._invoke_estimator_func("score_samples", dataset)
649
+
650
+ if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
651
+ if self._sklearn_object is None:
652
+ raise ValueError("Model must be fit before inference.")
653
+
654
+ expected_output_cols = self._get_output_column_names("score_samples")
655
+ handler = ModelTransformerBuilder.build(
656
+ dataset=dataset,
657
+ estimator=self._sklearn_object,
658
+ class_name="Pipeline",
659
+ subproject="",
660
+ autogenerated=False,
661
+ )
662
+ return handler.batch_inference(
663
+ inference_method="score_samples",
664
+ input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset),
665
+ expected_output_cols=expected_output_cols,
666
+ session=dataset._session,
667
+ dependencies=self._deps,
668
+ )
669
+ else:
670
+ return self._invoke_estimator_func("score_samples", dataset)
413
671
 
414
672
  @metaestimators.available_if(_final_step_has("predict_proba")) # type: ignore[misc]
415
673
  @telemetry.send_api_usage_telemetry(
@@ -427,8 +685,32 @@ class Pipeline(base.BaseTransformer):
427
685
 
428
686
  Returns:
429
687
  Output dataset.
688
+
689
+ Raises:
690
+ ValueError: An sklearn object has not been fit before calling this function
430
691
  """
431
- return self._invoke_estimator_func("predict_proba", dataset)
692
+
693
+ if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
694
+ if self._sklearn_object is None:
695
+ raise ValueError("Model must be fit before inference.")
696
+ expected_output_cols = self._get_output_column_names("predict_proba")
697
+
698
+ handler = ModelTransformerBuilder.build(
699
+ dataset=dataset,
700
+ estimator=self._sklearn_object,
701
+ class_name="Pipeline",
702
+ subproject="",
703
+ autogenerated=False,
704
+ )
705
+ return handler.batch_inference(
706
+ inference_method="predict_proba",
707
+ input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset),
708
+ expected_output_cols=expected_output_cols,
709
+ session=dataset._session,
710
+ dependencies=self._deps,
711
+ )
712
+ else:
713
+ return self._invoke_estimator_func("predict_proba", dataset)
432
714
 
433
715
  @metaestimators.available_if(_final_step_has("predict_log_proba")) # type: ignore[misc]
434
716
  @telemetry.send_api_usage_telemetry(
@@ -447,8 +729,31 @@ class Pipeline(base.BaseTransformer):
447
729
 
448
730
  Returns:
449
731
  Output dataset.
732
+
733
+ Raises:
734
+ ValueError: An sklearn object has not been fit before calling this function
450
735
  """
451
- return self._invoke_estimator_func("predict_log_proba", dataset)
736
+ if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
737
+ if self._sklearn_object is None:
738
+ raise ValueError("Model must be fit before inference.")
739
+
740
+ expected_output_cols = self._get_output_column_names("predict_log_proba")
741
+ handler = ModelTransformerBuilder.build(
742
+ dataset=dataset,
743
+ estimator=self._sklearn_object,
744
+ class_name="Pipeline",
745
+ subproject="",
746
+ autogenerated=False,
747
+ )
748
+ return handler.batch_inference(
749
+ inference_method="predict_log_proba",
750
+ input_cols=self.input_cols if self.input_cols else self._infer_input_cols(dataset),
751
+ expected_output_cols=expected_output_cols,
752
+ session=dataset._session,
753
+ dependencies=self._deps,
754
+ )
755
+ else:
756
+ return self._invoke_estimator_func("predict_log_proba", dataset)
452
757
 
453
758
  @metaestimators.available_if(_final_step_has("score")) # type: ignore[misc]
454
759
  @telemetry.send_api_usage_telemetry(
@@ -464,8 +769,30 @@ class Pipeline(base.BaseTransformer):
464
769
 
465
770
  Returns:
466
771
  Output dataset.
772
+
773
+ Raises:
774
+ ValueError: An sklearn object has not been fit before calling this function
467
775
  """
468
- return self._invoke_estimator_func("score", dataset)
776
+
777
+ if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
778
+ if self._sklearn_object is None:
779
+ raise ValueError("Model must be fit before scoreing.")
780
+ handler = ModelTransformerBuilder.build(
781
+ dataset=dataset,
782
+ estimator=self._sklearn_object,
783
+ class_name="Pipeline",
784
+ subproject="",
785
+ autogenerated=False,
786
+ )
787
+ return handler.score(
788
+ input_cols=self._infer_input_cols(),
789
+ label_cols=self._get_label_cols(),
790
+ session=dataset._session,
791
+ dependencies=self._deps,
792
+ score_sproc_imports=[],
793
+ )
794
+ else:
795
+ return self._invoke_estimator_func("score", dataset)
469
796
 
470
797
  def _invoke_estimator_func(
471
798
  self, func_name: str, dataset: Union[snowpark.DataFrame, pd.DataFrame]
@@ -495,15 +822,6 @@ class Pipeline(base.BaseTransformer):
495
822
  res: snowpark.DataFrame = getattr(estimator[1], func_name)(transformed_dataset)
496
823
  return res
497
824
 
498
- def _create_unfitted_sklearn_object(self) -> pipeline.Pipeline:
499
- sksteps = []
500
- for step in self.steps:
501
- if isinstance(step[1], base.BaseTransformer):
502
- sksteps.append(tuple([step[0], _utils.to_native_format(step[1])]))
503
- else:
504
- sksteps.append(tuple([step[0], step[1]]))
505
- return pipeline.Pipeline(steps=sksteps)
506
-
507
825
  def _construct_fitted_column_transformer_object(
508
826
  self,
509
827
  step_name_in_pipeline: str,
@@ -562,6 +880,125 @@ class Pipeline(base.BaseTransformer):
562
880
  ct._name_to_fitted_passthrough = {step_name_in_ct: ft}
563
881
  return ct
564
882
 
883
+ def _fit_ml_runtime(self, dataset: snowpark.DataFrame) -> None:
884
+ """Train the pipeline in the ML Runtime.
885
+
886
+ Args:
887
+ dataset: The training Snowpark dataframe
888
+
889
+ Raises:
890
+ ModuleNotFoundError: The ML Runtime Client is not installed.
891
+ """
892
+ try:
893
+ from snowflake.ml.runtime import MLRuntimeClient
894
+ except ModuleNotFoundError as e:
895
+ # The snowflake.ml.runtime module should always be present when
896
+ # the env var IN_SPCS_ML_RUNTIME is present.
897
+ raise ModuleNotFoundError("ML Runtime Python Client is not installed.") from e
898
+
899
+ client = MLRuntimeClient()
900
+ ml_runtime_compatible_pipeline = self._create_unfitted_sklearn_object()
901
+
902
+ label_cols = self._get_label_cols()
903
+ all_df_cols = dataset.columns
904
+ input_cols = [col for col in all_df_cols if col not in label_cols]
905
+
906
+ trained_pipeline = client.train(
907
+ estimator=ml_runtime_compatible_pipeline,
908
+ dataset=dataset,
909
+ input_cols=input_cols,
910
+ label_cols=label_cols,
911
+ sample_weight_col=self.sample_weight_col,
912
+ )
913
+
914
+ self._sklearn_object = trained_pipeline
915
+
916
+ def _get_label_cols(self) -> List[str]:
917
+ """Util function to get the label columns from the pipeline.
918
+ The label column is only present in the estimator
919
+
920
+ Returns:
921
+ List of label columns, or empty list if no label cols.
922
+ """
923
+ label_cols = []
924
+ estimator = self._get_estimator()
925
+ if estimator is not None:
926
+ label_cols = estimator[1].get_label_cols()
927
+
928
+ return label_cols
929
+
930
+ def _can_be_trained_in_ml_runtime(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> bool:
931
+ """A utility function to determine if the pipeline cam be pushed down to the ML Runtime for training.
932
+ Currently, this is true if:
933
+ - The training dataset is a snowpark dataframe,
934
+ - The IN_SPCS_ML_RUNTIME environment is present and
935
+ - The pipeline can be converted to an sklearn pipeline.
936
+
937
+ Args:
938
+ dataset: The training dataset
939
+
940
+ Returns:
941
+ True if the dataset can be fit in the ml runtime, else false.
942
+
943
+ """
944
+ if not isinstance(dataset, snowpark.DataFrame):
945
+ return False
946
+
947
+ if not os.environ.get(IN_ML_RUNTIME_ENV_VAR):
948
+ return False
949
+
950
+ return self._is_convertible_to_sklearn_object()
951
+
952
+ @staticmethod
953
+ def _wrap_transformer_in_column_transformer(
954
+ transformer_name: str, transformer: base.BaseTransformer
955
+ ) -> ColumnTransformer:
956
+ """A helper function to convert a transformer object to an sklearn object and wrap in an sklearn
957
+ ColumnTransformer.
958
+
959
+ Args:
960
+ transformer_name: Name of the transformer to be wrapped.
961
+ transformer: The transformer object to be wrapped.
962
+
963
+ Returns:
964
+ A column transformer sklearn object that uses the input columns from the initial snowpark ml transformer.
965
+ """
966
+ column_transformer = ColumnTransformer(
967
+ transformers=[(transformer_name, Pipeline._get_native_object(transformer), transformer.get_input_cols())],
968
+ remainder="passthrough",
969
+ )
970
+ return column_transformer
971
+
972
+ def _create_unfitted_sklearn_object(self) -> pipeline.Pipeline:
973
+ """Create a sklearn pipeline from the current snowml pipeline.
974
+ ColumnTransformers are used to wrap transformers as their input columns can be specified
975
+ as a subset of the pipeline's input columns.
976
+
977
+ Returns:
978
+ An unfit pipeline that can be fit using the ML runtime client.
979
+ """
980
+
981
+ sklearn_pipeline_steps = []
982
+
983
+ first_step_name, first_step_object = self.steps[0]
984
+
985
+ # Only the first step can have the input_cols field not None/empty.
986
+ if first_step_object.get_input_cols():
987
+ first_step_column_transformer = Pipeline._wrap_transformer_in_column_transformer(
988
+ first_step_name, first_step_object
989
+ )
990
+ first_step_skl = (first_step_name, first_step_column_transformer)
991
+ else:
992
+ first_step_skl = (first_step_name, Pipeline._get_native_object(first_step_object))
993
+
994
+ sklearn_pipeline_steps.append(first_step_skl)
995
+
996
+ for step_name, step_object in self.steps[1:]:
997
+ skl_step = (step_name, Pipeline._get_native_object(step_object))
998
+ sklearn_pipeline_steps.append(skl_step)
999
+
1000
+ return pipeline.Pipeline(sklearn_pipeline_steps)
1001
+
565
1002
  def _create_sklearn_object(self) -> pipeline.Pipeline:
566
1003
  if not self._is_fitted:
567
1004
  return self._create_unfitted_sklearn_object()
@@ -570,7 +1007,7 @@ class Pipeline(base.BaseTransformer):
570
1007
  raise exceptions.SnowflakeMLException(
571
1008
  error_code=error_codes.METHOD_NOT_ALLOWED,
572
1009
  original_exception=ValueError(
573
- "The pipeline can't be converted to SKLearn equivalent because it processing label or "
1010
+ "The pipeline can't be converted to SKLearn equivalent because it modifies processing label or "
574
1011
  "sample_weight columns as part of pipeline preprocessing steps which is not allowed in SKLearn."
575
1012
  ),
576
1013
  )
@@ -631,3 +1068,48 @@ class Pipeline(base.BaseTransformer):
631
1068
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
632
1069
  )
633
1070
  return self._model_signature_dict
1071
+
1072
+ @staticmethod
1073
+ def _get_native_object(estimator: base.BaseEstimator) -> object:
1074
+ """A helper function to get the native(sklearn, xgboost, or lightgbm)
1075
+ object from a snowpark ml estimator.
1076
+ TODO - better type hinting - is there a common base class for all xgb/lgbm estimators?
1077
+
1078
+ Args:
1079
+ estimator: the estimator from which to derive the native object.
1080
+
1081
+ Returns:
1082
+ a native estimator object
1083
+
1084
+ Raises:
1085
+ ValueError: The estimator is not an sklearn, xgboost, or lightgbm estimator.
1086
+ """
1087
+ methods = ["to_sklearn", "to_xgboost", "to_lightgbm"]
1088
+ for method_name in methods:
1089
+ if hasattr(estimator, method_name):
1090
+ try:
1091
+ result = getattr(estimator, method_name)()
1092
+ return result
1093
+ except exceptions.SnowflakeMLException:
1094
+ pass # Do nothing and continue to the next method
1095
+ raise ValueError("The estimator must be an sklearn, xgboost, or lightgbm estimator.")
1096
+
1097
+ def to_sklearn(self) -> pipeline.Pipeline:
1098
+ """Returns an sklearn Pipeline representing the object, if possible.
1099
+
1100
+ Returns:
1101
+ previously fit sklearn Pipeline if present, else an unfit pipeline
1102
+
1103
+ Raises:
1104
+ ValueError: The pipeline cannot be represented as an sklearn pipeline.
1105
+ """
1106
+ if self._is_fitted:
1107
+ if self._sklearn_object is not None:
1108
+ return self._sklearn_object
1109
+ else:
1110
+ return self._create_sklearn_object()
1111
+ else:
1112
+ if self._is_convertible_to_sklearn_object():
1113
+ return self._create_unfitted_sklearn_object()
1114
+ else:
1115
+ raise ValueError("This pipeline can not be converted to an sklearn pipeline.")