snowflake-ml-python 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -0
  2. snowflake/ml/_internal/telemetry.py +56 -7
  3. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +1 -7
  4. snowflake/ml/experiment/_entities/run.py +15 -0
  5. snowflake/ml/experiment/experiment_tracking.py +61 -73
  6. snowflake/ml/feature_store/access_manager.py +1 -0
  7. snowflake/ml/feature_store/feature_store.py +86 -31
  8. snowflake/ml/feature_store/feature_view.py +12 -6
  9. snowflake/ml/fileset/stage_fs.py +12 -1
  10. snowflake/ml/jobs/_utils/feature_flags.py +1 -0
  11. snowflake/ml/jobs/_utils/payload_utils.py +6 -1
  12. snowflake/ml/jobs/_utils/spec_utils.py +12 -3
  13. snowflake/ml/jobs/job.py +8 -3
  14. snowflake/ml/jobs/manager.py +19 -6
  15. snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
  16. snowflake/ml/model/_client/model/model_version_impl.py +45 -17
  17. snowflake/ml/model/_client/ops/model_ops.py +11 -4
  18. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  19. snowflake/ml/model/models/huggingface_pipeline.py +6 -7
  20. snowflake/ml/monitoring/explain_visualize.py +3 -1
  21. snowflake/ml/version.py +1 -1
  22. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/METADATA +68 -5
  23. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/RECORD +26 -26
  24. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/WHEEL +0 -0
  25. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/licenses/LICENSE.txt +0 -0
  26. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/top_level.txt +0 -0
@@ -474,8 +474,8 @@ class FeatureStore:
474
474
  feature_view: FeatureView instance to materialize.
475
475
  version: version of the registered FeatureView.
476
476
  NOTE: Version only accepts letters, numbers and underscore. Also version will be capitalized.
477
- block: Specify whether the FeatureView backend materialization should be blocking or not. If blocking then
478
- the API will wait until the initial FeatureView data is generated. Default to true.
477
+ block: Deprecated. To make the initial refresh asynchronous, set the `initialize`
478
+ argument on the `FeatureView` to `"ON_SCHEDULE"`. Default is true.
479
479
  overwrite: Overwrite the existing FeatureView with same version. This is the same as dropping the
480
480
  FeatureView first then recreate. NOTE: there will be backfill cost associated if the FeatureView is
481
481
  being continuously maintained.
@@ -521,6 +521,15 @@ class FeatureStore:
521
521
  """
522
522
  version = FeatureViewVersion(version)
523
523
 
524
+ if block is False:
525
+ raise snowml_exceptions.SnowflakeMLException(
526
+ error_code=error_codes.INVALID_ARGUMENT,
527
+ original_exception=ValueError(
528
+ 'block=False is deprecated. Use FeatureView(..., initialize="ON_SCHEDULE") '
529
+ "for async initial refresh."
530
+ ),
531
+ )
532
+
524
533
  if feature_view.status != FeatureViewStatus.DRAFT:
525
534
  try:
526
535
  return self._get_feature_view_if_exists(feature_view.name, str(version))
@@ -1191,7 +1200,7 @@ class FeatureStore:
1191
1200
  {self._config.database}.INFORMATION_SCHEMA.DYNAMIC_TABLE_REFRESH_HISTORY (RESULT_LIMIT => 10000)
1192
1201
  )
1193
1202
  WHERE NAME = '{fv_resolved_name}'
1194
- AND SCHEMA_NAME = '{self._config.schema}'
1203
+ AND SCHEMA_NAME = '{self._config.schema.resolved()}'
1195
1204
  """
1196
1205
  )
1197
1206
 
@@ -2094,26 +2103,48 @@ class FeatureStore:
2094
2103
  def _plan_online_update(
2095
2104
  self, feature_view: FeatureView, online_config: Optional[fv_mod.OnlineConfig]
2096
2105
  ) -> _OnlineUpdateStrategy:
2097
- """Plan online update operations based on current state and target config."""
2106
+ """Plan online update operations based on current state and target config.
2107
+
2108
+ Handles three cases:
2109
+ - enable is None: Preserve current online state, only update if currently online
2110
+ - enable is True: Enable online storage (create if needed, update if exists)
2111
+ - enable is False: Disable online storage (drop if exists)
2112
+
2113
+ Args:
2114
+ feature_view: The FeatureView object to check current online state.
2115
+ online_config: The OnlineConfig with target enable and lag settings.
2116
+
2117
+ Returns:
2118
+ _OnlineUpdateStrategy containing operations and their rollbacks.
2119
+ """
2098
2120
  if online_config is None:
2099
2121
  return self._OnlineUpdateStrategy([], [], None)
2100
2122
 
2101
2123
  current_online = feature_view.online
2102
2124
  target_online = online_config.enable
2103
2125
 
2104
- # Enable online (create table)
2126
+ # Case 1: enable is None - preserve current online state, only update if currently online
2127
+ if target_online is None:
2128
+ if current_online and (online_config.target_lag is not None):
2129
+ # Online is currently enabled and user wants to update lag
2130
+ return self._plan_online_update_existing(feature_view, online_config)
2131
+ else:
2132
+ # No online changes needed (either not online, or lag not specified)
2133
+ return self._OnlineUpdateStrategy([], [], None)
2134
+
2135
+ # Case 2: Enable online (create table)
2105
2136
  if target_online and not current_online:
2106
2137
  return self._plan_online_enable(feature_view, online_config)
2107
2138
 
2108
- # Disable online (drop table)
2139
+ # Case 3: Disable online (drop table)
2109
2140
  elif not target_online and current_online:
2110
2141
  return self._plan_online_disable(feature_view)
2111
2142
 
2112
- # Update existing online table
2143
+ # Case 4: Update existing online table
2113
2144
  elif target_online and current_online:
2114
2145
  return self._plan_online_update_existing(feature_view, online_config)
2115
2146
 
2116
- # No change needed
2147
+ # Case 5: No change needed
2117
2148
  else:
2118
2149
  return self._OnlineUpdateStrategy([], [], online_config)
2119
2150
 
@@ -2621,9 +2652,10 @@ class FeatureStore:
2621
2652
 
2622
2653
  This method supports feature views with different join keys by:
2623
2654
  1. Creating a spine CTE that includes all possible join keys
2624
- 2. Performing ASOF JOINs for each feature view using only its specific join keys when timestamp columns exist
2625
- 3. Performing LEFT JOINs for each feature view when timestamp columns are missing
2626
- 4. Combining results using INNER JOINs on each feature view's specific join keys
2655
+ 2. For each feature view, creating a deduplicated spine subquery with only that FV's join keys
2656
+ 3. Performing ASOF JOINs on the deduplicated spine when timestamp columns exist
2657
+ 4. Performing LEFT JOINs on the deduplicated spine when timestamp columns are missing
2658
+ 5. Combining results by LEFT JOINing each FV CTE back to the original SPINE
2627
2659
 
2628
2660
  Args:
2629
2661
  feature_views: A list of feature views to join.
@@ -2633,9 +2665,6 @@ class FeatureStore:
2633
2665
  include_feature_view_timestamp_col: Whether to include the timestamp column of
2634
2666
  the feature view in the result. Default to false.
2635
2667
 
2636
- Note: This method does NOT work when there are duplicate combinations of join keys and timestamp columns
2637
- in spine.
2638
-
2639
2668
  Returns:
2640
2669
  A SQL query string with CTE structure for joining feature views.
2641
2670
  """
@@ -2659,11 +2688,17 @@ class FeatureStore:
2659
2688
  fv_join_keys = list({k for e in feature_view.entities for k in e.join_keys})
2660
2689
  join_keys_str = ", ".join(fv_join_keys)
2661
2690
 
2662
- # Build the JOIN condition using only this feature view's join keys
2663
- join_conditions = [f'SPINE."{col}" = FEATURE."{col}"' for col in fv_join_keys]
2664
-
2665
2691
  # Use ASOF JOIN if both spine and feature view have timestamp columns, otherwise use LEFT JOIN
2666
2692
  if spine_timestamp_col is not None and feature_timestamp_col is not None:
2693
+ # Build the deduplicated spine columns set (join keys + timestamp)
2694
+ spine_dedup_cols_set = set(fv_join_keys)
2695
+ if spine_timestamp_col not in spine_dedup_cols_set:
2696
+ spine_dedup_cols_set.add(spine_timestamp_col)
2697
+ spine_dedup_cols_str = ", ".join(f'"{col}"' for col in spine_dedup_cols_set)
2698
+
2699
+ # Build the JOIN condition using only this feature view's join keys
2700
+ join_conditions_dedup = [f'SPINE_DEDUP."{col}" = FEATURE."{col}"' for col in fv_join_keys]
2701
+
2667
2702
  if include_feature_view_timestamp_col:
2668
2703
  f_ts_col_alias = identifier.concat_names(
2669
2704
  [feature_view.name, "_", str(feature_view.version), "_", feature_timestamp_col]
@@ -2674,36 +2709,46 @@ class FeatureStore:
2674
2709
  ctes.append(
2675
2710
  f"""{cte_name} AS (
2676
2711
  SELECT
2677
- SPINE.*,
2712
+ SPINE_DEDUP.*,
2678
2713
  {f_ts_col_str}
2679
2714
  FEATURE.* EXCLUDE ({join_keys_str}, {feature_timestamp_col})
2680
- FROM
2681
- SPINE
2715
+ FROM (
2716
+ SELECT DISTINCT {spine_dedup_cols_str}
2717
+ FROM SPINE
2718
+ ) SPINE_DEDUP
2682
2719
  ASOF JOIN (
2683
2720
  SELECT {join_keys_str}, {feature_timestamp_col}, {feature_columns[i]}
2684
2721
  FROM {feature_view.fully_qualified_name()}
2685
2722
  ) FEATURE
2686
- MATCH_CONDITION (SPINE."{spine_timestamp_col}" >= FEATURE."{feature_timestamp_col}")
2687
- ON {" AND ".join(join_conditions)}
2723
+ MATCH_CONDITION (SPINE_DEDUP."{spine_timestamp_col}" >= FEATURE."{feature_timestamp_col}")
2724
+ ON {" AND ".join(join_conditions_dedup)}
2688
2725
  )"""
2689
2726
  )
2690
2727
  else:
2728
+ # Build the deduplicated spine columns list (just join keys, no timestamp)
2729
+ spine_dedup_cols_str = ", ".join(f'"{col}"' for col in fv_join_keys)
2730
+
2731
+ # Build the JOIN condition using only this feature view's join keys
2732
+ join_conditions_dedup = [f'SPINE_DEDUP."{col}" = FEATURE."{col}"' for col in fv_join_keys]
2733
+
2691
2734
  ctes.append(
2692
2735
  f"""{cte_name} AS (
2693
2736
  SELECT
2694
- SPINE.*,
2737
+ SPINE_DEDUP.*,
2695
2738
  FEATURE.* EXCLUDE ({join_keys_str})
2696
- FROM
2697
- SPINE
2739
+ FROM (
2740
+ SELECT DISTINCT {spine_dedup_cols_str}
2741
+ FROM SPINE
2742
+ ) SPINE_DEDUP
2698
2743
  LEFT JOIN (
2699
2744
  SELECT {join_keys_str}, {feature_columns[i]}
2700
2745
  FROM {feature_view.fully_qualified_name()}
2701
2746
  ) FEATURE
2702
- ON {" AND ".join(join_conditions)}
2747
+ ON {" AND ".join(join_conditions_dedup)}
2703
2748
  )"""
2704
2749
  )
2705
2750
 
2706
- # Build final SELECT with individual joins to each FV CTE
2751
+ # Build final SELECT with LEFT joins to each FV CTE
2707
2752
  select_columns = []
2708
2753
  join_clauses = []
2709
2754
 
@@ -2711,19 +2756,29 @@ class FeatureStore:
2711
2756
  feature_view = feature_views[i]
2712
2757
  fv_join_keys = list({k for e in feature_view.entities for k in e.join_keys})
2713
2758
  join_conditions = [f'SPINE."{col}" = {cte_name}."{col}"' for col in fv_join_keys]
2714
- if spine_timestamp_col is not None:
2759
+ # Only include spine timestamp in join condition if both spine and FV have timestamps
2760
+ if spine_timestamp_col is not None and feature_view.timestamp_col is not None:
2715
2761
  join_conditions.append(f'SPINE."{spine_timestamp_col}" = {cte_name}."{spine_timestamp_col}"')
2762
+
2716
2763
  if include_feature_view_timestamp_col and feature_view.timestamp_col is not None:
2717
2764
  f_ts_col_alias = identifier.concat_names(
2718
2765
  [feature_view.name, "_", str(feature_view.version), "_", feature_view.timestamp_col]
2719
2766
  )
2720
2767
  f_ts_col_str = f"{cte_name}.{f_ts_col_alias} AS {f_ts_col_alias}"
2721
2768
  select_columns.append(f_ts_col_str)
2722
- select_columns.append(feature_columns[i])
2769
+
2770
+ # Select features from the CTE
2771
+ # feature_columns[i] is already a comma-separated string of column names
2772
+ feature_cols_from_cte = []
2773
+ for col in feature_columns[i].split(", "):
2774
+ col_clean = col.strip()
2775
+ feature_cols_from_cte.append(f"{cte_name}.{col_clean}")
2776
+ select_columns.extend(feature_cols_from_cte)
2777
+
2723
2778
  # Create join condition using only this feature view's join keys
2724
2779
  join_clauses.append(
2725
2780
  f"""
2726
- INNER JOIN {cte_name}
2781
+ LEFT JOIN {cte_name}
2727
2782
  ON {" AND ".join(join_conditions)}"""
2728
2783
  )
2729
2784
 
@@ -3388,7 +3443,7 @@ FROM SPINE{' '.join(join_clauses)}
3388
3443
  online_table_name = FeatureView._get_online_table_name(feature_view_name)
3389
3444
 
3390
3445
  fully_qualified_online_name = self._get_fully_qualified_name(online_table_name)
3391
- source_table_name = feature_view_name
3446
+ source_table_name = self._get_fully_qualified_name(feature_view_name)
3392
3447
 
3393
3448
  # Extract join keys for PRIMARY KEY (preserve order and ensure unique)
3394
3449
  ordered_join_keys: list[str] = []
@@ -1,7 +1,7 @@
1
+ """Feature view module for Snowflake ML Feature Store."""
1
2
  from __future__ import annotations
2
3
 
3
4
  import json
4
- import logging
5
5
  import re
6
6
  import warnings
7
7
  from collections import OrderedDict
@@ -52,7 +52,7 @@ _RESULT_SCAN_QUERY_PATTERN = re.compile(
52
52
  class OnlineConfig:
53
53
  """Configuration for online feature storage."""
54
54
 
55
- enable: bool = False
55
+ enable: Optional[bool] = None
56
56
  target_lag: Optional[str] = None
57
57
 
58
58
  def __post_init__(self) -> None:
@@ -248,6 +248,7 @@ class FeatureView(lineage_node.LineageNode):
248
248
  - If `timestamp_col` is provided, it is added to the default clustering keys.
249
249
  online_config: Optional configuration for online storage. If provided with enable=True,
250
250
  online storage will be enabled. Defaults to None (no online storage).
251
+ NOTE: this feature is currently in Public Preview.
251
252
  _kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
252
253
 
253
254
  Example::
@@ -289,8 +290,6 @@ class FeatureView(lineage_node.LineageNode):
289
290
 
290
291
  # noqa: DAR401
291
292
  """
292
- if online_config is not None:
293
- logging.warning("'online_config' is in private preview since 1.12.0. Do not use it in production.")
294
293
 
295
294
  self._name: SqlIdentifier = SqlIdentifier(name)
296
295
  self._entities: list[Entity] = entities
@@ -533,8 +532,15 @@ class FeatureView(lineage_node.LineageNode):
533
532
  return self._feature_desc
534
533
 
535
534
  @property
536
- def online(self) -> bool:
537
- return self._online_config.enable if self._online_config else False
535
+ def online(self) -> bool: # noqa: DAR101
536
+ """Check if online storage is enabled for this feature view.
537
+
538
+ Returns:
539
+ True if online storage is enabled, False otherwise.
540
+ """
541
+ if self._online_config and self._online_config.enable is True:
542
+ return True
543
+ return False
538
544
 
539
545
  @property
540
546
  def online_config(self) -> Optional[OnlineConfig]:
@@ -1,5 +1,6 @@
1
1
  import inspect
2
2
  import logging
3
+ import re
3
4
  import time
4
5
  from dataclasses import dataclass
5
6
  from typing import Any, Optional, Union, cast
@@ -27,6 +28,8 @@ _PRESIGNED_URL_LIFETIME_SEC = 14400
27
28
  # The threshold of when the presigned url should get refreshed before its expiration.
28
29
  _PRESIGNED_URL_HEADROOM_SEC = 3600
29
30
 
31
+ # Regex pattern to match cloud storage prefixes (s3://, gcs://, azure://) and bucket/container name at start of string
32
+ _CLOUD_PATH_PREFIX_PATTERN = re.compile(r"^(s3|gcs|azure)://[^/]+/", re.IGNORECASE)
30
33
 
31
34
  _PROJECT = "FileSet"
32
35
 
@@ -355,8 +358,16 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
355
358
 
356
359
  Returns:
357
360
  A string of the relative stage path.
361
+
362
+ Raises:
363
+ ValueError: If the stage path format is invalid.
358
364
  """
359
- return stage_path[len(self._stage) + 1 :]
365
+ if stage_path.lower().startswith(self._stage.lower()):
366
+ return stage_path[len(self._stage) + 1 :]
367
+ elif match := _CLOUD_PATH_PREFIX_PATTERN.match(stage_path):
368
+ return stage_path[match.end() :]
369
+
370
+ raise ValueError(f"Invalid stage path: {stage_path}")
360
371
 
361
372
  def _add_file_info_helper(
362
373
  self,
@@ -31,6 +31,7 @@ def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
31
31
  class FeatureFlags(Enum):
32
32
  USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2"
33
33
  ENABLE_RUNTIME_VERSIONS = "MLRS_ENABLE_RUNTIME_VERSIONS"
34
+ ENABLE_STAGE_MOUNT_V2 = "MLRS_ENABLE_STAGE_MOUNT_V2"
34
35
 
35
36
  def is_enabled(self, default: bool = False) -> bool:
36
37
  """Check if the feature flag is enabled.
@@ -620,7 +620,12 @@ def _serialize_callable(func: Callable[..., Any]) -> bytes:
620
620
  try:
621
621
  func_bytes: bytes = cp.dumps(func)
622
622
  return func_bytes
623
- except pickle.PicklingError as e:
623
+ except (pickle.PicklingError, TypeError) as e:
624
+ if isinstance(e, TypeError) and "_thread.lock" in str(e):
625
+ raise RuntimeError(
626
+ "Unable to pickle an object that internally holds a reference to a Session object, "
627
+ "such as a Snowpark DataFrame."
628
+ ) from e
624
629
  if isinstance(func, functools.partial):
625
630
  # Try to find which part of the partial isn't serializable for better debuggability
626
631
  objects = [
@@ -197,7 +197,7 @@ def generate_service_spec(
197
197
  resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
198
198
 
199
199
  # Add local volumes for ephemeral logs and artifacts
200
- volumes: list[dict[str, str]] = []
200
+ volumes: list[dict[str, Any]] = []
201
201
  volume_mounts: list[dict[str, str]] = []
202
202
  for volume_name, mount_path in [
203
203
  ("system-logs", "/var/log/managedservices/system/mlrs"),
@@ -246,7 +246,16 @@ def generate_service_spec(
246
246
  volumes.append(
247
247
  {
248
248
  "name": constants.STAGE_VOLUME_NAME,
249
- "source": payload.stage_path.as_posix(),
249
+ "source": "stage",
250
+ "stageConfig": {
251
+ "name": payload.stage_path.as_posix(),
252
+ "resources": {
253
+ "requests": {
254
+ "memory": "0Gi",
255
+ "cpu": "0",
256
+ },
257
+ },
258
+ },
250
259
  }
251
260
  )
252
261
 
@@ -286,7 +295,7 @@ def generate_service_spec(
286
295
  "storage",
287
296
  ]
288
297
 
289
- spec_dict = {
298
+ spec_dict: dict[str, Any] = {
290
299
  "containers": [
291
300
  {
292
301
  "name": constants.DEFAULT_CONTAINER_NAME,
snowflake/ml/jobs/job.py CHANGED
@@ -109,11 +109,16 @@ class MLJob(Generic[T], SerializableSessionMixin):
109
109
  return cast(dict[str, Any], container_spec)
110
110
 
111
111
  @property
112
- def _stage_path(self) -> str:
112
+ def _stage_path(self) -> Optional[str]:
113
113
  """Get the job's artifact storage stage location."""
114
114
  volumes = self._service_spec["spec"]["volumes"]
115
- stage_path = next(v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME)["source"]
116
- return cast(str, stage_path)
115
+ stage_volume = next((v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME), None)
116
+ if stage_volume is None:
117
+ return None
118
+ elif "stageConfig" in stage_volume:
119
+ return cast(str, stage_volume["stageConfig"]["name"])
120
+ else:
121
+ return cast(str, stage_volume["source"])
117
122
 
118
123
  @property
119
124
  def _result_path(self) -> str:
@@ -192,12 +192,12 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
192
192
  """Delete a job service from the backend. Status and logs will be lost."""
193
193
  job = job if isinstance(job, jb.MLJob) else get_job(job, session=session)
194
194
  session = job._session
195
- try:
196
- stage_path = job._stage_path
197
- session.sql(f"REMOVE {stage_path}/").collect()
198
- logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
199
- except Exception as e:
200
- logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
195
+ if job._stage_path:
196
+ try:
197
+ session.sql(f"REMOVE {job._stage_path}/").collect()
198
+ logger.debug(f"Successfully cleaned up stage files for job {job.id} at {job._stage_path}")
199
+ except Exception as e:
200
+ logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
201
201
  query_helper.run_query(session, "DROP SERVICE IDENTIFIER(?)", params=(job.id,))
202
202
 
203
203
 
@@ -520,6 +520,12 @@ def _submit_job(
520
520
  raise RuntimeError(
521
521
  "Please specify a schema, either in the session context or as a parameter in the job submission"
522
522
  )
523
+ elif e.sql_error_code == 3001 and "schema" in str(e).lower():
524
+ raise RuntimeError(
525
+ "please grant privileges on schema before submitting a job, see",
526
+ "https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements",
527
+ " for more details",
528
+ ) from e
523
529
  raise
524
530
 
525
531
  if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True):
@@ -546,6 +552,12 @@ def _submit_job(
546
552
  except SnowparkSQLException as e:
547
553
  if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
548
554
  raise
555
+ elif e.sql_error_code == 3001 and "schema" in str(e).lower():
556
+ raise RuntimeError(
557
+ "please grant privileges on schema before submitting a job, see",
558
+ "https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements"
559
+ " for more details",
560
+ ) from e
549
561
  # SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
550
562
  # stored procedures. This will be fixed in an upcoming release.
551
563
  logger.warning(
@@ -690,6 +702,7 @@ def _do_submit_job_v2(
690
702
  # when feature flag is enabled, we get the local python version and wrap it in a dict
691
703
  # in system function, we can know whether it is python version or image tag or full image URL through the format
692
704
  spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
705
+
693
706
  job_options = {
694
707
  "EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
695
708
  "QUERY_WAREHOUSE": query_warehouse,
@@ -4,14 +4,18 @@ from snowflake.ml.model._client.ops import service_ops
4
4
 
5
5
 
6
6
  def _get_inference_engine_args(
7
- experimental_options: Optional[dict[str, Any]],
7
+ inference_engine_options: Optional[dict[str, Any]],
8
8
  ) -> Optional[service_ops.InferenceEngineArgs]:
9
- if not experimental_options or "inference_engine" not in experimental_options:
9
+
10
+ if not inference_engine_options:
10
11
  return None
11
12
 
13
+ if "engine" not in inference_engine_options:
14
+ raise ValueError("'engine' field is required in inference_engine_options")
15
+
12
16
  return service_ops.InferenceEngineArgs(
13
- inference_engine=experimental_options["inference_engine"],
14
- inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
17
+ inference_engine=inference_engine_options["engine"],
18
+ inference_engine_args_override=inference_engine_options.get("engine_args_override"),
15
19
  )
16
20
 
17
21
 
@@ -12,7 +12,7 @@ from snowflake.ml import jobs
12
12
  from snowflake.ml._internal import telemetry
13
13
  from snowflake.ml._internal.utils import sql_identifier
14
14
  from snowflake.ml.lineage import lineage_node
15
- from snowflake.ml.model import task, type_hints
15
+ from snowflake.ml.model import openai_signatures, task, type_hints
16
16
  from snowflake.ml.model._client.model import (
17
17
  batch_inference_specs,
18
18
  inference_engine_utils,
@@ -23,6 +23,7 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest_sch
23
23
  from snowflake.ml.model._model_composer.model_method import utils as model_method_utils
24
24
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
25
25
  from snowflake.ml.model._packager.model_meta import model_meta_schema
26
+ from snowflake.ml.model._signatures import core
26
27
  from snowflake.snowpark import Session, async_job, dataframe
27
28
 
28
29
  _TELEMETRY_PROJECT = "MLOps"
@@ -940,14 +941,16 @@ class ModelVersion(lineage_node.LineageNode):
940
941
  self,
941
942
  statement_params: Optional[dict[str, Any]] = None,
942
943
  ) -> None:
943
- """Check if the model is a HuggingFace pipeline with text-generation task.
944
+ """Check if the model is a HuggingFace pipeline with text-generation task
945
+ and is logged with OPENAI_CHAT_SIGNATURE.
944
946
 
945
947
  Args:
946
948
  statement_params: Optional dictionary of statement parameters to include
947
949
  in the SQL command to fetch model spec.
948
950
 
949
951
  Raises:
950
- ValueError: If the model is not a HuggingFace text-generation model.
952
+ ValueError: If the model is not a HuggingFace text-generation model or
953
+ if the model is not logged with OPENAI_CHAT_SIGNATURE.
951
954
  """
952
955
  # Fetch model spec
953
956
  model_spec = self._get_model_spec(statement_params)
@@ -983,6 +986,21 @@ class ModelVersion(lineage_node.LineageNode):
983
986
  )
984
987
  raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}")
985
988
 
989
+ # Check if the model is logged with OPENAI_CHAT_SIGNATURE
990
+ signatures_dict = model_spec.get("signatures", {})
991
+
992
+ # Deserialize signatures from model spec to ModelSignature objects for proper semantic comparison.
993
+ deserialized_signatures = {
994
+ func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
995
+ }
996
+
997
+ if deserialized_signatures != openai_signatures.OPENAI_CHAT_SIGNATURE:
998
+ raise ValueError(
999
+ "Inference engine requires the model to be logged with OPENAI_CHAT_SIGNATURE. "
1000
+ f"Found signatures: {signatures_dict}. "
1001
+ "Please log the model with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE"
1002
+ )
1003
+
986
1004
  @overload
987
1005
  def create_service(
988
1006
  self,
@@ -1001,6 +1019,7 @@ class ModelVersion(lineage_node.LineageNode):
1001
1019
  force_rebuild: bool = False,
1002
1020
  build_external_access_integration: Optional[str] = None,
1003
1021
  block: bool = True,
1022
+ inference_engine_options: Optional[dict[str, Any]] = None,
1004
1023
  experimental_options: Optional[dict[str, Any]] = None,
1005
1024
  ) -> Union[str, async_job.AsyncJob]:
1006
1025
  """Create an inference service with the given spec.
@@ -1034,10 +1053,12 @@ class ModelVersion(lineage_node.LineageNode):
1034
1053
  block: A bool value indicating whether this function will wait until the service is available.
1035
1054
  When it is ``False``, this function executes the underlying service creation asynchronously
1036
1055
  and returns an :class:`AsyncJob`.
1037
- experimental_options: Experimental options for the service creation with custom inference engine.
1038
- Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
1039
- `inference_engine` is the name of the inference engine to use.
1040
- `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1056
+ inference_engine_options: Options for the service creation with custom inference engine.
1057
+ Supports `engine` and `engine_args_override`.
1058
+ `engine` is the type of the inference engine to use.
1059
+ `engine_args_override` is a list of string arguments to pass to the inference engine.
1060
+ experimental_options: Experimental options for the service creation.
1061
+ Currently only `autocapture` is supported.
1041
1062
  `autocapture` is a boolean to enable/disable inference table.
1042
1063
  """
1043
1064
  ...
@@ -1060,6 +1081,7 @@ class ModelVersion(lineage_node.LineageNode):
1060
1081
  force_rebuild: bool = False,
1061
1082
  build_external_access_integrations: Optional[list[str]] = None,
1062
1083
  block: bool = True,
1084
+ inference_engine_options: Optional[dict[str, Any]] = None,
1063
1085
  experimental_options: Optional[dict[str, Any]] = None,
1064
1086
  ) -> Union[str, async_job.AsyncJob]:
1065
1087
  """Create an inference service with the given spec.
@@ -1093,10 +1115,12 @@ class ModelVersion(lineage_node.LineageNode):
1093
1115
  block: A bool value indicating whether this function will wait until the service is available.
1094
1116
  When it is ``False``, this function executes the underlying service creation asynchronously
1095
1117
  and returns an :class:`AsyncJob`.
1096
- experimental_options: Experimental options for the service creation with custom inference engine.
1097
- Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
1098
- `inference_engine` is the name of the inference engine to use.
1099
- `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1118
+ inference_engine_options: Options for the service creation with custom inference engine.
1119
+ Supports `engine` and `engine_args_override`.
1120
+ `engine` is the type of the inference engine to use.
1121
+ `engine_args_override` is a list of string arguments to pass to the inference engine.
1122
+ experimental_options: Experimental options for the service creation.
1123
+ Currently only `autocapture` is supported.
1100
1124
  `autocapture` is a boolean to enable/disable inference table.
1101
1125
  """
1102
1126
  ...
@@ -1134,6 +1158,7 @@ class ModelVersion(lineage_node.LineageNode):
1134
1158
  build_external_access_integration: Optional[str] = None,
1135
1159
  build_external_access_integrations: Optional[list[str]] = None,
1136
1160
  block: bool = True,
1161
+ inference_engine_options: Optional[dict[str, Any]] = None,
1137
1162
  experimental_options: Optional[dict[str, Any]] = None,
1138
1163
  ) -> Union[str, async_job.AsyncJob]:
1139
1164
  """Create an inference service with the given spec.
@@ -1169,10 +1194,12 @@ class ModelVersion(lineage_node.LineageNode):
1169
1194
  block: A bool value indicating whether this function will wait until the service is available.
1170
1195
  When it is False, this function executes the underlying service creation asynchronously
1171
1196
  and returns an AsyncJob.
1172
- experimental_options: Experimental options for the service creation with custom inference engine.
1173
- Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
1174
- `inference_engine` is the name of the inference engine to use.
1175
- `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1197
+ inference_engine_options: Options for the service creation with custom inference engine.
1198
+ Supports `engine` and `engine_args_override`.
1199
+ `engine` is the type of the inference engine to use.
1200
+ `engine_args_override` is a list of string arguments to pass to the inference engine.
1201
+ experimental_options: Experimental options for the service creation.
1202
+ Currently only `autocapture` is supported.
1176
1203
  `autocapture` is a boolean to enable/disable inference table.
1177
1204
 
1178
1205
 
@@ -1209,9 +1236,10 @@ class ModelVersion(lineage_node.LineageNode):
1209
1236
  # Validate GPU support if GPU resources are requested
1210
1237
  self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
1211
1238
 
1212
- inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
1239
+ inference_engine_args = inference_engine_utils._get_inference_engine_args(inference_engine_options)
1213
1240
 
1214
- # Check if model is HuggingFace text-generation before doing inference engine checks
1241
+ # Check if model is HuggingFace text-generation and is logged with
1242
+ # OPENAI_CHAT_SIGNATURE before doing inference engine checks
1215
1243
  # Only validate if inference engine is actually specified
1216
1244
  if inference_engine_args is not None:
1217
1245
  self._check_huggingface_text_generation_model(statement_params)
@@ -515,10 +515,17 @@ class ModelOperator:
515
515
  statement_params=statement_params,
516
516
  )
517
517
  for r in res:
518
- if alias_name in r[self._model_client.MODEL_VERSION_ALIASES_COL_NAME]:
519
- return sql_identifier.SqlIdentifier(
520
- r[self._model_client.MODEL_VERSION_NAME_COL_NAME], case_sensitive=True
521
- )
518
+ aliases_data = r[self._model_client.MODEL_VERSION_ALIASES_COL_NAME]
519
+ if aliases_data:
520
+ aliases_list = json.loads(aliases_data)
521
+
522
+ # Compare using Snowflake identifier semantics for exact match
523
+ for alias in aliases_list:
524
+ if sql_identifier.SqlIdentifier(alias) == alias_name:
525
+ return sql_identifier.SqlIdentifier(
526
+ r[self._model_client.MODEL_VERSION_NAME_COL_NAME], case_sensitive=True
527
+ )
528
+
522
529
  return None
523
530
 
524
531
  def get_tag_value(