snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. snowflake/cortex/_complete.py +26 -5
  2. snowflake/cortex/_sse_client.py +81 -0
  3. snowflake/cortex/_util.py +105 -8
  4. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  5. snowflake/ml/dataset/dataset.py +15 -12
  6. snowflake/ml/dataset/dataset_factory.py +3 -4
  7. snowflake/ml/feature_store/feature_store.py +2 -2
  8. snowflake/ml/model/_client/sql/model_version.py +2 -2
  9. snowflake/ml/model/_model_composer/model_composer.py +2 -2
  10. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
  11. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  12. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  13. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  14. snowflake/ml/model/_signatures/core.py +13 -1
  15. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  16. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  17. snowflake/ml/model/model_signature.py +2 -0
  18. snowflake/ml/model/type_hints.py +1 -0
  19. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  20. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +156 -121
  21. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
  22. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
  23. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
  24. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
  25. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  26. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  27. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  28. snowflake/ml/modeling/cluster/birch.py +1 -1
  29. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  30. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  31. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  32. snowflake/ml/modeling/cluster/k_means.py +1 -1
  33. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  34. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  35. snowflake/ml/modeling/cluster/optics.py +1 -1
  36. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  37. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  38. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  39. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  40. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  41. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  42. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  43. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  44. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  45. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  46. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  47. snowflake/ml/modeling/covariance/oas.py +1 -1
  48. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  49. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  50. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  51. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  52. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  53. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  54. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  55. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  56. snowflake/ml/modeling/decomposition/pca.py +1 -1
  57. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  58. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  59. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  60. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  61. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  62. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  63. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  64. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  65. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  66. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  67. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  68. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  69. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  70. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  71. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  72. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  73. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  74. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  75. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  76. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  77. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  78. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  79. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  80. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  81. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  82. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  83. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  84. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  85. snowflake/ml/modeling/framework/base.py +3 -8
  86. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  87. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  88. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  89. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  90. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  91. snowflake/ml/modeling/impute/simple_imputer.py +8 -4
  92. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  93. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  94. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  95. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  96. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  97. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  98. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  99. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  100. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  101. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  102. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  103. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  104. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  105. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  106. snowflake/ml/modeling/linear_model/lars.py +1 -1
  107. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  108. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  109. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  110. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  111. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  112. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  113. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  114. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  115. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  116. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  117. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  119. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  120. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  121. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  122. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  123. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  124. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  125. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  126. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  127. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  128. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  129. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  130. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  131. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  132. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  133. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  134. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  135. snowflake/ml/modeling/manifold/isomap.py +1 -1
  136. snowflake/ml/modeling/manifold/mds.py +1 -1
  137. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  138. snowflake/ml/modeling/manifold/tsne.py +1 -1
  139. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  140. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  143. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  144. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  145. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  146. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  147. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  148. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  149. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  150. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  151. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  152. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  153. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  154. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  155. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  156. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  157. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  158. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  159. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  160. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  161. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  162. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  163. snowflake/ml/modeling/pipeline/pipeline.py +5 -0
  164. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  165. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  166. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  167. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  168. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  169. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  170. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
  171. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  173. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  174. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  175. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  176. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  177. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  178. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  179. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  180. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  181. snowflake/ml/modeling/svm/svc.py +1 -1
  182. snowflake/ml/modeling/svm/svr.py +1 -1
  183. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  184. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  185. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  186. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  187. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  188. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  189. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +21 -5
  193. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +196 -195
  194. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
  195. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
  196. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import importlib
2
2
  import inspect
3
3
  import os
4
4
  import posixpath
5
+ import sys
5
6
  from typing import Any, Dict, List, Optional
6
7
  from uuid import uuid4
7
8
 
@@ -13,12 +14,10 @@ from snowflake.ml._internal.utils import (
13
14
  identifier,
14
15
  pkg_version_utils,
15
16
  snowpark_dataframe_utils,
17
+ temp_file_utils,
16
18
  )
17
19
  from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
18
- from snowflake.ml._internal.utils.temp_file_utils import (
19
- cleanup_temp_files,
20
- get_temp_file_path,
21
- )
20
+ from snowflake.ml.modeling._internal import estimator_utils
22
21
  from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
23
22
  from snowflake.snowpark import DataFrame, Session, functions as F, types as T
24
23
  from snowflake.snowpark._internal.utils import (
@@ -26,7 +25,7 @@ from snowflake.snowpark._internal.utils import (
26
25
  random_name_for_temp_object,
27
26
  )
28
27
 
29
- cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
28
+ cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
30
29
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
31
30
  cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
32
31
 
@@ -97,7 +96,25 @@ class SnowparkTransformHandlers:
97
96
 
98
97
  dependencies = self._get_validated_snowpark_dependencies(session, dependencies)
99
98
  dataset = self.dataset
100
- estimator = self.estimator
99
+
100
+ statement_params = telemetry.get_function_usage_statement_params(
101
+ project=_PROJECT,
102
+ subproject=self._subproject,
103
+ function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
104
+ api_calls=[F.pandas_udf],
105
+ custom_tags={"autogen": True} if self._autogenerated else None,
106
+ )
107
+
108
+ temp_stage_name = estimator_utils.create_temp_stage(session)
109
+
110
+ estimator_file_name = estimator_utils.upload_model_to_stage(
111
+ stage_name=temp_stage_name,
112
+ estimator=self.estimator,
113
+ session=session,
114
+ statement_params=statement_params,
115
+ )
116
+ imports = [f"@{temp_stage_name}/{estimator_file_name}"]
117
+
101
118
  # Register vectorized UDF for batch inference
102
119
  batch_inference_udf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
103
120
 
@@ -113,13 +130,13 @@ class SnowparkTransformHandlers:
113
130
  for field in fields:
114
131
  input_datatypes.append(field.datatype)
115
132
 
116
- statement_params = telemetry.get_function_usage_statement_params(
117
- project=_PROJECT,
118
- subproject=self._subproject,
119
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
120
- api_calls=[F.pandas_udf],
121
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
122
- )
133
+ # TODO(xjiang): for optimization, use register_from_file to reduce duplicate loading estimator object
134
+ # or use cachetools here
135
+ def load_estimator() -> object:
136
+ estimator_file_path = os.path.join(sys._xoptions["snowflake_import_directory"], f"{estimator_file_name}")
137
+ with open(estimator_file_path, mode="rb") as local_estimator_file_obj:
138
+ estimator_object = cp.load(local_estimator_file_obj)
139
+ return estimator_object
123
140
 
124
141
  @F.pandas_udf( # type: ignore[arg-type, misc]
125
142
  is_permanent=False,
@@ -129,6 +146,7 @@ class SnowparkTransformHandlers:
129
146
  session=session,
130
147
  statement_params=statement_params,
131
148
  input_types=[T.PandasDataFrameType(input_datatypes)],
149
+ imports=imports, # type: ignore[arg-type]
132
150
  )
133
151
  def vec_batch_infer(input_df: pd.DataFrame) -> T.PandasSeries[dict]: # type: ignore[type-arg]
134
152
  import numpy as np # noqa: F401
@@ -136,6 +154,8 @@ class SnowparkTransformHandlers:
136
154
 
137
155
  input_df.columns = snowpark_cols
138
156
 
157
+ estimator = load_estimator()
158
+
139
159
  if hasattr(estimator, "n_jobs"):
140
160
  # Vectorized UDF cannot handle joblib multiprocessing right now, deactivate the n_jobs
141
161
  estimator.n_jobs = 1
@@ -225,7 +245,7 @@ class SnowparkTransformHandlers:
225
245
  queries = dataset.queries["queries"]
226
246
 
227
247
  # Create a temp file and dump the score to that file.
228
- local_score_file_name = get_temp_file_path()
248
+ local_score_file_name = temp_file_utils.get_temp_file_path()
229
249
  with open(local_score_file_name, mode="w+b") as local_score_file:
230
250
  cp.dump(estimator, local_score_file)
231
251
 
@@ -247,7 +267,7 @@ class SnowparkTransformHandlers:
247
267
  inspect.currentframe(), self.__class__.__name__
248
268
  ),
249
269
  api_calls=[F.sproc],
250
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
270
+ custom_tags={"autogen": True} if self._autogenerated else None,
251
271
  )
252
272
  # Put locally serialized score on stage.
253
273
  session.file.put(
@@ -290,7 +310,7 @@ class SnowparkTransformHandlers:
290
310
  df: pd.DataFrame = sp_df.to_pandas(statement_params=score_statement_params)
291
311
  df.columns = sp_df.columns
292
312
 
293
- local_score_file_name = get_temp_file_path()
313
+ local_score_file_name = temp_file_utils.get_temp_file_path()
294
314
  session.file.get(stage_score_file_name, local_score_file_name, statement_params=score_statement_params)
295
315
 
296
316
  local_score_file_name_path = os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0])
@@ -323,7 +343,7 @@ class SnowparkTransformHandlers:
323
343
  inspect.currentframe(), self.__class__.__name__
324
344
  ),
325
345
  api_calls=[Session.call],
326
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
346
+ custom_tags={"autogen": True} if self._autogenerated else None,
327
347
  )
328
348
 
329
349
  kwargs = telemetry.get_sproc_statement_params_kwargs(score_wrapper_sproc, score_statement_params)
@@ -338,7 +358,7 @@ class SnowparkTransformHandlers:
338
358
  **kwargs,
339
359
  )
340
360
 
341
- cleanup_temp_files([local_score_file_name])
361
+ temp_file_utils.cleanup_temp_files([local_score_file_name])
342
362
 
343
363
  return score
344
364
 
@@ -17,30 +17,19 @@ from snowflake.ml._internal.utils import (
17
17
  identifier,
18
18
  pkg_version_utils,
19
19
  snowpark_dataframe_utils,
20
+ temp_file_utils,
20
21
  )
21
- from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
22
- from snowflake.ml._internal.utils.temp_file_utils import (
23
- cleanup_temp_files,
24
- get_temp_file_path,
25
- )
22
+ from snowflake.ml.modeling._internal import estimator_utils
26
23
  from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
27
24
  from snowflake.ml.modeling._internal.model_specifications import (
28
25
  ModelSpecifications,
29
26
  ModelSpecificationsBuilder,
30
27
  )
31
- from snowflake.snowpark import (
32
- DataFrame,
33
- Session,
34
- exceptions as snowpark_exceptions,
35
- functions as F,
36
- )
37
- from snowflake.snowpark._internal.utils import (
38
- TempObjectType,
39
- random_name_for_temp_object,
40
- )
28
+ from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions
29
+ from snowflake.snowpark._internal import utils as snowpark_utils
41
30
  from snowflake.snowpark.stored_procedure import StoredProcedure
42
31
 
43
- cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
32
+ cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
44
33
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
45
34
  cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
46
35
 
@@ -90,60 +79,6 @@ class SnowparkModelTrainer:
90
79
  self._subproject = subproject
91
80
  self._class_name = estimator.__class__.__name__
92
81
 
93
- def _create_temp_stage(self) -> str:
94
- """
95
- Creates temporary stage.
96
-
97
- Returns:
98
- Temp stage name.
99
- """
100
- # Create temp stage to upload pickled model file.
101
- transform_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
102
- stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
103
- SqlResultValidator(session=self.session, query=stage_creation_query).has_dimensions(
104
- expected_rows=1, expected_cols=1
105
- ).validate()
106
- return transform_stage_name
107
-
108
- def _upload_model_to_stage(self, stage_name: str) -> Tuple[str, str]:
109
- """
110
- Util method to pickle and upload the model to a temp Snowflake stage.
111
-
112
- Args:
113
- stage_name: Stage name to save model.
114
-
115
- Returns:
116
- a tuple containing stage file paths for pickled input model for training and location to store trained
117
- models(response from training sproc).
118
- """
119
- # Create a temp file and dump the transform to that file.
120
- local_transform_file_name = get_temp_file_path()
121
- with open(local_transform_file_name, mode="w+b") as local_transform_file:
122
- cp.dump(self.estimator, local_transform_file)
123
-
124
- # Use posixpath to construct stage paths
125
- stage_transform_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
126
- stage_result_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
127
-
128
- statement_params = telemetry.get_function_usage_statement_params(
129
- project=_PROJECT,
130
- subproject=self._subproject,
131
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
132
- api_calls=[F.sproc],
133
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
134
- )
135
- # Put locally serialized transform on stage.
136
- self.session.file.put(
137
- local_transform_file_name,
138
- stage_transform_file_name,
139
- auto_compress=False,
140
- overwrite=True,
141
- statement_params=statement_params,
142
- )
143
-
144
- cleanup_temp_files([local_transform_file_name])
145
- return (stage_transform_file_name, stage_result_file_name)
146
-
147
82
  def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object:
148
83
  """
149
84
  Downloads the serialized model from a stage location and unpickles it.
@@ -156,7 +91,7 @@ class SnowparkModelTrainer:
156
91
  Returns:
157
92
  Deserialized model object.
158
93
  """
159
- local_result_file_name = get_temp_file_path()
94
+ local_result_file_name = temp_file_utils.get_temp_file_path()
160
95
  self.session.file.get(
161
96
  posixpath.join(dir_path, file_name),
162
97
  local_result_file_name,
@@ -166,13 +101,13 @@ class SnowparkModelTrainer:
166
101
  with open(os.path.join(local_result_file_name, file_name), mode="r+b") as result_file_obj:
167
102
  fit_estimator = cp.load(result_file_obj)
168
103
 
169
- cleanup_temp_files([local_result_file_name])
104
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
170
105
  return fit_estimator
171
106
 
172
107
  def _build_fit_wrapper_sproc(
173
108
  self,
174
109
  model_spec: ModelSpecifications,
175
- ) -> Callable[[Any, List[str], str, str, List[str], List[str], Optional[str], Dict[str, str]], str]:
110
+ ) -> Callable[[Any, List[str], str, List[str], List[str], Optional[str], Dict[str, str]], str]:
176
111
  """
177
112
  Constructs and returns a python stored procedure function to be used for training model.
178
113
 
@@ -188,8 +123,7 @@ class SnowparkModelTrainer:
188
123
  def fit_wrapper_function(
189
124
  session: Session,
190
125
  sql_queries: List[str],
191
- stage_transform_file_name: str,
192
- stage_result_file_name: str,
126
+ temp_stage_name: str,
193
127
  input_cols: List[str],
194
128
  label_cols: List[str],
195
129
  sample_weight_col: Optional[str],
@@ -212,9 +146,13 @@ class SnowparkModelTrainer:
212
146
  df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
213
147
  df.columns = sp_df.columns
214
148
 
215
- local_transform_file_name = get_temp_file_path()
149
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
216
150
 
217
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
151
+ session.file.get(
152
+ stage_location=temp_stage_name,
153
+ target_directory=local_transform_file_name,
154
+ statement_params=statement_params,
155
+ )
218
156
 
219
157
  local_transform_file_path = os.path.join(
220
158
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -233,14 +171,14 @@ class SnowparkModelTrainer:
233
171
 
234
172
  estimator.fit(**args)
235
173
 
236
- local_result_file_name = get_temp_file_path()
174
+ local_result_file_name = temp_file_utils.get_temp_file_path()
237
175
 
238
176
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
239
177
  cp.dump(estimator, local_result_file_obj)
240
178
 
241
179
  session.file.put(
242
- local_result_file_name,
243
- stage_result_file_name,
180
+ local_file_name=local_result_file_name,
181
+ stage_location=temp_stage_name,
244
182
  auto_compress=False,
245
183
  overwrite=True,
246
184
  statement_params=statement_params,
@@ -254,7 +192,7 @@ class SnowparkModelTrainer:
254
192
 
255
193
  def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
256
194
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
257
- fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
195
+ fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
258
196
 
259
197
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
260
198
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -284,7 +222,7 @@ class SnowparkModelTrainer:
284
222
  fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
285
223
  return fit_sproc
286
224
 
287
- fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
225
+ fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
288
226
 
289
227
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
290
228
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -307,7 +245,7 @@ class SnowparkModelTrainer:
307
245
  def _build_fit_predict_wrapper_sproc(
308
246
  self,
309
247
  model_spec: ModelSpecifications,
310
- ) -> Callable[[Session, List[str], str, str, List[str], Dict[str, str], bool, List[str], str], str]:
248
+ ) -> Callable[[Session, List[str], str, List[str], Dict[str, str], bool, List[str], str], str]:
311
249
  """
312
250
  Constructs and returns a python stored procedure function to be used for training model.
313
251
 
@@ -323,8 +261,7 @@ class SnowparkModelTrainer:
323
261
  def fit_predict_wrapper_function(
324
262
  session: Session,
325
263
  sql_queries: List[str],
326
- stage_transform_file_name: str,
327
- stage_result_file_name: str,
264
+ temp_stage_name: str,
328
265
  input_cols: List[str],
329
266
  statement_params: Dict[str, str],
330
267
  drop_input_cols: bool,
@@ -347,9 +284,13 @@ class SnowparkModelTrainer:
347
284
  df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
348
285
  df.columns = sp_df.columns
349
286
 
350
- local_transform_file_name = get_temp_file_path()
287
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
351
288
 
352
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
289
+ session.file.get(
290
+ stage_location=temp_stage_name,
291
+ target_directory=local_transform_file_name,
292
+ statement_params=statement_params,
293
+ )
353
294
 
354
295
  local_transform_file_path = os.path.join(
355
296
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -359,14 +300,14 @@ class SnowparkModelTrainer:
359
300
 
360
301
  fit_predict_result = estimator.fit_predict(X=df[input_cols])
361
302
 
362
- local_result_file_name = get_temp_file_path()
303
+ local_result_file_name = temp_file_utils.get_temp_file_path()
363
304
 
364
305
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
365
306
  cp.dump(estimator, local_result_file_obj)
366
307
 
367
308
  session.file.put(
368
- local_result_file_name,
369
- stage_result_file_name,
309
+ local_file_name=local_result_file_name,
310
+ stage_location=temp_stage_name,
370
311
  auto_compress=False,
371
312
  overwrite=True,
372
313
  statement_params=statement_params,
@@ -407,7 +348,6 @@ class SnowparkModelTrainer:
407
348
  Session,
408
349
  List[str],
409
350
  str,
410
- str,
411
351
  List[str],
412
352
  Optional[List[str]],
413
353
  Optional[str],
@@ -433,8 +373,7 @@ class SnowparkModelTrainer:
433
373
  def fit_transform_wrapper_function(
434
374
  session: Session,
435
375
  sql_queries: List[str],
436
- stage_transform_file_name: str,
437
- stage_result_file_name: str,
376
+ temp_stage_name: str,
438
377
  input_cols: List[str],
439
378
  label_cols: Optional[List[str]],
440
379
  sample_weight_col: Optional[str],
@@ -459,9 +398,13 @@ class SnowparkModelTrainer:
459
398
  df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
460
399
  df.columns = sp_df.columns
461
400
 
462
- local_transform_file_name = get_temp_file_path()
401
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
463
402
 
464
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
403
+ session.file.get(
404
+ stage_location=temp_stage_name,
405
+ target_directory=local_transform_file_name,
406
+ statement_params=statement_params,
407
+ )
465
408
 
466
409
  local_transform_file_path = os.path.join(
467
410
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -480,14 +423,14 @@ class SnowparkModelTrainer:
480
423
 
481
424
  fit_transform_result = estimator.fit_transform(**args)
482
425
 
483
- local_result_file_name = get_temp_file_path()
426
+ local_result_file_name = temp_file_utils.get_temp_file_path()
484
427
 
485
428
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
486
429
  cp.dump(estimator, local_result_file_obj)
487
430
 
488
431
  session.file.put(
489
- local_result_file_name,
490
- stage_result_file_name,
432
+ local_file_name=local_result_file_name,
433
+ stage_location=temp_stage_name,
491
434
  auto_compress=False,
492
435
  overwrite=True,
493
436
  statement_params=statement_params,
@@ -535,7 +478,7 @@ class SnowparkModelTrainer:
535
478
  def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
536
479
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
537
480
 
538
- fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
481
+ fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
539
482
 
540
483
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
541
484
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -567,7 +510,7 @@ class SnowparkModelTrainer:
567
510
  ]
568
511
  return fit_sproc
569
512
 
570
- fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
513
+ fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
571
514
 
572
515
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
573
516
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -592,7 +535,7 @@ class SnowparkModelTrainer:
592
535
  def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
593
536
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
594
537
 
595
- fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
538
+ fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
596
539
 
597
540
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
598
541
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -623,7 +566,7 @@ class SnowparkModelTrainer:
623
566
  ]
624
567
  return fit_sproc
625
568
 
626
- fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
569
+ fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
627
570
 
628
571
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
629
572
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -663,19 +606,21 @@ class SnowparkModelTrainer:
663
606
  # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
664
607
  queries = dataset.queries["queries"]
665
608
 
666
- transform_stage_name = self._create_temp_stage()
667
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
668
- stage_name=transform_stage_name
669
- )
670
-
671
- # Call fit sproc
609
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
672
610
  statement_params = telemetry.get_function_usage_statement_params(
673
611
  project=_PROJECT,
674
612
  subproject=self._subproject,
675
613
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
676
614
  api_calls=[Session.call],
677
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
615
+ custom_tags={"autogen": True} if self._autogenerated else None,
678
616
  )
617
+ estimator_utils.upload_model_to_stage(
618
+ stage_name=temp_stage_name,
619
+ estimator=self.estimator,
620
+ session=self.session,
621
+ statement_params=statement_params,
622
+ )
623
+ # Call fit sproc
679
624
 
680
625
  if _ENABLE_ANONYMOUS_SPROC:
681
626
  fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params)
@@ -686,8 +631,7 @@ class SnowparkModelTrainer:
686
631
  sproc_export_file_name: str = fit_wrapper_sproc(
687
632
  self.session,
688
633
  queries,
689
- stage_transform_file_name,
690
- stage_result_file_name,
634
+ temp_stage_name,
691
635
  self.input_cols,
692
636
  self.label_cols,
693
637
  self.sample_weight_col,
@@ -706,7 +650,7 @@ class SnowparkModelTrainer:
706
650
  sproc_export_file_name = fields[0]
707
651
 
708
652
  return self._fetch_model_from_stage(
709
- dir_path=stage_result_file_name,
653
+ dir_path=temp_stage_name,
710
654
  file_name=sproc_export_file_name,
711
655
  statement_params=statement_params,
712
656
  )
@@ -734,32 +678,34 @@ class SnowparkModelTrainer:
734
678
  # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
735
679
  queries = dataset.queries["queries"]
736
680
 
737
- transform_stage_name = self._create_temp_stage()
738
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
739
- stage_name=transform_stage_name
740
- )
741
-
742
- # Call fit sproc
743
681
  statement_params = telemetry.get_function_usage_statement_params(
744
682
  project=_PROJECT,
745
683
  subproject=self._subproject,
746
684
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
747
685
  api_calls=[Session.call],
748
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
686
+ custom_tags={"autogen": True} if self._autogenerated else None,
749
687
  )
750
688
 
689
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
690
+ estimator_utils.upload_model_to_stage(
691
+ stage_name=temp_stage_name,
692
+ estimator=self.estimator,
693
+ session=self.session,
694
+ statement_params=statement_params,
695
+ )
696
+
697
+ # Call fit sproc
751
698
  if _ENABLE_ANONYMOUS_SPROC:
752
699
  fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params)
753
700
  else:
754
701
  fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
755
702
 
756
- fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
703
+ fit_predict_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
757
704
 
758
705
  sproc_export_file_name: str = fit_predict_wrapper_sproc(
759
706
  self.session,
760
707
  queries,
761
- stage_transform_file_name,
762
- stage_result_file_name,
708
+ temp_stage_name,
763
709
  self.input_cols,
764
710
  statement_params,
765
711
  drop_input_cols,
@@ -769,7 +715,7 @@ class SnowparkModelTrainer:
769
715
 
770
716
  output_result_sp = self.session.table(fit_predict_result_name)
771
717
  fitted_estimator = self._fetch_model_from_stage(
772
- dir_path=stage_result_file_name,
718
+ dir_path=temp_stage_name,
773
719
  file_name=sproc_export_file_name,
774
720
  statement_params=statement_params,
775
721
  )
@@ -799,20 +745,23 @@ class SnowparkModelTrainer:
799
745
  # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
800
746
  queries = dataset.queries["queries"]
801
747
 
802
- transform_stage_name = self._create_temp_stage()
803
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
804
- stage_name=transform_stage_name
805
- )
806
-
807
- # Call fit sproc
808
748
  statement_params = telemetry.get_function_usage_statement_params(
809
749
  project=_PROJECT,
810
750
  subproject=self._subproject,
811
751
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
812
752
  api_calls=[Session.call],
813
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
753
+ custom_tags={"autogen": True} if self._autogenerated else None,
754
+ )
755
+
756
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
757
+ estimator_utils.upload_model_to_stage(
758
+ stage_name=temp_stage_name,
759
+ estimator=self.estimator,
760
+ session=self.session,
761
+ statement_params=statement_params,
814
762
  )
815
763
 
764
+ # Call fit sproc
816
765
  if _ENABLE_ANONYMOUS_SPROC:
817
766
  fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous(
818
767
  statement_params=statement_params
@@ -820,13 +769,12 @@ class SnowparkModelTrainer:
820
769
  else:
821
770
  fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
822
771
 
823
- fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
772
+ fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
824
773
 
825
774
  sproc_export_file_name: str = fit_transform_wrapper_sproc(
826
775
  self.session,
827
776
  queries,
828
- stage_transform_file_name,
829
- stage_result_file_name,
777
+ temp_stage_name,
830
778
  self.input_cols,
831
779
  self.label_cols,
832
780
  self.sample_weight_col,
@@ -838,7 +786,7 @@ class SnowparkModelTrainer:
838
786
 
839
787
  output_result_sp = self.session.table(fit_transform_result_name)
840
788
  fitted_estimator = self._fetch_model_from_stage(
841
- dir_path=stage_result_file_name,
789
+ dir_path=temp_stage_name,
842
790
  file_name=sproc_export_file_name,
843
791
  statement_params=statement_params,
844
792
  )