snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/ml/_internal/platform_capabilities.py +13 -11
  3. snowflake/ml/_internal/utils/identifier.py +2 -2
  4. snowflake/ml/jobs/_utils/constants.py +1 -1
  5. snowflake/ml/jobs/_utils/payload_utils.py +39 -30
  6. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  7. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +1 -1
  8. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  9. snowflake/ml/jobs/decorators.py +6 -0
  10. snowflake/ml/jobs/job.py +63 -16
  11. snowflake/ml/jobs/manager.py +50 -16
  12. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  13. snowflake/ml/model/_client/ops/service_ops.py +26 -14
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +340 -170
  15. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
  16. snowflake/ml/model/_client/sql/service.py +4 -13
  17. snowflake/ml/model/_model_composer/model_composer.py +41 -18
  18. snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
  19. snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
  20. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
  21. snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
  22. snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
  23. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  24. snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
  25. snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
  26. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +4 -4
  28. snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
  29. snowflake/ml/model/custom_model.py +17 -4
  30. snowflake/ml/model/model_signature.py +3 -3
  31. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  32. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  33. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  34. snowflake/ml/modeling/cluster/birch.py +9 -1
  35. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  36. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  37. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  38. snowflake/ml/modeling/cluster/k_means.py +9 -1
  39. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  40. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  41. snowflake/ml/modeling/cluster/optics.py +9 -1
  42. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  43. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  44. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  45. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  46. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  47. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  48. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  49. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  50. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  51. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  52. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  53. snowflake/ml/modeling/covariance/oas.py +9 -1
  54. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  55. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  56. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  57. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  58. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  59. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  60. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  61. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  62. snowflake/ml/modeling/decomposition/pca.py +9 -1
  63. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  64. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  65. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  66. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  67. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  68. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  69. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  70. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  71. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  72. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  73. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  74. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  77. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  78. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  79. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  80. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  81. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  82. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  83. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  84. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  85. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  86. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  87. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  88. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  89. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  90. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  91. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  93. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  94. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  95. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  96. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  97. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  98. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  99. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  100. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  104. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  106. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  108. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  109. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  110. snowflake/ml/modeling/linear_model/lars.py +9 -1
  111. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  112. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  113. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  114. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  115. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  116. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  117. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  118. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  119. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  120. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  122. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  123. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  124. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  125. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  126. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  127. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  128. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  129. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  130. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  131. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  132. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  133. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  134. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  135. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  136. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  137. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  138. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  139. snowflake/ml/modeling/manifold/isomap.py +9 -1
  140. snowflake/ml/modeling/manifold/mds.py +9 -1
  141. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  142. snowflake/ml/modeling/manifold/tsne.py +9 -1
  143. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  144. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  145. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  146. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  147. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  148. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  149. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  150. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  151. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  152. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  153. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  154. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  155. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  156. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  157. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  158. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  159. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  160. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  161. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  162. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  163. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  164. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  165. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  166. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  167. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  168. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  169. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  170. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  171. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  172. snowflake/ml/modeling/svm/svc.py +9 -1
  173. snowflake/ml/modeling/svm/svr.py +9 -1
  174. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  175. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  176. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  177. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  178. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  179. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  180. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  181. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  182. snowflake/ml/monitoring/explain_visualize.py +286 -0
  183. snowflake/ml/registry/_manager/model_manager.py +23 -2
  184. snowflake/ml/registry/registry.py +10 -9
  185. snowflake/ml/version.py +1 -1
  186. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +40 -8
  187. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/RECORD +190 -189
  188. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  189. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  190. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,10 @@
1
1
  from snowflake.cortex._classify_text import ClassifyText, classify_text
2
- from snowflake.cortex._complete import Complete, CompleteOptions, complete
2
+ from snowflake.cortex._complete import (
3
+ Complete,
4
+ CompleteOptions,
5
+ ConversationMessage,
6
+ complete,
7
+ )
3
8
  from snowflake.cortex._embed_text_768 import EmbedText768, embed_text_768
4
9
  from snowflake.cortex._embed_text_1024 import EmbedText1024, embed_text_1024
5
10
  from snowflake.cortex._extract_answer import ExtractAnswer, extract_answer
@@ -14,6 +19,7 @@ __all__ = [
14
19
  "Complete",
15
20
  "complete",
16
21
  "CompleteOptions",
22
+ "ConversationMessage",
17
23
  "EmbedText768",
18
24
  "embed_text_768",
19
25
  "EmbedText1024",
@@ -11,6 +11,9 @@ from snowflake.snowpark import (
11
11
  session as snowpark_session,
12
12
  )
13
13
 
14
+ LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
15
+ INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC"
16
+
14
17
 
15
18
  class PlatformCapabilities:
16
19
  """Class that retrieves platform feature values for the currently running server.
@@ -18,12 +21,12 @@ class PlatformCapabilities:
18
21
  Example usage:
19
22
  ```
20
23
  pc = PlatformCapabilities.get_instance(session)
21
- if pc.is_nested_function_enabled():
22
- # Nested functions are enabled.
23
- print("Nested functions are enabled.")
24
+ if pc.is_inlined_deployment_spec_enabled():
25
+ # Inline deployment spec is enabled.
26
+ print("Inline deployment spec is enabled.")
24
27
  else:
25
- # Nested functions are disabled.
26
- print("Nested functions are disabled or not supported.")
28
+ # Inline deployment spec is disabled.
29
+ print("Inline deployment spec is disabled or not supported.")
27
30
  ```
28
31
  """
29
32
 
@@ -50,9 +53,11 @@ class PlatformCapabilities:
50
53
 
51
54
  # For contextmanager, we need to have return type Iterator[Never]. However, Never type is introduced only in
52
55
  # Python 3.11. So, we are ignoring the type for this method.
56
+ _dummy_features: dict[str, Any] = {"dummy": "dummy"}
57
+
53
58
  @classmethod # type: ignore[arg-type]
54
59
  @contextmanager
55
- def mock_features(cls, features: dict[str, Any]) -> None: # type: ignore[misc]
60
+ def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
56
61
  logging.debug(f"Setting mock features: {features}")
57
62
  cls.set_mock_features(features)
58
63
  try:
@@ -61,14 +66,11 @@ class PlatformCapabilities:
61
66
  logging.debug(f"Clearing mock features: {features}")
62
67
  cls.clear_mock_features()
63
68
 
64
- def is_nested_function_enabled(self) -> bool:
65
- return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
66
-
67
69
  def is_inlined_deployment_spec_enabled(self) -> bool:
68
- return self._get_bool_feature("ENABLE_INLINE_DEPLOYMENT_SPEC", False)
70
+ return self._get_bool_feature(INLINE_DEPLOYMENT_SPEC_PARAMETER, False)
69
71
 
70
72
  def is_live_commit_enabled(self) -> bool:
71
- return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
73
+ return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
72
74
 
73
75
  @staticmethod
74
76
  def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
@@ -12,7 +12,7 @@ SF_IDENTIFIER_RE = re.compile(_SF_IDENTIFIER)
12
12
  _SF_SCHEMA_LEVEL_OBJECT = (
13
13
  rf"(?:(?:(?P<db>{_SF_IDENTIFIER})\.)?(?P<schema>{_SF_IDENTIFIER})\.)?(?P<object>{_SF_IDENTIFIER})"
14
14
  )
15
- _SF_STAGE_PATH = rf"{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>.*)"
15
+ _SF_STAGE_PATH = rf"@?{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>/.*)?"
16
16
  _SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT)
17
17
  _SF_STAGE_PATH_RE = re.compile(_SF_STAGE_PATH)
18
18
 
@@ -197,7 +197,7 @@ def parse_snowflake_stage_path(
197
197
  res.group("db"),
198
198
  res.group("schema"),
199
199
  res.group("object"),
200
- res.group("path"),
200
+ res.group("path") or "",
201
201
  )
202
202
 
203
203
 
@@ -13,7 +13,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
13
13
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
14
14
  DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
15
15
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
16
- DEFAULT_IMAGE_TAG = "1.0.1"
16
+ DEFAULT_IMAGE_TAG = "1.2.3"
17
17
  DEFAULT_ENTRYPOINT_PATH = "func.py"
18
18
 
19
19
  # Percent of container memory to allocate for /dev/shm volume
@@ -9,6 +9,7 @@ from pathlib import Path, PurePath
9
9
  from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
10
10
 
11
11
  import cloudpickle as cp
12
+ from packaging import version
12
13
 
13
14
  from snowflake import snowpark
14
15
  from snowflake.ml.jobs._utils import constants, types
@@ -97,11 +98,18 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
97
98
  head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
98
99
  if [ $? -eq 0 ]; then
99
100
  # Parse the output using read
100
- read head_index head_ip <<< "$head_info"
101
+ read head_index head_ip head_status<<< "$head_info"
101
102
 
102
103
  # Use the parsed variables
103
104
  echo "Head Instance Index: $head_index"
104
105
  echo "Head Instance IP: $head_ip"
106
+ echo "Head Instance Status: $head_status"
107
+
108
+ # If the head status is not "READY" or "PENDING", exit early
109
+ if [ "$head_status" != "READY" ] && [ "$head_status" != "PENDING" ]; then
110
+ echo "Head instance status is not READY or PENDING. Exiting."
111
+ exit 0
112
+ fi
105
113
 
106
114
  else
107
115
  echo "Error: Failed to get head instance information."
@@ -278,17 +286,19 @@ class JobPayload:
278
286
  stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
279
287
  source = resolve_source(self.source)
280
288
  entrypoint = resolve_entrypoint(source, self.entrypoint)
289
+ pip_requirements = self.pip_requirements or []
281
290
 
282
291
  # Create stage if necessary
283
292
  stage_name = stage_path.parts[0].lstrip("@")
284
293
  # Explicitly check if stage exists first since we may not have CREATE STAGE privilege
285
294
  try:
286
- session.sql(f"describe stage {stage_name}").collect()
295
+ session.sql("describe stage identifier(?)", params=[stage_name]).collect()
287
296
  except sp_exceptions.SnowparkSQLException:
288
297
  session.sql(
289
- f"create stage if not exists {stage_name}"
298
+ "create stage if not exists identifier(?)"
290
299
  " encryption = ( type = 'SNOWFLAKE_SSE' )"
291
- " comment = 'Created by snowflake.ml.jobs Python API'"
300
+ " comment = 'Created by snowflake.ml.jobs Python API'",
301
+ params=[stage_name],
292
302
  ).collect()
293
303
 
294
304
  # Upload payload to stage
@@ -301,6 +311,8 @@ class JobPayload:
301
311
  overwrite=True,
302
312
  )
303
313
  source = Path(entrypoint.file_path.parent)
314
+ if not any(r.startswith("cloudpickle") for r in pip_requirements):
315
+ pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
304
316
  elif source.is_dir():
305
317
  # Manually traverse the directory and upload each file, since Snowflake PUT
306
318
  # can't handle directories. Reduce the number of PUT operations by using
@@ -325,10 +337,10 @@ class JobPayload:
325
337
 
326
338
  # Upload requirements
327
339
  # TODO: Check if payload includes both a requirements.txt file and pip_requirements
328
- if self.pip_requirements:
340
+ if pip_requirements:
329
341
  # Upload requirements.txt to stage
330
342
  session.file.put_stream(
331
- io.BytesIO("\n".join(self.pip_requirements).encode()),
343
+ io.BytesIO("\n".join(pip_requirements).encode()),
332
344
  stage_location=stage_path.joinpath("requirements.txt").as_posix(),
333
345
  auto_compress=False,
334
346
  overwrite=True,
@@ -495,13 +507,6 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
495
507
  # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
496
508
  source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
497
509
 
498
- func_code = f"""
499
- {source_code_comment}
500
-
501
- import pickle
502
- {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
503
- """
504
-
505
510
  arg_dict_name = "kwargs"
506
511
  if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
507
512
  param_code = f"{arg_dict_name} = {{}}"
@@ -509,25 +514,29 @@ import pickle
509
514
  param_code = _generate_param_handler_code(signature, arg_dict_name)
510
515
 
511
516
  return f"""
512
- ### Version guard to check compatibility across Python versions ###
513
- import os
514
517
  import sys
515
- import warnings
516
-
517
- if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
518
- warnings.warn(
519
- "Python version mismatch: job was created using"
520
- " python{sys.version_info.major}.{sys.version_info.minor}"
521
- f" but runtime environment uses python{{sys.version_info.major}}.{{sys.version_info.minor}}."
522
- " Compatibility across Python versions is not guaranteed and may result in unexpected behavior."
523
- " This will be fixed in a future release; for now, please use Python version"
524
- f" {{sys.version_info.major}}.{{sys.version_info.minor}}.",
525
- RuntimeWarning,
526
- stacklevel=0,
527
- )
528
- ### End version guard ###
518
+ import pickle
529
519
 
530
- {func_code.strip()}
520
+ try:
521
+ {textwrap.indent(source_code_comment, ' ')}
522
+ {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
523
+ except (TypeError, pickle.PickleError):
524
+ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
525
+ raise RuntimeError(
526
+ "Failed to deserialize function due to Python version mismatch."
527
+ f" Runtime environment is Python {{sys.version_info.major}}.{{sys.version_info.minor}}"
528
+ " but function was serialized using Python {sys.version_info.major}.{sys.version_info.minor}."
529
+ ) from None
530
+ raise
531
+ except AttributeError as e:
532
+ if 'cloudpickle' in str(e):
533
+ import cloudpickle as cp
534
+ raise RuntimeError(
535
+ "Failed to deserialize function due to cloudpickle version mismatch."
536
+ f" Runtime environment uses cloudpickle=={{cp.__version__}}"
537
+ " but job was serialized using cloudpickle=={cp.__version__}."
538
+ ) from e
539
+ raise
531
540
 
532
541
  if __name__ == '__main__':
533
542
  {textwrap.indent(param_code, ' ')}
@@ -29,7 +29,7 @@ def get_self_ip() -> Optional[str]:
29
29
  return None
30
30
 
31
31
 
32
- def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
32
+ def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
33
33
  """Get the first instance of a batch job based on start time and instance ID.
34
34
 
35
35
  Args:
@@ -42,7 +42,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
42
42
 
43
43
  session = session_utils.get_session()
44
44
  df = session.sql(f"show service instances in service {service_name}")
45
- result = df.select('"instance_id"', '"ip_address"', '"start_time"').collect()
45
+ result = df.select('"instance_id"', '"ip_address"', '"start_time"', '"status"').collect()
46
46
 
47
47
  if not result:
48
48
  return None
@@ -57,7 +57,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
57
57
  ip_address = head_instance["ip_address"]
58
58
  try:
59
59
  socket.inet_aton(ip_address) # Validate IPv4 address
60
- return (head_instance["instance_id"], ip_address)
60
+ return (head_instance["instance_id"], ip_address, head_instance["status"])
61
61
  except OSError:
62
62
  logger.error(f"Error: Invalid IP address format: {ip_address}")
63
63
  return None
@@ -110,7 +110,7 @@ def main():
110
110
  head_info = get_first_instance(args.service_name)
111
111
  if head_info:
112
112
  # Print to stdout to allow capture but don't use logger
113
- sys.stdout.write(f"{head_info[0]} {head_info[1]}\n")
113
+ sys.stdout.write(" ".join(head_info) + "\n")
114
114
  sys.exit(0)
115
115
  time.sleep(args.retry_interval)
116
116
  # If we get here, we've timed out
@@ -59,7 +59,7 @@ class SimpleJSONEncoder(json.JSONEncoder):
59
59
  try:
60
60
  return super().default(obj)
61
61
  except TypeError:
62
- return str(obj)
62
+ return f"Unserializable object: {repr(obj)}"
63
63
 
64
64
 
65
65
  def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
@@ -11,7 +11,7 @@ from snowflake.ml.jobs._utils import constants, types
11
11
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
12
12
  """Extract resource information for the specified compute pool"""
13
13
  # Get the instance family
14
- rows = session.sql(f"show compute pools like '{compute_pool}'").collect()
14
+ rows = session.sql("show compute pools like ?", params=[compute_pool]).collect()
15
15
  if not rows:
16
16
  raise ValueError(f"Compute pool '{compute_pool}' not found")
17
17
  instance_family: str = rows[0]["instance_family"]
@@ -26,6 +26,8 @@ def remote(
26
26
  env_vars: Optional[dict[str, str]] = None,
27
27
  num_instances: Optional[int] = None,
28
28
  enable_metrics: bool = False,
29
+ database: Optional[str] = None,
30
+ schema: Optional[str] = None,
29
31
  session: Optional[snowpark.Session] = None,
30
32
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
31
33
  """
@@ -40,6 +42,8 @@ def remote(
40
42
  env_vars: Environment variables to set in container
41
43
  num_instances: The number of nodes in the job. If none specified, create a single node job.
42
44
  enable_metrics: Whether to enable metrics publishing for the job.
45
+ database: The database to use for the job.
46
+ schema: The schema to use for the job.
43
47
  session: The Snowpark session to use. If none specified, uses active session.
44
48
 
45
49
  Returns:
@@ -67,6 +71,8 @@ def remote(
67
71
  env_vars=env_vars,
68
72
  num_instances=num_instances,
69
73
  enable_metrics=enable_metrics,
74
+ database=database,
75
+ schema=schema,
70
76
  session=session,
71
77
  )
72
78
  assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
snowflake/ml/jobs/job.py CHANGED
@@ -1,12 +1,15 @@
1
1
  import time
2
+ from functools import cached_property
2
3
  from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
3
4
 
4
5
  import yaml
5
6
 
6
7
  from snowflake import snowpark
7
8
  from snowflake.ml._internal import telemetry
9
+ from snowflake.ml._internal.utils import identifier
8
10
  from snowflake.ml.jobs._utils import constants, interop_utils, types
9
- from snowflake.snowpark import context as sp_context
11
+ from snowflake.snowpark import Row, context as sp_context
12
+ from snowflake.snowpark.exceptions import SnowparkSQLException
10
13
 
11
14
  _PROJECT = "MLJob"
12
15
  TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
@@ -28,6 +31,14 @@ class MLJob(Generic[T]):
28
31
  self._status: types.JOB_STATUS = "PENDING"
29
32
  self._result: Optional[interop_utils.ExecutionResult] = None
30
33
 
34
+ @cached_property
35
+ def name(self) -> str:
36
+ return identifier.parse_schema_level_object_identifier(self.id)[-1]
37
+
38
+ @cached_property
39
+ def num_instances(self) -> int:
40
+ return _get_num_instances(self._session, self.id)
41
+
31
42
  @property
32
43
  def id(self) -> str:
33
44
  """Get the unique job ID"""
@@ -67,7 +78,7 @@ class MLJob(Generic[T]):
67
78
  """Get the job's result file location."""
68
79
  result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
69
80
  if result_path is None:
70
- raise RuntimeError(f"Job {self.id} doesn't have a result path configured")
81
+ raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
71
82
  return f"{self._stage_path}/{result_path}"
72
83
 
73
84
  @overload
@@ -128,7 +139,7 @@ class MLJob(Generic[T]):
128
139
  start_time = time.monotonic()
129
140
  while self.status not in TERMINAL_JOB_STATUSES:
130
141
  if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
131
- raise TimeoutError(f"Job {self.id} did not complete within {elapsed} seconds")
142
+ raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
132
143
  time.sleep(delay)
133
144
  delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
134
145
  return self.status
@@ -154,11 +165,11 @@ class MLJob(Generic[T]):
154
165
  try:
155
166
  self._result = interop_utils.fetch_result(self._session, self._result_path)
156
167
  except Exception as e:
157
- raise RuntimeError(f"Failed to retrieve result for job (id={self.id})") from e
168
+ raise RuntimeError(f"Failed to retrieve result for job (id={self.name})") from e
158
169
 
159
170
  if self._result.success:
160
171
  return cast(T, self._result.result)
161
- raise RuntimeError(f"Job execution failed (id={self.id})") from self._result.exception
172
+ raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
162
173
 
163
174
 
164
175
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
@@ -172,14 +183,14 @@ def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[in
172
183
  return cast(types.JOB_STATUS, row["status"])
173
184
  raise ValueError(f"Instance {instance_id} not found in job {job_id}")
174
185
  else:
175
- (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
186
+ row = _get_service_info(session, job_id)
176
187
  return cast(types.JOB_STATUS, row["status"])
177
188
 
178
189
 
179
190
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
180
191
  def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
181
192
  """Retrieve job execution service spec."""
182
- (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=[job_id]).collect()
193
+ row = _get_service_info(session, job_id)
183
194
  return cast(dict[str, Any], yaml.safe_load(row["spec"]))
184
195
 
185
196
 
@@ -196,10 +207,21 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
196
207
 
197
208
  Returns:
198
209
  The job's execution logs.
210
+
211
+ Raises:
212
+ SnowparkSQLException: if the container is pending
213
+ RuntimeError: if failed to get head instance_id
214
+
199
215
  """
200
216
  # If instance_id is not specified, try to get the head instance ID
201
217
  if instance_id is None:
202
- instance_id = _get_head_instance_id(session, job_id)
218
+ try:
219
+ instance_id = _get_head_instance_id(session, job_id)
220
+ except RuntimeError:
221
+ raise RuntimeError(
222
+ "Failed to retrieve job logs. "
223
+ "Logs may be inaccessible due to job expiration and can be retrieved from Event Table instead."
224
+ )
203
225
 
204
226
  # Assemble params: [job_id, instance_id, container_name, (optional) limit]
205
227
  params: list[Any] = [
@@ -210,10 +232,15 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
210
232
  if limit > 0:
211
233
  params.append(limit)
212
234
 
213
- (row,) = session.sql(
214
- f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
215
- params=params,
216
- ).collect()
235
+ try:
236
+ (row,) = session.sql(
237
+ f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
238
+ params=params,
239
+ ).collect()
240
+ except SnowparkSQLException as e:
241
+ if "Container Status: PENDING" in e.message:
242
+ return "Warning: Waiting for container to start. Logs will be shown when available."
243
+ raise
217
244
  return str(row[0])
218
245
 
219
246
 
@@ -223,18 +250,27 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
223
250
  Retrieve the head instance ID of a job.
224
251
 
225
252
  Args:
226
- session: The Snowpark session to use.
227
- job_id: The job ID.
253
+ session (Session): The Snowpark session to use.
254
+ job_id (str): The job ID.
228
255
 
229
256
  Returns:
230
- The head instance ID of the job. Returns None if the head instance has not started yet.
257
+ Optional[int]: The head instance ID of the job, or None if the head instance has not started yet.
258
+
259
+ Raises:
260
+ RuntimeError: If the instances died or if some instances disappeared.
231
261
  """
232
262
  rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
233
263
  if not rows:
234
264
  return None
265
+ if _get_num_instances(session, job_id) > len(rows):
266
+ raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
235
267
 
236
268
  # Sort by start_time first, then by instance_id
237
- sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
269
+ try:
270
+ sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
271
+ except TypeError:
272
+ raise RuntimeError("Job instance information unavailable.")
273
+
238
274
  head_instance = sorted_instances[0]
239
275
  if not head_instance["start_time"]:
240
276
  # If head instance hasn't started yet, return None
@@ -243,3 +279,14 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
243
279
  return int(head_instance["instance_id"])
244
280
  except (ValueError, TypeError):
245
281
  return 0
282
+
283
+
284
+ def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
285
+ (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
286
+ return row
287
+
288
+
289
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
290
+ def _get_num_instances(session: snowpark.Session, job_id: str) -> int:
291
+ row = _get_service_info(session, job_id)
292
+ return int(row["target_instances"]) if row["target_instances"] else 0