snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +26 -5
- snowflake/cortex/_sse_client.py +81 -0
- snowflake/cortex/_util.py +105 -8
- snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
- snowflake/ml/dataset/dataset.py +15 -12
- snowflake/ml/dataset/dataset_factory.py +3 -4
- snowflake/ml/feature_store/feature_store.py +2 -2
- snowflake/ml/model/_client/sql/model_version.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_signatures/builtins_handler.py +2 -1
- snowflake/ml/model/_signatures/core.py +13 -1
- snowflake/ml/model/_signatures/pandas_handler.py +2 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +156 -121
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/framework/base.py +3 -8
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +8 -4
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +5 -0
- snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
- snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +21 -5
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +196 -195
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
snowflake/cortex/_complete.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1
|
-
from typing import Optional, Union
|
1
|
+
from typing import Iterator, Optional, Union
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
|
-
from snowflake.cortex._util import
|
4
|
+
from snowflake.cortex._util import (
|
5
|
+
CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
6
|
+
call_rest_function,
|
7
|
+
call_sql_function,
|
8
|
+
process_rest_response,
|
9
|
+
)
|
5
10
|
from snowflake.ml._internal import telemetry
|
6
11
|
|
7
12
|
|
@@ -10,19 +15,35 @@ from snowflake.ml._internal import telemetry
|
|
10
15
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
16
|
)
|
12
17
|
def Complete(
|
13
|
-
model: Union[str, snowpark.Column],
|
14
|
-
|
18
|
+
model: Union[str, snowpark.Column],
|
19
|
+
prompt: Union[str, snowpark.Column],
|
20
|
+
session: Optional[snowpark.Session] = None,
|
21
|
+
use_rest_api_experimental: bool = False,
|
22
|
+
stream: bool = False,
|
23
|
+
) -> Union[str, Iterator[str], snowpark.Column]:
|
15
24
|
"""Complete calls into the LLM inference service to perform completion.
|
16
25
|
|
17
26
|
Args:
|
18
27
|
model: A Column of strings representing model types.
|
19
28
|
prompt: A Column of prompts to send to the LLM.
|
20
29
|
session: The snowpark session to use. Will be inferred by context if not specified.
|
30
|
+
use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
|
31
|
+
experimental and can be removed at any time.
|
32
|
+
stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
|
33
|
+
output as it is received. Each update is a string containing the new text content since the previous update.
|
34
|
+
The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
|
35
|
+
|
36
|
+
Raises:
|
37
|
+
ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
|
21
38
|
|
22
39
|
Returns:
|
23
40
|
A column of string responses.
|
24
41
|
"""
|
25
|
-
|
42
|
+
if stream is True and use_rest_api_experimental is False:
|
43
|
+
raise ValueError("If stream is set to True use_rest_api_experimental must also be set to True")
|
44
|
+
if use_rest_api_experimental:
|
45
|
+
response = call_rest_function("complete", model, prompt, session=session, stream=stream)
|
46
|
+
return process_rest_response(response)
|
26
47
|
return _complete_impl("snowflake.cortex.complete", model, prompt, session=session)
|
27
48
|
|
28
49
|
|
@@ -0,0 +1,81 @@
|
|
1
|
+
from typing import Iterator, cast
|
2
|
+
|
3
|
+
import requests
|
4
|
+
|
5
|
+
|
6
|
+
class Event:
|
7
|
+
def __init__(self, event: str = "message", data: str = "") -> None:
|
8
|
+
self.event = event
|
9
|
+
self.data = data
|
10
|
+
|
11
|
+
def __str__(self) -> str:
|
12
|
+
s = f"{self.event} event"
|
13
|
+
if self.data:
|
14
|
+
s += f", {len(self.data)} bytes"
|
15
|
+
else:
|
16
|
+
s += ", no data"
|
17
|
+
return s
|
18
|
+
|
19
|
+
|
20
|
+
class SSEClient:
|
21
|
+
def __init__(self, response: requests.Response) -> None:
|
22
|
+
|
23
|
+
self.response = response
|
24
|
+
|
25
|
+
def _read(self) -> Iterator[str]:
|
26
|
+
|
27
|
+
lines = b""
|
28
|
+
for chunk in self.response:
|
29
|
+
for line in chunk.splitlines(True):
|
30
|
+
lines += line
|
31
|
+
if lines.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
32
|
+
yield cast(str, lines)
|
33
|
+
lines = b""
|
34
|
+
if lines:
|
35
|
+
yield cast(str, lines)
|
36
|
+
|
37
|
+
def events(self) -> Iterator[Event]:
|
38
|
+
for raw_event in self._read():
|
39
|
+
event = Event()
|
40
|
+
# splitlines() only uses \r and \n
|
41
|
+
for line in raw_event.splitlines():
|
42
|
+
|
43
|
+
line = cast(bytes, line).decode("utf-8")
|
44
|
+
|
45
|
+
data = line.split(":", 1)
|
46
|
+
field = data[0]
|
47
|
+
|
48
|
+
if len(data) > 1:
|
49
|
+
# "If value starts with a single U+0020 SPACE character,
|
50
|
+
# remove it from value. .strip() would remove all white spaces"
|
51
|
+
if data[1].startswith(" "):
|
52
|
+
value = data[1][1:]
|
53
|
+
else:
|
54
|
+
value = data[1]
|
55
|
+
else:
|
56
|
+
value = ""
|
57
|
+
|
58
|
+
# The data field may come over multiple lines and their values
|
59
|
+
# are concatenated with each other.
|
60
|
+
if field == "data":
|
61
|
+
event.data += value + "\n"
|
62
|
+
elif field == "event":
|
63
|
+
event.event = value
|
64
|
+
|
65
|
+
if not event.data:
|
66
|
+
continue
|
67
|
+
|
68
|
+
# If the data field ends with a newline, remove it.
|
69
|
+
if event.data.endswith("\n"):
|
70
|
+
event.data = event.data[0:-1] # Replace trailing newline - rstrip would remove multiple.
|
71
|
+
|
72
|
+
# Empty event names default to 'message'
|
73
|
+
event.event = event.event or "message"
|
74
|
+
|
75
|
+
if event.event != "message": # ignore anything but “message” or default event
|
76
|
+
continue
|
77
|
+
|
78
|
+
yield event
|
79
|
+
|
80
|
+
def close(self) -> None:
|
81
|
+
self.response.close()
|
snowflake/cortex/_util.py
CHANGED
@@ -1,15 +1,34 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
from typing import Iterator, Optional, Union, cast
|
3
|
+
from urllib.parse import urljoin, urlparse
|
4
|
+
|
5
|
+
import requests
|
2
6
|
|
3
7
|
from snowflake import snowpark
|
8
|
+
from snowflake.cortex._sse_client import SSEClient
|
4
9
|
from snowflake.snowpark import context, functions
|
5
10
|
|
6
11
|
CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
|
7
12
|
|
8
13
|
|
14
|
+
class SSEParseException(Exception):
|
15
|
+
"""This exception is raised when an invalid server sent event is received from the server."""
|
16
|
+
|
17
|
+
pass
|
18
|
+
|
19
|
+
|
20
|
+
class SnowflakeAuthenticationException(Exception):
|
21
|
+
"""This exception is raised when the session object does not have session.connection.rest.token attribute."""
|
22
|
+
|
23
|
+
pass
|
24
|
+
|
25
|
+
|
9
26
|
# Calls a sql function, handling both immediate (e.g. python types) and batch
|
10
27
|
# (e.g. snowpark column and literal type modes).
|
11
28
|
def call_sql_function(
|
12
|
-
function: str,
|
29
|
+
function: str,
|
30
|
+
session: Optional[snowpark.Session],
|
31
|
+
*args: Union[str, snowpark.Column],
|
13
32
|
) -> Union[str, snowpark.Column]:
|
14
33
|
handle_as_column = False
|
15
34
|
for arg in args:
|
@@ -17,21 +36,29 @@ def call_sql_function(
|
|
17
36
|
handle_as_column = True
|
18
37
|
|
19
38
|
if handle_as_column:
|
20
|
-
return cast(Union[str, snowpark.Column],
|
21
|
-
return cast(
|
39
|
+
return cast(Union[str, snowpark.Column], _call_sql_function_column(function, *args))
|
40
|
+
return cast(
|
41
|
+
Union[str, snowpark.Column],
|
42
|
+
_call_sql_function_immediate(function, session, *args),
|
43
|
+
)
|
22
44
|
|
23
45
|
|
24
|
-
def
|
46
|
+
def _call_sql_function_column(function: str, *args: Union[str, snowpark.Column]) -> snowpark.Column:
|
25
47
|
return cast(snowpark.Column, functions.builtin(function)(*args))
|
26
48
|
|
27
49
|
|
28
|
-
def
|
29
|
-
function: str,
|
50
|
+
def _call_sql_function_immediate(
|
51
|
+
function: str,
|
52
|
+
session: Optional[snowpark.Session],
|
53
|
+
*args: Union[str, snowpark.Column],
|
30
54
|
) -> str:
|
31
55
|
if session is None:
|
32
56
|
session = context.get_active_session()
|
33
57
|
if session is None:
|
34
|
-
raise
|
58
|
+
raise SnowflakeAuthenticationException(
|
59
|
+
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
60
|
+
available in your environment."""
|
61
|
+
)
|
35
62
|
|
36
63
|
lit_args = []
|
37
64
|
for arg in args:
|
@@ -40,3 +67,73 @@ def call_sql_function_immediate(
|
|
40
67
|
empty_df = session.create_dataframe([snowpark.Row()])
|
41
68
|
df = empty_df.select(functions.builtin(function)(*lit_args))
|
42
69
|
return cast(str, df.collect()[0][0])
|
70
|
+
|
71
|
+
|
72
|
+
def call_rest_function(
|
73
|
+
function: str,
|
74
|
+
model: Union[str, snowpark.Column],
|
75
|
+
prompt: Union[str, snowpark.Column],
|
76
|
+
session: Optional[snowpark.Session] = None,
|
77
|
+
stream: bool = False,
|
78
|
+
) -> requests.Response:
|
79
|
+
if session is None:
|
80
|
+
session = context.get_active_session()
|
81
|
+
if session is None:
|
82
|
+
raise SnowflakeAuthenticationException(
|
83
|
+
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
84
|
+
available in your environment."""
|
85
|
+
)
|
86
|
+
|
87
|
+
if not hasattr(session.connection.rest, "token"):
|
88
|
+
raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
|
89
|
+
|
90
|
+
if session.connection.rest.token is None or session.connection.rest.token == "": # type: ignore[union-attr]
|
91
|
+
raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
|
92
|
+
|
93
|
+
url = urljoin(session.connection.host, f"api/v2/cortex/inference/{function}")
|
94
|
+
if urlparse(url).scheme == "":
|
95
|
+
url = "https://" + url
|
96
|
+
headers = {
|
97
|
+
"Content-Type": "application/json",
|
98
|
+
"Authorization": f'Snowflake Token="{session.connection.rest.token}"', # type: ignore[union-attr]
|
99
|
+
"Accept": "application/json, text/event-stream",
|
100
|
+
}
|
101
|
+
|
102
|
+
data = {
|
103
|
+
"model": model,
|
104
|
+
"messages": [{"content": prompt}],
|
105
|
+
"stream": stream,
|
106
|
+
}
|
107
|
+
|
108
|
+
response = requests.post(
|
109
|
+
url,
|
110
|
+
json=data,
|
111
|
+
headers=headers,
|
112
|
+
stream=stream,
|
113
|
+
)
|
114
|
+
response.raise_for_status()
|
115
|
+
return response
|
116
|
+
|
117
|
+
|
118
|
+
def process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
|
119
|
+
if not stream:
|
120
|
+
try:
|
121
|
+
message = response.json()["choices"][0]["message"]
|
122
|
+
output = str(message.get("content", ""))
|
123
|
+
return output
|
124
|
+
except (KeyError, IndexError) as e:
|
125
|
+
raise SSEParseException("Failed to parse streamed response.") from e
|
126
|
+
else:
|
127
|
+
return _return_gen(response)
|
128
|
+
|
129
|
+
|
130
|
+
def _return_gen(response: requests.Response) -> Iterator[str]:
|
131
|
+
client = SSEClient(response)
|
132
|
+
for event in client.events():
|
133
|
+
response_loaded = json.loads(event.data)
|
134
|
+
try:
|
135
|
+
delta = response_loaded["choices"][0]["delta"]
|
136
|
+
output = str(delta.get("content", ""))
|
137
|
+
yield output
|
138
|
+
except (KeyError, IndexError) as e:
|
139
|
+
raise SSEParseException("Failed to parse streamed response.") from e
|
@@ -1,21 +1,11 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Any, Callable, List
|
3
|
+
from typing import Any, Callable, List, Optional
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml._internal.lineage import data_source
|
7
7
|
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
def _get_datasources(*args: Any) -> List[data_source.DataSource]:
|
12
|
-
"""Helper method for extracting data sources attribute from DataFrames in an argument list"""
|
13
|
-
result = []
|
14
|
-
for arg in args:
|
15
|
-
srcs = getattr(arg, DATA_SOURCES_ATTR, None)
|
16
|
-
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
|
17
|
-
result += srcs
|
18
|
-
return result
|
8
|
+
_DATA_SOURCES_ATTR = "_data_sources"
|
19
9
|
|
20
10
|
|
21
11
|
def _wrap_func(
|
@@ -32,6 +22,37 @@ def _wrap_func(
|
|
32
22
|
return wrapped
|
33
23
|
|
34
24
|
|
25
|
+
def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
|
26
|
+
@functools.wraps(fn)
|
27
|
+
def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
|
28
|
+
df = fn(*args, **kwargs)
|
29
|
+
data_sources = get_data_sources(*args, *kwargs.values())
|
30
|
+
if data_sources:
|
31
|
+
patch_dataframe(df, data_sources, inplace=True)
|
32
|
+
return df
|
33
|
+
|
34
|
+
return wrapped
|
35
|
+
|
36
|
+
|
37
|
+
def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
38
|
+
"""Helper method for extracting data sources attribute from DataFrames in an argument list"""
|
39
|
+
result: Optional[List[data_source.DataSource]] = None
|
40
|
+
for arg in args:
|
41
|
+
srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
|
42
|
+
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
|
43
|
+
if result is None:
|
44
|
+
result = []
|
45
|
+
result += srcs
|
46
|
+
return result
|
47
|
+
|
48
|
+
|
49
|
+
def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
|
50
|
+
"""Helper method for attaching data sources to an object"""
|
51
|
+
if data_sources:
|
52
|
+
assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
|
53
|
+
setattr(obj, _DATA_SOURCES_ATTR, data_sources)
|
54
|
+
|
55
|
+
|
35
56
|
def patch_dataframe(
|
36
57
|
df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False
|
37
58
|
) -> snowpark.DataFrame:
|
@@ -62,7 +83,7 @@ def patch_dataframe(
|
|
62
83
|
]
|
63
84
|
if not inplace:
|
64
85
|
df = copy.copy(df)
|
65
|
-
|
86
|
+
set_data_sources(df, data_sources)
|
66
87
|
for func in funcs:
|
67
88
|
fn = getattr(df, func, None)
|
68
89
|
if fn is not None:
|
@@ -70,18 +91,6 @@ def patch_dataframe(
|
|
70
91
|
return df
|
71
92
|
|
72
93
|
|
73
|
-
def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
|
74
|
-
@functools.wraps(fn)
|
75
|
-
def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
|
76
|
-
df = fn(*args, **kwargs)
|
77
|
-
data_sources = _get_datasources(*args) + _get_datasources(*kwargs.values())
|
78
|
-
if data_sources:
|
79
|
-
patch_dataframe(df, data_sources, inplace=True)
|
80
|
-
return df
|
81
|
-
|
82
|
-
return wrapped
|
83
|
-
|
84
|
-
|
85
94
|
# Class-level monkey-patches
|
86
95
|
for klass, func_list in {
|
87
96
|
snowpark.DataFrame: [
|
snowflake/ml/dataset/dataset.py
CHANGED
@@ -65,6 +65,20 @@ class DatasetVersion:
|
|
65
65
|
comment: Optional[str] = self._get_property("comment")
|
66
66
|
return comment
|
67
67
|
|
68
|
+
@property
|
69
|
+
def label_cols(self) -> List[str]:
|
70
|
+
metadata = self._get_metadata()
|
71
|
+
if metadata is None or metadata.label_cols is None:
|
72
|
+
return []
|
73
|
+
return metadata.label_cols
|
74
|
+
|
75
|
+
@property
|
76
|
+
def exclude_cols(self) -> List[str]:
|
77
|
+
metadata = self._get_metadata()
|
78
|
+
if metadata is None or metadata.exclude_cols is None:
|
79
|
+
return []
|
80
|
+
return metadata.exclude_cols
|
81
|
+
|
68
82
|
def _get_property(self, property_name: str, default: Any = None) -> Any:
|
69
83
|
if self._properties is None:
|
70
84
|
sql_result = (
|
@@ -91,17 +105,6 @@ class DatasetVersion:
|
|
91
105
|
warnings.warn(f"Metadata parsing failed with error: {e}", UserWarning, stacklevel=2)
|
92
106
|
return self._metadata
|
93
107
|
|
94
|
-
def _get_exclude_cols(self) -> List[str]:
|
95
|
-
metadata = self._get_metadata()
|
96
|
-
if metadata is None:
|
97
|
-
return []
|
98
|
-
cols = []
|
99
|
-
if metadata.exclude_cols:
|
100
|
-
cols.extend(metadata.exclude_cols)
|
101
|
-
if metadata.label_cols:
|
102
|
-
cols.extend(metadata.label_cols)
|
103
|
-
return cols
|
104
|
-
|
105
108
|
def url(self) -> str:
|
106
109
|
"""Returns the URL of the DatasetVersion contents in Snowflake.
|
107
110
|
|
@@ -168,7 +171,7 @@ class Dataset:
|
|
168
171
|
fully_qualified_name=self._fully_qualified_name,
|
169
172
|
version=v.name,
|
170
173
|
url=v.url(),
|
171
|
-
exclude_cols=v.
|
174
|
+
exclude_cols=(v.label_cols + v.exclude_cols),
|
172
175
|
)
|
173
176
|
],
|
174
177
|
)
|
@@ -16,8 +16,7 @@ def create_from_dataframe(
|
|
16
16
|
**version_kwargs: Any,
|
17
17
|
) -> dataset.Dataset:
|
18
18
|
"""
|
19
|
-
Create a new versioned Dataset from a DataFrame
|
20
|
-
a DatasetReader for the newly created Dataset version.
|
19
|
+
Create a new versioned Dataset from a DataFrame.
|
21
20
|
|
22
21
|
Args:
|
23
22
|
session: The Snowpark Session instance to use.
|
@@ -39,7 +38,7 @@ def create_from_dataframe(
|
|
39
38
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
40
39
|
def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.Dataset:
|
41
40
|
"""
|
42
|
-
Load a versioned Dataset
|
41
|
+
Load a versioned Dataset.
|
43
42
|
|
44
43
|
Args:
|
45
44
|
session: The Snowpark Session instance to use.
|
@@ -47,7 +46,7 @@ def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.
|
|
47
46
|
version: The dataset version name.
|
48
47
|
|
49
48
|
Returns:
|
50
|
-
A
|
49
|
+
A Dataset object.
|
51
50
|
"""
|
52
51
|
ds: dataset.Dataset = dataset.Dataset.load(session, name).select_version(version)
|
53
52
|
return ds
|
@@ -920,7 +920,7 @@ class FeatureStore:
|
|
920
920
|
try:
|
921
921
|
if output_type == "table":
|
922
922
|
table_name = f"{name}_{version}"
|
923
|
-
result_df.write.mode("errorifexists").save_as_table(table_name)
|
923
|
+
result_df.write.mode("errorifexists").save_as_table(table_name)
|
924
924
|
ds_df = self._session.table(table_name)
|
925
925
|
return ds_df
|
926
926
|
else:
|
@@ -1762,7 +1762,7 @@ class FeatureStore:
|
|
1762
1762
|
f"""
|
1763
1763
|
SELECT * FROM TABLE(
|
1764
1764
|
{self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL(
|
1765
|
-
TAG_NAME => '{_FEATURE_STORE_OBJECT_TAG}'
|
1765
|
+
TAG_NAME => '{self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)}'
|
1766
1766
|
)
|
1767
1767
|
) LIMIT 1;
|
1768
1768
|
"""
|
@@ -272,7 +272,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
272
272
|
actual_schema_name.identifier(),
|
273
273
|
tmp_table_name,
|
274
274
|
)
|
275
|
-
input_df.write.save_as_table(
|
275
|
+
input_df.write.save_as_table(
|
276
276
|
table_name=INTERMEDIATE_TABLE_NAME,
|
277
277
|
mode="errorifexists",
|
278
278
|
table_type="temporary",
|
@@ -348,7 +348,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
348
348
|
actual_schema_name.identifier(),
|
349
349
|
tmp_table_name,
|
350
350
|
)
|
351
|
-
input_df.write.save_as_table(
|
351
|
+
input_df.write.save_as_table(
|
352
352
|
table_name=INTERMEDIATE_TABLE_NAME,
|
353
353
|
mode="errorifexists",
|
354
354
|
table_type="temporary",
|
@@ -182,9 +182,9 @@ class ModelComposer:
|
|
182
182
|
def _get_data_sources(
|
183
183
|
self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
|
184
184
|
) -> Optional[List[data_source.DataSource]]:
|
185
|
-
data_sources =
|
185
|
+
data_sources = lineage_utils.get_data_sources(model)
|
186
186
|
if not data_sources and sample_input_data is not None:
|
187
|
-
data_sources =
|
187
|
+
data_sources = lineage_utils.get_data_sources(sample_input_data)
|
188
188
|
if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
|
189
189
|
return data_sources
|
190
190
|
return None
|
@@ -74,4 +74,6 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
74
74
|
class {function_name}:
|
75
75
|
@vectorized(input=pd.DataFrame)
|
76
76
|
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
77
|
-
|
77
|
+
df.columns = input_cols
|
78
|
+
input_df = df.astype(dtype=dtype_map)
|
79
|
+
return runner(input_df[input_cols])
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import datetime
|
1
2
|
from collections import abc
|
2
3
|
from typing import Literal, Sequence
|
3
4
|
|
@@ -24,7 +25,7 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
|
|
24
25
|
# String is a Sequence but we take them as an whole
|
25
26
|
if isinstance(element, abc.Sequence) and not isinstance(element, str):
|
26
27
|
can_handle = ListOfBuiltinHandler.can_handle(element)
|
27
|
-
elif not isinstance(element, (int, float, bool, str)):
|
28
|
+
elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
|
28
29
|
can_handle = False
|
29
30
|
break
|
30
31
|
return can_handle
|
@@ -53,6 +53,8 @@ class DataType(Enum):
|
|
53
53
|
STRING = ("string", spt.StringType, np.str_)
|
54
54
|
BYTES = ("bytes", spt.BinaryType, np.bytes_)
|
55
55
|
|
56
|
+
TIMESTAMP_NTZ = ("datetime64[ns]", spt.TimestampType, "datetime64[ns]")
|
57
|
+
|
56
58
|
def as_snowpark_type(self) -> spt.DataType:
|
57
59
|
"""Convert to corresponding Snowpark Type.
|
58
60
|
|
@@ -78,6 +80,13 @@ class DataType(Enum):
|
|
78
80
|
Corresponding DataType.
|
79
81
|
"""
|
80
82
|
np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType}
|
83
|
+
|
84
|
+
# Add datetime types:
|
85
|
+
datetime_res = ["Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"]
|
86
|
+
|
87
|
+
for res in datetime_res:
|
88
|
+
np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ
|
89
|
+
|
81
90
|
for potential_type in np_to_snowml_type_mapping.keys():
|
82
91
|
if np.can_cast(np_type, potential_type, casting="no"):
|
83
92
|
# This is used since the same dtype might represented in different ways.
|
@@ -247,9 +256,12 @@ class FeatureSpec(BaseFeatureSpec):
|
|
247
256
|
result_type = spt.ArrayType(result_type)
|
248
257
|
return result_type
|
249
258
|
|
250
|
-
def as_dtype(self) -> npt.DTypeLike:
|
259
|
+
def as_dtype(self) -> Union[npt.DTypeLike, str]:
|
251
260
|
"""Convert to corresponding local Type."""
|
252
261
|
if not self._shape:
|
262
|
+
# scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64'
|
263
|
+
if "datetime64" in self._dtype._value:
|
264
|
+
return self._dtype._value
|
253
265
|
return self._dtype._numpy_type
|
254
266
|
return np.object_
|
255
267
|
|
@@ -147,6 +147,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
147
147
|
specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
|
148
148
|
elif isinstance(data[df_col].iloc[0], bytes):
|
149
149
|
specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
|
150
|
+
elif isinstance(data[df_col].iloc[0], np.datetime64):
|
151
|
+
specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
|
150
152
|
else:
|
151
153
|
specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(df_col_dtype), name=ft_name))
|
152
154
|
return specs
|
@@ -107,6 +107,9 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
107
107
|
if not features:
|
108
108
|
features = pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input")
|
109
109
|
# Role will be no effect on the column index. That is to say, the feature name is the actual column name.
|
110
|
+
if keep_order:
|
111
|
+
df = df.reset_index(drop=True)
|
112
|
+
df[infer_template._KEEP_ORDER_COL_NAME] = df.index
|
110
113
|
sp_df = session.create_dataframe(df)
|
111
114
|
column_names = []
|
112
115
|
columns = []
|
@@ -122,7 +125,4 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
122
125
|
|
123
126
|
sp_df = sp_df.with_columns(column_names, columns)
|
124
127
|
|
125
|
-
if keep_order:
|
126
|
-
sp_df = sp_df.with_column(infer_template._KEEP_ORDER_COL_NAME, F.monotonically_increasing_id())
|
127
|
-
|
128
128
|
return sp_df
|
@@ -168,6 +168,8 @@ def _validate_numpy_array(
|
|
168
168
|
max_v <= np.finfo(feature_type._numpy_type).max # type: ignore[arg-type]
|
169
169
|
and min_v >= np.finfo(feature_type._numpy_type).min # type: ignore[arg-type]
|
170
170
|
)
|
171
|
+
elif feature_type in [core.DataType.TIMESTAMP_NTZ]:
|
172
|
+
return np.issubdtype(arr.dtype, np.datetime64)
|
171
173
|
else:
|
172
174
|
return np.can_cast(arr.dtype, feature_type._numpy_type, casting="no")
|
173
175
|
|