snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -23,20 +23,26 @@ from snowflake.ml._internal.utils.temp_file_utils import (
23
23
  cleanup_temp_files,
24
24
  get_temp_file_path,
25
25
  )
26
+ from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
26
27
  from snowflake.ml.modeling._internal.model_specifications import (
27
28
  ModelSpecifications,
28
29
  ModelSpecificationsBuilder,
29
30
  )
30
- from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions
31
+ from snowflake.snowpark import (
32
+ DataFrame,
33
+ Session,
34
+ exceptions as snowpark_exceptions,
35
+ functions as F,
36
+ )
31
37
  from snowflake.snowpark._internal.utils import (
32
38
  TempObjectType,
33
39
  random_name_for_temp_object,
34
40
  )
35
- from snowflake.snowpark.functions import sproc
36
41
  from snowflake.snowpark.stored_procedure import StoredProcedure
37
42
 
38
43
  cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
39
44
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
45
+ cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
40
46
 
41
47
  _PROJECT = "ModelDevelopment"
42
48
 
@@ -122,7 +128,7 @@ class SnowparkModelTrainer:
122
128
  project=_PROJECT,
123
129
  subproject=self._subproject,
124
130
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
125
- api_calls=[sproc],
131
+ api_calls=[F.sproc],
126
132
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
127
133
  )
128
134
  # Put locally serialized transform on stage.
@@ -292,7 +298,7 @@ class SnowparkModelTrainer:
292
298
  """
293
299
  imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
294
300
 
295
- def fit_wrapper_function(
301
+ def fit_predict_wrapper_function(
296
302
  session: Session,
297
303
  sql_queries: List[str],
298
304
  stage_transform_file_name: str,
@@ -329,7 +335,7 @@ class SnowparkModelTrainer:
329
335
  with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
330
336
  estimator = cp.load(local_transform_file_obj)
331
337
 
332
- fit_predict_result = estimator.fit_predict(df[input_cols])
338
+ fit_predict_result = estimator.fit_predict(X=df[input_cols])
333
339
 
334
340
  local_result_file_name = get_temp_file_path()
335
341
 
@@ -349,8 +355,16 @@ class SnowparkModelTrainer:
349
355
  fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list)
350
356
  else:
351
357
  df = df.copy()
352
- fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list)
353
- fit_predict_result_pd = pd.concat([df, fit_predict_result_pd], axis=1)
358
+ # in case the output column name overlap with the input column names,
359
+ # remove the ones in input column names
360
+ remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(expected_output_cols_list))
361
+ fit_predict_result_pd = pd.concat(
362
+ [
363
+ df[remove_dataset_col_name_exist_in_output_col],
364
+ pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list),
365
+ ],
366
+ axis=1,
367
+ )
354
368
 
355
369
  # write into a temp table in sproc and load the table from outside
356
370
  session.write_pandas(
@@ -361,17 +375,150 @@ class SnowparkModelTrainer:
361
375
  # to pass debug information to the caller.
362
376
  return str(os.path.basename(local_result_file_name))
363
377
 
364
- return fit_wrapper_function
378
+ return fit_predict_wrapper_function
379
+
380
+ def _build_fit_transform_wrapper_sproc(
381
+ self,
382
+ model_spec: ModelSpecifications,
383
+ ) -> Callable[
384
+ [
385
+ Session,
386
+ List[str],
387
+ str,
388
+ str,
389
+ List[str],
390
+ Optional[List[str]],
391
+ Optional[str],
392
+ Dict[str, str],
393
+ bool,
394
+ List[str],
395
+ str,
396
+ ],
397
+ str,
398
+ ]:
399
+ """
400
+ Constructs and returns a python stored procedure function to be used for training model.
401
+
402
+ Args:
403
+ model_spec: ModelSpecifications object that contains model specific information
404
+ like required imports, package dependencies, etc.
405
+
406
+ Returns:
407
+ A callable that can be registered as a stored procedure.
408
+ """
409
+ imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
410
+
411
+ def fit_transform_wrapper_function(
412
+ session: Session,
413
+ sql_queries: List[str],
414
+ stage_transform_file_name: str,
415
+ stage_result_file_name: str,
416
+ input_cols: List[str],
417
+ label_cols: Optional[List[str]],
418
+ sample_weight_col: Optional[str],
419
+ statement_params: Dict[str, str],
420
+ drop_input_cols: bool,
421
+ expected_output_cols_list: List[str],
422
+ fit_transform_result_name: str,
423
+ ) -> str:
424
+ import os
425
+
426
+ import cloudpickle as cp
427
+ import pandas as pd
428
+
429
+ for import_name in imports:
430
+ importlib.import_module(import_name)
431
+
432
+ # Execute snowpark queries and obtain the results as pandas dataframe
433
+ # NB: this implies that the result data must fit into memory.
434
+ for query in sql_queries[:-1]:
435
+ _ = session.sql(query).collect(statement_params=statement_params)
436
+ sp_df = session.sql(sql_queries[-1])
437
+ df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
438
+ df.columns = sp_df.columns
439
+
440
+ local_transform_file_name = get_temp_file_path()
441
+
442
+ session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
443
+
444
+ local_transform_file_path = os.path.join(
445
+ local_transform_file_name, os.listdir(local_transform_file_name)[0]
446
+ )
447
+ with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
448
+ estimator = cp.load(local_transform_file_obj)
449
+
450
+ argspec = inspect.getfullargspec(estimator.fit)
451
+ args = {"X": df[input_cols]}
452
+ if label_cols:
453
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
454
+ args[label_arg_name] = df[label_cols].squeeze()
455
+
456
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
457
+ args["sample_weight"] = df[sample_weight_col].squeeze()
458
+
459
+ fit_transform_result = estimator.fit_transform(**args)
460
+
461
+ local_result_file_name = get_temp_file_path()
462
+
463
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
464
+ cp.dump(estimator, local_result_file_obj)
465
+
466
+ session.file.put(
467
+ local_result_file_name,
468
+ stage_result_file_name,
469
+ auto_compress=False,
470
+ overwrite=True,
471
+ statement_params=statement_params,
472
+ )
473
+
474
+ transformed_numpy_array, output_cols = handle_inference_result(
475
+ inference_res=fit_transform_result,
476
+ output_cols=expected_output_cols_list,
477
+ inference_method="fit_transform",
478
+ within_udf=True,
479
+ )
480
+
481
+ if len(transformed_numpy_array.shape) > 1:
482
+ if transformed_numpy_array.shape[1] != len(output_cols):
483
+ series = pd.Series(transformed_numpy_array.tolist())
484
+ transformed_pandas_df = pd.DataFrame(series, columns=output_cols)
485
+ else:
486
+ transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=output_cols)
487
+ else:
488
+ transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=output_cols)
489
+
490
+ # store the transform output
491
+ if not drop_input_cols:
492
+ df = df.copy()
493
+ # in case the output column name overlap with the input column names,
494
+ # remove the ones in input column names
495
+ remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(output_cols))
496
+ transformed_pandas_df = pd.concat(
497
+ [df[remove_dataset_col_name_exist_in_output_col], transformed_pandas_df], axis=1
498
+ )
499
+
500
+ # write into a temp table in sproc and load the table from outside
501
+ session.write_pandas(
502
+ transformed_pandas_df,
503
+ fit_transform_result_name,
504
+ auto_create_table=True,
505
+ table_type="temp",
506
+ quote_identifiers=False,
507
+ )
508
+
509
+ return str(os.path.basename(local_result_file_name))
510
+
511
+ return fit_transform_wrapper_function
365
512
 
366
513
  def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
367
514
  # If the sproc already exists, don't register.
368
- if not hasattr(self.session, "_FIT_PRE_WRAPPER_SPROCS"):
369
- self.session._FIT_PRE_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
515
+ if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
516
+ self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
370
517
 
371
518
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
372
- fit_predict_sproc_key = model_spec.__class__.__name__
373
- if fit_predict_sproc_key in self.session._FIT_PRE_WRAPPER_SPROCS: # type: ignore[attr-defined]
374
- fit_sproc: StoredProcedure = self.session._FIT_PRE_WRAPPER_SPROCS[ # type: ignore[attr-defined]
519
+ fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict"
520
+ if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
521
+ fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
375
522
  fit_predict_sproc_key
376
523
  ]
377
524
  return fit_sproc
@@ -392,12 +539,47 @@ class SnowparkModelTrainer:
392
539
  statement_params=statement_params,
393
540
  )
394
541
 
395
- self.session._FIT_PRE_WRAPPER_SPROCS[ # type: ignore[attr-defined]
542
+ self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
396
543
  fit_predict_sproc_key
397
544
  ] = fit_predict_wrapper_sproc
398
545
 
399
546
  return fit_predict_wrapper_sproc
400
547
 
548
+ def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
549
+ # If the sproc already exists, don't register.
550
+ if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
551
+ self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
552
+
553
+ model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
554
+ fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform"
555
+ if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
556
+ fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
557
+ fit_transform_sproc_key
558
+ ]
559
+ return fit_sproc
560
+
561
+ fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
562
+
563
+ relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
564
+ pkg_versions=model_spec.pkgDependencies, session=self.session
565
+ )
566
+
567
+ fit_transform_wrapper_sproc = self.session.sproc.register(
568
+ func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
569
+ is_permanent=False,
570
+ name=fit_transform_sproc_name,
571
+ packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
572
+ replace=True,
573
+ session=self.session,
574
+ statement_params=statement_params,
575
+ )
576
+
577
+ self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
578
+ fit_transform_sproc_key
579
+ ] = fit_transform_wrapper_sproc
580
+
581
+ return fit_transform_wrapper_sproc
582
+
401
583
  def train(self) -> object:
402
584
  """
403
585
  Trains the model by pushing down the compute into Snowflake using stored procedures.
@@ -498,10 +680,10 @@ class SnowparkModelTrainer:
498
680
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
499
681
  )
500
682
 
501
- fit_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
683
+ fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
502
684
  fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
503
685
 
504
- sproc_export_file_name: str = fit_wrapper_sproc(
686
+ sproc_export_file_name: str = fit_predict_wrapper_sproc(
505
687
  self.session,
506
688
  queries,
507
689
  stage_transform_file_name,
@@ -521,3 +703,66 @@ class SnowparkModelTrainer:
521
703
  )
522
704
 
523
705
  return output_result_sp, fitted_estimator
706
+
707
+ def train_fit_transform(
708
+ self,
709
+ expected_output_cols_list: List[str],
710
+ drop_input_cols: Optional[bool] = False,
711
+ ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
712
+ """Trains the model by pushing down the compute into Snowflake using stored procedures.
713
+ This API is different from fit itself because it would also provide the transform
714
+ output.
715
+
716
+ Args:
717
+ expected_output_cols_list (List[str]): The output columns
718
+ name as a list. Defaults to None.
719
+ drop_input_cols (Optional[bool]): Boolean to determine whether to
720
+ drop the input columns from the output dataset.
721
+
722
+ Returns:
723
+ Tuple[Union[DataFrame, pd.DataFrame], object]: [transformed dataset, estimator]
724
+ """
725
+ dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(self.dataset)
726
+
727
+ # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
728
+ queries = dataset.queries["queries"]
729
+
730
+ transform_stage_name = self._create_temp_stage()
731
+ (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
732
+ stage_name=transform_stage_name
733
+ )
734
+
735
+ # Call fit sproc
736
+ statement_params = telemetry.get_function_usage_statement_params(
737
+ project=_PROJECT,
738
+ subproject=self._subproject,
739
+ function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
740
+ api_calls=[Session.call],
741
+ custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
742
+ )
743
+
744
+ fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
745
+ fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
746
+
747
+ sproc_export_file_name: str = fit_transform_wrapper_sproc(
748
+ self.session,
749
+ queries,
750
+ stage_transform_file_name,
751
+ stage_result_file_name,
752
+ self.input_cols,
753
+ self.label_cols,
754
+ self.sample_weight_col,
755
+ statement_params,
756
+ drop_input_cols,
757
+ expected_output_cols_list,
758
+ fit_transform_result_name,
759
+ )
760
+
761
+ output_result_sp = self.session.table(fit_transform_result_name)
762
+ fitted_estimator = self._fetch_model_from_stage(
763
+ dir_path=stage_result_file_name,
764
+ file_name=sproc_export_file_name,
765
+ statement_params=statement_params,
766
+ )
767
+
768
+ return output_result_sp, fitted_estimator