snowflake-ml-python 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +26 -5
- snowflake/cortex/_sentiment.py +7 -4
- snowflake/cortex/_sse_client.py +81 -0
- snowflake/cortex/_util.py +105 -8
- snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
- snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
- snowflake/ml/dataset/dataset.py +15 -12
- snowflake/ml/dataset/dataset_factory.py +3 -4
- snowflake/ml/feature_store/access_manager.py +34 -30
- snowflake/ml/feature_store/feature_store.py +3 -3
- snowflake/ml/feature_store/feature_view.py +12 -11
- snowflake/ml/fileset/snowfs.py +2 -31
- snowflake/ml/model/_client/ops/model_ops.py +43 -0
- snowflake/ml/model/_client/sql/model_version.py +55 -3
- snowflake/ml/model/_model_composer/model_composer.py +7 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
- snowflake/ml/model/_signatures/builtins_handler.py +2 -1
- snowflake/ml/model/_signatures/core.py +13 -1
- snowflake/ml/model/_signatures/pandas_handler.py +2 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +196 -242
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +161 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -2
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -2
- snowflake/ml/modeling/cluster/birch.py +9 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -2
- snowflake/ml/modeling/cluster/dbscan.py +9 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -2
- snowflake/ml/modeling/cluster/k_means.py +9 -2
- snowflake/ml/modeling/cluster/mean_shift.py +9 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -2
- snowflake/ml/modeling/cluster/optics.py +9 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -2
- snowflake/ml/modeling/compose/column_transformer.py +9 -2
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -2
- snowflake/ml/modeling/covariance/oas.py +9 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -2
- snowflake/ml/modeling/decomposition/pca.py +9 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -2
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -2
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -2
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -2
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -2
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -2
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -2
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -2
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -2
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -2
- snowflake/ml/modeling/framework/base.py +3 -8
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -2
- snowflake/ml/modeling/impute/knn_imputer.py +9 -2
- snowflake/ml/modeling/impute/missing_indicator.py +9 -2
- snowflake/ml/modeling/impute/simple_imputer.py +28 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -2
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -2
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -2
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -2
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -2
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -2
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -2
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/lars.py +9 -2
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -2
- snowflake/ml/modeling/linear_model/lasso.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -2
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -2
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -2
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -2
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -2
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -2
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/perceptron.py +9 -2
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/ridge.py +9 -2
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -2
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -2
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -2
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -2
- snowflake/ml/modeling/manifold/isomap.py +9 -2
- snowflake/ml/modeling/manifold/mds.py +9 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -2
- snowflake/ml/modeling/manifold/tsne.py +9 -2
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -2
- snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -2
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -2
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -2
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -2
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -2
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -2
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -2
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -2
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -2
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -2
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -2
- snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +5 -0
- snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
- snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -2
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -2
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -2
- snowflake/ml/modeling/svm/linear_svc.py +9 -2
- snowflake/ml/modeling/svm/linear_svr.py +9 -2
- snowflake/ml/modeling/svm/nu_svc.py +9 -2
- snowflake/ml/modeling/svm/nu_svr.py +9 -2
- snowflake/ml/modeling/svm/svc.py +9 -2
- snowflake/ml/modeling/svm/svr.py +9 -2
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -2
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -2
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -2
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -2
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -2
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -2
- snowflake/ml/registry/_manager/model_manager.py +59 -1
- snowflake/ml/registry/registry.py +10 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +32 -4
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +207 -204
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
snowflake/cortex/_complete.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1
|
-
from typing import Optional, Union
|
1
|
+
from typing import Iterator, Optional, Union
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
|
-
from snowflake.cortex._util import
|
4
|
+
from snowflake.cortex._util import (
|
5
|
+
CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
6
|
+
call_rest_function,
|
7
|
+
call_sql_function,
|
8
|
+
process_rest_response,
|
9
|
+
)
|
5
10
|
from snowflake.ml._internal import telemetry
|
6
11
|
|
7
12
|
|
@@ -10,19 +15,35 @@ from snowflake.ml._internal import telemetry
|
|
10
15
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
16
|
)
|
12
17
|
def Complete(
|
13
|
-
model: Union[str, snowpark.Column],
|
14
|
-
|
18
|
+
model: Union[str, snowpark.Column],
|
19
|
+
prompt: Union[str, snowpark.Column],
|
20
|
+
session: Optional[snowpark.Session] = None,
|
21
|
+
use_rest_api_experimental: bool = False,
|
22
|
+
stream: bool = False,
|
23
|
+
) -> Union[str, Iterator[str], snowpark.Column]:
|
15
24
|
"""Complete calls into the LLM inference service to perform completion.
|
16
25
|
|
17
26
|
Args:
|
18
27
|
model: A Column of strings representing model types.
|
19
28
|
prompt: A Column of prompts to send to the LLM.
|
20
29
|
session: The snowpark session to use. Will be inferred by context if not specified.
|
30
|
+
use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
|
31
|
+
experimental and can be removed at any time.
|
32
|
+
stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
|
33
|
+
output as it is received. Each update is a string containing the new text content since the previous update.
|
34
|
+
The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
|
35
|
+
|
36
|
+
Raises:
|
37
|
+
ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
|
21
38
|
|
22
39
|
Returns:
|
23
40
|
A column of string responses.
|
24
41
|
"""
|
25
|
-
|
42
|
+
if stream is True and use_rest_api_experimental is False:
|
43
|
+
raise ValueError("If stream is set to True use_rest_api_experimental must also be set to True")
|
44
|
+
if use_rest_api_experimental:
|
45
|
+
response = call_rest_function("complete", model, prompt, session=session, stream=stream)
|
46
|
+
return process_rest_response(response)
|
26
47
|
return _complete_impl("snowflake.cortex.complete", model, prompt, session=session)
|
27
48
|
|
28
49
|
|
snowflake/cortex/_sentiment.py
CHANGED
@@ -11,7 +11,7 @@ from snowflake.ml._internal import telemetry
|
|
11
11
|
)
|
12
12
|
def Sentiment(
|
13
13
|
text: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None
|
14
|
-
) -> Union[
|
14
|
+
) -> Union[float, snowpark.Column]:
|
15
15
|
"""Sentiment calls into the LLM inference service to perform sentiment analysis on the input text.
|
16
16
|
|
17
17
|
Args:
|
@@ -21,11 +21,14 @@ def Sentiment(
|
|
21
21
|
Returns:
|
22
22
|
A column of floats. 1 represents positive sentiment, -1 represents negative sentiment.
|
23
23
|
"""
|
24
|
-
|
25
24
|
return _sentiment_impl("snowflake.cortex.sentiment", text, session=session)
|
26
25
|
|
27
26
|
|
28
27
|
def _sentiment_impl(
|
29
28
|
function: str, text: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None
|
30
|
-
) -> Union[
|
31
|
-
|
29
|
+
) -> Union[float, snowpark.Column]:
|
30
|
+
|
31
|
+
output = call_sql_function(function, session, text)
|
32
|
+
if isinstance(output, snowpark.Column):
|
33
|
+
return output
|
34
|
+
return float(output)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
from typing import Iterator, cast
|
2
|
+
|
3
|
+
import requests
|
4
|
+
|
5
|
+
|
6
|
+
class Event:
|
7
|
+
def __init__(self, event: str = "message", data: str = "") -> None:
|
8
|
+
self.event = event
|
9
|
+
self.data = data
|
10
|
+
|
11
|
+
def __str__(self) -> str:
|
12
|
+
s = f"{self.event} event"
|
13
|
+
if self.data:
|
14
|
+
s += f", {len(self.data)} bytes"
|
15
|
+
else:
|
16
|
+
s += ", no data"
|
17
|
+
return s
|
18
|
+
|
19
|
+
|
20
|
+
class SSEClient:
|
21
|
+
def __init__(self, response: requests.Response) -> None:
|
22
|
+
|
23
|
+
self.response = response
|
24
|
+
|
25
|
+
def _read(self) -> Iterator[str]:
|
26
|
+
|
27
|
+
lines = b""
|
28
|
+
for chunk in self.response:
|
29
|
+
for line in chunk.splitlines(True):
|
30
|
+
lines += line
|
31
|
+
if lines.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
32
|
+
yield cast(str, lines)
|
33
|
+
lines = b""
|
34
|
+
if lines:
|
35
|
+
yield cast(str, lines)
|
36
|
+
|
37
|
+
def events(self) -> Iterator[Event]:
|
38
|
+
for raw_event in self._read():
|
39
|
+
event = Event()
|
40
|
+
# splitlines() only uses \r and \n
|
41
|
+
for line in raw_event.splitlines():
|
42
|
+
|
43
|
+
line = cast(bytes, line).decode("utf-8")
|
44
|
+
|
45
|
+
data = line.split(":", 1)
|
46
|
+
field = data[0]
|
47
|
+
|
48
|
+
if len(data) > 1:
|
49
|
+
# "If value starts with a single U+0020 SPACE character,
|
50
|
+
# remove it from value. .strip() would remove all white spaces"
|
51
|
+
if data[1].startswith(" "):
|
52
|
+
value = data[1][1:]
|
53
|
+
else:
|
54
|
+
value = data[1]
|
55
|
+
else:
|
56
|
+
value = ""
|
57
|
+
|
58
|
+
# The data field may come over multiple lines and their values
|
59
|
+
# are concatenated with each other.
|
60
|
+
if field == "data":
|
61
|
+
event.data += value + "\n"
|
62
|
+
elif field == "event":
|
63
|
+
event.event = value
|
64
|
+
|
65
|
+
if not event.data:
|
66
|
+
continue
|
67
|
+
|
68
|
+
# If the data field ends with a newline, remove it.
|
69
|
+
if event.data.endswith("\n"):
|
70
|
+
event.data = event.data[0:-1] # Replace trailing newline - rstrip would remove multiple.
|
71
|
+
|
72
|
+
# Empty event names default to 'message'
|
73
|
+
event.event = event.event or "message"
|
74
|
+
|
75
|
+
if event.event != "message": # ignore anything but “message” or default event
|
76
|
+
continue
|
77
|
+
|
78
|
+
yield event
|
79
|
+
|
80
|
+
def close(self) -> None:
|
81
|
+
self.response.close()
|
snowflake/cortex/_util.py
CHANGED
@@ -1,15 +1,34 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
from typing import Iterator, Optional, Union, cast
|
3
|
+
from urllib.parse import urljoin, urlparse
|
4
|
+
|
5
|
+
import requests
|
2
6
|
|
3
7
|
from snowflake import snowpark
|
8
|
+
from snowflake.cortex._sse_client import SSEClient
|
4
9
|
from snowflake.snowpark import context, functions
|
5
10
|
|
6
11
|
CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
|
7
12
|
|
8
13
|
|
14
|
+
class SSEParseException(Exception):
|
15
|
+
"""This exception is raised when an invalid server sent event is received from the server."""
|
16
|
+
|
17
|
+
pass
|
18
|
+
|
19
|
+
|
20
|
+
class SnowflakeAuthenticationException(Exception):
|
21
|
+
"""This exception is raised when the session object does not have session.connection.rest.token attribute."""
|
22
|
+
|
23
|
+
pass
|
24
|
+
|
25
|
+
|
9
26
|
# Calls a sql function, handling both immediate (e.g. python types) and batch
|
10
27
|
# (e.g. snowpark column and literal type modes).
|
11
28
|
def call_sql_function(
|
12
|
-
function: str,
|
29
|
+
function: str,
|
30
|
+
session: Optional[snowpark.Session],
|
31
|
+
*args: Union[str, snowpark.Column],
|
13
32
|
) -> Union[str, snowpark.Column]:
|
14
33
|
handle_as_column = False
|
15
34
|
for arg in args:
|
@@ -17,21 +36,29 @@ def call_sql_function(
|
|
17
36
|
handle_as_column = True
|
18
37
|
|
19
38
|
if handle_as_column:
|
20
|
-
return cast(Union[str, snowpark.Column],
|
21
|
-
return cast(
|
39
|
+
return cast(Union[str, snowpark.Column], _call_sql_function_column(function, *args))
|
40
|
+
return cast(
|
41
|
+
Union[str, snowpark.Column],
|
42
|
+
_call_sql_function_immediate(function, session, *args),
|
43
|
+
)
|
22
44
|
|
23
45
|
|
24
|
-
def
|
46
|
+
def _call_sql_function_column(function: str, *args: Union[str, snowpark.Column]) -> snowpark.Column:
|
25
47
|
return cast(snowpark.Column, functions.builtin(function)(*args))
|
26
48
|
|
27
49
|
|
28
|
-
def
|
29
|
-
function: str,
|
50
|
+
def _call_sql_function_immediate(
|
51
|
+
function: str,
|
52
|
+
session: Optional[snowpark.Session],
|
53
|
+
*args: Union[str, snowpark.Column],
|
30
54
|
) -> str:
|
31
55
|
if session is None:
|
32
56
|
session = context.get_active_session()
|
33
57
|
if session is None:
|
34
|
-
raise
|
58
|
+
raise SnowflakeAuthenticationException(
|
59
|
+
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
60
|
+
available in your environment."""
|
61
|
+
)
|
35
62
|
|
36
63
|
lit_args = []
|
37
64
|
for arg in args:
|
@@ -40,3 +67,73 @@ def call_sql_function_immediate(
|
|
40
67
|
empty_df = session.create_dataframe([snowpark.Row()])
|
41
68
|
df = empty_df.select(functions.builtin(function)(*lit_args))
|
42
69
|
return cast(str, df.collect()[0][0])
|
70
|
+
|
71
|
+
|
72
|
+
def call_rest_function(
|
73
|
+
function: str,
|
74
|
+
model: Union[str, snowpark.Column],
|
75
|
+
prompt: Union[str, snowpark.Column],
|
76
|
+
session: Optional[snowpark.Session] = None,
|
77
|
+
stream: bool = False,
|
78
|
+
) -> requests.Response:
|
79
|
+
if session is None:
|
80
|
+
session = context.get_active_session()
|
81
|
+
if session is None:
|
82
|
+
raise SnowflakeAuthenticationException(
|
83
|
+
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
84
|
+
available in your environment."""
|
85
|
+
)
|
86
|
+
|
87
|
+
if not hasattr(session.connection.rest, "token"):
|
88
|
+
raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
|
89
|
+
|
90
|
+
if session.connection.rest.token is None or session.connection.rest.token == "": # type: ignore[union-attr]
|
91
|
+
raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
|
92
|
+
|
93
|
+
url = urljoin(session.connection.host, f"api/v2/cortex/inference/{function}")
|
94
|
+
if urlparse(url).scheme == "":
|
95
|
+
url = "https://" + url
|
96
|
+
headers = {
|
97
|
+
"Content-Type": "application/json",
|
98
|
+
"Authorization": f'Snowflake Token="{session.connection.rest.token}"', # type: ignore[union-attr]
|
99
|
+
"Accept": "application/json, text/event-stream",
|
100
|
+
}
|
101
|
+
|
102
|
+
data = {
|
103
|
+
"model": model,
|
104
|
+
"messages": [{"content": prompt}],
|
105
|
+
"stream": stream,
|
106
|
+
}
|
107
|
+
|
108
|
+
response = requests.post(
|
109
|
+
url,
|
110
|
+
json=data,
|
111
|
+
headers=headers,
|
112
|
+
stream=stream,
|
113
|
+
)
|
114
|
+
response.raise_for_status()
|
115
|
+
return response
|
116
|
+
|
117
|
+
|
118
|
+
def process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
|
119
|
+
if not stream:
|
120
|
+
try:
|
121
|
+
message = response.json()["choices"][0]["message"]
|
122
|
+
output = str(message.get("content", ""))
|
123
|
+
return output
|
124
|
+
except (KeyError, IndexError) as e:
|
125
|
+
raise SSEParseException("Failed to parse streamed response.") from e
|
126
|
+
else:
|
127
|
+
return _return_gen(response)
|
128
|
+
|
129
|
+
|
130
|
+
def _return_gen(response: requests.Response) -> Iterator[str]:
|
131
|
+
client = SSEClient(response)
|
132
|
+
for event in client.events():
|
133
|
+
response_loaded = json.loads(event.data)
|
134
|
+
try:
|
135
|
+
delta = response_loaded["choices"][0]["delta"]
|
136
|
+
output = str(delta.get("content", ""))
|
137
|
+
yield output
|
138
|
+
except (KeyError, IndexError) as e:
|
139
|
+
raise SSEParseException("Failed to parse streamed response.") from e
|
@@ -1,21 +1,11 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Any, Callable, List
|
3
|
+
from typing import Any, Callable, List, Optional
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml._internal.lineage import data_source
|
7
7
|
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
def _get_datasources(*args: Any) -> List[data_source.DataSource]:
|
12
|
-
"""Helper method for extracting data sources attribute from DataFrames in an argument list"""
|
13
|
-
result = []
|
14
|
-
for arg in args:
|
15
|
-
srcs = getattr(arg, DATA_SOURCES_ATTR, None)
|
16
|
-
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
|
17
|
-
result += srcs
|
18
|
-
return result
|
8
|
+
_DATA_SOURCES_ATTR = "_data_sources"
|
19
9
|
|
20
10
|
|
21
11
|
def _wrap_func(
|
@@ -32,6 +22,37 @@ def _wrap_func(
|
|
32
22
|
return wrapped
|
33
23
|
|
34
24
|
|
25
|
+
def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
|
26
|
+
@functools.wraps(fn)
|
27
|
+
def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
|
28
|
+
df = fn(*args, **kwargs)
|
29
|
+
data_sources = get_data_sources(*args, *kwargs.values())
|
30
|
+
if data_sources:
|
31
|
+
patch_dataframe(df, data_sources, inplace=True)
|
32
|
+
return df
|
33
|
+
|
34
|
+
return wrapped
|
35
|
+
|
36
|
+
|
37
|
+
def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
38
|
+
"""Helper method for extracting data sources attribute from DataFrames in an argument list"""
|
39
|
+
result: Optional[List[data_source.DataSource]] = None
|
40
|
+
for arg in args:
|
41
|
+
srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
|
42
|
+
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
|
43
|
+
if result is None:
|
44
|
+
result = []
|
45
|
+
result += srcs
|
46
|
+
return result
|
47
|
+
|
48
|
+
|
49
|
+
def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
|
50
|
+
"""Helper method for attaching data sources to an object"""
|
51
|
+
if data_sources:
|
52
|
+
assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
|
53
|
+
setattr(obj, _DATA_SOURCES_ATTR, data_sources)
|
54
|
+
|
55
|
+
|
35
56
|
def patch_dataframe(
|
36
57
|
df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False
|
37
58
|
) -> snowpark.DataFrame:
|
@@ -62,7 +83,7 @@ def patch_dataframe(
|
|
62
83
|
]
|
63
84
|
if not inplace:
|
64
85
|
df = copy.copy(df)
|
65
|
-
|
86
|
+
set_data_sources(df, data_sources)
|
66
87
|
for func in funcs:
|
67
88
|
fn = getattr(df, func, None)
|
68
89
|
if fn is not None:
|
@@ -70,18 +91,6 @@ def patch_dataframe(
|
|
70
91
|
return df
|
71
92
|
|
72
93
|
|
73
|
-
def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
|
74
|
-
@functools.wraps(fn)
|
75
|
-
def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
|
76
|
-
df = fn(*args, **kwargs)
|
77
|
-
data_sources = _get_datasources(*args) + _get_datasources(*kwargs.values())
|
78
|
-
if data_sources:
|
79
|
-
patch_dataframe(df, data_sources, inplace=True)
|
80
|
-
return df
|
81
|
-
|
82
|
-
return wrapped
|
83
|
-
|
84
|
-
|
85
94
|
# Class-level monkey-patches
|
86
95
|
for klass, func_list in {
|
87
96
|
snowpark.DataFrame: [
|
@@ -8,14 +8,17 @@ from absl.logging import logging
|
|
8
8
|
logger = logging.getLogger(__name__)
|
9
9
|
|
10
10
|
|
11
|
-
def get_temp_file_path() -> str:
|
11
|
+
def get_temp_file_path(prefix: str = "") -> str:
|
12
12
|
"""Returns a new random temp file path.
|
13
13
|
|
14
|
+
Args:
|
15
|
+
prefix: A prefix to the temp file path, this can help add stored file information. Defaults to None.
|
16
|
+
|
14
17
|
Returns:
|
15
18
|
A new temp file path.
|
16
19
|
"""
|
17
20
|
# TODO(snandamuri): Use in-memory filesystem for temp files.
|
18
|
-
local_file = tempfile.NamedTemporaryFile(delete=True)
|
21
|
+
local_file = tempfile.NamedTemporaryFile(prefix=prefix, delete=True)
|
19
22
|
local_file_name = local_file.name
|
20
23
|
local_file.close()
|
21
24
|
return local_file_name
|
snowflake/ml/dataset/dataset.py
CHANGED
@@ -65,6 +65,20 @@ class DatasetVersion:
|
|
65
65
|
comment: Optional[str] = self._get_property("comment")
|
66
66
|
return comment
|
67
67
|
|
68
|
+
@property
|
69
|
+
def label_cols(self) -> List[str]:
|
70
|
+
metadata = self._get_metadata()
|
71
|
+
if metadata is None or metadata.label_cols is None:
|
72
|
+
return []
|
73
|
+
return metadata.label_cols
|
74
|
+
|
75
|
+
@property
|
76
|
+
def exclude_cols(self) -> List[str]:
|
77
|
+
metadata = self._get_metadata()
|
78
|
+
if metadata is None or metadata.exclude_cols is None:
|
79
|
+
return []
|
80
|
+
return metadata.exclude_cols
|
81
|
+
|
68
82
|
def _get_property(self, property_name: str, default: Any = None) -> Any:
|
69
83
|
if self._properties is None:
|
70
84
|
sql_result = (
|
@@ -91,17 +105,6 @@ class DatasetVersion:
|
|
91
105
|
warnings.warn(f"Metadata parsing failed with error: {e}", UserWarning, stacklevel=2)
|
92
106
|
return self._metadata
|
93
107
|
|
94
|
-
def _get_exclude_cols(self) -> List[str]:
|
95
|
-
metadata = self._get_metadata()
|
96
|
-
if metadata is None:
|
97
|
-
return []
|
98
|
-
cols = []
|
99
|
-
if metadata.exclude_cols:
|
100
|
-
cols.extend(metadata.exclude_cols)
|
101
|
-
if metadata.label_cols:
|
102
|
-
cols.extend(metadata.label_cols)
|
103
|
-
return cols
|
104
|
-
|
105
108
|
def url(self) -> str:
|
106
109
|
"""Returns the URL of the DatasetVersion contents in Snowflake.
|
107
110
|
|
@@ -168,7 +171,7 @@ class Dataset:
|
|
168
171
|
fully_qualified_name=self._fully_qualified_name,
|
169
172
|
version=v.name,
|
170
173
|
url=v.url(),
|
171
|
-
exclude_cols=v.
|
174
|
+
exclude_cols=(v.label_cols + v.exclude_cols),
|
172
175
|
)
|
173
176
|
],
|
174
177
|
)
|
@@ -16,8 +16,7 @@ def create_from_dataframe(
|
|
16
16
|
**version_kwargs: Any,
|
17
17
|
) -> dataset.Dataset:
|
18
18
|
"""
|
19
|
-
Create a new versioned Dataset from a DataFrame
|
20
|
-
a DatasetReader for the newly created Dataset version.
|
19
|
+
Create a new versioned Dataset from a DataFrame.
|
21
20
|
|
22
21
|
Args:
|
23
22
|
session: The Snowpark Session instance to use.
|
@@ -39,7 +38,7 @@ def create_from_dataframe(
|
|
39
38
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
40
39
|
def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.Dataset:
|
41
40
|
"""
|
42
|
-
Load a versioned Dataset
|
41
|
+
Load a versioned Dataset.
|
43
42
|
|
44
43
|
Args:
|
45
44
|
session: The Snowpark Session instance to use.
|
@@ -47,7 +46,7 @@ def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.
|
|
47
46
|
version: The dataset version name.
|
48
47
|
|
49
48
|
Returns:
|
50
|
-
A
|
49
|
+
A Dataset object.
|
51
50
|
"""
|
52
51
|
ds: dataset.Dataset = dataset.Dataset.load(session, name).select_version(version)
|
53
52
|
return ds
|
@@ -42,6 +42,8 @@ class _SessionInfo:
|
|
42
42
|
# Lists of permissions as tuples of (OBJECT_TYPE, [PRIVILEGES, ...])
|
43
43
|
_PRE_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
|
44
44
|
_FeatureStoreRole.PRODUCER: [
|
45
|
+
_Privilege("DATABASE", "{database}", ["USAGE"]),
|
46
|
+
_Privilege("SCHEMA", "{database}.{schema}", ["USAGE"]),
|
45
47
|
_Privilege(
|
46
48
|
"SCHEMA",
|
47
49
|
"{database}.{schema}",
|
@@ -69,8 +71,7 @@ _PRE_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
|
|
69
71
|
_Privilege("DYNAMIC TABLE", _ALL_OBJECTS, ["SELECT", "MONITOR"], "SCHEMA {database}.{schema}"),
|
70
72
|
_Privilege("VIEW", _ALL_OBJECTS, ["SELECT", "REFERENCES"], "SCHEMA {database}.{schema}"),
|
71
73
|
_Privilege("TABLE", _ALL_OBJECTS, ["SELECT", "REFERENCES"], "SCHEMA {database}.{schema}"),
|
72
|
-
|
73
|
-
# _Privilege("DATASET", _ALL_OBJECTS, ["USAGE"], "SCHEMA {database}.{schema}"),
|
74
|
+
_Privilege("DATASET", _ALL_OBJECTS, ["USAGE"], "SCHEMA {database}.{schema}"),
|
74
75
|
# User should decide whether they want to grant warehouse usage to CONSUMER
|
75
76
|
# _Privilege("WAREHOUSE", "{warehouse}", ["USAGE"]),
|
76
77
|
],
|
@@ -128,8 +129,7 @@ def _grant_privileges(
|
|
128
129
|
def _configure_pre_init_privileges(
|
129
130
|
session: Session,
|
130
131
|
session_info: _SessionInfo,
|
131
|
-
|
132
|
-
consumer_role: str = "SNOWML_FEATURE_STORE_CONSUMER_RL",
|
132
|
+
roles_to_create: Dict[_FeatureStoreRole, str],
|
133
133
|
) -> None:
|
134
134
|
"""
|
135
135
|
Configure Feature Store role privileges. Must be run with ACCOUNTADMIN
|
@@ -141,8 +141,7 @@ def _configure_pre_init_privileges(
|
|
141
141
|
Args:
|
142
142
|
session: Snowpark Session to interact with Snowflake backend.
|
143
143
|
session_info: Session info like database and schema for the FeatureStore instance.
|
144
|
-
|
145
|
-
consumer_role: Name of consumer role to be configured.
|
144
|
+
roles_to_create: Producer and optional consumer roles to create.
|
146
145
|
"""
|
147
146
|
|
148
147
|
# Create schema if not already exists
|
@@ -159,29 +158,30 @@ def _configure_pre_init_privileges(
|
|
159
158
|
|
160
159
|
# Pass schema ownership from admin to PRODUCER
|
161
160
|
if schema_created:
|
161
|
+
# TODO: we are missing a test case for this code path
|
162
162
|
session.sql(
|
163
|
-
f"GRANT OWNERSHIP ON SCHEMA {session_info.database}.{session_info.schema}
|
163
|
+
f"GRANT OWNERSHIP ON SCHEMA {session_info.database}.{session_info.schema} "
|
164
|
+
f"TO ROLE {roles_to_create[_FeatureStoreRole.PRODUCER]}"
|
164
165
|
).collect()
|
165
166
|
|
166
167
|
# Grant privileges to roles
|
167
|
-
|
168
|
-
|
168
|
+
for role_type, role in roles_to_create.items():
|
169
|
+
_grant_privileges(session, role, _PRE_INIT_PRIVILEGES[role_type], session_info)
|
169
170
|
|
170
171
|
|
171
172
|
def _configure_post_init_privileges(
|
172
173
|
session: Session,
|
173
174
|
session_info: _SessionInfo,
|
174
|
-
|
175
|
-
consumer_role: str = "FS_CONSUMER",
|
175
|
+
roles_to_create: Dict[_FeatureStoreRole, str],
|
176
176
|
) -> None:
|
177
|
-
|
178
|
-
|
177
|
+
for role_type, role in roles_to_create.items():
|
178
|
+
_grant_privileges(session, role, _POST_INIT_PRIVILEGES[role_type], session_info)
|
179
179
|
|
180
180
|
|
181
181
|
def _configure_role_hierarchy(
|
182
182
|
session: Session,
|
183
183
|
producer_role: str,
|
184
|
-
consumer_role: str,
|
184
|
+
consumer_role: Optional[str],
|
185
185
|
) -> None:
|
186
186
|
"""
|
187
187
|
Create Feature Store roles and configure role hierarchy hierarchy. Must be run with
|
@@ -195,18 +195,17 @@ def _configure_role_hierarchy(
|
|
195
195
|
producer_role: Name of producer role to be configured.
|
196
196
|
consumer_role: Name of consumer role to be configured.
|
197
197
|
"""
|
198
|
+
# Create the necessary roles and build role hierarchy
|
198
199
|
producer_role = SqlIdentifier(producer_role)
|
199
|
-
consumer_role = SqlIdentifier(consumer_role)
|
200
|
-
|
201
|
-
# Create the necessary roles
|
202
200
|
session.sql(f"CREATE ROLE IF NOT EXISTS {producer_role}").collect()
|
203
|
-
session.sql(f"CREATE ROLE IF NOT EXISTS {consumer_role}").collect()
|
204
|
-
|
205
|
-
# Build role hierarchy
|
206
|
-
session.sql(f"GRANT ROLE {consumer_role} TO ROLE {producer_role}").collect()
|
207
201
|
session.sql(f"GRANT ROLE {producer_role} TO ROLE SYSADMIN").collect()
|
208
202
|
session.sql(f"GRANT ROLE {producer_role} TO ROLE {session.get_current_role()}").collect()
|
209
203
|
|
204
|
+
if consumer_role is not None:
|
205
|
+
consumer_role = SqlIdentifier(consumer_role)
|
206
|
+
session.sql(f"CREATE ROLE IF NOT EXISTS {consumer_role}").collect()
|
207
|
+
session.sql(f"GRANT ROLE {consumer_role} TO ROLE {producer_role}").collect()
|
208
|
+
|
210
209
|
|
211
210
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
212
211
|
def setup_feature_store(
|
@@ -215,7 +214,7 @@ def setup_feature_store(
|
|
215
214
|
schema: str,
|
216
215
|
warehouse: str,
|
217
216
|
producer_role: str = "FS_PRODUCER",
|
218
|
-
consumer_role: str =
|
217
|
+
consumer_role: Optional[str] = None,
|
219
218
|
) -> FeatureStore:
|
220
219
|
"""
|
221
220
|
Sets up a new Feature Store including role/privilege setup. Must be run with ACCOUNTADMIN
|
@@ -230,7 +229,7 @@ def setup_feature_store(
|
|
230
229
|
schema: Schema to create the FeatureStore instance.
|
231
230
|
warehouse: Default warehouse for Feature Store compute.
|
232
231
|
producer_role: Name of producer role to be configured.
|
233
|
-
consumer_role: Name of consumer role to be configured.
|
232
|
+
consumer_role: Name of consumer role to be configured. If not specified, consumer role won't be created.
|
234
233
|
|
235
234
|
Returns:
|
236
235
|
Feature Store instance.
|
@@ -249,20 +248,25 @@ def setup_feature_store(
|
|
249
248
|
)
|
250
249
|
|
251
250
|
try:
|
251
|
+
roles_to_create = {_FeatureStoreRole.PRODUCER: producer_role}
|
252
|
+
if consumer_role is not None:
|
253
|
+
roles_to_create.update({_FeatureStoreRole.CONSUMER: consumer_role})
|
252
254
|
_configure_role_hierarchy(session, producer_role=producer_role, consumer_role=consumer_role)
|
253
255
|
except exceptions.SnowparkSQLException:
|
254
256
|
# Error can be safely ignored if roles already exist and hierarchy is already built
|
255
|
-
for role in (
|
257
|
+
for _, role in roles_to_create.items():
|
256
258
|
# Ensure roles already exist
|
257
259
|
if session.sql(f"SHOW ROLES LIKE '{role}' STARTS WITH '{role}'").count() == 0:
|
258
260
|
raise
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
261
|
+
|
262
|
+
if consumer_role is not None:
|
263
|
+
# Ensure hierarchy already configured
|
264
|
+
consumer_grants = session.sql(f"SHOW GRANTS ON ROLE {consumer_role}").collect()
|
265
|
+
if not any(r["granted_to"] == "ROLE" and r["grantee_name"] == producer_role for r in consumer_grants):
|
266
|
+
raise
|
263
267
|
|
264
268
|
# Do any pre-FeatureStore.__init__() privilege setup
|
265
|
-
_configure_pre_init_privileges(session, session_info,
|
269
|
+
_configure_pre_init_privileges(session, session_info, roles_to_create)
|
266
270
|
|
267
271
|
# Use PRODUCER role to create and operate new Feature Store
|
268
272
|
current_role = session.get_current_role()
|
@@ -274,6 +278,6 @@ def setup_feature_store(
|
|
274
278
|
session.use_role(current_role)
|
275
279
|
|
276
280
|
# Do any post-FeatureStore.__init__() privilege setup
|
277
|
-
_configure_post_init_privileges(session, session_info,
|
281
|
+
_configure_post_init_privileges(session, session_info, roles_to_create)
|
278
282
|
|
279
283
|
return fs
|
@@ -920,7 +920,7 @@ class FeatureStore:
|
|
920
920
|
try:
|
921
921
|
if output_type == "table":
|
922
922
|
table_name = f"{name}_{version}"
|
923
|
-
result_df.write.mode("errorifexists").save_as_table(table_name)
|
923
|
+
result_df.write.mode("errorifexists").save_as_table(table_name)
|
924
924
|
ds_df = self._session.table(table_name)
|
925
925
|
return ds_df
|
926
926
|
else:
|
@@ -1761,8 +1761,8 @@ class FeatureStore:
|
|
1761
1761
|
self._session.sql(
|
1762
1762
|
f"""
|
1763
1763
|
SELECT * FROM TABLE(
|
1764
|
-
INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL(
|
1765
|
-
TAG_NAME => '{_FEATURE_STORE_OBJECT_TAG}'
|
1764
|
+
{self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL(
|
1765
|
+
TAG_NAME => '{self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)}'
|
1766
1766
|
)
|
1767
1767
|
) LIMIT 1;
|
1768
1768
|
"""
|