snowflake-ml-python 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. snowflake/cortex/_complete.py +26 -5
  2. snowflake/cortex/_sentiment.py +7 -4
  3. snowflake/cortex/_sse_client.py +81 -0
  4. snowflake/cortex/_util.py +105 -8
  5. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  6. snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
  7. snowflake/ml/dataset/dataset.py +15 -12
  8. snowflake/ml/dataset/dataset_factory.py +3 -4
  9. snowflake/ml/feature_store/access_manager.py +34 -30
  10. snowflake/ml/feature_store/feature_store.py +3 -3
  11. snowflake/ml/feature_store/feature_view.py +12 -11
  12. snowflake/ml/fileset/snowfs.py +2 -31
  13. snowflake/ml/model/_client/ops/model_ops.py +43 -0
  14. snowflake/ml/model/_client/sql/model_version.py +55 -3
  15. snowflake/ml/model/_model_composer/model_composer.py +7 -3
  16. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
  17. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  18. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
  19. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  20. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
  21. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  22. snowflake/ml/model/_signatures/core.py +13 -1
  23. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  24. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  25. snowflake/ml/model/model_signature.py +2 -0
  26. snowflake/ml/model/type_hints.py +1 -0
  27. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  28. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +196 -242
  29. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +161 -0
  30. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
  31. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
  32. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
  33. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -2
  34. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -2
  35. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -2
  36. snowflake/ml/modeling/cluster/birch.py +9 -2
  37. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -2
  38. snowflake/ml/modeling/cluster/dbscan.py +9 -2
  39. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -2
  40. snowflake/ml/modeling/cluster/k_means.py +9 -2
  41. snowflake/ml/modeling/cluster/mean_shift.py +9 -2
  42. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -2
  43. snowflake/ml/modeling/cluster/optics.py +9 -2
  44. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -2
  45. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -2
  46. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -2
  47. snowflake/ml/modeling/compose/column_transformer.py +9 -2
  48. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -2
  49. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -2
  50. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -2
  51. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -2
  52. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -2
  53. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -2
  54. snowflake/ml/modeling/covariance/min_cov_det.py +9 -2
  55. snowflake/ml/modeling/covariance/oas.py +9 -2
  56. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -2
  57. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -2
  58. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -2
  59. snowflake/ml/modeling/decomposition/fast_ica.py +9 -2
  60. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -2
  61. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -2
  62. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -2
  63. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -2
  64. snowflake/ml/modeling/decomposition/pca.py +9 -2
  65. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -2
  66. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -2
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -2
  68. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -2
  69. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -2
  70. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -2
  71. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -2
  72. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -2
  73. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -2
  74. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -2
  75. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -2
  76. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -2
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -2
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -2
  79. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -2
  80. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -2
  81. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -2
  82. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -2
  83. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -2
  84. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -2
  85. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -2
  86. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -2
  87. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -2
  88. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -2
  89. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -2
  90. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -2
  91. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -2
  92. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -2
  93. snowflake/ml/modeling/framework/base.py +3 -8
  94. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -2
  95. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -2
  96. snowflake/ml/modeling/impute/iterative_imputer.py +9 -2
  97. snowflake/ml/modeling/impute/knn_imputer.py +9 -2
  98. snowflake/ml/modeling/impute/missing_indicator.py +9 -2
  99. snowflake/ml/modeling/impute/simple_imputer.py +28 -5
  100. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -2
  101. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -2
  102. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -2
  103. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -2
  104. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -2
  105. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -2
  106. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -2
  107. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -2
  108. snowflake/ml/modeling/linear_model/ard_regression.py +9 -2
  109. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -2
  110. snowflake/ml/modeling/linear_model/elastic_net.py +9 -2
  111. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -2
  112. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -2
  113. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -2
  114. snowflake/ml/modeling/linear_model/lars.py +9 -2
  115. snowflake/ml/modeling/linear_model/lars_cv.py +9 -2
  116. snowflake/ml/modeling/linear_model/lasso.py +9 -2
  117. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -2
  118. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -2
  119. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -2
  120. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -2
  121. snowflake/ml/modeling/linear_model/linear_regression.py +9 -2
  122. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -2
  123. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -2
  124. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -2
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -2
  126. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -2
  127. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -2
  128. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -2
  129. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -2
  130. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -2
  131. snowflake/ml/modeling/linear_model/perceptron.py +9 -2
  132. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -2
  133. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -2
  134. snowflake/ml/modeling/linear_model/ridge.py +9 -2
  135. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -2
  136. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -2
  137. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -2
  138. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -2
  139. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -2
  140. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -2
  141. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -2
  142. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -2
  143. snowflake/ml/modeling/manifold/isomap.py +9 -2
  144. snowflake/ml/modeling/manifold/mds.py +9 -2
  145. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -2
  146. snowflake/ml/modeling/manifold/tsne.py +9 -2
  147. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -2
  148. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -2
  149. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  150. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  151. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -2
  152. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -2
  153. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -2
  154. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -2
  155. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -2
  156. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -2
  157. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -2
  158. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -2
  159. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -2
  160. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -2
  161. snowflake/ml/modeling/neighbors/kernel_density.py +9 -2
  162. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -2
  163. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -2
  164. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -2
  165. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -2
  166. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -2
  167. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -2
  168. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -2
  169. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -2
  170. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -2
  171. snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
  172. snowflake/ml/modeling/pipeline/pipeline.py +5 -0
  173. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  174. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  175. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  176. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  177. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  178. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  179. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
  180. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -2
  182. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  183. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  184. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -2
  185. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -2
  186. snowflake/ml/modeling/svm/linear_svc.py +9 -2
  187. snowflake/ml/modeling/svm/linear_svr.py +9 -2
  188. snowflake/ml/modeling/svm/nu_svc.py +9 -2
  189. snowflake/ml/modeling/svm/nu_svr.py +9 -2
  190. snowflake/ml/modeling/svm/svc.py +9 -2
  191. snowflake/ml/modeling/svm/svr.py +9 -2
  192. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -2
  193. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -2
  194. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -2
  195. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -2
  196. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -2
  197. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -2
  198. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -2
  199. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -2
  200. snowflake/ml/registry/_manager/model_manager.py +59 -1
  201. snowflake/ml/registry/registry.py +10 -1
  202. snowflake/ml/version.py +1 -1
  203. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +32 -4
  204. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +207 -204
  205. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
  206. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
  207. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,161 @@
1
+ """
2
+ Description:
3
+ This is the helper file for distributed_hpo_trainer.py to create UDTF by `register_from_file`.
4
+ Performance Benefits:
5
+ The performance benefits come from two aspects,
6
+ 1. register_from_file can reduce duplicating loading data by only loading data once in each node
7
+ 2. register_from_file enable user to load data in global variable, whereas writing UDF in python script cannot.
8
+ Developer Tips:
9
+ Because this script is now a string, so there's no type hinting, linting, etc. It is highly recommended
10
+ to develop in a python script, test the type hinting, and then convert it into a string.
11
+ """
12
+
13
+ execute_template = """
14
+ from typing import Tuple, Any, List, Dict, Set, Iterator
15
+ import os
16
+ import sys
17
+ import pandas as pd
18
+ import numpy as np
19
+ import numpy.typing as npt
20
+ import cloudpickle as cp
21
+ import io
22
+
23
+
24
+ def _load_data_into_udf() -> Tuple[
25
+ npt.NDArray[Any],
26
+ npt.NDArray[Any],
27
+ List[List[int]],
28
+ List[Dict[str, Any]],
29
+ object,
30
+ Dict[str, Any],
31
+ Dict[str, Any],
32
+ ]:
33
+ import pyarrow.parquet as pq
34
+
35
+ data_files = [
36
+ filename
37
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
38
+ if filename.startswith("dataset")
39
+ ]
40
+ partial_df = [
41
+ pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
42
+ for file_name in data_files
43
+ ]
44
+ df = pd.concat(partial_df, ignore_index=True)
45
+ constant_file_path = None
46
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"]):
47
+ if filename.startswith("constant"):
48
+ constant_file_path = os.path.join(sys._xoptions["snowflake_import_directory"], f"{filename}")
49
+ if constant_file_path is None:
50
+ raise ValueError("UDTF cannot find the constant location, abort!")
51
+ with open(constant_file_path, mode="rb") as constant_file_obj:
52
+ CONSTANTS = cp.load(constant_file_obj)
53
+ df.columns = CONSTANTS['dataset_snowpark_cols']
54
+
55
+ # load parameter grid
56
+ local_estimator_file_path = os.path.join(
57
+ sys._xoptions["snowflake_import_directory"],
58
+ f"{CONSTANTS['estimator_location']}"
59
+ )
60
+ with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
61
+ estimator_objects = cp.load(local_estimator_file_obj)
62
+ params_to_evaluate = estimator_objects["param_grid"]
63
+
64
+ # load indices
65
+ local_indices_file_path = os.path.join(
66
+ sys._xoptions["snowflake_import_directory"],
67
+ f"{CONSTANTS['indices_location']}"
68
+ )
69
+ with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
70
+ indices = cp.load(local_indices_file_obj)
71
+
72
+ # load base estimator
73
+ local_base_estimator_file_path = os.path.join(
74
+ sys._xoptions["snowflake_import_directory"], f"{CONSTANTS['base_estimator_location']}"
75
+ )
76
+ with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
77
+ base_estimator = cp.load(local_base_estimator_file_obj)
78
+
79
+ # load fit_and_score_kwargs
80
+ local_fit_and_score_kwargs_file_path = os.path.join(
81
+ sys._xoptions["snowflake_import_directory"], f"{CONSTANTS['fit_and_score_kwargs_location']}"
82
+ )
83
+ with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
84
+ fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
85
+
86
+ # convert dataframe to numpy would save memory consumption
87
+ return (
88
+ df[CONSTANTS['input_cols']].to_numpy(),
89
+ df[CONSTANTS['label_cols']].squeeze().to_numpy(),
90
+ indices,
91
+ params_to_evaluate,
92
+ base_estimator,
93
+ fit_and_score_kwargs,
94
+ CONSTANTS
95
+ )
96
+
97
+
98
+ global_load_data = _load_data_into_udf()
99
+
100
+
101
+ # Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
102
+ class SearchCV:
103
+ def __init__(self) -> None:
104
+ X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs, CONSTANTS = global_load_data
105
+ self.X = X
106
+ self.y = y
107
+ self.test_indices = indices
108
+ self.params_to_evaluate = params_to_evaluate
109
+ self.base_estimator = base_estimator
110
+ self.fit_and_score_kwargs = fit_and_score_kwargs
111
+ self.fit_score_params: List[Any] = []
112
+ self.CONSTANTS = CONSTANTS
113
+ self.cv_indices_set: Set[int] = set()
114
+
115
+ def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
116
+ self.fit_score_params.extend([[idx, params_idx, cv_idx]])
117
+ self.cv_indices_set.add(cv_idx)
118
+
119
+ def end_partition(self) -> Iterator[Tuple[int, str]]:
120
+ from sklearn.base import clone
121
+ from sklearn.model_selection._validation import _fit_and_score
122
+ from sklearn.utils.parallel import Parallel, delayed
123
+
124
+ cached_train_test_indices = {}
125
+ # Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
126
+ full_index = np.arange(self.CONSTANTS['DATA_LENGTH'])
127
+ for i in self.cv_indices_set:
128
+ cached_train_test_indices[i] = [
129
+ np.setdiff1d(full_index, self.test_indices[i]),
130
+ self.test_indices[i],
131
+ ]
132
+
133
+ parallel = Parallel(n_jobs=self.CONSTANTS['_N_JOBS'], pre_dispatch=self.CONSTANTS['_PRE_DISPATCH'])
134
+
135
+ out = parallel(
136
+ delayed(_fit_and_score)(
137
+ clone(self.base_estimator),
138
+ self.X,
139
+ self.y,
140
+ train=cached_train_test_indices[split_idx][0],
141
+ test=cached_train_test_indices[split_idx][1],
142
+ parameters=self.params_to_evaluate[cand_idx],
143
+ split_progress=(split_idx, self.CONSTANTS['n_splits']),
144
+ candidate_progress=(cand_idx, self.CONSTANTS['n_candidates']),
145
+ **self.fit_and_score_kwargs, # load sample weight here
146
+ )
147
+ for _, cand_idx, split_idx in self.fit_score_params
148
+ )
149
+
150
+ binary_cv_results = None
151
+ with io.BytesIO() as f:
152
+ cp.dump(out, f)
153
+ f.seek(0)
154
+ binary_cv_results = f.getvalue().hex()
155
+ yield (
156
+ self.fit_score_params[0][0],
157
+ binary_cv_results,
158
+ )
159
+
160
+ SearchCV._sf_node_singleton = True
161
+ """
@@ -2,6 +2,7 @@ import importlib
2
2
  import inspect
3
3
  import os
4
4
  import posixpath
5
+ import sys
5
6
  from typing import Any, Dict, List, Optional
6
7
  from uuid import uuid4
7
8
 
@@ -13,12 +14,10 @@ from snowflake.ml._internal.utils import (
13
14
  identifier,
14
15
  pkg_version_utils,
15
16
  snowpark_dataframe_utils,
17
+ temp_file_utils,
16
18
  )
17
19
  from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
18
- from snowflake.ml._internal.utils.temp_file_utils import (
19
- cleanup_temp_files,
20
- get_temp_file_path,
21
- )
20
+ from snowflake.ml.modeling._internal import estimator_utils
22
21
  from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
23
22
  from snowflake.snowpark import DataFrame, Session, functions as F, types as T
24
23
  from snowflake.snowpark._internal.utils import (
@@ -26,7 +25,7 @@ from snowflake.snowpark._internal.utils import (
26
25
  random_name_for_temp_object,
27
26
  )
28
27
 
29
- cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
28
+ cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
30
29
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
31
30
  cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
32
31
 
@@ -97,7 +96,25 @@ class SnowparkTransformHandlers:
97
96
 
98
97
  dependencies = self._get_validated_snowpark_dependencies(session, dependencies)
99
98
  dataset = self.dataset
100
- estimator = self.estimator
99
+
100
+ statement_params = telemetry.get_function_usage_statement_params(
101
+ project=_PROJECT,
102
+ subproject=self._subproject,
103
+ function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
104
+ api_calls=[F.pandas_udf],
105
+ custom_tags={"autogen": True} if self._autogenerated else None,
106
+ )
107
+
108
+ temp_stage_name = estimator_utils.create_temp_stage(session)
109
+
110
+ estimator_file_name = estimator_utils.upload_model_to_stage(
111
+ stage_name=temp_stage_name,
112
+ estimator=self.estimator,
113
+ session=session,
114
+ statement_params=statement_params,
115
+ )
116
+ imports = [f"@{temp_stage_name}/{estimator_file_name}"]
117
+
101
118
  # Register vectorized UDF for batch inference
102
119
  batch_inference_udf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
103
120
 
@@ -113,13 +130,13 @@ class SnowparkTransformHandlers:
113
130
  for field in fields:
114
131
  input_datatypes.append(field.datatype)
115
132
 
116
- statement_params = telemetry.get_function_usage_statement_params(
117
- project=_PROJECT,
118
- subproject=self._subproject,
119
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
120
- api_calls=[F.pandas_udf],
121
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
122
- )
133
+ # TODO(xjiang): for optimization, use register_from_file to reduce duplicate loading estimator object
134
+ # or use cachetools here
135
+ def load_estimator() -> object:
136
+ estimator_file_path = os.path.join(sys._xoptions["snowflake_import_directory"], f"{estimator_file_name}")
137
+ with open(estimator_file_path, mode="rb") as local_estimator_file_obj:
138
+ estimator_object = cp.load(local_estimator_file_obj)
139
+ return estimator_object
123
140
 
124
141
  @F.pandas_udf( # type: ignore[arg-type, misc]
125
142
  is_permanent=False,
@@ -129,6 +146,7 @@ class SnowparkTransformHandlers:
129
146
  session=session,
130
147
  statement_params=statement_params,
131
148
  input_types=[T.PandasDataFrameType(input_datatypes)],
149
+ imports=imports, # type: ignore[arg-type]
132
150
  )
133
151
  def vec_batch_infer(input_df: pd.DataFrame) -> T.PandasSeries[dict]: # type: ignore[type-arg]
134
152
  import numpy as np # noqa: F401
@@ -136,6 +154,8 @@ class SnowparkTransformHandlers:
136
154
 
137
155
  input_df.columns = snowpark_cols
138
156
 
157
+ estimator = load_estimator()
158
+
139
159
  if hasattr(estimator, "n_jobs"):
140
160
  # Vectorized UDF cannot handle joblib multiprocessing right now, deactivate the n_jobs
141
161
  estimator.n_jobs = 1
@@ -225,7 +245,7 @@ class SnowparkTransformHandlers:
225
245
  queries = dataset.queries["queries"]
226
246
 
227
247
  # Create a temp file and dump the score to that file.
228
- local_score_file_name = get_temp_file_path()
248
+ local_score_file_name = temp_file_utils.get_temp_file_path()
229
249
  with open(local_score_file_name, mode="w+b") as local_score_file:
230
250
  cp.dump(estimator, local_score_file)
231
251
 
@@ -247,7 +267,7 @@ class SnowparkTransformHandlers:
247
267
  inspect.currentframe(), self.__class__.__name__
248
268
  ),
249
269
  api_calls=[F.sproc],
250
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
270
+ custom_tags={"autogen": True} if self._autogenerated else None,
251
271
  )
252
272
  # Put locally serialized score on stage.
253
273
  session.file.put(
@@ -290,7 +310,7 @@ class SnowparkTransformHandlers:
290
310
  df: pd.DataFrame = sp_df.to_pandas(statement_params=score_statement_params)
291
311
  df.columns = sp_df.columns
292
312
 
293
- local_score_file_name = get_temp_file_path()
313
+ local_score_file_name = temp_file_utils.get_temp_file_path()
294
314
  session.file.get(stage_score_file_name, local_score_file_name, statement_params=score_statement_params)
295
315
 
296
316
  local_score_file_name_path = os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0])
@@ -323,7 +343,7 @@ class SnowparkTransformHandlers:
323
343
  inspect.currentframe(), self.__class__.__name__
324
344
  ),
325
345
  api_calls=[Session.call],
326
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
346
+ custom_tags={"autogen": True} if self._autogenerated else None,
327
347
  )
328
348
 
329
349
  kwargs = telemetry.get_sproc_statement_params_kwargs(score_wrapper_sproc, score_statement_params)
@@ -338,7 +358,7 @@ class SnowparkTransformHandlers:
338
358
  **kwargs,
339
359
  )
340
360
 
341
- cleanup_temp_files([local_score_file_name])
361
+ temp_file_utils.cleanup_temp_files([local_score_file_name])
342
362
 
343
363
  return score
344
364