snowflake-ml-python 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +0 -4
  2. snowflake/ml/feature_store/__init__.py +2 -0
  3. snowflake/ml/feature_store/aggregation.py +367 -0
  4. snowflake/ml/feature_store/feature.py +366 -0
  5. snowflake/ml/feature_store/feature_store.py +234 -20
  6. snowflake/ml/feature_store/feature_view.py +189 -4
  7. snowflake/ml/feature_store/metadata_manager.py +425 -0
  8. snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
  9. snowflake/ml/model/__init__.py +4 -0
  10. snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
  11. snowflake/ml/model/_client/model/model_version_impl.py +31 -14
  12. snowflake/ml/model/_client/ops/model_ops.py +2 -8
  13. snowflake/ml/model/_client/ops/service_ops.py +0 -5
  14. snowflake/ml/model/_client/sql/service.py +21 -29
  15. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
  16. snowflake/ml/model/_packager/model_handlers/huggingface.py +20 -0
  17. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +70 -14
  18. snowflake/ml/model/_signatures/utils.py +76 -1
  19. snowflake/ml/version.py +1 -1
  20. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/METADATA +39 -1
  21. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/RECORD +24 -20
  22. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/WHEEL +1 -1
  23. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  24. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,9 @@ from snowflake.ml._internal.utils.sql_identifier import (
20
20
  to_sql_identifiers,
21
21
  )
22
22
  from snowflake.ml.feature_store import feature_store
23
+ from snowflake.ml.feature_store.aggregation import AggregationSpec
23
24
  from snowflake.ml.feature_store.entity import Entity
25
+ from snowflake.ml.feature_store.feature import Feature
24
26
  from snowflake.ml.lineage import lineage_node
25
27
  from snowflake.snowpark import DataFrame, Session
26
28
  from snowflake.snowpark.exceptions import SnowparkSQLException
@@ -92,6 +94,7 @@ class _FeatureViewMetadata:
92
94
 
93
95
  entities: list[str]
94
96
  timestamp_col: str
97
+ is_tiled: bool = False # Whether FV uses tile-based aggregations
95
98
 
96
99
  def to_json(self) -> str:
97
100
  return json.dumps(asdict(self))
@@ -99,6 +102,9 @@ class _FeatureViewMetadata:
99
102
  @classmethod
100
103
  def from_json(cls, json_str: str) -> _FeatureViewMetadata:
101
104
  state_dict = json.loads(json_str)
105
+ # Backward compatibility: old FVs don't have is_tiled
106
+ if "is_tiled" not in state_dict:
107
+ state_dict["is_tiled"] = False
102
108
  return cls(**state_dict)
103
109
 
104
110
 
@@ -213,6 +219,8 @@ class FeatureView(lineage_node.LineageNode):
213
219
  refresh_mode: str = "AUTO",
214
220
  cluster_by: Optional[list[str]] = None,
215
221
  online_config: Optional[OnlineConfig] = None,
222
+ feature_granularity: Optional[str] = None,
223
+ features: Optional[list[Feature]] = None,
216
224
  **_kwargs: Any,
217
225
  ) -> None:
218
226
  """
@@ -256,6 +264,13 @@ class FeatureView(lineage_node.LineageNode):
256
264
 
257
265
  .. note::
258
266
  This feature is currently in preview.
267
+ feature_granularity: The tile interval for time-series aggregations (e.g., "1h", "1d").
268
+ When specified along with ``features``, enables tile-based aggregation where a
269
+ Dynamic Table stores pre-computed partial aggregations (tiles), and dataset
270
+ generation merges these tiles for point-in-time correct results.
271
+ features: List of aggregation feature definitions using the ``Feature`` class.
272
+ Required when ``feature_granularity`` is specified. Defines the aggregations
273
+ to compute (e.g., SUM, COUNT, LAST_N) with their windows.
259
274
  _kwargs: Reserved kwargs for system generated args.
260
275
 
261
276
  .. caution::
@@ -329,6 +344,19 @@ class FeatureView(lineage_node.LineageNode):
329
344
  )
330
345
  self._online_config: Optional[OnlineConfig] = online_config
331
346
 
347
+ # Tile-based aggregation fields
348
+ self._feature_granularity: Optional[str] = feature_granularity
349
+ self._aggregation_specs: Optional[list[AggregationSpec]] = _kwargs.pop("_aggregation_specs", None)
350
+ if features is not None:
351
+ self._aggregation_specs = [f.to_spec() for f in features]
352
+
353
+ # For tiled FVs, re-initialize feature_descs with output column names (not source column names)
354
+ # This ensures descriptions are keyed by the output columns users will see in datasets
355
+ if self.is_tiled and self._aggregation_specs:
356
+ self._feature_desc = OrderedDict(
357
+ (SqlIdentifier(spec.get_sql_column_name()), "") for spec in self._aggregation_specs
358
+ )
359
+
332
360
  # Validate kwargs
333
361
  if _kwargs:
334
362
  raise TypeError(f"FeatureView.__init__ got an unexpected keyword argument: '{next(iter(_kwargs.keys()))}'")
@@ -535,6 +563,9 @@ class FeatureView(lineage_node.LineageNode):
535
563
 
536
564
  @property
537
565
  def feature_names(self) -> list[SqlIdentifier]:
566
+ # For tiled FVs, return output column names from aggregation specs
567
+ if self.is_tiled and self._aggregation_specs:
568
+ return [SqlIdentifier(spec.get_sql_column_name()) for spec in self._aggregation_specs]
538
569
  return list(self._feature_desc.keys()) if self._feature_desc is not None else []
539
570
 
540
571
  @property
@@ -556,6 +587,25 @@ class FeatureView(lineage_node.LineageNode):
556
587
  def online_config(self) -> Optional[OnlineConfig]:
557
588
  return self._online_config
558
589
 
590
+ @property
591
+ def is_tiled(self) -> bool:
592
+ """Check if this feature view uses tile-based aggregation.
593
+
594
+ Returns:
595
+ True if feature_granularity and features are configured.
596
+ """
597
+ return self._feature_granularity is not None and self._aggregation_specs is not None
598
+
599
+ @property
600
+ def feature_granularity(self) -> Optional[str]:
601
+ """Get the tile interval for aggregations."""
602
+ return self._feature_granularity
603
+
604
+ @property
605
+ def aggregation_specs(self) -> Optional[list[AggregationSpec]]:
606
+ """Get the aggregation specifications (internal use)."""
607
+ return self._aggregation_specs
608
+
559
609
  def fully_qualified_online_table_name(self) -> str:
560
610
  """Get the fully qualified name for the online feature table.
561
611
 
@@ -741,7 +791,7 @@ class FeatureView(lineage_node.LineageNode):
741
791
  def _metadata(self) -> _FeatureViewMetadata:
742
792
  entity_names = [e.name.identifier() for e in self.entities]
743
793
  ts_col = self.timestamp_col.identifier() if self.timestamp_col is not None else _TIMESTAMP_COL_PLACEHOLDER
744
- return _FeatureViewMetadata(entity_names, ts_col)
794
+ return _FeatureViewMetadata(entity_names, ts_col, is_tiled=self.is_tiled)
745
795
 
746
796
  def _get_query(self) -> str:
747
797
  if len(self._feature_df.queries["queries"]) != 1:
@@ -765,7 +815,10 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
765
815
  if k not in df_cols:
766
816
  raise ValueError(f"join_key {k} in Entity {e.name} is not found in input dataframe: {df_cols}")
767
817
 
768
- if self._timestamp_col is not None:
818
+ # For tiled FVs, timestamp_col is used for tiling (not an output column)
819
+ # and cluster_by is adjusted to use TILE_START instead of timestamp_col
820
+ # So skip these validations for tiled FVs
821
+ if self._timestamp_col is not None and not self.is_tiled:
769
822
  ts_col = self._timestamp_col
770
823
  if ts_col == SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER):
771
824
  raise ValueError(f"Invalid timestamp_col name, cannot be {_TIMESTAMP_COL_PLACEHOLDER}.")
@@ -776,7 +829,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
776
829
  if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
777
830
  raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
778
831
 
779
- if self.cluster_by is not None:
832
+ if self.cluster_by is not None and not self.is_tiled:
780
833
  for column in self.cluster_by:
781
834
  if column not in df_cols:
782
835
  raise ValueError(
@@ -790,6 +843,93 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
790
843
  if self._initialize not in ["ON_CREATE", "ON_SCHEDULE"]:
791
844
  raise ValueError("'initialize' only supports ON_CREATE or ON_SCHEDULE.")
792
845
 
846
+ # Validate tiled aggregation configuration
847
+ if self._feature_granularity is not None and self._aggregation_specs is None:
848
+ raise ValueError(
849
+ "feature_granularity requires features to be specified. "
850
+ "Use the Feature class to define aggregations."
851
+ )
852
+ if self._aggregation_specs is not None and self._feature_granularity is None:
853
+ raise ValueError(
854
+ "features requires feature_granularity to be specified. "
855
+ "Specify the tile interval (e.g., '1h', '1d')."
856
+ )
857
+ if self.is_tiled:
858
+ if self._timestamp_col is None:
859
+ raise ValueError(
860
+ "timestamp_col is required for tile-based aggregations. "
861
+ "Specify the timestamp column used for time-series lookups."
862
+ )
863
+ if self._refresh_freq is None:
864
+ raise ValueError(
865
+ "refresh_freq is required for tile-based aggregations. "
866
+ "Tiled feature views must be managed (Dynamic Tables)."
867
+ )
868
+ # Validate window and offset are multiples of granularity
869
+ self._validate_window_offset_alignment()
870
+ # Validate feature aliases are unique
871
+ self._validate_unique_feature_aliases()
872
+
873
+ def _validate_window_offset_alignment(self) -> None:
874
+ """Validate that window and offset are multiples of feature_granularity.
875
+
876
+ This ensures clean tile boundaries for aggregations.
877
+ Lifetime windows are exempt from this validation as they aggregate all tiles.
878
+
879
+ Raises:
880
+ ValueError: If window or offset is not a multiple of feature_granularity.
881
+ """
882
+ from snowflake.ml.feature_store.aggregation import interval_to_seconds
883
+
884
+ granularity_seconds = interval_to_seconds(self._feature_granularity) # type: ignore[arg-type]
885
+
886
+ for spec in self._aggregation_specs: # type: ignore[union-attr]
887
+ # Skip validation for lifetime windows (they aggregate all tiles)
888
+ if spec.is_lifetime():
889
+ continue
890
+
891
+ # Validate window alignment
892
+ window_seconds = spec.get_window_seconds()
893
+ if window_seconds % granularity_seconds != 0:
894
+ raise ValueError(
895
+ f"Window '{spec.window}' for feature '{spec.output_column}' must be a multiple of "
896
+ f"feature_granularity '{self._feature_granularity}'. "
897
+ f"Window ({window_seconds}s) is not divisible by granularity ({granularity_seconds}s)."
898
+ )
899
+
900
+ # Validate offset alignment (if offset is specified)
901
+ offset_seconds = spec.get_offset_seconds()
902
+ if offset_seconds > 0 and offset_seconds % granularity_seconds != 0:
903
+ raise ValueError(
904
+ f"Offset '{spec.offset}' for feature '{spec.output_column}' must be a multiple of "
905
+ f"feature_granularity '{self._feature_granularity}'. "
906
+ f"Offset ({offset_seconds}s) is not divisible by granularity ({granularity_seconds}s)."
907
+ )
908
+
909
+ def _validate_unique_feature_aliases(self) -> None:
910
+ """Validate that all feature aliases are unique.
911
+
912
+ Duplicate aliases would cause SQL errors or ambiguous columns
913
+ in the generated tile table and dataset queries.
914
+
915
+ Raises:
916
+ ValueError: If duplicate feature aliases are found.
917
+ """
918
+ if self._aggregation_specs is None:
919
+ return
920
+
921
+ seen_aliases: dict[str, int] = {}
922
+ for spec in self._aggregation_specs:
923
+ # Normalize to uppercase for comparison (Snowflake default)
924
+ alias = spec.output_column.strip('"').upper()
925
+ if alias in seen_aliases:
926
+ raise ValueError(
927
+ f"Duplicate feature alias '{spec.output_column}' found. "
928
+ f"Each feature must have a unique alias. "
929
+ f"Use .alias() to provide distinct names for features."
930
+ )
931
+ seen_aliases[alias] = 1
932
+
793
933
  def _get_column_names(self) -> Optional[list[SqlIdentifier]]:
794
934
  try:
795
935
  return to_sql_identifiers(self._infer_schema_df.columns)
@@ -861,6 +1001,12 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
861
1001
 
862
1002
  fv_dict["_online_config"] = self._online_config.to_json() if self._online_config is not None else None
863
1003
 
1004
+ # Aggregation fields
1005
+ fv_dict["_feature_granularity"] = self._feature_granularity
1006
+ fv_dict["_aggregation_specs"] = (
1007
+ [spec.to_dict() for spec in self._aggregation_specs] if self._aggregation_specs is not None else None
1008
+ )
1009
+
864
1010
  lineage_node_keys = [key for key in fv_dict if key.startswith("_node") or key == "_session"]
865
1011
 
866
1012
  for key in lineage_node_keys:
@@ -930,6 +1076,11 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
930
1076
  e.owner = e_json["owner"]
931
1077
  entities.append(e)
932
1078
 
1079
+ # Deserialize aggregation specs if present
1080
+ aggregation_specs = None
1081
+ if json_dict.get("_aggregation_specs"):
1082
+ aggregation_specs = [AggregationSpec.from_dict(s) for s in json_dict["_aggregation_specs"]]
1083
+
933
1084
  return FeatureView._construct_feature_view(
934
1085
  name=json_dict["_name"],
935
1086
  entities=entities,
@@ -952,6 +1103,8 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
952
1103
  online_config=OnlineConfig.from_json(json_dict["_online_config"])
953
1104
  if json_dict.get("_online_config")
954
1105
  else None,
1106
+ feature_granularity=json_dict.get("_feature_granularity"),
1107
+ aggregation_specs=aggregation_specs,
955
1108
  )
956
1109
 
957
1110
  def _get_compact_repr(self) -> _CompactRepresentation:
@@ -1025,6 +1178,8 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
1025
1178
  session: Session,
1026
1179
  cluster_by: Optional[list[str]] = None,
1027
1180
  online_config: Optional[OnlineConfig] = None,
1181
+ feature_granularity: Optional[str] = None,
1182
+ aggregation_specs: Optional[list[AggregationSpec]] = None,
1028
1183
  ) -> FeatureView:
1029
1184
  fv = FeatureView(
1030
1185
  name=name,
@@ -1032,13 +1187,15 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
1032
1187
  feature_df=feature_df,
1033
1188
  timestamp_col=timestamp_col,
1034
1189
  desc=desc,
1190
+ refresh_freq=refresh_freq,
1035
1191
  _infer_schema_df=infer_schema_df,
1036
1192
  cluster_by=cluster_by,
1037
1193
  online_config=online_config,
1194
+ feature_granularity=feature_granularity,
1195
+ _aggregation_specs=aggregation_specs,
1038
1196
  )
1039
1197
  fv._version = FeatureViewVersion(version) if version is not None else None
1040
1198
  fv._status = status
1041
- fv._refresh_freq = refresh_freq
1042
1199
  fv._database = SqlIdentifier(database) if database is not None else None
1043
1200
  fv._schema = SqlIdentifier(schema) if schema is not None else None
1044
1201
  fv._warehouse = SqlIdentifier(warehouse) if warehouse is not None else None
@@ -1071,6 +1228,34 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
1071
1228
 
1072
1229
  return default_cluster_by_cols
1073
1230
 
1231
+ def _get_tile_query(self) -> str:
1232
+ """Generate the tiling query for tile-based aggregations.
1233
+
1234
+ This query is used as the source for the Dynamic Table that stores
1235
+ pre-computed partial aggregations (tiles).
1236
+
1237
+ Returns:
1238
+ SQL query string for creating tile table.
1239
+
1240
+ Raises:
1241
+ ValueError: If the feature view is not configured for tiling.
1242
+ """
1243
+ if not self.is_tiled:
1244
+ raise ValueError("_get_tile_query called on non-tiled feature view")
1245
+
1246
+ from snowflake.ml.feature_store.tile_sql_generator import TilingSqlGenerator
1247
+
1248
+ join_keys = [str(k) for e in self._entities for k in e.join_keys]
1249
+
1250
+ generator = TilingSqlGenerator(
1251
+ source_query=self._query,
1252
+ join_keys=join_keys,
1253
+ timestamp_col=str(self._timestamp_col),
1254
+ feature_granularity=self._feature_granularity, # type: ignore[arg-type]
1255
+ features=self._aggregation_specs, # type: ignore[arg-type]
1256
+ )
1257
+ return generator.generate()
1258
+
1074
1259
  @staticmethod
1075
1260
  def _get_online_table_name(
1076
1261
  feature_view_name: Union[SqlIdentifier, str], version: Optional[Union[FeatureViewVersion, str]] = None