snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,25 @@
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
4
|
-
|
5
|
-
|
6
|
-
sql_identifier,
|
7
|
-
)
|
8
|
-
from snowflake.snowpark import row, session
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
5
|
+
from snowflake.snowpark import row
|
9
6
|
|
10
7
|
|
11
|
-
class ModuleTagSQLClient:
|
12
|
-
def __init__(
|
13
|
-
self,
|
14
|
-
session: session.Session,
|
15
|
-
*,
|
16
|
-
database_name: sql_identifier.SqlIdentifier,
|
17
|
-
schema_name: sql_identifier.SqlIdentifier,
|
18
|
-
) -> None:
|
19
|
-
self._session = session
|
20
|
-
self._database_name = database_name
|
21
|
-
self._schema_name = schema_name
|
22
|
-
|
23
|
-
def __eq__(self, __value: object) -> bool:
|
24
|
-
if not isinstance(__value, ModuleTagSQLClient):
|
25
|
-
return False
|
26
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
27
|
-
|
28
|
-
def fully_qualified_module_name(
|
29
|
-
self,
|
30
|
-
module_name: sql_identifier.SqlIdentifier,
|
31
|
-
) -> str:
|
32
|
-
return identifier.get_schema_level_object_identifier(
|
33
|
-
self._database_name.identifier(), self._schema_name.identifier(), module_name.identifier()
|
34
|
-
)
|
35
|
-
|
8
|
+
class ModuleTagSQLClient(_base._BaseSQLClient):
|
36
9
|
def set_tag_on_model(
|
37
10
|
self,
|
38
|
-
model_name: sql_identifier.SqlIdentifier,
|
39
11
|
*,
|
40
|
-
|
41
|
-
|
12
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
13
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
14
|
+
model_name: sql_identifier.SqlIdentifier,
|
15
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
16
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
42
17
|
tag_name: sql_identifier.SqlIdentifier,
|
43
18
|
tag_value: str,
|
44
19
|
statement_params: Optional[Dict[str, Any]] = None,
|
45
20
|
) -> None:
|
46
|
-
fq_model_name = self.
|
47
|
-
fq_tag_name =
|
48
|
-
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
49
|
-
)
|
21
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
22
|
+
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
50
23
|
query_result_checker.SqlResultValidator(
|
51
24
|
self._session,
|
52
25
|
f"ALTER MODEL {fq_model_name} SET TAG {fq_tag_name} = $${tag_value}$$",
|
@@ -55,17 +28,17 @@ class ModuleTagSQLClient:
|
|
55
28
|
|
56
29
|
def unset_tag_on_model(
|
57
30
|
self,
|
58
|
-
model_name: sql_identifier.SqlIdentifier,
|
59
31
|
*,
|
60
|
-
|
61
|
-
|
32
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
33
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
34
|
+
model_name: sql_identifier.SqlIdentifier,
|
35
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
36
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
62
37
|
tag_name: sql_identifier.SqlIdentifier,
|
63
38
|
statement_params: Optional[Dict[str, Any]] = None,
|
64
39
|
) -> None:
|
65
|
-
fq_model_name = self.
|
66
|
-
fq_tag_name =
|
67
|
-
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
68
|
-
)
|
40
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
41
|
+
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
69
42
|
query_result_checker.SqlResultValidator(
|
70
43
|
self._session,
|
71
44
|
f"ALTER MODEL {fq_model_name} UNSET TAG {fq_tag_name}",
|
@@ -74,21 +47,21 @@ class ModuleTagSQLClient:
|
|
74
47
|
|
75
48
|
def get_tag_value(
|
76
49
|
self,
|
77
|
-
module_name: sql_identifier.SqlIdentifier,
|
78
50
|
*,
|
79
|
-
|
80
|
-
|
51
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
52
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
53
|
+
model_name: sql_identifier.SqlIdentifier,
|
54
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
55
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
81
56
|
tag_name: sql_identifier.SqlIdentifier,
|
82
57
|
statement_params: Optional[Dict[str, Any]] = None,
|
83
58
|
) -> row.Row:
|
84
|
-
|
85
|
-
fq_tag_name =
|
86
|
-
tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
|
87
|
-
)
|
59
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
60
|
+
fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
|
88
61
|
return (
|
89
62
|
query_result_checker.SqlResultValidator(
|
90
63
|
self._session,
|
91
|
-
f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${
|
64
|
+
f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_model_name}$$, 'MODULE') AS TAG_VALUE",
|
92
65
|
statement_params=statement_params,
|
93
66
|
)
|
94
67
|
.has_dimensions(expected_rows=1, expected_cols=1)
|
@@ -98,16 +71,19 @@ class ModuleTagSQLClient:
|
|
98
71
|
|
99
72
|
def get_tag_list(
|
100
73
|
self,
|
101
|
-
module_name: sql_identifier.SqlIdentifier,
|
102
74
|
*,
|
75
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
76
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
77
|
+
model_name: sql_identifier.SqlIdentifier,
|
103
78
|
statement_params: Optional[Dict[str, Any]] = None,
|
104
79
|
) -> List[row.Row]:
|
105
|
-
|
80
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
81
|
+
actual_database_name = database_name or self._database_name
|
106
82
|
return (
|
107
83
|
query_result_checker.SqlResultValidator(
|
108
84
|
self._session,
|
109
85
|
f"""SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE
|
110
|
-
FROM TABLE({
|
86
|
+
FROM TABLE({actual_database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_model_name}$$, 'MODULE'))""",
|
111
87
|
statement_params=statement_params,
|
112
88
|
)
|
113
89
|
.has_column("TAG_DATABASE", allow_empty=True)
|
@@ -37,6 +37,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
37
37
|
session: snowpark.Session,
|
38
38
|
artifact_stage_location: str,
|
39
39
|
compute_pool: str,
|
40
|
+
job_name: str,
|
40
41
|
external_access_integrations: List[str],
|
41
42
|
) -> None:
|
42
43
|
"""Initialization
|
@@ -49,6 +50,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
49
50
|
artifact_stage_location: Spec file and future deployment related artifacts will be stored under
|
50
51
|
{stage}/models/{model_id}
|
51
52
|
compute_pool: The compute pool used to run docker image build workload.
|
53
|
+
job_name: job_name to use.
|
52
54
|
external_access_integrations: EAIs for network connection.
|
53
55
|
"""
|
54
56
|
self.context_dir = context_dir
|
@@ -58,6 +60,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
58
60
|
self.artifact_stage_location = artifact_stage_location
|
59
61
|
self.compute_pool = compute_pool
|
60
62
|
self.external_access_integrations = external_access_integrations
|
63
|
+
self.job_name = job_name
|
61
64
|
self.client = snowservice_client.SnowServiceClient(session)
|
62
65
|
|
63
66
|
assert artifact_stage_location.startswith(
|
@@ -203,8 +206,9 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
203
206
|
)
|
204
207
|
|
205
208
|
def _launch_kaniko_job(self, spec_stage_location: str) -> None:
|
206
|
-
logger.debug("Submitting job for building docker image with kaniko")
|
209
|
+
logger.debug(f"Submitting job {self.job_name} for building docker image with kaniko")
|
207
210
|
self.client.create_job(
|
211
|
+
job_name=self.job_name,
|
208
212
|
compute_pool=self.compute_pool,
|
209
213
|
spec_stage_location=spec_stage_location,
|
210
214
|
external_access_integrations=self.external_access_integrations,
|
@@ -30,6 +30,7 @@ USER mambauser
|
|
30
30
|
|
31
31
|
# Set MAMBA_DOCKERFILE_ACTIVATE=1 to activate the conda environment during build time.
|
32
32
|
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
33
|
+
ARG MAMBA_NO_LOW_SPEED_LIMIT=1
|
33
34
|
|
34
35
|
# Bitsandbytes uses this ENVVAR to determine CUDA library location
|
35
36
|
ENV CONDA_PREFIX=/opt/conda
|
@@ -346,6 +346,7 @@ class SnowServiceDeployment:
|
|
346
346
|
(db, schema, _, _) = identifier.parse_schema_level_object_identifier(service_func_name)
|
347
347
|
|
348
348
|
self._service_name = identifier.get_schema_level_object_identifier(db, schema, f"service_{model_id}")
|
349
|
+
self._job_name = identifier.get_schema_level_object_identifier(db, schema, f"build_{model_id}")
|
349
350
|
# Spec file and future deployment related artifacts will be stored under {stage}/models/{model_id}
|
350
351
|
self._model_artifact_stage_location = posixpath.join(deployment_stage_path, "models", self.id)
|
351
352
|
self.debug_dir: Optional[str] = None
|
@@ -468,6 +469,7 @@ class SnowServiceDeployment:
|
|
468
469
|
session=self.session,
|
469
470
|
artifact_stage_location=self._model_artifact_stage_location,
|
470
471
|
compute_pool=self.options.compute_pool,
|
472
|
+
job_name=self._job_name,
|
471
473
|
external_access_integrations=self.options.external_access_integrations,
|
472
474
|
)
|
473
475
|
else:
|
@@ -17,11 +17,6 @@ class ResourceStatus(Enum):
|
|
17
17
|
INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
|
18
18
|
|
19
19
|
|
20
|
-
RESOURCE_TO_STATUS_FUNCTION_MAPPING = {
|
21
|
-
ResourceType.SERVICE: "SYSTEM$GET_SERVICE_STATUS",
|
22
|
-
ResourceType.JOB: "SYSTEM$GET_JOB_STATUS",
|
23
|
-
}
|
24
|
-
|
25
20
|
PREDICT = "predict"
|
26
21
|
STAGE = "stage"
|
27
22
|
COMPUTE_POOL = "compute_pool"
|
@@ -70,13 +70,16 @@ class SnowServiceClient:
|
|
70
70
|
logger.debug(f"Create service with SQL: \n {sql}")
|
71
71
|
self.session.sql(sql).collect()
|
72
72
|
|
73
|
-
def create_job(
|
73
|
+
def create_job(
|
74
|
+
self, job_name: str, compute_pool: str, spec_stage_location: str, external_access_integrations: List[str]
|
75
|
+
) -> None:
|
74
76
|
"""Execute the job creation SQL command. Note that the job creation is synchronous, hence we execute it in a
|
75
77
|
async way so that we can query the log in the meantime.
|
76
78
|
|
77
79
|
Upon job failure, full job container log will be logged.
|
78
80
|
|
79
81
|
Args:
|
82
|
+
job_name: name of the job
|
80
83
|
compute_pool: name of the compute pool
|
81
84
|
spec_stage_location: path to the stage location where the spec is located at.
|
82
85
|
external_access_integrations: EAIs for network connection.
|
@@ -84,19 +87,18 @@ class SnowServiceClient:
|
|
84
87
|
stage, path = uri.get_stage_and_path(spec_stage_location)
|
85
88
|
sql = textwrap.dedent(
|
86
89
|
f"""
|
87
|
-
EXECUTE SERVICE
|
90
|
+
EXECUTE JOB SERVICE
|
88
91
|
IN COMPUTE POOL {compute_pool}
|
89
92
|
FROM {stage}
|
90
|
-
|
93
|
+
SPECIFICATION_FILE = '{path}'
|
94
|
+
NAME = {job_name}
|
91
95
|
EXTERNAL_ACCESS_INTEGRATIONS = ({', '.join(external_access_integrations)})
|
92
96
|
"""
|
93
97
|
)
|
94
98
|
logger.debug(f"Create job with SQL: \n {sql}")
|
95
|
-
|
96
|
-
cur.execute_async(sql)
|
97
|
-
job_id = cur._sfqid
|
99
|
+
self.session.sql(sql).collect_nowait()
|
98
100
|
self.block_until_resource_is_ready(
|
99
|
-
resource_name=
|
101
|
+
resource_name=job_name,
|
100
102
|
resource_type=constants.ResourceType.JOB,
|
101
103
|
container_name=constants.KANIKO_CONTAINER_NAME,
|
102
104
|
max_retries=240,
|
@@ -182,10 +184,7 @@ class SnowServiceClient:
|
|
182
184
|
"""
|
183
185
|
assert resource_type == constants.ResourceType.SERVICE or resource_type == constants.ResourceType.JOB
|
184
186
|
query_command = ""
|
185
|
-
|
186
|
-
query_command = f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')"
|
187
|
-
elif resource_type == constants.ResourceType.JOB:
|
188
|
-
query_command = f"CALL SYSTEM$GET_JOB_LOGS('{resource_name}', '{container_name}')"
|
187
|
+
query_command = f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')"
|
189
188
|
logger.warning(
|
190
189
|
f"Best-effort log streaming from SPCS will be enabled when python logging level is set to INFO."
|
191
190
|
f"Alternatively, you can also query the logs by running the query '{query_command}'"
|
@@ -201,7 +200,7 @@ class SnowServiceClient:
|
|
201
200
|
)
|
202
201
|
lsp.process_new_logs(resource_log, log_level=logging.INFO)
|
203
202
|
|
204
|
-
status = self.get_resource_status(resource_name=resource_name
|
203
|
+
status = self.get_resource_status(resource_name=resource_name)
|
205
204
|
|
206
205
|
if resource_type == constants.ResourceType.JOB and status == constants.ResourceStatus.DONE:
|
207
206
|
return
|
@@ -246,52 +245,24 @@ class SnowServiceClient:
|
|
246
245
|
def get_resource_log(
|
247
246
|
self, resource_name: str, resource_type: constants.ResourceType, container_name: str
|
248
247
|
) -> Optional[str]:
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
try:
|
259
|
-
row = self.session.sql(f"CALL SYSTEM$GET_JOB_LOGS('{resource_name}', '{container_name}')").collect()
|
260
|
-
return str(row[0]["SYSTEM$GET_JOB_LOGS"])
|
261
|
-
except Exception:
|
262
|
-
return None
|
263
|
-
else:
|
264
|
-
raise snowml_exceptions.SnowflakeMLException(
|
265
|
-
error_code=error_codes.NOT_IMPLEMENTED,
|
266
|
-
original_exception=NotImplementedError(
|
267
|
-
f"{resource_type.name} is not yet supported in get_resource_log function"
|
268
|
-
),
|
269
|
-
)
|
270
|
-
|
271
|
-
def get_resource_status(
|
272
|
-
self, resource_name: str, resource_type: constants.ResourceType
|
273
|
-
) -> Optional[constants.ResourceStatus]:
|
248
|
+
try:
|
249
|
+
row = self.session.sql(
|
250
|
+
f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')"
|
251
|
+
).collect()
|
252
|
+
return str(row[0]["SYSTEM$GET_SERVICE_LOGS"])
|
253
|
+
except Exception:
|
254
|
+
return None
|
255
|
+
|
256
|
+
def get_resource_status(self, resource_name: str) -> Optional[constants.ResourceStatus]:
|
274
257
|
"""Get resource status.
|
275
258
|
|
276
259
|
Args:
|
277
260
|
resource_name: Name of the resource.
|
278
|
-
resource_type: Type of the resource.
|
279
|
-
|
280
|
-
Raises:
|
281
|
-
SnowflakeMLException: If resource type does not have a corresponding system function for querying status.
|
282
|
-
SnowflakeMLException: If corresponding status call failed.
|
283
261
|
|
284
262
|
Returns:
|
285
263
|
Optional[constants.ResourceStatus]: The status of the resource, or None if the resource status is empty.
|
286
264
|
"""
|
287
|
-
|
288
|
-
raise snowml_exceptions.SnowflakeMLException(
|
289
|
-
error_code=error_codes.INVALID_ARGUMENT,
|
290
|
-
original_exception=ValueError(
|
291
|
-
f"Status querying is not supported for resources of type '{resource_type}'."
|
292
|
-
),
|
293
|
-
)
|
294
|
-
status_func = constants.RESOURCE_TO_STATUS_FUNCTION_MAPPING[resource_type]
|
265
|
+
status_func = "SYSTEM$GET_SERVICE_STATUS"
|
295
266
|
try:
|
296
267
|
row = self.session.sql(f"CALL {status_func}('{resource_name}');").collect()
|
297
268
|
except Exception:
|
@@ -8,8 +8,10 @@ from typing import Any, Dict, List, Optional
|
|
8
8
|
|
9
9
|
from absl import logging
|
10
10
|
from packaging import requirements
|
11
|
+
from typing_extensions import deprecated
|
11
12
|
|
12
13
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
|
+
from snowflake.ml._internal.lineage import data_source, lineage_utils
|
13
15
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
14
16
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
15
17
|
from snowflake.ml.model._packager import model_packager
|
@@ -134,6 +136,7 @@ class ModelComposer:
|
|
134
136
|
model_meta=self.packager.meta,
|
135
137
|
model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
|
136
138
|
options=options,
|
139
|
+
data_sources=self._get_data_sources(model),
|
137
140
|
)
|
138
141
|
|
139
142
|
file_utils.upload_directory_to_stage(
|
@@ -143,7 +146,8 @@ class ModelComposer:
|
|
143
146
|
statement_params=self._statement_params,
|
144
147
|
)
|
145
148
|
|
146
|
-
|
149
|
+
@deprecated("Only used by PrPr model registry. Use static method version of load instead.")
|
150
|
+
def legacy_load(
|
147
151
|
self,
|
148
152
|
*,
|
149
153
|
meta_only: bool = False,
|
@@ -163,3 +167,20 @@ class ModelComposer:
|
|
163
167
|
with zipfile.ZipFile(self.model_local_path, mode="r", compression=zipfile.ZIP_DEFLATED) as zf:
|
164
168
|
zf.extractall(path=self._packager_workspace_path)
|
165
169
|
self.packager.load(meta_only=meta_only, options=options)
|
170
|
+
|
171
|
+
@staticmethod
|
172
|
+
def load(
|
173
|
+
workspace_path: pathlib.Path,
|
174
|
+
*,
|
175
|
+
meta_only: bool = False,
|
176
|
+
options: Optional[model_types.ModelLoadOption] = None,
|
177
|
+
) -> model_packager.ModelPackager:
|
178
|
+
mp = model_packager.ModelPackager(str(workspace_path / ModelComposer.MODEL_DIR_REL_PATH))
|
179
|
+
mp.load(meta_only=meta_only, options=options)
|
180
|
+
return mp
|
181
|
+
|
182
|
+
def _get_data_sources(self, model: model_types.SupportedModelType) -> Optional[List[data_source.DataSource]]:
|
183
|
+
data_sources = getattr(model, lineage_utils.DATA_SOURCES_ATTR, None)
|
184
|
+
if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
|
185
|
+
return data_sources
|
186
|
+
return None
|
@@ -5,6 +5,7 @@ from typing import List, Optional, cast
|
|
5
5
|
|
6
6
|
import yaml
|
7
7
|
|
8
|
+
from snowflake.ml._internal.lineage import data_source
|
8
9
|
from snowflake.ml.model import type_hints
|
9
10
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
10
11
|
from snowflake.ml.model._model_composer.model_method import (
|
@@ -36,6 +37,7 @@ class ModelManifest:
|
|
36
37
|
model_meta: model_meta_api.ModelMetadata,
|
37
38
|
model_file_rel_path: pathlib.PurePosixPath,
|
38
39
|
options: Optional[type_hints.ModelSaveOption] = None,
|
40
|
+
data_sources: Optional[List[data_source.DataSource]] = None,
|
39
41
|
) -> None:
|
40
42
|
if options is None:
|
41
43
|
options = {}
|
@@ -90,6 +92,10 @@ class ModelManifest:
|
|
90
92
|
],
|
91
93
|
)
|
92
94
|
|
95
|
+
lineage_sources = self._extract_lineage_info(data_sources)
|
96
|
+
if lineage_sources:
|
97
|
+
manifest_dict["lineage_sources"] = lineage_sources
|
98
|
+
|
93
99
|
with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
|
94
100
|
# Anchors are not supported in the server, avoid that.
|
95
101
|
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
@@ -108,3 +114,19 @@ class ModelManifest:
|
|
108
114
|
res = cast(model_manifest_schema.ModelManifestDict, raw_input)
|
109
115
|
|
110
116
|
return res
|
117
|
+
|
118
|
+
def _extract_lineage_info(
|
119
|
+
self, data_sources: Optional[List[data_source.DataSource]]
|
120
|
+
) -> List[model_manifest_schema.LineageSourceDict]:
|
121
|
+
result = []
|
122
|
+
if data_sources:
|
123
|
+
for source in data_sources:
|
124
|
+
result.append(
|
125
|
+
model_manifest_schema.LineageSourceDict(
|
126
|
+
# Currently, we only support lineage from Dataset.
|
127
|
+
type=model_manifest_schema.LineageSourceTypes.DATASET.value,
|
128
|
+
entity=source.fully_qualified_name,
|
129
|
+
version=source.version,
|
130
|
+
)
|
131
|
+
)
|
132
|
+
return result
|
@@ -75,8 +75,19 @@ class SnowparkMLDataDict(TypedDict):
|
|
75
75
|
functions: Required[List[ModelFunctionInfoDict]]
|
76
76
|
|
77
77
|
|
78
|
+
class LineageSourceTypes(enum.Enum):
|
79
|
+
DATASET = "DATASET"
|
80
|
+
|
81
|
+
|
82
|
+
class LineageSourceDict(TypedDict):
|
83
|
+
type: Required[str]
|
84
|
+
entity: Required[str]
|
85
|
+
version: NotRequired[str]
|
86
|
+
|
87
|
+
|
78
88
|
class ModelManifestDict(TypedDict):
|
79
89
|
manifest_version: Required[str]
|
80
90
|
runtimes: Required[Dict[str, ModelRuntimeDict]]
|
81
91
|
methods: Required[List[ModelMethodDict]]
|
82
92
|
user_data: NotRequired[Dict[str, Any]]
|
93
|
+
lineage_sources: NotRequired[List[LineageSourceDict]]
|
@@ -284,6 +284,7 @@ class ModelEnv:
|
|
284
284
|
" This may prevent model deploying to Snowflake Warehouse."
|
285
285
|
),
|
286
286
|
category=UserWarning,
|
287
|
+
stacklevel=2,
|
287
288
|
)
|
288
289
|
if len(channel_dependencies) == 0 and channel not in self._conda_dependencies:
|
289
290
|
warnings.warn(
|
@@ -292,6 +293,7 @@ class ModelEnv:
|
|
292
293
|
" This may prevent model deploying to Snowflake Warehouse."
|
293
294
|
),
|
294
295
|
category=UserWarning,
|
296
|
+
stacklevel=2,
|
295
297
|
)
|
296
298
|
self._conda_dependencies[channel] = []
|
297
299
|
|
@@ -307,6 +309,7 @@ class ModelEnv:
|
|
307
309
|
" This may be unintentional."
|
308
310
|
),
|
309
311
|
category=UserWarning,
|
312
|
+
stacklevel=2,
|
310
313
|
)
|
311
314
|
|
312
315
|
if pip_requirements_list:
|
@@ -316,6 +319,7 @@ class ModelEnv:
|
|
316
319
|
" This may prevent model deploying to Snowflake Warehouse."
|
317
320
|
),
|
318
321
|
category=UserWarning,
|
322
|
+
stacklevel=2,
|
319
323
|
)
|
320
324
|
for pip_dependency in pip_requirements_list:
|
321
325
|
if any(
|
@@ -338,6 +342,7 @@ class ModelEnv:
|
|
338
342
|
" This may prevent model deploying to Snowflake Warehouse."
|
339
343
|
),
|
340
344
|
category=UserWarning,
|
345
|
+
stacklevel=2,
|
341
346
|
)
|
342
347
|
for pip_dependency in pip_requirements_list:
|
343
348
|
if any(
|
@@ -372,3 +377,39 @@ class ModelEnv:
|
|
372
377
|
"cuda_version": self.cuda_version,
|
373
378
|
"snowpark_ml_version": self.snowpark_ml_version,
|
374
379
|
}
|
380
|
+
|
381
|
+
def validate_with_local_env(
|
382
|
+
self, check_snowpark_ml_version: bool = False
|
383
|
+
) -> List[env_utils.IncorrectLocalEnvironmentError]:
|
384
|
+
errors = []
|
385
|
+
try:
|
386
|
+
env_utils.validate_py_runtime_version(str(self._python_version))
|
387
|
+
except env_utils.IncorrectLocalEnvironmentError as e:
|
388
|
+
errors.append(e)
|
389
|
+
|
390
|
+
for conda_reqs in self._conda_dependencies.values():
|
391
|
+
for conda_req in conda_reqs:
|
392
|
+
try:
|
393
|
+
env_utils.validate_local_installed_version_of_pip_package(
|
394
|
+
env_utils.try_convert_conda_requirement_to_pip(conda_req)
|
395
|
+
)
|
396
|
+
except env_utils.IncorrectLocalEnvironmentError as e:
|
397
|
+
errors.append(e)
|
398
|
+
|
399
|
+
for pip_req in self._pip_requirements:
|
400
|
+
try:
|
401
|
+
env_utils.validate_local_installed_version_of_pip_package(pip_req)
|
402
|
+
except env_utils.IncorrectLocalEnvironmentError as e:
|
403
|
+
errors.append(e)
|
404
|
+
|
405
|
+
if check_snowpark_ml_version:
|
406
|
+
# For Modeling model
|
407
|
+
if self._snowpark_ml_version.base_version != snowml_env.VERSION:
|
408
|
+
errors.append(
|
409
|
+
env_utils.IncorrectLocalEnvironmentError(
|
410
|
+
f"The local installed version of Snowpark ML library is {snowml_env.VERSION} "
|
411
|
+
f"which differs from required version {self.snowpark_ml_version}."
|
412
|
+
)
|
413
|
+
)
|
414
|
+
|
415
|
+
return errors
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import pathlib
|
2
3
|
import tempfile
|
3
4
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
4
5
|
|
@@ -45,7 +46,7 @@ def _parse_mlflow_env(model_uri: str, env: model_env.ModelEnv) -> model_env.Mode
|
|
45
46
|
if not os.path.exists(conda_env_file_path):
|
46
47
|
raise ValueError("Cannot load MLFlow model dependencies.")
|
47
48
|
|
48
|
-
env.load_from_conda_file(conda_env_file_path)
|
49
|
+
env.load_from_conda_file(pathlib.Path(conda_env_file_path))
|
49
50
|
|
50
51
|
return env
|
51
52
|
|
@@ -320,11 +320,7 @@ class ModelMetadata:
|
|
320
320
|
|
321
321
|
with open(model_yaml_path, "w", encoding="utf-8") as out:
|
322
322
|
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
323
|
-
yaml.safe_dump(
|
324
|
-
model_dict,
|
325
|
-
stream=out,
|
326
|
-
default_flow_style=False,
|
327
|
-
)
|
323
|
+
yaml.safe_dump(model_dict, stream=out, default_flow_style=False)
|
328
324
|
|
329
325
|
@staticmethod
|
330
326
|
def _validate_model_metadata(loaded_meta: Any) -> model_meta_schema.ModelMetadataDict:
|
@@ -4,7 +4,6 @@ from typing import Dict, List, Optional
|
|
4
4
|
|
5
5
|
from absl import logging
|
6
6
|
|
7
|
-
from snowflake.ml._internal import env_utils
|
8
7
|
from snowflake.ml._internal.exceptions import (
|
9
8
|
error_codes,
|
10
9
|
exceptions as snowml_exceptions,
|
@@ -129,8 +128,6 @@ class ModelPackager:
|
|
129
128
|
|
130
129
|
model_meta.load_code_path(self.local_dir_path)
|
131
130
|
|
132
|
-
env_utils.validate_py_runtime_version(self.meta.env.python_version)
|
133
|
-
|
134
131
|
handler = model_handler.load_handler(self.meta.model_type)
|
135
132
|
if handler is None:
|
136
133
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -3,6 +3,8 @@ from typing import List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
|
6
|
+
from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
|
7
|
+
|
6
8
|
|
7
9
|
class PandasModelTrainer:
|
8
10
|
"""
|
@@ -72,11 +74,61 @@ class PandasModelTrainer:
|
|
72
74
|
Tuple[pd.DataFrame, object]: [predicted dataset, estimator]
|
73
75
|
"""
|
74
76
|
assert hasattr(self.estimator, "fit_predict") # make type checker happy
|
75
|
-
|
76
|
-
result = self.estimator.fit_predict(**args)
|
77
|
+
result = self.estimator.fit_predict(X=self.dataset[self.input_cols])
|
77
78
|
result_df = pd.DataFrame(data=result, columns=expected_output_cols_list)
|
78
79
|
if drop_input_cols:
|
79
80
|
result_df = result_df
|
80
81
|
else:
|
81
|
-
|
82
|
+
# in case the output column name overlap with the input column names,
|
83
|
+
# remove the ones in input column names
|
84
|
+
remove_dataset_col_name_exist_in_output_col = list(
|
85
|
+
set(self.dataset.columns) - set(expected_output_cols_list)
|
86
|
+
)
|
87
|
+
result_df = pd.concat([self.dataset[remove_dataset_col_name_exist_in_output_col], result_df], axis=1)
|
88
|
+
return (result_df, self.estimator)
|
89
|
+
|
90
|
+
def train_fit_transform(
|
91
|
+
self,
|
92
|
+
expected_output_cols_list: List[str],
|
93
|
+
drop_input_cols: Optional[bool] = False,
|
94
|
+
) -> Tuple[pd.DataFrame, object]:
|
95
|
+
"""Trains the model using specified features and target columns from the dataset.
|
96
|
+
This API is different from fit itself because it would also provide the transform
|
97
|
+
output.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
expected_output_cols_list (List[str]): The output columns
|
101
|
+
name as a list. Defaults to None.
|
102
|
+
drop_input_cols (Optional[bool]): Boolean to determine whether to
|
103
|
+
drop the input columns from the output dataset.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
Tuple[pd.DataFrame, object]: [transformed dataset, estimator]
|
107
|
+
"""
|
108
|
+
assert hasattr(self.estimator, "fit") # make type checker happy
|
109
|
+
assert hasattr(self.estimator, "fit_transform") # make type checker happy
|
110
|
+
|
111
|
+
argspec = inspect.getfullargspec(self.estimator.fit)
|
112
|
+
args = {"X": self.dataset[self.input_cols]}
|
113
|
+
if self.label_cols:
|
114
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
115
|
+
args[label_arg_name] = self.dataset[self.label_cols].squeeze()
|
116
|
+
|
117
|
+
if self.sample_weight_col is not None and "sample_weight" in argspec.args:
|
118
|
+
args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
|
119
|
+
|
120
|
+
inference_res = self.estimator.fit_transform(**args)
|
121
|
+
|
122
|
+
transformed_numpy_array, output_cols = handle_inference_result(
|
123
|
+
inference_res=inference_res, output_cols=expected_output_cols_list, inference_method="fit_transform"
|
124
|
+
)
|
125
|
+
|
126
|
+
result_df = pd.DataFrame(data=transformed_numpy_array, columns=output_cols)
|
127
|
+
if drop_input_cols:
|
128
|
+
result_df = result_df
|
129
|
+
else:
|
130
|
+
# in case the output column name overlap with the input column names,
|
131
|
+
# remove the ones in input column names
|
132
|
+
remove_dataset_col_name_exist_in_output_col = list(set(self.dataset.columns) - set(output_cols))
|
133
|
+
result_df = pd.concat([self.dataset[remove_dataset_col_name_exist_in_output_col], result_df], axis=1)
|
82
134
|
return (result_df, self.estimator)
|