snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +66 -35
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +26 -2
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
- snowflake/ml/data/data_connector.py +186 -0
- snowflake/ml/data/data_ingestor.py +45 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/data/ingestor_utils.py +62 -0
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +23 -117
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/example_helper.py +278 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
- snowflake/ml/feature_store/feature_store.py +637 -76
- snowflake/ml/feature_store/feature_view.py +316 -9
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +14 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +7 -4
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
snowflake/cortex/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from snowflake.cortex._classify_text import ClassifyText
|
1
2
|
from snowflake.cortex._complete import Complete, CompleteOptions
|
2
3
|
from snowflake.cortex._extract_answer import ExtractAnswer
|
3
4
|
from snowflake.cortex._sentiment import Sentiment
|
@@ -5,6 +6,7 @@ from snowflake.cortex._summarize import Summarize
|
|
5
6
|
from snowflake.cortex._translate import Translate
|
6
7
|
|
7
8
|
__all__ = [
|
9
|
+
"ClassifyText",
|
8
10
|
"Complete",
|
9
11
|
"CompleteOptions",
|
10
12
|
"ExtractAnswer",
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import List, Optional, Union
|
2
|
+
|
3
|
+
from snowflake import snowpark
|
4
|
+
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
|
+
from snowflake.ml._internal import telemetry
|
6
|
+
|
7
|
+
|
8
|
+
@telemetry.send_api_usage_telemetry(
|
9
|
+
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
|
+
)
|
11
|
+
def ClassifyText(
|
12
|
+
str_input: Union[str, snowpark.Column],
|
13
|
+
categories: Union[List[str], snowpark.Column],
|
14
|
+
session: Optional[snowpark.Session] = None,
|
15
|
+
) -> Union[str, snowpark.Column]:
|
16
|
+
"""Use the LLM inference service to classify the INPUT text into one of the target CATEGORIES.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
str_input: A Column of strings to classify.
|
20
|
+
categories: A list of candidate categories to classify the INPUT text into.
|
21
|
+
session: The snowpark session to use. Will be inferred by context if not specified.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
A column of classification responses.
|
25
|
+
"""
|
26
|
+
|
27
|
+
return _classify_text_impl("snowflake.cortex.classify_text", str_input, categories, session=session)
|
28
|
+
|
29
|
+
|
30
|
+
def _classify_text_impl(
|
31
|
+
function: str,
|
32
|
+
str_input: Union[str, snowpark.Column],
|
33
|
+
categories: Union[List[str], snowpark.Column],
|
34
|
+
session: Optional[snowpark.Session] = None,
|
35
|
+
) -> Union[str, snowpark.Column]:
|
36
|
+
return call_sql_function(function, session, str_input, categories)
|
snowflake/cortex/_complete.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
|
-
|
3
|
+
import time
|
4
|
+
from typing import Any, Callable, Iterator, List, Optional, TypedDict, Union, cast
|
4
5
|
from urllib.parse import urlunparse
|
5
6
|
|
6
7
|
import requests
|
@@ -52,12 +53,43 @@ class ResponseParseException(Exception):
|
|
52
53
|
pass
|
53
54
|
|
54
55
|
|
56
|
+
_MAX_RETRY_SECONDS = 30
|
57
|
+
|
58
|
+
|
59
|
+
def retry(func: Callable[..., requests.Response]) -> Callable[..., requests.Response]:
|
60
|
+
def inner(*args: Any, **kwargs: Any) -> requests.Response:
|
61
|
+
deadline = cast(Optional[float], kwargs["deadline"])
|
62
|
+
kwargs = {key: value for key, value in kwargs.items() if key != "deadline"}
|
63
|
+
expRetrySeconds = 0.5
|
64
|
+
while True:
|
65
|
+
if deadline is not None and time.time() > deadline:
|
66
|
+
raise TimeoutError()
|
67
|
+
response = func(*args, **kwargs)
|
68
|
+
if response.status_code >= 200 and response.status_code < 300:
|
69
|
+
return response
|
70
|
+
retry_status_codes = [429, 503, 504]
|
71
|
+
if response.status_code not in retry_status_codes:
|
72
|
+
response.raise_for_status()
|
73
|
+
logger.debug(f"request failed with status code {response.status_code}, retrying")
|
74
|
+
|
75
|
+
# Formula: delay(i) = max(RetryAfterHeader, min(2^i, _MAX_RETRY_SECONDS)).
|
76
|
+
expRetrySeconds = min(2 * expRetrySeconds, _MAX_RETRY_SECONDS)
|
77
|
+
retrySeconds = expRetrySeconds
|
78
|
+
retryAfterHeader = response.headers.get("retry-after")
|
79
|
+
if retryAfterHeader is not None:
|
80
|
+
retrySeconds = max(retrySeconds, int(retryAfterHeader))
|
81
|
+
logger.debug(f"sleeping for {retrySeconds}s before retrying")
|
82
|
+
time.sleep(retrySeconds)
|
83
|
+
|
84
|
+
return inner
|
85
|
+
|
86
|
+
|
87
|
+
@retry
|
55
88
|
def _call_complete_rest(
|
56
89
|
model: str,
|
57
90
|
prompt: Union[str, List[ConversationMessage]],
|
58
91
|
options: Optional[CompleteOptions] = None,
|
59
92
|
session: Optional[snowpark.Session] = None,
|
60
|
-
stream: bool = False,
|
61
93
|
) -> requests.Response:
|
62
94
|
session = session or context.get_active_session()
|
63
95
|
if session is None:
|
@@ -78,7 +110,7 @@ def _call_complete_rest(
|
|
78
110
|
scheme = "https"
|
79
111
|
if hasattr(session.connection, "scheme"):
|
80
112
|
scheme = session.connection.scheme
|
81
|
-
url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference
|
113
|
+
url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference:complete", "", "", ""))
|
82
114
|
|
83
115
|
headers = {
|
84
116
|
"Content-Type": "application/json",
|
@@ -88,7 +120,7 @@ def _call_complete_rest(
|
|
88
120
|
|
89
121
|
data = {
|
90
122
|
"model": model,
|
91
|
-
"stream":
|
123
|
+
"stream": True,
|
92
124
|
}
|
93
125
|
if isinstance(prompt, List):
|
94
126
|
data["messages"] = prompt
|
@@ -104,33 +136,20 @@ def _call_complete_rest(
|
|
104
136
|
if "top_p" in options:
|
105
137
|
data["top_p"] = options["top_p"]
|
106
138
|
|
107
|
-
logger.debug(f"making POST request to {url} (model={model}
|
108
|
-
|
139
|
+
logger.debug(f"making POST request to {url} (model={model})")
|
140
|
+
return requests.post(
|
109
141
|
url,
|
110
142
|
json=data,
|
111
143
|
headers=headers,
|
112
|
-
stream=
|
144
|
+
stream=True,
|
113
145
|
)
|
114
|
-
response.raise_for_status()
|
115
|
-
return response
|
116
146
|
|
117
147
|
|
118
|
-
def
|
119
|
-
if stream:
|
120
|
-
return _return_stream_response(response)
|
121
|
-
|
122
|
-
try:
|
123
|
-
content = response.json()["choices"][0]["message"]["content"]
|
124
|
-
assert isinstance(content, str)
|
125
|
-
return content
|
126
|
-
except (KeyError, IndexError, AssertionError) as e:
|
127
|
-
# Unlike the streaming case, errors are not ignored because a message must be returned.
|
128
|
-
raise ResponseParseException("Failed to parse message from response.") from e
|
129
|
-
|
130
|
-
|
131
|
-
def _return_stream_response(response: requests.Response) -> Iterator[str]:
|
148
|
+
def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
|
132
149
|
client = SSEClient(response)
|
133
150
|
for event in client.events():
|
151
|
+
if deadline is not None and time.time() > deadline:
|
152
|
+
raise TimeoutError()
|
134
153
|
try:
|
135
154
|
yield json.loads(event.data)["choices"][0]["delta"]["content"]
|
136
155
|
except (json.JSONDecodeError, KeyError, IndexError):
|
@@ -206,19 +225,23 @@ def _complete_impl(
|
|
206
225
|
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
207
226
|
options: Optional[CompleteOptions] = None,
|
208
227
|
session: Optional[snowpark.Session] = None,
|
209
|
-
use_rest_api_experimental: bool = False,
|
210
228
|
stream: bool = False,
|
211
229
|
function: str = "snowflake.cortex.complete",
|
230
|
+
timeout: Optional[float] = None,
|
231
|
+
deadline: Optional[float] = None,
|
212
232
|
) -> Union[str, Iterator[str], snowpark.Column]:
|
213
|
-
if
|
233
|
+
if timeout is not None and deadline is not None:
|
234
|
+
raise ValueError('only one of "timeout" and "deadline" must be set')
|
235
|
+
if timeout is not None:
|
236
|
+
deadline = time.time() + timeout
|
237
|
+
if stream:
|
214
238
|
if not isinstance(model, str):
|
215
239
|
raise ValueError("in REST mode, 'model' must be a string")
|
216
240
|
if not isinstance(prompt, str) and not isinstance(prompt, List):
|
217
241
|
raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
|
218
|
-
response = _call_complete_rest(model, prompt, options, session=session,
|
219
|
-
|
220
|
-
|
221
|
-
raise ValueError("streaming can only be enabled in REST mode, set use_rest_api_experimental=True")
|
242
|
+
response = _call_complete_rest(model, prompt, options, session=session, deadline=deadline)
|
243
|
+
assert response.status_code >= 200 and response.status_code < 300
|
244
|
+
return _return_stream_response(response, deadline)
|
222
245
|
return _complete_sql_impl(function, model, prompt, options, session)
|
223
246
|
|
224
247
|
|
@@ -231,8 +254,9 @@ def Complete(
|
|
231
254
|
*,
|
232
255
|
options: Optional[CompleteOptions] = None,
|
233
256
|
session: Optional[snowpark.Session] = None,
|
234
|
-
use_rest_api_experimental: bool = False,
|
235
257
|
stream: bool = False,
|
258
|
+
timeout: Optional[float] = None,
|
259
|
+
deadline: Optional[float] = None,
|
236
260
|
) -> Union[str, Iterator[str], snowpark.Column]:
|
237
261
|
"""Complete calls into the LLM inference service to perform completion.
|
238
262
|
|
@@ -241,19 +265,26 @@ def Complete(
|
|
241
265
|
prompt: A Column of prompts to send to the LLM.
|
242
266
|
options: A instance of snowflake.cortex.CompleteOptions
|
243
267
|
session: The snowpark session to use. Will be inferred by context if not specified.
|
244
|
-
use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
|
245
|
-
experimental and can be removed at any time.
|
246
268
|
stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
|
247
269
|
output as it is received. Each update is a string containing the new text content since the previous update.
|
248
|
-
|
270
|
+
timeout (float): Timeout in seconds to retry failed REST requests.
|
271
|
+
deadline (float): Time in seconds since the epoch (as returned by time.time()) to retry failed REST requests.
|
249
272
|
|
250
273
|
Raises:
|
251
|
-
ValueError:
|
274
|
+
ValueError: incorrect argument.
|
252
275
|
|
253
276
|
Returns:
|
254
277
|
A column of string responses.
|
255
278
|
"""
|
256
279
|
try:
|
257
|
-
return _complete_impl(
|
280
|
+
return _complete_impl(
|
281
|
+
model,
|
282
|
+
prompt,
|
283
|
+
options=options,
|
284
|
+
session=session,
|
285
|
+
stream=stream,
|
286
|
+
timeout=timeout,
|
287
|
+
deadline=deadline,
|
288
|
+
)
|
258
289
|
except ValueError as err:
|
259
290
|
raise err
|
snowflake/cortex/_util.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, Optional, Union, cast
|
1
|
+
from typing import Dict, List, Optional, Union, cast
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.snowpark import context, functions
|
@@ -23,7 +23,7 @@ class SnowflakeConfigurationException(Exception):
|
|
23
23
|
def call_sql_function(
|
24
24
|
function: str,
|
25
25
|
session: Optional[snowpark.Session],
|
26
|
-
*args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
|
26
|
+
*args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
|
27
27
|
) -> Union[str, snowpark.Column]:
|
28
28
|
handle_as_column = False
|
29
29
|
|
@@ -40,7 +40,7 @@ def call_sql_function(
|
|
40
40
|
|
41
41
|
|
42
42
|
def _call_sql_function_column(
|
43
|
-
function: str, *args: Union[str, snowpark.Column, Dict[str, Union[int, float]]]
|
43
|
+
function: str, *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]]
|
44
44
|
) -> snowpark.Column:
|
45
45
|
return cast(snowpark.Column, functions.builtin(function)(*args))
|
46
46
|
|
@@ -48,7 +48,7 @@ def _call_sql_function_column(
|
|
48
48
|
def _call_sql_function_immediate(
|
49
49
|
function: str,
|
50
50
|
session: Optional[snowpark.Session],
|
51
|
-
*args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
|
51
|
+
*args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
|
52
52
|
) -> str:
|
53
53
|
session = session or context.get_active_session()
|
54
54
|
if session is None:
|
@@ -27,7 +27,6 @@ class CONDA_OS(Enum):
|
|
27
27
|
NO_ARCH = "noarch"
|
28
28
|
|
29
29
|
|
30
|
-
_SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
|
31
30
|
_NODEFAULTS = "nodefaults"
|
32
31
|
_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
|
33
32
|
_SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
|
@@ -36,6 +35,7 @@ _SUPPORTED_PACKAGE_SPEC_OPS = ["==", ">=", "<=", ">", "<"]
|
|
36
35
|
DEFAULT_CHANNEL_NAME = ""
|
37
36
|
SNOWML_SPROC_ENV = "IN_SNOWML_SPROC"
|
38
37
|
SNOWPARK_ML_PKG_NAME = "snowflake-ml-python"
|
38
|
+
SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
|
39
39
|
|
40
40
|
|
41
41
|
def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
|
@@ -370,7 +370,7 @@ def get_matched_package_versions_in_snowflake_conda_channel(
|
|
370
370
|
|
371
371
|
assert not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
|
372
372
|
|
373
|
-
url = f"{
|
373
|
+
url = f"{SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
|
374
374
|
|
375
375
|
if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
|
376
376
|
try:
|
@@ -477,6 +477,7 @@ def save_conda_env_file(
|
|
477
477
|
path: pathlib.Path,
|
478
478
|
conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
|
479
479
|
python_version: str,
|
480
|
+
default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
|
480
481
|
) -> None:
|
481
482
|
"""Generate conda.yml file given a dict of dependencies after validation.
|
482
483
|
The channels part of conda.yml file will contains Snowflake Anaconda Channel, nodefaults and all channel names
|
@@ -489,6 +490,7 @@ def save_conda_env_file(
|
|
489
490
|
path: Path to the conda.yml file.
|
490
491
|
conda_chan_deps: Dict of conda dependencies after validated.
|
491
492
|
python_version: A string 'major.minor' showing python version relate to model.
|
493
|
+
default_channel_override: The default channel to be put in the first place of the channels section.
|
492
494
|
"""
|
493
495
|
assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
|
494
496
|
path.parent.mkdir(parents=True, exist_ok=True)
|
@@ -499,7 +501,11 @@ def save_conda_env_file(
|
|
499
501
|
channels = list(dict(sorted(conda_chan_deps.items(), key=lambda item: len(item[1]), reverse=True)).keys())
|
500
502
|
if DEFAULT_CHANNEL_NAME in channels:
|
501
503
|
channels.remove(DEFAULT_CHANNEL_NAME)
|
502
|
-
|
504
|
+
|
505
|
+
if default_channel_override in channels:
|
506
|
+
channels.remove(default_channel_override)
|
507
|
+
|
508
|
+
env["channels"] = [default_channel_override] + channels + [_NODEFAULTS]
|
503
509
|
env["dependencies"] = [f"python=={python_version}.*"]
|
504
510
|
for chan, reqs in conda_chan_deps.items():
|
505
511
|
env["dependencies"].extend(
|
@@ -567,8 +573,8 @@ def load_conda_env_file(
|
|
567
573
|
python_version = None
|
568
574
|
|
569
575
|
channels = env.get("channels", [])
|
570
|
-
if
|
571
|
-
channels
|
576
|
+
if len(channels) >= 1:
|
577
|
+
channels = channels[1:] # Skip the first channel which is the default channel
|
572
578
|
if _NODEFAULTS in channels:
|
573
579
|
channels.remove(_NODEFAULTS)
|
574
580
|
|
@@ -4,7 +4,10 @@ ATTRIBUTE_NOT_SET = (
|
|
4
4
|
"-differences."
|
5
5
|
)
|
6
6
|
SIZE_MISMATCH = "Size mismatch: {}={}, {}={}."
|
7
|
-
INVALID_MODEL_PARAM =
|
7
|
+
INVALID_MODEL_PARAM = (
|
8
|
+
"Invalid parameter {} for model {}. Valid parameters: {}."
|
9
|
+
"Note: Scikit learn params cannot be set until the model has been fit."
|
10
|
+
)
|
8
11
|
UNSUPPORTED_MODEL_CONVERSION = "Object doesn't support {}. Please use {}."
|
9
12
|
INCOMPATIBLE_NEW_SKLEARN_PARAM = "Incompatible scikit-learn version: {} requires scikit-learn>={}. Installed: {}."
|
10
13
|
REMOVED_SKLEARN_PARAM = "Incompatible scikit-learn version: {} is removed in scikit-learn>={}. Installed: {}."
|
@@ -1,9 +1,9 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Any, Callable, List, Optional
|
3
|
+
from typing import Any, Callable, List, Optional, get_args
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
|
-
from snowflake.ml.
|
6
|
+
from snowflake.ml.data import data_source
|
7
7
|
|
8
8
|
_DATA_SOURCES_ATTR = "_data_sources"
|
9
9
|
|
@@ -39,7 +39,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
|
39
39
|
result: Optional[List[data_source.DataSource]] = None
|
40
40
|
for arg in args:
|
41
41
|
srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
|
42
|
-
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
|
42
|
+
if isinstance(srcs, list) and all(isinstance(s, get_args(data_source.DataSource)) for s in srcs):
|
43
43
|
if result is None:
|
44
44
|
result = []
|
45
45
|
result += srcs
|
@@ -49,7 +49,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
|
49
49
|
def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
|
50
50
|
"""Helper method for attaching data sources to an object"""
|
51
51
|
if data_sources:
|
52
|
-
assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
|
52
|
+
assert all(isinstance(ds, get_args(data_source.DataSource)) for ds in data_sources)
|
53
53
|
setattr(obj, _DATA_SOURCES_ATTR, data_sources)
|
54
54
|
|
55
55
|
|
@@ -44,6 +44,20 @@ _Args = ParamSpec("_Args")
|
|
44
44
|
_ReturnValue = TypeVar("_ReturnValue")
|
45
45
|
|
46
46
|
|
47
|
+
@enum.unique
|
48
|
+
class TelemetryProject(enum.Enum):
|
49
|
+
MLOPS = "MLOps"
|
50
|
+
MODELING = "ModelDevelopment"
|
51
|
+
# TODO: Update with remaining projects.
|
52
|
+
|
53
|
+
|
54
|
+
@enum.unique
|
55
|
+
class TelemetrySubProject(enum.Enum):
|
56
|
+
MONITORING = "Monitoring"
|
57
|
+
REGISTRY = "ModelManagement"
|
58
|
+
# TODO: Update with remaining subprojects.
|
59
|
+
|
60
|
+
|
47
61
|
@enum.unique
|
48
62
|
class TelemetryField(enum.Enum):
|
49
63
|
# constants
|
@@ -277,6 +291,7 @@ def send_api_usage_telemetry(
|
|
277
291
|
]
|
278
292
|
] = None,
|
279
293
|
sfqids_extractor: Optional[Callable[..., List[str]]] = None,
|
294
|
+
subproject_extractor: Optional[Callable[[Any], str]] = None,
|
280
295
|
custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
|
281
296
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
|
282
297
|
"""
|
@@ -290,6 +305,7 @@ def send_api_usage_telemetry(
|
|
290
305
|
conn_attr_name: Name of the SnowflakeConnection attribute in `self`.
|
291
306
|
api_calls_extractor: Extract API calls from `self`.
|
292
307
|
sfqids_extractor: Extract sfqids from `self`.
|
308
|
+
subproject_extractor: Extract subproject at runtime from `self`.
|
293
309
|
custom_tags: Custom tags.
|
294
310
|
|
295
311
|
Returns:
|
@@ -297,10 +313,14 @@ def send_api_usage_telemetry(
|
|
297
313
|
|
298
314
|
Raises:
|
299
315
|
TypeError: If `conn_attr_name` is provided but the conn attribute is not of type SnowflakeConnection.
|
316
|
+
ValueError: If both `subproject` and `subproject_extractor` are provided
|
300
317
|
|
301
318
|
# noqa: DAR402
|
302
319
|
"""
|
303
320
|
|
321
|
+
if subproject is not None and subproject_extractor is not None:
|
322
|
+
raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
|
323
|
+
|
304
324
|
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
|
305
325
|
@functools.wraps(func)
|
306
326
|
def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
|
@@ -322,9 +342,13 @@ def send_api_usage_telemetry(
|
|
322
342
|
if sfqids_extractor:
|
323
343
|
sfqids = sfqids_extractor(args[0])
|
324
344
|
|
345
|
+
subproject_name = subproject
|
346
|
+
if subproject_extractor is not None:
|
347
|
+
subproject_name = subproject_extractor(args[0])
|
348
|
+
|
325
349
|
statement_params = get_function_usage_statement_params(
|
326
350
|
project=project,
|
327
|
-
subproject=
|
351
|
+
subproject=subproject_name,
|
328
352
|
function_category=TelemetryField.FUNC_CAT_USAGE.value,
|
329
353
|
function_name=_get_full_func_name(func),
|
330
354
|
function_parameters=params,
|
@@ -381,7 +405,7 @@ def send_api_usage_telemetry(
|
|
381
405
|
raise e.original_exception from e
|
382
406
|
|
383
407
|
# TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton.
|
384
|
-
telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=
|
408
|
+
telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject_name)
|
385
409
|
telemetry_args = dict(
|
386
410
|
func_name=_get_full_func_name(func),
|
387
411
|
function_category=TelemetryField.FUNC_CAT_USAGE.value,
|
@@ -26,30 +26,11 @@ def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
|
26
26
|
pkg_versions: List[str], session: Session, subproject: Optional[str] = None
|
27
27
|
) -> List[str]:
|
28
28
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
29
|
-
return
|
29
|
+
return pkg_versions
|
30
30
|
else:
|
31
31
|
return _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(pkg_versions, session, subproject)
|
32
32
|
|
33
33
|
|
34
|
-
def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_sync(
|
35
|
-
pkg_versions: List[str], session: Session, subproject: Optional[str] = None
|
36
|
-
) -> List[str]:
|
37
|
-
for pkg_version in pkg_versions:
|
38
|
-
if pkg_version not in cache:
|
39
|
-
pkg_version_list = _query_pkg_version_supported_in_snowflake_conda_channel(
|
40
|
-
pkg_version=pkg_version, session=session, block=True, subproject=subproject
|
41
|
-
)
|
42
|
-
assert isinstance(pkg_version_list, list) # keep mypy happy
|
43
|
-
try:
|
44
|
-
cache[pkg_version] = pkg_version_list[0]["VERSION"]
|
45
|
-
except IndexError:
|
46
|
-
cache[pkg_version] = None
|
47
|
-
|
48
|
-
pkg_version_conda_list = _get_conda_packages_and_emit_warnings(pkg_versions)
|
49
|
-
|
50
|
-
return pkg_version_conda_list
|
51
|
-
|
52
|
-
|
53
34
|
def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
|
54
35
|
pkg_versions: List[str], session: Session, subproject: Optional[str] = None
|
55
36
|
) -> List[str]:
|
@@ -60,7 +41,11 @@ def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
|
|
60
41
|
async_job = _query_pkg_version_supported_in_snowflake_conda_channel(
|
61
42
|
pkg_version=pkg_version, session=session, block=False, subproject=subproject
|
62
43
|
)
|
63
|
-
|
44
|
+
if isinstance(async_job, list):
|
45
|
+
raise RuntimeError(
|
46
|
+
"Async job was expected, executed query was returned. Please contact Snowflake support."
|
47
|
+
)
|
48
|
+
|
64
49
|
pkg_version_async_job_list.append((pkg_version, async_job))
|
65
50
|
|
66
51
|
# Populate the cache.
|
@@ -143,7 +128,8 @@ def _get_conda_packages_and_emit_warnings(pkg_versions: List[str]) -> List[str]:
|
|
143
128
|
warnings.warn(
|
144
129
|
f"Package {', '.join([pkg[0] for pkg in pkg_version_warning_list])} is not supported "
|
145
130
|
f"in snowflake conda channel for python runtime "
|
146
|
-
f"{', '.join([pkg[1] for pkg in pkg_version_warning_list])}."
|
131
|
+
f"{', '.join([pkg[1] for pkg in pkg_version_warning_list])}.",
|
132
|
+
stacklevel=1,
|
147
133
|
)
|
148
134
|
|
149
135
|
return pkg_version_conda_list
|