snowflake-ml-python 1.15.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/_internal/platform_capabilities.py +4 -0
  4. snowflake/ml/_internal/utils/mixins.py +24 -9
  5. snowflake/ml/experiment/experiment_tracking.py +63 -19
  6. snowflake/ml/jobs/__init__.py +4 -0
  7. snowflake/ml/jobs/_interop/__init__.py +0 -0
  8. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  9. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  10. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  11. snowflake/ml/jobs/_interop/legacy.py +225 -0
  12. snowflake/ml/jobs/_interop/protocols.py +471 -0
  13. snowflake/ml/jobs/_interop/results.py +51 -0
  14. snowflake/ml/jobs/_interop/utils.py +144 -0
  15. snowflake/ml/jobs/_utils/constants.py +4 -1
  16. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  17. snowflake/ml/jobs/_utils/payload_utils.py +1 -1
  18. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  19. snowflake/ml/jobs/_utils/spec_utils.py +50 -11
  20. snowflake/ml/jobs/_utils/types.py +10 -0
  21. snowflake/ml/jobs/job.py +168 -36
  22. snowflake/ml/jobs/manager.py +54 -36
  23. snowflake/ml/model/__init__.py +16 -2
  24. snowflake/ml/model/_client/model/batch_inference_specs.py +18 -2
  25. snowflake/ml/model/_client/model/model_version_impl.py +44 -7
  26. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  27. snowflake/ml/model/_client/ops/service_ops.py +50 -5
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  29. snowflake/ml/model/_client/sql/model_version.py +3 -1
  30. snowflake/ml/model/_client/sql/stage.py +8 -0
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  32. snowflake/ml/model/_model_composer/model_method/model_method.py +32 -4
  33. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  34. snowflake/ml/model/_packager/model_env/model_env.py +48 -21
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  38. snowflake/ml/model/type_hints.py +13 -0
  39. snowflake/ml/model/volatility.py +34 -0
  40. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  41. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  42. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  43. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  44. snowflake/ml/modeling/cluster/birch.py +1 -1
  45. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  46. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  47. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  48. snowflake/ml/modeling/cluster/k_means.py +1 -1
  49. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  50. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  51. snowflake/ml/modeling/cluster/optics.py +1 -1
  52. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  53. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  54. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  55. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  56. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  57. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  58. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  59. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  60. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  61. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  62. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  63. snowflake/ml/modeling/covariance/oas.py +1 -1
  64. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  65. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  66. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  67. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  68. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  69. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  70. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  71. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  72. snowflake/ml/modeling/decomposition/pca.py +1 -1
  73. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  74. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  75. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  76. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  77. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  78. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  79. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  82. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  83. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  84. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  85. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  88. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  89. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  90. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  91. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  92. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  93. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  94. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  95. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  96. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  97. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  98. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  99. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  100. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  101. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  102. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  103. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  104. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  105. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  106. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  107. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  108. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  109. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  110. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  111. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  112. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  113. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  114. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  115. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  116. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  117. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  119. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  120. snowflake/ml/modeling/linear_model/lars.py +1 -1
  121. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  122. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  123. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  124. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  125. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  126. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  127. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  128. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  129. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  130. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  131. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  132. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  133. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  135. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  136. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  137. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  138. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  140. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  141. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  142. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  143. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  144. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  145. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  146. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  147. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  148. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  149. snowflake/ml/modeling/manifold/isomap.py +1 -1
  150. snowflake/ml/modeling/manifold/mds.py +1 -1
  151. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  152. snowflake/ml/modeling/manifold/tsne.py +1 -1
  153. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  154. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  155. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  156. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  157. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  158. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  159. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  160. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  161. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  162. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  163. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  164. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  165. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  166. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  167. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  168. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  169. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  170. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  171. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  172. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  173. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  174. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  175. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  176. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  177. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  178. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  179. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  180. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  181. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  182. snowflake/ml/modeling/svm/svc.py +1 -1
  183. snowflake/ml/modeling/svm/svr.py +1 -1
  184. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  185. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  186. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  187. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  188. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  189. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  191. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  192. snowflake/ml/registry/_manager/model_manager.py +1 -0
  193. snowflake/ml/registry/_manager/model_parameter_reconciler.py +27 -0
  194. snowflake/ml/registry/registry.py +15 -0
  195. snowflake/ml/utils/authentication.py +16 -0
  196. snowflake/ml/version.py +1 -1
  197. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +65 -5
  198. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +201 -192
  199. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
  200. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
  201. {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py CHANGED
@@ -12,12 +12,19 @@ from snowflake import snowpark
12
12
  from snowflake.ml._internal import telemetry
13
13
  from snowflake.ml._internal.utils import identifier
14
14
  from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
15
- from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types
15
+ from snowflake.ml.jobs._interop import results as interop_result, utils as interop_utils
16
+ from snowflake.ml.jobs._utils import (
17
+ constants,
18
+ payload_utils,
19
+ query_helper,
20
+ stage_utils,
21
+ types,
22
+ )
16
23
  from snowflake.snowpark import Row, context as sp_context
17
24
  from snowflake.snowpark.exceptions import SnowparkSQLException
18
25
 
19
26
  _PROJECT = "MLJob"
20
- TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
27
+ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR", "DELETED"}
21
28
 
22
29
  T = TypeVar("T")
23
30
 
@@ -36,7 +43,12 @@ class MLJob(Generic[T], SerializableSessionMixin):
36
43
  self._session = session or sp_context.get_active_session()
37
44
 
38
45
  self._status: types.JOB_STATUS = "PENDING"
39
- self._result: Optional[interop_utils.ExecutionResult] = None
46
+ self._result: Optional[interop_result.ExecutionResult] = None
47
+
48
+ @cached_property
49
+ def _service_info(self) -> types.ServiceInfo:
50
+ """Get the job's service info."""
51
+ return _resolve_service_info(self.id, self._session)
40
52
 
41
53
  @cached_property
42
54
  def name(self) -> str:
@@ -44,7 +56,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
44
56
 
45
57
  @cached_property
46
58
  def target_instances(self) -> int:
47
- return _get_target_instances(self._session, self.id)
59
+ return self._service_info.target_instances
48
60
 
49
61
  @cached_property
50
62
  def min_instances(self) -> int:
@@ -69,8 +81,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
69
81
  @cached_property
70
82
  def _compute_pool(self) -> str:
71
83
  """Get the job's compute pool name."""
72
- row = _get_service_info(self._session, self.id)
73
- return cast(str, row["compute_pool"])
84
+ return self._service_info.compute_pool
74
85
 
75
86
  @property
76
87
  def _service_spec(self) -> dict[str, Any]:
@@ -82,7 +93,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
82
93
  @property
83
94
  def _container_spec(self) -> dict[str, Any]:
84
95
  """Get the job's main container spec."""
85
- containers = self._service_spec["spec"]["containers"]
96
+ try:
97
+ containers = self._service_spec["spec"]["containers"]
98
+ except SnowparkSQLException as e:
99
+ if e.sql_error_code == 2003:
100
+ # If the job is deleted, the service spec is not available
101
+ return {}
102
+ raise
86
103
  if len(containers) == 1:
87
104
  return cast(dict[str, Any], containers[0])
88
105
  try:
@@ -105,22 +122,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
105
122
  if result_path_str is None:
106
123
  raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
107
124
 
108
- # If result path is relative, it is relative to the stage mount path
109
- result_path = Path(result_path_str)
110
- if not result_path.is_absolute():
111
- return f"{self._stage_path}/{result_path.as_posix()}"
125
+ return self._transform_path(result_path_str)
112
126
 
113
- # If result path is absolute, it is relative to the stage mount path
127
+ def _transform_path(self, path_str: str) -> str:
128
+ """Transform a local path within the container to a stage path."""
129
+ path = payload_utils.resolve_path(path_str)
130
+ if isinstance(path, stage_utils.StagePath):
131
+ # Stage paths need no transformation
132
+ return path.as_posix()
133
+ if not path.is_absolute():
134
+ # Assume relative paths are relative to stage mount path
135
+ return f"{self._stage_path}/{path.as_posix()}"
136
+
137
+ # If result path is absolute, rebase it onto the stage mount path
138
+ # TODO: Rather than matching by name, use the longest mount path which matches
114
139
  volume_mounts = self._container_spec["volumeMounts"]
115
140
  stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
116
141
  stage_mount = Path(stage_mount_str)
117
142
  try:
118
- relative_path = result_path.relative_to(stage_mount)
143
+ relative_path = path.relative_to(stage_mount)
119
144
  return f"{self._stage_path}/{relative_path.as_posix()}"
120
145
  except ValueError:
121
- raise ValueError(
122
- f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
123
- )
146
+ raise ValueError(f"Result path {path} is absolute, but should be relative to stage mount {stage_mount}")
124
147
 
125
148
  @overload
126
149
  def get_logs(
@@ -165,7 +188,14 @@ class MLJob(Generic[T], SerializableSessionMixin):
165
188
  Returns:
166
189
  The job's execution logs.
167
190
  """
168
- logs = _get_logs(self._session, self.id, limit, instance_id, self._container_spec["name"], verbose)
191
+ logs = _get_logs(
192
+ self._session,
193
+ self.id,
194
+ limit,
195
+ instance_id,
196
+ self._container_spec["name"] if "name" in self._container_spec else constants.DEFAULT_CONTAINER_NAME,
197
+ verbose,
198
+ )
169
199
  assert isinstance(logs, str) # mypy
170
200
  if as_list:
171
201
  return logs.splitlines()
@@ -218,7 +248,6 @@ class MLJob(Generic[T], SerializableSessionMixin):
218
248
  delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
219
249
  return self.status
220
250
 
221
- @snowpark._internal.utils.private_preview(version="1.8.2")
222
251
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
223
252
  def result(self, timeout: float = -1) -> T:
224
253
  """
@@ -237,13 +266,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
237
266
  if self._result is None:
238
267
  self.wait(timeout)
239
268
  try:
240
- self._result = interop_utils.fetch_result(self._session, self._result_path)
269
+ self._result = interop_utils.load_result(
270
+ self._result_path, session=self._session, path_transform=self._transform_path
271
+ )
241
272
  except Exception as e:
242
- raise RuntimeError(f"Failed to retrieve result for job (id={self.name})") from e
273
+ raise RuntimeError(f"Failed to retrieve result for job, error: {e!r}") from e
243
274
 
244
- if self._result.success:
245
- return cast(T, self._result.result)
246
- raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
275
+ return cast(T, self._result.get_value())
247
276
 
248
277
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
249
278
  def cancel(self) -> None:
@@ -256,22 +285,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
256
285
  self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
257
286
  logger.debug(f"Cancellation requested for job {self.id}")
258
287
  except SnowparkSQLException as e:
259
- raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e
288
+ raise RuntimeError(f"Failed to cancel job, error: {e!r}") from e
260
289
 
261
290
 
262
291
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
263
292
  def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
264
293
  """Retrieve job or job instance execution status."""
265
- if instance_id is not None:
266
- # Get specific instance status
267
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
268
- for row in rows:
269
- if row["instance_id"] == str(instance_id):
270
- return cast(types.JOB_STATUS, row["status"])
271
- raise ValueError(f"Instance {instance_id} not found in job {job_id}")
272
- else:
273
- row = _get_service_info(session, job_id)
274
- return cast(types.JOB_STATUS, row["status"])
294
+ try:
295
+ if instance_id is not None:
296
+ # Get specific instance status
297
+ rows = query_helper.run_query(session, "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,))
298
+ for row in rows:
299
+ if row["instance_id"] == str(instance_id):
300
+ return cast(types.JOB_STATUS, row["status"])
301
+ raise ValueError(f"Instance {instance_id} not found in job {job_id}")
302
+ else:
303
+ row = _get_service_info(session, job_id)
304
+ return cast(types.JOB_STATUS, row["status"])
305
+ except SnowparkSQLException as e:
306
+ if e.sql_error_code == 2003:
307
+ row = _get_service_info_spcs(session, job_id)
308
+ return cast(types.JOB_STATUS, row["STATUS"])
309
+ raise
275
310
 
276
311
 
277
312
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
@@ -542,8 +577,21 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
542
577
 
543
578
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
544
579
  def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
545
- row = _get_service_info(session, job_id)
546
- return int(row["target_instances"])
580
+ try:
581
+ row = _get_service_info(session, job_id)
582
+ return int(row["target_instances"])
583
+ except SnowparkSQLException as e:
584
+ if e.sql_error_code == 2003:
585
+ row = _get_service_info_spcs(session, job_id)
586
+ try:
587
+ params = json.loads(row["PARAMETERS"])
588
+ if isinstance(params, dict):
589
+ return int(params.get("REPLICAS", 1))
590
+ else:
591
+ return 1
592
+ except (json.JSONDecodeError, ValueError):
593
+ return 1
594
+ raise
547
595
 
548
596
 
549
597
  def _get_logs_spcs(
@@ -581,3 +629,87 @@ def _get_logs_spcs(
581
629
  query.append(f" LIMIT {limit};")
582
630
  rows = session.sql("\n".join(query)).collect()
583
631
  return rows
632
+
633
+
634
+ def _get_service_info_spcs(session: snowpark.Session, job_id: str) -> Any:
635
+ """
636
+ Retrieve the service info from the SPCS interface.
637
+
638
+ Args:
639
+ session (Session): The Snowpark session to use.
640
+ job_id (str): The job ID.
641
+
642
+ Returns:
643
+ Any: The service info.
644
+
645
+ Raises:
646
+ SnowparkSQLException: If the job does not exist or is too old to retrieve.
647
+ """
648
+ db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
649
+ db = db or session.get_current_database()
650
+ schema = schema or session.get_current_schema()
651
+ rows = query_helper.run_query(
652
+ session,
653
+ """
654
+ select DATABASE_NAME, SCHEMA_NAME, NAME, STATUS, COMPUTE_POOL_NAME, PARAMETERS
655
+ from table(snowflake.spcs.get_job_history())
656
+ where database_name = ? and schema_name = ? and name = ?
657
+ """,
658
+ params=(db, schema, name),
659
+ )
660
+ if rows:
661
+ return rows[0]
662
+ else:
663
+ raise SnowparkSQLException(f"Job {job_id} does not exist or could not be retrieved", sql_error_code=2003)
664
+
665
+
666
+ def _resolve_service_info(id: str, session: snowpark.Session) -> types.ServiceInfo:
667
+ try:
668
+ row = _get_service_info(session, id)
669
+ except SnowparkSQLException as e:
670
+ if e.sql_error_code == 2003:
671
+ row = _get_service_info_spcs(session, id)
672
+ else:
673
+ raise
674
+ if not row:
675
+ raise SnowparkSQLException(f"Job {id} does not exist or could not be retrieved", sql_error_code=2003)
676
+
677
+ if "compute_pool" in row:
678
+ compute_pool = row["compute_pool"]
679
+ elif "COMPUTE_POOL_NAME" in row:
680
+ compute_pool = row["COMPUTE_POOL_NAME"]
681
+ else:
682
+ raise ValueError(f"compute_pool not found in row: {row}")
683
+
684
+ if "status" in row:
685
+ status = row["status"]
686
+ elif "STATUS" in row:
687
+ status = row["STATUS"]
688
+ else:
689
+ raise ValueError(f"status not found in row: {row}")
690
+ # Normalize target_instances
691
+ target_instances: int
692
+ if "target_instances" in row and row["target_instances"] is not None:
693
+ try:
694
+ target_instances = int(row["target_instances"])
695
+ except (ValueError, TypeError):
696
+ target_instances = 1
697
+ elif "PARAMETERS" in row and row["PARAMETERS"]:
698
+ try:
699
+ params = json.loads(row["PARAMETERS"])
700
+ target_instances = int(params.get("REPLICAS", 1)) if isinstance(params, dict) else 1
701
+ except (json.JSONDecodeError, ValueError, TypeError):
702
+ target_instances = 1
703
+ else:
704
+ target_instances = 1
705
+
706
+ database_name = row["database_name"] if "database_name" in row else row["DATABASE_NAME"]
707
+ schema_name = row["schema_name"] if "schema_name" in row else row["SCHEMA_NAME"]
708
+
709
+ return types.ServiceInfo(
710
+ database_name=database_name,
711
+ schema_name=schema_name,
712
+ status=cast(types.JOB_STATUS, status),
713
+ compute_pool=cast(str, compute_pool),
714
+ target_instances=target_instances,
715
+ )
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import pathlib
4
+ import sys
4
5
  import textwrap
5
6
  from pathlib import PurePath
6
7
  from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
@@ -20,6 +21,7 @@ from snowflake.ml.jobs._utils import (
20
21
  spec_utils,
21
22
  types,
22
23
  )
24
+ from snowflake.snowpark._internal import utils as sp_utils
23
25
  from snowflake.snowpark.context import get_active_session
24
26
  from snowflake.snowpark.exceptions import SnowparkSQLException
25
27
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -178,8 +180,10 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
178
180
  _ = job._service_spec
179
181
  return job
180
182
  except SnowparkSQLException as e:
181
- if "does not exist" in e.message:
182
- raise ValueError(f"Job does not exist: {job_id}") from e
183
+ if e.sql_error_code == 2003:
184
+ job = jb.MLJob[Any](job_id, session=session)
185
+ _ = job.status
186
+ return job
183
187
  raise
184
188
 
185
189
 
@@ -344,6 +348,9 @@ def submit_from_stage(
344
348
  query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
345
349
  spec_overrides (dict): A dictionary of overrides for the service spec.
346
350
  imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
351
+ runtime_environment (str): The runtime image to use. Only support image tag or full image URL,
352
+ e.g. "1.7.1" or "image_repo/image_name:image_tag". When it refers to a full image URL,
353
+ it should contain image repository, image name and image tag.
347
354
 
348
355
  Returns:
349
356
  An object representing the submitted job.
@@ -409,6 +416,7 @@ def _submit_job(
409
416
  "min_instances",
410
417
  "enable_metrics",
411
418
  "query_warehouse",
419
+ "runtime_environment",
412
420
  ],
413
421
  )
414
422
  def _submit_job(
@@ -441,7 +449,7 @@ def _submit_job(
441
449
  Raises:
442
450
  ValueError: If database or schema value(s) are invalid
443
451
  RuntimeError: If schema is not specified in session context or job submission
444
- snowpark.exceptions.SnowparkSQLException: if failed to upload payload
452
+ SnowparkSQLException: if failed to upload payload
445
453
  """
446
454
  session = _ensure_session(session)
447
455
 
@@ -459,6 +467,9 @@ def _submit_job(
459
467
  )
460
468
  imports = kwargs.pop("additional_payloads")
461
469
 
470
+ if "runtime_environment" in kwargs:
471
+ logger.warning("'runtime_environment' is in private preview since 1.15.0, do not use it in production.")
472
+
462
473
  # Use kwargs for less common optional parameters
463
474
  database = kwargs.pop("database", None)
464
475
  schema = kwargs.pop("schema", None)
@@ -470,6 +481,7 @@ def _submit_job(
470
481
  enable_metrics = kwargs.pop("enable_metrics", True)
471
482
  query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
472
483
  imports = kwargs.pop("imports", None) or imports
484
+ runtime_environment = kwargs.pop("runtime_environment", None)
473
485
 
474
486
  # Warn if there are unknown kwargs
475
487
  if kwargs:
@@ -503,48 +515,44 @@ def _submit_job(
503
515
  uploaded_payload = payload_utils.JobPayload(
504
516
  source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
505
517
  ).upload(session, stage_path)
506
- except snowpark.exceptions.SnowparkSQLException as e:
518
+ except SnowparkSQLException as e:
507
519
  if e.sql_error_code == 90106:
508
520
  raise RuntimeError(
509
521
  "Please specify a schema, either in the session context or as a parameter in the job submission"
510
522
  )
511
523
  raise
512
524
 
513
- # FIXME: Temporary patches, remove this after v1 is deprecated
514
- if target_instances > 1:
515
- default_spec_overrides = {
516
- "spec": {
517
- "endpoints": [
518
- {"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
519
- ]
520
- },
521
- }
522
- if spec_overrides:
523
- spec_overrides = spec_utils.merge_patch(
524
- default_spec_overrides, spec_overrides, display_name="spec_overrides"
525
- )
526
- else:
527
- spec_overrides = default_spec_overrides
528
-
529
- if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
525
+ if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True):
530
526
  # Add default env vars (extracted from spec_utils.generate_service_spec)
531
527
  combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
532
528
 
533
- return _do_submit_job_v2(
534
- session=session,
535
- payload=uploaded_payload,
536
- args=args,
537
- env_vars=combined_env_vars,
538
- spec_overrides=spec_overrides,
539
- compute_pool=compute_pool,
540
- job_id=job_id,
541
- external_access_integrations=external_access_integrations,
542
- query_warehouse=query_warehouse,
543
- target_instances=target_instances,
544
- min_instances=min_instances,
545
- enable_metrics=enable_metrics,
546
- use_async=True,
547
- )
529
+ try:
530
+ return _do_submit_job_v2(
531
+ session=session,
532
+ payload=uploaded_payload,
533
+ args=args,
534
+ env_vars=combined_env_vars,
535
+ spec_overrides=spec_overrides,
536
+ compute_pool=compute_pool,
537
+ job_id=job_id,
538
+ external_access_integrations=external_access_integrations,
539
+ query_warehouse=query_warehouse,
540
+ target_instances=target_instances,
541
+ min_instances=min_instances,
542
+ enable_metrics=enable_metrics,
543
+ use_async=True,
544
+ runtime_environment=runtime_environment,
545
+ )
546
+ except SnowparkSQLException as e:
547
+ if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
548
+ raise
549
+ # SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
550
+ # stored procedures. This will be fixed in an upcoming release.
551
+ logger.warning(
552
+ "Job submission using V2 failed with error {}. Falling back to V1.".format(
553
+ str(e).split("\n", 1)[0],
554
+ )
555
+ )
548
556
 
549
557
  # Fall back to v1
550
558
  # Generate service spec
@@ -556,6 +564,7 @@ def _submit_job(
556
564
  target_instances=target_instances,
557
565
  min_instances=min_instances,
558
566
  enable_metrics=enable_metrics,
567
+ runtime_environment=runtime_environment,
559
568
  )
560
569
 
561
570
  # Generate spec overrides
@@ -639,6 +648,7 @@ def _do_submit_job_v2(
639
648
  min_instances: int = 1,
640
649
  enable_metrics: bool = True,
641
650
  use_async: bool = True,
651
+ runtime_environment: Optional[str] = None,
642
652
  ) -> jb.MLJob[Any]:
643
653
  """
644
654
  Generate the SQL query for job submission.
@@ -657,6 +667,7 @@ def _do_submit_job_v2(
657
667
  min_instances: Minimum number of instances required to start the job.
658
668
  enable_metrics: Whether to enable platform metrics for the job.
659
669
  use_async: Whether to run the job asynchronously.
670
+ runtime_environment: image tag or full image URL to use for the job.
660
671
 
661
672
  Returns:
662
673
  The job object.
@@ -672,6 +683,13 @@ def _do_submit_job_v2(
672
683
  "ENABLE_METRICS": enable_metrics,
673
684
  "SPEC_OVERRIDES": spec_overrides,
674
685
  }
686
+ # for the image tag or full image URL, we use that directly
687
+ if runtime_environment:
688
+ spec_options["RUNTIME"] = runtime_environment
689
+ elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
690
+ # when feature flag is enabled, we get the local python version and wrap it in a dict
691
+ # in system function, we can know whether it is python version or image tag or full image URL through the format
692
+ spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
675
693
  job_options = {
676
694
  "EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
677
695
  "QUERY_WAREHOUSE": query_warehouse,
@@ -1,6 +1,20 @@
1
- from snowflake.ml.model._client.model.batch_inference_specs import JobSpec, OutputSpec
1
+ from snowflake.ml.model._client.model.batch_inference_specs import (
2
+ JobSpec,
3
+ OutputSpec,
4
+ SaveMode,
5
+ )
2
6
  from snowflake.ml.model._client.model.model_impl import Model
3
7
  from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
4
8
  from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
9
+ from snowflake.ml.model.volatility import Volatility
5
10
 
6
- __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "JobSpec", "OutputSpec"]
11
+ __all__ = [
12
+ "Model",
13
+ "ModelVersion",
14
+ "ExportMode",
15
+ "HuggingFacePipelineModel",
16
+ "JobSpec",
17
+ "OutputSpec",
18
+ "SaveMode",
19
+ "Volatility",
20
+ ]
@@ -1,10 +1,26 @@
1
- from typing import Optional, Union
1
+ from enum import Enum
2
+ from typing import Optional
2
3
 
3
4
  from pydantic import BaseModel
4
5
 
5
6
 
7
+ class SaveMode(str, Enum):
8
+ """Save mode options for batch inference output.
9
+
10
+ Determines the behavior when files already exist in the output location.
11
+
12
+ OVERWRITE: Remove existing files and write new results.
13
+
14
+ ERROR: Raise an error if files already exist in the output location.
15
+ """
16
+
17
+ OVERWRITE = "overwrite"
18
+ ERROR = "error"
19
+
20
+
6
21
  class OutputSpec(BaseModel):
7
22
  stage_location: str
23
+ mode: SaveMode = SaveMode.ERROR
8
24
 
9
25
 
10
26
  class JobSpec(BaseModel):
@@ -12,10 +28,10 @@ class JobSpec(BaseModel):
12
28
  job_name: Optional[str] = None
13
29
  num_workers: Optional[int] = None
14
30
  function_name: Optional[str] = None
15
- gpu: Optional[Union[str, int]] = None
16
31
  force_rebuild: bool = False
17
32
  max_batch_rows: int = 1024
18
33
  warehouse: Optional[str] = None
19
34
  cpu_requests: Optional[str] = None
20
35
  memory_requests: Optional[str] = None
36
+ gpu_requests: Optional[str] = None
21
37
  replicas: Optional[int] = None
@@ -19,7 +19,9 @@ from snowflake.ml.model._client.model import (
19
19
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
20
20
  from snowflake.ml.model._model_composer import model_composer
21
21
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
22
+ from snowflake.ml.model._model_composer.model_method import utils as model_method_utils
22
23
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
24
+ from snowflake.ml.model._packager.model_meta import model_meta_schema
23
25
  from snowflake.snowpark import Session, async_job, dataframe
24
26
 
25
27
  _TELEMETRY_PROJECT = "MLOps"
@@ -41,6 +43,7 @@ class ModelVersion(lineage_node.LineageNode):
41
43
  _model_name: sql_identifier.SqlIdentifier
42
44
  _version_name: sql_identifier.SqlIdentifier
43
45
  _functions: list[model_manifest_schema.ModelFunctionInfo]
46
+ _model_spec: Optional[model_meta_schema.ModelMetadataDict]
44
47
 
45
48
  def __init__(self) -> None:
46
49
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
@@ -150,6 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
150
153
  self._model_name = model_name
151
154
  self._version_name = version_name
152
155
  self._functions = self._get_functions()
156
+ self._model_spec = None
153
157
  super(cls, cls).__init__(
154
158
  self,
155
159
  session=model_ops._session,
@@ -437,6 +441,26 @@ class ModelVersion(lineage_node.LineageNode):
437
441
  """
438
442
  return self._functions
439
443
 
444
+ def _get_model_spec(self, statement_params: Optional[dict[str, Any]] = None) -> model_meta_schema.ModelMetadataDict:
445
+ """Fetch and cache the model spec for this model version.
446
+
447
+ Args:
448
+ statement_params: Optional dictionary of statement parameters to include
449
+ in the SQL command to fetch the model spec.
450
+
451
+ Returns:
452
+ The model spec as a dictionary for this model version.
453
+ """
454
+ if self._model_spec is None:
455
+ self._model_spec = self._model_ops._fetch_model_spec(
456
+ database_name=None,
457
+ schema_name=None,
458
+ model_name=self._model_name,
459
+ version_name=self._version_name,
460
+ statement_params=statement_params,
461
+ )
462
+ return self._model_spec
463
+
440
464
  @overload
441
465
  def run(
442
466
  self,
@@ -531,6 +555,8 @@ class ModelVersion(lineage_node.LineageNode):
531
555
  statement_params=statement_params,
532
556
  )
533
557
  else:
558
+ explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
559
+
534
560
  return self._model_ops.invoke_method(
535
561
  method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
536
562
  method_function_type=target_function_info["target_method_function_type"],
@@ -544,13 +570,27 @@ class ModelVersion(lineage_node.LineageNode):
544
570
  partition_column=partition_column,
545
571
  statement_params=statement_params,
546
572
  is_partitioned=target_function_info["is_partitioned"],
573
+ explain_case_sensitive=explain_case_sensitive,
547
574
  )
548
575
 
576
+ def _determine_explain_case_sensitivity(
577
+ self,
578
+ target_function_info: model_manifest_schema.ModelFunctionInfo,
579
+ statement_params: Optional[dict[str, Any]] = None,
580
+ ) -> bool:
581
+ model_spec = self._get_model_spec(statement_params)
582
+ method_options = model_spec.get("method_options", {})
583
+ return model_method_utils.determine_explain_case_sensitive_from_method_options(
584
+ method_options, target_function_info["name"]
585
+ )
586
+
549
587
  @telemetry.send_api_usage_telemetry(
550
588
  project=_TELEMETRY_PROJECT,
551
589
  subproject=_TELEMETRY_SUBPROJECT,
552
590
  func_params_to_log=[
553
591
  "compute_pool",
592
+ "output_spec",
593
+ "job_spec",
554
594
  ],
555
595
  )
556
596
  def _run_batch(
@@ -579,6 +619,8 @@ class ModelVersion(lineage_node.LineageNode):
579
619
  output_stage_location += "/"
580
620
  input_stage_location = f"{output_stage_location}{_BATCH_INFERENCE_TEMPORARY_FOLDER}/"
581
621
 
622
+ self._service_ops._enforce_save_mode(output_spec.mode, output_stage_location)
623
+
582
624
  try:
583
625
  input_spec.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
584
626
  # todo: be specific about the type of errors to provide better error messages.
@@ -605,6 +647,7 @@ class ModelVersion(lineage_node.LineageNode):
605
647
  warehouse=sql_identifier.SqlIdentifier(warehouse),
606
648
  cpu_requests=job_spec.cpu_requests,
607
649
  memory_requests=job_spec.memory_requests,
650
+ gpu_requests=job_spec.gpu_requests,
608
651
  job_name=job_name,
609
652
  replicas=job_spec.replicas,
610
653
  # input and output
@@ -798,13 +841,7 @@ class ModelVersion(lineage_node.LineageNode):
798
841
  ValueError: If the model is not a HuggingFace text-generation model.
799
842
  """
800
843
  # Fetch model spec
801
- model_spec = self._model_ops._fetch_model_spec(
802
- database_name=None,
803
- schema_name=None,
804
- model_name=self._model_name,
805
- version_name=self._version_name,
806
- statement_params=statement_params,
807
- )
844
+ model_spec = self._get_model_spec(statement_params)
808
845
 
809
846
  # Check if model_type is huggingface_pipeline
810
847
  model_type = model_spec.get("model_type")