snowflake-ml-python 1.7.4__py3-none-any.whl → 1.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +21 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/jobs/_utils/constants.py +7 -1
- snowflake/ml/jobs/_utils/payload_utils.py +139 -53
- snowflake/ml/jobs/_utils/spec_utils.py +5 -7
- snowflake/ml/jobs/decorators.py +5 -25
- snowflake/ml/jobs/job.py +4 -4
- snowflake/ml/model/_packager/model_env/model_env.py +45 -28
- snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +16 -0
- snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/core.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +5 -5
- snowflake/ml/model/_signatures/pandas_handler.py +9 -7
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/model_signature.py +8 -0
- snowflake/ml/model/type_hints.py +15 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
- snowflake/ml/registry/registry.py +34 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +58 -25
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +41 -38
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@ import yaml
|
|
12
12
|
from packaging import requirements, specifiers, version
|
13
13
|
|
14
14
|
import snowflake.connector
|
15
|
-
from snowflake.ml._internal import env as snowml_env
|
15
|
+
from snowflake.ml._internal import env as snowml_env, relax_version_strategy
|
16
16
|
from snowflake.ml._internal.utils import query_result_checker
|
17
17
|
from snowflake.snowpark import context, exceptions, session
|
18
18
|
|
@@ -56,6 +56,8 @@ def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
|
|
56
56
|
|
57
57
|
if r.name == "python":
|
58
58
|
raise ValueError("Don't specify python as a dependency, use python version argument instead.")
|
59
|
+
if r.name == "cuda":
|
60
|
+
raise ValueError("Don't specify cuda as a dependency, use cuda version argument instead.")
|
59
61
|
except requirements.InvalidRequirement:
|
60
62
|
raise ValueError(f"Invalid package requirement {req_str} found.")
|
61
63
|
|
@@ -313,19 +315,14 @@ def get_package_spec_with_supported_ops_only(req: requirements.Requirement) -> r
|
|
313
315
|
return new_req
|
314
316
|
|
315
317
|
|
316
|
-
def
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
Returns:
|
324
|
-
A new requirement object after relaxations.
|
325
|
-
"""
|
326
|
-
new_req = copy.deepcopy(req)
|
318
|
+
def _relax_specifier_set(
|
319
|
+
specifier_set: specifiers.SpecifierSet, strategy: relax_version_strategy.RelaxVersionStrategy
|
320
|
+
) -> specifiers.SpecifierSet:
|
321
|
+
if strategy == relax_version_strategy.RelaxVersionStrategy.NO_RELAX:
|
322
|
+
return specifier_set
|
323
|
+
specifier_set = copy.deepcopy(specifier_set)
|
327
324
|
relaxed_specifier_set = set()
|
328
|
-
for spec in
|
325
|
+
for spec in specifier_set._specs:
|
329
326
|
if spec.operator != "==":
|
330
327
|
relaxed_specifier_set.add(spec)
|
331
328
|
continue
|
@@ -337,9 +334,40 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
|
|
337
334
|
relaxed_specifier_set.add(spec)
|
338
335
|
continue
|
339
336
|
assert pinned_version is not None
|
340
|
-
|
341
|
-
|
342
|
-
|
337
|
+
if strategy == relax_version_strategy.RelaxVersionStrategy.PATCH:
|
338
|
+
relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
|
339
|
+
relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major}.{pinned_version.minor+1}"))
|
340
|
+
elif strategy == relax_version_strategy.RelaxVersionStrategy.MINOR:
|
341
|
+
relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
|
342
|
+
relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
|
343
|
+
elif strategy == relax_version_strategy.RelaxVersionStrategy.MAJOR:
|
344
|
+
relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}"))
|
345
|
+
relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
|
346
|
+
specifier_set._specs = frozenset(relaxed_specifier_set)
|
347
|
+
return specifier_set
|
348
|
+
|
349
|
+
|
350
|
+
def relax_requirement_version(req: requirements.Requirement) -> requirements.Requirement:
|
351
|
+
"""Relax version specifier from a requirement. It detects any ==x.y.z in specifiers and replaced with relaxed
|
352
|
+
version specifier based on the strategy defined in RELAX_VERSION_STRATEGY_MAP.
|
353
|
+
|
354
|
+
NO_RELAX: No relaxation.
|
355
|
+
PATCH: >=x.y, <x.(y+1)
|
356
|
+
MINOR (default): >=x.y, <(x+1)
|
357
|
+
MAJOR: >=x, <(x+1)
|
358
|
+
|
359
|
+
|
360
|
+
Args:
|
361
|
+
req: The requirement that version specifier to be removed.
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
A new requirement object after relaxations.
|
365
|
+
"""
|
366
|
+
new_req = copy.deepcopy(req)
|
367
|
+
strategy = relax_version_strategy.RELAX_VERSION_STRATEGY_MAP.get(
|
368
|
+
req.name, relax_version_strategy.RelaxVersionStrategy.MINOR
|
369
|
+
)
|
370
|
+
new_req.specifier = _relax_specifier_set(new_req.specifier, strategy)
|
343
371
|
return new_req
|
344
372
|
|
345
373
|
|
@@ -431,10 +459,11 @@ def save_conda_env_file(
|
|
431
459
|
path: pathlib.Path,
|
432
460
|
conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
|
433
461
|
python_version: str,
|
462
|
+
cuda_version: Optional[str] = None,
|
434
463
|
default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
|
435
464
|
) -> None:
|
436
465
|
"""Generate conda.yml file given a dict of dependencies after validation.
|
437
|
-
The channels part of conda.yml file will
|
466
|
+
The channels part of conda.yml file will contain Snowflake Anaconda Channel, nodefaults and all channel names
|
438
467
|
in keys of the dict, ordered by the number of the packages which belongs to.
|
439
468
|
The dependencies part of conda.yml file will contains requirements specifications. If the requirements is in the
|
440
469
|
value list whose key is DEFAULT_CHANNEL_NAME, then the channel won't be specified explicitly. Otherwise, it will be
|
@@ -443,7 +472,8 @@ def save_conda_env_file(
|
|
443
472
|
Args:
|
444
473
|
path: Path to the conda.yml file.
|
445
474
|
conda_chan_deps: Dict of conda dependencies after validated.
|
446
|
-
python_version: A string 'major.minor'
|
475
|
+
python_version: A string 'major.minor' for the model's python version.
|
476
|
+
cuda_version: A string 'major.minor' for the model's cuda version.
|
447
477
|
default_channel_override: The default channel to be put in the first place of the channels section.
|
448
478
|
"""
|
449
479
|
assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
|
@@ -461,6 +491,10 @@ def save_conda_env_file(
|
|
461
491
|
|
462
492
|
env["channels"] = [default_channel_override] + channels + [_NODEFAULTS]
|
463
493
|
env["dependencies"] = [f"python=={python_version}.*"]
|
494
|
+
|
495
|
+
if cuda_version is not None:
|
496
|
+
env["dependencies"].extend([f"nvidia::cuda=={cuda_version}.*"])
|
497
|
+
|
464
498
|
for chan, reqs in conda_chan_deps.items():
|
465
499
|
env["dependencies"].extend(
|
466
500
|
[f"{chan}::{str(req)}" if chan != DEFAULT_CHANNEL_NAME else str(req) for req in reqs]
|
@@ -487,7 +521,12 @@ def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requi
|
|
487
521
|
|
488
522
|
def load_conda_env_file(
|
489
523
|
path: pathlib.Path,
|
490
|
-
) -> Tuple[
|
524
|
+
) -> Tuple[
|
525
|
+
DefaultDict[str, List[requirements.Requirement]],
|
526
|
+
Optional[List[requirements.Requirement]],
|
527
|
+
Optional[str],
|
528
|
+
Optional[str],
|
529
|
+
]:
|
491
530
|
"""Read conda.yml file to get a dict of dependencies after validation.
|
492
531
|
The channels part of conda.yml file will be processed with following rules:
|
493
532
|
1. If it is Snowflake Anaconda Channel, ignore as it is default.
|
@@ -515,7 +554,7 @@ def load_conda_env_file(
|
|
515
554
|
and a string 'major.minor.patchlevel' of python version.
|
516
555
|
"""
|
517
556
|
if not path.exists():
|
518
|
-
return collections.defaultdict(list), None, None
|
557
|
+
return collections.defaultdict(list), None, None, None
|
519
558
|
|
520
559
|
with open(path, encoding="utf-8") as f:
|
521
560
|
env = yaml.safe_load(stream=f)
|
@@ -526,6 +565,7 @@ def load_conda_env_file(
|
|
526
565
|
pip_deps = []
|
527
566
|
|
528
567
|
python_version = None
|
568
|
+
cuda_version = None
|
529
569
|
|
530
570
|
channels = env.get("channels", [])
|
531
571
|
if len(channels) >= 1:
|
@@ -541,6 +581,9 @@ def load_conda_env_file(
|
|
541
581
|
# ver is str: python w/ specifier
|
542
582
|
if ver:
|
543
583
|
python_version = ver
|
584
|
+
elif dep.startswith("nvidia::cuda"):
|
585
|
+
r = requirements.Requirement(dep.split("nvidia::")[1])
|
586
|
+
cuda_version = list(r.specifier)[0].version.strip(".*")
|
544
587
|
elif ver is None:
|
545
588
|
deps.append(dep)
|
546
589
|
elif isinstance(dep, dict) and "pip" in dep:
|
@@ -555,7 +598,7 @@ def load_conda_env_file(
|
|
555
598
|
if channel not in conda_dep_dict:
|
556
599
|
conda_dep_dict[channel] = []
|
557
600
|
|
558
|
-
return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version
|
601
|
+
return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version, cuda_version
|
559
602
|
|
560
603
|
|
561
604
|
def load_requirements_file(path: pathlib.Path) -> List[requirements.Requirement]:
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class RelaxVersionStrategy(Enum):
|
5
|
+
NO_RELAX = "no_relax"
|
6
|
+
PATCH = "patch"
|
7
|
+
MINOR = "minor"
|
8
|
+
MAJOR = "major"
|
9
|
+
|
10
|
+
|
11
|
+
RELAX_VERSION_STRATEGY_MAP = {
|
12
|
+
# The version of cloudpickle should not be relaxed as it is used for serialization.
|
13
|
+
"cloudpickle": RelaxVersionStrategy.NO_RELAX,
|
14
|
+
# The version of scikit-learn should be relaxed only in patch version as it has breaking changes in minor version.
|
15
|
+
"scikit-learn": RelaxVersionStrategy.PATCH,
|
16
|
+
}
|
@@ -4,6 +4,9 @@ import enum
|
|
4
4
|
import functools
|
5
5
|
import inspect
|
6
6
|
import operator
|
7
|
+
import sys
|
8
|
+
import time
|
9
|
+
import traceback
|
7
10
|
import types
|
8
11
|
from typing import (
|
9
12
|
Any,
|
@@ -75,6 +78,8 @@ class TelemetryField(enum.Enum):
|
|
75
78
|
KEY_FUNC_PARAMS = "func_params"
|
76
79
|
KEY_ERROR_INFO = "error_info"
|
77
80
|
KEY_ERROR_CODE = "error_code"
|
81
|
+
KEY_STACK_TRACE = "stack_trace"
|
82
|
+
KEY_DURATION = "duration"
|
78
83
|
KEY_VERSION = "version"
|
79
84
|
KEY_PYTHON_VERSION = "python_version"
|
80
85
|
KEY_OS = "operating_system"
|
@@ -435,6 +440,7 @@ def send_api_usage_telemetry(
|
|
435
440
|
|
436
441
|
# noqa: DAR402
|
437
442
|
"""
|
443
|
+
start_time = time.perf_counter()
|
438
444
|
|
439
445
|
if subproject is not None and subproject_extractor is not None:
|
440
446
|
raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
|
@@ -555,8 +561,16 @@ def send_api_usage_telemetry(
|
|
555
561
|
)
|
556
562
|
else:
|
557
563
|
me = e
|
564
|
+
|
558
565
|
telemetry_args["error"] = repr(me)
|
559
566
|
telemetry_args["error_code"] = me.error_code
|
567
|
+
# exclude telemetry frames
|
568
|
+
excluded_frames = 2
|
569
|
+
tb = traceback.extract_tb(sys.exc_info()[2])
|
570
|
+
formatted_tb = "".join(traceback.format_list(tb[excluded_frames:]))
|
571
|
+
formatted_exception = traceback.format_exception_only(*sys.exc_info()[:2])[0] # error type + message
|
572
|
+
telemetry_args["stack_trace"] = formatted_tb + formatted_exception
|
573
|
+
|
560
574
|
me.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined]
|
561
575
|
if e is not me:
|
562
576
|
raise # Directly raise non-wrapped exceptions to preserve original stacktrace
|
@@ -565,6 +579,7 @@ def send_api_usage_telemetry(
|
|
565
579
|
else:
|
566
580
|
raise me.original_exception from e
|
567
581
|
finally:
|
582
|
+
telemetry_args["duration"] = time.perf_counter() - start_time # type: ignore[assignment]
|
568
583
|
telemetry.send_function_usage_telemetry(**telemetry_args)
|
569
584
|
global _log_counter
|
570
585
|
_log_counter += 1
|
@@ -718,12 +733,14 @@ class _SourceTelemetryClient:
|
|
718
733
|
self,
|
719
734
|
func_name: str,
|
720
735
|
function_category: str,
|
736
|
+
duration: float,
|
721
737
|
func_params: Optional[Dict[str, Any]] = None,
|
722
738
|
api_calls: Optional[List[Dict[str, Any]]] = None,
|
723
739
|
sfqids: Optional[List[Any]] = None,
|
724
740
|
custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
|
725
741
|
error: Optional[str] = None,
|
726
742
|
error_code: Optional[str] = None,
|
743
|
+
stack_trace: Optional[str] = None,
|
727
744
|
) -> None:
|
728
745
|
"""
|
729
746
|
Send function usage telemetry message.
|
@@ -731,12 +748,14 @@ class _SourceTelemetryClient:
|
|
731
748
|
Args:
|
732
749
|
func_name: Function name.
|
733
750
|
function_category: Function category.
|
751
|
+
duration: Function duration.
|
734
752
|
func_params: Function parameters.
|
735
753
|
api_calls: API calls.
|
736
754
|
sfqids: Snowflake query IDs.
|
737
755
|
custom_tags: Custom tags.
|
738
756
|
error: Error.
|
739
757
|
error_code: Error code.
|
758
|
+
stack_trace: Error stack trace.
|
740
759
|
"""
|
741
760
|
data: Dict[str, Any] = {
|
742
761
|
TelemetryField.KEY_FUNC_NAME.value: func_name,
|
@@ -755,11 +774,13 @@ class _SourceTelemetryClient:
|
|
755
774
|
message: Dict[str, Any] = {
|
756
775
|
**self._create_basic_telemetry_data(telemetry_type),
|
757
776
|
TelemetryField.KEY_DATA.value: data,
|
777
|
+
TelemetryField.KEY_DURATION.value: duration,
|
758
778
|
}
|
759
779
|
|
760
780
|
if error:
|
761
781
|
message[TelemetryField.KEY_ERROR_INFO.value] = error
|
762
782
|
message[TelemetryField.KEY_ERROR_CODE.value] = error_code
|
783
|
+
message[TelemetryField.KEY_STACK_TRACE.value] = stack_trace
|
763
784
|
|
764
785
|
self._send(message)
|
765
786
|
|
@@ -116,7 +116,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
116
116
|
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
|
117
117
|
ds = self._get_dataset(shuffle=False)
|
118
118
|
table = ds.to_table() if limit is None else ds.head(num_rows=limit)
|
119
|
-
return table.to_pandas()
|
119
|
+
return table.to_pandas(split_blocks=True, self_destruct=True)
|
120
120
|
|
121
121
|
def _get_dataset(self, shuffle: bool) -> pds.Dataset:
|
122
122
|
format = self._format
|
@@ -144,6 +144,7 @@ _LIST_FEATURE_VIEW_SCHEMA = StructType(
|
|
144
144
|
StructField("refresh_mode", StringType()),
|
145
145
|
StructField("scheduling_state", StringType()),
|
146
146
|
StructField("warehouse", StringType()),
|
147
|
+
StructField("cluster_by", StringType()),
|
147
148
|
]
|
148
149
|
)
|
149
150
|
|
@@ -1832,6 +1833,12 @@ class FeatureStore:
|
|
1832
1833
|
WAREHOUSE = {warehouse}
|
1833
1834
|
REFRESH_MODE = {feature_view.refresh_mode}
|
1834
1835
|
INITIALIZE = {feature_view.initialize}
|
1836
|
+
"""
|
1837
|
+
if feature_view.cluster_by:
|
1838
|
+
cluster_by_clause = f"CLUSTER BY ({', '.join(feature_view.cluster_by)})"
|
1839
|
+
query += f"{cluster_by_clause}"
|
1840
|
+
|
1841
|
+
query += f"""
|
1835
1842
|
AS {feature_view.query}
|
1836
1843
|
"""
|
1837
1844
|
self._session.sql(query).collect(block=block, statement_params=self._telemetry_stmp)
|
@@ -2249,6 +2256,7 @@ class FeatureStore:
|
|
2249
2256
|
values.append(row["refresh_mode"] if "refresh_mode" in row else None)
|
2250
2257
|
values.append(row["scheduling_state"] if "scheduling_state" in row else None)
|
2251
2258
|
values.append(row["warehouse"] if "warehouse" in row else None)
|
2259
|
+
values.append(json.dumps(self._extract_cluster_by_columns(row["cluster_by"])) if "cluster_by" in row else None)
|
2252
2260
|
output_values.append(values)
|
2253
2261
|
|
2254
2262
|
def _lookup_feature_view_metadata(self, row: Row, fv_name: str) -> Tuple[_FeatureViewMetadata, str]:
|
@@ -2335,6 +2343,7 @@ class FeatureStore:
|
|
2335
2343
|
owner=row["owner"],
|
2336
2344
|
infer_schema_df=infer_schema_df,
|
2337
2345
|
session=self._session,
|
2346
|
+
cluster_by=self._extract_cluster_by_columns(row["cluster_by"]),
|
2338
2347
|
)
|
2339
2348
|
return fv
|
2340
2349
|
else:
|
@@ -2625,3 +2634,12 @@ class FeatureStore:
|
|
2625
2634
|
)
|
2626
2635
|
|
2627
2636
|
return feature_view
|
2637
|
+
|
2638
|
+
@staticmethod
|
2639
|
+
def _extract_cluster_by_columns(cluster_by_clause: str) -> List[str]:
|
2640
|
+
# Use regex to extract elements inside the parentheses.
|
2641
|
+
match = re.search(r"\((.*?)\)", cluster_by_clause)
|
2642
|
+
if match:
|
2643
|
+
# Handle both quoted and unquoted column names.
|
2644
|
+
return re.findall(identifier.SF_IDENTIFIER_RE, match.group(1))
|
2645
|
+
return []
|
@@ -170,6 +170,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
170
170
|
warehouse: Optional[str] = None,
|
171
171
|
initialize: str = "ON_CREATE",
|
172
172
|
refresh_mode: str = "AUTO",
|
173
|
+
cluster_by: Optional[List[str]] = None,
|
173
174
|
**_kwargs: Any,
|
174
175
|
) -> None:
|
175
176
|
"""
|
@@ -200,6 +201,9 @@ class FeatureView(lineage_node.LineageNode):
|
|
200
201
|
refresh_mode: The refresh mode of managed feature view. The value can be 'AUTO', 'FULL' or 'INCREMENETAL'.
|
201
202
|
For managed feature view, the default value is 'AUTO'. For static feature view it has no effect.
|
202
203
|
Check https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table for for details.
|
204
|
+
cluster_by: Columns to cluster the feature view by.
|
205
|
+
- Defaults to the join keys from entities.
|
206
|
+
- If `timestamp_col` is provided, it is added to the default clustering keys.
|
203
207
|
_kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
|
204
208
|
|
205
209
|
Example::
|
@@ -224,6 +228,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
224
228
|
>>> print(registered_fv.status)
|
225
229
|
FeatureViewStatus.ACTIVE
|
226
230
|
|
231
|
+
# noqa: DAR401
|
227
232
|
"""
|
228
233
|
|
229
234
|
self._name: SqlIdentifier = SqlIdentifier(name)
|
@@ -233,7 +238,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
233
238
|
SqlIdentifier(timestamp_col) if timestamp_col is not None else None
|
234
239
|
)
|
235
240
|
self._desc: str = desc
|
236
|
-
self._infer_schema_df: DataFrame = _kwargs.
|
241
|
+
self._infer_schema_df: DataFrame = _kwargs.pop("_infer_schema_df", self._feature_df)
|
237
242
|
self._query: str = self._get_query()
|
238
243
|
self._version: Optional[FeatureViewVersion] = None
|
239
244
|
self._status: FeatureViewStatus = FeatureViewStatus.DRAFT
|
@@ -249,6 +254,14 @@ class FeatureView(lineage_node.LineageNode):
|
|
249
254
|
self._refresh_mode: Optional[str] = refresh_mode
|
250
255
|
self._refresh_mode_reason: Optional[str] = None
|
251
256
|
self._owner: Optional[str] = None
|
257
|
+
self._cluster_by: List[SqlIdentifier] = (
|
258
|
+
[SqlIdentifier(col) for col in cluster_by] if cluster_by is not None else self._get_default_cluster_by()
|
259
|
+
)
|
260
|
+
|
261
|
+
# Validate kwargs
|
262
|
+
if _kwargs:
|
263
|
+
raise TypeError(f"FeatureView.__init__ got an unexpected keyword argument: '{next(iter(_kwargs.keys()))}'")
|
264
|
+
|
252
265
|
self._validate()
|
253
266
|
|
254
267
|
def slice(self, names: List[str]) -> FeatureViewSlice:
|
@@ -394,6 +407,10 @@ class FeatureView(lineage_node.LineageNode):
|
|
394
407
|
def timestamp_col(self) -> Optional[SqlIdentifier]:
|
395
408
|
return self._timestamp_col
|
396
409
|
|
410
|
+
@property
|
411
|
+
def cluster_by(self) -> Optional[List[SqlIdentifier]]:
|
412
|
+
return self._cluster_by
|
413
|
+
|
397
414
|
@property
|
398
415
|
def desc(self) -> str:
|
399
416
|
return self._desc
|
@@ -656,6 +673,14 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
656
673
|
if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
|
657
674
|
raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
|
658
675
|
|
676
|
+
if self.cluster_by is not None:
|
677
|
+
for column in self.cluster_by:
|
678
|
+
if column not in df_cols:
|
679
|
+
raise ValueError(
|
680
|
+
f"Column '{column}' in `cluster_by` is not in the feature DataFrame schema. "
|
681
|
+
f"{df_cols}, {self.cluster_by}"
|
682
|
+
)
|
683
|
+
|
659
684
|
if re.match(_RESULT_SCAN_QUERY_PATTERN, self._query) is not None:
|
660
685
|
raise ValueError(f"feature_df should not be reading from RESULT_SCAN. Invalid query: {self._query}")
|
661
686
|
|
@@ -890,6 +915,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
890
915
|
owner: Optional[str],
|
891
916
|
infer_schema_df: Optional[DataFrame],
|
892
917
|
session: Session,
|
918
|
+
cluster_by: Optional[List[str]] = None,
|
893
919
|
) -> FeatureView:
|
894
920
|
fv = FeatureView(
|
895
921
|
name=name,
|
@@ -898,6 +924,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
898
924
|
timestamp_col=timestamp_col,
|
899
925
|
desc=desc,
|
900
926
|
_infer_schema_df=infer_schema_df,
|
927
|
+
cluster_by=cluster_by,
|
901
928
|
)
|
902
929
|
fv._version = FeatureViewVersion(version) if version is not None else None
|
903
930
|
fv._status = status
|
@@ -916,5 +943,23 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
916
943
|
)
|
917
944
|
return fv
|
918
945
|
|
946
|
+
#
|
947
|
+
def _get_default_cluster_by(self) -> List[SqlIdentifier]:
|
948
|
+
"""
|
949
|
+
Get default columns to cluster the feature view by.
|
950
|
+
Default cluster_by columns are join keys from entities and timestamp_col if it exists
|
951
|
+
|
952
|
+
Returns:
|
953
|
+
List of SqlIdentifiers representing the default columns to cluster the feature view by.
|
954
|
+
"""
|
955
|
+
# We don't focus on the order of entities here, as users can define a custom 'cluster_by'
|
956
|
+
# if a specific order is required.
|
957
|
+
default_cluster_by_cols = [key for entity in self.entities if entity.join_keys for key in entity.join_keys]
|
958
|
+
|
959
|
+
if self.timestamp_col:
|
960
|
+
default_cluster_by_cols.append(self.timestamp_col)
|
961
|
+
|
962
|
+
return default_cluster_by_cols
|
963
|
+
|
919
964
|
|
920
965
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["feature_view"] = FeatureView
|
@@ -4,12 +4,15 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
4
4
|
# SPCS specification constants
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
|
+
MEMORY_VOLUME_NAME = "dshm"
|
8
|
+
STAGE_VOLUME_NAME = "stage-volume"
|
9
|
+
STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
7
10
|
|
8
11
|
# Default container image information
|
9
12
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
10
13
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
11
14
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
12
|
-
DEFAULT_IMAGE_TAG = "0.
|
15
|
+
DEFAULT_IMAGE_TAG = "0.9.2"
|
13
16
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
14
17
|
|
15
18
|
# Percent of container memory to allocate for /dev/shm volume
|
@@ -19,6 +22,9 @@ MEMORY_VOLUME_SIZE = 0.3
|
|
19
22
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
20
23
|
JOB_POLL_MAX_DELAY_SECONDS = 1
|
21
24
|
|
25
|
+
# Magic attributes
|
26
|
+
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
27
|
+
|
22
28
|
# Compute pool resource information
|
23
29
|
# TODO: Query Snowflake for resource information instead of relying on this hardcoded
|
24
30
|
# table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
|