snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ class ModelTrainer(Protocol):
|
|
20
20
|
self,
|
21
21
|
expected_output_cols_list: List[str],
|
22
22
|
drop_input_cols: Optional[bool] = False,
|
23
|
+
example_output_pd_df: Optional[pd.DataFrame] = None,
|
23
24
|
) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
|
24
25
|
raise NotImplementedError
|
25
26
|
|
@@ -377,7 +377,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
377
377
|
anonymous=True,
|
378
378
|
imports=imports, # type: ignore[arg-type]
|
379
379
|
statement_params=sproc_statement_params,
|
380
|
-
execute_as="caller",
|
381
380
|
)
|
382
381
|
def _distributed_search(
|
383
382
|
session: Session,
|
@@ -495,7 +494,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
495
494
|
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
496
495
|
args[label_arg_name] = df[label_cols].squeeze()
|
497
496
|
|
498
|
-
if sample_weight_col is not None
|
497
|
+
if sample_weight_col is not None:
|
499
498
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
500
499
|
return args, estimator, indices, len(df), params_to_evaluate
|
501
500
|
|
@@ -783,7 +782,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
783
782
|
anonymous=True,
|
784
783
|
imports=imports, # type: ignore[arg-type]
|
785
784
|
statement_params=sproc_statement_params,
|
786
|
-
execute_as="caller",
|
787
785
|
)
|
788
786
|
def _distributed_search(
|
789
787
|
session: Session,
|
@@ -1061,7 +1059,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
1061
1059
|
if label_cols:
|
1062
1060
|
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
1063
1061
|
args[label_arg_name] = y
|
1064
|
-
if sample_weight_col is not None
|
1062
|
+
if sample_weight_col is not None:
|
1065
1063
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
1066
1064
|
# estimator.refit = original_refit
|
1067
1065
|
refit_start_time = time.time()
|
@@ -318,19 +318,19 @@ class SnowparkTransformHandlers:
|
|
318
318
|
with open(local_score_file_name_path, mode="r+b") as local_score_file_obj:
|
319
319
|
estimator = cp.load(local_score_file_obj)
|
320
320
|
|
321
|
-
|
322
|
-
if "X" in
|
321
|
+
params = inspect.signature(estimator.score).parameters
|
322
|
+
if "X" in params:
|
323
323
|
args = {"X": df[input_cols]}
|
324
|
-
elif "X_test" in
|
324
|
+
elif "X_test" in params:
|
325
325
|
args = {"X_test": df[input_cols]}
|
326
326
|
else:
|
327
327
|
raise RuntimeError("Neither 'X' or 'X_test' exist in argument")
|
328
328
|
|
329
329
|
if label_cols:
|
330
|
-
label_arg_name = "Y" if "Y" in
|
330
|
+
label_arg_name = "Y" if "Y" in params else "y"
|
331
331
|
args[label_arg_name] = df[label_cols].squeeze()
|
332
332
|
|
333
|
-
if sample_weight_col is not None and "sample_weight" in
|
333
|
+
if sample_weight_col is not None and "sample_weight" in params:
|
334
334
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
335
335
|
|
336
336
|
result: float = estimator.score(**args)
|
@@ -35,6 +35,7 @@ cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
|
35
35
|
|
36
36
|
_PROJECT = "ModelDevelopment"
|
37
37
|
_ENABLE_ANONYMOUS_SPROC = False
|
38
|
+
_ENABLE_TRACER = True
|
38
39
|
|
39
40
|
|
40
41
|
class SnowparkModelTrainer:
|
@@ -119,6 +120,8 @@ class SnowparkModelTrainer:
|
|
119
120
|
A callable that can be registered as a stored procedure.
|
120
121
|
"""
|
121
122
|
imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
|
123
|
+
method_name = "fit"
|
124
|
+
tracer_name = f"snowpark.ml.modeling.{self._class_name.lower()}.{method_name}"
|
122
125
|
|
123
126
|
def fit_wrapper_function(
|
124
127
|
session: Session,
|
@@ -138,110 +141,97 @@ class SnowparkModelTrainer:
|
|
138
141
|
for import_name in imports:
|
139
142
|
importlib.import_module(import_name)
|
140
143
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
144
|
+
def fit_and_return_estimator() -> str:
|
145
|
+
"""This is a helper function within the sproc to download the data, fit the model, and upload the model.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
The name of the file in session's temp stage (temp_stage_name) that contains the serialized model.
|
149
|
+
"""
|
150
|
+
# Execute snowpark queries and obtain the results as pandas dataframe
|
151
|
+
# NB: this implies that the result data must fit into memory.
|
152
|
+
for query in sql_queries[:-1]:
|
153
|
+
_ = session.sql(query).collect(statement_params=statement_params)
|
154
|
+
sp_df = session.sql(sql_queries[-1])
|
155
|
+
df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
|
156
|
+
df.columns = sp_df.columns
|
157
|
+
|
158
|
+
local_transform_file_name = temp_file_utils.get_temp_file_path()
|
159
|
+
|
160
|
+
session.file.get(
|
161
|
+
stage_location=temp_stage_name,
|
162
|
+
target_directory=local_transform_file_name,
|
163
|
+
statement_params=statement_params,
|
164
|
+
)
|
148
165
|
|
149
|
-
|
166
|
+
local_transform_file_path = os.path.join(
|
167
|
+
local_transform_file_name, os.listdir(local_transform_file_name)[0]
|
168
|
+
)
|
169
|
+
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
170
|
+
estimator = cp.load(local_transform_file_obj)
|
150
171
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
172
|
+
params = inspect.signature(estimator.fit).parameters
|
173
|
+
args = {"X": df[input_cols]}
|
174
|
+
if label_cols:
|
175
|
+
label_arg_name = "Y" if "Y" in params else "y"
|
176
|
+
args[label_arg_name] = df[label_cols].squeeze()
|
156
177
|
|
157
|
-
|
158
|
-
|
159
|
-
)
|
160
|
-
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
161
|
-
estimator = cp.load(local_transform_file_obj)
|
178
|
+
if sample_weight_col is not None and "sample_weight" in params:
|
179
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
162
180
|
|
163
|
-
|
164
|
-
args = {"X": df[input_cols]}
|
165
|
-
if label_cols:
|
166
|
-
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
167
|
-
args[label_arg_name] = df[label_cols].squeeze()
|
181
|
+
estimator.fit(**args)
|
168
182
|
|
169
|
-
|
170
|
-
args["sample_weight"] = df[sample_weight_col].squeeze()
|
183
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
171
184
|
|
172
|
-
|
185
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
186
|
+
cp.dump(estimator, local_result_file_obj)
|
173
187
|
|
174
|
-
|
188
|
+
session.file.put(
|
189
|
+
local_file_name=local_result_file_name,
|
190
|
+
stage_location=temp_stage_name,
|
191
|
+
auto_compress=False,
|
192
|
+
overwrite=True,
|
193
|
+
statement_params=statement_params,
|
194
|
+
)
|
195
|
+
return local_result_file_name
|
175
196
|
|
176
|
-
|
177
|
-
cp.dump(estimator, local_result_file_obj)
|
197
|
+
if _ENABLE_TRACER:
|
178
198
|
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
auto_compress=False,
|
183
|
-
overwrite=True,
|
184
|
-
statement_params=statement_params,
|
185
|
-
)
|
199
|
+
# Use opentelemetry to trace the dist and span of the fit operation.
|
200
|
+
# This would allow user to see the trace in the Snowflake UI.
|
201
|
+
from opentelemetry import trace
|
186
202
|
|
187
|
-
|
188
|
-
|
189
|
-
|
203
|
+
tracer = trace.get_tracer(tracer_name)
|
204
|
+
with tracer.start_as_current_span("fit"):
|
205
|
+
local_result_file_name = fit_and_return_estimator()
|
206
|
+
# Note: you can add something like + "|" + str(df) to the return string
|
207
|
+
# to pass debug information to the caller.
|
208
|
+
return str(os.path.basename(local_result_file_name))
|
209
|
+
else:
|
210
|
+
local_result_file_name = fit_and_return_estimator()
|
211
|
+
return str(os.path.basename(local_result_file_name))
|
190
212
|
|
191
213
|
return fit_wrapper_function
|
192
214
|
|
193
|
-
def
|
215
|
+
def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
|
194
216
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
195
|
-
fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
196
|
-
|
197
|
-
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
198
|
-
pkg_versions=model_spec.pkgDependencies, session=self.session
|
199
|
-
)
|
200
|
-
|
201
|
-
fit_wrapper_sproc = self.session.sproc.register(
|
202
|
-
func=self._build_fit_wrapper_sproc(model_spec=model_spec),
|
203
|
-
is_permanent=False,
|
204
|
-
name=fit_sproc_name,
|
205
|
-
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
206
|
-
replace=True,
|
207
|
-
session=self.session,
|
208
|
-
statement_params=statement_params,
|
209
|
-
anonymous=True,
|
210
|
-
execute_as="caller",
|
211
|
-
)
|
212
|
-
|
213
|
-
return fit_wrapper_sproc
|
214
|
-
|
215
|
-
def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
216
|
-
# If the sproc already exists, don't register.
|
217
|
-
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
218
|
-
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
219
|
-
|
220
|
-
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
221
|
-
fit_sproc_key = model_spec.__class__.__name__
|
222
|
-
if fit_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
223
|
-
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
|
224
|
-
return fit_sproc
|
225
217
|
|
226
218
|
fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
227
219
|
|
228
220
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
229
221
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
230
222
|
)
|
223
|
+
packages = ["snowflake-snowpark-python", "snowflake-telemetry-python"] + relaxed_dependencies
|
231
224
|
|
232
225
|
fit_wrapper_sproc = self.session.sproc.register(
|
233
226
|
func=self._build_fit_wrapper_sproc(model_spec=model_spec),
|
234
227
|
is_permanent=False,
|
235
228
|
name=fit_sproc_name,
|
236
|
-
packages=
|
229
|
+
packages=packages, # type: ignore[arg-type]
|
237
230
|
replace=True,
|
238
231
|
session=self.session,
|
239
232
|
statement_params=statement_params,
|
240
|
-
|
233
|
+
anonymous=anonymous,
|
241
234
|
)
|
242
|
-
|
243
|
-
self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined]
|
244
|
-
|
245
235
|
return fit_wrapper_sproc
|
246
236
|
|
247
237
|
def _build_fit_predict_wrapper_sproc(
|
@@ -333,7 +323,9 @@ class SnowparkModelTrainer:
|
|
333
323
|
|
334
324
|
# write into a temp table in sproc and load the table from outside
|
335
325
|
session.write_pandas(
|
336
|
-
fit_predict_result_pd,
|
326
|
+
fit_predict_result_pd,
|
327
|
+
fit_predict_result_name,
|
328
|
+
overwrite=True,
|
337
329
|
)
|
338
330
|
|
339
331
|
# Note: you can add something like + "|" + str(df) to the return string
|
@@ -414,13 +406,13 @@ class SnowparkModelTrainer:
|
|
414
406
|
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
415
407
|
estimator = cp.load(local_transform_file_obj)
|
416
408
|
|
417
|
-
|
409
|
+
params = inspect.signature(estimator.fit).parameters
|
418
410
|
args = {"X": df[input_cols]}
|
419
411
|
if label_cols:
|
420
|
-
label_arg_name = "Y" if "Y" in
|
412
|
+
label_arg_name = "Y" if "Y" in params else "y"
|
421
413
|
args[label_arg_name] = df[label_cols].squeeze()
|
422
414
|
|
423
|
-
if sample_weight_col is not None and "sample_weight" in
|
415
|
+
if sample_weight_col is not None and "sample_weight" in params:
|
424
416
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
425
417
|
|
426
418
|
fit_transform_result = estimator.fit_transform(**args)
|
@@ -468,16 +460,14 @@ class SnowparkModelTrainer:
|
|
468
460
|
session.write_pandas(
|
469
461
|
transformed_pandas_df,
|
470
462
|
fit_transform_result_name,
|
471
|
-
|
472
|
-
table_type="temp",
|
473
|
-
quote_identifiers=False,
|
463
|
+
overwrite=True,
|
474
464
|
)
|
475
465
|
|
476
466
|
return str(os.path.basename(local_result_file_name))
|
477
467
|
|
478
468
|
return fit_transform_wrapper_function
|
479
469
|
|
480
|
-
def
|
470
|
+
def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
|
481
471
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
482
472
|
|
483
473
|
fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
@@ -494,49 +484,12 @@ class SnowparkModelTrainer:
|
|
494
484
|
replace=True,
|
495
485
|
session=self.session,
|
496
486
|
statement_params=statement_params,
|
497
|
-
anonymous=
|
498
|
-
execute_as="caller",
|
487
|
+
anonymous=anonymous,
|
499
488
|
)
|
500
489
|
|
501
490
|
return fit_predict_wrapper_sproc
|
502
491
|
|
503
|
-
def
|
504
|
-
# If the sproc already exists, don't register.
|
505
|
-
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
506
|
-
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
507
|
-
|
508
|
-
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
509
|
-
fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict"
|
510
|
-
if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
511
|
-
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
512
|
-
fit_predict_sproc_key
|
513
|
-
]
|
514
|
-
return fit_sproc
|
515
|
-
|
516
|
-
fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
517
|
-
|
518
|
-
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
519
|
-
pkg_versions=model_spec.pkgDependencies, session=self.session
|
520
|
-
)
|
521
|
-
|
522
|
-
fit_predict_wrapper_sproc = self.session.sproc.register(
|
523
|
-
func=self._build_fit_predict_wrapper_sproc(model_spec=model_spec),
|
524
|
-
is_permanent=False,
|
525
|
-
name=fit_predict_sproc_name,
|
526
|
-
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
527
|
-
replace=True,
|
528
|
-
session=self.session,
|
529
|
-
statement_params=statement_params,
|
530
|
-
execute_as="caller",
|
531
|
-
)
|
532
|
-
|
533
|
-
self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
534
|
-
fit_predict_sproc_key
|
535
|
-
] = fit_predict_wrapper_sproc
|
536
|
-
|
537
|
-
return fit_predict_wrapper_sproc
|
538
|
-
|
539
|
-
def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
492
|
+
def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
|
540
493
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
541
494
|
|
542
495
|
fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
@@ -553,44 +506,8 @@ class SnowparkModelTrainer:
|
|
553
506
|
replace=True,
|
554
507
|
session=self.session,
|
555
508
|
statement_params=statement_params,
|
556
|
-
anonymous=
|
557
|
-
execute_as="caller",
|
509
|
+
anonymous=anonymous,
|
558
510
|
)
|
559
|
-
return fit_transform_wrapper_sproc
|
560
|
-
|
561
|
-
def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
562
|
-
# If the sproc already exists, don't register.
|
563
|
-
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
564
|
-
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
565
|
-
|
566
|
-
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
567
|
-
fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform"
|
568
|
-
if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
569
|
-
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
570
|
-
fit_transform_sproc_key
|
571
|
-
]
|
572
|
-
return fit_sproc
|
573
|
-
|
574
|
-
fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
575
|
-
|
576
|
-
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
577
|
-
pkg_versions=model_spec.pkgDependencies, session=self.session
|
578
|
-
)
|
579
|
-
|
580
|
-
fit_transform_wrapper_sproc = self.session.sproc.register(
|
581
|
-
func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
|
582
|
-
is_permanent=False,
|
583
|
-
name=fit_transform_sproc_name,
|
584
|
-
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
585
|
-
replace=True,
|
586
|
-
session=self.session,
|
587
|
-
statement_params=statement_params,
|
588
|
-
execute_as="caller",
|
589
|
-
)
|
590
|
-
|
591
|
-
self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
592
|
-
fit_transform_sproc_key
|
593
|
-
] = fit_transform_wrapper_sproc
|
594
511
|
|
595
512
|
return fit_transform_wrapper_sproc
|
596
513
|
|
@@ -629,9 +546,9 @@ class SnowparkModelTrainer:
|
|
629
546
|
# Call fit sproc
|
630
547
|
|
631
548
|
if _ENABLE_ANONYMOUS_SPROC:
|
632
|
-
fit_wrapper_sproc = self.
|
549
|
+
fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=True)
|
633
550
|
else:
|
634
|
-
fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params)
|
551
|
+
fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=False)
|
635
552
|
|
636
553
|
try:
|
637
554
|
sproc_export_file_name: str = fit_wrapper_sproc(
|
@@ -665,6 +582,7 @@ class SnowparkModelTrainer:
|
|
665
582
|
self,
|
666
583
|
expected_output_cols_list: List[str],
|
667
584
|
drop_input_cols: Optional[bool] = False,
|
585
|
+
example_output_pd_df: Optional[pd.DataFrame] = None,
|
668
586
|
) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
|
669
587
|
"""Trains the model by pushing down the compute into Snowflake using stored procedures.
|
670
588
|
This API is different from fit itself because it would also provide the predict
|
@@ -675,6 +593,11 @@ class SnowparkModelTrainer:
|
|
675
593
|
name as a list. Defaults to None.
|
676
594
|
drop_input_cols (Optional[bool]): Boolean to determine drop
|
677
595
|
the input columns from the output dataset or not
|
596
|
+
example_output_pd_df (Optional[pd.DataFrame]): Example output dataframe
|
597
|
+
This is to create a temp table in the client side with df_one_row. This can maintain the same column
|
598
|
+
name and data type as the output dataframe. Within the sproc, we don't need to create another temp table
|
599
|
+
again - instead, we overwrite into this table without changing the schema.
|
600
|
+
This is not used in PandasModelTrainer.
|
678
601
|
|
679
602
|
Returns:
|
680
603
|
Tuple[Union[DataFrame, pd.DataFrame], object]: [predicted dataset, estimator]
|
@@ -702,12 +625,35 @@ class SnowparkModelTrainer:
|
|
702
625
|
|
703
626
|
# Call fit sproc
|
704
627
|
if _ENABLE_ANONYMOUS_SPROC:
|
705
|
-
fit_predict_wrapper_sproc = self.
|
628
|
+
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
|
629
|
+
statement_params=statement_params, anonymous=True
|
630
|
+
)
|
706
631
|
else:
|
707
|
-
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
|
632
|
+
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
|
633
|
+
statement_params=statement_params, anonymous=False
|
634
|
+
)
|
708
635
|
|
709
636
|
fit_predict_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
710
637
|
|
638
|
+
# Create a temp table in advance to store the output
|
639
|
+
# This would allow us to use the same table outside the stored procedure
|
640
|
+
if not drop_input_cols:
|
641
|
+
assert example_output_pd_df is not None
|
642
|
+
remove_dataset_col_name_exist_in_output_col = list(set(dataset.columns) - set(example_output_pd_df.columns))
|
643
|
+
pd_df_one_row = (
|
644
|
+
dataset.select(remove_dataset_col_name_exist_in_output_col)
|
645
|
+
.limit(1)
|
646
|
+
.to_pandas(statement_params=statement_params)
|
647
|
+
)
|
648
|
+
example_output_pd_df = pd.concat([pd_df_one_row, example_output_pd_df], axis=1)
|
649
|
+
|
650
|
+
self.session.write_pandas(
|
651
|
+
example_output_pd_df,
|
652
|
+
fit_predict_result_name,
|
653
|
+
auto_create_table=True,
|
654
|
+
table_type="temp",
|
655
|
+
)
|
656
|
+
|
711
657
|
sproc_export_file_name: str = fit_predict_wrapper_sproc(
|
712
658
|
self.session,
|
713
659
|
queries,
|
@@ -769,14 +715,32 @@ class SnowparkModelTrainer:
|
|
769
715
|
|
770
716
|
# Call fit sproc
|
771
717
|
if _ENABLE_ANONYMOUS_SPROC:
|
772
|
-
fit_transform_wrapper_sproc = self.
|
773
|
-
statement_params=statement_params
|
718
|
+
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
|
719
|
+
statement_params=statement_params, anonymous=True
|
774
720
|
)
|
775
721
|
else:
|
776
|
-
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
|
722
|
+
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
|
723
|
+
statement_params=statement_params, anonymous=False
|
724
|
+
)
|
777
725
|
|
778
726
|
fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
779
727
|
|
728
|
+
# Create a temp table in advance to store the output
|
729
|
+
# This would allow us to use the same table outside the stored procedure
|
730
|
+
df_one_line = dataset.limit(1).to_pandas(statement_params=statement_params)
|
731
|
+
df_one_line[
|
732
|
+
expected_output_cols_list[0]
|
733
|
+
] = "[0]" # Add one column as the output_col; this is a dummy value to represent the OBJECT type
|
734
|
+
if drop_input_cols:
|
735
|
+
self.session.write_pandas(
|
736
|
+
df_one_line[expected_output_cols_list[0]],
|
737
|
+
fit_transform_result_name,
|
738
|
+
auto_create_table=True,
|
739
|
+
table_type="temp",
|
740
|
+
)
|
741
|
+
else:
|
742
|
+
self.session.write_pandas(df_one_line, fit_transform_result_name, auto_create_table=True, table_type="temp")
|
743
|
+
|
780
744
|
sproc_export_file_name: str = fit_transform_wrapper_sproc(
|
781
745
|
self.session,
|
782
746
|
queries,
|
@@ -303,7 +303,6 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
|
303
303
|
statement_params=statement_params,
|
304
304
|
anonymous=True,
|
305
305
|
imports=list(import_file_paths),
|
306
|
-
execute_as="caller",
|
307
306
|
) # type: ignore[misc]
|
308
307
|
def fit_wrapper_sproc(
|
309
308
|
session: Session,
|