snowflake-ml-python 1.14.0__py3-none-any.whl → 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +13 -7
  2. snowflake/ml/_internal/utils/connection_params.py +5 -3
  3. snowflake/ml/_internal/utils/jwt_generator.py +3 -2
  4. snowflake/ml/_internal/utils/mixins.py +24 -9
  5. snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
  6. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
  7. snowflake/ml/experiment/_entities/__init__.py +2 -1
  8. snowflake/ml/experiment/_entities/run.py +0 -15
  9. snowflake/ml/experiment/_entities/run_metadata.py +3 -51
  10. snowflake/ml/experiment/experiment_tracking.py +71 -27
  11. snowflake/ml/jobs/_utils/spec_utils.py +49 -11
  12. snowflake/ml/jobs/manager.py +20 -0
  13. snowflake/ml/model/__init__.py +12 -2
  14. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -4
  15. snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
  16. snowflake/ml/model/_client/model/model_version_impl.py +30 -62
  17. snowflake/ml/model/_client/ops/service_ops.py +68 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  19. snowflake/ml/model/_client/sql/service.py +29 -2
  20. snowflake/ml/model/_client/sql/stage.py +8 -0
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  22. snowflake/ml/model/_model_composer/model_method/model_method.py +25 -2
  23. snowflake/ml/model/_packager/model_env/model_env.py +26 -16
  24. snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
  25. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
  26. snowflake/ml/model/_packager/model_packager.py +4 -3
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  28. snowflake/ml/model/_signatures/utils.py +0 -21
  29. snowflake/ml/model/models/huggingface_pipeline.py +56 -21
  30. snowflake/ml/model/type_hints.py +13 -0
  31. snowflake/ml/model/volatility.py +34 -0
  32. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  33. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  34. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  35. snowflake/ml/modeling/cluster/birch.py +1 -1
  36. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  37. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  38. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  39. snowflake/ml/modeling/cluster/k_means.py +1 -1
  40. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  41. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  42. snowflake/ml/modeling/cluster/optics.py +1 -1
  43. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  44. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  45. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  46. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  47. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  48. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  49. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  50. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  51. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  52. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  53. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  54. snowflake/ml/modeling/covariance/oas.py +1 -1
  55. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  56. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  57. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  58. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  59. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  60. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  61. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  62. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  63. snowflake/ml/modeling/decomposition/pca.py +1 -1
  64. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  65. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  66. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  67. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  68. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  69. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  70. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  71. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  72. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  73. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  74. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  75. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  78. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  79. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  82. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  83. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  84. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  85. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  86. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  87. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  88. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  89. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  90. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  91. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  94. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  95. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  96. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  97. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  98. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  99. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  100. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  101. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  102. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  103. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  104. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  105. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  106. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  107. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  108. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  109. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  110. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  111. snowflake/ml/modeling/linear_model/lars.py +1 -1
  112. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  113. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  114. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  115. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  116. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  118. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  119. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  120. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  123. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  124. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  125. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  126. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  127. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  128. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  129. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  130. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  131. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  132. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  133. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  135. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  136. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  137. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  138. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  140. snowflake/ml/modeling/manifold/isomap.py +1 -1
  141. snowflake/ml/modeling/manifold/mds.py +1 -1
  142. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  143. snowflake/ml/modeling/manifold/tsne.py +1 -1
  144. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  145. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  146. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  147. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  148. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  149. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  150. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  151. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  152. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  153. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  154. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  155. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  156. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  157. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  158. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  159. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  160. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  161. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  162. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  163. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  164. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  165. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  166. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  167. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  168. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  169. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  170. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  171. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  172. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  173. snowflake/ml/modeling/svm/svc.py +1 -1
  174. snowflake/ml/modeling/svm/svr.py +1 -1
  175. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  176. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  177. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  178. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  179. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  180. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  181. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  182. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  183. snowflake/ml/registry/_manager/model_manager.py +2 -1
  184. snowflake/ml/registry/_manager/model_parameter_reconciler.py +29 -2
  185. snowflake/ml/registry/registry.py +15 -0
  186. snowflake/ml/utils/authentication.py +16 -0
  187. snowflake/ml/utils/connection_params.py +5 -3
  188. snowflake/ml/version.py +1 -1
  189. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/METADATA +81 -36
  190. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/RECORD +193 -191
  191. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/WHEEL +0 -0
  192. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/licenses/LICENSE.txt +0 -0
  193. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
1
  import json
2
+ import logging
2
3
  from contextlib import contextmanager
3
4
  from typing import Any, Optional
4
5
 
5
- from absl import logging
6
6
  from packaging import version
7
7
 
8
8
  from snowflake.ml import version as snowml_version
@@ -13,8 +13,11 @@ from snowflake.snowpark import (
13
13
  session as snowpark_session,
14
14
  )
15
15
 
16
+ logger = logging.getLogger(__name__)
17
+
16
18
  LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
17
19
  INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
20
+ SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
18
21
 
19
22
 
20
23
  class PlatformCapabilities:
@@ -60,17 +63,20 @@ class PlatformCapabilities:
60
63
  @classmethod # type: ignore[arg-type]
61
64
  @contextmanager
62
65
  def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
63
- logging.debug(f"Setting mock features: {features}")
66
+ logger.debug(f"Setting mock features: {features}")
64
67
  cls.set_mock_features(features)
65
68
  try:
66
69
  yield
67
70
  finally:
68
- logging.debug(f"Clearing mock features: {features}")
71
+ logger.debug(f"Clearing mock features: {features}")
69
72
  cls.clear_mock_features()
70
73
 
71
74
  def is_inlined_deployment_spec_enabled(self) -> bool:
72
75
  return self._is_version_feature_enabled(INLINE_DEPLOYMENT_SPEC_PARAMETER)
73
76
 
77
+ def is_set_module_functions_volatility_from_manifest(self) -> bool:
78
+ return self._get_bool_feature(SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST, False)
79
+
74
80
  def is_live_commit_enabled(self) -> bool:
75
81
  return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
76
82
 
@@ -98,7 +104,7 @@ class PlatformCapabilities:
98
104
  error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
99
105
  )
100
106
  except snowpark_exceptions.SnowparkSQLException as e:
101
- logging.debug(f"Failed to retrieve platform capabilities: {e}")
107
+ logger.debug(f"Failed to retrieve platform capabilities: {e}")
102
108
  # This can happen is server side is older than 9.2. That is fine.
103
109
  return {}
104
110
 
@@ -144,7 +150,7 @@ class PlatformCapabilities:
144
150
 
145
151
  value = self.features.get(feature_name)
146
152
  if value is None:
147
- logging.debug(f"Feature {feature_name} not found, returning large version number")
153
+ logger.debug(f"Feature {feature_name} not found, returning large version number")
148
154
  return large_version
149
155
 
150
156
  try:
@@ -152,7 +158,7 @@ class PlatformCapabilities:
152
158
  version_str = str(value)
153
159
  return version.Version(version_str)
154
160
  except (version.InvalidVersion, ValueError, TypeError) as e:
155
- logging.debug(
161
+ logger.debug(
156
162
  f"Failed to parse version from feature {feature_name} with value '{value}': {e}. "
157
163
  f"Returning large version number"
158
164
  )
@@ -171,7 +177,7 @@ class PlatformCapabilities:
171
177
  feature_version = self._get_version_feature(feature_name)
172
178
 
173
179
  result = current_version >= feature_version
174
- logging.debug(
180
+ logger.debug(
175
181
  f"Version comparison for feature {feature_name}: "
176
182
  f"current={current_version}, feature={feature_version}, enabled={result}"
177
183
  )
@@ -1,11 +1,13 @@
1
1
  import configparser
2
+ import logging
2
3
  import os
3
4
  from typing import Optional, Union
4
5
 
5
- from absl import logging
6
6
  from cryptography.hazmat import backends
7
7
  from cryptography.hazmat.primitives import serialization
8
8
 
9
+ logger = logging.getLogger(__name__)
10
+
9
11
  _DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
10
12
 
11
13
 
@@ -106,7 +108,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
106
108
  """Loads the dictionary from snowsql config file."""
107
109
  snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
108
110
  if not os.path.exists(snowsql_config_file):
109
- logging.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
111
+ logger.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
110
112
  raise Exception("Snowflake SnowSQL config not found.")
111
113
 
112
114
  config = configparser.ConfigParser(inline_comment_prefixes="#")
@@ -122,7 +124,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
122
124
  # See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
123
125
  connection_name = "connections"
124
126
 
125
- logging.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
127
+ logger.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
126
128
  config.read(snowsql_config_file)
127
129
  conn_params = dict(config[connection_name])
128
130
  # Remap names to appropriate args in Python Connector API
@@ -110,15 +110,16 @@ class JWTGenerator:
110
110
  }
111
111
 
112
112
  # Regenerate the actual token
113
- token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
113
+ token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM) # type: ignore[arg-type]
114
114
  # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
115
115
  # If the token is a byte string, convert it to a string.
116
116
  if isinstance(token, bytes):
117
117
  token = token.decode("utf-8")
118
118
  self.token = token
119
+ public_key = self.private_key.public_key()
119
120
  logger.info(
120
121
  "Generated a JWT with the following payload: %s",
121
- jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
122
+ jwt.decode(self.token, key=public_key, algorithms=[JWTGenerator.ALGORITHM]), # type: ignore[arg-type]
122
123
  )
123
124
 
124
125
  return token
@@ -1,3 +1,4 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Any, Optional
2
3
 
3
4
  from snowflake.ml._internal.utils import identifier
@@ -16,6 +17,14 @@ def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
16
17
  return saved_resolved == current_resolved
17
18
 
18
19
 
20
+ @dataclass(frozen=True)
21
+ class _SessionState:
22
+ account: Optional[str]
23
+ role: Optional[str]
24
+ database: Optional[str]
25
+ schema: Optional[str]
26
+
27
+
19
28
  class SerializableSessionMixin:
20
29
  """Mixin that provides pickling capabilities for objects with Snowpark sessions."""
21
30
 
@@ -40,17 +49,23 @@ class SerializableSessionMixin:
40
49
 
41
50
  def __setstate__(self, state: dict[str, Any]) -> None:
42
51
  """Restore session from context during unpickling."""
43
- saved_account = state.pop(_SESSION_ACCOUNT_KEY, None)
44
- saved_role = state.pop(_SESSION_ROLE_KEY, None)
45
- saved_database = state.pop(_SESSION_DATABASE_KEY, None)
46
- saved_schema = state.pop(_SESSION_SCHEMA_KEY, None)
52
+ session_state = _SessionState(
53
+ account=state.pop(_SESSION_ACCOUNT_KEY, None),
54
+ role=state.pop(_SESSION_ROLE_KEY, None),
55
+ database=state.pop(_SESSION_DATABASE_KEY, None),
56
+ schema=state.pop(_SESSION_SCHEMA_KEY, None),
57
+ )
47
58
 
48
59
  if hasattr(super(), "__setstate__"):
49
60
  super().__setstate__(state) # type: ignore[misc]
50
61
  else:
51
62
  self.__dict__.update(state)
52
63
 
53
- if saved_account is not None:
64
+ self._set_session(session_state)
65
+
66
+ def _set_session(self, session_state: _SessionState) -> None:
67
+
68
+ if session_state.account is not None:
54
69
  active_sessions = snowpark_session._get_active_sessions()
55
70
  if len(active_sessions) == 0:
56
71
  raise RuntimeError("No active Snowpark session available. Please create a session.")
@@ -63,10 +78,10 @@ class SerializableSessionMixin:
63
78
  active_sessions,
64
79
  key=lambda s: sum(
65
80
  (
66
- _identifiers_match(saved_account, s.get_current_account()),
67
- _identifiers_match(saved_role, s.get_current_role()),
68
- _identifiers_match(saved_database, s.get_current_database()),
69
- _identifiers_match(saved_schema, s.get_current_schema()),
81
+ _identifiers_match(session_state.account, s.get_current_account()),
82
+ _identifiers_match(session_state.role, s.get_current_role()),
83
+ _identifiers_match(session_state.database, s.get_current_database()),
84
+ _identifiers_match(session_state.schema, s.get_current_schema()),
70
85
  )
71
86
  ),
72
87
  ),
@@ -1,10 +1,9 @@
1
+ import logging
1
2
  import os
2
3
  import shutil
3
4
  import tempfile
4
5
  from typing import Iterable, Union
5
6
 
6
- from absl.logging import logging
7
-
8
7
  logger = logging.getLogger(__name__)
9
8
 
10
9
 
@@ -76,17 +76,30 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
76
76
  self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
77
77
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
78
78
 
79
- def modify_run(
79
+ def modify_run_add_metrics(
80
80
  self,
81
81
  *,
82
82
  experiment_name: sql_identifier.SqlIdentifier,
83
83
  run_name: sql_identifier.SqlIdentifier,
84
- run_metadata: str,
84
+ metrics: str,
85
85
  ) -> None:
86
86
  experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
87
87
  query_result_checker.SqlResultValidator(
88
88
  self._session,
89
- f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} SET METADATA=$${run_metadata}$$",
89
+ f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
90
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
91
+
92
+ def modify_run_add_params(
93
+ self,
94
+ *,
95
+ experiment_name: sql_identifier.SqlIdentifier,
96
+ run_name: sql_identifier.SqlIdentifier,
97
+ params: str,
98
+ ) -> None:
99
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
100
+ query_result_checker.SqlResultValidator(
101
+ self._session,
102
+ f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
90
103
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
91
104
 
92
105
  def put_artifact(
@@ -1,4 +1,5 @@
1
1
  from snowflake.ml.experiment._entities.experiment import Experiment
2
2
  from snowflake.ml.experiment._entities.run import Run
3
+ from snowflake.ml.experiment._entities.run_metadata import Metric, Param
3
4
 
4
- __all__ = ["Experiment", "Run"]
5
+ __all__ = ["Experiment", "Run", "Metric", "Param"]
@@ -1,11 +1,8 @@
1
- import json
2
1
  import types
3
2
  from typing import TYPE_CHECKING, Optional
4
3
 
5
4
  from snowflake.ml._internal.utils import sql_identifier
6
5
  from snowflake.ml.experiment import _experiment_info as experiment_info
7
- from snowflake.ml.experiment._client import experiment_tracking_sql_client
8
- from snowflake.ml.experiment._entities import run_metadata
9
6
 
10
7
  if TYPE_CHECKING:
11
8
  from snowflake.ml.experiment import experiment_tracking
@@ -41,18 +38,6 @@ class Run:
41
38
  if self._experiment_tracking._run is self:
42
39
  self._experiment_tracking.end_run()
43
40
 
44
- def _get_metadata(
45
- self,
46
- ) -> run_metadata.RunMetadata:
47
- runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
48
- experiment_name=self.experiment_name, like=str(self.name)
49
- )
50
- if not runs:
51
- raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
52
- return run_metadata.RunMetadata.from_dict(
53
- json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
54
- )
55
-
56
41
  def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
57
42
  return experiment_info.ExperimentInfo(
58
43
  fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
@@ -1,12 +1,4 @@
1
1
  import dataclasses
2
- import enum
3
- import typing
4
-
5
-
6
- class RunStatus(str, enum.Enum):
7
- UNKNOWN = "UNKNOWN"
8
- RUNNING = "RUNNING"
9
- FINISHED = "FINISHED"
10
2
 
11
3
 
12
4
  @dataclasses.dataclass
@@ -15,54 +7,14 @@ class Metric:
15
7
  value: float
16
8
  step: int
17
9
 
10
+ def to_dict(self) -> dict: # type: ignore[type-arg]
11
+ return dataclasses.asdict(self)
12
+
18
13
 
19
14
  @dataclasses.dataclass
20
15
  class Param:
21
16
  name: str
22
17
  value: str
23
18
 
24
-
25
- @dataclasses.dataclass
26
- class RunMetadata:
27
- status: RunStatus
28
- metrics: list[Metric]
29
- parameters: list[Param]
30
-
31
- @classmethod
32
- def from_dict(
33
- cls,
34
- metadata: dict, # type: ignore[type-arg]
35
- ) -> "RunMetadata":
36
- return RunMetadata(
37
- status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
38
- metrics=[Metric(**m) for m in metadata.get("metrics", [])],
39
- parameters=[Param(**p) for p in metadata.get("parameters", [])],
40
- )
41
-
42
19
  def to_dict(self) -> dict: # type: ignore[type-arg]
43
20
  return dataclasses.asdict(self)
44
-
45
- def set_metric(
46
- self,
47
- key: str,
48
- value: float,
49
- step: int,
50
- ) -> None:
51
- for metric in self.metrics:
52
- if metric.name == key and metric.step == step:
53
- metric.value = value
54
- break
55
- else:
56
- self.metrics.append(Metric(name=key, value=value, step=step))
57
-
58
- def set_param(
59
- self,
60
- key: str,
61
- value: typing.Any,
62
- ) -> None:
63
- for parameter in self.parameters:
64
- if parameter.name == key:
65
- parameter.value = str(value)
66
- break
67
- else:
68
- self.parameters.append(Param(name=key, value=str(value)))
@@ -1,10 +1,10 @@
1
1
  import functools
2
2
  import json
3
3
  import sys
4
- from typing import Any, Optional, Union
4
+ from typing import Any, Callable, Concatenate, Optional, ParamSpec, TypeVar, Union
5
5
  from urllib.parse import quote
6
6
 
7
- import snowflake.snowpark._internal.utils as snowpark_utils
7
+ from snowflake import snowpark
8
8
  from snowflake.ml import model as ml_model, registry
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import mixins, sql_identifier
@@ -18,20 +18,40 @@ from snowflake.ml.experiment._client import (
18
18
  )
19
19
  from snowflake.ml.model import type_hints
20
20
  from snowflake.ml.utils import sql_client as sql_client_utils
21
- from snowflake.snowpark import session
22
21
 
23
22
  DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
24
23
 
24
+ P = ParamSpec("P")
25
+ T = TypeVar("T")
26
+
27
+
28
+ def _restore_session(
29
+ func: Callable[Concatenate["ExperimentTracking", P], T],
30
+ ) -> Callable[Concatenate["ExperimentTracking", P], T]:
31
+ @functools.wraps(func)
32
+ def wrapper(self: "ExperimentTracking", /, *args: P.args, **kwargs: P.kwargs) -> T:
33
+ if self._session is None:
34
+ if self._session_state is None:
35
+ raise RuntimeError(
36
+ f"Session is not set before calling {func.__name__}, and there is no session state to restore from"
37
+ )
38
+ self._set_session(self._session_state)
39
+ if self._session is None:
40
+ raise RuntimeError(f"Failed to restore session before calling {func.__name__}")
41
+ return func(self, *args, **kwargs)
42
+
43
+ return wrapper
44
+
25
45
 
26
46
  class ExperimentTracking(mixins.SerializableSessionMixin):
27
47
  """
28
48
  Class to manage experiments in Snowflake.
29
49
  """
30
50
 
31
- @snowpark_utils.private_preview(version="1.9.1")
51
+ @snowpark._internal.utils.private_preview(version="1.9.1")
32
52
  def __init__(
33
53
  self,
34
- session: session.Session,
54
+ session: snowpark.Session,
35
55
  *,
36
56
  database_name: Optional[str] = None,
37
57
  schema_name: Optional[str] = None,
@@ -73,7 +93,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
73
93
  database_name=self._database_name,
74
94
  schema_name=self._schema_name,
75
95
  )
76
- self._session = session
96
+ self._session: Optional[snowpark.Session] = session
97
+ # Used to store information about the session if the session could not be restored during unpickling
98
+ # _session_state is None if and only if _session is not None
99
+ self._session_state: Optional[mixins._SessionState] = None
77
100
 
78
101
  # The experiment in context
79
102
  self._experiment: Optional[entities.Experiment] = None
@@ -87,20 +110,29 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
87
110
  state["_registry"] = None
88
111
  return state
89
112
 
90
- def __setstate__(self, state: dict[str, Any]) -> None:
91
- super().__setstate__(state)
92
- # Restore unpicklable attributes
93
- self._sql_client = sql_client.ExperimentTrackingSQLClient(
94
- session=self._session,
95
- database_name=self._database_name,
96
- schema_name=self._schema_name,
97
- )
98
- self._registry = registry.Registry(
99
- session=self._session,
100
- database_name=self._database_name,
101
- schema_name=self._schema_name,
102
- )
113
+ def _set_session(self, session_state: mixins._SessionState) -> None:
114
+ try:
115
+ super()._set_session(session_state)
116
+ assert self._session is not None
117
+ except (snowpark.exceptions.SnowparkSessionException, AssertionError):
118
+ # If session was not set, store the session state
119
+ self._session = None
120
+ self._session_state = session_state
121
+ else:
122
+ # If session was set, clear the session state, and reinitialize the SQL client and registry
123
+ self._session_state = None
124
+ self._sql_client = sql_client.ExperimentTrackingSQLClient(
125
+ session=self._session,
126
+ database_name=self._database_name,
127
+ schema_name=self._schema_name,
128
+ )
129
+ self._registry = registry.Registry(
130
+ session=self._session,
131
+ database_name=self._database_name,
132
+ schema_name=self._schema_name,
133
+ )
103
134
 
135
+ @_restore_session
104
136
  def set_experiment(
105
137
  self,
106
138
  experiment_name: str,
@@ -125,6 +157,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
125
157
  self._run = None
126
158
  return self._experiment
127
159
 
160
+ @_restore_session
128
161
  def delete_experiment(
129
162
  self,
130
163
  experiment_name: str,
@@ -141,8 +174,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
141
174
  self._run = None
142
175
 
143
176
  @functools.wraps(registry.Registry.log_model)
177
+ @_restore_session
144
178
  def log_model(
145
179
  self,
180
+ /, # self needs to be a positional argument to stop mypy from complaining
146
181
  model: Union[type_hints.SupportedModelType, ml_model.ModelVersion],
147
182
  *,
148
183
  model_name: str,
@@ -152,6 +187,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
152
187
  with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
153
188
  return self._registry.log_model(model, model_name=model_name, **kwargs)
154
189
 
190
+ @_restore_session
155
191
  def start_run(
156
192
  self,
157
193
  run_name: Optional[str] = None,
@@ -181,6 +217,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
181
217
  self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
182
218
  return self._run
183
219
 
220
+ @_restore_session
184
221
  def end_run(self, run_name: Optional[str] = None) -> None:
185
222
  """
186
223
  End the current run if no run name is provided. Otherwise, the specified run is ended.
@@ -210,6 +247,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
210
247
  self._run = None
211
248
  self._print_urls(experiment_name=experiment_name, run_name=run_name)
212
249
 
250
+ @_restore_session
213
251
  def delete_run(
214
252
  self,
215
253
  run_name: str,
@@ -248,6 +286,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
248
286
  """
249
287
  self.log_metrics(metrics={key: value}, step=step)
250
288
 
289
+ @_restore_session
251
290
  def log_metrics(
252
291
  self,
253
292
  metrics: dict[str, float],
@@ -261,13 +300,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
261
300
  step: The step of the metrics. Defaults to 0.
262
301
  """
263
302
  run = self._get_or_start_run()
264
- metadata = run._get_metadata()
303
+ metrics_list = []
265
304
  for key, value in metrics.items():
266
- metadata.set_metric(key, value, step)
267
- self._sql_client.modify_run(
305
+ metrics_list.append(entities.Metric(key, value, step))
306
+ self._sql_client.modify_run_add_metrics(
268
307
  experiment_name=run.experiment_name,
269
308
  run_name=run.name,
270
- run_metadata=json.dumps(metadata.to_dict()),
309
+ metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
271
310
  )
272
311
 
273
312
  def log_param(
@@ -284,6 +323,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
284
323
  """
285
324
  self.log_params({key: value})
286
325
 
326
+ @_restore_session
287
327
  def log_params(
288
328
  self,
289
329
  params: dict[str, Any],
@@ -296,15 +336,16 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
296
336
  to string.
297
337
  """
298
338
  run = self._get_or_start_run()
299
- metadata = run._get_metadata()
339
+ params_list = []
300
340
  for key, value in params.items():
301
- metadata.set_param(key, value)
302
- self._sql_client.modify_run(
341
+ params_list.append(entities.Param(key, str(value)))
342
+ self._sql_client.modify_run_add_params(
303
343
  experiment_name=run.experiment_name,
304
344
  run_name=run.name,
305
- run_metadata=json.dumps(metadata.to_dict()),
345
+ params=json.dumps([param.to_dict() for param in params_list]),
306
346
  )
307
347
 
348
+ @_restore_session
308
349
  def log_artifact(
309
350
  self,
310
351
  local_path: str,
@@ -328,6 +369,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
328
369
  file_path=file_path,
329
370
  )
330
371
 
372
+ @_restore_session
331
373
  def list_artifacts(
332
374
  self,
333
375
  run_name: str,
@@ -356,6 +398,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
356
398
  artifact_path=artifact_path or "",
357
399
  )
358
400
 
401
+ @_restore_session
359
402
  def download_artifacts(
360
403
  self,
361
404
  run_name: str,
@@ -397,6 +440,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
397
440
  return self._run
398
441
  return self.start_run()
399
442
 
443
+ @_restore_session
400
444
  def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
401
445
  generator = hrid_generator.HRID16()
402
446
  existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)