snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
- snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
- snowflake/ml/modeling/cluster/birch.py +94 -124
- snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
- snowflake/ml/modeling/cluster/dbscan.py +94 -124
- snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
- snowflake/ml/modeling/cluster/k_means.py +93 -124
- snowflake/ml/modeling/cluster/mean_shift.py +94 -124
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
- snowflake/ml/modeling/cluster/optics.py +94 -124
- snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
- snowflake/ml/modeling/compose/column_transformer.py +94 -124
- snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
- snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
- snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
- snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
- snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
- snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
- snowflake/ml/modeling/covariance/oas.py +80 -110
- snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
- snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
- snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
- snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
- snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/pca.py +94 -124
- snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
- snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
- snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
- snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
- snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
- snowflake/ml/modeling/framework/base.py +2 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
- snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
- snowflake/ml/modeling/impute/knn_imputer.py +94 -124
- snowflake/ml/modeling/impute/missing_indicator.py +94 -124
- snowflake/ml/modeling/impute/simple_imputer.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
- snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
- snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/lars.py +96 -124
- snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
- snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
- snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
- snowflake/ml/modeling/linear_model/perceptron.py +95 -124
- snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ridge.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
- snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
- snowflake/ml/modeling/manifold/isomap.py +94 -124
- snowflake/ml/modeling/manifold/mds.py +94 -124
- snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
- snowflake/ml/modeling/manifold/tsne.py +94 -124
- snowflake/ml/modeling/metrics/classification.py +187 -52
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
- snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
- snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
- snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
- snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
- snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
- snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
- snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
- snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
- snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
- snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
- snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
- snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
- snowflake/ml/modeling/svm/linear_svc.py +96 -124
- snowflake/ml/modeling/svm/linear_svr.py +96 -124
- snowflake/ml/modeling/svm/nu_svc.py +96 -124
- snowflake/ml/modeling/svm/nu_svr.py +96 -124
- snowflake/ml/modeling/svm/svc.py +96 -124
- snowflake/ml/modeling/svm/svr.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
- snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
import inspect
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
import cloudpickle as cp
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
8
|
+
|
9
|
+
|
10
|
+
class ModelSpecifications:
|
11
|
+
"""
|
12
|
+
A dataclass to define model based specifications like required imports, and package dependencies for Sproc/Udfs.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, imports: List[str], pkgDependencies: List[str]) -> None:
|
16
|
+
self.imports = imports
|
17
|
+
self.pkgDependencies = pkgDependencies
|
18
|
+
|
19
|
+
|
20
|
+
class SKLearnModelSpecifications(ModelSpecifications):
|
21
|
+
def __init__(self) -> None:
|
22
|
+
import sklearn
|
23
|
+
|
24
|
+
imports: List[str] = ["sklearn"]
|
25
|
+
# TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
|
26
|
+
pkgDependencies = [
|
27
|
+
f"numpy=={np.__version__}",
|
28
|
+
f"scikit-learn=={sklearn.__version__}",
|
29
|
+
f"cloudpickle=={cp.__version__}",
|
30
|
+
]
|
31
|
+
|
32
|
+
# A change from previous implementation.
|
33
|
+
# When reusing the Sprocs for all the fit() call in the session, the static dpendencies list should include
|
34
|
+
# all the possible dependencies required during the lifetime.
|
35
|
+
|
36
|
+
# Include XGBoost in the dependencies if it is installed.
|
37
|
+
try:
|
38
|
+
import xgboost
|
39
|
+
except ModuleNotFoundError:
|
40
|
+
pass
|
41
|
+
else:
|
42
|
+
pkgDependencies.append(f"xgboost=={xgboost.__version__}")
|
43
|
+
|
44
|
+
# Include lightgbm in the dependencies if it is installed.
|
45
|
+
try:
|
46
|
+
import lightgbm
|
47
|
+
except ModuleNotFoundError:
|
48
|
+
pass
|
49
|
+
else:
|
50
|
+
pkgDependencies.append(f"lightgbm=={lightgbm.__version__}")
|
51
|
+
|
52
|
+
super().__init__(imports=imports, pkgDependencies=pkgDependencies)
|
53
|
+
|
54
|
+
|
55
|
+
class XGBoostModelSpecifications(ModelSpecifications):
|
56
|
+
def __init__(self) -> None:
|
57
|
+
import xgboost
|
58
|
+
|
59
|
+
imports: List[str] = ["xgboost"]
|
60
|
+
pkgDependencies: List[str] = [
|
61
|
+
f"numpy=={np.__version__}",
|
62
|
+
f"xgboost=={xgboost.__version__}",
|
63
|
+
f"cloudpickle=={cp.__version__}",
|
64
|
+
]
|
65
|
+
super().__init__(imports=imports, pkgDependencies=pkgDependencies)
|
66
|
+
|
67
|
+
|
68
|
+
class LightGBMModelSpecifications(ModelSpecifications):
|
69
|
+
def __init__(self) -> None:
|
70
|
+
import lightgbm
|
71
|
+
|
72
|
+
imports: List[str] = ["lightgbm"]
|
73
|
+
pkgDependencies: List[str] = [
|
74
|
+
f"numpy=={np.__version__}",
|
75
|
+
f"lightgbm=={lightgbm.__version__}",
|
76
|
+
f"cloudpickle=={cp.__version__}",
|
77
|
+
]
|
78
|
+
super().__init__(imports=imports, pkgDependencies=pkgDependencies)
|
79
|
+
|
80
|
+
|
81
|
+
class SklearnModelSelectionModelSpecifications(ModelSpecifications):
|
82
|
+
def __init__(self) -> None:
|
83
|
+
import sklearn
|
84
|
+
import xgboost
|
85
|
+
|
86
|
+
imports: List[str] = ["sklearn", "xgboost"]
|
87
|
+
pkgDependencies: List[str] = [
|
88
|
+
f"numpy=={np.__version__}",
|
89
|
+
f"scikit-learn=={sklearn.__version__}",
|
90
|
+
f"cloudpickle=={cp.__version__}",
|
91
|
+
f"xgboost=={xgboost.__version__}",
|
92
|
+
]
|
93
|
+
|
94
|
+
# Only include lightgbm in the dependencies if it is installed.
|
95
|
+
try:
|
96
|
+
import lightgbm
|
97
|
+
except ModuleNotFoundError:
|
98
|
+
pass
|
99
|
+
else:
|
100
|
+
imports.append("lightgbm")
|
101
|
+
pkgDependencies.append(f"lightgbm=={lightgbm.__version__}")
|
102
|
+
|
103
|
+
super().__init__(imports=imports, pkgDependencies=pkgDependencies)
|
104
|
+
|
105
|
+
|
106
|
+
class ModelSpecificationsBuilder:
|
107
|
+
"""
|
108
|
+
A factory class to build ModelSpecifications object for different types of models.
|
109
|
+
"""
|
110
|
+
|
111
|
+
@classmethod
|
112
|
+
def build(cls, model: object) -> ModelSpecifications:
|
113
|
+
"""
|
114
|
+
A static factory method that builds ModelSpecifications object based on the module name of native model object.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
model: Native model object to be trained.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
Appropriate ModelSpecification object
|
121
|
+
|
122
|
+
Raises:
|
123
|
+
SnowflakeMLException: Raises an exception the module of given model can't be determined.
|
124
|
+
TypeError: Raises the exception for unsupported modules.
|
125
|
+
"""
|
126
|
+
module = inspect.getmodule(model)
|
127
|
+
if module is None:
|
128
|
+
raise exceptions.SnowflakeMLException(
|
129
|
+
error_code=error_codes.INVALID_TYPE,
|
130
|
+
original_exception=ValueError("Unable to infer model type of the given native model object."),
|
131
|
+
)
|
132
|
+
root_module_name = module.__name__.split(".")[0]
|
133
|
+
if root_module_name == "sklearn":
|
134
|
+
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
135
|
+
|
136
|
+
if isinstance(model, GridSearchCV) or isinstance(model, RandomizedSearchCV):
|
137
|
+
return SklearnModelSelectionModelSpecifications()
|
138
|
+
return SKLearnModelSpecifications()
|
139
|
+
elif root_module_name == "xgboost":
|
140
|
+
return XGBoostModelSpecifications()
|
141
|
+
elif root_module_name == "lightgbm":
|
142
|
+
return LightGBMModelSpecifications()
|
143
|
+
else:
|
144
|
+
raise TypeError(
|
145
|
+
f"Unexpected module type: {root_module_name}." "Supported module types: sklearn, xgboost, lightgbm."
|
146
|
+
)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from typing import Protocol
|
2
|
+
|
3
|
+
|
4
|
+
class ModelTrainer(Protocol):
|
5
|
+
"""
|
6
|
+
Interface for model trainer implementations.
|
7
|
+
|
8
|
+
There are multiple flavors of training like training with pandas datasets, training with
|
9
|
+
Snowpark datasets using sprocs, and out of core training with Snowpark datasets etc.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def train(self) -> object:
|
13
|
+
raise NotImplementedError
|
@@ -0,0 +1,78 @@
|
|
1
|
+
from typing import List, Optional, Union
|
2
|
+
|
3
|
+
import pandas as pd
|
4
|
+
from sklearn import model_selection
|
5
|
+
|
6
|
+
from snowflake.ml.modeling._internal.distributed_hpo_trainer import (
|
7
|
+
DistributedHPOTrainer,
|
8
|
+
)
|
9
|
+
from snowflake.ml.modeling._internal.estimator_utils import is_single_node
|
10
|
+
from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
|
11
|
+
from snowflake.ml.modeling._internal.pandas_trainer import PandasModelTrainer
|
12
|
+
from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
|
13
|
+
from snowflake.snowpark import DataFrame, Session
|
14
|
+
|
15
|
+
_PROJECT = "ModelDevelopment"
|
16
|
+
|
17
|
+
|
18
|
+
class ModelTrainerBuilder:
|
19
|
+
"""
|
20
|
+
A builder class to create instances of ModelTrainer for different models and training conditions.
|
21
|
+
|
22
|
+
This class provides methods to build instances of ModelTrainer tailored to specific machine learning
|
23
|
+
models and training configurations like dataset's location etc. It abstracts the creation process,
|
24
|
+
allowing the user to obtain a configured ModelTrainer for a particular model architecture or configuration.
|
25
|
+
"""
|
26
|
+
|
27
|
+
_ENABLE_DISTRIBUTED = True
|
28
|
+
|
29
|
+
@classmethod
|
30
|
+
def _check_if_distributed_hpo_enabled(cls, session: Session) -> bool:
|
31
|
+
return not is_single_node(session) and ModelTrainerBuilder._ENABLE_DISTRIBUTED is True
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def build(
|
35
|
+
cls,
|
36
|
+
estimator: object,
|
37
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
38
|
+
input_cols: Optional[List[str]] = None,
|
39
|
+
label_cols: Optional[List[str]] = None,
|
40
|
+
sample_weight_col: Optional[str] = None,
|
41
|
+
autogenerated: bool = False,
|
42
|
+
subproject: str = "",
|
43
|
+
) -> ModelTrainer:
|
44
|
+
"""
|
45
|
+
Builder method that creates an approproiate ModelTrainer instance based on the given params.
|
46
|
+
"""
|
47
|
+
assert input_cols is not None # Make MyPy happpy
|
48
|
+
if isinstance(dataset, pd.DataFrame):
|
49
|
+
return PandasModelTrainer(
|
50
|
+
estimator=estimator,
|
51
|
+
dataset=dataset,
|
52
|
+
input_cols=input_cols,
|
53
|
+
label_cols=label_cols,
|
54
|
+
sample_weight_col=sample_weight_col,
|
55
|
+
)
|
56
|
+
elif isinstance(dataset, DataFrame):
|
57
|
+
trainer_klass = SnowparkModelTrainer
|
58
|
+
assert dataset._session is not None # Make MyPy happpy
|
59
|
+
if isinstance(estimator, model_selection.GridSearchCV) or isinstance(
|
60
|
+
estimator, model_selection.RandomizedSearchCV
|
61
|
+
):
|
62
|
+
if ModelTrainerBuilder._check_if_distributed_hpo_enabled(session=dataset._session):
|
63
|
+
trainer_klass = DistributedHPOTrainer
|
64
|
+
return trainer_klass(
|
65
|
+
estimator=estimator,
|
66
|
+
dataset=dataset,
|
67
|
+
session=dataset._session,
|
68
|
+
input_cols=input_cols,
|
69
|
+
label_cols=label_cols,
|
70
|
+
sample_weight_col=sample_weight_col,
|
71
|
+
autogenerated=autogenerated,
|
72
|
+
subproject=subproject,
|
73
|
+
)
|
74
|
+
else:
|
75
|
+
raise TypeError(
|
76
|
+
f"Unexpected dataset type: {type(dataset)}."
|
77
|
+
"Supported dataset types: snowpark.DataFrame, pandas.DataFrame."
|
78
|
+
)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import inspect
|
2
|
+
from typing import List, Optional
|
3
|
+
|
4
|
+
import pandas as pd
|
5
|
+
|
6
|
+
|
7
|
+
class PandasModelTrainer:
|
8
|
+
"""
|
9
|
+
A class for training machine learning models using Pandas datasets.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
estimator: object,
|
15
|
+
dataset: pd.DataFrame,
|
16
|
+
input_cols: List[str],
|
17
|
+
label_cols: Optional[List[str]],
|
18
|
+
sample_weight_col: Optional[str],
|
19
|
+
) -> None:
|
20
|
+
"""
|
21
|
+
Initializes the PandasModelTrainer with a model, a Pandas DataFrame, feature, and label column names.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
estimator: SKLearn compatible estimator or transformer object.
|
25
|
+
dataset: The dataset used for training the model.
|
26
|
+
input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training.
|
27
|
+
label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn.
|
28
|
+
sample_weight_col: The column name representing the weight of training examples.
|
29
|
+
"""
|
30
|
+
self.estimator = estimator
|
31
|
+
self.dataset = dataset
|
32
|
+
self.input_cols = input_cols
|
33
|
+
self.label_cols = label_cols
|
34
|
+
self.sample_weight_col = sample_weight_col
|
35
|
+
|
36
|
+
def train(self) -> object:
|
37
|
+
"""
|
38
|
+
Trains the model using specified features and target columns from the dataset.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Trained model
|
42
|
+
"""
|
43
|
+
assert hasattr(self.estimator, "fit") # Keep mypy happy
|
44
|
+
argspec = inspect.getfullargspec(self.estimator.fit)
|
45
|
+
args = {"X": self.dataset[self.input_cols]}
|
46
|
+
|
47
|
+
if self.label_cols:
|
48
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
49
|
+
args[label_arg_name] = self.dataset[self.label_cols].squeeze()
|
50
|
+
|
51
|
+
if self.sample_weight_col is not None and "sample_weight" in argspec.args:
|
52
|
+
args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
|
53
|
+
|
54
|
+
return self.estimator.fit(**args)
|