snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/_internal/telemetry.py +3 -2
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
  5. snowflake/ml/experiment/callback/keras.py +3 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  7. snowflake/ml/experiment/callback/xgboost.py +3 -0
  8. snowflake/ml/experiment/experiment_tracking.py +19 -7
  9. snowflake/ml/feature_store/feature_store.py +236 -61
  10. snowflake/ml/jobs/__init__.py +4 -0
  11. snowflake/ml/jobs/_interop/__init__.py +0 -0
  12. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  13. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  14. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  15. snowflake/ml/jobs/_interop/legacy.py +225 -0
  16. snowflake/ml/jobs/_interop/protocols.py +471 -0
  17. snowflake/ml/jobs/_interop/results.py +51 -0
  18. snowflake/ml/jobs/_interop/utils.py +144 -0
  19. snowflake/ml/jobs/_utils/constants.py +16 -2
  20. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  21. snowflake/ml/jobs/_utils/payload_utils.py +8 -2
  22. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  23. snowflake/ml/jobs/_utils/spec_utils.py +2 -1
  24. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  25. snowflake/ml/jobs/_utils/types.py +15 -0
  26. snowflake/ml/jobs/job.py +186 -40
  27. snowflake/ml/jobs/manager.py +48 -39
  28. snowflake/ml/model/__init__.py +19 -0
  29. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  30. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  31. snowflake/ml/model/_client/model/model_version_impl.py +168 -18
  32. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  33. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  36. snowflake/ml/model/_client/sql/model_version.py +3 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
  39. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  40. snowflake/ml/model/_packager/model_env/model_env.py +22 -5
  41. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  42. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  43. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  44. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  46. snowflake/ml/model/type_hints.py +16 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  48. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  49. snowflake/ml/version.py +1 -1
  50. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
  51. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
  52. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
  53. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
  54. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
@@ -1199,10 +1199,10 @@ class FeatureStore:
1199
1199
  """Get refresh history for online feature table."""
1200
1200
  online_table_name = FeatureView._get_online_table_name(feature_view.name, feature_view.version)
1201
1201
  select_cols = "*" if verbose else "name, state, refresh_start_time, refresh_end_time, refresh_action"
1202
- prefix = (
1203
- f"{self._config.database.resolved()}."
1204
- f"{self._config.schema.resolved()}."
1205
- f"{online_table_name.resolved()}"
1202
+ name = (
1203
+ f"{self._config.database.identifier()}."
1204
+ f"{self._config.schema.identifier()}."
1205
+ f"{online_table_name.identifier()}"
1206
1206
  )
1207
1207
  return self._session.sql(
1208
1208
  f"""
@@ -1210,9 +1210,8 @@ class FeatureStore:
1210
1210
  {select_cols}
1211
1211
  FROM TABLE (
1212
1212
  {self._config.database}.INFORMATION_SCHEMA.ONLINE_FEATURE_TABLE_REFRESH_HISTORY (
1213
- NAME_PREFIX => '{prefix}'
1213
+ NAME => '{name}'
1214
1214
  )
1215
-
1216
1215
  )
1217
1216
  """
1218
1217
  )
@@ -1591,6 +1590,7 @@ class FeatureStore:
1591
1590
  spine_timestamp_col: Optional[str] = None,
1592
1591
  exclude_columns: Optional[list[str]] = None,
1593
1592
  include_feature_view_timestamp_col: bool = False,
1593
+ join_method: Literal["sequential", "cte"] = "sequential",
1594
1594
  ) -> DataFrame:
1595
1595
  """
1596
1596
  Enrich spine dataframe with feature values. Mainly used to generate inference data input.
@@ -1604,6 +1604,8 @@ class FeatureStore:
1604
1604
  exclude_columns: Column names to exclude from the result dataframe.
1605
1605
  include_feature_view_timestamp_col: Generated dataset will include timestamp column of feature view
1606
1606
  (if feature view has timestamp column) if set true. Default to false.
1607
+ join_method: Method for feature joins. "sequential" for layer-by-layer joins (default),
1608
+ "cte" for CTE method. (Internal use only - subject to change)
1607
1609
 
1608
1610
  Returns:
1609
1611
  Snowpark DataFrame containing the joined results.
@@ -1641,6 +1643,7 @@ class FeatureStore:
1641
1643
  cast(list[Union[FeatureView, FeatureViewSlice]], features),
1642
1644
  spine_timestamp_col,
1643
1645
  include_feature_view_timestamp_col,
1646
+ join_method,
1644
1647
  )
1645
1648
 
1646
1649
  if exclude_columns is not None:
@@ -1659,6 +1662,7 @@ class FeatureStore:
1659
1662
  spine_label_cols: Optional[list[str]] = None,
1660
1663
  exclude_columns: Optional[list[str]] = None,
1661
1664
  include_feature_view_timestamp_col: bool = False,
1665
+ join_method: Literal["sequential", "cte"] = "sequential",
1662
1666
  ) -> DataFrame:
1663
1667
  """
1664
1668
  Generate a training set from the specified Spine DataFrame and Feature Views. Result is
@@ -1676,6 +1680,8 @@ class FeatureStore:
1676
1680
  exclude_columns: Name of column(s) to exclude from the resulting training set.
1677
1681
  include_feature_view_timestamp_col: Generated dataset will include timestamp column of feature view
1678
1682
  (if feature view has timestamp column) if set true. Default to false.
1683
+ join_method: Method for feature joins. "sequential" for layer-by-layer joins (default),
1684
+ "cte" for CTE method. (Internal use only - subject to change)
1679
1685
 
1680
1686
  Returns:
1681
1687
  Returns a Snowpark DataFrame representing the training set.
@@ -1709,7 +1715,7 @@ class FeatureStore:
1709
1715
  spine_label_cols = to_sql_identifiers(spine_label_cols) # type: ignore[assignment]
1710
1716
 
1711
1717
  result_df, join_keys = self._join_features(
1712
- spine_df, features, spine_timestamp_col, include_feature_view_timestamp_col
1718
+ spine_df, features, spine_timestamp_col, include_feature_view_timestamp_col, join_method
1713
1719
  )
1714
1720
 
1715
1721
  if exclude_columns is not None:
@@ -1757,6 +1763,7 @@ class FeatureStore:
1757
1763
  include_feature_view_timestamp_col: bool = False,
1758
1764
  desc: str = "",
1759
1765
  output_type: Literal["dataset"] = "dataset",
1766
+ join_method: Literal["sequential", "cte"] = "sequential",
1760
1767
  ) -> dataset.Dataset:
1761
1768
  ...
1762
1769
 
@@ -1774,6 +1781,7 @@ class FeatureStore:
1774
1781
  exclude_columns: Optional[list[str]] = None,
1775
1782
  include_feature_view_timestamp_col: bool = False,
1776
1783
  desc: str = "",
1784
+ join_method: Literal["sequential", "cte"] = "sequential",
1777
1785
  ) -> DataFrame:
1778
1786
  ...
1779
1787
 
@@ -1791,6 +1799,7 @@ class FeatureStore:
1791
1799
  include_feature_view_timestamp_col: bool = False,
1792
1800
  desc: str = "",
1793
1801
  output_type: Literal["dataset", "table"] = "dataset",
1802
+ join_method: Literal["sequential", "cte"] = "sequential",
1794
1803
  ) -> Union[dataset.Dataset, DataFrame]:
1795
1804
  """
1796
1805
  Generate dataset by given source table and feature views.
@@ -1811,6 +1820,8 @@ class FeatureStore:
1811
1820
  (if feature view has timestamp column) if set true. Default to false.
1812
1821
  desc: A description about this dataset.
1813
1822
  output_type: (Deprecated) The type of Snowflake storage to use for the generated training data.
1823
+ join_method: Method for feature joins. "sequential" for layer-by-layer joins (default),
1824
+ "cte" for CTE method. (Internal use only - subject to change)
1814
1825
 
1815
1826
  Returns:
1816
1827
  If output_type is "dataset" (default), returns a Dataset object.
@@ -1874,6 +1885,7 @@ class FeatureStore:
1874
1885
  exclude_columns=exclude_columns,
1875
1886
  include_feature_view_timestamp_col=include_feature_view_timestamp_col,
1876
1887
  save_as=table_name,
1888
+ join_method=join_method,
1877
1889
  )
1878
1890
  if output_type == "table":
1879
1891
  warnings.warn(
@@ -2596,91 +2608,254 @@ class FeatureStore:
2596
2608
  found_rows = self._find_object("TAGS", full_entity_tag_name)
2597
2609
  return len(found_rows) == 1
2598
2610
 
2611
+ def _build_cte_query(
2612
+ self,
2613
+ feature_views: list[FeatureView],
2614
+ feature_columns: list[str],
2615
+ spine_ref: str,
2616
+ spine_timestamp_col: Optional[SqlIdentifier],
2617
+ include_feature_view_timestamp_col: bool = False,
2618
+ ) -> str:
2619
+ """
2620
+ Build a CTE query with the spine query and the feature views.
2621
+
2622
+ This method supports feature views with different join keys by:
2623
+ 1. Creating a spine CTE that includes all possible join keys
2624
+ 2. Performing ASOF JOINs for each feature view using only its specific join keys when timestamp columns exist
2625
+ 3. Performing LEFT JOINs for each feature view when timestamp columns are missing
2626
+ 4. Combining results using INNER JOINs on each feature view's specific join keys
2627
+
2628
+ Args:
2629
+ feature_views: A list of feature views to join.
2630
+ feature_columns: A list of feature column strings for each feature view.
2631
+ spine_ref: The spine query.
2632
+ spine_timestamp_col: The timestamp column from spine. Can be None if spine has no timestamp column.
2633
+ include_feature_view_timestamp_col: Whether to include the timestamp column of
2634
+ the feature view in the result. Default to false.
2635
+
2636
+ Note: This method does NOT work when there are duplicate combinations of join keys and timestamp columns
2637
+ in spine.
2638
+
2639
+ Returns:
2640
+ A SQL query string with CTE structure for joining feature views.
2641
+ """
2642
+ if not feature_views:
2643
+ return f"SELECT * FROM ({spine_ref})"
2644
+
2645
+ # Create spine CTE with the spine query for reuse
2646
+ spine_cte = f"""SPINE AS (
2647
+ SELECT * FROM ({spine_ref})
2648
+ )"""
2649
+
2650
+ ctes = [spine_cte]
2651
+ cte_names = []
2652
+ for i, feature_view in enumerate(feature_views):
2653
+ cte_name = f"FV{i:03d}"
2654
+ cte_names.append(cte_name)
2655
+
2656
+ feature_timestamp_col = feature_view.timestamp_col
2657
+
2658
+ # Get the specific join keys for this feature view
2659
+ fv_join_keys = list({k for e in feature_view.entities for k in e.join_keys})
2660
+ join_keys_str = ", ".join(fv_join_keys)
2661
+
2662
+ # Build the JOIN condition using only this feature view's join keys
2663
+ join_conditions = [f'SPINE."{col}" = FEATURE."{col}"' for col in fv_join_keys]
2664
+
2665
+ # Use ASOF JOIN if both spine and feature view have timestamp columns, otherwise use LEFT JOIN
2666
+ if spine_timestamp_col is not None and feature_timestamp_col is not None:
2667
+ if include_feature_view_timestamp_col:
2668
+ f_ts_col_alias = identifier.concat_names(
2669
+ [feature_view.name, "_", str(feature_view.version), "_", feature_timestamp_col]
2670
+ )
2671
+ f_ts_col_str = f"FEATURE.{feature_timestamp_col} AS {f_ts_col_alias},"
2672
+ else:
2673
+ f_ts_col_str = ""
2674
+ ctes.append(
2675
+ f"""{cte_name} AS (
2676
+ SELECT
2677
+ SPINE.*,
2678
+ {f_ts_col_str}
2679
+ FEATURE.* EXCLUDE ({join_keys_str}, {feature_timestamp_col})
2680
+ FROM
2681
+ SPINE
2682
+ ASOF JOIN (
2683
+ SELECT {join_keys_str}, {feature_timestamp_col}, {feature_columns[i]}
2684
+ FROM {feature_view.fully_qualified_name()}
2685
+ ) FEATURE
2686
+ MATCH_CONDITION (SPINE."{spine_timestamp_col}" >= FEATURE."{feature_timestamp_col}")
2687
+ ON {" AND ".join(join_conditions)}
2688
+ )"""
2689
+ )
2690
+ else:
2691
+ ctes.append(
2692
+ f"""{cte_name} AS (
2693
+ SELECT
2694
+ SPINE.*,
2695
+ FEATURE.* EXCLUDE ({join_keys_str})
2696
+ FROM
2697
+ SPINE
2698
+ LEFT JOIN (
2699
+ SELECT {join_keys_str}, {feature_columns[i]}
2700
+ FROM {feature_view.fully_qualified_name()}
2701
+ ) FEATURE
2702
+ ON {" AND ".join(join_conditions)}
2703
+ )"""
2704
+ )
2705
+
2706
+ # Build final SELECT with individual joins to each FV CTE
2707
+ select_columns = []
2708
+ join_clauses = []
2709
+
2710
+ for i, cte_name in enumerate(cte_names):
2711
+ feature_view = feature_views[i]
2712
+ fv_join_keys = list({k for e in feature_view.entities for k in e.join_keys})
2713
+ join_conditions = [f'SPINE."{col}" = {cte_name}."{col}"' for col in fv_join_keys]
2714
+ if spine_timestamp_col is not None:
2715
+ join_conditions.append(f'SPINE."{spine_timestamp_col}" = {cte_name}."{spine_timestamp_col}"')
2716
+ if include_feature_view_timestamp_col and feature_view.timestamp_col is not None:
2717
+ f_ts_col_alias = identifier.concat_names(
2718
+ [feature_view.name, "_", str(feature_view.version), "_", feature_view.timestamp_col]
2719
+ )
2720
+ f_ts_col_str = f"{cte_name}.{f_ts_col_alias} AS {f_ts_col_alias}"
2721
+ select_columns.append(f_ts_col_str)
2722
+ select_columns.append(feature_columns[i])
2723
+ # Create join condition using only this feature view's join keys
2724
+ join_clauses.append(
2725
+ f"""
2726
+ INNER JOIN {cte_name}
2727
+ ON {" AND ".join(join_conditions)}"""
2728
+ )
2729
+
2730
+ query = f"""WITH
2731
+ {', '.join(ctes)}
2732
+ SELECT
2733
+ SPINE.*,
2734
+ {', '.join(select_columns)}
2735
+ FROM SPINE{' '.join(join_clauses)}
2736
+ """
2737
+
2738
+ return query
2739
+
2599
2740
  def _join_features(
2600
2741
  self,
2601
2742
  spine_df: DataFrame,
2602
2743
  features: list[Union[FeatureView, FeatureViewSlice]],
2603
2744
  spine_timestamp_col: Optional[SqlIdentifier],
2604
2745
  include_feature_view_timestamp_col: bool,
2746
+ join_method: Literal["sequential", "cte"] = "sequential",
2605
2747
  ) -> tuple[DataFrame, list[SqlIdentifier]]:
2606
- for f in features:
2607
- f = f.feature_view_ref if isinstance(f, FeatureViewSlice) else f
2608
- if f.status == FeatureViewStatus.DRAFT:
2748
+ # Validate join_method parameter
2749
+ if join_method not in ["sequential", "cte"]:
2750
+ raise ValueError(f"Invalid join_method '{join_method}'. Must be 'sequential' or 'cte'.")
2751
+
2752
+ feature_views: list[FeatureView] = []
2753
+ # Extract column selections for each feature view
2754
+ feature_columns: list[str] = []
2755
+ for feature in features:
2756
+ fv = feature.feature_view_ref if isinstance(feature, FeatureViewSlice) else feature
2757
+ if fv.status == FeatureViewStatus.DRAFT:
2609
2758
  raise snowml_exceptions.SnowflakeMLException(
2610
2759
  error_code=error_codes.NOT_FOUND,
2611
- original_exception=ValueError(f"FeatureView {f.name} has not been registered."),
2760
+ original_exception=ValueError(f"FeatureView {fv.name} has not been registered."),
2612
2761
  )
2613
- for e in f.entities:
2762
+ for e in fv.entities:
2614
2763
  for k in e.join_keys:
2615
2764
  if k not in to_sql_identifiers(spine_df.columns):
2616
2765
  raise snowml_exceptions.SnowflakeMLException(
2617
2766
  error_code=error_codes.INVALID_ARGUMENT,
2618
2767
  original_exception=ValueError(
2619
- f"join_key {k} from Entity {e.name} in FeatureView {f.name} is not found in spine_df."
2768
+ f"join_key {k} from Entity {e.name} in FeatureView {fv.name} "
2769
+ "is not found in spine_df."
2620
2770
  ),
2621
2771
  )
2622
-
2772
+ feature_views.append(fv)
2773
+ if isinstance(feature, FeatureViewSlice):
2774
+ cols = feature.names
2775
+ else:
2776
+ cols = feature.feature_names
2777
+ feature_columns.append(", ".join(col.resolved() for col in cols))
2778
+ # TODO (SNOW-2396184): remove this check and the non-ASOF join path as ASOF join is enabled by default now.
2623
2779
  if self._asof_join_enabled is None:
2624
2780
  self._asof_join_enabled = self._is_asof_join_enabled()
2625
2781
 
2626
2782
  # TODO: leverage Snowpark dataframe for more concise syntax once it supports AsOfJoin
2627
2783
  query = spine_df.queries["queries"][-1]
2628
- layer = 0
2629
- for f in features:
2630
- if isinstance(f, FeatureViewSlice):
2631
- cols = f.names
2632
- f = f.feature_view_ref
2633
- else:
2634
- cols = f.feature_names
2635
-
2636
- join_keys = list({k for e in f.entities for k in e.join_keys})
2637
- join_keys_str = ", ".join(join_keys)
2638
- assert f.version is not None
2639
- join_table_name = f.fully_qualified_name()
2640
-
2641
- if spine_timestamp_col is not None and f.timestamp_col is not None:
2642
- if self._asof_join_enabled:
2643
- if include_feature_view_timestamp_col:
2644
- f_ts_col_alias = identifier.concat_names([f.name, "_", f.version, "_", f.timestamp_col])
2645
- f_ts_col_str = f"r_{layer}.{f.timestamp_col} AS {f_ts_col_alias},"
2784
+ join_keys: list[SqlIdentifier] = []
2785
+
2786
+ if join_method == "cte":
2787
+
2788
+ logger.info(f"Using the CTE method with {len(features)} feature views")
2789
+
2790
+ query = self._build_cte_query(
2791
+ feature_views,
2792
+ feature_columns,
2793
+ spine_df.queries["queries"][-1],
2794
+ spine_timestamp_col,
2795
+ include_feature_view_timestamp_col,
2796
+ )
2797
+ else:
2798
+ # Use sequential joins layer by layer
2799
+ logger.info(f"Using the sequential join method with {len(features)} feature views")
2800
+ layer = 0
2801
+ for feature in features:
2802
+ if isinstance(feature, FeatureViewSlice):
2803
+ cols = feature.names
2804
+ feature = feature.feature_view_ref
2805
+ else:
2806
+ cols = feature.feature_names
2807
+
2808
+ join_keys = list({k for e in feature.entities for k in e.join_keys})
2809
+ join_keys_str = ", ".join(join_keys)
2810
+ assert feature.version is not None
2811
+ join_table_name = feature.fully_qualified_name()
2812
+
2813
+ if spine_timestamp_col is not None and feature.timestamp_col is not None:
2814
+ if self._asof_join_enabled:
2815
+ if include_feature_view_timestamp_col:
2816
+ f_ts_col_alias = identifier.concat_names(
2817
+ [feature.name, "_", feature.version, "_", feature.timestamp_col]
2818
+ )
2819
+ f_ts_col_str = f"r_{layer}.{feature.timestamp_col} AS {f_ts_col_alias},"
2820
+ else:
2821
+ f_ts_col_str = ""
2822
+ query = f"""
2823
+ SELECT
2824
+ l_{layer}.*,
2825
+ {f_ts_col_str}
2826
+ r_{layer}.* EXCLUDE ({join_keys_str}, {feature.timestamp_col})
2827
+ FROM ({query}) l_{layer}
2828
+ ASOF JOIN (
2829
+ SELECT {join_keys_str}, {feature.timestamp_col},
2830
+ {', '.join(col.resolved() for col in cols)}
2831
+ FROM {join_table_name}
2832
+ ) r_{layer}
2833
+ MATCH_CONDITION (l_{layer}.{spine_timestamp_col} >= r_{layer}.{feature.timestamp_col})
2834
+ ON {' AND '.join([f'l_{layer}.{k} = r_{layer}.{k}' for k in join_keys])}
2835
+ """
2646
2836
  else:
2647
- f_ts_col_str = ""
2837
+ query = self._composed_union_window_join_query(
2838
+ layer=layer,
2839
+ s_query=query,
2840
+ s_ts_col=spine_timestamp_col,
2841
+ f_df=feature.feature_df,
2842
+ f_table_name=join_table_name,
2843
+ f_ts_col=feature.timestamp_col,
2844
+ join_keys=join_keys,
2845
+ )
2846
+ else:
2648
2847
  query = f"""
2649
2848
  SELECT
2650
2849
  l_{layer}.*,
2651
- {f_ts_col_str}
2652
- r_{layer}.* EXCLUDE ({join_keys_str}, {f.timestamp_col})
2850
+ r_{layer}.* EXCLUDE ({join_keys_str})
2653
2851
  FROM ({query}) l_{layer}
2654
- ASOF JOIN (
2655
- SELECT {join_keys_str}, {f.timestamp_col}, {', '.join(cols)}
2852
+ LEFT JOIN (
2853
+ SELECT {join_keys_str}, {', '.join(col.resolved() for col in cols)}
2656
2854
  FROM {join_table_name}
2657
2855
  ) r_{layer}
2658
- MATCH_CONDITION (l_{layer}.{spine_timestamp_col} >= r_{layer}.{f.timestamp_col})
2659
2856
  ON {' AND '.join([f'l_{layer}.{k} = r_{layer}.{k}' for k in join_keys])}
2660
2857
  """
2661
- else:
2662
- query = self._composed_union_window_join_query(
2663
- layer=layer,
2664
- s_query=query,
2665
- s_ts_col=spine_timestamp_col,
2666
- f_df=f.feature_df,
2667
- f_table_name=join_table_name,
2668
- f_ts_col=f.timestamp_col,
2669
- join_keys=join_keys,
2670
- )
2671
- else:
2672
- query = f"""
2673
- SELECT
2674
- l_{layer}.*,
2675
- r_{layer}.* EXCLUDE ({join_keys_str})
2676
- FROM ({query}) l_{layer}
2677
- LEFT JOIN (
2678
- SELECT {join_keys_str}, {', '.join(cols)}
2679
- FROM {join_table_name}
2680
- ) r_{layer}
2681
- ON {' AND '.join([f'l_{layer}.{k} = r_{layer}.{k}' for k in join_keys])}
2682
- """
2683
- layer += 1
2858
+ layer += 1
2684
2859
 
2685
2860
  # TODO: construct result dataframe with datframe APIs once ASOF join is supported natively.
2686
2861
  # Below code manually construct result dataframe from private members of spine dataframe, which
@@ -1,3 +1,4 @@
1
+ from snowflake.ml.jobs._interop.exception_utils import install_exception_display_hooks
1
2
  from snowflake.ml.jobs._utils.types import JOB_STATUS
2
3
  from snowflake.ml.jobs.decorators import remote
3
4
  from snowflake.ml.jobs.job import MLJob
@@ -10,6 +11,9 @@ from snowflake.ml.jobs.manager import (
10
11
  submit_from_stage,
11
12
  )
12
13
 
14
+ # Initialize exception display hooks for remote job error handling
15
+ install_exception_display_hooks()
16
+
13
17
  __all__ = [
14
18
  "remote",
15
19
  "submit_file",
File without changes
@@ -0,0 +1,124 @@
1
+ import io
2
+ import json
3
+ from typing import Any, Literal, Optional, Protocol, Union, cast, overload
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml.jobs._interop import dto_schema
7
+
8
+
9
+ class StageFileWriter(io.IOBase):
10
+ """
11
+ A context manager IOBase implementation that proxies writes to an internal BytesIO
12
+ and uploads to Snowflake stage on close.
13
+ """
14
+
15
+ def __init__(self, session: snowpark.Session, path: str) -> None:
16
+ self._session = session
17
+ self._path = path
18
+ self._buffer = io.BytesIO()
19
+ self._closed = False
20
+ self._exception_occurred = False
21
+
22
+ def write(self, data: Union[bytes, bytearray]) -> int:
23
+ """Write data to the internal buffer."""
24
+ if self._closed:
25
+ raise ValueError("I/O operation on closed file")
26
+ return self._buffer.write(data)
27
+
28
+ def close(self, write_contents: bool = True) -> None:
29
+ """Close the file and upload the buffer contents to the stage."""
30
+ if not self._closed:
31
+ # Only upload if buffer has content and no exception occurred
32
+ if write_contents and self._buffer.tell() > 0:
33
+ self._buffer.seek(0)
34
+ self._session.file.put_stream(self._buffer, self._path)
35
+ self._buffer.close()
36
+ self._closed = True
37
+
38
+ def __enter__(self) -> "StageFileWriter":
39
+ return self
40
+
41
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
42
+ exception_occurred = exc_type is not None
43
+ self.close(write_contents=not exception_occurred)
44
+
45
+ @property
46
+ def closed(self) -> bool:
47
+ return self._closed
48
+
49
+ def writable(self) -> bool:
50
+ return not self._closed
51
+
52
+ def readable(self) -> bool:
53
+ return False
54
+
55
+ def seekable(self) -> bool:
56
+ return not self._closed
57
+
58
+
59
+ def _is_stage_path(path: str) -> bool:
60
+ return path.startswith("@") or path.startswith("snow://")
61
+
62
+
63
+ def open_stream(path: str, mode: str = "rb", session: Optional[snowpark.Session] = None) -> io.IOBase:
64
+ if _is_stage_path(path):
65
+ if session is None:
66
+ raise ValueError("Session is required when opening a stage path")
67
+ if "r" in mode:
68
+ stream: io.IOBase = session.file.get_stream(path) # type: ignore[assignment]
69
+ return stream
70
+ elif "w" in mode:
71
+ return StageFileWriter(session, path)
72
+ else:
73
+ raise ValueError(f"Unsupported mode '{mode}' for stage path")
74
+ else:
75
+ result: io.IOBase = open(path, mode) # type: ignore[assignment]
76
+ return result
77
+
78
+
79
+ class DtoCodec(Protocol):
80
+ @overload
81
+ @staticmethod
82
+ def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]:
83
+ ...
84
+
85
+ @overload
86
+ @staticmethod
87
+ def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
88
+ ...
89
+
90
+ @staticmethod
91
+ def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
92
+ pass
93
+
94
+ @staticmethod
95
+ def encode(dto: dto_schema.ResultDTO) -> bytes:
96
+ pass
97
+
98
+
99
+ class JsonDtoCodec(DtoCodec):
100
+ @overload
101
+ @staticmethod
102
+ def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]:
103
+ ...
104
+
105
+ @overload
106
+ @staticmethod
107
+ def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
108
+ ...
109
+
110
+ @staticmethod
111
+ def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
112
+ data = cast(dict[str, Any], json.load(stream))
113
+ if as_dict:
114
+ return data
115
+ return dto_schema.ResultDTO.model_validate(data)
116
+
117
+ @staticmethod
118
+ def encode(dto: dto_schema.ResultDTO) -> bytes:
119
+ # Temporarily extract the value to avoid accidentally applying model_dump() on it
120
+ result_value = dto.value
121
+ dto.value = None # Clear value to avoid serializing it in the model_dump
122
+ result_dict = dto.model_dump()
123
+ result_dict["value"] = result_value # Put back the value
124
+ return json.dumps(result_dict).encode("utf-8")
@@ -0,0 +1,95 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ from pydantic import BaseModel, model_validator
4
+ from typing_extensions import NotRequired, TypedDict
5
+
6
+
7
+ class BinaryManifest(TypedDict):
8
+ """
9
+ Binary data manifest schema.
10
+ Contains one of: path, bytes, or base64 for the serialized data.
11
+ """
12
+
13
+ path: NotRequired[str] # Path to file
14
+ bytes: NotRequired[bytes] # In-line byte string (not supported with JSON codec)
15
+ base64: NotRequired[str] # Base64 encoded string
16
+
17
+
18
+ class ParquetManifest(TypedDict):
19
+ """Protocol manifest schema for parquet files."""
20
+
21
+ paths: list[str] # File paths
22
+
23
+
24
+ # Union type for all manifest types, including catch-all dict[str, Any] for backward compatibility
25
+ PayloadManifest = Union[BinaryManifest, ParquetManifest, dict[str, Any]]
26
+
27
+
28
+ class ProtocolInfo(BaseModel):
29
+ """
30
+ The protocol used to serialize the result and the manifest of the result.
31
+ """
32
+
33
+ name: str
34
+ version: Optional[str] = None
35
+ metadata: Optional[dict[str, str]] = None
36
+ manifest: Optional[PayloadManifest] = None
37
+
38
+ def __str__(self) -> str:
39
+ result = self.name
40
+ if self.version:
41
+ result += f"-{self.version}"
42
+ return result
43
+
44
+ def with_manifest(self, manifest: PayloadManifest) -> "ProtocolInfo":
45
+ """
46
+ Return a new ProtocolInfo object with the manifest.
47
+ """
48
+ return ProtocolInfo(
49
+ name=self.name,
50
+ version=self.version,
51
+ metadata=self.metadata,
52
+ manifest=manifest,
53
+ )
54
+
55
+
56
+ class ResultMetadata(BaseModel):
57
+ """
58
+ The metadata of a result.
59
+ """
60
+
61
+ type: str
62
+ repr: str
63
+
64
+
65
+ class ExceptionMetadata(ResultMetadata):
66
+ message: str
67
+ traceback: str
68
+
69
+
70
+ class ResultDTO(BaseModel):
71
+ """
72
+ A JSON representation of an execution result.
73
+
74
+ Args:
75
+ success: Whether the execution was successful.
76
+ value: The value of the execution or the exception if the execution failed.
77
+ protocol: The protocol used to serialize the result.
78
+ metadata: The metadata of the result.
79
+ """
80
+
81
+ success: bool
82
+ value: Optional[Any] = None
83
+ protocol: Optional[ProtocolInfo] = None
84
+ metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None
85
+ serialize_error: Optional[str] = None
86
+
87
+ @model_validator(mode="before")
88
+ @classmethod
89
+ def validate_fields(cls, data: Any) -> Any:
90
+ """Ensure at least one of value, protocol, or metadata keys is specified."""
91
+ if isinstance(data, dict):
92
+ required_fields = {"value", "protocol", "metadata"}
93
+ if not any(field in data for field in required_fields):
94
+ raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified")
95
+ return data