snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +25 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +31 -1
- snowflake/ml/jobs/_utils/payload_utils.py +232 -72
- snowflake/ml/jobs/_utils/spec_utils.py +78 -38
- snowflake/ml/jobs/decorators.py +8 -25
- snowflake/ml/jobs/job.py +4 -4
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +49 -29
- snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +54 -33
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +12 -20
- snowflake/ml/model/_signatures/pandas_handler.py +28 -37
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +39 -13
- snowflake/ml/model/type_hints.py +28 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +18 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +52 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
snowflake/cortex/_complete.py
CHANGED
@@ -23,6 +23,15 @@ logger = logging.getLogger(__name__)
|
|
23
23
|
_REST_COMPLETE_URL = "/api/v2/cortex/inference:complete"
|
24
24
|
|
25
25
|
|
26
|
+
class ResponseFormat(TypedDict):
|
27
|
+
"""Represents an object describing response format config for structured-output mode"""
|
28
|
+
|
29
|
+
type: str
|
30
|
+
"""The response format type (e.g. "json")"""
|
31
|
+
schema: Dict[str, Any]
|
32
|
+
"""The schema defining the structure of the response. For json it should be a valid json schema object"""
|
33
|
+
|
34
|
+
|
26
35
|
class ConversationMessage(TypedDict):
|
27
36
|
"""Represents an conversation interaction."""
|
28
37
|
|
@@ -53,6 +62,9 @@ class CompleteOptions(TypedDict):
|
|
53
62
|
""" A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
|
54
63
|
from the language model. """
|
55
64
|
|
65
|
+
response_format: NotRequired[ResponseFormat]
|
66
|
+
""" An object describing response format config for structured-output mode """
|
67
|
+
|
56
68
|
|
57
69
|
class ResponseParseException(Exception):
|
58
70
|
"""This exception is raised when the server response cannot be parsed."""
|
@@ -108,6 +120,32 @@ def _make_common_request_headers() -> Dict[str, str]:
|
|
108
120
|
return headers
|
109
121
|
|
110
122
|
|
123
|
+
def _validate_response_format_object(options: CompleteOptions) -> None:
|
124
|
+
"""Validate the response format object for structured-output mode.
|
125
|
+
|
126
|
+
More details can be found in:
|
127
|
+
docs.snowflake.com/en/user-guide/snowflake-cortex/complete-structured-outputs#using-complete-structured-outputs
|
128
|
+
|
129
|
+
Args:
|
130
|
+
options: The complete options object.
|
131
|
+
|
132
|
+
Raises:
|
133
|
+
ValueError: If the response format object is invalid or missing required fields.
|
134
|
+
"""
|
135
|
+
if options is not None and options.get("response_format") is not None:
|
136
|
+
options_obj = options.get("response_format")
|
137
|
+
if not isinstance(options_obj, dict):
|
138
|
+
raise ValueError("'response_format' should be an object")
|
139
|
+
if options_obj.get("type") is None:
|
140
|
+
raise ValueError("'type' cannot be empty for 'response_format' object")
|
141
|
+
if not isinstance(options_obj.get("type"), str):
|
142
|
+
raise ValueError("'type' needs to be a str for 'response_format' object")
|
143
|
+
if options_obj.get("schema") is None:
|
144
|
+
raise ValueError("'schema' cannot be empty for 'response_format' object")
|
145
|
+
if not isinstance(options_obj.get("schema"), dict):
|
146
|
+
raise ValueError("'schema' needs to be a dict for 'response_format' object")
|
147
|
+
|
148
|
+
|
111
149
|
def _make_request_body(
|
112
150
|
model: str,
|
113
151
|
prompt: Union[str, List[ConversationMessage]],
|
@@ -136,12 +174,16 @@ def _make_request_body(
|
|
136
174
|
"response_when_unsafe": "Response filtered by Cortex Guard",
|
137
175
|
}
|
138
176
|
data["guardrails"] = guardrails_options
|
177
|
+
if "response_format" in options:
|
178
|
+
data["response_format"] = options["response_format"]
|
179
|
+
|
139
180
|
return data
|
140
181
|
|
141
182
|
|
142
183
|
# XP endpoint returns a dict response which needs to be converted to a format which can
|
143
184
|
# be consumed by the SSEClient. This method does that.
|
144
185
|
def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
|
186
|
+
|
145
187
|
response = requests.Response()
|
146
188
|
response.status_code = int(raw_resp["status"])
|
147
189
|
response.headers = raw_resp["headers"]
|
@@ -159,7 +201,6 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
|
|
159
201
|
data = json.loads(data)
|
160
202
|
except json.JSONDecodeError:
|
161
203
|
raise ValueError(f"Request failed (request id: {request_id})")
|
162
|
-
|
163
204
|
if response.status_code < 200 or response.status_code >= 300:
|
164
205
|
if "message" not in data:
|
165
206
|
raise ValueError(f"Request failed (request id: {request_id})")
|
@@ -241,11 +282,21 @@ def _return_stream_response(response: requests.Response, deadline: Optional[floa
|
|
241
282
|
if deadline is not None and time.time() > deadline:
|
242
283
|
raise TimeoutError()
|
243
284
|
try:
|
244
|
-
|
285
|
+
parsed_resp = json.loads(event.data)
|
286
|
+
except json.JSONDecodeError:
|
287
|
+
raise ResponseParseException("Server response cannot be parsed")
|
288
|
+
try:
|
289
|
+
yield parsed_resp["choices"][0]["delta"]["content"]
|
245
290
|
except (json.JSONDecodeError, KeyError, IndexError):
|
246
291
|
# For the sake of evolution of the output format,
|
247
292
|
# ignore stream messages that don't match the expected format.
|
248
|
-
|
293
|
+
|
294
|
+
# This is the case of midstream errors which were introduced specifically for structured output.
|
295
|
+
# TODO: discuss during code review
|
296
|
+
if parsed_resp.get("error"):
|
297
|
+
yield json.dumps(parsed_resp)
|
298
|
+
else:
|
299
|
+
pass
|
249
300
|
|
250
301
|
|
251
302
|
def _complete_call_sql_function_snowpark(
|
@@ -291,6 +342,8 @@ def _complete_non_streaming_impl(
|
|
291
342
|
raise ValueError("'model' cannot be a snowpark.Column when 'prompt' is a string.")
|
292
343
|
if isinstance(options, snowpark.Column):
|
293
344
|
raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.")
|
345
|
+
if options and not isinstance(options, snowpark.Column):
|
346
|
+
_validate_response_format_object(options)
|
294
347
|
return _complete_non_streaming_immediate(
|
295
348
|
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
296
349
|
model=model,
|
@@ -309,6 +362,8 @@ def _complete_rest(
|
|
309
362
|
session: Optional[snowpark.Session] = None,
|
310
363
|
deadline: Optional[float] = None,
|
311
364
|
) -> Iterator[str]:
|
365
|
+
if options:
|
366
|
+
_validate_response_format_object(options)
|
312
367
|
if snow_api_xp_request_handler is not None:
|
313
368
|
response = _call_complete_xp(
|
314
369
|
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
@@ -12,7 +12,7 @@ import yaml
|
|
12
12
|
from packaging import requirements, specifiers, version
|
13
13
|
|
14
14
|
import snowflake.connector
|
15
|
-
from snowflake.ml._internal import env as snowml_env
|
15
|
+
from snowflake.ml._internal import env as snowml_env, relax_version_strategy
|
16
16
|
from snowflake.ml._internal.utils import query_result_checker
|
17
17
|
from snowflake.snowpark import context, exceptions, session
|
18
18
|
|
@@ -56,6 +56,8 @@ def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
|
|
56
56
|
|
57
57
|
if r.name == "python":
|
58
58
|
raise ValueError("Don't specify python as a dependency, use python version argument instead.")
|
59
|
+
if r.name == "cuda":
|
60
|
+
raise ValueError("Don't specify cuda as a dependency, use cuda version argument instead.")
|
59
61
|
except requirements.InvalidRequirement:
|
60
62
|
raise ValueError(f"Invalid package requirement {req_str} found.")
|
61
63
|
|
@@ -313,19 +315,14 @@ def get_package_spec_with_supported_ops_only(req: requirements.Requirement) -> r
|
|
313
315
|
return new_req
|
314
316
|
|
315
317
|
|
316
|
-
def
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
Returns:
|
324
|
-
A new requirement object after relaxations.
|
325
|
-
"""
|
326
|
-
new_req = copy.deepcopy(req)
|
318
|
+
def _relax_specifier_set(
|
319
|
+
specifier_set: specifiers.SpecifierSet, strategy: relax_version_strategy.RelaxVersionStrategy
|
320
|
+
) -> specifiers.SpecifierSet:
|
321
|
+
if strategy == relax_version_strategy.RelaxVersionStrategy.NO_RELAX:
|
322
|
+
return specifier_set
|
323
|
+
specifier_set = copy.deepcopy(specifier_set)
|
327
324
|
relaxed_specifier_set = set()
|
328
|
-
for spec in
|
325
|
+
for spec in specifier_set._specs:
|
329
326
|
if spec.operator != "==":
|
330
327
|
relaxed_specifier_set.add(spec)
|
331
328
|
continue
|
@@ -337,9 +334,40 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
|
|
337
334
|
relaxed_specifier_set.add(spec)
|
338
335
|
continue
|
339
336
|
assert pinned_version is not None
|
340
|
-
|
341
|
-
|
342
|
-
|
337
|
+
if strategy == relax_version_strategy.RelaxVersionStrategy.PATCH:
|
338
|
+
relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
|
339
|
+
relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major}.{pinned_version.minor+1}"))
|
340
|
+
elif strategy == relax_version_strategy.RelaxVersionStrategy.MINOR:
|
341
|
+
relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
|
342
|
+
relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
|
343
|
+
elif strategy == relax_version_strategy.RelaxVersionStrategy.MAJOR:
|
344
|
+
relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}"))
|
345
|
+
relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
|
346
|
+
specifier_set._specs = frozenset(relaxed_specifier_set)
|
347
|
+
return specifier_set
|
348
|
+
|
349
|
+
|
350
|
+
def relax_requirement_version(req: requirements.Requirement) -> requirements.Requirement:
|
351
|
+
"""Relax version specifier from a requirement. It detects any ==x.y.z in specifiers and replaced with relaxed
|
352
|
+
version specifier based on the strategy defined in RELAX_VERSION_STRATEGY_MAP.
|
353
|
+
|
354
|
+
NO_RELAX: No relaxation.
|
355
|
+
PATCH: >=x.y, <x.(y+1)
|
356
|
+
MINOR (default): >=x.y, <(x+1)
|
357
|
+
MAJOR: >=x, <(x+1)
|
358
|
+
|
359
|
+
|
360
|
+
Args:
|
361
|
+
req: The requirement that version specifier to be removed.
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
A new requirement object after relaxations.
|
365
|
+
"""
|
366
|
+
new_req = copy.deepcopy(req)
|
367
|
+
strategy = relax_version_strategy.RELAX_VERSION_STRATEGY_MAP.get(
|
368
|
+
req.name, relax_version_strategy.RelaxVersionStrategy.MINOR
|
369
|
+
)
|
370
|
+
new_req.specifier = _relax_specifier_set(new_req.specifier, strategy)
|
343
371
|
return new_req
|
344
372
|
|
345
373
|
|
@@ -431,10 +459,11 @@ def save_conda_env_file(
|
|
431
459
|
path: pathlib.Path,
|
432
460
|
conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
|
433
461
|
python_version: str,
|
462
|
+
cuda_version: Optional[str] = None,
|
434
463
|
default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
|
435
464
|
) -> None:
|
436
465
|
"""Generate conda.yml file given a dict of dependencies after validation.
|
437
|
-
The channels part of conda.yml file will
|
466
|
+
The channels part of conda.yml file will contain Snowflake Anaconda Channel, nodefaults and all channel names
|
438
467
|
in keys of the dict, ordered by the number of the packages which belongs to.
|
439
468
|
The dependencies part of conda.yml file will contains requirements specifications. If the requirements is in the
|
440
469
|
value list whose key is DEFAULT_CHANNEL_NAME, then the channel won't be specified explicitly. Otherwise, it will be
|
@@ -443,7 +472,8 @@ def save_conda_env_file(
|
|
443
472
|
Args:
|
444
473
|
path: Path to the conda.yml file.
|
445
474
|
conda_chan_deps: Dict of conda dependencies after validated.
|
446
|
-
python_version: A string 'major.minor'
|
475
|
+
python_version: A string 'major.minor' for the model's python version.
|
476
|
+
cuda_version: A string 'major.minor' for the model's cuda version.
|
447
477
|
default_channel_override: The default channel to be put in the first place of the channels section.
|
448
478
|
"""
|
449
479
|
assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
|
@@ -461,6 +491,10 @@ def save_conda_env_file(
|
|
461
491
|
|
462
492
|
env["channels"] = [default_channel_override] + channels + [_NODEFAULTS]
|
463
493
|
env["dependencies"] = [f"python=={python_version}.*"]
|
494
|
+
|
495
|
+
if cuda_version is not None:
|
496
|
+
env["dependencies"].extend([f"nvidia::cuda=={cuda_version}.*"])
|
497
|
+
|
464
498
|
for chan, reqs in conda_chan_deps.items():
|
465
499
|
env["dependencies"].extend(
|
466
500
|
[f"{chan}::{str(req)}" if chan != DEFAULT_CHANNEL_NAME else str(req) for req in reqs]
|
@@ -487,7 +521,12 @@ def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requi
|
|
487
521
|
|
488
522
|
def load_conda_env_file(
|
489
523
|
path: pathlib.Path,
|
490
|
-
) -> Tuple[
|
524
|
+
) -> Tuple[
|
525
|
+
DefaultDict[str, List[requirements.Requirement]],
|
526
|
+
Optional[List[requirements.Requirement]],
|
527
|
+
Optional[str],
|
528
|
+
Optional[str],
|
529
|
+
]:
|
491
530
|
"""Read conda.yml file to get a dict of dependencies after validation.
|
492
531
|
The channels part of conda.yml file will be processed with following rules:
|
493
532
|
1. If it is Snowflake Anaconda Channel, ignore as it is default.
|
@@ -515,7 +554,7 @@ def load_conda_env_file(
|
|
515
554
|
and a string 'major.minor.patchlevel' of python version.
|
516
555
|
"""
|
517
556
|
if not path.exists():
|
518
|
-
return collections.defaultdict(list), None, None
|
557
|
+
return collections.defaultdict(list), None, None, None
|
519
558
|
|
520
559
|
with open(path, encoding="utf-8") as f:
|
521
560
|
env = yaml.safe_load(stream=f)
|
@@ -526,6 +565,7 @@ def load_conda_env_file(
|
|
526
565
|
pip_deps = []
|
527
566
|
|
528
567
|
python_version = None
|
568
|
+
cuda_version = None
|
529
569
|
|
530
570
|
channels = env.get("channels", [])
|
531
571
|
if len(channels) >= 1:
|
@@ -541,6 +581,9 @@ def load_conda_env_file(
|
|
541
581
|
# ver is str: python w/ specifier
|
542
582
|
if ver:
|
543
583
|
python_version = ver
|
584
|
+
elif dep.startswith("nvidia::cuda"):
|
585
|
+
r = requirements.Requirement(dep.split("nvidia::")[1])
|
586
|
+
cuda_version = list(r.specifier)[0].version.strip(".*")
|
544
587
|
elif ver is None:
|
545
588
|
deps.append(dep)
|
546
589
|
elif isinstance(dep, dict) and "pip" in dep:
|
@@ -555,7 +598,7 @@ def load_conda_env_file(
|
|
555
598
|
if channel not in conda_dep_dict:
|
556
599
|
conda_dep_dict[channel] = []
|
557
600
|
|
558
|
-
return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version
|
601
|
+
return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version, cuda_version
|
559
602
|
|
560
603
|
|
561
604
|
def load_requirements_file(path: pathlib.Path) -> List[requirements.Requirement]:
|
@@ -23,6 +23,7 @@ from typing import (
|
|
23
23
|
Tuple,
|
24
24
|
Union,
|
25
25
|
)
|
26
|
+
from urllib import parse
|
26
27
|
|
27
28
|
import cloudpickle
|
28
29
|
|
@@ -294,7 +295,7 @@ def _retry_on_sql_error(exception: Exception) -> bool:
|
|
294
295
|
def upload_directory_to_stage(
|
295
296
|
session: snowpark.Session,
|
296
297
|
local_path: pathlib.Path,
|
297
|
-
stage_path: pathlib.PurePosixPath,
|
298
|
+
stage_path: Union[pathlib.PurePosixPath, parse.ParseResult],
|
298
299
|
*,
|
299
300
|
statement_params: Optional[Dict[str, Any]] = None,
|
300
301
|
) -> None:
|
@@ -314,9 +315,22 @@ def upload_directory_to_stage(
|
|
314
315
|
root_path = pathlib.Path(root)
|
315
316
|
for filename in filenames:
|
316
317
|
local_file_path = root_path / filename
|
317
|
-
|
318
|
-
|
319
|
-
)
|
318
|
+
relative_path = pathlib.PurePosixPath(local_file_path.relative_to(local_path).as_posix())
|
319
|
+
|
320
|
+
if isinstance(stage_path, parse.ParseResult):
|
321
|
+
relative_stage_path = (pathlib.PosixPath(stage_path.path) / relative_path).parent
|
322
|
+
new_url = parse.ParseResult(
|
323
|
+
scheme=stage_path.scheme,
|
324
|
+
netloc=stage_path.netloc,
|
325
|
+
path=str(relative_stage_path),
|
326
|
+
params=stage_path.params,
|
327
|
+
query=stage_path.query,
|
328
|
+
fragment=stage_path.fragment,
|
329
|
+
)
|
330
|
+
stage_dir_path = parse.urlunparse(new_url)
|
331
|
+
else:
|
332
|
+
stage_dir_path = str((stage_path / relative_path).parent)
|
333
|
+
|
320
334
|
retrying.retry(
|
321
335
|
retry_on_exception=_retry_on_sql_error,
|
322
336
|
stop_max_attempt_number=5,
|
@@ -37,6 +37,9 @@ class PlatformCapabilities:
|
|
37
37
|
def is_nested_function_enabled(self) -> bool:
|
38
38
|
return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
|
39
39
|
|
40
|
+
def is_live_commit_enabled(self) -> bool:
|
41
|
+
return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
|
42
|
+
|
40
43
|
@staticmethod
|
41
44
|
def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
|
42
45
|
try:
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class RelaxVersionStrategy(Enum):
|
5
|
+
NO_RELAX = "no_relax"
|
6
|
+
PATCH = "patch"
|
7
|
+
MINOR = "minor"
|
8
|
+
MAJOR = "major"
|
9
|
+
|
10
|
+
|
11
|
+
RELAX_VERSION_STRATEGY_MAP = {
|
12
|
+
# The version of cloudpickle should not be relaxed as it is used for serialization.
|
13
|
+
"cloudpickle": RelaxVersionStrategy.NO_RELAX,
|
14
|
+
# The version of scikit-learn should be relaxed only in patch version as it has breaking changes in minor version.
|
15
|
+
"scikit-learn": RelaxVersionStrategy.PATCH,
|
16
|
+
}
|
@@ -4,6 +4,9 @@ import enum
|
|
4
4
|
import functools
|
5
5
|
import inspect
|
6
6
|
import operator
|
7
|
+
import sys
|
8
|
+
import time
|
9
|
+
import traceback
|
7
10
|
import types
|
8
11
|
from typing import (
|
9
12
|
Any,
|
@@ -75,6 +78,8 @@ class TelemetryField(enum.Enum):
|
|
75
78
|
KEY_FUNC_PARAMS = "func_params"
|
76
79
|
KEY_ERROR_INFO = "error_info"
|
77
80
|
KEY_ERROR_CODE = "error_code"
|
81
|
+
KEY_STACK_TRACE = "stack_trace"
|
82
|
+
KEY_DURATION = "duration"
|
78
83
|
KEY_VERSION = "version"
|
79
84
|
KEY_PYTHON_VERSION = "python_version"
|
80
85
|
KEY_OS = "operating_system"
|
@@ -348,6 +353,10 @@ def get_function_usage_statement_params(
|
|
348
353
|
statement_params[TelemetryField.KEY_API_CALLS.value].append({TelemetryField.NAME.value: api_call})
|
349
354
|
if custom_tags:
|
350
355
|
statement_params[TelemetryField.KEY_CUSTOM_TAGS.value] = custom_tags
|
356
|
+
# Snowpark doesn't support None value in statement_params from version 1.29
|
357
|
+
for k in statement_params:
|
358
|
+
if statement_params[k] is None:
|
359
|
+
statement_params[k] = ""
|
351
360
|
return statement_params
|
352
361
|
|
353
362
|
|
@@ -435,6 +444,7 @@ def send_api_usage_telemetry(
|
|
435
444
|
|
436
445
|
# noqa: DAR402
|
437
446
|
"""
|
447
|
+
start_time = time.perf_counter()
|
438
448
|
|
439
449
|
if subproject is not None and subproject_extractor is not None:
|
440
450
|
raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
|
@@ -555,8 +565,16 @@ def send_api_usage_telemetry(
|
|
555
565
|
)
|
556
566
|
else:
|
557
567
|
me = e
|
568
|
+
|
558
569
|
telemetry_args["error"] = repr(me)
|
559
570
|
telemetry_args["error_code"] = me.error_code
|
571
|
+
# exclude telemetry frames
|
572
|
+
excluded_frames = 2
|
573
|
+
tb = traceback.extract_tb(sys.exc_info()[2])
|
574
|
+
formatted_tb = "".join(traceback.format_list(tb[excluded_frames:]))
|
575
|
+
formatted_exception = traceback.format_exception_only(*sys.exc_info()[:2])[0] # error type + message
|
576
|
+
telemetry_args["stack_trace"] = formatted_tb + formatted_exception
|
577
|
+
|
560
578
|
me.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined]
|
561
579
|
if e is not me:
|
562
580
|
raise # Directly raise non-wrapped exceptions to preserve original stacktrace
|
@@ -565,6 +583,7 @@ def send_api_usage_telemetry(
|
|
565
583
|
else:
|
566
584
|
raise me.original_exception from e
|
567
585
|
finally:
|
586
|
+
telemetry_args["duration"] = time.perf_counter() - start_time # type: ignore[assignment]
|
568
587
|
telemetry.send_function_usage_telemetry(**telemetry_args)
|
569
588
|
global _log_counter
|
570
589
|
_log_counter += 1
|
@@ -718,12 +737,14 @@ class _SourceTelemetryClient:
|
|
718
737
|
self,
|
719
738
|
func_name: str,
|
720
739
|
function_category: str,
|
740
|
+
duration: float,
|
721
741
|
func_params: Optional[Dict[str, Any]] = None,
|
722
742
|
api_calls: Optional[List[Dict[str, Any]]] = None,
|
723
743
|
sfqids: Optional[List[Any]] = None,
|
724
744
|
custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
|
725
745
|
error: Optional[str] = None,
|
726
746
|
error_code: Optional[str] = None,
|
747
|
+
stack_trace: Optional[str] = None,
|
727
748
|
) -> None:
|
728
749
|
"""
|
729
750
|
Send function usage telemetry message.
|
@@ -731,12 +752,14 @@ class _SourceTelemetryClient:
|
|
731
752
|
Args:
|
732
753
|
func_name: Function name.
|
733
754
|
function_category: Function category.
|
755
|
+
duration: Function duration.
|
734
756
|
func_params: Function parameters.
|
735
757
|
api_calls: API calls.
|
736
758
|
sfqids: Snowflake query IDs.
|
737
759
|
custom_tags: Custom tags.
|
738
760
|
error: Error.
|
739
761
|
error_code: Error code.
|
762
|
+
stack_trace: Error stack trace.
|
740
763
|
"""
|
741
764
|
data: Dict[str, Any] = {
|
742
765
|
TelemetryField.KEY_FUNC_NAME.value: func_name,
|
@@ -755,11 +778,13 @@ class _SourceTelemetryClient:
|
|
755
778
|
message: Dict[str, Any] = {
|
756
779
|
**self._create_basic_telemetry_data(telemetry_type),
|
757
780
|
TelemetryField.KEY_DATA.value: data,
|
781
|
+
TelemetryField.KEY_DURATION.value: duration,
|
758
782
|
}
|
759
783
|
|
760
784
|
if error:
|
761
785
|
message[TelemetryField.KEY_ERROR_INFO.value] = error
|
762
786
|
message[TelemetryField.KEY_ERROR_CODE.value] = error_code
|
787
|
+
message[TelemetryField.KEY_STACK_TRACE.value] = stack_trace
|
763
788
|
|
764
789
|
self._send(message)
|
765
790
|
|
@@ -116,7 +116,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
116
116
|
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
|
117
117
|
ds = self._get_dataset(shuffle=False)
|
118
118
|
table = ds.to_table() if limit is None else ds.head(num_rows=limit)
|
119
|
-
return table.to_pandas()
|
119
|
+
return table.to_pandas(split_blocks=True, self_destruct=True)
|
120
120
|
|
121
121
|
def _get_dataset(self, shuffle: bool) -> pds.Dataset:
|
122
122
|
format = self._format
|
@@ -144,6 +144,7 @@ _LIST_FEATURE_VIEW_SCHEMA = StructType(
|
|
144
144
|
StructField("refresh_mode", StringType()),
|
145
145
|
StructField("scheduling_state", StringType()),
|
146
146
|
StructField("warehouse", StringType()),
|
147
|
+
StructField("cluster_by", StringType()),
|
147
148
|
]
|
148
149
|
)
|
149
150
|
|
@@ -1832,6 +1833,12 @@ class FeatureStore:
|
|
1832
1833
|
WAREHOUSE = {warehouse}
|
1833
1834
|
REFRESH_MODE = {feature_view.refresh_mode}
|
1834
1835
|
INITIALIZE = {feature_view.initialize}
|
1836
|
+
"""
|
1837
|
+
if feature_view.cluster_by:
|
1838
|
+
cluster_by_clause = f"CLUSTER BY ({', '.join(feature_view.cluster_by)})"
|
1839
|
+
query += f"{cluster_by_clause}"
|
1840
|
+
|
1841
|
+
query += f"""
|
1835
1842
|
AS {feature_view.query}
|
1836
1843
|
"""
|
1837
1844
|
self._session.sql(query).collect(block=block, statement_params=self._telemetry_stmp)
|
@@ -2249,6 +2256,7 @@ class FeatureStore:
|
|
2249
2256
|
values.append(row["refresh_mode"] if "refresh_mode" in row else None)
|
2250
2257
|
values.append(row["scheduling_state"] if "scheduling_state" in row else None)
|
2251
2258
|
values.append(row["warehouse"] if "warehouse" in row else None)
|
2259
|
+
values.append(json.dumps(self._extract_cluster_by_columns(row["cluster_by"])) if "cluster_by" in row else None)
|
2252
2260
|
output_values.append(values)
|
2253
2261
|
|
2254
2262
|
def _lookup_feature_view_metadata(self, row: Row, fv_name: str) -> Tuple[_FeatureViewMetadata, str]:
|
@@ -2335,6 +2343,7 @@ class FeatureStore:
|
|
2335
2343
|
owner=row["owner"],
|
2336
2344
|
infer_schema_df=infer_schema_df,
|
2337
2345
|
session=self._session,
|
2346
|
+
cluster_by=self._extract_cluster_by_columns(row["cluster_by"]),
|
2338
2347
|
)
|
2339
2348
|
return fv
|
2340
2349
|
else:
|
@@ -2625,3 +2634,12 @@ class FeatureStore:
|
|
2625
2634
|
)
|
2626
2635
|
|
2627
2636
|
return feature_view
|
2637
|
+
|
2638
|
+
@staticmethod
|
2639
|
+
def _extract_cluster_by_columns(cluster_by_clause: str) -> List[str]:
|
2640
|
+
# Use regex to extract elements inside the parentheses.
|
2641
|
+
match = re.search(r"\((.*?)\)", cluster_by_clause)
|
2642
|
+
if match:
|
2643
|
+
# Handle both quoted and unquoted column names.
|
2644
|
+
return re.findall(identifier.SF_IDENTIFIER_RE, match.group(1))
|
2645
|
+
return []
|
@@ -170,6 +170,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
170
170
|
warehouse: Optional[str] = None,
|
171
171
|
initialize: str = "ON_CREATE",
|
172
172
|
refresh_mode: str = "AUTO",
|
173
|
+
cluster_by: Optional[List[str]] = None,
|
173
174
|
**_kwargs: Any,
|
174
175
|
) -> None:
|
175
176
|
"""
|
@@ -200,6 +201,9 @@ class FeatureView(lineage_node.LineageNode):
|
|
200
201
|
refresh_mode: The refresh mode of managed feature view. The value can be 'AUTO', 'FULL' or 'INCREMENETAL'.
|
201
202
|
For managed feature view, the default value is 'AUTO'. For static feature view it has no effect.
|
202
203
|
Check https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table for for details.
|
204
|
+
cluster_by: Columns to cluster the feature view by.
|
205
|
+
- Defaults to the join keys from entities.
|
206
|
+
- If `timestamp_col` is provided, it is added to the default clustering keys.
|
203
207
|
_kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
|
204
208
|
|
205
209
|
Example::
|
@@ -224,6 +228,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
224
228
|
>>> print(registered_fv.status)
|
225
229
|
FeatureViewStatus.ACTIVE
|
226
230
|
|
231
|
+
# noqa: DAR401
|
227
232
|
"""
|
228
233
|
|
229
234
|
self._name: SqlIdentifier = SqlIdentifier(name)
|
@@ -233,7 +238,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
233
238
|
SqlIdentifier(timestamp_col) if timestamp_col is not None else None
|
234
239
|
)
|
235
240
|
self._desc: str = desc
|
236
|
-
self._infer_schema_df: DataFrame = _kwargs.
|
241
|
+
self._infer_schema_df: DataFrame = _kwargs.pop("_infer_schema_df", self._feature_df)
|
237
242
|
self._query: str = self._get_query()
|
238
243
|
self._version: Optional[FeatureViewVersion] = None
|
239
244
|
self._status: FeatureViewStatus = FeatureViewStatus.DRAFT
|
@@ -249,6 +254,14 @@ class FeatureView(lineage_node.LineageNode):
|
|
249
254
|
self._refresh_mode: Optional[str] = refresh_mode
|
250
255
|
self._refresh_mode_reason: Optional[str] = None
|
251
256
|
self._owner: Optional[str] = None
|
257
|
+
self._cluster_by: List[SqlIdentifier] = (
|
258
|
+
[SqlIdentifier(col) for col in cluster_by] if cluster_by is not None else self._get_default_cluster_by()
|
259
|
+
)
|
260
|
+
|
261
|
+
# Validate kwargs
|
262
|
+
if _kwargs:
|
263
|
+
raise TypeError(f"FeatureView.__init__ got an unexpected keyword argument: '{next(iter(_kwargs.keys()))}'")
|
264
|
+
|
252
265
|
self._validate()
|
253
266
|
|
254
267
|
def slice(self, names: List[str]) -> FeatureViewSlice:
|
@@ -394,6 +407,10 @@ class FeatureView(lineage_node.LineageNode):
|
|
394
407
|
def timestamp_col(self) -> Optional[SqlIdentifier]:
|
395
408
|
return self._timestamp_col
|
396
409
|
|
410
|
+
@property
|
411
|
+
def cluster_by(self) -> Optional[List[SqlIdentifier]]:
|
412
|
+
return self._cluster_by
|
413
|
+
|
397
414
|
@property
|
398
415
|
def desc(self) -> str:
|
399
416
|
return self._desc
|
@@ -656,6 +673,14 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
656
673
|
if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
|
657
674
|
raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
|
658
675
|
|
676
|
+
if self.cluster_by is not None:
|
677
|
+
for column in self.cluster_by:
|
678
|
+
if column not in df_cols:
|
679
|
+
raise ValueError(
|
680
|
+
f"Column '{column}' in `cluster_by` is not in the feature DataFrame schema. "
|
681
|
+
f"{df_cols}, {self.cluster_by}"
|
682
|
+
)
|
683
|
+
|
659
684
|
if re.match(_RESULT_SCAN_QUERY_PATTERN, self._query) is not None:
|
660
685
|
raise ValueError(f"feature_df should not be reading from RESULT_SCAN. Invalid query: {self._query}")
|
661
686
|
|
@@ -890,6 +915,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
890
915
|
owner: Optional[str],
|
891
916
|
infer_schema_df: Optional[DataFrame],
|
892
917
|
session: Session,
|
918
|
+
cluster_by: Optional[List[str]] = None,
|
893
919
|
) -> FeatureView:
|
894
920
|
fv = FeatureView(
|
895
921
|
name=name,
|
@@ -898,6 +924,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
898
924
|
timestamp_col=timestamp_col,
|
899
925
|
desc=desc,
|
900
926
|
_infer_schema_df=infer_schema_df,
|
927
|
+
cluster_by=cluster_by,
|
901
928
|
)
|
902
929
|
fv._version = FeatureViewVersion(version) if version is not None else None
|
903
930
|
fv._status = status
|
@@ -916,5 +943,23 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
916
943
|
)
|
917
944
|
return fv
|
918
945
|
|
946
|
+
#
|
947
|
+
def _get_default_cluster_by(self) -> List[SqlIdentifier]:
|
948
|
+
"""
|
949
|
+
Get default columns to cluster the feature view by.
|
950
|
+
Default cluster_by columns are join keys from entities and timestamp_col if it exists
|
951
|
+
|
952
|
+
Returns:
|
953
|
+
List of SqlIdentifiers representing the default columns to cluster the feature view by.
|
954
|
+
"""
|
955
|
+
# We don't focus on the order of entities here, as users can define a custom 'cluster_by'
|
956
|
+
# if a specific order is required.
|
957
|
+
default_cluster_by_cols = [key for entity in self.entities if entity.join_keys for key in entity.join_keys]
|
958
|
+
|
959
|
+
if self.timestamp_col:
|
960
|
+
default_cluster_by_cols.append(self.timestamp_col)
|
961
|
+
|
962
|
+
return default_cluster_by_cols
|
963
|
+
|
919
964
|
|
920
965
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["feature_view"] = FeatureView
|
snowflake/ml/fileset/fileset.py
CHANGED
@@ -257,7 +257,6 @@ class FileSet:
|
|
257
257
|
function_name=telemetry.get_statement_params_full_func_name(
|
258
258
|
inspect.currentframe(), cls.__class__.__name__
|
259
259
|
),
|
260
|
-
api_calls=[snowpark.DataFrameWriter.copy_into_location],
|
261
260
|
),
|
262
261
|
)
|
263
262
|
except snowpark_exceptions.SnowparkSQLException as e:
|