snowflake-ml-python 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +0 -4
- snowflake/ml/feature_store/__init__.py +2 -0
- snowflake/ml/feature_store/aggregation.py +367 -0
- snowflake/ml/feature_store/feature.py +366 -0
- snowflake/ml/feature_store/feature_store.py +234 -20
- snowflake/ml/feature_store/feature_view.py +189 -4
- snowflake/ml/feature_store/metadata_manager.py +425 -0
- snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/model/__init__.py +4 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
- snowflake/ml/model/_client/model/model_version_impl.py +120 -89
- snowflake/ml/model/_client/ops/model_ops.py +4 -26
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +63 -23
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +25 -54
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +74 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +121 -29
- snowflake/ml/model/_signatures/utils.py +130 -0
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/METADATA +105 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/RECORD +41 -35
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/WHEEL +1 -1
- snowflake/ml/experiment/callback/__init__.py +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,9 @@ from snowflake.ml._internal.utils.sql_identifier import (
|
|
|
20
20
|
to_sql_identifiers,
|
|
21
21
|
)
|
|
22
22
|
from snowflake.ml.feature_store import feature_store
|
|
23
|
+
from snowflake.ml.feature_store.aggregation import AggregationSpec
|
|
23
24
|
from snowflake.ml.feature_store.entity import Entity
|
|
25
|
+
from snowflake.ml.feature_store.feature import Feature
|
|
24
26
|
from snowflake.ml.lineage import lineage_node
|
|
25
27
|
from snowflake.snowpark import DataFrame, Session
|
|
26
28
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
@@ -92,6 +94,7 @@ class _FeatureViewMetadata:
|
|
|
92
94
|
|
|
93
95
|
entities: list[str]
|
|
94
96
|
timestamp_col: str
|
|
97
|
+
is_tiled: bool = False # Whether FV uses tile-based aggregations
|
|
95
98
|
|
|
96
99
|
def to_json(self) -> str:
|
|
97
100
|
return json.dumps(asdict(self))
|
|
@@ -99,6 +102,9 @@ class _FeatureViewMetadata:
|
|
|
99
102
|
@classmethod
|
|
100
103
|
def from_json(cls, json_str: str) -> _FeatureViewMetadata:
|
|
101
104
|
state_dict = json.loads(json_str)
|
|
105
|
+
# Backward compatibility: old FVs don't have is_tiled
|
|
106
|
+
if "is_tiled" not in state_dict:
|
|
107
|
+
state_dict["is_tiled"] = False
|
|
102
108
|
return cls(**state_dict)
|
|
103
109
|
|
|
104
110
|
|
|
@@ -213,6 +219,8 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
213
219
|
refresh_mode: str = "AUTO",
|
|
214
220
|
cluster_by: Optional[list[str]] = None,
|
|
215
221
|
online_config: Optional[OnlineConfig] = None,
|
|
222
|
+
feature_granularity: Optional[str] = None,
|
|
223
|
+
features: Optional[list[Feature]] = None,
|
|
216
224
|
**_kwargs: Any,
|
|
217
225
|
) -> None:
|
|
218
226
|
"""
|
|
@@ -256,6 +264,13 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
256
264
|
|
|
257
265
|
.. note::
|
|
258
266
|
This feature is currently in preview.
|
|
267
|
+
feature_granularity: The tile interval for time-series aggregations (e.g., "1h", "1d").
|
|
268
|
+
When specified along with ``features``, enables tile-based aggregation where a
|
|
269
|
+
Dynamic Table stores pre-computed partial aggregations (tiles), and dataset
|
|
270
|
+
generation merges these tiles for point-in-time correct results.
|
|
271
|
+
features: List of aggregation feature definitions using the ``Feature`` class.
|
|
272
|
+
Required when ``feature_granularity`` is specified. Defines the aggregations
|
|
273
|
+
to compute (e.g., SUM, COUNT, LAST_N) with their windows.
|
|
259
274
|
_kwargs: Reserved kwargs for system generated args.
|
|
260
275
|
|
|
261
276
|
.. caution::
|
|
@@ -329,6 +344,19 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
329
344
|
)
|
|
330
345
|
self._online_config: Optional[OnlineConfig] = online_config
|
|
331
346
|
|
|
347
|
+
# Tile-based aggregation fields
|
|
348
|
+
self._feature_granularity: Optional[str] = feature_granularity
|
|
349
|
+
self._aggregation_specs: Optional[list[AggregationSpec]] = _kwargs.pop("_aggregation_specs", None)
|
|
350
|
+
if features is not None:
|
|
351
|
+
self._aggregation_specs = [f.to_spec() for f in features]
|
|
352
|
+
|
|
353
|
+
# For tiled FVs, re-initialize feature_descs with output column names (not source column names)
|
|
354
|
+
# This ensures descriptions are keyed by the output columns users will see in datasets
|
|
355
|
+
if self.is_tiled and self._aggregation_specs:
|
|
356
|
+
self._feature_desc = OrderedDict(
|
|
357
|
+
(SqlIdentifier(spec.get_sql_column_name()), "") for spec in self._aggregation_specs
|
|
358
|
+
)
|
|
359
|
+
|
|
332
360
|
# Validate kwargs
|
|
333
361
|
if _kwargs:
|
|
334
362
|
raise TypeError(f"FeatureView.__init__ got an unexpected keyword argument: '{next(iter(_kwargs.keys()))}'")
|
|
@@ -535,6 +563,9 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
535
563
|
|
|
536
564
|
@property
|
|
537
565
|
def feature_names(self) -> list[SqlIdentifier]:
|
|
566
|
+
# For tiled FVs, return output column names from aggregation specs
|
|
567
|
+
if self.is_tiled and self._aggregation_specs:
|
|
568
|
+
return [SqlIdentifier(spec.get_sql_column_name()) for spec in self._aggregation_specs]
|
|
538
569
|
return list(self._feature_desc.keys()) if self._feature_desc is not None else []
|
|
539
570
|
|
|
540
571
|
@property
|
|
@@ -556,6 +587,25 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
556
587
|
def online_config(self) -> Optional[OnlineConfig]:
|
|
557
588
|
return self._online_config
|
|
558
589
|
|
|
590
|
+
@property
|
|
591
|
+
def is_tiled(self) -> bool:
|
|
592
|
+
"""Check if this feature view uses tile-based aggregation.
|
|
593
|
+
|
|
594
|
+
Returns:
|
|
595
|
+
True if feature_granularity and features are configured.
|
|
596
|
+
"""
|
|
597
|
+
return self._feature_granularity is not None and self._aggregation_specs is not None
|
|
598
|
+
|
|
599
|
+
@property
|
|
600
|
+
def feature_granularity(self) -> Optional[str]:
|
|
601
|
+
"""Get the tile interval for aggregations."""
|
|
602
|
+
return self._feature_granularity
|
|
603
|
+
|
|
604
|
+
@property
|
|
605
|
+
def aggregation_specs(self) -> Optional[list[AggregationSpec]]:
|
|
606
|
+
"""Get the aggregation specifications (internal use)."""
|
|
607
|
+
return self._aggregation_specs
|
|
608
|
+
|
|
559
609
|
def fully_qualified_online_table_name(self) -> str:
|
|
560
610
|
"""Get the fully qualified name for the online feature table.
|
|
561
611
|
|
|
@@ -741,7 +791,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
741
791
|
def _metadata(self) -> _FeatureViewMetadata:
|
|
742
792
|
entity_names = [e.name.identifier() for e in self.entities]
|
|
743
793
|
ts_col = self.timestamp_col.identifier() if self.timestamp_col is not None else _TIMESTAMP_COL_PLACEHOLDER
|
|
744
|
-
return _FeatureViewMetadata(entity_names, ts_col)
|
|
794
|
+
return _FeatureViewMetadata(entity_names, ts_col, is_tiled=self.is_tiled)
|
|
745
795
|
|
|
746
796
|
def _get_query(self) -> str:
|
|
747
797
|
if len(self._feature_df.queries["queries"]) != 1:
|
|
@@ -765,7 +815,10 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
765
815
|
if k not in df_cols:
|
|
766
816
|
raise ValueError(f"join_key {k} in Entity {e.name} is not found in input dataframe: {df_cols}")
|
|
767
817
|
|
|
768
|
-
|
|
818
|
+
# For tiled FVs, timestamp_col is used for tiling (not an output column)
|
|
819
|
+
# and cluster_by is adjusted to use TILE_START instead of timestamp_col
|
|
820
|
+
# So skip these validations for tiled FVs
|
|
821
|
+
if self._timestamp_col is not None and not self.is_tiled:
|
|
769
822
|
ts_col = self._timestamp_col
|
|
770
823
|
if ts_col == SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER):
|
|
771
824
|
raise ValueError(f"Invalid timestamp_col name, cannot be {_TIMESTAMP_COL_PLACEHOLDER}.")
|
|
@@ -776,7 +829,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
776
829
|
if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
|
|
777
830
|
raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
|
|
778
831
|
|
|
779
|
-
if self.cluster_by is not None:
|
|
832
|
+
if self.cluster_by is not None and not self.is_tiled:
|
|
780
833
|
for column in self.cluster_by:
|
|
781
834
|
if column not in df_cols:
|
|
782
835
|
raise ValueError(
|
|
@@ -790,6 +843,93 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
790
843
|
if self._initialize not in ["ON_CREATE", "ON_SCHEDULE"]:
|
|
791
844
|
raise ValueError("'initialize' only supports ON_CREATE or ON_SCHEDULE.")
|
|
792
845
|
|
|
846
|
+
# Validate tiled aggregation configuration
|
|
847
|
+
if self._feature_granularity is not None and self._aggregation_specs is None:
|
|
848
|
+
raise ValueError(
|
|
849
|
+
"feature_granularity requires features to be specified. "
|
|
850
|
+
"Use the Feature class to define aggregations."
|
|
851
|
+
)
|
|
852
|
+
if self._aggregation_specs is not None and self._feature_granularity is None:
|
|
853
|
+
raise ValueError(
|
|
854
|
+
"features requires feature_granularity to be specified. "
|
|
855
|
+
"Specify the tile interval (e.g., '1h', '1d')."
|
|
856
|
+
)
|
|
857
|
+
if self.is_tiled:
|
|
858
|
+
if self._timestamp_col is None:
|
|
859
|
+
raise ValueError(
|
|
860
|
+
"timestamp_col is required for tile-based aggregations. "
|
|
861
|
+
"Specify the timestamp column used for time-series lookups."
|
|
862
|
+
)
|
|
863
|
+
if self._refresh_freq is None:
|
|
864
|
+
raise ValueError(
|
|
865
|
+
"refresh_freq is required for tile-based aggregations. "
|
|
866
|
+
"Tiled feature views must be managed (Dynamic Tables)."
|
|
867
|
+
)
|
|
868
|
+
# Validate window and offset are multiples of granularity
|
|
869
|
+
self._validate_window_offset_alignment()
|
|
870
|
+
# Validate feature aliases are unique
|
|
871
|
+
self._validate_unique_feature_aliases()
|
|
872
|
+
|
|
873
|
+
def _validate_window_offset_alignment(self) -> None:
|
|
874
|
+
"""Validate that window and offset are multiples of feature_granularity.
|
|
875
|
+
|
|
876
|
+
This ensures clean tile boundaries for aggregations.
|
|
877
|
+
Lifetime windows are exempt from this validation as they aggregate all tiles.
|
|
878
|
+
|
|
879
|
+
Raises:
|
|
880
|
+
ValueError: If window or offset is not a multiple of feature_granularity.
|
|
881
|
+
"""
|
|
882
|
+
from snowflake.ml.feature_store.aggregation import interval_to_seconds
|
|
883
|
+
|
|
884
|
+
granularity_seconds = interval_to_seconds(self._feature_granularity) # type: ignore[arg-type]
|
|
885
|
+
|
|
886
|
+
for spec in self._aggregation_specs: # type: ignore[union-attr]
|
|
887
|
+
# Skip validation for lifetime windows (they aggregate all tiles)
|
|
888
|
+
if spec.is_lifetime():
|
|
889
|
+
continue
|
|
890
|
+
|
|
891
|
+
# Validate window alignment
|
|
892
|
+
window_seconds = spec.get_window_seconds()
|
|
893
|
+
if window_seconds % granularity_seconds != 0:
|
|
894
|
+
raise ValueError(
|
|
895
|
+
f"Window '{spec.window}' for feature '{spec.output_column}' must be a multiple of "
|
|
896
|
+
f"feature_granularity '{self._feature_granularity}'. "
|
|
897
|
+
f"Window ({window_seconds}s) is not divisible by granularity ({granularity_seconds}s)."
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
# Validate offset alignment (if offset is specified)
|
|
901
|
+
offset_seconds = spec.get_offset_seconds()
|
|
902
|
+
if offset_seconds > 0 and offset_seconds % granularity_seconds != 0:
|
|
903
|
+
raise ValueError(
|
|
904
|
+
f"Offset '{spec.offset}' for feature '{spec.output_column}' must be a multiple of "
|
|
905
|
+
f"feature_granularity '{self._feature_granularity}'. "
|
|
906
|
+
f"Offset ({offset_seconds}s) is not divisible by granularity ({granularity_seconds}s)."
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
def _validate_unique_feature_aliases(self) -> None:
|
|
910
|
+
"""Validate that all feature aliases are unique.
|
|
911
|
+
|
|
912
|
+
Duplicate aliases would cause SQL errors or ambiguous columns
|
|
913
|
+
in the generated tile table and dataset queries.
|
|
914
|
+
|
|
915
|
+
Raises:
|
|
916
|
+
ValueError: If duplicate feature aliases are found.
|
|
917
|
+
"""
|
|
918
|
+
if self._aggregation_specs is None:
|
|
919
|
+
return
|
|
920
|
+
|
|
921
|
+
seen_aliases: dict[str, int] = {}
|
|
922
|
+
for spec in self._aggregation_specs:
|
|
923
|
+
# Normalize to uppercase for comparison (Snowflake default)
|
|
924
|
+
alias = spec.output_column.strip('"').upper()
|
|
925
|
+
if alias in seen_aliases:
|
|
926
|
+
raise ValueError(
|
|
927
|
+
f"Duplicate feature alias '{spec.output_column}' found. "
|
|
928
|
+
f"Each feature must have a unique alias. "
|
|
929
|
+
f"Use .alias() to provide distinct names for features."
|
|
930
|
+
)
|
|
931
|
+
seen_aliases[alias] = 1
|
|
932
|
+
|
|
793
933
|
def _get_column_names(self) -> Optional[list[SqlIdentifier]]:
|
|
794
934
|
try:
|
|
795
935
|
return to_sql_identifiers(self._infer_schema_df.columns)
|
|
@@ -861,6 +1001,12 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
861
1001
|
|
|
862
1002
|
fv_dict["_online_config"] = self._online_config.to_json() if self._online_config is not None else None
|
|
863
1003
|
|
|
1004
|
+
# Aggregation fields
|
|
1005
|
+
fv_dict["_feature_granularity"] = self._feature_granularity
|
|
1006
|
+
fv_dict["_aggregation_specs"] = (
|
|
1007
|
+
[spec.to_dict() for spec in self._aggregation_specs] if self._aggregation_specs is not None else None
|
|
1008
|
+
)
|
|
1009
|
+
|
|
864
1010
|
lineage_node_keys = [key for key in fv_dict if key.startswith("_node") or key == "_session"]
|
|
865
1011
|
|
|
866
1012
|
for key in lineage_node_keys:
|
|
@@ -930,6 +1076,11 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
930
1076
|
e.owner = e_json["owner"]
|
|
931
1077
|
entities.append(e)
|
|
932
1078
|
|
|
1079
|
+
# Deserialize aggregation specs if present
|
|
1080
|
+
aggregation_specs = None
|
|
1081
|
+
if json_dict.get("_aggregation_specs"):
|
|
1082
|
+
aggregation_specs = [AggregationSpec.from_dict(s) for s in json_dict["_aggregation_specs"]]
|
|
1083
|
+
|
|
933
1084
|
return FeatureView._construct_feature_view(
|
|
934
1085
|
name=json_dict["_name"],
|
|
935
1086
|
entities=entities,
|
|
@@ -952,6 +1103,8 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
952
1103
|
online_config=OnlineConfig.from_json(json_dict["_online_config"])
|
|
953
1104
|
if json_dict.get("_online_config")
|
|
954
1105
|
else None,
|
|
1106
|
+
feature_granularity=json_dict.get("_feature_granularity"),
|
|
1107
|
+
aggregation_specs=aggregation_specs,
|
|
955
1108
|
)
|
|
956
1109
|
|
|
957
1110
|
def _get_compact_repr(self) -> _CompactRepresentation:
|
|
@@ -1025,6 +1178,8 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
1025
1178
|
session: Session,
|
|
1026
1179
|
cluster_by: Optional[list[str]] = None,
|
|
1027
1180
|
online_config: Optional[OnlineConfig] = None,
|
|
1181
|
+
feature_granularity: Optional[str] = None,
|
|
1182
|
+
aggregation_specs: Optional[list[AggregationSpec]] = None,
|
|
1028
1183
|
) -> FeatureView:
|
|
1029
1184
|
fv = FeatureView(
|
|
1030
1185
|
name=name,
|
|
@@ -1032,13 +1187,15 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
1032
1187
|
feature_df=feature_df,
|
|
1033
1188
|
timestamp_col=timestamp_col,
|
|
1034
1189
|
desc=desc,
|
|
1190
|
+
refresh_freq=refresh_freq,
|
|
1035
1191
|
_infer_schema_df=infer_schema_df,
|
|
1036
1192
|
cluster_by=cluster_by,
|
|
1037
1193
|
online_config=online_config,
|
|
1194
|
+
feature_granularity=feature_granularity,
|
|
1195
|
+
_aggregation_specs=aggregation_specs,
|
|
1038
1196
|
)
|
|
1039
1197
|
fv._version = FeatureViewVersion(version) if version is not None else None
|
|
1040
1198
|
fv._status = status
|
|
1041
|
-
fv._refresh_freq = refresh_freq
|
|
1042
1199
|
fv._database = SqlIdentifier(database) if database is not None else None
|
|
1043
1200
|
fv._schema = SqlIdentifier(schema) if schema is not None else None
|
|
1044
1201
|
fv._warehouse = SqlIdentifier(warehouse) if warehouse is not None else None
|
|
@@ -1071,6 +1228,34 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
1071
1228
|
|
|
1072
1229
|
return default_cluster_by_cols
|
|
1073
1230
|
|
|
1231
|
+
def _get_tile_query(self) -> str:
|
|
1232
|
+
"""Generate the tiling query for tile-based aggregations.
|
|
1233
|
+
|
|
1234
|
+
This query is used as the source for the Dynamic Table that stores
|
|
1235
|
+
pre-computed partial aggregations (tiles).
|
|
1236
|
+
|
|
1237
|
+
Returns:
|
|
1238
|
+
SQL query string for creating tile table.
|
|
1239
|
+
|
|
1240
|
+
Raises:
|
|
1241
|
+
ValueError: If the feature view is not configured for tiling.
|
|
1242
|
+
"""
|
|
1243
|
+
if not self.is_tiled:
|
|
1244
|
+
raise ValueError("_get_tile_query called on non-tiled feature view")
|
|
1245
|
+
|
|
1246
|
+
from snowflake.ml.feature_store.tile_sql_generator import TilingSqlGenerator
|
|
1247
|
+
|
|
1248
|
+
join_keys = [str(k) for e in self._entities for k in e.join_keys]
|
|
1249
|
+
|
|
1250
|
+
generator = TilingSqlGenerator(
|
|
1251
|
+
source_query=self._query,
|
|
1252
|
+
join_keys=join_keys,
|
|
1253
|
+
timestamp_col=str(self._timestamp_col),
|
|
1254
|
+
feature_granularity=self._feature_granularity, # type: ignore[arg-type]
|
|
1255
|
+
features=self._aggregation_specs, # type: ignore[arg-type]
|
|
1256
|
+
)
|
|
1257
|
+
return generator.generate()
|
|
1258
|
+
|
|
1074
1259
|
@staticmethod
|
|
1075
1260
|
def _get_online_table_name(
|
|
1076
1261
|
feature_view_name: Union[SqlIdentifier, str], version: Optional[Union[FeatureViewVersion, str]] = None
|