snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +26 -5
- snowflake/cortex/_sse_client.py +81 -0
- snowflake/cortex/_util.py +105 -8
- snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
- snowflake/ml/dataset/dataset.py +15 -12
- snowflake/ml/dataset/dataset_factory.py +3 -4
- snowflake/ml/feature_store/feature_store.py +2 -2
- snowflake/ml/model/_client/sql/model_version.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_signatures/builtins_handler.py +2 -1
- snowflake/ml/model/_signatures/core.py +13 -1
- snowflake/ml/model/_signatures/pandas_handler.py +2 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +156 -121
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/framework/base.py +3 -8
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +8 -4
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +5 -0
- snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
- snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +21 -5
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +196 -195
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import importlib
|
|
2
2
|
import inspect
|
3
3
|
import os
|
4
4
|
import posixpath
|
5
|
+
import sys
|
5
6
|
from typing import Any, Dict, List, Optional
|
6
7
|
from uuid import uuid4
|
7
8
|
|
@@ -13,12 +14,10 @@ from snowflake.ml._internal.utils import (
|
|
13
14
|
identifier,
|
14
15
|
pkg_version_utils,
|
15
16
|
snowpark_dataframe_utils,
|
17
|
+
temp_file_utils,
|
16
18
|
)
|
17
19
|
from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
|
18
|
-
from snowflake.ml._internal
|
19
|
-
cleanup_temp_files,
|
20
|
-
get_temp_file_path,
|
21
|
-
)
|
20
|
+
from snowflake.ml.modeling._internal import estimator_utils
|
22
21
|
from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
|
23
22
|
from snowflake.snowpark import DataFrame, Session, functions as F, types as T
|
24
23
|
from snowflake.snowpark._internal.utils import (
|
@@ -26,7 +25,7 @@ from snowflake.snowpark._internal.utils import (
|
|
26
25
|
random_name_for_temp_object,
|
27
26
|
)
|
28
27
|
|
29
|
-
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
28
|
+
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
30
29
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
31
30
|
cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
32
31
|
|
@@ -97,7 +96,25 @@ class SnowparkTransformHandlers:
|
|
97
96
|
|
98
97
|
dependencies = self._get_validated_snowpark_dependencies(session, dependencies)
|
99
98
|
dataset = self.dataset
|
100
|
-
|
99
|
+
|
100
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
101
|
+
project=_PROJECT,
|
102
|
+
subproject=self._subproject,
|
103
|
+
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
104
|
+
api_calls=[F.pandas_udf],
|
105
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
106
|
+
)
|
107
|
+
|
108
|
+
temp_stage_name = estimator_utils.create_temp_stage(session)
|
109
|
+
|
110
|
+
estimator_file_name = estimator_utils.upload_model_to_stage(
|
111
|
+
stage_name=temp_stage_name,
|
112
|
+
estimator=self.estimator,
|
113
|
+
session=session,
|
114
|
+
statement_params=statement_params,
|
115
|
+
)
|
116
|
+
imports = [f"@{temp_stage_name}/{estimator_file_name}"]
|
117
|
+
|
101
118
|
# Register vectorized UDF for batch inference
|
102
119
|
batch_inference_udf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
|
103
120
|
|
@@ -113,13 +130,13 @@ class SnowparkTransformHandlers:
|
|
113
130
|
for field in fields:
|
114
131
|
input_datatypes.append(field.datatype)
|
115
132
|
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
133
|
+
# TODO(xjiang): for optimization, use register_from_file to reduce duplicate loading estimator object
|
134
|
+
# or use cachetools here
|
135
|
+
def load_estimator() -> object:
|
136
|
+
estimator_file_path = os.path.join(sys._xoptions["snowflake_import_directory"], f"{estimator_file_name}")
|
137
|
+
with open(estimator_file_path, mode="rb") as local_estimator_file_obj:
|
138
|
+
estimator_object = cp.load(local_estimator_file_obj)
|
139
|
+
return estimator_object
|
123
140
|
|
124
141
|
@F.pandas_udf( # type: ignore[arg-type, misc]
|
125
142
|
is_permanent=False,
|
@@ -129,6 +146,7 @@ class SnowparkTransformHandlers:
|
|
129
146
|
session=session,
|
130
147
|
statement_params=statement_params,
|
131
148
|
input_types=[T.PandasDataFrameType(input_datatypes)],
|
149
|
+
imports=imports, # type: ignore[arg-type]
|
132
150
|
)
|
133
151
|
def vec_batch_infer(input_df: pd.DataFrame) -> T.PandasSeries[dict]: # type: ignore[type-arg]
|
134
152
|
import numpy as np # noqa: F401
|
@@ -136,6 +154,8 @@ class SnowparkTransformHandlers:
|
|
136
154
|
|
137
155
|
input_df.columns = snowpark_cols
|
138
156
|
|
157
|
+
estimator = load_estimator()
|
158
|
+
|
139
159
|
if hasattr(estimator, "n_jobs"):
|
140
160
|
# Vectorized UDF cannot handle joblib multiprocessing right now, deactivate the n_jobs
|
141
161
|
estimator.n_jobs = 1
|
@@ -225,7 +245,7 @@ class SnowparkTransformHandlers:
|
|
225
245
|
queries = dataset.queries["queries"]
|
226
246
|
|
227
247
|
# Create a temp file and dump the score to that file.
|
228
|
-
local_score_file_name = get_temp_file_path()
|
248
|
+
local_score_file_name = temp_file_utils.get_temp_file_path()
|
229
249
|
with open(local_score_file_name, mode="w+b") as local_score_file:
|
230
250
|
cp.dump(estimator, local_score_file)
|
231
251
|
|
@@ -247,7 +267,7 @@ class SnowparkTransformHandlers:
|
|
247
267
|
inspect.currentframe(), self.__class__.__name__
|
248
268
|
),
|
249
269
|
api_calls=[F.sproc],
|
250
|
-
custom_tags=
|
270
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
251
271
|
)
|
252
272
|
# Put locally serialized score on stage.
|
253
273
|
session.file.put(
|
@@ -290,7 +310,7 @@ class SnowparkTransformHandlers:
|
|
290
310
|
df: pd.DataFrame = sp_df.to_pandas(statement_params=score_statement_params)
|
291
311
|
df.columns = sp_df.columns
|
292
312
|
|
293
|
-
local_score_file_name = get_temp_file_path()
|
313
|
+
local_score_file_name = temp_file_utils.get_temp_file_path()
|
294
314
|
session.file.get(stage_score_file_name, local_score_file_name, statement_params=score_statement_params)
|
295
315
|
|
296
316
|
local_score_file_name_path = os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0])
|
@@ -323,7 +343,7 @@ class SnowparkTransformHandlers:
|
|
323
343
|
inspect.currentframe(), self.__class__.__name__
|
324
344
|
),
|
325
345
|
api_calls=[Session.call],
|
326
|
-
custom_tags=
|
346
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
327
347
|
)
|
328
348
|
|
329
349
|
kwargs = telemetry.get_sproc_statement_params_kwargs(score_wrapper_sproc, score_statement_params)
|
@@ -338,7 +358,7 @@ class SnowparkTransformHandlers:
|
|
338
358
|
**kwargs,
|
339
359
|
)
|
340
360
|
|
341
|
-
cleanup_temp_files([local_score_file_name])
|
361
|
+
temp_file_utils.cleanup_temp_files([local_score_file_name])
|
342
362
|
|
343
363
|
return score
|
344
364
|
|
@@ -17,30 +17,19 @@ from snowflake.ml._internal.utils import (
|
|
17
17
|
identifier,
|
18
18
|
pkg_version_utils,
|
19
19
|
snowpark_dataframe_utils,
|
20
|
+
temp_file_utils,
|
20
21
|
)
|
21
|
-
from snowflake.ml._internal
|
22
|
-
from snowflake.ml._internal.utils.temp_file_utils import (
|
23
|
-
cleanup_temp_files,
|
24
|
-
get_temp_file_path,
|
25
|
-
)
|
22
|
+
from snowflake.ml.modeling._internal import estimator_utils
|
26
23
|
from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
|
27
24
|
from snowflake.ml.modeling._internal.model_specifications import (
|
28
25
|
ModelSpecifications,
|
29
26
|
ModelSpecificationsBuilder,
|
30
27
|
)
|
31
|
-
from snowflake.snowpark import
|
32
|
-
|
33
|
-
Session,
|
34
|
-
exceptions as snowpark_exceptions,
|
35
|
-
functions as F,
|
36
|
-
)
|
37
|
-
from snowflake.snowpark._internal.utils import (
|
38
|
-
TempObjectType,
|
39
|
-
random_name_for_temp_object,
|
40
|
-
)
|
28
|
+
from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions
|
29
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
41
30
|
from snowflake.snowpark.stored_procedure import StoredProcedure
|
42
31
|
|
43
|
-
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
32
|
+
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
44
33
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
45
34
|
cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
46
35
|
|
@@ -90,60 +79,6 @@ class SnowparkModelTrainer:
|
|
90
79
|
self._subproject = subproject
|
91
80
|
self._class_name = estimator.__class__.__name__
|
92
81
|
|
93
|
-
def _create_temp_stage(self) -> str:
|
94
|
-
"""
|
95
|
-
Creates temporary stage.
|
96
|
-
|
97
|
-
Returns:
|
98
|
-
Temp stage name.
|
99
|
-
"""
|
100
|
-
# Create temp stage to upload pickled model file.
|
101
|
-
transform_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
|
102
|
-
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
103
|
-
SqlResultValidator(session=self.session, query=stage_creation_query).has_dimensions(
|
104
|
-
expected_rows=1, expected_cols=1
|
105
|
-
).validate()
|
106
|
-
return transform_stage_name
|
107
|
-
|
108
|
-
def _upload_model_to_stage(self, stage_name: str) -> Tuple[str, str]:
|
109
|
-
"""
|
110
|
-
Util method to pickle and upload the model to a temp Snowflake stage.
|
111
|
-
|
112
|
-
Args:
|
113
|
-
stage_name: Stage name to save model.
|
114
|
-
|
115
|
-
Returns:
|
116
|
-
a tuple containing stage file paths for pickled input model for training and location to store trained
|
117
|
-
models(response from training sproc).
|
118
|
-
"""
|
119
|
-
# Create a temp file and dump the transform to that file.
|
120
|
-
local_transform_file_name = get_temp_file_path()
|
121
|
-
with open(local_transform_file_name, mode="w+b") as local_transform_file:
|
122
|
-
cp.dump(self.estimator, local_transform_file)
|
123
|
-
|
124
|
-
# Use posixpath to construct stage paths
|
125
|
-
stage_transform_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
|
126
|
-
stage_result_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
|
127
|
-
|
128
|
-
statement_params = telemetry.get_function_usage_statement_params(
|
129
|
-
project=_PROJECT,
|
130
|
-
subproject=self._subproject,
|
131
|
-
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
132
|
-
api_calls=[F.sproc],
|
133
|
-
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
134
|
-
)
|
135
|
-
# Put locally serialized transform on stage.
|
136
|
-
self.session.file.put(
|
137
|
-
local_transform_file_name,
|
138
|
-
stage_transform_file_name,
|
139
|
-
auto_compress=False,
|
140
|
-
overwrite=True,
|
141
|
-
statement_params=statement_params,
|
142
|
-
)
|
143
|
-
|
144
|
-
cleanup_temp_files([local_transform_file_name])
|
145
|
-
return (stage_transform_file_name, stage_result_file_name)
|
146
|
-
|
147
82
|
def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object:
|
148
83
|
"""
|
149
84
|
Downloads the serialized model from a stage location and unpickles it.
|
@@ -156,7 +91,7 @@ class SnowparkModelTrainer:
|
|
156
91
|
Returns:
|
157
92
|
Deserialized model object.
|
158
93
|
"""
|
159
|
-
local_result_file_name = get_temp_file_path()
|
94
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
160
95
|
self.session.file.get(
|
161
96
|
posixpath.join(dir_path, file_name),
|
162
97
|
local_result_file_name,
|
@@ -166,13 +101,13 @@ class SnowparkModelTrainer:
|
|
166
101
|
with open(os.path.join(local_result_file_name, file_name), mode="r+b") as result_file_obj:
|
167
102
|
fit_estimator = cp.load(result_file_obj)
|
168
103
|
|
169
|
-
cleanup_temp_files([local_result_file_name])
|
104
|
+
temp_file_utils.cleanup_temp_files([local_result_file_name])
|
170
105
|
return fit_estimator
|
171
106
|
|
172
107
|
def _build_fit_wrapper_sproc(
|
173
108
|
self,
|
174
109
|
model_spec: ModelSpecifications,
|
175
|
-
) -> Callable[[Any, List[str], str,
|
110
|
+
) -> Callable[[Any, List[str], str, List[str], List[str], Optional[str], Dict[str, str]], str]:
|
176
111
|
"""
|
177
112
|
Constructs and returns a python stored procedure function to be used for training model.
|
178
113
|
|
@@ -188,8 +123,7 @@ class SnowparkModelTrainer:
|
|
188
123
|
def fit_wrapper_function(
|
189
124
|
session: Session,
|
190
125
|
sql_queries: List[str],
|
191
|
-
|
192
|
-
stage_result_file_name: str,
|
126
|
+
temp_stage_name: str,
|
193
127
|
input_cols: List[str],
|
194
128
|
label_cols: List[str],
|
195
129
|
sample_weight_col: Optional[str],
|
@@ -212,9 +146,13 @@ class SnowparkModelTrainer:
|
|
212
146
|
df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
|
213
147
|
df.columns = sp_df.columns
|
214
148
|
|
215
|
-
local_transform_file_name = get_temp_file_path()
|
149
|
+
local_transform_file_name = temp_file_utils.get_temp_file_path()
|
216
150
|
|
217
|
-
session.file.get(
|
151
|
+
session.file.get(
|
152
|
+
stage_location=temp_stage_name,
|
153
|
+
target_directory=local_transform_file_name,
|
154
|
+
statement_params=statement_params,
|
155
|
+
)
|
218
156
|
|
219
157
|
local_transform_file_path = os.path.join(
|
220
158
|
local_transform_file_name, os.listdir(local_transform_file_name)[0]
|
@@ -233,14 +171,14 @@ class SnowparkModelTrainer:
|
|
233
171
|
|
234
172
|
estimator.fit(**args)
|
235
173
|
|
236
|
-
local_result_file_name = get_temp_file_path()
|
174
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
237
175
|
|
238
176
|
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
239
177
|
cp.dump(estimator, local_result_file_obj)
|
240
178
|
|
241
179
|
session.file.put(
|
242
|
-
local_result_file_name,
|
243
|
-
|
180
|
+
local_file_name=local_result_file_name,
|
181
|
+
stage_location=temp_stage_name,
|
244
182
|
auto_compress=False,
|
245
183
|
overwrite=True,
|
246
184
|
statement_params=statement_params,
|
@@ -254,7 +192,7 @@ class SnowparkModelTrainer:
|
|
254
192
|
|
255
193
|
def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
256
194
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
257
|
-
fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
195
|
+
fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
258
196
|
|
259
197
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
260
198
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
@@ -284,7 +222,7 @@ class SnowparkModelTrainer:
|
|
284
222
|
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
|
285
223
|
return fit_sproc
|
286
224
|
|
287
|
-
fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
225
|
+
fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
288
226
|
|
289
227
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
290
228
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
@@ -307,7 +245,7 @@ class SnowparkModelTrainer:
|
|
307
245
|
def _build_fit_predict_wrapper_sproc(
|
308
246
|
self,
|
309
247
|
model_spec: ModelSpecifications,
|
310
|
-
) -> Callable[[Session, List[str], str,
|
248
|
+
) -> Callable[[Session, List[str], str, List[str], Dict[str, str], bool, List[str], str], str]:
|
311
249
|
"""
|
312
250
|
Constructs and returns a python stored procedure function to be used for training model.
|
313
251
|
|
@@ -323,8 +261,7 @@ class SnowparkModelTrainer:
|
|
323
261
|
def fit_predict_wrapper_function(
|
324
262
|
session: Session,
|
325
263
|
sql_queries: List[str],
|
326
|
-
|
327
|
-
stage_result_file_name: str,
|
264
|
+
temp_stage_name: str,
|
328
265
|
input_cols: List[str],
|
329
266
|
statement_params: Dict[str, str],
|
330
267
|
drop_input_cols: bool,
|
@@ -347,9 +284,13 @@ class SnowparkModelTrainer:
|
|
347
284
|
df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
|
348
285
|
df.columns = sp_df.columns
|
349
286
|
|
350
|
-
local_transform_file_name = get_temp_file_path()
|
287
|
+
local_transform_file_name = temp_file_utils.get_temp_file_path()
|
351
288
|
|
352
|
-
session.file.get(
|
289
|
+
session.file.get(
|
290
|
+
stage_location=temp_stage_name,
|
291
|
+
target_directory=local_transform_file_name,
|
292
|
+
statement_params=statement_params,
|
293
|
+
)
|
353
294
|
|
354
295
|
local_transform_file_path = os.path.join(
|
355
296
|
local_transform_file_name, os.listdir(local_transform_file_name)[0]
|
@@ -359,14 +300,14 @@ class SnowparkModelTrainer:
|
|
359
300
|
|
360
301
|
fit_predict_result = estimator.fit_predict(X=df[input_cols])
|
361
302
|
|
362
|
-
local_result_file_name = get_temp_file_path()
|
303
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
363
304
|
|
364
305
|
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
365
306
|
cp.dump(estimator, local_result_file_obj)
|
366
307
|
|
367
308
|
session.file.put(
|
368
|
-
local_result_file_name,
|
369
|
-
|
309
|
+
local_file_name=local_result_file_name,
|
310
|
+
stage_location=temp_stage_name,
|
370
311
|
auto_compress=False,
|
371
312
|
overwrite=True,
|
372
313
|
statement_params=statement_params,
|
@@ -407,7 +348,6 @@ class SnowparkModelTrainer:
|
|
407
348
|
Session,
|
408
349
|
List[str],
|
409
350
|
str,
|
410
|
-
str,
|
411
351
|
List[str],
|
412
352
|
Optional[List[str]],
|
413
353
|
Optional[str],
|
@@ -433,8 +373,7 @@ class SnowparkModelTrainer:
|
|
433
373
|
def fit_transform_wrapper_function(
|
434
374
|
session: Session,
|
435
375
|
sql_queries: List[str],
|
436
|
-
|
437
|
-
stage_result_file_name: str,
|
376
|
+
temp_stage_name: str,
|
438
377
|
input_cols: List[str],
|
439
378
|
label_cols: Optional[List[str]],
|
440
379
|
sample_weight_col: Optional[str],
|
@@ -459,9 +398,13 @@ class SnowparkModelTrainer:
|
|
459
398
|
df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
|
460
399
|
df.columns = sp_df.columns
|
461
400
|
|
462
|
-
local_transform_file_name = get_temp_file_path()
|
401
|
+
local_transform_file_name = temp_file_utils.get_temp_file_path()
|
463
402
|
|
464
|
-
session.file.get(
|
403
|
+
session.file.get(
|
404
|
+
stage_location=temp_stage_name,
|
405
|
+
target_directory=local_transform_file_name,
|
406
|
+
statement_params=statement_params,
|
407
|
+
)
|
465
408
|
|
466
409
|
local_transform_file_path = os.path.join(
|
467
410
|
local_transform_file_name, os.listdir(local_transform_file_name)[0]
|
@@ -480,14 +423,14 @@ class SnowparkModelTrainer:
|
|
480
423
|
|
481
424
|
fit_transform_result = estimator.fit_transform(**args)
|
482
425
|
|
483
|
-
local_result_file_name = get_temp_file_path()
|
426
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
484
427
|
|
485
428
|
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
486
429
|
cp.dump(estimator, local_result_file_obj)
|
487
430
|
|
488
431
|
session.file.put(
|
489
|
-
local_result_file_name,
|
490
|
-
|
432
|
+
local_file_name=local_result_file_name,
|
433
|
+
stage_location=temp_stage_name,
|
491
434
|
auto_compress=False,
|
492
435
|
overwrite=True,
|
493
436
|
statement_params=statement_params,
|
@@ -535,7 +478,7 @@ class SnowparkModelTrainer:
|
|
535
478
|
def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
536
479
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
537
480
|
|
538
|
-
fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
481
|
+
fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
539
482
|
|
540
483
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
541
484
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
@@ -567,7 +510,7 @@ class SnowparkModelTrainer:
|
|
567
510
|
]
|
568
511
|
return fit_sproc
|
569
512
|
|
570
|
-
fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
513
|
+
fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
571
514
|
|
572
515
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
573
516
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
@@ -592,7 +535,7 @@ class SnowparkModelTrainer:
|
|
592
535
|
def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
593
536
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
594
537
|
|
595
|
-
fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
538
|
+
fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
596
539
|
|
597
540
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
598
541
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
@@ -623,7 +566,7 @@ class SnowparkModelTrainer:
|
|
623
566
|
]
|
624
567
|
return fit_sproc
|
625
568
|
|
626
|
-
fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
569
|
+
fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
627
570
|
|
628
571
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
629
572
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
@@ -663,19 +606,21 @@ class SnowparkModelTrainer:
|
|
663
606
|
# Extract query that generated the dataframe. We will need to pass it to the fit procedure.
|
664
607
|
queries = dataset.queries["queries"]
|
665
608
|
|
666
|
-
|
667
|
-
(stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
|
668
|
-
stage_name=transform_stage_name
|
669
|
-
)
|
670
|
-
|
671
|
-
# Call fit sproc
|
609
|
+
temp_stage_name = estimator_utils.create_temp_stage(self.session)
|
672
610
|
statement_params = telemetry.get_function_usage_statement_params(
|
673
611
|
project=_PROJECT,
|
674
612
|
subproject=self._subproject,
|
675
613
|
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
676
614
|
api_calls=[Session.call],
|
677
|
-
custom_tags=
|
615
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
678
616
|
)
|
617
|
+
estimator_utils.upload_model_to_stage(
|
618
|
+
stage_name=temp_stage_name,
|
619
|
+
estimator=self.estimator,
|
620
|
+
session=self.session,
|
621
|
+
statement_params=statement_params,
|
622
|
+
)
|
623
|
+
# Call fit sproc
|
679
624
|
|
680
625
|
if _ENABLE_ANONYMOUS_SPROC:
|
681
626
|
fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params)
|
@@ -686,8 +631,7 @@ class SnowparkModelTrainer:
|
|
686
631
|
sproc_export_file_name: str = fit_wrapper_sproc(
|
687
632
|
self.session,
|
688
633
|
queries,
|
689
|
-
|
690
|
-
stage_result_file_name,
|
634
|
+
temp_stage_name,
|
691
635
|
self.input_cols,
|
692
636
|
self.label_cols,
|
693
637
|
self.sample_weight_col,
|
@@ -706,7 +650,7 @@ class SnowparkModelTrainer:
|
|
706
650
|
sproc_export_file_name = fields[0]
|
707
651
|
|
708
652
|
return self._fetch_model_from_stage(
|
709
|
-
dir_path=
|
653
|
+
dir_path=temp_stage_name,
|
710
654
|
file_name=sproc_export_file_name,
|
711
655
|
statement_params=statement_params,
|
712
656
|
)
|
@@ -734,32 +678,34 @@ class SnowparkModelTrainer:
|
|
734
678
|
# Extract query that generated the dataframe. We will need to pass it to the fit procedure.
|
735
679
|
queries = dataset.queries["queries"]
|
736
680
|
|
737
|
-
transform_stage_name = self._create_temp_stage()
|
738
|
-
(stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
|
739
|
-
stage_name=transform_stage_name
|
740
|
-
)
|
741
|
-
|
742
|
-
# Call fit sproc
|
743
681
|
statement_params = telemetry.get_function_usage_statement_params(
|
744
682
|
project=_PROJECT,
|
745
683
|
subproject=self._subproject,
|
746
684
|
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
747
685
|
api_calls=[Session.call],
|
748
|
-
custom_tags=
|
686
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
749
687
|
)
|
750
688
|
|
689
|
+
temp_stage_name = estimator_utils.create_temp_stage(self.session)
|
690
|
+
estimator_utils.upload_model_to_stage(
|
691
|
+
stage_name=temp_stage_name,
|
692
|
+
estimator=self.estimator,
|
693
|
+
session=self.session,
|
694
|
+
statement_params=statement_params,
|
695
|
+
)
|
696
|
+
|
697
|
+
# Call fit sproc
|
751
698
|
if _ENABLE_ANONYMOUS_SPROC:
|
752
699
|
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params)
|
753
700
|
else:
|
754
701
|
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
|
755
702
|
|
756
|
-
fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
|
703
|
+
fit_predict_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
757
704
|
|
758
705
|
sproc_export_file_name: str = fit_predict_wrapper_sproc(
|
759
706
|
self.session,
|
760
707
|
queries,
|
761
|
-
|
762
|
-
stage_result_file_name,
|
708
|
+
temp_stage_name,
|
763
709
|
self.input_cols,
|
764
710
|
statement_params,
|
765
711
|
drop_input_cols,
|
@@ -769,7 +715,7 @@ class SnowparkModelTrainer:
|
|
769
715
|
|
770
716
|
output_result_sp = self.session.table(fit_predict_result_name)
|
771
717
|
fitted_estimator = self._fetch_model_from_stage(
|
772
|
-
dir_path=
|
718
|
+
dir_path=temp_stage_name,
|
773
719
|
file_name=sproc_export_file_name,
|
774
720
|
statement_params=statement_params,
|
775
721
|
)
|
@@ -799,20 +745,23 @@ class SnowparkModelTrainer:
|
|
799
745
|
# Extract query that generated the dataframe. We will need to pass it to the fit procedure.
|
800
746
|
queries = dataset.queries["queries"]
|
801
747
|
|
802
|
-
transform_stage_name = self._create_temp_stage()
|
803
|
-
(stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
|
804
|
-
stage_name=transform_stage_name
|
805
|
-
)
|
806
|
-
|
807
|
-
# Call fit sproc
|
808
748
|
statement_params = telemetry.get_function_usage_statement_params(
|
809
749
|
project=_PROJECT,
|
810
750
|
subproject=self._subproject,
|
811
751
|
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
812
752
|
api_calls=[Session.call],
|
813
|
-
custom_tags=
|
753
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
754
|
+
)
|
755
|
+
|
756
|
+
temp_stage_name = estimator_utils.create_temp_stage(self.session)
|
757
|
+
estimator_utils.upload_model_to_stage(
|
758
|
+
stage_name=temp_stage_name,
|
759
|
+
estimator=self.estimator,
|
760
|
+
session=self.session,
|
761
|
+
statement_params=statement_params,
|
814
762
|
)
|
815
763
|
|
764
|
+
# Call fit sproc
|
816
765
|
if _ENABLE_ANONYMOUS_SPROC:
|
817
766
|
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous(
|
818
767
|
statement_params=statement_params
|
@@ -820,13 +769,12 @@ class SnowparkModelTrainer:
|
|
820
769
|
else:
|
821
770
|
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
|
822
771
|
|
823
|
-
fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
|
772
|
+
fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
824
773
|
|
825
774
|
sproc_export_file_name: str = fit_transform_wrapper_sproc(
|
826
775
|
self.session,
|
827
776
|
queries,
|
828
|
-
|
829
|
-
stage_result_file_name,
|
777
|
+
temp_stage_name,
|
830
778
|
self.input_cols,
|
831
779
|
self.label_cols,
|
832
780
|
self.sample_weight_col,
|
@@ -838,7 +786,7 @@ class SnowparkModelTrainer:
|
|
838
786
|
|
839
787
|
output_result_sp = self.session.table(fit_transform_result_name)
|
840
788
|
fitted_estimator = self._fetch_model_from_stage(
|
841
|
-
dir_path=
|
789
|
+
dir_path=temp_stage_name,
|
842
790
|
file_name=sproc_export_file_name,
|
843
791
|
statement_params=statement_params,
|
844
792
|
)
|