snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -1
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +281 -21
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +12 -85
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +38 -2
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
- snowflake/ml/data/_internal/ingestor_utils.py +58 -0
- snowflake/ml/data/data_connector.py +133 -0
- snowflake/ml/data/data_ingestor.py +28 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/dataset/dataset.py +39 -32
- snowflake/ml/dataset/dataset_reader.py +18 -118
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
- snowflake/ml/feature_store/examples/example_helper.py +240 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
- snowflake/ml/feature_store/feature_store.py +987 -264
- snowflake/ml/feature_store/feature_view.py +228 -13
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +24 -18
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +45 -2
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_model_composer/model_composer.py +15 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
- snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
- snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +77 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
- snowflake/ml/modeling/cluster/birch.py +4 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
- snowflake/ml/modeling/cluster/dbscan.py +4 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
- snowflake/ml/modeling/cluster/k_means.py +4 -2
- snowflake/ml/modeling/cluster/mean_shift.py +4 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
- snowflake/ml/modeling/cluster/optics.py +4 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
- snowflake/ml/modeling/compose/column_transformer.py +4 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
- snowflake/ml/modeling/covariance/oas.py +4 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/pca.py +4 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
- snowflake/ml/modeling/impute/knn_imputer.py +4 -2
- snowflake/ml/modeling/impute/missing_indicator.py +4 -2
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
- snowflake/ml/modeling/manifold/isomap.py +4 -2
- snowflake/ml/modeling/manifold/mds.py +4 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
- snowflake/ml/modeling/manifold/tsne.py +4 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
- snowflake/ml/modeling/pipeline/pipeline.py +5 -4
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
- snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
snowflake/cortex/__init__.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1
|
-
from snowflake.cortex.
|
1
|
+
from snowflake.cortex._classify_text import ClassifyText
|
2
|
+
from snowflake.cortex._complete import Complete, CompleteOptions
|
2
3
|
from snowflake.cortex._extract_answer import ExtractAnswer
|
3
4
|
from snowflake.cortex._sentiment import Sentiment
|
4
5
|
from snowflake.cortex._summarize import Summarize
|
5
6
|
from snowflake.cortex._translate import Translate
|
6
7
|
|
7
8
|
__all__ = [
|
9
|
+
"ClassifyText",
|
8
10
|
"Complete",
|
11
|
+
"CompleteOptions",
|
9
12
|
"ExtractAnswer",
|
10
13
|
"Sentiment",
|
11
14
|
"Summarize",
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import List, Optional, Union
|
2
|
+
|
3
|
+
from snowflake import snowpark
|
4
|
+
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
|
+
from snowflake.ml._internal import telemetry
|
6
|
+
|
7
|
+
|
8
|
+
@telemetry.send_api_usage_telemetry(
|
9
|
+
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
|
+
)
|
11
|
+
def ClassifyText(
|
12
|
+
str_input: Union[str, snowpark.Column],
|
13
|
+
categories: Union[List[str], snowpark.Column],
|
14
|
+
session: Optional[snowpark.Session] = None,
|
15
|
+
) -> Union[str, snowpark.Column]:
|
16
|
+
"""Use the LLM inference service to classify the INPUT text into one of the target CATEGORIES.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
str_input: A Column of strings to classify.
|
20
|
+
categories: A list of candidate categories to classify the INPUT text into.
|
21
|
+
session: The snowpark session to use. Will be inferred by context if not specified.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
A column of classification responses.
|
25
|
+
"""
|
26
|
+
|
27
|
+
return _classify_text_impl("snowflake.cortex.classify_text", str_input, categories, session=session)
|
28
|
+
|
29
|
+
|
30
|
+
def _classify_text_impl(
|
31
|
+
function: str,
|
32
|
+
str_input: Union[str, snowpark.Column],
|
33
|
+
categories: Union[List[str], snowpark.Column],
|
34
|
+
session: Optional[snowpark.Session] = None,
|
35
|
+
) -> Union[str, snowpark.Column]:
|
36
|
+
return call_sql_function(function, session, str_input, categories)
|
snowflake/cortex/_complete.py
CHANGED
@@ -1,37 +1,299 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import time
|
4
|
+
from typing import Any, Callable, Iterator, List, Optional, TypedDict, Union, cast
|
5
|
+
from urllib.parse import urlunparse
|
6
|
+
|
7
|
+
import requests
|
8
|
+
from typing_extensions import NotRequired
|
2
9
|
|
3
10
|
from snowflake import snowpark
|
11
|
+
from snowflake.cortex._sse_client import SSEClient
|
4
12
|
from snowflake.cortex._util import (
|
5
13
|
CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
6
|
-
|
7
|
-
|
8
|
-
process_rest_response,
|
14
|
+
SnowflakeAuthenticationException,
|
15
|
+
SnowflakeConfigurationException,
|
9
16
|
)
|
10
17
|
from snowflake.ml._internal import telemetry
|
18
|
+
from snowflake.snowpark import context, functions
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class ConversationMessage(TypedDict):
|
24
|
+
"""Represents an conversation interaction."""
|
25
|
+
|
26
|
+
role: str
|
27
|
+
"""The role of the participant. For example, "user" or "assistant"."""
|
28
|
+
|
29
|
+
content: str
|
30
|
+
"""The content of the message."""
|
31
|
+
|
32
|
+
|
33
|
+
class CompleteOptions(TypedDict):
|
34
|
+
"""Options configuring a snowflake.cortex.Complete call."""
|
35
|
+
|
36
|
+
max_tokens: NotRequired[int]
|
37
|
+
""" Sets the maximum number of output tokens in the response. Small values can result in
|
38
|
+
truncated responses. """
|
39
|
+
temperature: NotRequired[float]
|
40
|
+
""" A value from 0 to 1 (inclusive) that controls the randomness of the output of the language
|
41
|
+
model. A higher temperature (for example, 0.7) results in more diverse and random output, while a lower
|
42
|
+
temperature (such as 0.2) makes the output more deterministic and focused. """
|
43
|
+
|
44
|
+
top_p: NotRequired[float]
|
45
|
+
""" A value from 0 to 1 (inclusive) that controls the randomness and diversity of the language model,
|
46
|
+
generally used as an alternative to temperature. The difference is that top_p restricts the set of possible tokens
|
47
|
+
that the model outputs, while temperature influences which tokens are chosen at each step. """
|
48
|
+
|
49
|
+
|
50
|
+
class ResponseParseException(Exception):
|
51
|
+
"""This exception is raised when the server response cannot be parsed."""
|
52
|
+
|
53
|
+
pass
|
54
|
+
|
55
|
+
|
56
|
+
_MAX_RETRY_SECONDS = 30
|
57
|
+
|
58
|
+
|
59
|
+
def retry(func: Callable[..., requests.Response]) -> Callable[..., requests.Response]:
|
60
|
+
def inner(*args: Any, **kwargs: Any) -> requests.Response:
|
61
|
+
deadline = cast(Optional[float], kwargs["deadline"])
|
62
|
+
kwargs = {key: value for key, value in kwargs.items() if key != "deadline"}
|
63
|
+
expRetrySeconds = 0.5
|
64
|
+
while True:
|
65
|
+
if deadline is not None and time.time() > deadline:
|
66
|
+
raise TimeoutError()
|
67
|
+
response = func(*args, **kwargs)
|
68
|
+
if response.status_code >= 200 and response.status_code < 300:
|
69
|
+
return response
|
70
|
+
retry_status_codes = [429, 503, 504]
|
71
|
+
if response.status_code not in retry_status_codes:
|
72
|
+
response.raise_for_status()
|
73
|
+
logger.debug(f"request failed with status code {response.status_code}, retrying")
|
74
|
+
|
75
|
+
# Formula: delay(i) = max(RetryAfterHeader, min(2^i, _MAX_RETRY_SECONDS)).
|
76
|
+
expRetrySeconds = min(2 * expRetrySeconds, _MAX_RETRY_SECONDS)
|
77
|
+
retrySeconds = expRetrySeconds
|
78
|
+
retryAfterHeader = response.headers.get("retry-after")
|
79
|
+
if retryAfterHeader is not None:
|
80
|
+
retrySeconds = max(retrySeconds, int(retryAfterHeader))
|
81
|
+
logger.debug(f"sleeping for {retrySeconds}s before retrying")
|
82
|
+
time.sleep(retrySeconds)
|
83
|
+
|
84
|
+
return inner
|
85
|
+
|
86
|
+
|
87
|
+
@retry
|
88
|
+
def _call_complete_rest(
|
89
|
+
model: str,
|
90
|
+
prompt: Union[str, List[ConversationMessage]],
|
91
|
+
options: Optional[CompleteOptions] = None,
|
92
|
+
session: Optional[snowpark.Session] = None,
|
93
|
+
stream: bool = False,
|
94
|
+
) -> requests.Response:
|
95
|
+
session = session or context.get_active_session()
|
96
|
+
if session is None:
|
97
|
+
raise SnowflakeAuthenticationException(
|
98
|
+
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
99
|
+
available in your environment."""
|
100
|
+
)
|
101
|
+
|
102
|
+
if session.connection.host is None or session.connection.host == "":
|
103
|
+
raise SnowflakeConfigurationException("Snowflake connection configuration does not specify 'host'")
|
104
|
+
|
105
|
+
if session.connection.rest is None or not hasattr(session.connection.rest, "token"):
|
106
|
+
raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
|
107
|
+
|
108
|
+
if session.connection.rest.token is None or session.connection.rest.token == "":
|
109
|
+
raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
|
110
|
+
|
111
|
+
scheme = "https"
|
112
|
+
if hasattr(session.connection, "scheme"):
|
113
|
+
scheme = session.connection.scheme
|
114
|
+
url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference:complete", "", "", ""))
|
115
|
+
|
116
|
+
headers = {
|
117
|
+
"Content-Type": "application/json",
|
118
|
+
"Authorization": f'Snowflake Token="{session.connection.rest.token}"',
|
119
|
+
"Accept": "application/json, text/event-stream",
|
120
|
+
}
|
121
|
+
|
122
|
+
data = {
|
123
|
+
"model": model,
|
124
|
+
"stream": stream,
|
125
|
+
}
|
126
|
+
if isinstance(prompt, List):
|
127
|
+
data["messages"] = prompt
|
128
|
+
else:
|
129
|
+
data["messages"] = [{"content": prompt}]
|
130
|
+
|
131
|
+
if options:
|
132
|
+
if "max_tokens" in options:
|
133
|
+
data["max_tokens"] = options["max_tokens"]
|
134
|
+
data["max_output_tokens"] = options["max_tokens"]
|
135
|
+
if "temperature" in options:
|
136
|
+
data["temperature"] = options["temperature"]
|
137
|
+
if "top_p" in options:
|
138
|
+
data["top_p"] = options["top_p"]
|
139
|
+
|
140
|
+
logger.debug(f"making POST request to {url} (model={model}, stream={stream})")
|
141
|
+
return requests.post(
|
142
|
+
url,
|
143
|
+
json=data,
|
144
|
+
headers=headers,
|
145
|
+
stream=stream,
|
146
|
+
)
|
147
|
+
|
148
|
+
|
149
|
+
def _process_rest_response(
|
150
|
+
response: requests.Response,
|
151
|
+
stream: bool = False,
|
152
|
+
deadline: Optional[float] = None,
|
153
|
+
) -> Union[str, Iterator[str]]:
|
154
|
+
if stream:
|
155
|
+
return _return_stream_response(response, deadline)
|
156
|
+
|
157
|
+
try:
|
158
|
+
content = response.json()["choices"][0]["message"]["content"]
|
159
|
+
assert isinstance(content, str)
|
160
|
+
return content
|
161
|
+
except (KeyError, IndexError, AssertionError) as e:
|
162
|
+
# Unlike the streaming case, errors are not ignored because a message must be returned.
|
163
|
+
raise ResponseParseException("Failed to parse message from response.") from e
|
164
|
+
|
165
|
+
|
166
|
+
def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
|
167
|
+
client = SSEClient(response)
|
168
|
+
for event in client.events():
|
169
|
+
if deadline is not None and time.time() > deadline:
|
170
|
+
raise TimeoutError()
|
171
|
+
try:
|
172
|
+
yield json.loads(event.data)["choices"][0]["delta"]["content"]
|
173
|
+
except (json.JSONDecodeError, KeyError, IndexError):
|
174
|
+
# For the sake of evolution of the output format,
|
175
|
+
# ignore stream messages that don't match the expected format.
|
176
|
+
pass
|
177
|
+
|
178
|
+
|
179
|
+
def _complete_call_sql_function_snowpark(
|
180
|
+
function: str, *args: Union[str, snowpark.Column, CompleteOptions]
|
181
|
+
) -> snowpark.Column:
|
182
|
+
return cast(snowpark.Column, functions.builtin(function)(*args))
|
183
|
+
|
184
|
+
|
185
|
+
def _complete_call_sql_function_immediate(
|
186
|
+
function: str,
|
187
|
+
model: str,
|
188
|
+
prompt: Union[str, List[ConversationMessage]],
|
189
|
+
options: Optional[CompleteOptions],
|
190
|
+
session: Optional[snowpark.Session],
|
191
|
+
) -> str:
|
192
|
+
session = session or context.get_active_session()
|
193
|
+
if session is None:
|
194
|
+
raise SnowflakeAuthenticationException(
|
195
|
+
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
196
|
+
available in your environment."""
|
197
|
+
)
|
198
|
+
|
199
|
+
# https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
|
200
|
+
if options is not None or not isinstance(prompt, str):
|
201
|
+
if isinstance(prompt, List):
|
202
|
+
prompt_arg = prompt
|
203
|
+
else:
|
204
|
+
prompt_arg = [{"role": "user", "content": prompt}]
|
205
|
+
options = options or {}
|
206
|
+
lit_args = [
|
207
|
+
functions.lit(model),
|
208
|
+
functions.lit(prompt_arg),
|
209
|
+
functions.lit(options),
|
210
|
+
]
|
211
|
+
else:
|
212
|
+
lit_args = [
|
213
|
+
functions.lit(model),
|
214
|
+
functions.lit(prompt),
|
215
|
+
]
|
216
|
+
|
217
|
+
empty_df = session.create_dataframe([snowpark.Row()])
|
218
|
+
df = empty_df.select(functions.builtin(function)(*lit_args))
|
219
|
+
return cast(str, df.collect()[0][0])
|
220
|
+
|
221
|
+
|
222
|
+
def _complete_sql_impl(
|
223
|
+
function: str,
|
224
|
+
model: Union[str, snowpark.Column],
|
225
|
+
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
226
|
+
options: Optional[Union[CompleteOptions, snowpark.Column]],
|
227
|
+
session: Optional[snowpark.Session],
|
228
|
+
) -> Union[str, snowpark.Column]:
|
229
|
+
if isinstance(prompt, snowpark.Column):
|
230
|
+
if options is not None:
|
231
|
+
return _complete_call_sql_function_snowpark(function, model, prompt, options)
|
232
|
+
else:
|
233
|
+
return _complete_call_sql_function_snowpark(function, model, prompt)
|
234
|
+
if isinstance(model, snowpark.Column):
|
235
|
+
raise ValueError("'model' cannot be a snowpark.Column when 'prompt' is a string.")
|
236
|
+
if isinstance(options, snowpark.Column):
|
237
|
+
raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.")
|
238
|
+
return _complete_call_sql_function_immediate(function, model, prompt, options, session)
|
239
|
+
|
240
|
+
|
241
|
+
def _complete_impl(
|
242
|
+
model: Union[str, snowpark.Column],
|
243
|
+
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
244
|
+
options: Optional[CompleteOptions] = None,
|
245
|
+
session: Optional[snowpark.Session] = None,
|
246
|
+
use_rest_api_experimental: bool = False,
|
247
|
+
stream: bool = False,
|
248
|
+
function: str = "snowflake.cortex.complete",
|
249
|
+
timeout: Optional[float] = None,
|
250
|
+
deadline: Optional[float] = None,
|
251
|
+
) -> Union[str, Iterator[str], snowpark.Column]:
|
252
|
+
if timeout is not None and deadline is not None:
|
253
|
+
raise ValueError('only one of "timeout" and "deadline" must be set')
|
254
|
+
if timeout is not None:
|
255
|
+
deadline = time.time() + timeout
|
256
|
+
if use_rest_api_experimental:
|
257
|
+
if not isinstance(model, str):
|
258
|
+
raise ValueError("in REST mode, 'model' must be a string")
|
259
|
+
if not isinstance(prompt, str) and not isinstance(prompt, List):
|
260
|
+
raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
|
261
|
+
response = _call_complete_rest(model, prompt, options, session=session, stream=stream, deadline=deadline)
|
262
|
+
assert response.status_code >= 200 and response.status_code < 300
|
263
|
+
return _process_rest_response(response, stream=stream)
|
264
|
+
if stream is True:
|
265
|
+
raise ValueError("streaming can only be enabled in REST mode, set use_rest_api_experimental=True")
|
266
|
+
return _complete_sql_impl(function, model, prompt, options, session)
|
11
267
|
|
12
268
|
|
13
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
14
269
|
@telemetry.send_api_usage_telemetry(
|
15
270
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
16
271
|
)
|
17
272
|
def Complete(
|
18
273
|
model: Union[str, snowpark.Column],
|
19
|
-
prompt: Union[str, snowpark.Column],
|
274
|
+
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
275
|
+
*,
|
276
|
+
options: Optional[CompleteOptions] = None,
|
20
277
|
session: Optional[snowpark.Session] = None,
|
21
278
|
use_rest_api_experimental: bool = False,
|
22
279
|
stream: bool = False,
|
280
|
+
timeout: Optional[float] = None,
|
281
|
+
deadline: Optional[float] = None,
|
23
282
|
) -> Union[str, Iterator[str], snowpark.Column]:
|
24
283
|
"""Complete calls into the LLM inference service to perform completion.
|
25
284
|
|
26
285
|
Args:
|
27
286
|
model: A Column of strings representing model types.
|
28
287
|
prompt: A Column of prompts to send to the LLM.
|
288
|
+
options: A instance of snowflake.cortex.CompleteOptions
|
29
289
|
session: The snowpark session to use. Will be inferred by context if not specified.
|
30
290
|
use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
|
31
291
|
experimental and can be removed at any time.
|
32
292
|
stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
|
33
293
|
output as it is received. Each update is a string containing the new text content since the previous update.
|
34
294
|
The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
|
295
|
+
timeout (float): Timeout in seconds to retry failed REST requests.
|
296
|
+
deadline (float): Time in seconds since the epoch (as returned by time.time()) to retry failed REST requests.
|
35
297
|
|
36
298
|
Raises:
|
37
299
|
ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
|
@@ -39,18 +301,16 @@ def Complete(
|
|
39
301
|
Returns:
|
40
302
|
A column of string responses.
|
41
303
|
"""
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
) -> Union[str, snowpark.Column]:
|
56
|
-
return call_sql_function(function, session, model, prompt)
|
304
|
+
try:
|
305
|
+
return _complete_impl(
|
306
|
+
model,
|
307
|
+
prompt,
|
308
|
+
options=options,
|
309
|
+
session=session,
|
310
|
+
use_rest_api_experimental=use_rest_api_experimental,
|
311
|
+
stream=stream,
|
312
|
+
timeout=timeout,
|
313
|
+
deadline=deadline,
|
314
|
+
)
|
315
|
+
except ValueError as err:
|
316
|
+
raise err
|
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
|
7
7
|
|
8
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
9
8
|
@telemetry.send_api_usage_telemetry(
|
10
9
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
10
|
)
|
snowflake/cortex/_sentiment.py
CHANGED
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
|
7
7
|
|
8
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
9
8
|
@telemetry.send_api_usage_telemetry(
|
10
9
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
10
|
)
|
snowflake/cortex/_summarize.py
CHANGED
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
|
7
7
|
|
8
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
9
8
|
@telemetry.send_api_usage_telemetry(
|
10
9
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
10
|
)
|
snowflake/cortex/_translate.py
CHANGED
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
|
7
7
|
|
8
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
9
8
|
@telemetry.send_api_usage_telemetry(
|
10
9
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
10
|
)
|
snowflake/cortex/_util.py
CHANGED
@@ -1,24 +1,19 @@
|
|
1
|
-
import
|
2
|
-
from typing import Iterator, Optional, Union, cast
|
3
|
-
from urllib.parse import urljoin, urlparse
|
4
|
-
|
5
|
-
import requests
|
1
|
+
from typing import Dict, List, Optional, Union, cast
|
6
2
|
|
7
3
|
from snowflake import snowpark
|
8
|
-
from snowflake.cortex._sse_client import SSEClient
|
9
4
|
from snowflake.snowpark import context, functions
|
10
5
|
|
11
6
|
CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
|
12
7
|
|
13
8
|
|
14
|
-
class
|
15
|
-
"""This exception is raised when
|
9
|
+
class SnowflakeAuthenticationException(Exception):
|
10
|
+
"""This exception is raised when there is an issue with Snowflake's configuration."""
|
16
11
|
|
17
12
|
pass
|
18
13
|
|
19
14
|
|
20
|
-
class
|
21
|
-
"""This exception is raised when
|
15
|
+
class SnowflakeConfigurationException(Exception):
|
16
|
+
"""This exception is raised when there is an issue with Snowflake's configuration."""
|
22
17
|
|
23
18
|
pass
|
24
19
|
|
@@ -28,9 +23,10 @@ class SnowflakeAuthenticationException(Exception):
|
|
28
23
|
def call_sql_function(
|
29
24
|
function: str,
|
30
25
|
session: Optional[snowpark.Session],
|
31
|
-
*args: Union[str, snowpark.Column],
|
26
|
+
*args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
|
32
27
|
) -> Union[str, snowpark.Column]:
|
33
28
|
handle_as_column = False
|
29
|
+
|
34
30
|
for arg in args:
|
35
31
|
if isinstance(arg, snowpark.Column):
|
36
32
|
handle_as_column = True
|
@@ -43,17 +39,18 @@ def call_sql_function(
|
|
43
39
|
)
|
44
40
|
|
45
41
|
|
46
|
-
def _call_sql_function_column(
|
42
|
+
def _call_sql_function_column(
|
43
|
+
function: str, *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]]
|
44
|
+
) -> snowpark.Column:
|
47
45
|
return cast(snowpark.Column, functions.builtin(function)(*args))
|
48
46
|
|
49
47
|
|
50
48
|
def _call_sql_function_immediate(
|
51
49
|
function: str,
|
52
50
|
session: Optional[snowpark.Session],
|
53
|
-
*args: Union[str, snowpark.Column],
|
51
|
+
*args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
|
54
52
|
) -> str:
|
55
|
-
|
56
|
-
session = context.get_active_session()
|
53
|
+
session = session or context.get_active_session()
|
57
54
|
if session is None:
|
58
55
|
raise SnowflakeAuthenticationException(
|
59
56
|
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
@@ -67,73 +64,3 @@ def _call_sql_function_immediate(
|
|
67
64
|
empty_df = session.create_dataframe([snowpark.Row()])
|
68
65
|
df = empty_df.select(functions.builtin(function)(*lit_args))
|
69
66
|
return cast(str, df.collect()[0][0])
|
70
|
-
|
71
|
-
|
72
|
-
def call_rest_function(
|
73
|
-
function: str,
|
74
|
-
model: Union[str, snowpark.Column],
|
75
|
-
prompt: Union[str, snowpark.Column],
|
76
|
-
session: Optional[snowpark.Session] = None,
|
77
|
-
stream: bool = False,
|
78
|
-
) -> requests.Response:
|
79
|
-
if session is None:
|
80
|
-
session = context.get_active_session()
|
81
|
-
if session is None:
|
82
|
-
raise SnowflakeAuthenticationException(
|
83
|
-
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
84
|
-
available in your environment."""
|
85
|
-
)
|
86
|
-
|
87
|
-
if not hasattr(session.connection.rest, "token"):
|
88
|
-
raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
|
89
|
-
|
90
|
-
if session.connection.rest.token is None or session.connection.rest.token == "": # type: ignore[union-attr]
|
91
|
-
raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
|
92
|
-
|
93
|
-
url = urljoin(session.connection.host, f"api/v2/cortex/inference/{function}")
|
94
|
-
if urlparse(url).scheme == "":
|
95
|
-
url = "https://" + url
|
96
|
-
headers = {
|
97
|
-
"Content-Type": "application/json",
|
98
|
-
"Authorization": f'Snowflake Token="{session.connection.rest.token}"', # type: ignore[union-attr]
|
99
|
-
"Accept": "application/json, text/event-stream",
|
100
|
-
}
|
101
|
-
|
102
|
-
data = {
|
103
|
-
"model": model,
|
104
|
-
"messages": [{"content": prompt}],
|
105
|
-
"stream": stream,
|
106
|
-
}
|
107
|
-
|
108
|
-
response = requests.post(
|
109
|
-
url,
|
110
|
-
json=data,
|
111
|
-
headers=headers,
|
112
|
-
stream=stream,
|
113
|
-
)
|
114
|
-
response.raise_for_status()
|
115
|
-
return response
|
116
|
-
|
117
|
-
|
118
|
-
def process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
|
119
|
-
if not stream:
|
120
|
-
try:
|
121
|
-
message = response.json()["choices"][0]["message"]
|
122
|
-
output = str(message.get("content", ""))
|
123
|
-
return output
|
124
|
-
except (KeyError, IndexError) as e:
|
125
|
-
raise SSEParseException("Failed to parse streamed response.") from e
|
126
|
-
else:
|
127
|
-
return _return_gen(response)
|
128
|
-
|
129
|
-
|
130
|
-
def _return_gen(response: requests.Response) -> Iterator[str]:
|
131
|
-
client = SSEClient(response)
|
132
|
-
for event in client.events():
|
133
|
-
response_loaded = json.loads(event.data)
|
134
|
-
try:
|
135
|
-
delta = response_loaded["choices"][0]["delta"]
|
136
|
-
output = str(delta.get("content", ""))
|
137
|
-
yield output
|
138
|
-
except (KeyError, IndexError) as e:
|
139
|
-
raise SSEParseException("Failed to parse streamed response.") from e
|
@@ -64,13 +64,20 @@ class ImageRegistryHttpClient:
|
|
64
64
|
operations. For general use of a retryable HTTP client, consider using the "retryable_http" module.
|
65
65
|
"""
|
66
66
|
|
67
|
-
def __init__(self, *, session: snowpark.Session,
|
67
|
+
def __init__(self, *, repo_url: str, session: Optional[snowpark.Session] = None, no_cred: bool = False) -> None:
|
68
68
|
self._repo_url = repo_url
|
69
|
-
self._session_token_manager = session_token_manager.SessionTokenManager(session)
|
70
69
|
self._retryable_http = retryable_http.get_http_client()
|
71
|
-
self.
|
70
|
+
self._no_cred = no_cred
|
71
|
+
|
72
|
+
if not self._no_cred:
|
73
|
+
self._bearer_token = ""
|
74
|
+
assert session is not None
|
75
|
+
self._session_token_manager = session_token_manager.SessionTokenManager(session)
|
72
76
|
|
73
77
|
def _with_bearer_token_header(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
78
|
+
if self._no_cred:
|
79
|
+
return {} if not headers else headers.copy()
|
80
|
+
|
74
81
|
if not self._bearer_token:
|
75
82
|
self._fetch_bearer_token()
|
76
83
|
assert self._bearer_token
|