snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import json
|
4
|
+
import re
|
4
5
|
from collections import OrderedDict
|
5
|
-
from dataclasses import dataclass
|
6
|
+
from dataclasses import asdict, dataclass
|
6
7
|
from enum import Enum
|
7
|
-
from typing import Dict, List, Optional
|
8
|
+
from typing import Any, Dict, List, Optional
|
8
9
|
|
9
10
|
from snowflake.ml._internal.exceptions import (
|
10
11
|
error_codes,
|
@@ -26,21 +27,45 @@ from snowflake.snowpark.types import (
|
|
26
27
|
)
|
27
28
|
|
28
29
|
_FEATURE_VIEW_NAME_DELIMITER = "$"
|
29
|
-
|
30
|
+
_LEGACY_TIMESTAMP_COL_PLACEHOLDER_VALS = ["FS_TIMESTAMP_COL_PLACEHOLDER_VAL", "NULL"]
|
31
|
+
_TIMESTAMP_COL_PLACEHOLDER = "NULL"
|
30
32
|
_FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE"
|
33
|
+
# Feature view version rule is aligned with dataset version rule in SQL.
|
34
|
+
_FEATURE_VIEW_VERSION_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_.\-]*$")
|
35
|
+
_FEATURE_VIEW_VERSION_MAX_LENGTH = 128
|
31
36
|
|
32
37
|
|
33
|
-
|
38
|
+
@dataclass(frozen=True)
|
39
|
+
class _FeatureViewMetadata:
|
40
|
+
"""Represent metadata tracked on top of FV backend object"""
|
41
|
+
|
42
|
+
entities: List[str]
|
43
|
+
timestamp_col: str
|
44
|
+
|
45
|
+
def to_json(self) -> str:
|
46
|
+
return json.dumps(asdict(self))
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def from_json(cls, json_str: str) -> _FeatureViewMetadata:
|
50
|
+
state_dict = json.loads(json_str)
|
51
|
+
return cls(**state_dict)
|
52
|
+
|
53
|
+
|
54
|
+
class FeatureViewVersion(str):
|
34
55
|
def __new__(cls, version: str) -> FeatureViewVersion:
|
35
|
-
if
|
56
|
+
if not _FEATURE_VIEW_VERSION_RE.match(version) or len(version) > _FEATURE_VIEW_VERSION_MAX_LENGTH:
|
36
57
|
raise snowml_exceptions.SnowflakeMLException(
|
37
58
|
error_code=error_codes.INVALID_ARGUMENT,
|
38
|
-
original_exception=ValueError(
|
59
|
+
original_exception=ValueError(
|
60
|
+
f"`{version}` is not a valid feature view version. "
|
61
|
+
"It must start with letter or digit, and followed by letter, digit, '_', '-' or '.'. "
|
62
|
+
f"The length limit is {_FEATURE_VIEW_VERSION_MAX_LENGTH}."
|
63
|
+
),
|
39
64
|
)
|
40
|
-
return super().__new__(cls, version)
|
65
|
+
return super().__new__(cls, version)
|
41
66
|
|
42
67
|
def __init__(self, version: str) -> None:
|
43
|
-
super().__init__(
|
68
|
+
super().__init__()
|
44
69
|
|
45
70
|
|
46
71
|
class FeatureViewStatus(Enum):
|
@@ -97,12 +122,13 @@ class FeatureView:
|
|
97
122
|
timestamp_col: Optional[str] = None,
|
98
123
|
refresh_freq: Optional[str] = None,
|
99
124
|
desc: str = "",
|
125
|
+
**_kwargs: Any,
|
100
126
|
) -> None:
|
101
127
|
"""
|
102
128
|
Create a FeatureView instance.
|
103
129
|
|
104
130
|
Args:
|
105
|
-
name: name of the FeatureView. NOTE:
|
131
|
+
name: name of the FeatureView. NOTE: following Snowflake identifier rule
|
106
132
|
entities: entities that the FeatureView is associated with.
|
107
133
|
feature_df: Snowpark DataFrame containing data source and all feature feature_df logics.
|
108
134
|
Final projection of the DataFrame should contain feature names, join keys and timestamp(if applicable).
|
@@ -116,6 +142,7 @@ class FeatureView:
|
|
116
142
|
NOTE: If refresh_freq is not provided, then FeatureView will be registered as View on Snowflake backend
|
117
143
|
and there won't be extra storage cost.
|
118
144
|
desc: description of the FeatureView.
|
145
|
+
_kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
|
119
146
|
"""
|
120
147
|
|
121
148
|
self._name: SqlIdentifier = SqlIdentifier(name)
|
@@ -125,6 +152,7 @@ class FeatureView:
|
|
125
152
|
SqlIdentifier(timestamp_col) if timestamp_col is not None else None
|
126
153
|
)
|
127
154
|
self._desc: str = desc
|
155
|
+
self._infer_schema_df: DataFrame = _kwargs.get("_infer_schema_df", self._feature_df)
|
128
156
|
self._query: str = self._get_query()
|
129
157
|
self._version: Optional[FeatureViewVersion] = None
|
130
158
|
self._status: FeatureViewStatus = FeatureViewStatus.DRAFT
|
@@ -271,7 +299,7 @@ class FeatureView:
|
|
271
299
|
|
272
300
|
@property
|
273
301
|
def output_schema(self) -> StructType:
|
274
|
-
return self.
|
302
|
+
return self._infer_schema_df.schema
|
275
303
|
|
276
304
|
@property
|
277
305
|
def refresh_mode(self) -> Optional[str]:
|
@@ -285,6 +313,11 @@ class FeatureView:
|
|
285
313
|
def owner(self) -> Optional[str]:
|
286
314
|
return self._owner
|
287
315
|
|
316
|
+
def _metadata(self) -> _FeatureViewMetadata:
|
317
|
+
entity_names = [e.name.identifier() for e in self.entities]
|
318
|
+
ts_col = self.timestamp_col.identifier() if self.timestamp_col is not None else _TIMESTAMP_COL_PLACEHOLDER
|
319
|
+
return _FeatureViewMetadata(entity_names, ts_col)
|
320
|
+
|
288
321
|
def _get_query(self) -> str:
|
289
322
|
if len(self._feature_df.queries["queries"]) != 1:
|
290
323
|
raise ValueError(
|
@@ -300,7 +333,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
300
333
|
f"FeatureView name `{self._name}` contains invalid character `{_FEATURE_VIEW_NAME_DELIMITER}`."
|
301
334
|
)
|
302
335
|
|
303
|
-
unescaped_df_cols = to_sql_identifiers(self.
|
336
|
+
unescaped_df_cols = to_sql_identifiers(self._infer_schema_df.columns)
|
304
337
|
for e in self._entities:
|
305
338
|
for k in e.join_keys:
|
306
339
|
if k not in unescaped_df_cols:
|
@@ -312,17 +345,17 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
312
345
|
ts_col = self._timestamp_col
|
313
346
|
if ts_col == SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER):
|
314
347
|
raise ValueError(f"Invalid timestamp_col name, cannot be {_TIMESTAMP_COL_PLACEHOLDER}.")
|
315
|
-
if ts_col not in to_sql_identifiers(self.
|
348
|
+
if ts_col not in to_sql_identifiers(self._infer_schema_df.columns):
|
316
349
|
raise ValueError(f"timestamp_col {ts_col} is not found in input dataframe.")
|
317
350
|
|
318
|
-
col_type = self.
|
351
|
+
col_type = self._infer_schema_df.schema[ts_col].datatype
|
319
352
|
if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
|
320
353
|
raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
|
321
354
|
|
322
355
|
def _get_feature_names(self) -> List[SqlIdentifier]:
|
323
356
|
join_keys = [k for e in self._entities for k in e.join_keys]
|
324
357
|
ts_col = [self._timestamp_col] if self._timestamp_col is not None else []
|
325
|
-
feature_names = to_sql_identifiers(self.
|
358
|
+
feature_names = to_sql_identifiers(self._infer_schema_df.columns, case_sensitive=False)
|
326
359
|
return [c for c in feature_names if c not in join_keys + ts_col]
|
327
360
|
|
328
361
|
def __repr__(self) -> str:
|
@@ -355,6 +388,9 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
355
388
|
fv_dict = self.__dict__.copy()
|
356
389
|
if "_feature_df" in fv_dict:
|
357
390
|
fv_dict.pop("_feature_df")
|
391
|
+
if "_infer_schema_df" in fv_dict:
|
392
|
+
infer_schema_df = fv_dict.pop("_infer_schema_df")
|
393
|
+
fv_dict["_infer_schema_query"] = infer_schema_df.queries["queries"][0]
|
358
394
|
fv_dict["_entities"] = [e._to_dict() for e in self._entities]
|
359
395
|
fv_dict["_status"] = str(self._status)
|
360
396
|
fv_dict["_name"] = str(self._name) if self._name is not None else None
|
@@ -411,6 +447,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
411
447
|
refresh_mode=json_dict["_refresh_mode"],
|
412
448
|
refresh_mode_reason=json_dict["_refresh_mode_reason"],
|
413
449
|
owner=json_dict["_owner"],
|
450
|
+
infer_schema_df=session.sql(json_dict.get("_infer_schema_query", None)),
|
414
451
|
)
|
415
452
|
|
416
453
|
@staticmethod
|
@@ -436,12 +473,13 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
436
473
|
status: FeatureViewStatus,
|
437
474
|
feature_descs: Dict[str, str],
|
438
475
|
refresh_freq: Optional[str],
|
439
|
-
database:
|
440
|
-
schema:
|
476
|
+
database: str,
|
477
|
+
schema: str,
|
441
478
|
warehouse: Optional[str],
|
442
479
|
refresh_mode: Optional[str],
|
443
480
|
refresh_mode_reason: Optional[str],
|
444
481
|
owner: Optional[str],
|
482
|
+
infer_schema_df: Optional[DataFrame],
|
445
483
|
) -> FeatureView:
|
446
484
|
fv = FeatureView(
|
447
485
|
name=name,
|
@@ -449,6 +487,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
449
487
|
feature_df=feature_df,
|
450
488
|
timestamp_col=timestamp_col,
|
451
489
|
desc=desc,
|
490
|
+
_infer_schema_df=infer_schema_df,
|
452
491
|
)
|
453
492
|
fv._version = FeatureViewVersion(version) if version is not None else None
|
454
493
|
fv._status = status
|
@@ -0,0 +1,149 @@
|
|
1
|
+
import re
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import Any, List, Optional, Tuple
|
4
|
+
|
5
|
+
from snowflake import snowpark
|
6
|
+
from snowflake.connector import connection
|
7
|
+
from snowflake.ml._internal import telemetry
|
8
|
+
from snowflake.ml._internal.exceptions import (
|
9
|
+
error_codes,
|
10
|
+
exceptions as snowml_exceptions,
|
11
|
+
fileset_errors,
|
12
|
+
)
|
13
|
+
from snowflake.ml._internal.utils import identifier
|
14
|
+
from snowflake.snowpark import exceptions as snowpark_exceptions
|
15
|
+
|
16
|
+
from . import stage_fs
|
17
|
+
|
18
|
+
_SNOWURL_PATH_RE = re.compile(r"versions/(?P<version>[^/]+)(?:/+(?P<filepath>.*))?")
|
19
|
+
|
20
|
+
|
21
|
+
class SFEmbeddedStageFileSystem(stage_fs.SFStageFileSystem):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
*,
|
25
|
+
domain: str,
|
26
|
+
name: str,
|
27
|
+
snowpark_session: Optional[snowpark.Session] = None,
|
28
|
+
sf_connection: Optional[connection.SnowflakeConnection] = None,
|
29
|
+
**kwargs: Any,
|
30
|
+
) -> None:
|
31
|
+
|
32
|
+
(db, schema, object_name, _) = identifier.parse_schema_level_object_identifier(name)
|
33
|
+
self._name = name # TODO: Require or resolve FQN
|
34
|
+
self._domain = domain
|
35
|
+
|
36
|
+
super().__init__(
|
37
|
+
db=db,
|
38
|
+
schema=schema,
|
39
|
+
stage=object_name,
|
40
|
+
snowpark_session=snowpark_session,
|
41
|
+
sf_connection=sf_connection,
|
42
|
+
**kwargs,
|
43
|
+
)
|
44
|
+
|
45
|
+
@property
|
46
|
+
def stage_name(self) -> str:
|
47
|
+
"""Get the Snowflake path to this stage.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
A string in the format of snow://<domain>/<name>
|
51
|
+
Example: snow://dataset/my_dataset
|
52
|
+
|
53
|
+
# noqa: DAR203
|
54
|
+
"""
|
55
|
+
return f"snow://{self._domain}/{self._name}"
|
56
|
+
|
57
|
+
def _stage_path_to_relative_path(self, stage_path: str) -> str:
|
58
|
+
"""Convert a stage file path which comes from the LIST query to a relative file path in that stage.
|
59
|
+
|
60
|
+
The file path returned by LIST query always has the format "versions/<version>/<relative_file_path>".
|
61
|
+
The full "versions/<version>/<relative_file_path>" is returned
|
62
|
+
|
63
|
+
Args:
|
64
|
+
stage_path: A string started with the name of the stage.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
A string of the relative stage path.
|
68
|
+
"""
|
69
|
+
return stage_path
|
70
|
+
|
71
|
+
def _fetch_presigned_urls(
|
72
|
+
self, files: List[str], url_lifetime: float = stage_fs._PRESIGNED_URL_LIFETIME_SEC
|
73
|
+
) -> List[Tuple[str, str]]:
|
74
|
+
"""Fetch presigned urls for the given files."""
|
75
|
+
# SnowURL requires full snow://<domain>/<entity>/versions/<version> as the stage path arg to get_presigned_Url
|
76
|
+
versions_dict = defaultdict(list)
|
77
|
+
for file in files:
|
78
|
+
match = _SNOWURL_PATH_RE.fullmatch(file)
|
79
|
+
assert match is not None and match.group("filepath") is not None
|
80
|
+
versions_dict[match.group("version")].append(match.group("filepath"))
|
81
|
+
try:
|
82
|
+
async_jobs: List[snowpark.AsyncJob] = []
|
83
|
+
for version, version_files in versions_dict.items():
|
84
|
+
for file in version_files:
|
85
|
+
stage_loc = f"{self.stage_name}/versions/{version}"
|
86
|
+
query_result = self._session.sql(
|
87
|
+
f"select '{version}/{file}' as name,"
|
88
|
+
f" get_presigned_url('{stage_loc}', '{file}', {url_lifetime}) as url"
|
89
|
+
).collect(
|
90
|
+
block=False,
|
91
|
+
statement_params=telemetry.get_function_usage_statement_params(
|
92
|
+
project=stage_fs._PROJECT,
|
93
|
+
api_calls=[snowpark.DataFrame.collect],
|
94
|
+
),
|
95
|
+
)
|
96
|
+
async_jobs.append(query_result)
|
97
|
+
presigned_urls: List[Tuple[str, str]] = [
|
98
|
+
(r["NAME"], r["URL"]) for job in async_jobs for r in stage_fs._resolve_async_job(job)
|
99
|
+
]
|
100
|
+
return presigned_urls
|
101
|
+
except snowpark_exceptions.SnowparkClientException as e:
|
102
|
+
if e.message.startswith(fileset_errors.ERRNO_DOMAIN_NOT_EXIST) or e.message.startswith(
|
103
|
+
fileset_errors.ERRNO_STAGE_NOT_EXIST
|
104
|
+
):
|
105
|
+
raise snowml_exceptions.SnowflakeMLException(
|
106
|
+
error_code=error_codes.SNOWML_NOT_FOUND,
|
107
|
+
original_exception=fileset_errors.StageNotFoundError(
|
108
|
+
f"Stage {self.stage_name} does not exist or is not authorized."
|
109
|
+
),
|
110
|
+
)
|
111
|
+
else:
|
112
|
+
raise snowml_exceptions.SnowflakeMLException(
|
113
|
+
error_code=error_codes.INTERNAL_SNOWML_ERROR,
|
114
|
+
original_exception=fileset_errors.FileSetError(str(e)),
|
115
|
+
)
|
116
|
+
|
117
|
+
@classmethod
|
118
|
+
def _parent(cls, path: str) -> str:
|
119
|
+
"""Get parent of specified path up to minimally valid root path.
|
120
|
+
|
121
|
+
For SnowURL, the minimum valid path is snow://<domain>/<entity>/versions/<version>
|
122
|
+
|
123
|
+
Args:
|
124
|
+
path: File or directory path
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
Parent path
|
128
|
+
|
129
|
+
Examples:
|
130
|
+
----
|
131
|
+
>>> fs._parent("snow://dataset/my_ds/versions/my_version/file.ext")
|
132
|
+
"snow://dataset/my_ds/versions/my_version/"
|
133
|
+
>>> fs._parent("snow://dataset/my_ds/versions/my_version/subdir/file.ext")
|
134
|
+
"snow://dataset/my_ds/versions/my_version/subdir/"
|
135
|
+
>>> fs._parent("snow://dataset/my_ds/versions/my_version/")
|
136
|
+
"snow://dataset/my_ds/versions/my_version/"
|
137
|
+
>>> fs._parent("snow://dataset/my_ds/versions/my_version")
|
138
|
+
"snow://dataset/my_ds/versions/my_version"
|
139
|
+
"""
|
140
|
+
path_match = _SNOWURL_PATH_RE.fullmatch(path)
|
141
|
+
if not path_match:
|
142
|
+
return super()._parent(path) # type: ignore[no-any-return]
|
143
|
+
filepath: str = path_match.group("filepath") or ""
|
144
|
+
root: str = path[: path_match.start("filepath")] if filepath else path
|
145
|
+
if "/" in filepath:
|
146
|
+
parent = filepath.rsplit("/", 1)[0]
|
147
|
+
return root + parent
|
148
|
+
else:
|
149
|
+
return root
|
snowflake/ml/fileset/sfcfs.py
CHANGED
@@ -185,7 +185,6 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
185
185
|
func_params_to_log=["detail"],
|
186
186
|
conn_attr_name="_conn",
|
187
187
|
)
|
188
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
189
188
|
def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[List[str], List[Dict[str, Any]]]:
|
190
189
|
"""Override fsspec `ls` method. List single "directory" with or without details.
|
191
190
|
|
@@ -216,7 +215,6 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
216
215
|
project=_PROJECT,
|
217
216
|
conn_attr_name="_conn",
|
218
217
|
)
|
219
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
220
218
|
def optimize_read(self, files: Optional[List[str]] = None) -> None:
|
221
219
|
"""Prefetch and cache the presigned urls for all the given files to speed up the file opening.
|
222
220
|
|
@@ -242,7 +240,6 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
242
240
|
project=_PROJECT,
|
243
241
|
conn_attr_name="_conn",
|
244
242
|
)
|
245
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
246
243
|
def _open(self, path: str, **kwargs: Any) -> fsspec.spec.AbstractBufferedFile:
|
247
244
|
"""Override fsspec `_open` method. Open a file for reading in 'rb' mode.
|
248
245
|
|
@@ -268,7 +265,6 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
268
265
|
project=_PROJECT,
|
269
266
|
conn_attr_name="_conn",
|
270
267
|
)
|
271
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
272
268
|
def info(self, path: str, **kwargs: Any) -> Dict[str, Any]:
|
273
269
|
"""Override fsspec `info` method. Give details of entry at path."""
|
274
270
|
file_path = self._parse_file_path(path)
|
@@ -0,0 +1,160 @@
|
|
1
|
+
import collections
|
2
|
+
import logging
|
3
|
+
import re
|
4
|
+
from typing import Any, Dict, Optional
|
5
|
+
|
6
|
+
import fsspec
|
7
|
+
import packaging.version as pkg_version
|
8
|
+
|
9
|
+
from snowflake import snowpark
|
10
|
+
from snowflake.connector import connection
|
11
|
+
from snowflake.ml._internal.exceptions import (
|
12
|
+
error_codes,
|
13
|
+
exceptions as snowml_exceptions,
|
14
|
+
)
|
15
|
+
from snowflake.ml._internal.utils import identifier, snowflake_env
|
16
|
+
from snowflake.ml.fileset import embedded_stage_fs, sfcfs
|
17
|
+
|
18
|
+
PROTOCOL_NAME = "snow"
|
19
|
+
|
20
|
+
_SFFileEntityPath = collections.namedtuple(
|
21
|
+
"_SFFileEntityPath", ["domain", "name", "filepath", "version", "relative_path"]
|
22
|
+
)
|
23
|
+
_PROJECT = "FileSet"
|
24
|
+
_SNOWURL_PATTERN = re.compile(
|
25
|
+
f"({PROTOCOL_NAME}://)?"
|
26
|
+
r"(?<!@)(?P<domain>\w+)/"
|
27
|
+
rf"(?P<name>(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/"
|
28
|
+
r"(?P<path>versions/(?:(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)"
|
29
|
+
)
|
30
|
+
|
31
|
+
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
32
|
+
_BUG_VERSION_MIN = pkg_version.Version("8.17") # Inclusive minimum version with bugged behavior
|
33
|
+
_BUG_VERSION_MAX = pkg_version.Version("8.18") # Exclusive maximum version with bugged behavior
|
34
|
+
|
35
|
+
|
36
|
+
class SnowFileSystem(sfcfs.SFFileSystem):
|
37
|
+
"""A filesystem that allows user to access Snowflake embedded stage files with valid Snowflake locations.
|
38
|
+
|
39
|
+
The file system is is based on fsspec (https://filesystem-spec.readthedocs.io/). It is a file system wrapper
|
40
|
+
built on top of SFStageFileSystem. It takes Snowflake embedded stage path as the input and supports read operation.
|
41
|
+
A valid Snowflake location will have the form "snow://{domain}/{entity_name}/versions/{version}/{path_to_file}".
|
42
|
+
|
43
|
+
See `sfcfs.SFFileSystem` documentation for example usage patterns.
|
44
|
+
"""
|
45
|
+
|
46
|
+
protocol = PROTOCOL_NAME
|
47
|
+
_IS_BUGGED_VERSION = None
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
sf_connection: Optional[connection.SnowflakeConnection] = None,
|
52
|
+
snowpark_session: Optional[snowpark.Session] = None,
|
53
|
+
**kwargs: Any,
|
54
|
+
) -> None:
|
55
|
+
super().__init__(sf_connection=sf_connection, snowpark_session=snowpark_session, **kwargs)
|
56
|
+
|
57
|
+
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
58
|
+
if SnowFileSystem._IS_BUGGED_VERSION is None:
|
59
|
+
try:
|
60
|
+
sf_version = snowflake_env.get_current_snowflake_version(self._session)
|
61
|
+
SnowFileSystem._IS_BUGGED_VERSION = _BUG_VERSION_MIN <= sf_version < _BUG_VERSION_MAX
|
62
|
+
except Exception:
|
63
|
+
SnowFileSystem._IS_BUGGED_VERSION = False
|
64
|
+
|
65
|
+
def info(self, path: str, **kwargs: Any) -> Dict[str, Any]:
|
66
|
+
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
67
|
+
res: Dict[str, Any] = super().info(path, **kwargs)
|
68
|
+
if res.get("type") == "directory" and not res["name"].endswith("/"):
|
69
|
+
res["name"] += "/"
|
70
|
+
return res
|
71
|
+
|
72
|
+
def _get_stage_fs(
|
73
|
+
self, sf_file_path: _SFFileEntityPath # type: ignore[override]
|
74
|
+
) -> embedded_stage_fs.SFEmbeddedStageFileSystem:
|
75
|
+
"""Get the stage file system for the given snowflake location.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
sf_file_path: The Snowflake path information.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
A SFEmbeddedStageFileSystem object which supports readonly file operations on Snowflake embedded stages.
|
82
|
+
"""
|
83
|
+
stage_fs_key = (sf_file_path.domain, sf_file_path.name, sf_file_path.version)
|
84
|
+
if stage_fs_key not in self._stage_fs_set:
|
85
|
+
cnt_stage_fs = embedded_stage_fs.SFEmbeddedStageFileSystem(
|
86
|
+
snowpark_session=self._session,
|
87
|
+
domain=sf_file_path.domain,
|
88
|
+
name=sf_file_path.name,
|
89
|
+
**self._kwargs,
|
90
|
+
)
|
91
|
+
self._stage_fs_set[stage_fs_key] = cnt_stage_fs
|
92
|
+
return self._stage_fs_set[stage_fs_key]
|
93
|
+
|
94
|
+
def _stage_path_to_absolute_path(self, stage_fs: embedded_stage_fs.SFEmbeddedStageFileSystem, path: str) -> str:
|
95
|
+
"""Convert the relative path in a stage to an absolute path starts with the location of the stage."""
|
96
|
+
# Strip protocol from absolute path, since backend needs snow:// prefix to resolve correctly
|
97
|
+
# but fsspec logic strips protocol when doing any searching and globbing
|
98
|
+
stage_name = stage_fs.stage_name
|
99
|
+
protocol = f"{PROTOCOL_NAME}://"
|
100
|
+
if stage_name.startswith(protocol):
|
101
|
+
stage_name = stage_name[len(protocol) :]
|
102
|
+
abs_path = stage_name + "/" + path
|
103
|
+
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
104
|
+
if self._IS_BUGGED_VERSION:
|
105
|
+
match = _SNOWURL_PATTERN.fullmatch(abs_path)
|
106
|
+
assert match is not None
|
107
|
+
if match.group("relpath"):
|
108
|
+
abs_path = abs_path.replace(match.group("relpath"), match.group("relpath").lstrip("/"))
|
109
|
+
return abs_path
|
110
|
+
|
111
|
+
@classmethod
|
112
|
+
def _parse_file_path(cls, path: str) -> _SFFileEntityPath: # type: ignore[override]
|
113
|
+
"""Parse a snowflake location path.
|
114
|
+
|
115
|
+
The following properties will be extracted from the path input:
|
116
|
+
- embedded stage domain
|
117
|
+
- entity name
|
118
|
+
- path (in format `versions/{version}/{relative_path}`)
|
119
|
+
- entity version (optional)
|
120
|
+
- relative file path (optional)
|
121
|
+
|
122
|
+
Args:
|
123
|
+
path: A string in the format of "snow://{domain}/{entity_name}/versions/{version}/{path_to_file}".
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
A namedtuple consists of domain, entity name, filepath, version, and relative path, where
|
127
|
+
filepath = "versions/{version}/{relative_path}"
|
128
|
+
|
129
|
+
Raises:
|
130
|
+
SnowflakeMLException: An error occurred when invalid path is given.
|
131
|
+
"""
|
132
|
+
snowurl_match = _SNOWURL_PATTERN.fullmatch(path)
|
133
|
+
if not snowurl_match:
|
134
|
+
raise snowml_exceptions.SnowflakeMLException(
|
135
|
+
error_code=error_codes.SNOWML_INVALID_STAGE,
|
136
|
+
original_exception=ValueError(f"Invalid Snow URL: {path}"),
|
137
|
+
)
|
138
|
+
|
139
|
+
try:
|
140
|
+
domain = snowurl_match.group("domain")
|
141
|
+
parsed_name = identifier.parse_schema_level_object_identifier(snowurl_match.group("name"))
|
142
|
+
name = identifier.get_schema_level_object_identifier(*parsed_name)
|
143
|
+
filepath = snowurl_match.group("path")
|
144
|
+
version = snowurl_match.group("version")
|
145
|
+
relative_path = snowurl_match.group("relpath") or ""
|
146
|
+
logging.debug(f"Parsed snow URL: {snowurl_match.groups()}")
|
147
|
+
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
148
|
+
if cls._IS_BUGGED_VERSION:
|
149
|
+
filepath = f"versions/{version}//{relative_path}"
|
150
|
+
return _SFFileEntityPath(
|
151
|
+
domain=domain, name=name, version=version, relative_path=relative_path, filepath=filepath
|
152
|
+
)
|
153
|
+
except ValueError as e:
|
154
|
+
raise snowml_exceptions.SnowflakeMLException(
|
155
|
+
error_code=error_codes.SNOWML_INVALID_STAGE,
|
156
|
+
original_exception=e,
|
157
|
+
)
|
158
|
+
|
159
|
+
|
160
|
+
fsspec.register_implementation(PROTOCOL_NAME, SnowFileSystem)
|
snowflake/ml/fileset/stage_fs.py
CHANGED
@@ -2,13 +2,13 @@ import inspect
|
|
2
2
|
import logging
|
3
3
|
import time
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
6
6
|
|
7
7
|
import fsspec
|
8
8
|
from fsspec.implementations import http as httpfs
|
9
9
|
|
10
10
|
from snowflake import snowpark
|
11
|
-
from snowflake.connector import connection, errorcode
|
11
|
+
from snowflake.connector import connection, errorcode, errors as snowpark_errors
|
12
12
|
from snowflake.ml._internal import telemetry
|
13
13
|
from snowflake.ml._internal.exceptions import (
|
14
14
|
error_codes,
|
@@ -18,6 +18,7 @@ from snowflake.ml._internal.exceptions import (
|
|
18
18
|
)
|
19
19
|
from snowflake.snowpark import exceptions as snowpark_exceptions
|
20
20
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
21
|
+
from snowflake.snowpark._internal.analyzer import snowflake_plan
|
21
22
|
|
22
23
|
# The default length of how long a presigned url stays active in seconds.
|
23
24
|
# Presigned url here is used to fetch file objects from Snowflake when SFStageFileSystem.open() is called.
|
@@ -144,7 +145,6 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
144
145
|
project=_PROJECT,
|
145
146
|
func_params_to_log=["detail"],
|
146
147
|
)
|
147
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
148
148
|
def ls(self, path: str, detail: bool = False) -> Union[List[str], List[Dict[str, Any]]]:
|
149
149
|
"""Override fsspec `ls` method. List single "directory" with or without details.
|
150
150
|
|
@@ -168,7 +168,8 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
168
168
|
try:
|
169
169
|
loc = self.stage_name
|
170
170
|
path = path.lstrip("/")
|
171
|
-
|
171
|
+
async_job: snowpark.AsyncJob = self._session.sql(f"LIST '{loc}/{path}'").collect(block=False)
|
172
|
+
objects: List[snowpark.Row] = _resolve_async_job(async_job)
|
172
173
|
except snowpark_exceptions.SnowparkClientException as e:
|
173
174
|
if e.message.startswith(fileset_errors.ERRNO_DOMAIN_NOT_EXIST):
|
174
175
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -191,7 +192,6 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
191
192
|
@telemetry.send_api_usage_telemetry(
|
192
193
|
project=_PROJECT,
|
193
194
|
)
|
194
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
195
195
|
def optimize_read(self, files: Optional[List[str]] = None) -> None:
|
196
196
|
"""Prefetch and cache the presigned urls for all the given files to speed up the read performance.
|
197
197
|
|
@@ -218,7 +218,6 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
218
218
|
@telemetry.send_api_usage_telemetry(
|
219
219
|
project=_PROJECT,
|
220
220
|
)
|
221
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
222
221
|
def _open(self, path: str, mode: str = "rb", **kwargs: Any) -> fsspec.spec.AbstractBufferedFile:
|
223
222
|
"""Override fsspec `_open` method. Open a file for reading.
|
224
223
|
|
@@ -292,9 +291,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
292
291
|
original_exception=e,
|
293
292
|
)
|
294
293
|
|
295
|
-
def _parse_list_result(
|
296
|
-
self, list_result: List[Tuple[str, int, str, str]], search_path: str
|
297
|
-
) -> List[Dict[str, Any]]:
|
294
|
+
def _parse_list_result(self, list_result: List[snowpark.Row], search_path: str) -> List[Dict[str, Any]]:
|
298
295
|
"""Convert the result from LIST query to the expected format of fsspec ls() method.
|
299
296
|
|
300
297
|
Note that Snowflake LIST query has different behavior with ls(). LIST query will return all the stage files
|
@@ -315,7 +312,8 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
315
312
|
"""
|
316
313
|
files: Dict[str, Dict[str, Any]] = {}
|
317
314
|
search_path = search_path.strip("/")
|
318
|
-
for
|
315
|
+
for row in list_result:
|
316
|
+
name, size, md5, last_modified = row["name"], row["size"], row["md5"], row["last_modified"]
|
319
317
|
obj_path = self._stage_path_to_relative_path(name)
|
320
318
|
if obj_path == search_path:
|
321
319
|
# If there is a exact match, then the matched object will always be a file object.
|
@@ -411,3 +409,20 @@ def _match_error_code(ex: snowpark_exceptions.SnowparkSQLException, error_code:
|
|
411
409
|
# Snowpark writes error code to message instead of populating e.error_code
|
412
410
|
error_code_str = str(error_code)
|
413
411
|
return ex.error_code == error_code_str or error_code_str in ex.message
|
412
|
+
|
413
|
+
|
414
|
+
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
415
|
+
def _resolve_async_job(async_job: snowpark.AsyncJob) -> List[snowpark.Row]:
|
416
|
+
# Make sure Snowpark exceptions are properly caught and converted by wrap_exception wrapper
|
417
|
+
try:
|
418
|
+
query_result = cast(List[snowpark.Row], async_job.result("row"))
|
419
|
+
return query_result
|
420
|
+
except snowpark_errors.DatabaseError as e:
|
421
|
+
# HACK: Snowpark surfaces a generic exception if query doesn't complete immediately
|
422
|
+
# assume it's due to FileNotFound
|
423
|
+
if type(e) is snowpark_errors.DatabaseError and "results are unavailable" in str(e):
|
424
|
+
raise snowml_exceptions.SnowflakeMLException(
|
425
|
+
error_code=error_codes.SNOWML_NOT_FOUND,
|
426
|
+
original_exception=fileset_errors.StageNotFoundError("Query failed."),
|
427
|
+
) from e
|
428
|
+
raise
|
snowflake/ml/model/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from snowflake.ml.model._client.model.model_impl import Model
|
2
|
-
from snowflake.ml.model._client.model.model_version_impl import ModelVersion
|
2
|
+
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
3
3
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
4
4
|
from snowflake.ml.model.models.llm import LLM, LLMOptions
|
5
5
|
|
6
|
-
__all__ = ["Model", "ModelVersion", "HuggingFacePipelineModel", "LLM", "LLMOptions"]
|
6
|
+
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "LLM", "LLMOptions"]
|