snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -23,20 +23,26 @@ from snowflake.ml._internal.utils.temp_file_utils import (
23
23
  cleanup_temp_files,
24
24
  get_temp_file_path,
25
25
  )
26
+ from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
26
27
  from snowflake.ml.modeling._internal.model_specifications import (
27
28
  ModelSpecifications,
28
29
  ModelSpecificationsBuilder,
29
30
  )
30
- from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions
31
+ from snowflake.snowpark import (
32
+ DataFrame,
33
+ Session,
34
+ exceptions as snowpark_exceptions,
35
+ functions as F,
36
+ )
31
37
  from snowflake.snowpark._internal.utils import (
32
38
  TempObjectType,
33
39
  random_name_for_temp_object,
34
40
  )
35
- from snowflake.snowpark.functions import sproc
36
41
  from snowflake.snowpark.stored_procedure import StoredProcedure
37
42
 
38
43
  cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
39
44
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
45
+ cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
40
46
 
41
47
  _PROJECT = "ModelDevelopment"
42
48
 
@@ -122,7 +128,7 @@ class SnowparkModelTrainer:
122
128
  project=_PROJECT,
123
129
  subproject=self._subproject,
124
130
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
125
- api_calls=[sproc],
131
+ api_calls=[F.sproc],
126
132
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
127
133
  )
128
134
  # Put locally serialized transform on stage.
@@ -292,7 +298,7 @@ class SnowparkModelTrainer:
292
298
  """
293
299
  imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
294
300
 
295
- def fit_wrapper_function(
301
+ def fit_predict_wrapper_function(
296
302
  session: Session,
297
303
  sql_queries: List[str],
298
304
  stage_transform_file_name: str,
@@ -329,7 +335,7 @@ class SnowparkModelTrainer:
329
335
  with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
330
336
  estimator = cp.load(local_transform_file_obj)
331
337
 
332
- fit_predict_result = estimator.fit_predict(df[input_cols])
338
+ fit_predict_result = estimator.fit_predict(X=df[input_cols])
333
339
 
334
340
  local_result_file_name = get_temp_file_path()
335
341
 
@@ -349,8 +355,16 @@ class SnowparkModelTrainer:
349
355
  fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list)
350
356
  else:
351
357
  df = df.copy()
352
- fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list)
353
- fit_predict_result_pd = pd.concat([df, fit_predict_result_pd], axis=1)
358
+ # in case the output column name overlap with the input column names,
359
+ # remove the ones in input column names
360
+ remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(expected_output_cols_list))
361
+ fit_predict_result_pd = pd.concat(
362
+ [
363
+ df[remove_dataset_col_name_exist_in_output_col],
364
+ pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list),
365
+ ],
366
+ axis=1,
367
+ )
354
368
 
355
369
  # write into a temp table in sproc and load the table from outside
356
370
  session.write_pandas(
@@ -361,17 +375,150 @@ class SnowparkModelTrainer:
361
375
  # to pass debug information to the caller.
362
376
  return str(os.path.basename(local_result_file_name))
363
377
 
364
- return fit_wrapper_function
378
+ return fit_predict_wrapper_function
379
+
380
+ def _build_fit_transform_wrapper_sproc(
381
+ self,
382
+ model_spec: ModelSpecifications,
383
+ ) -> Callable[
384
+ [
385
+ Session,
386
+ List[str],
387
+ str,
388
+ str,
389
+ List[str],
390
+ Optional[List[str]],
391
+ Optional[str],
392
+ Dict[str, str],
393
+ bool,
394
+ List[str],
395
+ str,
396
+ ],
397
+ str,
398
+ ]:
399
+ """
400
+ Constructs and returns a python stored procedure function to be used for training model.
401
+
402
+ Args:
403
+ model_spec: ModelSpecifications object that contains model specific information
404
+ like required imports, package dependencies, etc.
405
+
406
+ Returns:
407
+ A callable that can be registered as a stored procedure.
408
+ """
409
+ imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
410
+
411
+ def fit_transform_wrapper_function(
412
+ session: Session,
413
+ sql_queries: List[str],
414
+ stage_transform_file_name: str,
415
+ stage_result_file_name: str,
416
+ input_cols: List[str],
417
+ label_cols: Optional[List[str]],
418
+ sample_weight_col: Optional[str],
419
+ statement_params: Dict[str, str],
420
+ drop_input_cols: bool,
421
+ expected_output_cols_list: List[str],
422
+ fit_transform_result_name: str,
423
+ ) -> str:
424
+ import os
425
+
426
+ import cloudpickle as cp
427
+ import pandas as pd
428
+
429
+ for import_name in imports:
430
+ importlib.import_module(import_name)
431
+
432
+ # Execute snowpark queries and obtain the results as pandas dataframe
433
+ # NB: this implies that the result data must fit into memory.
434
+ for query in sql_queries[:-1]:
435
+ _ = session.sql(query).collect(statement_params=statement_params)
436
+ sp_df = session.sql(sql_queries[-1])
437
+ df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
438
+ df.columns = sp_df.columns
439
+
440
+ local_transform_file_name = get_temp_file_path()
441
+
442
+ session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
443
+
444
+ local_transform_file_path = os.path.join(
445
+ local_transform_file_name, os.listdir(local_transform_file_name)[0]
446
+ )
447
+ with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
448
+ estimator = cp.load(local_transform_file_obj)
449
+
450
+ argspec = inspect.getfullargspec(estimator.fit)
451
+ args = {"X": df[input_cols]}
452
+ if label_cols:
453
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
454
+ args[label_arg_name] = df[label_cols].squeeze()
455
+
456
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
457
+ args["sample_weight"] = df[sample_weight_col].squeeze()
458
+
459
+ fit_transform_result = estimator.fit_transform(**args)
460
+
461
+ local_result_file_name = get_temp_file_path()
462
+
463
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
464
+ cp.dump(estimator, local_result_file_obj)
465
+
466
+ session.file.put(
467
+ local_result_file_name,
468
+ stage_result_file_name,
469
+ auto_compress=False,
470
+ overwrite=True,
471
+ statement_params=statement_params,
472
+ )
473
+
474
+ transformed_numpy_array, output_cols = handle_inference_result(
475
+ inference_res=fit_transform_result,
476
+ output_cols=expected_output_cols_list,
477
+ inference_method="fit_transform",
478
+ within_udf=True,
479
+ )
480
+
481
+ if len(transformed_numpy_array.shape) > 1:
482
+ if transformed_numpy_array.shape[1] != len(output_cols):
483
+ series = pd.Series(transformed_numpy_array.tolist())
484
+ transformed_pandas_df = pd.DataFrame(series, columns=output_cols)
485
+ else:
486
+ transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=output_cols)
487
+ else:
488
+ transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=output_cols)
489
+
490
+ # store the transform output
491
+ if not drop_input_cols:
492
+ df = df.copy()
493
+ # in case the output column name overlap with the input column names,
494
+ # remove the ones in input column names
495
+ remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(output_cols))
496
+ transformed_pandas_df = pd.concat(
497
+ [df[remove_dataset_col_name_exist_in_output_col], transformed_pandas_df], axis=1
498
+ )
499
+
500
+ # write into a temp table in sproc and load the table from outside
501
+ session.write_pandas(
502
+ transformed_pandas_df,
503
+ fit_transform_result_name,
504
+ auto_create_table=True,
505
+ table_type="temp",
506
+ quote_identifiers=False,
507
+ )
508
+
509
+ return str(os.path.basename(local_result_file_name))
510
+
511
+ return fit_transform_wrapper_function
365
512
 
366
513
  def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
367
514
  # If the sproc already exists, don't register.
368
- if not hasattr(self.session, "_FIT_PRE_WRAPPER_SPROCS"):
369
- self.session._FIT_PRE_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
515
+ if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
516
+ self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
370
517
 
371
518
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
372
- fit_predict_sproc_key = model_spec.__class__.__name__
373
- if fit_predict_sproc_key in self.session._FIT_PRE_WRAPPER_SPROCS: # type: ignore[attr-defined]
374
- fit_sproc: StoredProcedure = self.session._FIT_PRE_WRAPPER_SPROCS[ # type: ignore[attr-defined]
519
+ fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict"
520
+ if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
521
+ fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
375
522
  fit_predict_sproc_key
376
523
  ]
377
524
  return fit_sproc
@@ -392,12 +539,47 @@ class SnowparkModelTrainer:
392
539
  statement_params=statement_params,
393
540
  )
394
541
 
395
- self.session._FIT_PRE_WRAPPER_SPROCS[ # type: ignore[attr-defined]
542
+ self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
396
543
  fit_predict_sproc_key
397
544
  ] = fit_predict_wrapper_sproc
398
545
 
399
546
  return fit_predict_wrapper_sproc
400
547
 
548
+ def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
549
+ # If the sproc already exists, don't register.
550
+ if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
551
+ self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
552
+
553
+ model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
554
+ fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform"
555
+ if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
556
+ fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
557
+ fit_transform_sproc_key
558
+ ]
559
+ return fit_sproc
560
+
561
+ fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
562
+
563
+ relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
564
+ pkg_versions=model_spec.pkgDependencies, session=self.session
565
+ )
566
+
567
+ fit_transform_wrapper_sproc = self.session.sproc.register(
568
+ func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
569
+ is_permanent=False,
570
+ name=fit_transform_sproc_name,
571
+ packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
572
+ replace=True,
573
+ session=self.session,
574
+ statement_params=statement_params,
575
+ )
576
+
577
+ self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
578
+ fit_transform_sproc_key
579
+ ] = fit_transform_wrapper_sproc
580
+
581
+ return fit_transform_wrapper_sproc
582
+
401
583
  def train(self) -> object:
402
584
  """
403
585
  Trains the model by pushing down the compute into Snowflake using stored procedures.
@@ -498,10 +680,10 @@ class SnowparkModelTrainer:
498
680
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
499
681
  )
500
682
 
501
- fit_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
683
+ fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
502
684
  fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
503
685
 
504
- sproc_export_file_name: str = fit_wrapper_sproc(
686
+ sproc_export_file_name: str = fit_predict_wrapper_sproc(
505
687
  self.session,
506
688
  queries,
507
689
  stage_transform_file_name,
@@ -521,3 +703,66 @@ class SnowparkModelTrainer:
521
703
  )
522
704
 
523
705
  return output_result_sp, fitted_estimator
706
+
707
+ def train_fit_transform(
708
+ self,
709
+ expected_output_cols_list: List[str],
710
+ drop_input_cols: Optional[bool] = False,
711
+ ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
712
+ """Trains the model by pushing down the compute into Snowflake using stored procedures.
713
+ This API is different from fit itself because it would also provide the transform
714
+ output.
715
+
716
+ Args:
717
+ expected_output_cols_list (List[str]): The output columns
718
+ name as a list. Defaults to None.
719
+ drop_input_cols (Optional[bool]): Boolean to determine whether to
720
+ drop the input columns from the output dataset.
721
+
722
+ Returns:
723
+ Tuple[Union[DataFrame, pd.DataFrame], object]: [transformed dataset, estimator]
724
+ """
725
+ dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(self.dataset)
726
+
727
+ # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
728
+ queries = dataset.queries["queries"]
729
+
730
+ transform_stage_name = self._create_temp_stage()
731
+ (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
732
+ stage_name=transform_stage_name
733
+ )
734
+
735
+ # Call fit sproc
736
+ statement_params = telemetry.get_function_usage_statement_params(
737
+ project=_PROJECT,
738
+ subproject=self._subproject,
739
+ function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
740
+ api_calls=[Session.call],
741
+ custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
742
+ )
743
+
744
+ fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
745
+ fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
746
+
747
+ sproc_export_file_name: str = fit_transform_wrapper_sproc(
748
+ self.session,
749
+ queries,
750
+ stage_transform_file_name,
751
+ stage_result_file_name,
752
+ self.input_cols,
753
+ self.label_cols,
754
+ self.sample_weight_col,
755
+ statement_params,
756
+ drop_input_cols,
757
+ expected_output_cols_list,
758
+ fit_transform_result_name,
759
+ )
760
+
761
+ output_result_sp = self.session.table(fit_transform_result_name)
762
+ fitted_estimator = self._fetch_model_from_stage(
763
+ dir_path=stage_result_file_name,
764
+ file_name=sproc_export_file_name,
765
+ statement_params=statement_params,
766
+ )
767
+
768
+ return output_result_sp, fitted_estimator
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.calibration".replace("sk
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class CalibratedClassifierCV(BaseTransformer):
70
64
  r"""Probability calibration with isotonic regression or logistic regression
71
65
  For more details on this class, see [sklearn.calibration.CalibratedClassifierCV]
@@ -328,20 +322,17 @@ class CalibratedClassifierCV(BaseTransformer):
328
322
  self,
329
323
  dataset: DataFrame,
330
324
  inference_method: str,
331
- ) -> List[str]:
332
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
333
- return the available package that exists in the snowflake anaconda channel
325
+ ) -> None:
326
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
334
327
 
335
328
  Args:
336
329
  dataset: snowpark dataframe
337
330
  inference_method: the inference method such as predict, score...
338
-
331
+
339
332
  Raises:
340
333
  SnowflakeMLException: If the estimator is not fitted, raise error
341
334
  SnowflakeMLException: If the session is None, raise error
342
335
 
343
- Returns:
344
- A list of available package that exists in the snowflake anaconda channel
345
336
  """
346
337
  if not self._is_fitted:
347
338
  raise exceptions.SnowflakeMLException(
@@ -359,9 +350,7 @@ class CalibratedClassifierCV(BaseTransformer):
359
350
  "Session must not specified for snowpark dataset."
360
351
  ),
361
352
  )
362
- # Validate that key package version in user workspace are supported in snowflake conda channel
363
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
364
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
353
+
365
354
 
366
355
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
367
356
  @telemetry.send_api_usage_telemetry(
@@ -409,7 +398,8 @@ class CalibratedClassifierCV(BaseTransformer):
409
398
 
410
399
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
411
400
 
412
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
401
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
402
+ self._deps = self._get_dependencies()
413
403
  assert isinstance(
414
404
  dataset._session, Session
415
405
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -492,10 +482,8 @@ class CalibratedClassifierCV(BaseTransformer):
492
482
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
493
483
  expected_dtype = convert_sp_to_sf_type(output_types[0])
494
484
 
495
- self._deps = self._batch_inference_validate_snowpark(
496
- dataset=dataset,
497
- inference_method=inference_method,
498
- )
485
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
486
+ self._deps = self._get_dependencies()
499
487
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
500
488
 
501
489
  transform_kwargs = dict(
@@ -562,16 +550,40 @@ class CalibratedClassifierCV(BaseTransformer):
562
550
  self._is_fitted = True
563
551
  return output_result
564
552
 
553
+
554
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
555
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
556
+ """ Method not supported for this class.
565
557
 
566
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
567
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
568
- """
558
+
559
+ Raises:
560
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
561
+
562
+ Args:
563
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
564
+ Snowpark or Pandas DataFrame.
565
+ output_cols_prefix: Prefix for the response columns
569
566
  Returns:
570
567
  Transformed dataset.
571
568
  """
572
- self.fit(dataset)
573
- assert self._sklearn_object is not None
574
- return self._sklearn_object.embedding_
569
+ self._infer_input_output_cols(dataset)
570
+ super()._check_dataset_type(dataset)
571
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
572
+ estimator=self._sklearn_object,
573
+ dataset=dataset,
574
+ input_cols=self.input_cols,
575
+ label_cols=self.label_cols,
576
+ sample_weight_col=self.sample_weight_col,
577
+ autogenerated=self._autogenerated,
578
+ subproject=_SUBPROJECT,
579
+ )
580
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
581
+ drop_input_cols=self._drop_input_cols,
582
+ expected_output_cols_list=self.output_cols,
583
+ )
584
+ self._sklearn_object = fitted_estimator
585
+ self._is_fitted = True
586
+ return output_result
575
587
 
576
588
 
577
589
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -664,10 +676,8 @@ class CalibratedClassifierCV(BaseTransformer):
664
676
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
665
677
 
666
678
  if isinstance(dataset, DataFrame):
667
- self._deps = self._batch_inference_validate_snowpark(
668
- dataset=dataset,
669
- inference_method=inference_method,
670
- )
679
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
680
+ self._deps = self._get_dependencies()
671
681
  assert isinstance(
672
682
  dataset._session, Session
673
683
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -734,10 +744,8 @@ class CalibratedClassifierCV(BaseTransformer):
734
744
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
735
745
 
736
746
  if isinstance(dataset, DataFrame):
737
- self._deps = self._batch_inference_validate_snowpark(
738
- dataset=dataset,
739
- inference_method=inference_method,
740
- )
747
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
748
+ self._deps = self._get_dependencies()
741
749
  assert isinstance(
742
750
  dataset._session, Session
743
751
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -799,10 +807,8 @@ class CalibratedClassifierCV(BaseTransformer):
799
807
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
800
808
 
801
809
  if isinstance(dataset, DataFrame):
802
- self._deps = self._batch_inference_validate_snowpark(
803
- dataset=dataset,
804
- inference_method=inference_method,
805
- )
810
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
811
+ self._deps = self._get_dependencies()
806
812
  assert isinstance(
807
813
  dataset._session, Session
808
814
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -868,10 +874,8 @@ class CalibratedClassifierCV(BaseTransformer):
868
874
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
869
875
 
870
876
  if isinstance(dataset, DataFrame):
871
- self._deps = self._batch_inference_validate_snowpark(
872
- dataset=dataset,
873
- inference_method=inference_method,
874
- )
877
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
878
+ self._deps = self._get_dependencies()
875
879
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
876
880
  transform_kwargs = dict(
877
881
  session=dataset._session,
@@ -935,17 +939,15 @@ class CalibratedClassifierCV(BaseTransformer):
935
939
  transform_kwargs: ScoreKwargsTypedDict = dict()
936
940
 
937
941
  if isinstance(dataset, DataFrame):
938
- self._deps = self._batch_inference_validate_snowpark(
939
- dataset=dataset,
940
- inference_method="score",
941
- )
942
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
943
+ self._deps = self._get_dependencies()
942
944
  selected_cols = self._get_active_columns()
943
945
  if len(selected_cols) > 0:
944
946
  dataset = dataset.select(selected_cols)
945
947
  assert isinstance(dataset._session, Session) # keep mypy happy
946
948
  transform_kwargs = dict(
947
949
  session=dataset._session,
948
- dependencies=["snowflake-snowpark-python"] + self._deps,
950
+ dependencies=self._deps,
949
951
  score_sproc_imports=['sklearn'],
950
952
  )
951
953
  elif isinstance(dataset, pd.DataFrame):
@@ -1010,11 +1012,8 @@ class CalibratedClassifierCV(BaseTransformer):
1010
1012
 
1011
1013
  if isinstance(dataset, DataFrame):
1012
1014
 
1013
- self._deps = self._batch_inference_validate_snowpark(
1014
- dataset=dataset,
1015
- inference_method=inference_method,
1016
-
1017
- )
1015
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1016
+ self._deps = self._get_dependencies()
1018
1017
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1019
1018
  transform_kwargs = dict(
1020
1019
  session = dataset._session,