snowflake-ml-python 1.5.3__py3-none-any.whl → 1.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -1
- snowflake/cortex/_complete.py +224 -21
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +12 -85
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/telemetry.py +26 -0
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/dataset/dataset.py +39 -20
- snowflake/ml/feature_store/feature_store.py +440 -243
- snowflake/ml/feature_store/feature_view.py +61 -9
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +6 -8
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +45 -2
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -4
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -2
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/type_hints.py +73 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
- snowflake/ml/modeling/cluster/birch.py +4 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
- snowflake/ml/modeling/cluster/dbscan.py +4 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
- snowflake/ml/modeling/cluster/k_means.py +4 -2
- snowflake/ml/modeling/cluster/mean_shift.py +4 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
- snowflake/ml/modeling/cluster/optics.py +4 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
- snowflake/ml/modeling/compose/column_transformer.py +4 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
- snowflake/ml/modeling/covariance/oas.py +4 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/pca.py +4 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
- snowflake/ml/modeling/impute/knn_imputer.py +4 -2
- snowflake/ml/modeling/impute/missing_indicator.py +4 -2
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
- snowflake/ml/modeling/manifold/isomap.py +4 -2
- snowflake/ml/modeling/manifold/mds.py +4 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
- snowflake/ml/modeling/manifold/tsne.py +4 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
- snowflake/ml/modeling/pipeline/pipeline.py +1 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
- snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +35 -7
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/RECORD +131 -127
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
snowflake/cortex/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from snowflake.cortex._complete import Complete
|
1
|
+
from snowflake.cortex._complete import Complete, CompleteOptions
|
2
2
|
from snowflake.cortex._extract_answer import ExtractAnswer
|
3
3
|
from snowflake.cortex._sentiment import Sentiment
|
4
4
|
from snowflake.cortex._summarize import Summarize
|
@@ -6,6 +6,7 @@ from snowflake.cortex._translate import Translate
|
|
6
6
|
|
7
7
|
__all__ = [
|
8
8
|
"Complete",
|
9
|
+
"CompleteOptions",
|
9
10
|
"ExtractAnswer",
|
10
11
|
"Sentiment",
|
11
12
|
"Summarize",
|
snowflake/cortex/_complete.py
CHANGED
@@ -1,22 +1,235 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
from typing import Iterator, List, Optional, TypedDict, Union, cast
|
4
|
+
from urllib.parse import urlunparse
|
5
|
+
|
6
|
+
import requests
|
7
|
+
from typing_extensions import NotRequired
|
2
8
|
|
3
9
|
from snowflake import snowpark
|
10
|
+
from snowflake.cortex._sse_client import SSEClient
|
4
11
|
from snowflake.cortex._util import (
|
5
12
|
CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
6
|
-
|
7
|
-
|
8
|
-
process_rest_response,
|
13
|
+
SnowflakeAuthenticationException,
|
14
|
+
SnowflakeConfigurationException,
|
9
15
|
)
|
10
16
|
from snowflake.ml._internal import telemetry
|
17
|
+
from snowflake.snowpark import context, functions
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class ConversationMessage(TypedDict):
|
23
|
+
"""Represents an conversation interaction."""
|
24
|
+
|
25
|
+
role: str
|
26
|
+
"""The role of the participant. For example, "user" or "assistant"."""
|
27
|
+
|
28
|
+
content: str
|
29
|
+
"""The content of the message."""
|
30
|
+
|
31
|
+
|
32
|
+
class CompleteOptions(TypedDict):
|
33
|
+
"""Options configuring a snowflake.cortex.Complete call."""
|
34
|
+
|
35
|
+
max_tokens: NotRequired[int]
|
36
|
+
""" Sets the maximum number of output tokens in the response. Small values can result in
|
37
|
+
truncated responses. """
|
38
|
+
temperature: NotRequired[float]
|
39
|
+
""" A value from 0 to 1 (inclusive) that controls the randomness of the output of the language
|
40
|
+
model. A higher temperature (for example, 0.7) results in more diverse and random output, while a lower
|
41
|
+
temperature (such as 0.2) makes the output more deterministic and focused. """
|
42
|
+
|
43
|
+
top_p: NotRequired[float]
|
44
|
+
""" A value from 0 to 1 (inclusive) that controls the randomness and diversity of the language model,
|
45
|
+
generally used as an alternative to temperature. The difference is that top_p restricts the set of possible tokens
|
46
|
+
that the model outputs, while temperature influences which tokens are chosen at each step. """
|
47
|
+
|
48
|
+
|
49
|
+
class ResponseParseException(Exception):
|
50
|
+
"""This exception is raised when the server response cannot be parsed."""
|
51
|
+
|
52
|
+
pass
|
53
|
+
|
54
|
+
|
55
|
+
def _call_complete_rest(
|
56
|
+
model: str,
|
57
|
+
prompt: Union[str, List[ConversationMessage]],
|
58
|
+
options: Optional[CompleteOptions] = None,
|
59
|
+
session: Optional[snowpark.Session] = None,
|
60
|
+
stream: bool = False,
|
61
|
+
) -> requests.Response:
|
62
|
+
session = session or context.get_active_session()
|
63
|
+
if session is None:
|
64
|
+
raise SnowflakeAuthenticationException(
|
65
|
+
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
66
|
+
available in your environment."""
|
67
|
+
)
|
68
|
+
|
69
|
+
if session.connection.host is None or session.connection.host == "":
|
70
|
+
raise SnowflakeConfigurationException("Snowflake connection configuration does not specify 'host'")
|
71
|
+
|
72
|
+
if session.connection.rest is None or not hasattr(session.connection.rest, "token"):
|
73
|
+
raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
|
74
|
+
|
75
|
+
if session.connection.rest.token is None or session.connection.rest.token == "":
|
76
|
+
raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
|
77
|
+
|
78
|
+
scheme = "https"
|
79
|
+
if hasattr(session.connection, "scheme"):
|
80
|
+
scheme = session.connection.scheme
|
81
|
+
url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference/complete", "", "", ""))
|
82
|
+
|
83
|
+
headers = {
|
84
|
+
"Content-Type": "application/json",
|
85
|
+
"Authorization": f'Snowflake Token="{session.connection.rest.token}"',
|
86
|
+
"Accept": "application/json, text/event-stream",
|
87
|
+
}
|
88
|
+
|
89
|
+
data = {
|
90
|
+
"model": model,
|
91
|
+
"stream": stream,
|
92
|
+
}
|
93
|
+
if isinstance(prompt, List):
|
94
|
+
data["messages"] = prompt
|
95
|
+
else:
|
96
|
+
data["messages"] = [{"content": prompt}]
|
97
|
+
|
98
|
+
if options:
|
99
|
+
if "max_tokens" in options:
|
100
|
+
data["max_tokens"] = options["max_tokens"]
|
101
|
+
data["max_output_tokens"] = options["max_tokens"]
|
102
|
+
if "temperature" in options:
|
103
|
+
data["temperature"] = options["temperature"]
|
104
|
+
if "top_p" in options:
|
105
|
+
data["top_p"] = options["top_p"]
|
106
|
+
|
107
|
+
logger.debug(f"making POST request to {url} (model={model}, stream={stream})")
|
108
|
+
response = requests.post(
|
109
|
+
url,
|
110
|
+
json=data,
|
111
|
+
headers=headers,
|
112
|
+
stream=stream,
|
113
|
+
)
|
114
|
+
response.raise_for_status()
|
115
|
+
return response
|
116
|
+
|
117
|
+
|
118
|
+
def _process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
|
119
|
+
if stream:
|
120
|
+
return _return_stream_response(response)
|
121
|
+
|
122
|
+
try:
|
123
|
+
content = response.json()["choices"][0]["message"]["content"]
|
124
|
+
assert isinstance(content, str)
|
125
|
+
return content
|
126
|
+
except (KeyError, IndexError, AssertionError) as e:
|
127
|
+
# Unlike the streaming case, errors are not ignored because a message must be returned.
|
128
|
+
raise ResponseParseException("Failed to parse message from response.") from e
|
129
|
+
|
130
|
+
|
131
|
+
def _return_stream_response(response: requests.Response) -> Iterator[str]:
|
132
|
+
client = SSEClient(response)
|
133
|
+
for event in client.events():
|
134
|
+
try:
|
135
|
+
yield json.loads(event.data)["choices"][0]["delta"]["content"]
|
136
|
+
except (json.JSONDecodeError, KeyError, IndexError):
|
137
|
+
# For the sake of evolution of the output format,
|
138
|
+
# ignore stream messages that don't match the expected format.
|
139
|
+
pass
|
140
|
+
|
141
|
+
|
142
|
+
def _complete_call_sql_function_snowpark(
|
143
|
+
function: str, *args: Union[str, snowpark.Column, CompleteOptions]
|
144
|
+
) -> snowpark.Column:
|
145
|
+
return cast(snowpark.Column, functions.builtin(function)(*args))
|
146
|
+
|
147
|
+
|
148
|
+
def _complete_call_sql_function_immediate(
|
149
|
+
function: str,
|
150
|
+
model: str,
|
151
|
+
prompt: Union[str, List[ConversationMessage]],
|
152
|
+
options: Optional[CompleteOptions],
|
153
|
+
session: Optional[snowpark.Session],
|
154
|
+
) -> str:
|
155
|
+
session = session or context.get_active_session()
|
156
|
+
if session is None:
|
157
|
+
raise SnowflakeAuthenticationException(
|
158
|
+
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
159
|
+
available in your environment."""
|
160
|
+
)
|
161
|
+
|
162
|
+
# https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
|
163
|
+
if options is not None or not isinstance(prompt, str):
|
164
|
+
if isinstance(prompt, List):
|
165
|
+
prompt_arg = prompt
|
166
|
+
else:
|
167
|
+
prompt_arg = [{"role": "user", "content": prompt}]
|
168
|
+
options = options or {}
|
169
|
+
lit_args = [
|
170
|
+
functions.lit(model),
|
171
|
+
functions.lit(prompt_arg),
|
172
|
+
functions.lit(options),
|
173
|
+
]
|
174
|
+
else:
|
175
|
+
lit_args = [
|
176
|
+
functions.lit(model),
|
177
|
+
functions.lit(prompt),
|
178
|
+
]
|
179
|
+
|
180
|
+
empty_df = session.create_dataframe([snowpark.Row()])
|
181
|
+
df = empty_df.select(functions.builtin(function)(*lit_args))
|
182
|
+
return cast(str, df.collect()[0][0])
|
183
|
+
|
184
|
+
|
185
|
+
def _complete_sql_impl(
|
186
|
+
function: str,
|
187
|
+
model: Union[str, snowpark.Column],
|
188
|
+
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
189
|
+
options: Optional[Union[CompleteOptions, snowpark.Column]],
|
190
|
+
session: Optional[snowpark.Session],
|
191
|
+
) -> Union[str, snowpark.Column]:
|
192
|
+
if isinstance(prompt, snowpark.Column):
|
193
|
+
if options is not None:
|
194
|
+
return _complete_call_sql_function_snowpark(function, model, prompt, options)
|
195
|
+
else:
|
196
|
+
return _complete_call_sql_function_snowpark(function, model, prompt)
|
197
|
+
if isinstance(model, snowpark.Column):
|
198
|
+
raise ValueError("'model' cannot be a snowpark.Column when 'prompt' is a string.")
|
199
|
+
if isinstance(options, snowpark.Column):
|
200
|
+
raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.")
|
201
|
+
return _complete_call_sql_function_immediate(function, model, prompt, options, session)
|
202
|
+
|
203
|
+
|
204
|
+
def _complete_impl(
|
205
|
+
model: Union[str, snowpark.Column],
|
206
|
+
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
207
|
+
options: Optional[CompleteOptions] = None,
|
208
|
+
session: Optional[snowpark.Session] = None,
|
209
|
+
use_rest_api_experimental: bool = False,
|
210
|
+
stream: bool = False,
|
211
|
+
function: str = "snowflake.cortex.complete",
|
212
|
+
) -> Union[str, Iterator[str], snowpark.Column]:
|
213
|
+
if use_rest_api_experimental:
|
214
|
+
if not isinstance(model, str):
|
215
|
+
raise ValueError("in REST mode, 'model' must be a string")
|
216
|
+
if not isinstance(prompt, str) and not isinstance(prompt, List):
|
217
|
+
raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
|
218
|
+
response = _call_complete_rest(model, prompt, options, session=session, stream=stream)
|
219
|
+
return _process_rest_response(response, stream=stream)
|
220
|
+
if stream is True:
|
221
|
+
raise ValueError("streaming can only be enabled in REST mode, set use_rest_api_experimental=True")
|
222
|
+
return _complete_sql_impl(function, model, prompt, options, session)
|
11
223
|
|
12
224
|
|
13
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
14
225
|
@telemetry.send_api_usage_telemetry(
|
15
226
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
16
227
|
)
|
17
228
|
def Complete(
|
18
229
|
model: Union[str, snowpark.Column],
|
19
|
-
prompt: Union[str, snowpark.Column],
|
230
|
+
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
231
|
+
*,
|
232
|
+
options: Optional[CompleteOptions] = None,
|
20
233
|
session: Optional[snowpark.Session] = None,
|
21
234
|
use_rest_api_experimental: bool = False,
|
22
235
|
stream: bool = False,
|
@@ -26,6 +239,7 @@ def Complete(
|
|
26
239
|
Args:
|
27
240
|
model: A Column of strings representing model types.
|
28
241
|
prompt: A Column of prompts to send to the LLM.
|
242
|
+
options: A instance of snowflake.cortex.CompleteOptions
|
29
243
|
session: The snowpark session to use. Will be inferred by context if not specified.
|
30
244
|
use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
|
31
245
|
experimental and can be removed at any time.
|
@@ -39,18 +253,7 @@ def Complete(
|
|
39
253
|
Returns:
|
40
254
|
A column of string responses.
|
41
255
|
"""
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
return process_rest_response(response)
|
47
|
-
return _complete_impl("snowflake.cortex.complete", model, prompt, session=session)
|
48
|
-
|
49
|
-
|
50
|
-
def _complete_impl(
|
51
|
-
function: str,
|
52
|
-
model: Union[str, snowpark.Column],
|
53
|
-
prompt: Union[str, snowpark.Column],
|
54
|
-
session: Optional[snowpark.Session] = None,
|
55
|
-
) -> Union[str, snowpark.Column]:
|
56
|
-
return call_sql_function(function, session, model, prompt)
|
256
|
+
try:
|
257
|
+
return _complete_impl(model, prompt, options, session, use_rest_api_experimental, stream)
|
258
|
+
except ValueError as err:
|
259
|
+
raise err
|
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
|
7
7
|
|
8
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
9
8
|
@telemetry.send_api_usage_telemetry(
|
10
9
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
10
|
)
|
snowflake/cortex/_sentiment.py
CHANGED
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
|
7
7
|
|
8
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
9
8
|
@telemetry.send_api_usage_telemetry(
|
10
9
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
10
|
)
|
snowflake/cortex/_summarize.py
CHANGED
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
|
7
7
|
|
8
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
9
8
|
@telemetry.send_api_usage_telemetry(
|
10
9
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
10
|
)
|
snowflake/cortex/_translate.py
CHANGED
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
|
7
7
|
|
8
|
-
@snowpark._internal.utils.experimental(version="1.0.12")
|
9
8
|
@telemetry.send_api_usage_telemetry(
|
10
9
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
11
10
|
)
|
snowflake/cortex/_util.py
CHANGED
@@ -1,24 +1,19 @@
|
|
1
|
-
import
|
2
|
-
from typing import Iterator, Optional, Union, cast
|
3
|
-
from urllib.parse import urljoin, urlparse
|
4
|
-
|
5
|
-
import requests
|
1
|
+
from typing import Dict, Optional, Union, cast
|
6
2
|
|
7
3
|
from snowflake import snowpark
|
8
|
-
from snowflake.cortex._sse_client import SSEClient
|
9
4
|
from snowflake.snowpark import context, functions
|
10
5
|
|
11
6
|
CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
|
12
7
|
|
13
8
|
|
14
|
-
class
|
15
|
-
"""This exception is raised when
|
9
|
+
class SnowflakeAuthenticationException(Exception):
|
10
|
+
"""This exception is raised when there is an issue with Snowflake's configuration."""
|
16
11
|
|
17
12
|
pass
|
18
13
|
|
19
14
|
|
20
|
-
class
|
21
|
-
"""This exception is raised when
|
15
|
+
class SnowflakeConfigurationException(Exception):
|
16
|
+
"""This exception is raised when there is an issue with Snowflake's configuration."""
|
22
17
|
|
23
18
|
pass
|
24
19
|
|
@@ -28,9 +23,10 @@ class SnowflakeAuthenticationException(Exception):
|
|
28
23
|
def call_sql_function(
|
29
24
|
function: str,
|
30
25
|
session: Optional[snowpark.Session],
|
31
|
-
*args: Union[str, snowpark.Column],
|
26
|
+
*args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
|
32
27
|
) -> Union[str, snowpark.Column]:
|
33
28
|
handle_as_column = False
|
29
|
+
|
34
30
|
for arg in args:
|
35
31
|
if isinstance(arg, snowpark.Column):
|
36
32
|
handle_as_column = True
|
@@ -43,17 +39,18 @@ def call_sql_function(
|
|
43
39
|
)
|
44
40
|
|
45
41
|
|
46
|
-
def _call_sql_function_column(
|
42
|
+
def _call_sql_function_column(
|
43
|
+
function: str, *args: Union[str, snowpark.Column, Dict[str, Union[int, float]]]
|
44
|
+
) -> snowpark.Column:
|
47
45
|
return cast(snowpark.Column, functions.builtin(function)(*args))
|
48
46
|
|
49
47
|
|
50
48
|
def _call_sql_function_immediate(
|
51
49
|
function: str,
|
52
50
|
session: Optional[snowpark.Session],
|
53
|
-
*args: Union[str, snowpark.Column],
|
51
|
+
*args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
|
54
52
|
) -> str:
|
55
|
-
|
56
|
-
session = context.get_active_session()
|
53
|
+
session = session or context.get_active_session()
|
57
54
|
if session is None:
|
58
55
|
raise SnowflakeAuthenticationException(
|
59
56
|
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
@@ -67,73 +64,3 @@ def _call_sql_function_immediate(
|
|
67
64
|
empty_df = session.create_dataframe([snowpark.Row()])
|
68
65
|
df = empty_df.select(functions.builtin(function)(*lit_args))
|
69
66
|
return cast(str, df.collect()[0][0])
|
70
|
-
|
71
|
-
|
72
|
-
def call_rest_function(
|
73
|
-
function: str,
|
74
|
-
model: Union[str, snowpark.Column],
|
75
|
-
prompt: Union[str, snowpark.Column],
|
76
|
-
session: Optional[snowpark.Session] = None,
|
77
|
-
stream: bool = False,
|
78
|
-
) -> requests.Response:
|
79
|
-
if session is None:
|
80
|
-
session = context.get_active_session()
|
81
|
-
if session is None:
|
82
|
-
raise SnowflakeAuthenticationException(
|
83
|
-
"""Session required. Provide the session through a session=... argument or ensure an active session is
|
84
|
-
available in your environment."""
|
85
|
-
)
|
86
|
-
|
87
|
-
if not hasattr(session.connection.rest, "token"):
|
88
|
-
raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
|
89
|
-
|
90
|
-
if session.connection.rest.token is None or session.connection.rest.token == "": # type: ignore[union-attr]
|
91
|
-
raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
|
92
|
-
|
93
|
-
url = urljoin(session.connection.host, f"api/v2/cortex/inference/{function}")
|
94
|
-
if urlparse(url).scheme == "":
|
95
|
-
url = "https://" + url
|
96
|
-
headers = {
|
97
|
-
"Content-Type": "application/json",
|
98
|
-
"Authorization": f'Snowflake Token="{session.connection.rest.token}"', # type: ignore[union-attr]
|
99
|
-
"Accept": "application/json, text/event-stream",
|
100
|
-
}
|
101
|
-
|
102
|
-
data = {
|
103
|
-
"model": model,
|
104
|
-
"messages": [{"content": prompt}],
|
105
|
-
"stream": stream,
|
106
|
-
}
|
107
|
-
|
108
|
-
response = requests.post(
|
109
|
-
url,
|
110
|
-
json=data,
|
111
|
-
headers=headers,
|
112
|
-
stream=stream,
|
113
|
-
)
|
114
|
-
response.raise_for_status()
|
115
|
-
return response
|
116
|
-
|
117
|
-
|
118
|
-
def process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
|
119
|
-
if not stream:
|
120
|
-
try:
|
121
|
-
message = response.json()["choices"][0]["message"]
|
122
|
-
output = str(message.get("content", ""))
|
123
|
-
return output
|
124
|
-
except (KeyError, IndexError) as e:
|
125
|
-
raise SSEParseException("Failed to parse streamed response.") from e
|
126
|
-
else:
|
127
|
-
return _return_gen(response)
|
128
|
-
|
129
|
-
|
130
|
-
def _return_gen(response: requests.Response) -> Iterator[str]:
|
131
|
-
client = SSEClient(response)
|
132
|
-
for event in client.events():
|
133
|
-
response_loaded = json.loads(event.data)
|
134
|
-
try:
|
135
|
-
delta = response_loaded["choices"][0]["delta"]
|
136
|
-
output = str(delta.get("content", ""))
|
137
|
-
yield output
|
138
|
-
except (KeyError, IndexError) as e:
|
139
|
-
raise SSEParseException("Failed to parse streamed response.") from e
|
@@ -64,13 +64,20 @@ class ImageRegistryHttpClient:
|
|
64
64
|
operations. For general use of a retryable HTTP client, consider using the "retryable_http" module.
|
65
65
|
"""
|
66
66
|
|
67
|
-
def __init__(self, *, session: snowpark.Session,
|
67
|
+
def __init__(self, *, repo_url: str, session: Optional[snowpark.Session] = None, no_cred: bool = False) -> None:
|
68
68
|
self._repo_url = repo_url
|
69
|
-
self._session_token_manager = session_token_manager.SessionTokenManager(session)
|
70
69
|
self._retryable_http = retryable_http.get_http_client()
|
71
|
-
self.
|
70
|
+
self._no_cred = no_cred
|
71
|
+
|
72
|
+
if not self._no_cred:
|
73
|
+
self._bearer_token = ""
|
74
|
+
assert session is not None
|
75
|
+
self._session_token_manager = session_token_manager.SessionTokenManager(session)
|
72
76
|
|
73
77
|
def _with_bearer_token_header(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
78
|
+
if self._no_cred:
|
79
|
+
return {} if not headers else headers.copy()
|
80
|
+
|
74
81
|
if not self._bearer_token:
|
75
82
|
self._fetch_bearer_token()
|
76
83
|
assert self._bearer_token
|
@@ -13,6 +13,7 @@ This library only supports a limited set of features:
|
|
13
13
|
It's recommended to use this library to copy previously tested images using sha256 to avoid surprises
|
14
14
|
with respect to compatibility.
|
15
15
|
"""
|
16
|
+
|
16
17
|
import dataclasses
|
17
18
|
import hashlib
|
18
19
|
import io
|
@@ -152,7 +153,8 @@ class BlobTransfer:
|
|
152
153
|
src_image: ImageDescriptor
|
153
154
|
dest_image: ImageDescriptor
|
154
155
|
manifest: Manifest
|
155
|
-
|
156
|
+
src_image_registry_http_client: image_registry_http_client.ImageRegistryHttpClient
|
157
|
+
dest_image_registry_http_client: image_registry_http_client.ImageRegistryHttpClient
|
156
158
|
|
157
159
|
def upload_all_blobs(self) -> None:
|
158
160
|
blob_digests = self.manifest.get_blob_digests()
|
@@ -169,7 +171,7 @@ class BlobTransfer:
|
|
169
171
|
"""
|
170
172
|
Check if the blob already exists in the destination registry.
|
171
173
|
"""
|
172
|
-
resp = self.
|
174
|
+
resp = self.dest_image_registry_http_client.head(self.dest_image.blob_link(blob_digest), headers={})
|
173
175
|
return resp.status_code != 200
|
174
176
|
|
175
177
|
def _fetch_blob(self, blob_digest: str) -> Tuple[io.BytesIO, int]:
|
@@ -178,7 +180,7 @@ class BlobTransfer:
|
|
178
180
|
"""
|
179
181
|
src_blob_link = self.src_image.blob_link(blob_digest)
|
180
182
|
headers = {_CONTENT_LENGTH_HEADER: "0"}
|
181
|
-
resp = self.
|
183
|
+
resp = self.src_image_registry_http_client.get(src_blob_link, headers=headers)
|
182
184
|
|
183
185
|
assert resp.status_code == 200, f"Blob GET failed with code {resp.status_code}"
|
184
186
|
assert _CONTENT_LENGTH_HEADER in resp.headers, f"Blob does not contain {_CONTENT_LENGTH_HEADER}"
|
@@ -189,7 +191,7 @@ class BlobTransfer:
|
|
189
191
|
"""
|
190
192
|
Obtain the upload URL from the destination registry.
|
191
193
|
"""
|
192
|
-
response = self.
|
194
|
+
response = self.dest_image_registry_http_client.post(self.dest_image.blob_upload_link())
|
193
195
|
assert (
|
194
196
|
response.status_code == 202
|
195
197
|
), f"Failed to get the upload URL to destination. Status {response.status_code}. {str(response.content)}"
|
@@ -216,14 +218,14 @@ class BlobTransfer:
|
|
216
218
|
headers[_CONTENT_RANGE_HEADER] = f"{start_byte}-{end_byte}"
|
217
219
|
headers[_CONTENT_LENGTH_HEADER] = str(chunk_length)
|
218
220
|
|
219
|
-
resp = self.
|
221
|
+
resp = self.dest_image_registry_http_client.patch(next_loc, headers=headers, data=chunk)
|
220
222
|
assert resp.status_code == 202, f"Blob PATCH failed with code {resp.status_code}"
|
221
223
|
|
222
224
|
next_loc = resp.headers[_LOCATION_HEADER]
|
223
225
|
start_byte += chunk_length
|
224
226
|
|
225
227
|
# Finalize the upload
|
226
|
-
resp = self.
|
228
|
+
resp = self.dest_image_registry_http_client.put(f"{next_loc}&digest={blob_digest}")
|
227
229
|
assert resp.status_code == 201, f"Blob PUT failed with code {resp.status_code}"
|
228
230
|
|
229
231
|
def _transfer(self, blob_digest: str) -> None:
|
@@ -340,21 +342,32 @@ def copy_image(
|
|
340
342
|
src_image: ImageDescriptor,
|
341
343
|
dest_image: ImageDescriptor,
|
342
344
|
arch: _Arch,
|
343
|
-
|
345
|
+
src_retryable_http: image_registry_http_client.ImageRegistryHttpClient,
|
346
|
+
dest_retryable_http: image_registry_http_client.ImageRegistryHttpClient,
|
344
347
|
) -> None:
|
345
348
|
logger.debug(f"Pulling image manifest for {src_image}")
|
346
349
|
|
347
350
|
# 1. Get the manifest
|
348
|
-
manifest = get_manifest(src_image, arch,
|
351
|
+
manifest = get_manifest(src_image, arch, src_retryable_http)
|
349
352
|
logger.debug(f"Manifest pulled for {src_image} with digest {manifest.manifest_digest}")
|
350
353
|
|
351
354
|
# 2: Retrieve all blob digests from manifest; fetch blob based on blob digest, then upload blob.
|
352
|
-
blob_transfer = BlobTransfer(
|
355
|
+
blob_transfer = BlobTransfer(
|
356
|
+
src_image,
|
357
|
+
dest_image,
|
358
|
+
manifest,
|
359
|
+
src_image_registry_http_client=src_retryable_http,
|
360
|
+
dest_image_registry_http_client=dest_retryable_http,
|
361
|
+
)
|
353
362
|
blob_transfer.upload_all_blobs()
|
354
363
|
|
355
364
|
# 3. Upload the manifest
|
356
365
|
logger.debug(f"All blobs copied successfully. Copying manifest for {src_image} to {dest_image}")
|
357
|
-
put_manifest(
|
366
|
+
put_manifest(
|
367
|
+
dest_image,
|
368
|
+
manifest,
|
369
|
+
dest_retryable_http,
|
370
|
+
)
|
358
371
|
|
359
372
|
logger.debug(f"Image {src_image} copied to {dest_image}")
|
360
373
|
|
@@ -201,6 +201,12 @@ class ImageRegistryClient:
|
|
201
201
|
)
|
202
202
|
# TODO[shchen]: Remove the imagelib, instead rely on the copy image system function later.
|
203
203
|
imagelib.copy_image(
|
204
|
-
src_image=src_image,
|
204
|
+
src_image=src_image,
|
205
|
+
dest_image=dest_image,
|
206
|
+
arch=arch,
|
207
|
+
src_retryable_http=image_registry_http_client.ImageRegistryHttpClient(
|
208
|
+
repo_url=src_image.registry_name, no_cred=True
|
209
|
+
),
|
210
|
+
dest_retryable_http=self.image_registry_http_client,
|
205
211
|
)
|
206
212
|
logger.info("Image copy completed successfully")
|
@@ -1,11 +1,11 @@
|
|
1
1
|
# Error code from Snowflake Python Connector.
|
2
|
-
ERRNO_OBJECT_ALREADY_EXISTS =
|
3
|
-
ERRNO_OBJECT_NOT_EXIST =
|
4
|
-
ERRNO_FILES_ALREADY_EXISTING =
|
5
|
-
ERRNO_VERSION_ALREADY_EXISTS =
|
6
|
-
ERRNO_DATASET_NOT_EXIST =
|
7
|
-
ERRNO_DATASET_VERSION_NOT_EXIST =
|
8
|
-
ERRNO_DATASET_VERSION_ALREADY_EXISTS =
|
2
|
+
ERRNO_OBJECT_ALREADY_EXISTS = 2002
|
3
|
+
ERRNO_OBJECT_NOT_EXIST = 2043
|
4
|
+
ERRNO_FILES_ALREADY_EXISTING = 1030
|
5
|
+
ERRNO_VERSION_ALREADY_EXISTS = 92917
|
6
|
+
ERRNO_DATASET_NOT_EXIST = 399019
|
7
|
+
ERRNO_DATASET_VERSION_NOT_EXIST = 399012
|
8
|
+
ERRNO_DATASET_VERSION_ALREADY_EXISTS = 399020
|
9
9
|
|
10
10
|
|
11
11
|
class DatasetError(Exception):
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Error code from Snowflake Python Connector.
|
2
|
-
ERRNO_FILE_EXIST_IN_STAGE =
|
3
|
-
ERRNO_DOMAIN_NOT_EXIST =
|
4
|
-
ERRNO_STAGE_NOT_EXIST =
|
2
|
+
ERRNO_FILE_EXIST_IN_STAGE = 1030
|
3
|
+
ERRNO_DOMAIN_NOT_EXIST = 2003
|
4
|
+
ERRNO_STAGE_NOT_EXIST = 391707
|
5
5
|
|
6
6
|
|
7
7
|
class FileSetError(Exception):
|