snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -23,20 +23,26 @@ from snowflake.ml._internal.utils.temp_file_utils import (
|
|
23
23
|
cleanup_temp_files,
|
24
24
|
get_temp_file_path,
|
25
25
|
)
|
26
|
+
from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
|
26
27
|
from snowflake.ml.modeling._internal.model_specifications import (
|
27
28
|
ModelSpecifications,
|
28
29
|
ModelSpecificationsBuilder,
|
29
30
|
)
|
30
|
-
from snowflake.snowpark import
|
31
|
+
from snowflake.snowpark import (
|
32
|
+
DataFrame,
|
33
|
+
Session,
|
34
|
+
exceptions as snowpark_exceptions,
|
35
|
+
functions as F,
|
36
|
+
)
|
31
37
|
from snowflake.snowpark._internal.utils import (
|
32
38
|
TempObjectType,
|
33
39
|
random_name_for_temp_object,
|
34
40
|
)
|
35
|
-
from snowflake.snowpark.functions import sproc
|
36
41
|
from snowflake.snowpark.stored_procedure import StoredProcedure
|
37
42
|
|
38
43
|
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
39
44
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
45
|
+
cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
40
46
|
|
41
47
|
_PROJECT = "ModelDevelopment"
|
42
48
|
|
@@ -122,7 +128,7 @@ class SnowparkModelTrainer:
|
|
122
128
|
project=_PROJECT,
|
123
129
|
subproject=self._subproject,
|
124
130
|
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
125
|
-
api_calls=[sproc],
|
131
|
+
api_calls=[F.sproc],
|
126
132
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
127
133
|
)
|
128
134
|
# Put locally serialized transform on stage.
|
@@ -292,7 +298,7 @@ class SnowparkModelTrainer:
|
|
292
298
|
"""
|
293
299
|
imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
|
294
300
|
|
295
|
-
def
|
301
|
+
def fit_predict_wrapper_function(
|
296
302
|
session: Session,
|
297
303
|
sql_queries: List[str],
|
298
304
|
stage_transform_file_name: str,
|
@@ -329,7 +335,7 @@ class SnowparkModelTrainer:
|
|
329
335
|
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
330
336
|
estimator = cp.load(local_transform_file_obj)
|
331
337
|
|
332
|
-
fit_predict_result = estimator.fit_predict(df[input_cols])
|
338
|
+
fit_predict_result = estimator.fit_predict(X=df[input_cols])
|
333
339
|
|
334
340
|
local_result_file_name = get_temp_file_path()
|
335
341
|
|
@@ -349,8 +355,16 @@ class SnowparkModelTrainer:
|
|
349
355
|
fit_predict_result_pd = pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list)
|
350
356
|
else:
|
351
357
|
df = df.copy()
|
352
|
-
|
353
|
-
|
358
|
+
# in case the output column name overlap with the input column names,
|
359
|
+
# remove the ones in input column names
|
360
|
+
remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(expected_output_cols_list))
|
361
|
+
fit_predict_result_pd = pd.concat(
|
362
|
+
[
|
363
|
+
df[remove_dataset_col_name_exist_in_output_col],
|
364
|
+
pd.DataFrame(data=fit_predict_result, columns=expected_output_cols_list),
|
365
|
+
],
|
366
|
+
axis=1,
|
367
|
+
)
|
354
368
|
|
355
369
|
# write into a temp table in sproc and load the table from outside
|
356
370
|
session.write_pandas(
|
@@ -361,17 +375,150 @@ class SnowparkModelTrainer:
|
|
361
375
|
# to pass debug information to the caller.
|
362
376
|
return str(os.path.basename(local_result_file_name))
|
363
377
|
|
364
|
-
return
|
378
|
+
return fit_predict_wrapper_function
|
379
|
+
|
380
|
+
def _build_fit_transform_wrapper_sproc(
|
381
|
+
self,
|
382
|
+
model_spec: ModelSpecifications,
|
383
|
+
) -> Callable[
|
384
|
+
[
|
385
|
+
Session,
|
386
|
+
List[str],
|
387
|
+
str,
|
388
|
+
str,
|
389
|
+
List[str],
|
390
|
+
Optional[List[str]],
|
391
|
+
Optional[str],
|
392
|
+
Dict[str, str],
|
393
|
+
bool,
|
394
|
+
List[str],
|
395
|
+
str,
|
396
|
+
],
|
397
|
+
str,
|
398
|
+
]:
|
399
|
+
"""
|
400
|
+
Constructs and returns a python stored procedure function to be used for training model.
|
401
|
+
|
402
|
+
Args:
|
403
|
+
model_spec: ModelSpecifications object that contains model specific information
|
404
|
+
like required imports, package dependencies, etc.
|
405
|
+
|
406
|
+
Returns:
|
407
|
+
A callable that can be registered as a stored procedure.
|
408
|
+
"""
|
409
|
+
imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
|
410
|
+
|
411
|
+
def fit_transform_wrapper_function(
|
412
|
+
session: Session,
|
413
|
+
sql_queries: List[str],
|
414
|
+
stage_transform_file_name: str,
|
415
|
+
stage_result_file_name: str,
|
416
|
+
input_cols: List[str],
|
417
|
+
label_cols: Optional[List[str]],
|
418
|
+
sample_weight_col: Optional[str],
|
419
|
+
statement_params: Dict[str, str],
|
420
|
+
drop_input_cols: bool,
|
421
|
+
expected_output_cols_list: List[str],
|
422
|
+
fit_transform_result_name: str,
|
423
|
+
) -> str:
|
424
|
+
import os
|
425
|
+
|
426
|
+
import cloudpickle as cp
|
427
|
+
import pandas as pd
|
428
|
+
|
429
|
+
for import_name in imports:
|
430
|
+
importlib.import_module(import_name)
|
431
|
+
|
432
|
+
# Execute snowpark queries and obtain the results as pandas dataframe
|
433
|
+
# NB: this implies that the result data must fit into memory.
|
434
|
+
for query in sql_queries[:-1]:
|
435
|
+
_ = session.sql(query).collect(statement_params=statement_params)
|
436
|
+
sp_df = session.sql(sql_queries[-1])
|
437
|
+
df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
|
438
|
+
df.columns = sp_df.columns
|
439
|
+
|
440
|
+
local_transform_file_name = get_temp_file_path()
|
441
|
+
|
442
|
+
session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
|
443
|
+
|
444
|
+
local_transform_file_path = os.path.join(
|
445
|
+
local_transform_file_name, os.listdir(local_transform_file_name)[0]
|
446
|
+
)
|
447
|
+
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
448
|
+
estimator = cp.load(local_transform_file_obj)
|
449
|
+
|
450
|
+
argspec = inspect.getfullargspec(estimator.fit)
|
451
|
+
args = {"X": df[input_cols]}
|
452
|
+
if label_cols:
|
453
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
454
|
+
args[label_arg_name] = df[label_cols].squeeze()
|
455
|
+
|
456
|
+
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
457
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
458
|
+
|
459
|
+
fit_transform_result = estimator.fit_transform(**args)
|
460
|
+
|
461
|
+
local_result_file_name = get_temp_file_path()
|
462
|
+
|
463
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
464
|
+
cp.dump(estimator, local_result_file_obj)
|
465
|
+
|
466
|
+
session.file.put(
|
467
|
+
local_result_file_name,
|
468
|
+
stage_result_file_name,
|
469
|
+
auto_compress=False,
|
470
|
+
overwrite=True,
|
471
|
+
statement_params=statement_params,
|
472
|
+
)
|
473
|
+
|
474
|
+
transformed_numpy_array, output_cols = handle_inference_result(
|
475
|
+
inference_res=fit_transform_result,
|
476
|
+
output_cols=expected_output_cols_list,
|
477
|
+
inference_method="fit_transform",
|
478
|
+
within_udf=True,
|
479
|
+
)
|
480
|
+
|
481
|
+
if len(transformed_numpy_array.shape) > 1:
|
482
|
+
if transformed_numpy_array.shape[1] != len(output_cols):
|
483
|
+
series = pd.Series(transformed_numpy_array.tolist())
|
484
|
+
transformed_pandas_df = pd.DataFrame(series, columns=output_cols)
|
485
|
+
else:
|
486
|
+
transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=output_cols)
|
487
|
+
else:
|
488
|
+
transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=output_cols)
|
489
|
+
|
490
|
+
# store the transform output
|
491
|
+
if not drop_input_cols:
|
492
|
+
df = df.copy()
|
493
|
+
# in case the output column name overlap with the input column names,
|
494
|
+
# remove the ones in input column names
|
495
|
+
remove_dataset_col_name_exist_in_output_col = list(set(df.columns) - set(output_cols))
|
496
|
+
transformed_pandas_df = pd.concat(
|
497
|
+
[df[remove_dataset_col_name_exist_in_output_col], transformed_pandas_df], axis=1
|
498
|
+
)
|
499
|
+
|
500
|
+
# write into a temp table in sproc and load the table from outside
|
501
|
+
session.write_pandas(
|
502
|
+
transformed_pandas_df,
|
503
|
+
fit_transform_result_name,
|
504
|
+
auto_create_table=True,
|
505
|
+
table_type="temp",
|
506
|
+
quote_identifiers=False,
|
507
|
+
)
|
508
|
+
|
509
|
+
return str(os.path.basename(local_result_file_name))
|
510
|
+
|
511
|
+
return fit_transform_wrapper_function
|
365
512
|
|
366
513
|
def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
367
514
|
# If the sproc already exists, don't register.
|
368
|
-
if not hasattr(self.session, "
|
369
|
-
self.session.
|
515
|
+
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
516
|
+
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
370
517
|
|
371
518
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
372
|
-
fit_predict_sproc_key = model_spec.__class__.__name__
|
373
|
-
if fit_predict_sproc_key in self.session.
|
374
|
-
fit_sproc: StoredProcedure = self.session.
|
519
|
+
fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict"
|
520
|
+
if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
521
|
+
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
375
522
|
fit_predict_sproc_key
|
376
523
|
]
|
377
524
|
return fit_sproc
|
@@ -392,12 +539,47 @@ class SnowparkModelTrainer:
|
|
392
539
|
statement_params=statement_params,
|
393
540
|
)
|
394
541
|
|
395
|
-
self.session.
|
542
|
+
self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
396
543
|
fit_predict_sproc_key
|
397
544
|
] = fit_predict_wrapper_sproc
|
398
545
|
|
399
546
|
return fit_predict_wrapper_sproc
|
400
547
|
|
548
|
+
def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
549
|
+
# If the sproc already exists, don't register.
|
550
|
+
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
551
|
+
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
552
|
+
|
553
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
554
|
+
fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform"
|
555
|
+
if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
556
|
+
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
557
|
+
fit_transform_sproc_key
|
558
|
+
]
|
559
|
+
return fit_sproc
|
560
|
+
|
561
|
+
fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
562
|
+
|
563
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
564
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
565
|
+
)
|
566
|
+
|
567
|
+
fit_transform_wrapper_sproc = self.session.sproc.register(
|
568
|
+
func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
|
569
|
+
is_permanent=False,
|
570
|
+
name=fit_transform_sproc_name,
|
571
|
+
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
572
|
+
replace=True,
|
573
|
+
session=self.session,
|
574
|
+
statement_params=statement_params,
|
575
|
+
)
|
576
|
+
|
577
|
+
self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
578
|
+
fit_transform_sproc_key
|
579
|
+
] = fit_transform_wrapper_sproc
|
580
|
+
|
581
|
+
return fit_transform_wrapper_sproc
|
582
|
+
|
401
583
|
def train(self) -> object:
|
402
584
|
"""
|
403
585
|
Trains the model by pushing down the compute into Snowflake using stored procedures.
|
@@ -498,10 +680,10 @@ class SnowparkModelTrainer:
|
|
498
680
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
499
681
|
)
|
500
682
|
|
501
|
-
|
683
|
+
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
|
502
684
|
fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
|
503
685
|
|
504
|
-
sproc_export_file_name: str =
|
686
|
+
sproc_export_file_name: str = fit_predict_wrapper_sproc(
|
505
687
|
self.session,
|
506
688
|
queries,
|
507
689
|
stage_transform_file_name,
|
@@ -521,3 +703,66 @@ class SnowparkModelTrainer:
|
|
521
703
|
)
|
522
704
|
|
523
705
|
return output_result_sp, fitted_estimator
|
706
|
+
|
707
|
+
def train_fit_transform(
|
708
|
+
self,
|
709
|
+
expected_output_cols_list: List[str],
|
710
|
+
drop_input_cols: Optional[bool] = False,
|
711
|
+
) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
|
712
|
+
"""Trains the model by pushing down the compute into Snowflake using stored procedures.
|
713
|
+
This API is different from fit itself because it would also provide the transform
|
714
|
+
output.
|
715
|
+
|
716
|
+
Args:
|
717
|
+
expected_output_cols_list (List[str]): The output columns
|
718
|
+
name as a list. Defaults to None.
|
719
|
+
drop_input_cols (Optional[bool]): Boolean to determine whether to
|
720
|
+
drop the input columns from the output dataset.
|
721
|
+
|
722
|
+
Returns:
|
723
|
+
Tuple[Union[DataFrame, pd.DataFrame], object]: [transformed dataset, estimator]
|
724
|
+
"""
|
725
|
+
dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(self.dataset)
|
726
|
+
|
727
|
+
# Extract query that generated the dataframe. We will need to pass it to the fit procedure.
|
728
|
+
queries = dataset.queries["queries"]
|
729
|
+
|
730
|
+
transform_stage_name = self._create_temp_stage()
|
731
|
+
(stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
|
732
|
+
stage_name=transform_stage_name
|
733
|
+
)
|
734
|
+
|
735
|
+
# Call fit sproc
|
736
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
737
|
+
project=_PROJECT,
|
738
|
+
subproject=self._subproject,
|
739
|
+
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
740
|
+
api_calls=[Session.call],
|
741
|
+
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
742
|
+
)
|
743
|
+
|
744
|
+
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
|
745
|
+
fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
|
746
|
+
|
747
|
+
sproc_export_file_name: str = fit_transform_wrapper_sproc(
|
748
|
+
self.session,
|
749
|
+
queries,
|
750
|
+
stage_transform_file_name,
|
751
|
+
stage_result_file_name,
|
752
|
+
self.input_cols,
|
753
|
+
self.label_cols,
|
754
|
+
self.sample_weight_col,
|
755
|
+
statement_params,
|
756
|
+
drop_input_cols,
|
757
|
+
expected_output_cols_list,
|
758
|
+
fit_transform_result_name,
|
759
|
+
)
|
760
|
+
|
761
|
+
output_result_sp = self.session.table(fit_transform_result_name)
|
762
|
+
fitted_estimator = self._fetch_model_from_stage(
|
763
|
+
dir_path=stage_result_file_name,
|
764
|
+
file_name=sproc_export_file_name,
|
765
|
+
statement_params=statement_params,
|
766
|
+
)
|
767
|
+
|
768
|
+
return output_result_sp, fitted_estimator
|