snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. snowflake/cortex/__init__.py +2 -1
  2. snowflake/cortex/_complete.py +240 -16
  3. snowflake/cortex/_extract_answer.py +0 -1
  4. snowflake/cortex/_sentiment.py +0 -1
  5. snowflake/cortex/_sse_client.py +81 -0
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +34 -10
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  16. snowflake/ml/_internal/telemetry.py +26 -0
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/dataset/dataset.py +54 -32
  20. snowflake/ml/dataset/dataset_factory.py +3 -4
  21. snowflake/ml/feature_store/feature_store.py +440 -243
  22. snowflake/ml/feature_store/feature_view.py +61 -9
  23. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  24. snowflake/ml/fileset/fileset.py +2 -2
  25. snowflake/ml/fileset/snowfs.py +4 -15
  26. snowflake/ml/fileset/stage_fs.py +6 -8
  27. snowflake/ml/lineage/__init__.py +3 -0
  28. snowflake/ml/lineage/lineage_node.py +139 -0
  29. snowflake/ml/model/_client/model/model_impl.py +47 -14
  30. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  31. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  32. snowflake/ml/model/_client/sql/model.py +1 -0
  33. snowflake/ml/model/_client/sql/model_version.py +47 -4
  34. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  35. snowflake/ml/model/_model_composer/model_composer.py +7 -6
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
  37. snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
  38. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -3
  40. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
  41. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  42. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
  43. snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
  44. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  45. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  46. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
  47. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  48. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  49. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  50. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  51. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  52. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  53. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  56. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  57. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  58. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  59. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  60. snowflake/ml/model/_packager/model_packager.py +9 -4
  61. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  62. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  63. snowflake/ml/model/_signatures/core.py +13 -1
  64. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  65. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  66. snowflake/ml/model/custom_model.py +22 -2
  67. snowflake/ml/model/model_signature.py +2 -0
  68. snowflake/ml/model/type_hints.py +74 -4
  69. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  70. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +158 -121
  71. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
  72. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +39 -18
  73. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +88 -134
  74. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +22 -17
  75. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  76. snowflake/ml/modeling/cluster/affinity_propagation.py +5 -3
  77. snowflake/ml/modeling/cluster/agglomerative_clustering.py +5 -3
  78. snowflake/ml/modeling/cluster/birch.py +5 -3
  79. snowflake/ml/modeling/cluster/bisecting_k_means.py +5 -3
  80. snowflake/ml/modeling/cluster/dbscan.py +5 -3
  81. snowflake/ml/modeling/cluster/feature_agglomeration.py +5 -3
  82. snowflake/ml/modeling/cluster/k_means.py +5 -3
  83. snowflake/ml/modeling/cluster/mean_shift.py +5 -3
  84. snowflake/ml/modeling/cluster/mini_batch_k_means.py +5 -3
  85. snowflake/ml/modeling/cluster/optics.py +5 -3
  86. snowflake/ml/modeling/cluster/spectral_biclustering.py +5 -3
  87. snowflake/ml/modeling/cluster/spectral_clustering.py +5 -3
  88. snowflake/ml/modeling/cluster/spectral_coclustering.py +5 -3
  89. snowflake/ml/modeling/compose/column_transformer.py +5 -3
  90. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  91. snowflake/ml/modeling/covariance/elliptic_envelope.py +5 -3
  92. snowflake/ml/modeling/covariance/empirical_covariance.py +5 -3
  93. snowflake/ml/modeling/covariance/graphical_lasso.py +5 -3
  94. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +5 -3
  95. snowflake/ml/modeling/covariance/ledoit_wolf.py +5 -3
  96. snowflake/ml/modeling/covariance/min_cov_det.py +5 -3
  97. snowflake/ml/modeling/covariance/oas.py +5 -3
  98. snowflake/ml/modeling/covariance/shrunk_covariance.py +5 -3
  99. snowflake/ml/modeling/decomposition/dictionary_learning.py +5 -3
  100. snowflake/ml/modeling/decomposition/factor_analysis.py +5 -3
  101. snowflake/ml/modeling/decomposition/fast_ica.py +5 -3
  102. snowflake/ml/modeling/decomposition/incremental_pca.py +5 -3
  103. snowflake/ml/modeling/decomposition/kernel_pca.py +5 -3
  104. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -3
  105. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -3
  106. snowflake/ml/modeling/decomposition/pca.py +5 -3
  107. snowflake/ml/modeling/decomposition/sparse_pca.py +5 -3
  108. snowflake/ml/modeling/decomposition/truncated_svd.py +5 -3
  109. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  110. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  111. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  112. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  113. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  114. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  115. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  116. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  117. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  118. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  119. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  120. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  121. snowflake/ml/modeling/ensemble/isolation_forest.py +5 -3
  122. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  123. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  124. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  125. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  126. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  127. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  128. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  129. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  130. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  131. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  132. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  133. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  134. snowflake/ml/modeling/feature_selection/variance_threshold.py +5 -3
  135. snowflake/ml/modeling/framework/base.py +3 -8
  136. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  137. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  138. snowflake/ml/modeling/impute/iterative_imputer.py +5 -3
  139. snowflake/ml/modeling/impute/knn_imputer.py +5 -3
  140. snowflake/ml/modeling/impute/missing_indicator.py +5 -3
  141. snowflake/ml/modeling/impute/simple_imputer.py +8 -4
  142. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +5 -3
  143. snowflake/ml/modeling/kernel_approximation/nystroem.py +5 -3
  144. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +5 -3
  145. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +5 -3
  146. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +5 -3
  147. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  148. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  149. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  151. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  152. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  153. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  154. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  155. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  156. snowflake/ml/modeling/linear_model/lars.py +1 -1
  157. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  158. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  159. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  160. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  161. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  162. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  163. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  164. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  165. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  166. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  167. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  168. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  169. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  170. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  171. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  172. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  173. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  174. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  175. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  176. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  177. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  178. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  179. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  180. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  181. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -3
  182. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  183. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  184. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  185. snowflake/ml/modeling/manifold/isomap.py +5 -3
  186. snowflake/ml/modeling/manifold/mds.py +5 -3
  187. snowflake/ml/modeling/manifold/spectral_embedding.py +5 -3
  188. snowflake/ml/modeling/manifold/tsne.py +5 -3
  189. snowflake/ml/modeling/metrics/ranking.py +3 -0
  190. snowflake/ml/modeling/metrics/regression.py +3 -0
  191. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +5 -3
  192. snowflake/ml/modeling/mixture/gaussian_mixture.py +5 -3
  193. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  194. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  195. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  196. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  197. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  198. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  199. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  200. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  201. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  202. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  203. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  204. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  205. snowflake/ml/modeling/neighbors/kernel_density.py +5 -3
  206. snowflake/ml/modeling/neighbors/local_outlier_factor.py +5 -3
  207. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  208. snowflake/ml/modeling/neighbors/nearest_neighbors.py +5 -3
  209. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  210. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  211. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  212. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +5 -3
  213. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  214. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  215. snowflake/ml/modeling/pipeline/pipeline.py +6 -0
  216. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  217. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  218. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  219. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  220. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  221. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  222. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +53 -11
  223. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +44 -13
  224. snowflake/ml/modeling/preprocessing/polynomial_features.py +5 -3
  225. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  226. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  227. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  228. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  229. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  230. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  231. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  232. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  233. snowflake/ml/modeling/svm/svc.py +1 -1
  234. snowflake/ml/modeling/svm/svr.py +1 -1
  235. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  236. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  237. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  238. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  239. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  240. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  241. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  242. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  243. snowflake/ml/registry/_manager/model_manager.py +16 -3
  244. snowflake/ml/version.py +1 -1
  245. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +51 -7
  246. snowflake_ml_python-1.5.4.dist-info/RECORD +389 -0
  247. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
  248. snowflake_ml_python-1.5.2.dist-info/RECORD +0 -384
  249. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
  250. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import importlib
2
2
  import inspect
3
3
  import os
4
4
  import posixpath
5
+ import sys
5
6
  from typing import Any, Dict, List, Optional
6
7
  from uuid import uuid4
7
8
 
@@ -13,12 +14,10 @@ from snowflake.ml._internal.utils import (
13
14
  identifier,
14
15
  pkg_version_utils,
15
16
  snowpark_dataframe_utils,
17
+ temp_file_utils,
16
18
  )
17
19
  from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
18
- from snowflake.ml._internal.utils.temp_file_utils import (
19
- cleanup_temp_files,
20
- get_temp_file_path,
21
- )
20
+ from snowflake.ml.modeling._internal import estimator_utils
22
21
  from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
23
22
  from snowflake.snowpark import DataFrame, Session, functions as F, types as T
24
23
  from snowflake.snowpark._internal.utils import (
@@ -26,7 +25,7 @@ from snowflake.snowpark._internal.utils import (
26
25
  random_name_for_temp_object,
27
26
  )
28
27
 
29
- cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
28
+ cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
30
29
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
31
30
  cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
32
31
 
@@ -97,7 +96,25 @@ class SnowparkTransformHandlers:
97
96
 
98
97
  dependencies = self._get_validated_snowpark_dependencies(session, dependencies)
99
98
  dataset = self.dataset
100
- estimator = self.estimator
99
+
100
+ statement_params = telemetry.get_function_usage_statement_params(
101
+ project=_PROJECT,
102
+ subproject=self._subproject,
103
+ function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
104
+ api_calls=[F.pandas_udf],
105
+ custom_tags={"autogen": True} if self._autogenerated else None,
106
+ )
107
+
108
+ temp_stage_name = estimator_utils.create_temp_stage(session)
109
+
110
+ estimator_file_name = estimator_utils.upload_model_to_stage(
111
+ stage_name=temp_stage_name,
112
+ estimator=self.estimator,
113
+ session=session,
114
+ statement_params=statement_params,
115
+ )
116
+ imports = [f"@{temp_stage_name}/{estimator_file_name}"]
117
+
101
118
  # Register vectorized UDF for batch inference
102
119
  batch_inference_udf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
103
120
 
@@ -113,13 +130,13 @@ class SnowparkTransformHandlers:
113
130
  for field in fields:
114
131
  input_datatypes.append(field.datatype)
115
132
 
116
- statement_params = telemetry.get_function_usage_statement_params(
117
- project=_PROJECT,
118
- subproject=self._subproject,
119
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
120
- api_calls=[F.pandas_udf],
121
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
122
- )
133
+ # TODO(xjiang): for optimization, use register_from_file to reduce duplicate loading estimator object
134
+ # or use cachetools here
135
+ def load_estimator() -> object:
136
+ estimator_file_path = os.path.join(sys._xoptions["snowflake_import_directory"], f"{estimator_file_name}")
137
+ with open(estimator_file_path, mode="rb") as local_estimator_file_obj:
138
+ estimator_object = cp.load(local_estimator_file_obj)
139
+ return estimator_object
123
140
 
124
141
  @F.pandas_udf( # type: ignore[arg-type, misc]
125
142
  is_permanent=False,
@@ -129,6 +146,7 @@ class SnowparkTransformHandlers:
129
146
  session=session,
130
147
  statement_params=statement_params,
131
148
  input_types=[T.PandasDataFrameType(input_datatypes)],
149
+ imports=imports, # type: ignore[arg-type]
132
150
  )
133
151
  def vec_batch_infer(input_df: pd.DataFrame) -> T.PandasSeries[dict]: # type: ignore[type-arg]
134
152
  import numpy as np # noqa: F401
@@ -136,6 +154,8 @@ class SnowparkTransformHandlers:
136
154
 
137
155
  input_df.columns = snowpark_cols
138
156
 
157
+ estimator = load_estimator()
158
+
139
159
  if hasattr(estimator, "n_jobs"):
140
160
  # Vectorized UDF cannot handle joblib multiprocessing right now, deactivate the n_jobs
141
161
  estimator.n_jobs = 1
@@ -225,7 +245,7 @@ class SnowparkTransformHandlers:
225
245
  queries = dataset.queries["queries"]
226
246
 
227
247
  # Create a temp file and dump the score to that file.
228
- local_score_file_name = get_temp_file_path()
248
+ local_score_file_name = temp_file_utils.get_temp_file_path()
229
249
  with open(local_score_file_name, mode="w+b") as local_score_file:
230
250
  cp.dump(estimator, local_score_file)
231
251
 
@@ -247,7 +267,7 @@ class SnowparkTransformHandlers:
247
267
  inspect.currentframe(), self.__class__.__name__
248
268
  ),
249
269
  api_calls=[F.sproc],
250
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
270
+ custom_tags={"autogen": True} if self._autogenerated else None,
251
271
  )
252
272
  # Put locally serialized score on stage.
253
273
  session.file.put(
@@ -266,6 +286,7 @@ class SnowparkTransformHandlers:
266
286
  session=session,
267
287
  statement_params=statement_params,
268
288
  anonymous=True,
289
+ execute_as="caller",
269
290
  )
270
291
  def score_wrapper_sproc(
271
292
  session: Session,
@@ -290,7 +311,7 @@ class SnowparkTransformHandlers:
290
311
  df: pd.DataFrame = sp_df.to_pandas(statement_params=score_statement_params)
291
312
  df.columns = sp_df.columns
292
313
 
293
- local_score_file_name = get_temp_file_path()
314
+ local_score_file_name = temp_file_utils.get_temp_file_path()
294
315
  session.file.get(stage_score_file_name, local_score_file_name, statement_params=score_statement_params)
295
316
 
296
317
  local_score_file_name_path = os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0])
@@ -323,7 +344,7 @@ class SnowparkTransformHandlers:
323
344
  inspect.currentframe(), self.__class__.__name__
324
345
  ),
325
346
  api_calls=[Session.call],
326
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
347
+ custom_tags={"autogen": True} if self._autogenerated else None,
327
348
  )
328
349
 
329
350
  kwargs = telemetry.get_sproc_statement_params_kwargs(score_wrapper_sproc, score_statement_params)
@@ -338,7 +359,7 @@ class SnowparkTransformHandlers:
338
359
  **kwargs,
339
360
  )
340
361
 
341
- cleanup_temp_files([local_score_file_name])
362
+ temp_file_utils.cleanup_temp_files([local_score_file_name])
342
363
 
343
364
  return score
344
365
 
@@ -17,30 +17,19 @@ from snowflake.ml._internal.utils import (
17
17
  identifier,
18
18
  pkg_version_utils,
19
19
  snowpark_dataframe_utils,
20
+ temp_file_utils,
20
21
  )
21
- from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
22
- from snowflake.ml._internal.utils.temp_file_utils import (
23
- cleanup_temp_files,
24
- get_temp_file_path,
25
- )
22
+ from snowflake.ml.modeling._internal import estimator_utils
26
23
  from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
27
24
  from snowflake.ml.modeling._internal.model_specifications import (
28
25
  ModelSpecifications,
29
26
  ModelSpecificationsBuilder,
30
27
  )
31
- from snowflake.snowpark import (
32
- DataFrame,
33
- Session,
34
- exceptions as snowpark_exceptions,
35
- functions as F,
36
- )
37
- from snowflake.snowpark._internal.utils import (
38
- TempObjectType,
39
- random_name_for_temp_object,
40
- )
28
+ from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions
29
+ from snowflake.snowpark._internal import utils as snowpark_utils
41
30
  from snowflake.snowpark.stored_procedure import StoredProcedure
42
31
 
43
- cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
32
+ cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
44
33
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
45
34
  cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
46
35
 
@@ -90,60 +79,6 @@ class SnowparkModelTrainer:
90
79
  self._subproject = subproject
91
80
  self._class_name = estimator.__class__.__name__
92
81
 
93
- def _create_temp_stage(self) -> str:
94
- """
95
- Creates temporary stage.
96
-
97
- Returns:
98
- Temp stage name.
99
- """
100
- # Create temp stage to upload pickled model file.
101
- transform_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
102
- stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
103
- SqlResultValidator(session=self.session, query=stage_creation_query).has_dimensions(
104
- expected_rows=1, expected_cols=1
105
- ).validate()
106
- return transform_stage_name
107
-
108
- def _upload_model_to_stage(self, stage_name: str) -> Tuple[str, str]:
109
- """
110
- Util method to pickle and upload the model to a temp Snowflake stage.
111
-
112
- Args:
113
- stage_name: Stage name to save model.
114
-
115
- Returns:
116
- a tuple containing stage file paths for pickled input model for training and location to store trained
117
- models(response from training sproc).
118
- """
119
- # Create a temp file and dump the transform to that file.
120
- local_transform_file_name = get_temp_file_path()
121
- with open(local_transform_file_name, mode="w+b") as local_transform_file:
122
- cp.dump(self.estimator, local_transform_file)
123
-
124
- # Use posixpath to construct stage paths
125
- stage_transform_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
126
- stage_result_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
127
-
128
- statement_params = telemetry.get_function_usage_statement_params(
129
- project=_PROJECT,
130
- subproject=self._subproject,
131
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
132
- api_calls=[F.sproc],
133
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
134
- )
135
- # Put locally serialized transform on stage.
136
- self.session.file.put(
137
- local_transform_file_name,
138
- stage_transform_file_name,
139
- auto_compress=False,
140
- overwrite=True,
141
- statement_params=statement_params,
142
- )
143
-
144
- cleanup_temp_files([local_transform_file_name])
145
- return (stage_transform_file_name, stage_result_file_name)
146
-
147
82
  def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object:
148
83
  """
149
84
  Downloads the serialized model from a stage location and unpickles it.
@@ -156,7 +91,7 @@ class SnowparkModelTrainer:
156
91
  Returns:
157
92
  Deserialized model object.
158
93
  """
159
- local_result_file_name = get_temp_file_path()
94
+ local_result_file_name = temp_file_utils.get_temp_file_path()
160
95
  self.session.file.get(
161
96
  posixpath.join(dir_path, file_name),
162
97
  local_result_file_name,
@@ -166,13 +101,13 @@ class SnowparkModelTrainer:
166
101
  with open(os.path.join(local_result_file_name, file_name), mode="r+b") as result_file_obj:
167
102
  fit_estimator = cp.load(result_file_obj)
168
103
 
169
- cleanup_temp_files([local_result_file_name])
104
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
170
105
  return fit_estimator
171
106
 
172
107
  def _build_fit_wrapper_sproc(
173
108
  self,
174
109
  model_spec: ModelSpecifications,
175
- ) -> Callable[[Any, List[str], str, str, List[str], List[str], Optional[str], Dict[str, str]], str]:
110
+ ) -> Callable[[Any, List[str], str, List[str], List[str], Optional[str], Dict[str, str]], str]:
176
111
  """
177
112
  Constructs and returns a python stored procedure function to be used for training model.
178
113
 
@@ -188,8 +123,7 @@ class SnowparkModelTrainer:
188
123
  def fit_wrapper_function(
189
124
  session: Session,
190
125
  sql_queries: List[str],
191
- stage_transform_file_name: str,
192
- stage_result_file_name: str,
126
+ temp_stage_name: str,
193
127
  input_cols: List[str],
194
128
  label_cols: List[str],
195
129
  sample_weight_col: Optional[str],
@@ -212,9 +146,13 @@ class SnowparkModelTrainer:
212
146
  df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
213
147
  df.columns = sp_df.columns
214
148
 
215
- local_transform_file_name = get_temp_file_path()
149
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
216
150
 
217
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
151
+ session.file.get(
152
+ stage_location=temp_stage_name,
153
+ target_directory=local_transform_file_name,
154
+ statement_params=statement_params,
155
+ )
218
156
 
219
157
  local_transform_file_path = os.path.join(
220
158
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -233,14 +171,14 @@ class SnowparkModelTrainer:
233
171
 
234
172
  estimator.fit(**args)
235
173
 
236
- local_result_file_name = get_temp_file_path()
174
+ local_result_file_name = temp_file_utils.get_temp_file_path()
237
175
 
238
176
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
239
177
  cp.dump(estimator, local_result_file_obj)
240
178
 
241
179
  session.file.put(
242
- local_result_file_name,
243
- stage_result_file_name,
180
+ local_file_name=local_result_file_name,
181
+ stage_location=temp_stage_name,
244
182
  auto_compress=False,
245
183
  overwrite=True,
246
184
  statement_params=statement_params,
@@ -254,7 +192,7 @@ class SnowparkModelTrainer:
254
192
 
255
193
  def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
256
194
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
257
- fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
195
+ fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
258
196
 
259
197
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
260
198
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -269,6 +207,7 @@ class SnowparkModelTrainer:
269
207
  session=self.session,
270
208
  statement_params=statement_params,
271
209
  anonymous=True,
210
+ execute_as="caller",
272
211
  )
273
212
 
274
213
  return fit_wrapper_sproc
@@ -284,7 +223,7 @@ class SnowparkModelTrainer:
284
223
  fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
285
224
  return fit_sproc
286
225
 
287
- fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
226
+ fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
288
227
 
289
228
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
290
229
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -298,6 +237,7 @@ class SnowparkModelTrainer:
298
237
  replace=True,
299
238
  session=self.session,
300
239
  statement_params=statement_params,
240
+ execute_as="caller",
301
241
  )
302
242
 
303
243
  self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined]
@@ -307,7 +247,7 @@ class SnowparkModelTrainer:
307
247
  def _build_fit_predict_wrapper_sproc(
308
248
  self,
309
249
  model_spec: ModelSpecifications,
310
- ) -> Callable[[Session, List[str], str, str, List[str], Dict[str, str], bool, List[str], str], str]:
250
+ ) -> Callable[[Session, List[str], str, List[str], Dict[str, str], bool, List[str], str], str]:
311
251
  """
312
252
  Constructs and returns a python stored procedure function to be used for training model.
313
253
 
@@ -323,8 +263,7 @@ class SnowparkModelTrainer:
323
263
  def fit_predict_wrapper_function(
324
264
  session: Session,
325
265
  sql_queries: List[str],
326
- stage_transform_file_name: str,
327
- stage_result_file_name: str,
266
+ temp_stage_name: str,
328
267
  input_cols: List[str],
329
268
  statement_params: Dict[str, str],
330
269
  drop_input_cols: bool,
@@ -347,9 +286,13 @@ class SnowparkModelTrainer:
347
286
  df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
348
287
  df.columns = sp_df.columns
349
288
 
350
- local_transform_file_name = get_temp_file_path()
289
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
351
290
 
352
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
291
+ session.file.get(
292
+ stage_location=temp_stage_name,
293
+ target_directory=local_transform_file_name,
294
+ statement_params=statement_params,
295
+ )
353
296
 
354
297
  local_transform_file_path = os.path.join(
355
298
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -359,14 +302,14 @@ class SnowparkModelTrainer:
359
302
 
360
303
  fit_predict_result = estimator.fit_predict(X=df[input_cols])
361
304
 
362
- local_result_file_name = get_temp_file_path()
305
+ local_result_file_name = temp_file_utils.get_temp_file_path()
363
306
 
364
307
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
365
308
  cp.dump(estimator, local_result_file_obj)
366
309
 
367
310
  session.file.put(
368
- local_result_file_name,
369
- stage_result_file_name,
311
+ local_file_name=local_result_file_name,
312
+ stage_location=temp_stage_name,
370
313
  auto_compress=False,
371
314
  overwrite=True,
372
315
  statement_params=statement_params,
@@ -407,7 +350,6 @@ class SnowparkModelTrainer:
407
350
  Session,
408
351
  List[str],
409
352
  str,
410
- str,
411
353
  List[str],
412
354
  Optional[List[str]],
413
355
  Optional[str],
@@ -433,8 +375,7 @@ class SnowparkModelTrainer:
433
375
  def fit_transform_wrapper_function(
434
376
  session: Session,
435
377
  sql_queries: List[str],
436
- stage_transform_file_name: str,
437
- stage_result_file_name: str,
378
+ temp_stage_name: str,
438
379
  input_cols: List[str],
439
380
  label_cols: Optional[List[str]],
440
381
  sample_weight_col: Optional[str],
@@ -459,9 +400,13 @@ class SnowparkModelTrainer:
459
400
  df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
460
401
  df.columns = sp_df.columns
461
402
 
462
- local_transform_file_name = get_temp_file_path()
403
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
463
404
 
464
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
405
+ session.file.get(
406
+ stage_location=temp_stage_name,
407
+ target_directory=local_transform_file_name,
408
+ statement_params=statement_params,
409
+ )
465
410
 
466
411
  local_transform_file_path = os.path.join(
467
412
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -480,14 +425,14 @@ class SnowparkModelTrainer:
480
425
 
481
426
  fit_transform_result = estimator.fit_transform(**args)
482
427
 
483
- local_result_file_name = get_temp_file_path()
428
+ local_result_file_name = temp_file_utils.get_temp_file_path()
484
429
 
485
430
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
486
431
  cp.dump(estimator, local_result_file_obj)
487
432
 
488
433
  session.file.put(
489
- local_result_file_name,
490
- stage_result_file_name,
434
+ local_file_name=local_result_file_name,
435
+ stage_location=temp_stage_name,
491
436
  auto_compress=False,
492
437
  overwrite=True,
493
438
  statement_params=statement_params,
@@ -535,7 +480,7 @@ class SnowparkModelTrainer:
535
480
  def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
536
481
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
537
482
 
538
- fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
483
+ fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
539
484
 
540
485
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
541
486
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -550,6 +495,7 @@ class SnowparkModelTrainer:
550
495
  session=self.session,
551
496
  statement_params=statement_params,
552
497
  anonymous=True,
498
+ execute_as="caller",
553
499
  )
554
500
 
555
501
  return fit_predict_wrapper_sproc
@@ -567,7 +513,7 @@ class SnowparkModelTrainer:
567
513
  ]
568
514
  return fit_sproc
569
515
 
570
- fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
516
+ fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
571
517
 
572
518
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
573
519
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -581,6 +527,7 @@ class SnowparkModelTrainer:
581
527
  replace=True,
582
528
  session=self.session,
583
529
  statement_params=statement_params,
530
+ execute_as="caller",
584
531
  )
585
532
 
586
533
  self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
@@ -592,7 +539,7 @@ class SnowparkModelTrainer:
592
539
  def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
593
540
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
594
541
 
595
- fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
542
+ fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
596
543
 
597
544
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
598
545
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -607,6 +554,7 @@ class SnowparkModelTrainer:
607
554
  session=self.session,
608
555
  statement_params=statement_params,
609
556
  anonymous=True,
557
+ execute_as="caller",
610
558
  )
611
559
  return fit_transform_wrapper_sproc
612
560
 
@@ -623,7 +571,7 @@ class SnowparkModelTrainer:
623
571
  ]
624
572
  return fit_sproc
625
573
 
626
- fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
574
+ fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
627
575
 
628
576
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
629
577
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -637,6 +585,7 @@ class SnowparkModelTrainer:
637
585
  replace=True,
638
586
  session=self.session,
639
587
  statement_params=statement_params,
588
+ execute_as="caller",
640
589
  )
641
590
 
642
591
  self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
@@ -663,19 +612,21 @@ class SnowparkModelTrainer:
663
612
  # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
664
613
  queries = dataset.queries["queries"]
665
614
 
666
- transform_stage_name = self._create_temp_stage()
667
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
668
- stage_name=transform_stage_name
669
- )
670
-
671
- # Call fit sproc
615
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
672
616
  statement_params = telemetry.get_function_usage_statement_params(
673
617
  project=_PROJECT,
674
618
  subproject=self._subproject,
675
619
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
676
620
  api_calls=[Session.call],
677
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
621
+ custom_tags={"autogen": True} if self._autogenerated else None,
678
622
  )
623
+ estimator_utils.upload_model_to_stage(
624
+ stage_name=temp_stage_name,
625
+ estimator=self.estimator,
626
+ session=self.session,
627
+ statement_params=statement_params,
628
+ )
629
+ # Call fit sproc
679
630
 
680
631
  if _ENABLE_ANONYMOUS_SPROC:
681
632
  fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params)
@@ -686,8 +637,7 @@ class SnowparkModelTrainer:
686
637
  sproc_export_file_name: str = fit_wrapper_sproc(
687
638
  self.session,
688
639
  queries,
689
- stage_transform_file_name,
690
- stage_result_file_name,
640
+ temp_stage_name,
691
641
  self.input_cols,
692
642
  self.label_cols,
693
643
  self.sample_weight_col,
@@ -706,7 +656,7 @@ class SnowparkModelTrainer:
706
656
  sproc_export_file_name = fields[0]
707
657
 
708
658
  return self._fetch_model_from_stage(
709
- dir_path=stage_result_file_name,
659
+ dir_path=temp_stage_name,
710
660
  file_name=sproc_export_file_name,
711
661
  statement_params=statement_params,
712
662
  )
@@ -734,32 +684,34 @@ class SnowparkModelTrainer:
734
684
  # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
735
685
  queries = dataset.queries["queries"]
736
686
 
737
- transform_stage_name = self._create_temp_stage()
738
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
739
- stage_name=transform_stage_name
740
- )
741
-
742
- # Call fit sproc
743
687
  statement_params = telemetry.get_function_usage_statement_params(
744
688
  project=_PROJECT,
745
689
  subproject=self._subproject,
746
690
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
747
691
  api_calls=[Session.call],
748
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
692
+ custom_tags={"autogen": True} if self._autogenerated else None,
749
693
  )
750
694
 
695
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
696
+ estimator_utils.upload_model_to_stage(
697
+ stage_name=temp_stage_name,
698
+ estimator=self.estimator,
699
+ session=self.session,
700
+ statement_params=statement_params,
701
+ )
702
+
703
+ # Call fit sproc
751
704
  if _ENABLE_ANONYMOUS_SPROC:
752
705
  fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params)
753
706
  else:
754
707
  fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
755
708
 
756
- fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
709
+ fit_predict_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
757
710
 
758
711
  sproc_export_file_name: str = fit_predict_wrapper_sproc(
759
712
  self.session,
760
713
  queries,
761
- stage_transform_file_name,
762
- stage_result_file_name,
714
+ temp_stage_name,
763
715
  self.input_cols,
764
716
  statement_params,
765
717
  drop_input_cols,
@@ -769,7 +721,7 @@ class SnowparkModelTrainer:
769
721
 
770
722
  output_result_sp = self.session.table(fit_predict_result_name)
771
723
  fitted_estimator = self._fetch_model_from_stage(
772
- dir_path=stage_result_file_name,
724
+ dir_path=temp_stage_name,
773
725
  file_name=sproc_export_file_name,
774
726
  statement_params=statement_params,
775
727
  )
@@ -799,20 +751,23 @@ class SnowparkModelTrainer:
799
751
  # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
800
752
  queries = dataset.queries["queries"]
801
753
 
802
- transform_stage_name = self._create_temp_stage()
803
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
804
- stage_name=transform_stage_name
805
- )
806
-
807
- # Call fit sproc
808
754
  statement_params = telemetry.get_function_usage_statement_params(
809
755
  project=_PROJECT,
810
756
  subproject=self._subproject,
811
757
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
812
758
  api_calls=[Session.call],
813
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
759
+ custom_tags={"autogen": True} if self._autogenerated else None,
760
+ )
761
+
762
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
763
+ estimator_utils.upload_model_to_stage(
764
+ stage_name=temp_stage_name,
765
+ estimator=self.estimator,
766
+ session=self.session,
767
+ statement_params=statement_params,
814
768
  )
815
769
 
770
+ # Call fit sproc
816
771
  if _ENABLE_ANONYMOUS_SPROC:
817
772
  fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous(
818
773
  statement_params=statement_params
@@ -820,13 +775,12 @@ class SnowparkModelTrainer:
820
775
  else:
821
776
  fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
822
777
 
823
- fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
778
+ fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
824
779
 
825
780
  sproc_export_file_name: str = fit_transform_wrapper_sproc(
826
781
  self.session,
827
782
  queries,
828
- stage_transform_file_name,
829
- stage_result_file_name,
783
+ temp_stage_name,
830
784
  self.input_cols,
831
785
  self.label_cols,
832
786
  self.sample_weight_col,
@@ -838,7 +792,7 @@ class SnowparkModelTrainer:
838
792
 
839
793
  output_result_sp = self.session.table(fit_transform_result_name)
840
794
  fitted_estimator = self._fetch_model_from_stage(
841
- dir_path=stage_result_file_name,
795
+ dir_path=temp_stage_name,
842
796
  file_name=sproc_export_file_name,
843
797
  statement_params=statement_params,
844
798
  )