snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -110,6 +110,15 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
|
|
110
110
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
|
111
111
|
# Needs to convert to conversation object.
|
112
112
|
if task == "conversational":
|
113
|
+
warnings.warn(
|
114
|
+
(
|
115
|
+
"Conversational pipeline is removed from transformers since 4.42.0. "
|
116
|
+
"Support will be removed from snowflake-ml-python soon."
|
117
|
+
),
|
118
|
+
category=DeprecationWarning,
|
119
|
+
stacklevel=1,
|
120
|
+
)
|
121
|
+
|
113
122
|
return core.ModelSignature(
|
114
123
|
inputs=[
|
115
124
|
core.FeatureSpec(name="user_inputs", dtype=core.DataType.STRING, shape=(-1,)),
|
snowflake/ml/model/models/llm.py
CHANGED
@@ -70,7 +70,9 @@ class LLM:
|
|
70
70
|
|
71
71
|
import peft
|
72
72
|
|
73
|
-
peft_config = peft.PeftConfig.from_pretrained(
|
73
|
+
peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined]
|
74
|
+
model_id_or_path, **hub_kwargs
|
75
|
+
)
|
74
76
|
if peft_config.peft_type != peft.PeftType.LORA: # type: ignore[attr-defined]
|
75
77
|
raise ValueError("Only LORA is supported.")
|
76
78
|
if peft_config.task_type != peft.TaskType.CAUSAL_LM: # type: ignore[attr-defined]
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
+
from enum import Enum
|
2
3
|
from typing import (
|
3
4
|
TYPE_CHECKING,
|
4
5
|
Any,
|
@@ -232,13 +233,12 @@ class BaseModelSaveOption(TypedDict):
|
|
232
233
|
_legacy_save: NotRequired[bool]
|
233
234
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
234
235
|
method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
|
235
|
-
|
236
|
+
enable_explainability: NotRequired[bool]
|
236
237
|
|
237
238
|
|
238
239
|
class CatBoostModelSaveOptions(BaseModelSaveOption):
|
239
240
|
target_methods: NotRequired[Sequence[str]]
|
240
241
|
cuda_version: NotRequired[str]
|
241
|
-
enable_explainability: NotRequired[bool]
|
242
242
|
|
243
243
|
|
244
244
|
class CustomModelSaveOption(BaseModelSaveOption):
|
@@ -252,12 +252,10 @@ class SKLModelSaveOptions(BaseModelSaveOption):
|
|
252
252
|
class XGBModelSaveOptions(BaseModelSaveOption):
|
253
253
|
target_methods: NotRequired[Sequence[str]]
|
254
254
|
cuda_version: NotRequired[str]
|
255
|
-
enable_explainability: NotRequired[bool]
|
256
255
|
|
257
256
|
|
258
257
|
class LGBMModelSaveOptions(BaseModelSaveOption):
|
259
258
|
target_methods: NotRequired[Sequence[str]]
|
260
|
-
enable_explainability: NotRequired[bool]
|
261
259
|
|
262
260
|
|
263
261
|
class SNOWModelSaveOptions(BaseModelSaveOption):
|
@@ -433,3 +431,11 @@ class Deployment(TypedDict):
|
|
433
431
|
signature: core.ModelSignature
|
434
432
|
options: Required[DeployOptions]
|
435
433
|
details: NotRequired[DeployDetails]
|
434
|
+
|
435
|
+
|
436
|
+
class ModelObjective(Enum):
|
437
|
+
UNKNOWN = "unknown"
|
438
|
+
BINARY_CLASSIFICATION = "binary_classification"
|
439
|
+
MULTI_CLASSIFICATION = "multi_classification"
|
440
|
+
REGRESSION = "regression"
|
441
|
+
RANKING = "ranking"
|
@@ -166,10 +166,10 @@ class PandasTransformHandlers:
|
|
166
166
|
SnowflakeMLException: The input column list does not have one of `X` and `X_test`.
|
167
167
|
"""
|
168
168
|
assert hasattr(self.estimator, "score") # make type checker happy
|
169
|
-
|
170
|
-
if "X" in
|
169
|
+
params = inspect.signature(self.estimator.score).parameters
|
170
|
+
if "X" in params:
|
171
171
|
score_args = {"X": self.dataset[input_cols]}
|
172
|
-
elif "X_test" in
|
172
|
+
elif "X_test" in params:
|
173
173
|
score_args = {"X_test": self.dataset[input_cols]}
|
174
174
|
else:
|
175
175
|
raise exceptions.SnowflakeMLException(
|
@@ -178,10 +178,10 @@ class PandasTransformHandlers:
|
|
178
178
|
)
|
179
179
|
|
180
180
|
if len(label_cols) > 0:
|
181
|
-
label_arg_name = "Y" if "Y" in
|
181
|
+
label_arg_name = "Y" if "Y" in params else "y"
|
182
182
|
score_args[label_arg_name] = self.dataset[label_cols].squeeze()
|
183
183
|
|
184
|
-
if sample_weight_col is not None and "sample_weight" in
|
184
|
+
if sample_weight_col is not None and "sample_weight" in params:
|
185
185
|
score_args["sample_weight"] = self.dataset[sample_weight_col].squeeze()
|
186
186
|
|
187
187
|
score = self.estimator.score(**score_args)
|
@@ -43,14 +43,14 @@ class PandasModelTrainer:
|
|
43
43
|
Trained model
|
44
44
|
"""
|
45
45
|
assert hasattr(self.estimator, "fit") # Keep mypy happy
|
46
|
-
|
46
|
+
params = inspect.signature(self.estimator.fit).parameters
|
47
47
|
args = {"X": self.dataset[self.input_cols]}
|
48
48
|
|
49
49
|
if self.label_cols:
|
50
|
-
label_arg_name = "Y" if "Y" in
|
50
|
+
label_arg_name = "Y" if "Y" in params else "y"
|
51
51
|
args[label_arg_name] = self.dataset[self.label_cols].squeeze()
|
52
52
|
|
53
|
-
if self.sample_weight_col is not None and "sample_weight" in
|
53
|
+
if self.sample_weight_col is not None and "sample_weight" in params:
|
54
54
|
args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
|
55
55
|
|
56
56
|
return self.estimator.fit(**args)
|
@@ -59,6 +59,7 @@ class PandasModelTrainer:
|
|
59
59
|
self,
|
60
60
|
expected_output_cols_list: List[str],
|
61
61
|
drop_input_cols: Optional[bool] = False,
|
62
|
+
example_output_pd_df: Optional[pd.DataFrame] = None,
|
62
63
|
) -> Tuple[pd.DataFrame, object]:
|
63
64
|
"""Trains the model using specified features and target columns from the dataset.
|
64
65
|
This API is different from fit itself because it would also provide the predict
|
@@ -69,6 +70,8 @@ class PandasModelTrainer:
|
|
69
70
|
name as a list. Defaults to None.
|
70
71
|
drop_input_cols (Optional[bool]): Boolean to determine whether to
|
71
72
|
drop the input columns from the output dataset.
|
73
|
+
example_output_pd_df (Optional[pd.DataFrame]): Example output dataframe
|
74
|
+
This is not used in PandasModelTrainer. It is used in SnowparkModelTrainer.
|
72
75
|
|
73
76
|
Returns:
|
74
77
|
Tuple[pd.DataFrame, object]: [predicted dataset, estimator]
|
@@ -108,13 +111,13 @@ class PandasModelTrainer:
|
|
108
111
|
assert hasattr(self.estimator, "fit") # make type checker happy
|
109
112
|
assert hasattr(self.estimator, "fit_transform") # make type checker happy
|
110
113
|
|
111
|
-
|
114
|
+
params = inspect.signature(self.estimator.fit).parameters
|
112
115
|
args = {"X": self.dataset[self.input_cols]}
|
113
116
|
if self.label_cols:
|
114
|
-
label_arg_name = "Y" if "Y" in
|
117
|
+
label_arg_name = "Y" if "Y" in params else "y"
|
115
118
|
args[label_arg_name] = self.dataset[self.label_cols].squeeze()
|
116
119
|
|
117
|
-
if self.sample_weight_col is not None and "sample_weight" in
|
120
|
+
if self.sample_weight_col is not None and "sample_weight" in params:
|
118
121
|
args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
|
119
122
|
|
120
123
|
inference_res = self.estimator.fit_transform(**args)
|
@@ -53,11 +53,13 @@ class SKLearnModelSpecifications(ModelSpecifications):
|
|
53
53
|
|
54
54
|
class XGBoostModelSpecifications(ModelSpecifications):
|
55
55
|
def __init__(self) -> None:
|
56
|
+
import sklearn
|
56
57
|
import xgboost
|
57
58
|
|
58
59
|
imports: List[str] = ["xgboost"]
|
59
60
|
pkgDependencies: List[str] = [
|
60
61
|
f"numpy=={np.__version__}",
|
62
|
+
f"scikit-learn=={sklearn.__version__}",
|
61
63
|
f"xgboost=={xgboost.__version__}",
|
62
64
|
f"cloudpickle=={cp.__version__}",
|
63
65
|
]
|
@@ -20,6 +20,7 @@ class ModelTrainer(Protocol):
|
|
20
20
|
self,
|
21
21
|
expected_output_cols_list: List[str],
|
22
22
|
drop_input_cols: Optional[bool] = False,
|
23
|
+
example_output_pd_df: Optional[pd.DataFrame] = None,
|
23
24
|
) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
|
24
25
|
raise NotImplementedError
|
25
26
|
|
@@ -495,7 +495,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
495
495
|
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
496
496
|
args[label_arg_name] = df[label_cols].squeeze()
|
497
497
|
|
498
|
-
if sample_weight_col is not None
|
498
|
+
if sample_weight_col is not None:
|
499
499
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
500
500
|
return args, estimator, indices, len(df), params_to_evaluate
|
501
501
|
|
@@ -1061,7 +1061,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
1061
1061
|
if label_cols:
|
1062
1062
|
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
1063
1063
|
args[label_arg_name] = y
|
1064
|
-
if sample_weight_col is not None
|
1064
|
+
if sample_weight_col is not None:
|
1065
1065
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
1066
1066
|
# estimator.refit = original_refit
|
1067
1067
|
refit_start_time = time.time()
|
@@ -318,19 +318,19 @@ class SnowparkTransformHandlers:
|
|
318
318
|
with open(local_score_file_name_path, mode="r+b") as local_score_file_obj:
|
319
319
|
estimator = cp.load(local_score_file_obj)
|
320
320
|
|
321
|
-
|
322
|
-
if "X" in
|
321
|
+
params = inspect.signature(estimator.score).parameters
|
322
|
+
if "X" in params:
|
323
323
|
args = {"X": df[input_cols]}
|
324
|
-
elif "X_test" in
|
324
|
+
elif "X_test" in params:
|
325
325
|
args = {"X_test": df[input_cols]}
|
326
326
|
else:
|
327
327
|
raise RuntimeError("Neither 'X' or 'X_test' exist in argument")
|
328
328
|
|
329
329
|
if label_cols:
|
330
|
-
label_arg_name = "Y" if "Y" in
|
330
|
+
label_arg_name = "Y" if "Y" in params else "y"
|
331
331
|
args[label_arg_name] = df[label_cols].squeeze()
|
332
332
|
|
333
|
-
if sample_weight_col is not None and "sample_weight" in
|
333
|
+
if sample_weight_col is not None and "sample_weight" in params:
|
334
334
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
335
335
|
|
336
336
|
result: float = estimator.score(**args)
|
@@ -35,6 +35,7 @@ cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
|
35
35
|
|
36
36
|
_PROJECT = "ModelDevelopment"
|
37
37
|
_ENABLE_ANONYMOUS_SPROC = False
|
38
|
+
_ENABLE_TRACER = True
|
38
39
|
|
39
40
|
|
40
41
|
class SnowparkModelTrainer:
|
@@ -119,6 +120,8 @@ class SnowparkModelTrainer:
|
|
119
120
|
A callable that can be registered as a stored procedure.
|
120
121
|
"""
|
121
122
|
imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
|
123
|
+
method_name = "fit"
|
124
|
+
tracer_name = f"snowpark.ml.modeling.{self._class_name.lower()}.{method_name}"
|
122
125
|
|
123
126
|
def fit_wrapper_function(
|
124
127
|
session: Session,
|
@@ -138,110 +141,98 @@ class SnowparkModelTrainer:
|
|
138
141
|
for import_name in imports:
|
139
142
|
importlib.import_module(import_name)
|
140
143
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
144
|
+
def fit_and_return_estimator() -> str:
|
145
|
+
"""This is a helper function within the sproc to download the data, fit the model, and upload the model.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
The name of the file in session's temp stage (temp_stage_name) that contains the serialized model.
|
149
|
+
"""
|
150
|
+
# Execute snowpark queries and obtain the results as pandas dataframe
|
151
|
+
# NB: this implies that the result data must fit into memory.
|
152
|
+
for query in sql_queries[:-1]:
|
153
|
+
_ = session.sql(query).collect(statement_params=statement_params)
|
154
|
+
sp_df = session.sql(sql_queries[-1])
|
155
|
+
df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
|
156
|
+
df.columns = sp_df.columns
|
157
|
+
|
158
|
+
local_transform_file_name = temp_file_utils.get_temp_file_path()
|
159
|
+
|
160
|
+
session.file.get(
|
161
|
+
stage_location=temp_stage_name,
|
162
|
+
target_directory=local_transform_file_name,
|
163
|
+
statement_params=statement_params,
|
164
|
+
)
|
148
165
|
|
149
|
-
|
166
|
+
local_transform_file_path = os.path.join(
|
167
|
+
local_transform_file_name, os.listdir(local_transform_file_name)[0]
|
168
|
+
)
|
169
|
+
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
170
|
+
estimator = cp.load(local_transform_file_obj)
|
150
171
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
172
|
+
params = inspect.signature(estimator.fit).parameters
|
173
|
+
args = {"X": df[input_cols]}
|
174
|
+
if label_cols:
|
175
|
+
label_arg_name = "Y" if "Y" in params else "y"
|
176
|
+
args[label_arg_name] = df[label_cols].squeeze()
|
156
177
|
|
157
|
-
|
158
|
-
|
159
|
-
)
|
160
|
-
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
161
|
-
estimator = cp.load(local_transform_file_obj)
|
178
|
+
if sample_weight_col is not None and "sample_weight" in params:
|
179
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
162
180
|
|
163
|
-
|
164
|
-
args = {"X": df[input_cols]}
|
165
|
-
if label_cols:
|
166
|
-
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
167
|
-
args[label_arg_name] = df[label_cols].squeeze()
|
181
|
+
estimator.fit(**args)
|
168
182
|
|
169
|
-
|
170
|
-
args["sample_weight"] = df[sample_weight_col].squeeze()
|
183
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
171
184
|
|
172
|
-
|
185
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
186
|
+
cp.dump(estimator, local_result_file_obj)
|
173
187
|
|
174
|
-
|
188
|
+
session.file.put(
|
189
|
+
local_file_name=local_result_file_name,
|
190
|
+
stage_location=temp_stage_name,
|
191
|
+
auto_compress=False,
|
192
|
+
overwrite=True,
|
193
|
+
statement_params=statement_params,
|
194
|
+
)
|
195
|
+
return local_result_file_name
|
175
196
|
|
176
|
-
|
177
|
-
cp.dump(estimator, local_result_file_obj)
|
197
|
+
if _ENABLE_TRACER:
|
178
198
|
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
auto_compress=False,
|
183
|
-
overwrite=True,
|
184
|
-
statement_params=statement_params,
|
185
|
-
)
|
199
|
+
# Use opentelemetry to trace the dist and span of the fit operation.
|
200
|
+
# This would allow user to see the trace in the Snowflake UI.
|
201
|
+
from opentelemetry import trace
|
186
202
|
|
187
|
-
|
188
|
-
|
189
|
-
|
203
|
+
tracer = trace.get_tracer(tracer_name)
|
204
|
+
with tracer.start_as_current_span("fit"):
|
205
|
+
local_result_file_name = fit_and_return_estimator()
|
206
|
+
# Note: you can add something like + "|" + str(df) to the return string
|
207
|
+
# to pass debug information to the caller.
|
208
|
+
return str(os.path.basename(local_result_file_name))
|
209
|
+
else:
|
210
|
+
local_result_file_name = fit_and_return_estimator()
|
211
|
+
return str(os.path.basename(local_result_file_name))
|
190
212
|
|
191
213
|
return fit_wrapper_function
|
192
214
|
|
193
|
-
def
|
215
|
+
def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
|
194
216
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
195
|
-
fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
196
|
-
|
197
|
-
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
198
|
-
pkg_versions=model_spec.pkgDependencies, session=self.session
|
199
|
-
)
|
200
|
-
|
201
|
-
fit_wrapper_sproc = self.session.sproc.register(
|
202
|
-
func=self._build_fit_wrapper_sproc(model_spec=model_spec),
|
203
|
-
is_permanent=False,
|
204
|
-
name=fit_sproc_name,
|
205
|
-
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
206
|
-
replace=True,
|
207
|
-
session=self.session,
|
208
|
-
statement_params=statement_params,
|
209
|
-
anonymous=True,
|
210
|
-
execute_as="caller",
|
211
|
-
)
|
212
|
-
|
213
|
-
return fit_wrapper_sproc
|
214
|
-
|
215
|
-
def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
216
|
-
# If the sproc already exists, don't register.
|
217
|
-
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
218
|
-
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
219
|
-
|
220
|
-
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
221
|
-
fit_sproc_key = model_spec.__class__.__name__
|
222
|
-
if fit_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
223
|
-
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
|
224
|
-
return fit_sproc
|
225
217
|
|
226
218
|
fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
227
219
|
|
228
220
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
229
221
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
230
222
|
)
|
223
|
+
packages = ["snowflake-snowpark-python", "snowflake-telemetry-python"] + relaxed_dependencies
|
231
224
|
|
232
225
|
fit_wrapper_sproc = self.session.sproc.register(
|
233
226
|
func=self._build_fit_wrapper_sproc(model_spec=model_spec),
|
234
227
|
is_permanent=False,
|
235
228
|
name=fit_sproc_name,
|
236
|
-
packages=
|
229
|
+
packages=packages, # type: ignore[arg-type]
|
237
230
|
replace=True,
|
238
231
|
session=self.session,
|
239
232
|
statement_params=statement_params,
|
240
233
|
execute_as="caller",
|
234
|
+
anonymous=anonymous,
|
241
235
|
)
|
242
|
-
|
243
|
-
self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined]
|
244
|
-
|
245
236
|
return fit_wrapper_sproc
|
246
237
|
|
247
238
|
def _build_fit_predict_wrapper_sproc(
|
@@ -333,7 +324,9 @@ class SnowparkModelTrainer:
|
|
333
324
|
|
334
325
|
# write into a temp table in sproc and load the table from outside
|
335
326
|
session.write_pandas(
|
336
|
-
fit_predict_result_pd,
|
327
|
+
fit_predict_result_pd,
|
328
|
+
fit_predict_result_name,
|
329
|
+
overwrite=True,
|
337
330
|
)
|
338
331
|
|
339
332
|
# Note: you can add something like + "|" + str(df) to the return string
|
@@ -414,13 +407,13 @@ class SnowparkModelTrainer:
|
|
414
407
|
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
415
408
|
estimator = cp.load(local_transform_file_obj)
|
416
409
|
|
417
|
-
|
410
|
+
params = inspect.signature(estimator.fit).parameters
|
418
411
|
args = {"X": df[input_cols]}
|
419
412
|
if label_cols:
|
420
|
-
label_arg_name = "Y" if "Y" in
|
413
|
+
label_arg_name = "Y" if "Y" in params else "y"
|
421
414
|
args[label_arg_name] = df[label_cols].squeeze()
|
422
415
|
|
423
|
-
if sample_weight_col is not None and "sample_weight" in
|
416
|
+
if sample_weight_col is not None and "sample_weight" in params:
|
424
417
|
args["sample_weight"] = df[sample_weight_col].squeeze()
|
425
418
|
|
426
419
|
fit_transform_result = estimator.fit_transform(**args)
|
@@ -477,7 +470,7 @@ class SnowparkModelTrainer:
|
|
477
470
|
|
478
471
|
return fit_transform_wrapper_function
|
479
472
|
|
480
|
-
def
|
473
|
+
def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
|
481
474
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
482
475
|
|
483
476
|
fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
@@ -494,82 +487,14 @@ class SnowparkModelTrainer:
|
|
494
487
|
replace=True,
|
495
488
|
session=self.session,
|
496
489
|
statement_params=statement_params,
|
497
|
-
anonymous=
|
490
|
+
anonymous=anonymous,
|
498
491
|
execute_as="caller",
|
499
492
|
)
|
500
493
|
|
501
494
|
return fit_predict_wrapper_sproc
|
502
495
|
|
503
|
-
def
|
504
|
-
# If the sproc already exists, don't register.
|
505
|
-
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
506
|
-
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
507
|
-
|
508
|
-
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
509
|
-
fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict"
|
510
|
-
if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
511
|
-
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
512
|
-
fit_predict_sproc_key
|
513
|
-
]
|
514
|
-
return fit_sproc
|
515
|
-
|
516
|
-
fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
517
|
-
|
518
|
-
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
519
|
-
pkg_versions=model_spec.pkgDependencies, session=self.session
|
520
|
-
)
|
521
|
-
|
522
|
-
fit_predict_wrapper_sproc = self.session.sproc.register(
|
523
|
-
func=self._build_fit_predict_wrapper_sproc(model_spec=model_spec),
|
524
|
-
is_permanent=False,
|
525
|
-
name=fit_predict_sproc_name,
|
526
|
-
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
527
|
-
replace=True,
|
528
|
-
session=self.session,
|
529
|
-
statement_params=statement_params,
|
530
|
-
execute_as="caller",
|
531
|
-
)
|
532
|
-
|
533
|
-
self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
534
|
-
fit_predict_sproc_key
|
535
|
-
] = fit_predict_wrapper_sproc
|
536
|
-
|
537
|
-
return fit_predict_wrapper_sproc
|
538
|
-
|
539
|
-
def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
540
|
-
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
541
|
-
|
542
|
-
fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
543
|
-
|
544
|
-
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
545
|
-
pkg_versions=model_spec.pkgDependencies, session=self.session
|
546
|
-
)
|
547
|
-
|
548
|
-
fit_transform_wrapper_sproc = self.session.sproc.register(
|
549
|
-
func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
|
550
|
-
is_permanent=False,
|
551
|
-
name=fit_transform_sproc_name,
|
552
|
-
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
553
|
-
replace=True,
|
554
|
-
session=self.session,
|
555
|
-
statement_params=statement_params,
|
556
|
-
anonymous=True,
|
557
|
-
execute_as="caller",
|
558
|
-
)
|
559
|
-
return fit_transform_wrapper_sproc
|
560
|
-
|
561
|
-
def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
562
|
-
# If the sproc already exists, don't register.
|
563
|
-
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
564
|
-
self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
565
|
-
|
496
|
+
def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
|
566
497
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
567
|
-
fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform"
|
568
|
-
if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
569
|
-
fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
570
|
-
fit_transform_sproc_key
|
571
|
-
]
|
572
|
-
return fit_sproc
|
573
498
|
|
574
499
|
fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
575
500
|
|
@@ -586,12 +511,9 @@ class SnowparkModelTrainer:
|
|
586
511
|
session=self.session,
|
587
512
|
statement_params=statement_params,
|
588
513
|
execute_as="caller",
|
514
|
+
anonymous=anonymous,
|
589
515
|
)
|
590
516
|
|
591
|
-
self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
|
592
|
-
fit_transform_sproc_key
|
593
|
-
] = fit_transform_wrapper_sproc
|
594
|
-
|
595
517
|
return fit_transform_wrapper_sproc
|
596
518
|
|
597
519
|
def train(self) -> object:
|
@@ -629,9 +551,9 @@ class SnowparkModelTrainer:
|
|
629
551
|
# Call fit sproc
|
630
552
|
|
631
553
|
if _ENABLE_ANONYMOUS_SPROC:
|
632
|
-
fit_wrapper_sproc = self.
|
554
|
+
fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=True)
|
633
555
|
else:
|
634
|
-
fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params)
|
556
|
+
fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=False)
|
635
557
|
|
636
558
|
try:
|
637
559
|
sproc_export_file_name: str = fit_wrapper_sproc(
|
@@ -665,6 +587,7 @@ class SnowparkModelTrainer:
|
|
665
587
|
self,
|
666
588
|
expected_output_cols_list: List[str],
|
667
589
|
drop_input_cols: Optional[bool] = False,
|
590
|
+
example_output_pd_df: Optional[pd.DataFrame] = None,
|
668
591
|
) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
|
669
592
|
"""Trains the model by pushing down the compute into Snowflake using stored procedures.
|
670
593
|
This API is different from fit itself because it would also provide the predict
|
@@ -675,6 +598,11 @@ class SnowparkModelTrainer:
|
|
675
598
|
name as a list. Defaults to None.
|
676
599
|
drop_input_cols (Optional[bool]): Boolean to determine drop
|
677
600
|
the input columns from the output dataset or not
|
601
|
+
example_output_pd_df (Optional[pd.DataFrame]): Example output dataframe
|
602
|
+
This is to create a temp table in the client side with df_one_row. This can maintain the same column
|
603
|
+
name and data type as the output dataframe. Within the sproc, we don't need to create another temp table
|
604
|
+
again - instead, we overwrite into this table without changing the schema.
|
605
|
+
This is not used in PandasModelTrainer.
|
678
606
|
|
679
607
|
Returns:
|
680
608
|
Tuple[Union[DataFrame, pd.DataFrame], object]: [predicted dataset, estimator]
|
@@ -702,12 +630,35 @@ class SnowparkModelTrainer:
|
|
702
630
|
|
703
631
|
# Call fit sproc
|
704
632
|
if _ENABLE_ANONYMOUS_SPROC:
|
705
|
-
fit_predict_wrapper_sproc = self.
|
633
|
+
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
|
634
|
+
statement_params=statement_params, anonymous=True
|
635
|
+
)
|
706
636
|
else:
|
707
|
-
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
|
637
|
+
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
|
638
|
+
statement_params=statement_params, anonymous=False
|
639
|
+
)
|
708
640
|
|
709
641
|
fit_predict_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
710
642
|
|
643
|
+
# Create a temp table in advance to store the output
|
644
|
+
# This would allow us to use the same table outside the stored procedure
|
645
|
+
if not drop_input_cols:
|
646
|
+
assert example_output_pd_df is not None
|
647
|
+
remove_dataset_col_name_exist_in_output_col = list(set(dataset.columns) - set(example_output_pd_df.columns))
|
648
|
+
pd_df_one_row = (
|
649
|
+
dataset.select(remove_dataset_col_name_exist_in_output_col)
|
650
|
+
.limit(1)
|
651
|
+
.to_pandas(statement_params=statement_params)
|
652
|
+
)
|
653
|
+
example_output_pd_df = pd.concat([pd_df_one_row, example_output_pd_df], axis=1)
|
654
|
+
|
655
|
+
self.session.write_pandas(
|
656
|
+
example_output_pd_df,
|
657
|
+
fit_predict_result_name,
|
658
|
+
auto_create_table=True,
|
659
|
+
table_type="temp",
|
660
|
+
)
|
661
|
+
|
711
662
|
sproc_export_file_name: str = fit_predict_wrapper_sproc(
|
712
663
|
self.session,
|
713
664
|
queries,
|
@@ -769,11 +720,13 @@ class SnowparkModelTrainer:
|
|
769
720
|
|
770
721
|
# Call fit sproc
|
771
722
|
if _ENABLE_ANONYMOUS_SPROC:
|
772
|
-
fit_transform_wrapper_sproc = self.
|
773
|
-
statement_params=statement_params
|
723
|
+
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
|
724
|
+
statement_params=statement_params, anonymous=True
|
774
725
|
)
|
775
726
|
else:
|
776
|
-
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
|
727
|
+
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
|
728
|
+
statement_params=statement_params, anonymous=False
|
729
|
+
)
|
777
730
|
|
778
731
|
fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
779
732
|
|