snowflake-ml-python 1.15.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/_internal/platform_capabilities.py +4 -0
  4. snowflake/ml/_internal/utils/mixins.py +24 -9
  5. snowflake/ml/experiment/experiment_tracking.py +63 -19
  6. snowflake/ml/jobs/__init__.py +4 -0
  7. snowflake/ml/jobs/_interop/__init__.py +0 -0
  8. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  9. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  10. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  11. snowflake/ml/jobs/_interop/legacy.py +225 -0
  12. snowflake/ml/jobs/_interop/protocols.py +471 -0
  13. snowflake/ml/jobs/_interop/results.py +51 -0
  14. snowflake/ml/jobs/_interop/utils.py +144 -0
  15. snowflake/ml/jobs/_utils/constants.py +4 -1
  16. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  17. snowflake/ml/jobs/_utils/payload_utils.py +1 -1
  18. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  19. snowflake/ml/jobs/_utils/spec_utils.py +50 -11
  20. snowflake/ml/jobs/_utils/types.py +10 -0
  21. snowflake/ml/jobs/job.py +168 -36
  22. snowflake/ml/jobs/manager.py +54 -36
  23. snowflake/ml/model/__init__.py +16 -2
  24. snowflake/ml/model/_client/model/batch_inference_specs.py +18 -2
  25. snowflake/ml/model/_client/model/model_version_impl.py +44 -7
  26. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  27. snowflake/ml/model/_client/ops/service_ops.py +50 -5
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  29. snowflake/ml/model/_client/sql/model_version.py +3 -1
  30. snowflake/ml/model/_client/sql/stage.py +8 -0
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  32. snowflake/ml/model/_model_composer/model_method/model_method.py +32 -4
  33. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  34. snowflake/ml/model/_packager/model_env/model_env.py +48 -21
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  38. snowflake/ml/model/type_hints.py +13 -0
  39. snowflake/ml/model/volatility.py +34 -0
  40. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  41. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  42. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  43. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  44. snowflake/ml/modeling/cluster/birch.py +1 -1
  45. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  46. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  47. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  48. snowflake/ml/modeling/cluster/k_means.py +1 -1
  49. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  50. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  51. snowflake/ml/modeling/cluster/optics.py +1 -1
  52. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  53. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  54. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  55. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  56. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  57. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  58. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  59. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  60. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  61. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  62. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  63. snowflake/ml/modeling/covariance/oas.py +1 -1
  64. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  65. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  66. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  67. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  68. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  69. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  70. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  71. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  72. snowflake/ml/modeling/decomposition/pca.py +1 -1
  73. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  74. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  75. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  76. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  77. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  78. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  79. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  82. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  83. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  84. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  85. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  88. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  89. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  90. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  91. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  92. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  93. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  94. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  95. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  96. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  97. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  98. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  99. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  100. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  101. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  102. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  103. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  104. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  105. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  106. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  107. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  108. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  109. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  110. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  111. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  112. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  113. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  114. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  115. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  116. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  117. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  119. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  120. snowflake/ml/modeling/linear_model/lars.py +1 -1
  121. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  122. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  123. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  124. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  125. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  126. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  127. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  128. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  129. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  130. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  131. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  132. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  133. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  135. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  136. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  137. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  138. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  140. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  141. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  142. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  143. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  144. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  145. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  146. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  147. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  148. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  149. snowflake/ml/modeling/manifold/isomap.py +1 -1
  150. snowflake/ml/modeling/manifold/mds.py +1 -1
  151. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  152. snowflake/ml/modeling/manifold/tsne.py +1 -1
  153. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  154. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  155. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  156. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  157. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  158. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  159. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  160. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  161. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  162. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  163. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  164. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  165. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  166. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  167. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  168. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  169. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  170. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  171. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  172. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  173. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  174. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  175. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  176. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  177. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  178. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  179. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  180. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  181. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  182. snowflake/ml/modeling/svm/svc.py +1 -1
  183. snowflake/ml/modeling/svm/svr.py +1 -1
  184. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  185. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  186. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  187. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  188. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  189. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  191. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  192. snowflake/ml/registry/_manager/model_manager.py +1 -0
  193. snowflake/ml/registry/_manager/model_parameter_reconciler.py +27 -0
  194. snowflake/ml/registry/registry.py +15 -0
  195. snowflake/ml/utils/authentication.py +16 -0
  196. snowflake/ml/version.py +1 -1
  197. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +65 -5
  198. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +201 -192
  199. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
  200. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
  201. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ aerial
1
2
  afraid
2
3
  ancient
3
4
  angry
@@ -26,7 +27,6 @@ dull
26
27
  empty
27
28
  evil
28
29
  fast
29
- fat
30
30
  fluffy
31
31
  foolish
32
32
  fresh
@@ -57,10 +57,10 @@ lovely
57
57
  lucky
58
58
  massive
59
59
  mean
60
+ metallic
60
61
  mighty
61
62
  modern
62
63
  moody
63
- nasty
64
64
  neat
65
65
  nervous
66
66
  new
@@ -85,7 +85,6 @@ rotten
85
85
  rude
86
86
  selfish
87
87
  serious
88
- shaggy
89
88
  sharp
90
89
  short
91
90
  shy
@@ -96,14 +95,15 @@ slippery
96
95
  smart
97
96
  smooth
98
97
  soft
98
+ solid
99
99
  sour
100
100
  spicy
101
101
  splendid
102
102
  spotty
103
+ squishy
103
104
  stale
104
105
  strange
105
106
  strong
106
- stupid
107
107
  sweet
108
108
  swift
109
109
  tall
@@ -116,7 +116,6 @@ tidy
116
116
  tiny
117
117
  tough
118
118
  tricky
119
- ugly
120
119
  warm
121
120
  weak
122
121
  wet
@@ -124,5 +123,6 @@ wicked
124
123
  wise
125
124
  witty
126
125
  wonderful
126
+ wooden
127
127
  yellow
128
128
  young
@@ -1,10 +1,9 @@
1
1
  anaconda
2
2
  ant
3
- ape
4
- baboon
5
3
  badger
6
4
  bat
7
5
  bear
6
+ beetle
8
7
  bird
9
8
  bobcat
10
9
  bulldog
@@ -73,7 +72,6 @@ lobster
73
72
  mayfly
74
73
  mamba
75
74
  mole
76
- monkey
77
75
  moose
78
76
  moth
79
77
  mouse
@@ -114,6 +112,7 @@ swan
114
112
  termite
115
113
  tiger
116
114
  treefrog
115
+ tuna
117
116
  turkey
118
117
  turtle
119
118
  vampirebat
@@ -126,3 +125,4 @@ worm
126
125
  yak
127
126
  yeti
128
127
  zebra
128
+ zebrafish
@@ -17,6 +17,7 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
  LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
19
19
  INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
20
+ SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
20
21
 
21
22
 
22
23
  class PlatformCapabilities:
@@ -73,6 +74,9 @@ class PlatformCapabilities:
73
74
  def is_inlined_deployment_spec_enabled(self) -> bool:
74
75
  return self._is_version_feature_enabled(INLINE_DEPLOYMENT_SPEC_PARAMETER)
75
76
 
77
+ def is_set_module_functions_volatility_from_manifest(self) -> bool:
78
+ return self._get_bool_feature(SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST, False)
79
+
76
80
  def is_live_commit_enabled(self) -> bool:
77
81
  return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
78
82
 
@@ -1,3 +1,4 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Any, Optional
2
3
 
3
4
  from snowflake.ml._internal.utils import identifier
@@ -16,6 +17,14 @@ def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
16
17
  return saved_resolved == current_resolved
17
18
 
18
19
 
20
+ @dataclass(frozen=True)
21
+ class _SessionState:
22
+ account: Optional[str]
23
+ role: Optional[str]
24
+ database: Optional[str]
25
+ schema: Optional[str]
26
+
27
+
19
28
  class SerializableSessionMixin:
20
29
  """Mixin that provides pickling capabilities for objects with Snowpark sessions."""
21
30
 
@@ -40,17 +49,23 @@ class SerializableSessionMixin:
40
49
 
41
50
  def __setstate__(self, state: dict[str, Any]) -> None:
42
51
  """Restore session from context during unpickling."""
43
- saved_account = state.pop(_SESSION_ACCOUNT_KEY, None)
44
- saved_role = state.pop(_SESSION_ROLE_KEY, None)
45
- saved_database = state.pop(_SESSION_DATABASE_KEY, None)
46
- saved_schema = state.pop(_SESSION_SCHEMA_KEY, None)
52
+ session_state = _SessionState(
53
+ account=state.pop(_SESSION_ACCOUNT_KEY, None),
54
+ role=state.pop(_SESSION_ROLE_KEY, None),
55
+ database=state.pop(_SESSION_DATABASE_KEY, None),
56
+ schema=state.pop(_SESSION_SCHEMA_KEY, None),
57
+ )
47
58
 
48
59
  if hasattr(super(), "__setstate__"):
49
60
  super().__setstate__(state) # type: ignore[misc]
50
61
  else:
51
62
  self.__dict__.update(state)
52
63
 
53
- if saved_account is not None:
64
+ self._set_session(session_state)
65
+
66
+ def _set_session(self, session_state: _SessionState) -> None:
67
+
68
+ if session_state.account is not None:
54
69
  active_sessions = snowpark_session._get_active_sessions()
55
70
  if len(active_sessions) == 0:
56
71
  raise RuntimeError("No active Snowpark session available. Please create a session.")
@@ -63,10 +78,10 @@ class SerializableSessionMixin:
63
78
  active_sessions,
64
79
  key=lambda s: sum(
65
80
  (
66
- _identifiers_match(saved_account, s.get_current_account()),
67
- _identifiers_match(saved_role, s.get_current_role()),
68
- _identifiers_match(saved_database, s.get_current_database()),
69
- _identifiers_match(saved_schema, s.get_current_schema()),
81
+ _identifiers_match(session_state.account, s.get_current_account()),
82
+ _identifiers_match(session_state.role, s.get_current_role()),
83
+ _identifiers_match(session_state.database, s.get_current_database()),
84
+ _identifiers_match(session_state.schema, s.get_current_schema()),
70
85
  )
71
86
  ),
72
87
  ),
@@ -1,10 +1,10 @@
1
1
  import functools
2
2
  import json
3
3
  import sys
4
- from typing import Any, Optional, Union
4
+ from typing import Any, Callable, Concatenate, Optional, ParamSpec, TypeVar, Union
5
5
  from urllib.parse import quote
6
6
 
7
- import snowflake.snowpark._internal.utils as snowpark_utils
7
+ from snowflake import snowpark
8
8
  from snowflake.ml import model as ml_model, registry
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import mixins, sql_identifier
@@ -18,20 +18,40 @@ from snowflake.ml.experiment._client import (
18
18
  )
19
19
  from snowflake.ml.model import type_hints
20
20
  from snowflake.ml.utils import sql_client as sql_client_utils
21
- from snowflake.snowpark import session
22
21
 
23
22
  DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
24
23
 
24
+ P = ParamSpec("P")
25
+ T = TypeVar("T")
26
+
27
+
28
+ def _restore_session(
29
+ func: Callable[Concatenate["ExperimentTracking", P], T],
30
+ ) -> Callable[Concatenate["ExperimentTracking", P], T]:
31
+ @functools.wraps(func)
32
+ def wrapper(self: "ExperimentTracking", /, *args: P.args, **kwargs: P.kwargs) -> T:
33
+ if self._session is None:
34
+ if self._session_state is None:
35
+ raise RuntimeError(
36
+ f"Session is not set before calling {func.__name__}, and there is no session state to restore from"
37
+ )
38
+ self._set_session(self._session_state)
39
+ if self._session is None:
40
+ raise RuntimeError(f"Failed to restore session before calling {func.__name__}")
41
+ return func(self, *args, **kwargs)
42
+
43
+ return wrapper
44
+
25
45
 
26
46
  class ExperimentTracking(mixins.SerializableSessionMixin):
27
47
  """
28
48
  Class to manage experiments in Snowflake.
29
49
  """
30
50
 
31
- @snowpark_utils.private_preview(version="1.9.1")
51
+ @snowpark._internal.utils.private_preview(version="1.9.1")
32
52
  def __init__(
33
53
  self,
34
- session: session.Session,
54
+ session: snowpark.Session,
35
55
  *,
36
56
  database_name: Optional[str] = None,
37
57
  schema_name: Optional[str] = None,
@@ -73,7 +93,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
73
93
  database_name=self._database_name,
74
94
  schema_name=self._schema_name,
75
95
  )
76
- self._session = session
96
+ self._session: Optional[snowpark.Session] = session
97
+ # Used to store information about the session if the session could not be restored during unpickling
98
+ # _session_state is None if and only if _session is not None
99
+ self._session_state: Optional[mixins._SessionState] = None
77
100
 
78
101
  # The experiment in context
79
102
  self._experiment: Optional[entities.Experiment] = None
@@ -87,20 +110,29 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
87
110
  state["_registry"] = None
88
111
  return state
89
112
 
90
- def __setstate__(self, state: dict[str, Any]) -> None:
91
- super().__setstate__(state)
92
- # Restore unpicklable attributes
93
- self._sql_client = sql_client.ExperimentTrackingSQLClient(
94
- session=self._session,
95
- database_name=self._database_name,
96
- schema_name=self._schema_name,
97
- )
98
- self._registry = registry.Registry(
99
- session=self._session,
100
- database_name=self._database_name,
101
- schema_name=self._schema_name,
102
- )
113
+ def _set_session(self, session_state: mixins._SessionState) -> None:
114
+ try:
115
+ super()._set_session(session_state)
116
+ assert self._session is not None
117
+ except (snowpark.exceptions.SnowparkSessionException, AssertionError):
118
+ # If session was not set, store the session state
119
+ self._session = None
120
+ self._session_state = session_state
121
+ else:
122
+ # If session was set, clear the session state, and reinitialize the SQL client and registry
123
+ self._session_state = None
124
+ self._sql_client = sql_client.ExperimentTrackingSQLClient(
125
+ session=self._session,
126
+ database_name=self._database_name,
127
+ schema_name=self._schema_name,
128
+ )
129
+ self._registry = registry.Registry(
130
+ session=self._session,
131
+ database_name=self._database_name,
132
+ schema_name=self._schema_name,
133
+ )
103
134
 
135
+ @_restore_session
104
136
  def set_experiment(
105
137
  self,
106
138
  experiment_name: str,
@@ -125,6 +157,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
125
157
  self._run = None
126
158
  return self._experiment
127
159
 
160
+ @_restore_session
128
161
  def delete_experiment(
129
162
  self,
130
163
  experiment_name: str,
@@ -141,8 +174,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
141
174
  self._run = None
142
175
 
143
176
  @functools.wraps(registry.Registry.log_model)
177
+ @_restore_session
144
178
  def log_model(
145
179
  self,
180
+ /, # self needs to be a positional argument to stop mypy from complaining
146
181
  model: Union[type_hints.SupportedModelType, ml_model.ModelVersion],
147
182
  *,
148
183
  model_name: str,
@@ -152,6 +187,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
152
187
  with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
153
188
  return self._registry.log_model(model, model_name=model_name, **kwargs)
154
189
 
190
+ @_restore_session
155
191
  def start_run(
156
192
  self,
157
193
  run_name: Optional[str] = None,
@@ -181,6 +217,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
181
217
  self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
182
218
  return self._run
183
219
 
220
+ @_restore_session
184
221
  def end_run(self, run_name: Optional[str] = None) -> None:
185
222
  """
186
223
  End the current run if no run name is provided. Otherwise, the specified run is ended.
@@ -210,6 +247,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
210
247
  self._run = None
211
248
  self._print_urls(experiment_name=experiment_name, run_name=run_name)
212
249
 
250
+ @_restore_session
213
251
  def delete_run(
214
252
  self,
215
253
  run_name: str,
@@ -248,6 +286,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
248
286
  """
249
287
  self.log_metrics(metrics={key: value}, step=step)
250
288
 
289
+ @_restore_session
251
290
  def log_metrics(
252
291
  self,
253
292
  metrics: dict[str, float],
@@ -284,6 +323,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
284
323
  """
285
324
  self.log_params({key: value})
286
325
 
326
+ @_restore_session
287
327
  def log_params(
288
328
  self,
289
329
  params: dict[str, Any],
@@ -305,6 +345,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
305
345
  params=json.dumps([param.to_dict() for param in params_list]),
306
346
  )
307
347
 
348
+ @_restore_session
308
349
  def log_artifact(
309
350
  self,
310
351
  local_path: str,
@@ -328,6 +369,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
328
369
  file_path=file_path,
329
370
  )
330
371
 
372
+ @_restore_session
331
373
  def list_artifacts(
332
374
  self,
333
375
  run_name: str,
@@ -356,6 +398,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
356
398
  artifact_path=artifact_path or "",
357
399
  )
358
400
 
401
+ @_restore_session
359
402
  def download_artifacts(
360
403
  self,
361
404
  run_name: str,
@@ -397,6 +440,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
397
440
  return self._run
398
441
  return self.start_run()
399
442
 
443
+ @_restore_session
400
444
  def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
401
445
  generator = hrid_generator.HRID16()
402
446
  existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
@@ -1,3 +1,4 @@
1
+ from snowflake.ml.jobs._interop.exception_utils import install_exception_display_hooks
1
2
  from snowflake.ml.jobs._utils.types import JOB_STATUS
2
3
  from snowflake.ml.jobs.decorators import remote
3
4
  from snowflake.ml.jobs.job import MLJob
@@ -10,6 +11,9 @@ from snowflake.ml.jobs.manager import (
10
11
  submit_from_stage,
11
12
  )
12
13
 
14
+ # Initialize exception display hooks for remote job error handling
15
+ install_exception_display_hooks()
16
+
13
17
  __all__ = [
14
18
  "remote",
15
19
  "submit_file",
File without changes
@@ -0,0 +1,124 @@
1
+ import io
2
+ import json
3
+ from typing import Any, Literal, Optional, Protocol, Union, cast, overload
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml.jobs._interop import dto_schema
7
+
8
+
9
+ class StageFileWriter(io.IOBase):
10
+ """
11
+ A context manager IOBase implementation that proxies writes to an internal BytesIO
12
+ and uploads to Snowflake stage on close.
13
+ """
14
+
15
+ def __init__(self, session: snowpark.Session, path: str) -> None:
16
+ self._session = session
17
+ self._path = path
18
+ self._buffer = io.BytesIO()
19
+ self._closed = False
20
+ self._exception_occurred = False
21
+
22
+ def write(self, data: Union[bytes, bytearray]) -> int:
23
+ """Write data to the internal buffer."""
24
+ if self._closed:
25
+ raise ValueError("I/O operation on closed file")
26
+ return self._buffer.write(data)
27
+
28
+ def close(self, write_contents: bool = True) -> None:
29
+ """Close the file and upload the buffer contents to the stage."""
30
+ if not self._closed:
31
+ # Only upload if buffer has content and no exception occurred
32
+ if write_contents and self._buffer.tell() > 0:
33
+ self._buffer.seek(0)
34
+ self._session.file.put_stream(self._buffer, self._path)
35
+ self._buffer.close()
36
+ self._closed = True
37
+
38
+ def __enter__(self) -> "StageFileWriter":
39
+ return self
40
+
41
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
42
+ exception_occurred = exc_type is not None
43
+ self.close(write_contents=not exception_occurred)
44
+
45
+ @property
46
+ def closed(self) -> bool:
47
+ return self._closed
48
+
49
+ def writable(self) -> bool:
50
+ return not self._closed
51
+
52
+ def readable(self) -> bool:
53
+ return False
54
+
55
+ def seekable(self) -> bool:
56
+ return not self._closed
57
+
58
+
59
+ def _is_stage_path(path: str) -> bool:
60
+ return path.startswith("@") or path.startswith("snow://")
61
+
62
+
63
+ def open_stream(path: str, mode: str = "rb", session: Optional[snowpark.Session] = None) -> io.IOBase:
64
+ if _is_stage_path(path):
65
+ if session is None:
66
+ raise ValueError("Session is required when opening a stage path")
67
+ if "r" in mode:
68
+ stream: io.IOBase = session.file.get_stream(path) # type: ignore[assignment]
69
+ return stream
70
+ elif "w" in mode:
71
+ return StageFileWriter(session, path)
72
+ else:
73
+ raise ValueError(f"Unsupported mode '{mode}' for stage path")
74
+ else:
75
+ result: io.IOBase = open(path, mode) # type: ignore[assignment]
76
+ return result
77
+
78
+
79
+ class DtoCodec(Protocol):
80
+ @overload
81
+ @staticmethod
82
+ def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]:
83
+ ...
84
+
85
+ @overload
86
+ @staticmethod
87
+ def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
88
+ ...
89
+
90
+ @staticmethod
91
+ def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
92
+ pass
93
+
94
+ @staticmethod
95
+ def encode(dto: dto_schema.ResultDTO) -> bytes:
96
+ pass
97
+
98
+
99
+ class JsonDtoCodec(DtoCodec):
100
+ @overload
101
+ @staticmethod
102
+ def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]:
103
+ ...
104
+
105
+ @overload
106
+ @staticmethod
107
+ def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
108
+ ...
109
+
110
+ @staticmethod
111
+ def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
112
+ data = cast(dict[str, Any], json.load(stream))
113
+ if as_dict:
114
+ return data
115
+ return dto_schema.ResultDTO.model_validate(data)
116
+
117
+ @staticmethod
118
+ def encode(dto: dto_schema.ResultDTO) -> bytes:
119
+ # Temporarily extract the value to avoid accidentally applying model_dump() on it
120
+ result_value = dto.value
121
+ dto.value = None # Clear value to avoid serializing it in the model_dump
122
+ result_dict = dto.model_dump()
123
+ result_dict["value"] = result_value # Put back the value
124
+ return json.dumps(result_dict).encode("utf-8")
@@ -0,0 +1,95 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ from pydantic import BaseModel, model_validator
4
+ from typing_extensions import NotRequired, TypedDict
5
+
6
+
7
+ class BinaryManifest(TypedDict):
8
+ """
9
+ Binary data manifest schema.
10
+ Contains one of: path, bytes, or base64 for the serialized data.
11
+ """
12
+
13
+ path: NotRequired[str] # Path to file
14
+ bytes: NotRequired[bytes] # In-line byte string (not supported with JSON codec)
15
+ base64: NotRequired[str] # Base64 encoded string
16
+
17
+
18
+ class ParquetManifest(TypedDict):
19
+ """Protocol manifest schema for parquet files."""
20
+
21
+ paths: list[str] # File paths
22
+
23
+
24
+ # Union type for all manifest types, including catch-all dict[str, Any] for backward compatibility
25
+ PayloadManifest = Union[BinaryManifest, ParquetManifest, dict[str, Any]]
26
+
27
+
28
+ class ProtocolInfo(BaseModel):
29
+ """
30
+ The protocol used to serialize the result and the manifest of the result.
31
+ """
32
+
33
+ name: str
34
+ version: Optional[str] = None
35
+ metadata: Optional[dict[str, str]] = None
36
+ manifest: Optional[PayloadManifest] = None
37
+
38
+ def __str__(self) -> str:
39
+ result = self.name
40
+ if self.version:
41
+ result += f"-{self.version}"
42
+ return result
43
+
44
+ def with_manifest(self, manifest: PayloadManifest) -> "ProtocolInfo":
45
+ """
46
+ Return a new ProtocolInfo object with the manifest.
47
+ """
48
+ return ProtocolInfo(
49
+ name=self.name,
50
+ version=self.version,
51
+ metadata=self.metadata,
52
+ manifest=manifest,
53
+ )
54
+
55
+
56
+ class ResultMetadata(BaseModel):
57
+ """
58
+ The metadata of a result.
59
+ """
60
+
61
+ type: str
62
+ repr: str
63
+
64
+
65
+ class ExceptionMetadata(ResultMetadata):
66
+ message: str
67
+ traceback: str
68
+
69
+
70
+ class ResultDTO(BaseModel):
71
+ """
72
+ A JSON representation of an execution result.
73
+
74
+ Args:
75
+ success: Whether the execution was successful.
76
+ value: The value of the execution or the exception if the execution failed.
77
+ protocol: The protocol used to serialize the result.
78
+ metadata: The metadata of the result.
79
+ """
80
+
81
+ success: bool
82
+ value: Optional[Any] = None
83
+ protocol: Optional[ProtocolInfo] = None
84
+ metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None
85
+ serialize_error: Optional[str] = None
86
+
87
+ @model_validator(mode="before")
88
+ @classmethod
89
+ def validate_fields(cls, data: Any) -> Any:
90
+ """Ensure at least one of value, protocol, or metadata keys is specified."""
91
+ if isinstance(data, dict):
92
+ required_fields = {"value", "protocol", "metadata"}
93
+ if not any(field in data for field in required_fields):
94
+ raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified")
95
+ return data