snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -1
- snowflake/cortex/_complete.py +240 -16
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_sse_client.py +81 -0
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +34 -10
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
- snowflake/ml/_internal/telemetry.py +26 -0
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/dataset/dataset.py +54 -32
- snowflake/ml/dataset/dataset_factory.py +3 -4
- snowflake/ml/feature_store/feature_store.py +440 -243
- snowflake/ml/feature_store/feature_view.py +61 -9
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +6 -8
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +47 -4
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +7 -6
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_signatures/builtins_handler.py +2 -1
- snowflake/ml/model/_signatures/core.py +13 -1
- snowflake/ml/model/_signatures/pandas_handler.py +2 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +74 -4
- snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +158 -121
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +39 -18
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +88 -134
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +22 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +5 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +5 -3
- snowflake/ml/modeling/cluster/birch.py +5 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +5 -3
- snowflake/ml/modeling/cluster/dbscan.py +5 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +5 -3
- snowflake/ml/modeling/cluster/k_means.py +5 -3
- snowflake/ml/modeling/cluster/mean_shift.py +5 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +5 -3
- snowflake/ml/modeling/cluster/optics.py +5 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +5 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +5 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +5 -3
- snowflake/ml/modeling/compose/column_transformer.py +5 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +5 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +5 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +5 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +5 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +5 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +5 -3
- snowflake/ml/modeling/covariance/oas.py +5 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +5 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +5 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +5 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +5 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +5 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +5 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -3
- snowflake/ml/modeling/decomposition/pca.py +5 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +5 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +5 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +5 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +5 -3
- snowflake/ml/modeling/framework/base.py +3 -8
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +5 -3
- snowflake/ml/modeling/impute/knn_imputer.py +5 -3
- snowflake/ml/modeling/impute/missing_indicator.py +5 -3
- snowflake/ml/modeling/impute/simple_imputer.py +8 -4
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +5 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +5 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +5 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +5 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +5 -3
- snowflake/ml/modeling/manifold/mds.py +5 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +5 -3
- snowflake/ml/modeling/manifold/tsne.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +5 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +5 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +5 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +5 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +5 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +5 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +6 -0
- snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
- snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +53 -11
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +44 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +5 -3
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +51 -7
- snowflake_ml_python-1.5.4.dist-info/RECORD +389 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
- snowflake_ml_python-1.5.2.dist-info/RECORD +0 -384
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,11 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Any, Callable, List
|
3
|
+
from typing import Any, Callable, List, Optional
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml._internal.lineage import data_source
|
7
7
|
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
def _get_datasources(*args: Any) -> List[data_source.DataSource]:
|
12
|
-
"""Helper method for extracting data sources attribute from DataFrames in an argument list"""
|
13
|
-
result = []
|
14
|
-
for arg in args:
|
15
|
-
srcs = getattr(arg, DATA_SOURCES_ATTR, None)
|
16
|
-
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
|
17
|
-
result += srcs
|
18
|
-
return result
|
8
|
+
_DATA_SOURCES_ATTR = "_data_sources"
|
19
9
|
|
20
10
|
|
21
11
|
def _wrap_func(
|
@@ -32,6 +22,37 @@ def _wrap_func(
|
|
32
22
|
return wrapped
|
33
23
|
|
34
24
|
|
25
|
+
def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
|
26
|
+
@functools.wraps(fn)
|
27
|
+
def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
|
28
|
+
df = fn(*args, **kwargs)
|
29
|
+
data_sources = get_data_sources(*args, *kwargs.values())
|
30
|
+
if data_sources:
|
31
|
+
patch_dataframe(df, data_sources, inplace=True)
|
32
|
+
return df
|
33
|
+
|
34
|
+
return wrapped
|
35
|
+
|
36
|
+
|
37
|
+
def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
38
|
+
"""Helper method for extracting data sources attribute from DataFrames in an argument list"""
|
39
|
+
result: Optional[List[data_source.DataSource]] = None
|
40
|
+
for arg in args:
|
41
|
+
srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
|
42
|
+
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
|
43
|
+
if result is None:
|
44
|
+
result = []
|
45
|
+
result += srcs
|
46
|
+
return result
|
47
|
+
|
48
|
+
|
49
|
+
def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
|
50
|
+
"""Helper method for attaching data sources to an object"""
|
51
|
+
if data_sources:
|
52
|
+
assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
|
53
|
+
setattr(obj, _DATA_SOURCES_ATTR, data_sources)
|
54
|
+
|
55
|
+
|
35
56
|
def patch_dataframe(
|
36
57
|
df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False
|
37
58
|
) -> snowpark.DataFrame:
|
@@ -62,7 +83,7 @@ def patch_dataframe(
|
|
62
83
|
]
|
63
84
|
if not inplace:
|
64
85
|
df = copy.copy(df)
|
65
|
-
|
86
|
+
set_data_sources(df, data_sources)
|
66
87
|
for func in funcs:
|
67
88
|
fn = getattr(df, func, None)
|
68
89
|
if fn is not None:
|
@@ -70,18 +91,6 @@ def patch_dataframe(
|
|
70
91
|
return df
|
71
92
|
|
72
93
|
|
73
|
-
def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
|
74
|
-
@functools.wraps(fn)
|
75
|
-
def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
|
76
|
-
df = fn(*args, **kwargs)
|
77
|
-
data_sources = _get_datasources(*args) + _get_datasources(*kwargs.values())
|
78
|
-
if data_sources:
|
79
|
-
patch_dataframe(df, data_sources, inplace=True)
|
80
|
-
return df
|
81
|
-
|
82
|
-
return wrapped
|
83
|
-
|
84
|
-
|
85
94
|
# Class-level monkey-patches
|
86
95
|
for klass, func_list in {
|
87
96
|
snowpark.DataFrame: [
|
@@ -10,6 +10,7 @@ from typing import (
|
|
10
10
|
Dict,
|
11
11
|
Iterable,
|
12
12
|
List,
|
13
|
+
Mapping,
|
13
14
|
Optional,
|
14
15
|
Tuple,
|
15
16
|
TypeVar,
|
@@ -92,6 +93,31 @@ def get_statement_params(
|
|
92
93
|
)
|
93
94
|
|
94
95
|
|
96
|
+
def add_statement_params_custom_tags(
|
97
|
+
statement_params: Optional[Dict[str, Any]], custom_tags: Mapping[str, Any]
|
98
|
+
) -> Dict[str, Any]:
|
99
|
+
"""
|
100
|
+
Add custom_tags to existing statement_params. Overwrite keys in custom_tags dict that already exist.
|
101
|
+
If existing statement_params are not provided, do nothing as the information cannot be effectively tracked.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
statement_params: Existing statement_params dictionary.
|
105
|
+
custom_tags: Dictionary of existing k/v pairs to add as custom_tags
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
new statement_params dictionary with all keys and an updated custom_tags field.
|
109
|
+
"""
|
110
|
+
if not statement_params:
|
111
|
+
return {}
|
112
|
+
existing_custom_tags: Dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
|
113
|
+
existing_custom_tags.update(custom_tags)
|
114
|
+
# NOTE: This can be done with | operator after upgrade from py3.8
|
115
|
+
return {
|
116
|
+
**statement_params,
|
117
|
+
TelemetryField.KEY_CUSTOM_TAGS.value: existing_custom_tags,
|
118
|
+
}
|
119
|
+
|
120
|
+
|
95
121
|
# TODO: we can merge this with get_statement_params after code clean up
|
96
122
|
def get_statement_params_full_func_name(frame: Optional[types.FrameType], class_name: Optional[str] = None) -> str:
|
97
123
|
"""
|
@@ -165,6 +165,20 @@ def parse_schema_level_object_identifier(
|
|
165
165
|
)
|
166
166
|
|
167
167
|
|
168
|
+
def is_fully_qualified_name(name: str) -> bool:
|
169
|
+
"""
|
170
|
+
Checks if a given name is a fully qualified name, which is in the format '<db>.<schema>.<object_name>'.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
name: The name to be checked.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
bool: True if the name is fully qualified, False otherwise.
|
177
|
+
"""
|
178
|
+
res = parse_schema_level_object_identifier(name)
|
179
|
+
return res[0] is not None and res[1] is not None and res[2] is not None and not res[3]
|
180
|
+
|
181
|
+
|
168
182
|
def get_schema_level_object_identifier(
|
169
183
|
db: Optional[str],
|
170
184
|
schema: Optional[str],
|
@@ -1,22 +1,27 @@
|
|
1
1
|
import logging
|
2
2
|
import warnings
|
3
|
+
from typing import List, Optional
|
3
4
|
|
4
5
|
from snowflake import snowpark
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
5
7
|
from snowflake.snowpark import functions, types
|
6
8
|
|
7
9
|
|
8
|
-
def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
|
10
|
+
def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[List[str]] = None) -> snowpark.DataFrame:
|
9
11
|
"""Cast columns in the dataframe to types that are compatible with tensor.
|
10
12
|
|
11
13
|
It assists FileSet.make() in performing implicit data casting.
|
12
14
|
|
13
15
|
Args:
|
14
16
|
df: A snowpark dataframe.
|
17
|
+
ignore_columns: Columns to exclude from casting. These columns will be propagated unchanged.
|
15
18
|
|
16
19
|
Returns:
|
17
20
|
A snowpark dataframe whose data type has been casted.
|
18
21
|
"""
|
19
22
|
|
23
|
+
ignore_cols_set = {sql_identifier.SqlIdentifier(c).identifier() for c in ignore_columns} if ignore_columns else {}
|
24
|
+
|
20
25
|
fields = df.schema.fields
|
21
26
|
selected_cols = []
|
22
27
|
for field in fields:
|
@@ -40,7 +45,9 @@ def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
|
|
40
45
|
dest = field.datatype
|
41
46
|
selected_cols.append(functions.cast(functions.col(src), dest).alias(src))
|
42
47
|
else:
|
43
|
-
if field.
|
48
|
+
if field.column_identifier.name in ignore_cols_set:
|
49
|
+
pass
|
50
|
+
elif field.datatype in (types.DateType(), types.TimestampType(), types.TimeType()):
|
44
51
|
logging.warning(
|
45
52
|
"A Column with DATE or TIMESTAMP data type detected. "
|
46
53
|
"It might not be able to get converted to tensors. "
|
@@ -90,7 +97,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
|
|
90
97
|
" is being automatically converted to DoubleType in the Snowpark DataFrame. "
|
91
98
|
"This automatic conversion may lead to potential precision loss and rounding errors. "
|
92
99
|
"If you wish to prevent this conversion, you should manually perform "
|
93
|
-
"the necessary data type conversion."
|
100
|
+
"the necessary data type conversion.",
|
101
|
+
UserWarning,
|
102
|
+
stacklevel=2,
|
94
103
|
)
|
95
104
|
else:
|
96
105
|
# IntegerType default as NUMBER(38, 0), but
|
@@ -102,7 +111,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
|
|
102
111
|
" is being automatically converted to LongType in the Snowpark DataFrame. "
|
103
112
|
"This automatic conversion may lead to potential precision loss and rounding errors. "
|
104
113
|
"If you wish to prevent this conversion, you should manually perform "
|
105
|
-
"the necessary data type conversion."
|
114
|
+
"the necessary data type conversion.",
|
115
|
+
UserWarning,
|
116
|
+
stacklevel=2,
|
106
117
|
)
|
107
118
|
selected_cols.append(functions.cast(functions.col(src), dest_dtype).alias(src))
|
108
119
|
# TODO: add more type handling or error message
|
snowflake/ml/dataset/dataset.py
CHANGED
@@ -19,6 +19,7 @@ from snowflake.ml._internal.utils import (
|
|
19
19
|
snowpark_dataframe_utils,
|
20
20
|
)
|
21
21
|
from snowflake.ml.dataset import dataset_metadata, dataset_reader
|
22
|
+
from snowflake.ml.lineage import lineage_node
|
22
23
|
from snowflake.snowpark import exceptions as snowpark_exceptions, functions
|
23
24
|
|
24
25
|
_PROJECT = "Dataset"
|
@@ -65,6 +66,20 @@ class DatasetVersion:
|
|
65
66
|
comment: Optional[str] = self._get_property("comment")
|
66
67
|
return comment
|
67
68
|
|
69
|
+
@property
|
70
|
+
def label_cols(self) -> List[str]:
|
71
|
+
metadata = self._get_metadata()
|
72
|
+
if metadata is None or metadata.label_cols is None:
|
73
|
+
return []
|
74
|
+
return metadata.label_cols
|
75
|
+
|
76
|
+
@property
|
77
|
+
def exclude_cols(self) -> List[str]:
|
78
|
+
metadata = self._get_metadata()
|
79
|
+
if metadata is None or metadata.exclude_cols is None:
|
80
|
+
return []
|
81
|
+
return metadata.exclude_cols
|
82
|
+
|
68
83
|
def _get_property(self, property_name: str, default: Any = None) -> Any:
|
69
84
|
if self._properties is None:
|
70
85
|
sql_result = (
|
@@ -91,17 +106,6 @@ class DatasetVersion:
|
|
91
106
|
warnings.warn(f"Metadata parsing failed with error: {e}", UserWarning, stacklevel=2)
|
92
107
|
return self._metadata
|
93
108
|
|
94
|
-
def _get_exclude_cols(self) -> List[str]:
|
95
|
-
metadata = self._get_metadata()
|
96
|
-
if metadata is None:
|
97
|
-
return []
|
98
|
-
cols = []
|
99
|
-
if metadata.exclude_cols:
|
100
|
-
cols.extend(metadata.exclude_cols)
|
101
|
-
if metadata.label_cols:
|
102
|
-
cols.extend(metadata.label_cols)
|
103
|
-
return cols
|
104
|
-
|
105
109
|
def url(self) -> str:
|
106
110
|
"""Returns the URL of the DatasetVersion contents in Snowflake.
|
107
111
|
|
@@ -122,7 +126,7 @@ class DatasetVersion:
|
|
122
126
|
return f"{self.__class__.__name__}(dataset='{self._parent.fully_qualified_name}', version='{self.name}')"
|
123
127
|
|
124
128
|
|
125
|
-
class Dataset:
|
129
|
+
class Dataset(lineage_node.LineageNode):
|
126
130
|
"""Represents a Snowflake Dataset which is organized into versions."""
|
127
131
|
|
128
132
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
@@ -135,18 +139,31 @@ class Dataset:
|
|
135
139
|
selected_version: Optional[str] = None,
|
136
140
|
) -> None:
|
137
141
|
"""Initialize a lazily evaluated Dataset object"""
|
138
|
-
self._session = session
|
139
142
|
self._db = database
|
140
143
|
self._schema = schema
|
141
144
|
self._name = name
|
142
|
-
|
145
|
+
|
146
|
+
super().__init__(
|
147
|
+
session,
|
148
|
+
identifier.get_schema_level_object_identifier(database, schema, name),
|
149
|
+
domain="dataset",
|
150
|
+
version=selected_version,
|
151
|
+
)
|
143
152
|
|
144
153
|
self._version = DatasetVersion(self, selected_version) if selected_version else None
|
145
154
|
self._reader: Optional[dataset_reader.DatasetReader] = None
|
146
155
|
|
156
|
+
def __repr__(self) -> str:
|
157
|
+
return (
|
158
|
+
f"{self.__class__.__name__}(\n"
|
159
|
+
f" name='{self._lineage_node_name}',\n"
|
160
|
+
f" version='{self._version._version if self._version else None}',\n"
|
161
|
+
f")"
|
162
|
+
)
|
163
|
+
|
147
164
|
@property
|
148
165
|
def fully_qualified_name(self) -> str:
|
149
|
-
return self.
|
166
|
+
return self._lineage_node_name
|
150
167
|
|
151
168
|
@property
|
152
169
|
def selected_version(self) -> Optional[DatasetVersion]:
|
@@ -165,10 +182,10 @@ class Dataset:
|
|
165
182
|
self._session,
|
166
183
|
[
|
167
184
|
data_source.DataSource(
|
168
|
-
fully_qualified_name=self.
|
185
|
+
fully_qualified_name=self._lineage_node_name,
|
169
186
|
version=v.name,
|
170
187
|
url=v.url(),
|
171
|
-
exclude_cols=v.
|
188
|
+
exclude_cols=(v.label_cols + v.exclude_cols),
|
172
189
|
)
|
173
190
|
],
|
174
191
|
)
|
@@ -227,9 +244,8 @@ class Dataset:
|
|
227
244
|
try:
|
228
245
|
session.sql(query).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)
|
229
246
|
return Dataset(session, db, schema, ds_name)
|
230
|
-
except snowpark_exceptions.
|
231
|
-
|
232
|
-
if e.message.startswith(dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS):
|
247
|
+
except snowpark_exceptions.SnowparkSQLException as e:
|
248
|
+
if e.sql_error_code == dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS:
|
233
249
|
raise snowml_exceptions.SnowflakeMLException(
|
234
250
|
error_code=error_codes.OBJECT_ALREADY_EXISTS,
|
235
251
|
original_exception=dataset_errors.DatasetExistError(
|
@@ -293,7 +309,7 @@ class Dataset:
|
|
293
309
|
Raises:
|
294
310
|
SnowflakeMLException: The Dataset no longer exists.
|
295
311
|
SnowflakeMLException: The specified Dataset version already exists.
|
296
|
-
snowpark_exceptions.
|
312
|
+
snowpark_exceptions.SnowparkSQLException: An error occurred during Dataset creation.
|
297
313
|
|
298
314
|
Note: During the generation of stage files, data casting will occur. The casting rules are as follows::
|
299
315
|
- Data casting:
|
@@ -318,7 +334,8 @@ class Dataset:
|
|
318
334
|
- DateType(DATE): Not supported. A warning will be logged.
|
319
335
|
- VariantType(VARIANT): Not supported. A warning will be logged.
|
320
336
|
"""
|
321
|
-
|
337
|
+
cast_ignore_cols = (exclude_cols or []) + (label_cols or [])
|
338
|
+
casted_df = snowpark_dataframe_utils.cast_snowpark_dataframe(input_dataframe, ignore_columns=cast_ignore_cols)
|
322
339
|
|
323
340
|
if shuffle:
|
324
341
|
casted_df = casted_df.order_by(functions.random())
|
@@ -364,19 +381,19 @@ class Dataset:
|
|
364
381
|
|
365
382
|
return Dataset(self._session, self._db, self._schema, self._name, version)
|
366
383
|
|
367
|
-
except snowpark_exceptions.
|
368
|
-
if e.
|
384
|
+
except snowpark_exceptions.SnowparkSQLException as e:
|
385
|
+
if e.sql_error_code == dataset_errors.ERRNO_DATASET_NOT_EXIST:
|
369
386
|
raise snowml_exceptions.SnowflakeMLException(
|
370
387
|
error_code=error_codes.NOT_FOUND,
|
371
388
|
original_exception=dataset_errors.DatasetNotExistError(
|
372
389
|
dataset_error_messages.DATASET_NOT_EXIST.format(self.fully_qualified_name)
|
373
390
|
),
|
374
391
|
) from e
|
375
|
-
elif
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
392
|
+
elif e.sql_error_code in {
|
393
|
+
dataset_errors.ERRNO_DATASET_VERSION_ALREADY_EXISTS,
|
394
|
+
dataset_errors.ERRNO_VERSION_ALREADY_EXISTS,
|
395
|
+
dataset_errors.ERRNO_FILES_ALREADY_EXISTING,
|
396
|
+
}:
|
380
397
|
raise snowml_exceptions.SnowflakeMLException(
|
381
398
|
error_code=error_codes.OBJECT_ALREADY_EXISTS,
|
382
399
|
original_exception=dataset_errors.DatasetExistError(
|
@@ -432,9 +449,8 @@ class Dataset:
|
|
432
449
|
.has_column(_DATASET_VERSION_NAME_COL, allow_empty=True)
|
433
450
|
.validate()
|
434
451
|
)
|
435
|
-
except snowpark_exceptions.
|
436
|
-
|
437
|
-
if e.message.startswith(dataset_errors.ERRNO_OBJECT_NOT_EXIST):
|
452
|
+
except snowpark_exceptions.SnowparkSQLException as e:
|
453
|
+
if e.sql_error_code == dataset_errors.ERRNO_OBJECT_NOT_EXIST:
|
438
454
|
raise snowml_exceptions.SnowflakeMLException(
|
439
455
|
error_code=error_codes.NOT_FOUND,
|
440
456
|
original_exception=dataset_errors.DatasetNotExistError(
|
@@ -456,6 +472,12 @@ class Dataset:
|
|
456
472
|
),
|
457
473
|
)
|
458
474
|
|
475
|
+
@staticmethod
|
476
|
+
def _load_from_lineage_node(session: snowpark.Session, name: str, version: str) -> "Dataset":
|
477
|
+
return Dataset.load(session, name).select_version(version)
|
478
|
+
|
479
|
+
|
480
|
+
lineage_node.DOMAIN_LINEAGE_REGISTRY["dataset"] = Dataset
|
459
481
|
|
460
482
|
# Utility methods
|
461
483
|
|
@@ -16,8 +16,7 @@ def create_from_dataframe(
|
|
16
16
|
**version_kwargs: Any,
|
17
17
|
) -> dataset.Dataset:
|
18
18
|
"""
|
19
|
-
Create a new versioned Dataset from a DataFrame
|
20
|
-
a DatasetReader for the newly created Dataset version.
|
19
|
+
Create a new versioned Dataset from a DataFrame.
|
21
20
|
|
22
21
|
Args:
|
23
22
|
session: The Snowpark Session instance to use.
|
@@ -39,7 +38,7 @@ def create_from_dataframe(
|
|
39
38
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
40
39
|
def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.Dataset:
|
41
40
|
"""
|
42
|
-
Load a versioned Dataset
|
41
|
+
Load a versioned Dataset.
|
43
42
|
|
44
43
|
Args:
|
45
44
|
session: The Snowpark Session instance to use.
|
@@ -47,7 +46,7 @@ def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.
|
|
47
46
|
version: The dataset version name.
|
48
47
|
|
49
48
|
Returns:
|
50
|
-
A
|
49
|
+
A Dataset object.
|
51
50
|
"""
|
52
51
|
ds: dataset.Dataset = dataset.Dataset.load(session, name).select_version(version)
|
53
52
|
return ds
|