snowflake-ml-python 1.17.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. snowflake/ml/_internal/telemetry.py +3 -2
  2. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
  3. snowflake/ml/experiment/callback/keras.py +3 -0
  4. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  5. snowflake/ml/experiment/callback/xgboost.py +3 -0
  6. snowflake/ml/experiment/experiment_tracking.py +19 -7
  7. snowflake/ml/feature_store/feature_store.py +236 -61
  8. snowflake/ml/jobs/_utils/constants.py +12 -1
  9. snowflake/ml/jobs/_utils/payload_utils.py +7 -1
  10. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  11. snowflake/ml/jobs/_utils/types.py +5 -0
  12. snowflake/ml/jobs/job.py +16 -2
  13. snowflake/ml/jobs/manager.py +12 -1
  14. snowflake/ml/model/__init__.py +19 -0
  15. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  16. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  17. snowflake/ml/model/_client/model/model_version_impl.py +129 -11
  18. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  19. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  20. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  22. snowflake/ml/model/_model_composer/model_method/model_method.py +4 -1
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  24. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  25. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -0
  26. snowflake/ml/model/type_hints.py +16 -0
  27. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  28. snowflake/ml/version.py +1 -1
  29. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +25 -1
  30. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +33 -32
  31. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
  32. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
  33. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
@@ -1199,10 +1199,10 @@ class FeatureStore:
1199
1199
  """Get refresh history for online feature table."""
1200
1200
  online_table_name = FeatureView._get_online_table_name(feature_view.name, feature_view.version)
1201
1201
  select_cols = "*" if verbose else "name, state, refresh_start_time, refresh_end_time, refresh_action"
1202
- prefix = (
1203
- f"{self._config.database.resolved()}."
1204
- f"{self._config.schema.resolved()}."
1205
- f"{online_table_name.resolved()}"
1202
+ name = (
1203
+ f"{self._config.database.identifier()}."
1204
+ f"{self._config.schema.identifier()}."
1205
+ f"{online_table_name.identifier()}"
1206
1206
  )
1207
1207
  return self._session.sql(
1208
1208
  f"""
@@ -1210,9 +1210,8 @@ class FeatureStore:
1210
1210
  {select_cols}
1211
1211
  FROM TABLE (
1212
1212
  {self._config.database}.INFORMATION_SCHEMA.ONLINE_FEATURE_TABLE_REFRESH_HISTORY (
1213
- NAME_PREFIX => '{prefix}'
1213
+ NAME => '{name}'
1214
1214
  )
1215
-
1216
1215
  )
1217
1216
  """
1218
1217
  )
@@ -1591,6 +1590,7 @@ class FeatureStore:
1591
1590
  spine_timestamp_col: Optional[str] = None,
1592
1591
  exclude_columns: Optional[list[str]] = None,
1593
1592
  include_feature_view_timestamp_col: bool = False,
1593
+ join_method: Literal["sequential", "cte"] = "sequential",
1594
1594
  ) -> DataFrame:
1595
1595
  """
1596
1596
  Enrich spine dataframe with feature values. Mainly used to generate inference data input.
@@ -1604,6 +1604,8 @@ class FeatureStore:
1604
1604
  exclude_columns: Column names to exclude from the result dataframe.
1605
1605
  include_feature_view_timestamp_col: Generated dataset will include timestamp column of feature view
1606
1606
  (if feature view has timestamp column) if set true. Default to false.
1607
+ join_method: Method for feature joins. "sequential" for layer-by-layer joins (default),
1608
+ "cte" for CTE method. (Internal use only - subject to change)
1607
1609
 
1608
1610
  Returns:
1609
1611
  Snowpark DataFrame containing the joined results.
@@ -1641,6 +1643,7 @@ class FeatureStore:
1641
1643
  cast(list[Union[FeatureView, FeatureViewSlice]], features),
1642
1644
  spine_timestamp_col,
1643
1645
  include_feature_view_timestamp_col,
1646
+ join_method,
1644
1647
  )
1645
1648
 
1646
1649
  if exclude_columns is not None:
@@ -1659,6 +1662,7 @@ class FeatureStore:
1659
1662
  spine_label_cols: Optional[list[str]] = None,
1660
1663
  exclude_columns: Optional[list[str]] = None,
1661
1664
  include_feature_view_timestamp_col: bool = False,
1665
+ join_method: Literal["sequential", "cte"] = "sequential",
1662
1666
  ) -> DataFrame:
1663
1667
  """
1664
1668
  Generate a training set from the specified Spine DataFrame and Feature Views. Result is
@@ -1676,6 +1680,8 @@ class FeatureStore:
1676
1680
  exclude_columns: Name of column(s) to exclude from the resulting training set.
1677
1681
  include_feature_view_timestamp_col: Generated dataset will include timestamp column of feature view
1678
1682
  (if feature view has timestamp column) if set true. Default to false.
1683
+ join_method: Method for feature joins. "sequential" for layer-by-layer joins (default),
1684
+ "cte" for CTE method. (Internal use only - subject to change)
1679
1685
 
1680
1686
  Returns:
1681
1687
  Returns a Snowpark DataFrame representing the training set.
@@ -1709,7 +1715,7 @@ class FeatureStore:
1709
1715
  spine_label_cols = to_sql_identifiers(spine_label_cols) # type: ignore[assignment]
1710
1716
 
1711
1717
  result_df, join_keys = self._join_features(
1712
- spine_df, features, spine_timestamp_col, include_feature_view_timestamp_col
1718
+ spine_df, features, spine_timestamp_col, include_feature_view_timestamp_col, join_method
1713
1719
  )
1714
1720
 
1715
1721
  if exclude_columns is not None:
@@ -1757,6 +1763,7 @@ class FeatureStore:
1757
1763
  include_feature_view_timestamp_col: bool = False,
1758
1764
  desc: str = "",
1759
1765
  output_type: Literal["dataset"] = "dataset",
1766
+ join_method: Literal["sequential", "cte"] = "sequential",
1760
1767
  ) -> dataset.Dataset:
1761
1768
  ...
1762
1769
 
@@ -1774,6 +1781,7 @@ class FeatureStore:
1774
1781
  exclude_columns: Optional[list[str]] = None,
1775
1782
  include_feature_view_timestamp_col: bool = False,
1776
1783
  desc: str = "",
1784
+ join_method: Literal["sequential", "cte"] = "sequential",
1777
1785
  ) -> DataFrame:
1778
1786
  ...
1779
1787
 
@@ -1791,6 +1799,7 @@ class FeatureStore:
1791
1799
  include_feature_view_timestamp_col: bool = False,
1792
1800
  desc: str = "",
1793
1801
  output_type: Literal["dataset", "table"] = "dataset",
1802
+ join_method: Literal["sequential", "cte"] = "sequential",
1794
1803
  ) -> Union[dataset.Dataset, DataFrame]:
1795
1804
  """
1796
1805
  Generate dataset by given source table and feature views.
@@ -1811,6 +1820,8 @@ class FeatureStore:
1811
1820
  (if feature view has timestamp column) if set true. Default to false.
1812
1821
  desc: A description about this dataset.
1813
1822
  output_type: (Deprecated) The type of Snowflake storage to use for the generated training data.
1823
+ join_method: Method for feature joins. "sequential" for layer-by-layer joins (default),
1824
+ "cte" for CTE method. (Internal use only - subject to change)
1814
1825
 
1815
1826
  Returns:
1816
1827
  If output_type is "dataset" (default), returns a Dataset object.
@@ -1874,6 +1885,7 @@ class FeatureStore:
1874
1885
  exclude_columns=exclude_columns,
1875
1886
  include_feature_view_timestamp_col=include_feature_view_timestamp_col,
1876
1887
  save_as=table_name,
1888
+ join_method=join_method,
1877
1889
  )
1878
1890
  if output_type == "table":
1879
1891
  warnings.warn(
@@ -2596,91 +2608,254 @@ class FeatureStore:
2596
2608
  found_rows = self._find_object("TAGS", full_entity_tag_name)
2597
2609
  return len(found_rows) == 1
2598
2610
 
2611
+ def _build_cte_query(
2612
+ self,
2613
+ feature_views: list[FeatureView],
2614
+ feature_columns: list[str],
2615
+ spine_ref: str,
2616
+ spine_timestamp_col: Optional[SqlIdentifier],
2617
+ include_feature_view_timestamp_col: bool = False,
2618
+ ) -> str:
2619
+ """
2620
+ Build a CTE query with the spine query and the feature views.
2621
+
2622
+ This method supports feature views with different join keys by:
2623
+ 1. Creating a spine CTE that includes all possible join keys
2624
+ 2. Performing ASOF JOINs for each feature view using only its specific join keys when timestamp columns exist
2625
+ 3. Performing LEFT JOINs for each feature view when timestamp columns are missing
2626
+ 4. Combining results using INNER JOINs on each feature view's specific join keys
2627
+
2628
+ Args:
2629
+ feature_views: A list of feature views to join.
2630
+ feature_columns: A list of feature column strings for each feature view.
2631
+ spine_ref: The spine query.
2632
+ spine_timestamp_col: The timestamp column from spine. Can be None if spine has no timestamp column.
2633
+ include_feature_view_timestamp_col: Whether to include the timestamp column of
2634
+ the feature view in the result. Default to false.
2635
+
2636
+ Note: This method does NOT work when there are duplicate combinations of join keys and timestamp columns
2637
+ in spine.
2638
+
2639
+ Returns:
2640
+ A SQL query string with CTE structure for joining feature views.
2641
+ """
2642
+ if not feature_views:
2643
+ return f"SELECT * FROM ({spine_ref})"
2644
+
2645
+ # Create spine CTE with the spine query for reuse
2646
+ spine_cte = f"""SPINE AS (
2647
+ SELECT * FROM ({spine_ref})
2648
+ )"""
2649
+
2650
+ ctes = [spine_cte]
2651
+ cte_names = []
2652
+ for i, feature_view in enumerate(feature_views):
2653
+ cte_name = f"FV{i:03d}"
2654
+ cte_names.append(cte_name)
2655
+
2656
+ feature_timestamp_col = feature_view.timestamp_col
2657
+
2658
+ # Get the specific join keys for this feature view
2659
+ fv_join_keys = list({k for e in feature_view.entities for k in e.join_keys})
2660
+ join_keys_str = ", ".join(fv_join_keys)
2661
+
2662
+ # Build the JOIN condition using only this feature view's join keys
2663
+ join_conditions = [f'SPINE."{col}" = FEATURE."{col}"' for col in fv_join_keys]
2664
+
2665
+ # Use ASOF JOIN if both spine and feature view have timestamp columns, otherwise use LEFT JOIN
2666
+ if spine_timestamp_col is not None and feature_timestamp_col is not None:
2667
+ if include_feature_view_timestamp_col:
2668
+ f_ts_col_alias = identifier.concat_names(
2669
+ [feature_view.name, "_", str(feature_view.version), "_", feature_timestamp_col]
2670
+ )
2671
+ f_ts_col_str = f"FEATURE.{feature_timestamp_col} AS {f_ts_col_alias},"
2672
+ else:
2673
+ f_ts_col_str = ""
2674
+ ctes.append(
2675
+ f"""{cte_name} AS (
2676
+ SELECT
2677
+ SPINE.*,
2678
+ {f_ts_col_str}
2679
+ FEATURE.* EXCLUDE ({join_keys_str}, {feature_timestamp_col})
2680
+ FROM
2681
+ SPINE
2682
+ ASOF JOIN (
2683
+ SELECT {join_keys_str}, {feature_timestamp_col}, {feature_columns[i]}
2684
+ FROM {feature_view.fully_qualified_name()}
2685
+ ) FEATURE
2686
+ MATCH_CONDITION (SPINE."{spine_timestamp_col}" >= FEATURE."{feature_timestamp_col}")
2687
+ ON {" AND ".join(join_conditions)}
2688
+ )"""
2689
+ )
2690
+ else:
2691
+ ctes.append(
2692
+ f"""{cte_name} AS (
2693
+ SELECT
2694
+ SPINE.*,
2695
+ FEATURE.* EXCLUDE ({join_keys_str})
2696
+ FROM
2697
+ SPINE
2698
+ LEFT JOIN (
2699
+ SELECT {join_keys_str}, {feature_columns[i]}
2700
+ FROM {feature_view.fully_qualified_name()}
2701
+ ) FEATURE
2702
+ ON {" AND ".join(join_conditions)}
2703
+ )"""
2704
+ )
2705
+
2706
+ # Build final SELECT with individual joins to each FV CTE
2707
+ select_columns = []
2708
+ join_clauses = []
2709
+
2710
+ for i, cte_name in enumerate(cte_names):
2711
+ feature_view = feature_views[i]
2712
+ fv_join_keys = list({k for e in feature_view.entities for k in e.join_keys})
2713
+ join_conditions = [f'SPINE."{col}" = {cte_name}."{col}"' for col in fv_join_keys]
2714
+ if spine_timestamp_col is not None:
2715
+ join_conditions.append(f'SPINE."{spine_timestamp_col}" = {cte_name}."{spine_timestamp_col}"')
2716
+ if include_feature_view_timestamp_col and feature_view.timestamp_col is not None:
2717
+ f_ts_col_alias = identifier.concat_names(
2718
+ [feature_view.name, "_", str(feature_view.version), "_", feature_view.timestamp_col]
2719
+ )
2720
+ f_ts_col_str = f"{cte_name}.{f_ts_col_alias} AS {f_ts_col_alias}"
2721
+ select_columns.append(f_ts_col_str)
2722
+ select_columns.append(feature_columns[i])
2723
+ # Create join condition using only this feature view's join keys
2724
+ join_clauses.append(
2725
+ f"""
2726
+ INNER JOIN {cte_name}
2727
+ ON {" AND ".join(join_conditions)}"""
2728
+ )
2729
+
2730
+ query = f"""WITH
2731
+ {', '.join(ctes)}
2732
+ SELECT
2733
+ SPINE.*,
2734
+ {', '.join(select_columns)}
2735
+ FROM SPINE{' '.join(join_clauses)}
2736
+ """
2737
+
2738
+ return query
2739
+
2599
2740
  def _join_features(
2600
2741
  self,
2601
2742
  spine_df: DataFrame,
2602
2743
  features: list[Union[FeatureView, FeatureViewSlice]],
2603
2744
  spine_timestamp_col: Optional[SqlIdentifier],
2604
2745
  include_feature_view_timestamp_col: bool,
2746
+ join_method: Literal["sequential", "cte"] = "sequential",
2605
2747
  ) -> tuple[DataFrame, list[SqlIdentifier]]:
2606
- for f in features:
2607
- f = f.feature_view_ref if isinstance(f, FeatureViewSlice) else f
2608
- if f.status == FeatureViewStatus.DRAFT:
2748
+ # Validate join_method parameter
2749
+ if join_method not in ["sequential", "cte"]:
2750
+ raise ValueError(f"Invalid join_method '{join_method}'. Must be 'sequential' or 'cte'.")
2751
+
2752
+ feature_views: list[FeatureView] = []
2753
+ # Extract column selections for each feature view
2754
+ feature_columns: list[str] = []
2755
+ for feature in features:
2756
+ fv = feature.feature_view_ref if isinstance(feature, FeatureViewSlice) else feature
2757
+ if fv.status == FeatureViewStatus.DRAFT:
2609
2758
  raise snowml_exceptions.SnowflakeMLException(
2610
2759
  error_code=error_codes.NOT_FOUND,
2611
- original_exception=ValueError(f"FeatureView {f.name} has not been registered."),
2760
+ original_exception=ValueError(f"FeatureView {fv.name} has not been registered."),
2612
2761
  )
2613
- for e in f.entities:
2762
+ for e in fv.entities:
2614
2763
  for k in e.join_keys:
2615
2764
  if k not in to_sql_identifiers(spine_df.columns):
2616
2765
  raise snowml_exceptions.SnowflakeMLException(
2617
2766
  error_code=error_codes.INVALID_ARGUMENT,
2618
2767
  original_exception=ValueError(
2619
- f"join_key {k} from Entity {e.name} in FeatureView {f.name} is not found in spine_df."
2768
+ f"join_key {k} from Entity {e.name} in FeatureView {fv.name} "
2769
+ "is not found in spine_df."
2620
2770
  ),
2621
2771
  )
2622
-
2772
+ feature_views.append(fv)
2773
+ if isinstance(feature, FeatureViewSlice):
2774
+ cols = feature.names
2775
+ else:
2776
+ cols = feature.feature_names
2777
+ feature_columns.append(", ".join(col.resolved() for col in cols))
2778
+ # TODO (SNOW-2396184): remove this check and the non-ASOF join path as ASOF join is enabled by default now.
2623
2779
  if self._asof_join_enabled is None:
2624
2780
  self._asof_join_enabled = self._is_asof_join_enabled()
2625
2781
 
2626
2782
  # TODO: leverage Snowpark dataframe for more concise syntax once it supports AsOfJoin
2627
2783
  query = spine_df.queries["queries"][-1]
2628
- layer = 0
2629
- for f in features:
2630
- if isinstance(f, FeatureViewSlice):
2631
- cols = f.names
2632
- f = f.feature_view_ref
2633
- else:
2634
- cols = f.feature_names
2635
-
2636
- join_keys = list({k for e in f.entities for k in e.join_keys})
2637
- join_keys_str = ", ".join(join_keys)
2638
- assert f.version is not None
2639
- join_table_name = f.fully_qualified_name()
2640
-
2641
- if spine_timestamp_col is not None and f.timestamp_col is not None:
2642
- if self._asof_join_enabled:
2643
- if include_feature_view_timestamp_col:
2644
- f_ts_col_alias = identifier.concat_names([f.name, "_", f.version, "_", f.timestamp_col])
2645
- f_ts_col_str = f"r_{layer}.{f.timestamp_col} AS {f_ts_col_alias},"
2784
+ join_keys: list[SqlIdentifier] = []
2785
+
2786
+ if join_method == "cte":
2787
+
2788
+ logger.info(f"Using the CTE method with {len(features)} feature views")
2789
+
2790
+ query = self._build_cte_query(
2791
+ feature_views,
2792
+ feature_columns,
2793
+ spine_df.queries["queries"][-1],
2794
+ spine_timestamp_col,
2795
+ include_feature_view_timestamp_col,
2796
+ )
2797
+ else:
2798
+ # Use sequential joins layer by layer
2799
+ logger.info(f"Using the sequential join method with {len(features)} feature views")
2800
+ layer = 0
2801
+ for feature in features:
2802
+ if isinstance(feature, FeatureViewSlice):
2803
+ cols = feature.names
2804
+ feature = feature.feature_view_ref
2805
+ else:
2806
+ cols = feature.feature_names
2807
+
2808
+ join_keys = list({k for e in feature.entities for k in e.join_keys})
2809
+ join_keys_str = ", ".join(join_keys)
2810
+ assert feature.version is not None
2811
+ join_table_name = feature.fully_qualified_name()
2812
+
2813
+ if spine_timestamp_col is not None and feature.timestamp_col is not None:
2814
+ if self._asof_join_enabled:
2815
+ if include_feature_view_timestamp_col:
2816
+ f_ts_col_alias = identifier.concat_names(
2817
+ [feature.name, "_", feature.version, "_", feature.timestamp_col]
2818
+ )
2819
+ f_ts_col_str = f"r_{layer}.{feature.timestamp_col} AS {f_ts_col_alias},"
2820
+ else:
2821
+ f_ts_col_str = ""
2822
+ query = f"""
2823
+ SELECT
2824
+ l_{layer}.*,
2825
+ {f_ts_col_str}
2826
+ r_{layer}.* EXCLUDE ({join_keys_str}, {feature.timestamp_col})
2827
+ FROM ({query}) l_{layer}
2828
+ ASOF JOIN (
2829
+ SELECT {join_keys_str}, {feature.timestamp_col},
2830
+ {', '.join(col.resolved() for col in cols)}
2831
+ FROM {join_table_name}
2832
+ ) r_{layer}
2833
+ MATCH_CONDITION (l_{layer}.{spine_timestamp_col} >= r_{layer}.{feature.timestamp_col})
2834
+ ON {' AND '.join([f'l_{layer}.{k} = r_{layer}.{k}' for k in join_keys])}
2835
+ """
2646
2836
  else:
2647
- f_ts_col_str = ""
2837
+ query = self._composed_union_window_join_query(
2838
+ layer=layer,
2839
+ s_query=query,
2840
+ s_ts_col=spine_timestamp_col,
2841
+ f_df=feature.feature_df,
2842
+ f_table_name=join_table_name,
2843
+ f_ts_col=feature.timestamp_col,
2844
+ join_keys=join_keys,
2845
+ )
2846
+ else:
2648
2847
  query = f"""
2649
2848
  SELECT
2650
2849
  l_{layer}.*,
2651
- {f_ts_col_str}
2652
- r_{layer}.* EXCLUDE ({join_keys_str}, {f.timestamp_col})
2850
+ r_{layer}.* EXCLUDE ({join_keys_str})
2653
2851
  FROM ({query}) l_{layer}
2654
- ASOF JOIN (
2655
- SELECT {join_keys_str}, {f.timestamp_col}, {', '.join(cols)}
2852
+ LEFT JOIN (
2853
+ SELECT {join_keys_str}, {', '.join(col.resolved() for col in cols)}
2656
2854
  FROM {join_table_name}
2657
2855
  ) r_{layer}
2658
- MATCH_CONDITION (l_{layer}.{spine_timestamp_col} >= r_{layer}.{f.timestamp_col})
2659
2856
  ON {' AND '.join([f'l_{layer}.{k} = r_{layer}.{k}' for k in join_keys])}
2660
2857
  """
2661
- else:
2662
- query = self._composed_union_window_join_query(
2663
- layer=layer,
2664
- s_query=query,
2665
- s_ts_col=spine_timestamp_col,
2666
- f_df=f.feature_df,
2667
- f_table_name=join_table_name,
2668
- f_ts_col=f.timestamp_col,
2669
- join_keys=join_keys,
2670
- )
2671
- else:
2672
- query = f"""
2673
- SELECT
2674
- l_{layer}.*,
2675
- r_{layer}.* EXCLUDE ({join_keys_str})
2676
- FROM ({query}) l_{layer}
2677
- LEFT JOIN (
2678
- SELECT {join_keys_str}, {', '.join(cols)}
2679
- FROM {join_table_name}
2680
- ) r_{layer}
2681
- ON {' AND '.join([f'l_{layer}.{k} = r_{layer}.{k}' for k in join_keys])}
2682
- """
2683
- layer += 1
2858
+ layer += 1
2684
2859
 
2685
2860
  # TODO: construct result dataframe with datframe APIs once ASOF join is supported natively.
2686
2861
  # Below code manually construct result dataframe from private members of spine dataframe, which
@@ -56,8 +56,9 @@ ENABLE_HEALTH_CHECKS_ENV_VAR = "ENABLE_HEALTH_CHECKS"
56
56
  ENABLE_HEALTH_CHECKS = "false"
57
57
 
58
58
  # Job status polling constants
59
- JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
59
+ JOB_POLL_INITIAL_DELAY_SECONDS = 5
60
60
  JOB_POLL_MAX_DELAY_SECONDS = 30
61
+ JOB_SPCS_TIMEOUT_SECONDS = 30
61
62
 
62
63
  # Log start and end messages
63
64
  LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
@@ -73,6 +74,7 @@ COMMON_INSTANCE_FAMILIES = {
73
74
  "CPU_X64_XS": ComputeResources(cpu=1, memory=6),
74
75
  "CPU_X64_S": ComputeResources(cpu=3, memory=13),
75
76
  "CPU_X64_M": ComputeResources(cpu=6, memory=28),
77
+ "CPU_X64_SL": ComputeResources(cpu=14, memory=54),
76
78
  "CPU_X64_L": ComputeResources(cpu=28, memory=116),
77
79
  "HIGHMEM_X64_S": ComputeResources(cpu=6, memory=58),
78
80
  }
@@ -85,6 +87,7 @@ AWS_INSTANCE_FAMILIES = {
85
87
  }
86
88
  AZURE_INSTANCE_FAMILIES = {
87
89
  "HIGHMEM_X64_M": ComputeResources(cpu=28, memory=244),
90
+ "HIGHMEM_X64_SL": ComputeResources(cpu=92, memory=654),
88
91
  "HIGHMEM_X64_L": ComputeResources(cpu=92, memory=654),
89
92
  "GPU_NV_XS": ComputeResources(cpu=3, memory=26, gpu=1, gpu_type="T4"),
90
93
  "GPU_NV_SM": ComputeResources(cpu=32, memory=424, gpu=1, gpu_type="A10"),
@@ -92,7 +95,15 @@ AZURE_INSTANCE_FAMILIES = {
92
95
  "GPU_NV_3M": ComputeResources(cpu=44, memory=424, gpu=2, gpu_type="A100"),
93
96
  "GPU_NV_SL": ComputeResources(cpu=92, memory=858, gpu=4, gpu_type="A100"),
94
97
  }
98
+ GCP_INSTANCE_FAMILIES = {
99
+ "HIGHMEM_X64_M": ComputeResources(cpu=28, memory=244),
100
+ "HIGHMEM_X64_SL": ComputeResources(cpu=92, memory=654),
101
+ "GPU_GCP_NV_L4_1_24G": ComputeResources(cpu=6, memory=28, gpu=1, gpu_type="L4"),
102
+ "GPU_GCP_NV_L4_4_24G": ComputeResources(cpu=44, memory=178, gpu=4, gpu_type="L4"),
103
+ "GPU_GCP_NV_A100_8_40G": ComputeResources(cpu=92, memory=654, gpu=8, gpu_type="A100"),
104
+ }
95
105
  CLOUD_INSTANCE_FAMILIES = {
96
106
  SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
97
107
  SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
108
+ SnowflakeCloudType.GCP: GCP_INSTANCE_FAMILIES,
98
109
  }
@@ -488,10 +488,13 @@ class JobPayload:
488
488
  " comment = 'Created by snowflake.ml.jobs Python API'",
489
489
  params=[stage_name],
490
490
  )
491
-
491
+ payload_name = None
492
492
  # Upload payload to stage - organize into app/ subdirectory
493
493
  app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
494
494
  if not isinstance(source, types.PayloadPath):
495
+ if isinstance(source, function_payload_utils.FunctionPayload):
496
+ payload_name = source.function.__name__
497
+
495
498
  source_code = generate_python_code(source, source_code_display=True)
496
499
  _ = session.file.put_stream(
497
500
  io.BytesIO(source_code.encode()),
@@ -502,12 +505,14 @@ class JobPayload:
502
505
  source = Path(entrypoint.file_path.parent)
503
506
 
504
507
  elif isinstance(source, stage_utils.StagePath):
508
+ payload_name = entrypoint.file_path.stem
505
509
  # copy payload to stage
506
510
  if source == entrypoint.file_path:
507
511
  source = source.parent
508
512
  upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
509
513
 
510
514
  elif isinstance(source, Path):
515
+ payload_name = entrypoint.file_path.stem
511
516
  upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
512
517
  if source.is_file():
513
518
  source = source.parent
@@ -562,6 +567,7 @@ class JobPayload:
562
567
  *python_entrypoint,
563
568
  ],
564
569
  env_vars=env_vars,
570
+ payload_name=payload_name,
565
571
  )
566
572
 
567
573
 
@@ -32,6 +32,10 @@ class StagePath:
32
32
  self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
33
33
  self._path = Path(relpath or "")
34
34
 
35
+ @property
36
+ def stem(self) -> str:
37
+ return self._path.stem
38
+
35
39
  @property
36
40
  def parts(self) -> tuple[str, ...]:
37
41
  return self._path.parts
@@ -23,6 +23,10 @@ class PayloadPath(Protocol):
23
23
  def name(self) -> str:
24
24
  ...
25
25
 
26
+ @property
27
+ def stem(self) -> str:
28
+ ...
29
+
26
30
  @property
27
31
  def suffix(self) -> str:
28
32
  ...
@@ -92,6 +96,7 @@ class UploadedPayload:
92
96
  stage_path: PurePath
93
97
  entrypoint: list[Union[str, PurePath]]
94
98
  env_vars: dict[str, str] = field(default_factory=dict)
99
+ payload_name: Optional[str] = None
95
100
 
96
101
 
97
102
  @dataclass(frozen=True)
snowflake/ml/jobs/job.py CHANGED
@@ -120,7 +120,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
120
120
  """Get the job's result file location."""
121
121
  result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
122
122
  if result_path_str is None:
123
- raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
123
+ raise NotImplementedError(f"Job {self.name} doesn't have a result path configured")
124
124
 
125
125
  return self._transform_path(result_path_str)
126
126
 
@@ -229,8 +229,22 @@ class MLJob(Generic[T], SerializableSessionMixin):
229
229
  Raises:
230
230
  TimeoutError: If the job does not complete within the specified timeout.
231
231
  """
232
- delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
233
232
  start_time = time.monotonic()
233
+ try:
234
+ # spcs_wait_for() is a synchronous query, it’s more effective to do polling with exponential
235
+ # backoff. If the job is running for a long time. We want a hybrid option: use spcs_wait_for()
236
+ # for the first 30 seconds, then switch to polling for long running jobs
237
+ min_timeout = (
238
+ int(min(timeout, constants.JOB_SPCS_TIMEOUT_SECONDS))
239
+ if timeout >= 0
240
+ else constants.JOB_SPCS_TIMEOUT_SECONDS
241
+ )
242
+ query_helper.run_query(self._session, f"call {self.id}!spcs_wait_for('DONE', {min_timeout})")
243
+ return self.status
244
+ except SnowparkSQLException:
245
+ # if the function does not support for this environment
246
+ pass
247
+ delay: float = float(constants.JOB_POLL_INITIAL_DELAY_SECONDS) # Start with 5s delay
234
248
  warning_shown = False
235
249
  while (status := self.status) not in TERMINAL_JOB_STATUSES:
236
250
  elapsed = time.monotonic() - start_time
@@ -697,10 +697,21 @@ def _do_submit_job_v2(
697
697
  "MIN_INSTANCES": min_instances,
698
698
  "ASYNC": use_async,
699
699
  }
700
+ if payload.payload_name:
701
+ job_options["GENERATE_SUFFIX"] = True
700
702
  job_options = {k: v for k, v in job_options.items() if v is not None}
701
703
 
702
704
  query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
703
- params = [job_id, compute_pool, json.dumps(spec_options), json.dumps(job_options)]
705
+ if job_id:
706
+ database, schema, _ = identifier.parse_schema_level_object_identifier(job_id)
707
+ params = [
708
+ job_id
709
+ if payload.payload_name is None
710
+ else identifier.get_schema_level_object_identifier(database, schema, payload.payload_name) + "_",
711
+ compute_pool,
712
+ json.dumps(spec_options),
713
+ json.dumps(job_options),
714
+ ]
704
715
  actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
705
716
 
706
717
  return get_job(actual_job_id, session=session)
@@ -1,3 +1,6 @@
1
+ import sys
2
+ import warnings
3
+
1
4
  from snowflake.ml.model._client.model.batch_inference_specs import (
2
5
  JobSpec,
3
6
  OutputSpec,
@@ -18,3 +21,19 @@ __all__ = [
18
21
  "SaveMode",
19
22
  "Volatility",
20
23
  ]
24
+
25
+ _deprecation_warning_msg_for_3_9 = (
26
+ "Python 3.9 is deprecated in snowflake-ml-python. " "Please upgrade to Python 3.10 or greater."
27
+ )
28
+
29
+ warnings.filterwarnings(
30
+ "once",
31
+ message=_deprecation_warning_msg_for_3_9,
32
+ )
33
+
34
+ if sys.version_info.major == 3 and sys.version_info.minor == 9:
35
+ warnings.warn(
36
+ _deprecation_warning_msg_for_3_9,
37
+ category=DeprecationWarning,
38
+ stacklevel=2,
39
+ )