snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/utils/service_logger.py +26 -1
  3. snowflake/ml/experiment/_client/artifact.py +76 -0
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  5. snowflake/ml/experiment/callback/keras.py +63 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +5 -1
  7. snowflake/ml/experiment/callback/xgboost.py +5 -1
  8. snowflake/ml/experiment/experiment_tracking.py +89 -4
  9. snowflake/ml/feature_store/feature_store.py +1150 -131
  10. snowflake/ml/feature_store/feature_view.py +122 -0
  11. snowflake/ml/jobs/_utils/__init__.py +0 -0
  12. snowflake/ml/jobs/_utils/constants.py +9 -14
  13. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  14. snowflake/ml/jobs/_utils/payload_utils.py +61 -19
  15. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  16. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  17. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  18. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
  19. snowflake/ml/jobs/_utils/spec_utils.py +44 -13
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  21. snowflake/ml/jobs/_utils/types.py +7 -8
  22. snowflake/ml/jobs/job.py +34 -18
  23. snowflake/ml/jobs/manager.py +107 -24
  24. snowflake/ml/model/__init__.py +6 -1
  25. snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
  26. snowflake/ml/model/_client/model/model_version_impl.py +225 -73
  27. snowflake/ml/model/_client/ops/service_ops.py +128 -174
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
  30. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  32. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
  33. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  35. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  36. snowflake/ml/model/_signatures/utils.py +4 -2
  37. snowflake/ml/model/inference_engine.py +5 -0
  38. snowflake/ml/model/models/huggingface_pipeline.py +4 -3
  39. snowflake/ml/model/openai_signatures.py +57 -0
  40. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  41. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  46. snowflake/ml/modeling/cluster/birch.py +1 -1
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  48. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  50. snowflake/ml/modeling/cluster/k_means.py +1 -1
  51. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  53. snowflake/ml/modeling/cluster/optics.py +1 -1
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  57. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  64. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  65. snowflake/ml/modeling/covariance/oas.py +1 -1
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  69. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  74. snowflake/ml/modeling/decomposition/pca.py +1 -1
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  105. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  106. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  107. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  116. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  118. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  122. snowflake/ml/modeling/linear_model/lars.py +1 -1
  123. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  124. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  129. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  142. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  151. snowflake/ml/modeling/manifold/isomap.py +1 -1
  152. snowflake/ml/modeling/manifold/mds.py +1 -1
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  154. snowflake/ml/modeling/manifold/tsne.py +1 -1
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  157. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  158. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  159. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  160. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  161. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  162. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  163. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  164. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  165. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  166. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  167. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  168. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  169. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  170. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  171. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  172. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  173. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  174. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  175. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  176. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  177. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  178. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  179. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  180. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  181. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  182. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  183. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  184. snowflake/ml/modeling/svm/svc.py +1 -1
  185. snowflake/ml/modeling/svm/svr.py +1 -1
  186. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  187. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  188. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  189. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  191. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  192. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  193. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  194. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  195. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  196. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  197. snowflake/ml/monitoring/model_monitor.py +26 -0
  198. snowflake/ml/registry/_manager/model_manager.py +7 -35
  199. snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
  200. snowflake/ml/version.py +1 -1
  201. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
  202. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
  203. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
  204. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
1
1
  import logging
2
2
  import os
3
+ import sys
3
4
  from math import ceil
4
5
  from pathlib import PurePath
5
- from typing import Any, Optional, Union
6
+ from typing import Any, Literal, Optional, Union
6
7
 
7
8
  from snowflake import snowpark
8
9
  from snowflake.ml._internal.utils import snowflake_env
9
- from snowflake.ml.jobs._utils import constants, query_helper, types
10
+ from snowflake.ml.jobs._utils import constants, feature_flags, query_helper, types
11
+ from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
10
12
 
11
13
 
12
14
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
@@ -28,22 +30,53 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
28
30
  )
29
31
 
30
32
 
33
+ def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU", "GPU"]) -> Optional[str]:
34
+ rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
35
+ if not rows:
36
+ return None
37
+ try:
38
+ runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
39
+ spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes()
40
+ except Exception as e:
41
+ logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
42
+ return None
43
+
44
+ selected_runtime = next(
45
+ (
46
+ runtime
47
+ for runtime in spcs_container_runtimes
48
+ if (
49
+ runtime.hardware_type.lower() == target_hardware.lower()
50
+ and runtime.python_version.major == sys.version_info.major
51
+ and runtime.python_version.minor == sys.version_info.minor
52
+ )
53
+ ),
54
+ None,
55
+ )
56
+ return selected_runtime.runtime_container_image if selected_runtime else None
57
+
58
+
31
59
  def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
32
60
  # Retrieve compute pool node resources
33
61
  resources = _get_node_resources(session, compute_pool=compute_pool)
34
62
 
35
63
  # Use MLRuntime image
36
- image_repo = constants.DEFAULT_IMAGE_REPO
37
- image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
38
- image_tag = _get_runtime_image_tag()
64
+ hardware = "GPU" if resources.gpu > 0 else "CPU"
65
+ container_image = None
66
+ if feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
67
+ container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
68
+
69
+ if not container_image:
70
+ image_repo = constants.DEFAULT_IMAGE_REPO
71
+ image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
72
+ image_tag = _get_runtime_image_tag()
73
+ container_image = f"{image_repo}/{image_name}:{image_tag}"
39
74
 
40
75
  # TODO: Should each instance consume the entire pod?
41
76
  return types.ImageSpec(
42
- repo=image_repo,
43
- image_name=image_name,
44
- image_tag=image_tag,
45
77
  resource_requests=resources,
46
78
  resource_limits=resources,
79
+ container_image=container_image,
47
80
  )
48
81
 
49
82
 
@@ -65,6 +98,7 @@ def generate_spec_overrides(
65
98
  container_spec: dict[str, Any] = {
66
99
  "name": constants.DEFAULT_CONTAINER_NAME,
67
100
  }
101
+
68
102
  if environment_vars:
69
103
  # TODO: Validate environment variables
70
104
  container_spec["env"] = environment_vars
@@ -180,10 +214,7 @@ def generate_service_spec(
180
214
 
181
215
  # TODO: Add hooks for endpoints for integration with TensorBoard etc
182
216
 
183
- env_vars = {
184
- constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
185
- constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
186
- }
217
+ env_vars = payload.env_vars
187
218
  endpoints: list[dict[str, Any]] = []
188
219
 
189
220
  if target_instances > 1:
@@ -220,7 +251,7 @@ def generate_service_spec(
220
251
  "containers": [
221
252
  {
222
253
  "name": constants.DEFAULT_CONTAINER_NAME,
223
- "image": image_spec.full_name,
254
+ "image": image_spec.container_image,
224
255
  "command": ["/usr/local/bin/_entrypoint.sh"],
225
256
  "args": [
226
257
  (stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
@@ -121,15 +121,28 @@ class StagePath:
121
121
  return self._compose_path(self._path)
122
122
 
123
123
  def joinpath(self, *args: Union[str, PathLike[str]]) -> "StagePath":
124
+ """
125
+ Joins the given path arguments to the current path,
126
+ mimicking the behavior of pathlib.Path.joinpath.
127
+ If the argument is a stage path (i.e., an absolute path),
128
+ it overrides the current path and is returned as the final path.
129
+ If the argument is a normal path, it is joined with the current relative path
130
+ using self._path.joinpath(arg).
131
+
132
+ Args:
133
+ *args: Path components to join.
134
+
135
+ Returns:
136
+ A new StagePath with the joined path.
137
+
138
+ Raises:
139
+ NotImplementedError: the argument is a stage path.
140
+ """
124
141
  path = self
125
142
  for arg in args:
126
- path = path._make_child(arg)
143
+ if isinstance(arg, StagePath):
144
+ raise NotImplementedError
145
+ else:
146
+ # the arg might be an absolute path, so we need to remove the leading '/'
147
+ path = StagePath(f"{path.root}/{path._path.joinpath(arg).as_posix().lstrip('/')}")
127
148
  return path
128
-
129
- def _make_child(self, path: Union[str, PathLike[str]]) -> "StagePath":
130
- stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
131
- if self.root == stage_path.root:
132
- child_path = self._path.joinpath(stage_path._path)
133
- return StagePath(self._compose_path(child_path))
134
- else:
135
- return stage_path
@@ -1,5 +1,5 @@
1
1
  import os
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  from pathlib import PurePath
4
4
  from typing import Iterator, Literal, Optional, Protocol, Union, runtime_checkable
5
5
 
@@ -30,6 +30,10 @@ class PayloadPath(Protocol):
30
30
  def parent(self) -> "PayloadPath":
31
31
  ...
32
32
 
33
+ @property
34
+ def root(self) -> str:
35
+ ...
36
+
33
37
  def exists(self) -> bool:
34
38
  ...
35
39
 
@@ -86,6 +90,7 @@ class UploadedPayload:
86
90
  # TODO: Include manifest of payload files for validation
87
91
  stage_path: PurePath
88
92
  entrypoint: list[Union[str, PurePath]]
93
+ env_vars: dict[str, str] = field(default_factory=dict)
89
94
 
90
95
 
91
96
  @dataclass(frozen=True)
@@ -98,12 +103,6 @@ class ComputeResources:
98
103
 
99
104
  @dataclass(frozen=True)
100
105
  class ImageSpec:
101
- repo: str
102
- image_name: str
103
- image_tag: str
104
106
  resource_requests: ComputeResources
105
107
  resource_limits: ComputeResources
106
-
107
- @property
108
- def full_name(self) -> str:
109
- return f"{self.repo}/{self.image_name}:{self.image_tag}"
108
+ container_image: str
snowflake/ml/jobs/job.py CHANGED
@@ -99,21 +99,23 @@ class MLJob(Generic[T], SerializableSessionMixin):
99
99
  result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
100
100
  if result_path_str is None:
101
101
  raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
102
- volume_mounts = self._container_spec["volumeMounts"]
103
- stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
104
102
 
103
+ # If result path is relative, it is relative to the stage mount path
105
104
  result_path = Path(result_path_str)
105
+ if not result_path.is_absolute():
106
+ return f"{self._stage_path}/{result_path.as_posix()}"
107
+
108
+ # If result path is absolute, it is relative to the stage mount path
109
+ volume_mounts = self._container_spec["volumeMounts"]
110
+ stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
106
111
  stage_mount = Path(stage_mount_str)
107
112
  try:
108
113
  relative_path = result_path.relative_to(stage_mount)
114
+ return f"{self._stage_path}/{relative_path.as_posix()}"
109
115
  except ValueError:
110
- if result_path.is_absolute():
111
- raise ValueError(
112
- f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
113
- )
114
- relative_path = result_path
115
-
116
- return f"{self._stage_path}/{relative_path.as_posix()}"
116
+ raise ValueError(
117
+ f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
118
+ )
117
119
 
118
120
  @overload
119
121
  def get_logs(
@@ -199,7 +201,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
199
201
  elapsed = time.monotonic() - start_time
200
202
  if elapsed >= timeout >= 0:
201
203
  raise TimeoutError(f"Job {self.name} did not complete within {timeout} seconds")
202
- elif status == "PENDING" and not warning_shown and elapsed >= 2: # Only show warning after 2s
204
+ elif status == "PENDING" and not warning_shown and elapsed >= 5: # Only show warning after 5s
203
205
  pool_info = _get_compute_pool_info(self._session, self._compute_pool)
204
206
  if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
205
207
  logger.warning(
@@ -419,15 +421,29 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
419
421
  if not rows:
420
422
  return None
421
423
 
422
- if target_instances > len(rows):
423
- raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
424
+ # we have already integrated with first_instance startup policy,
425
+ # the instance 0 is guaranteed to be the head instance
426
+ head_instance = next(
427
+ (
428
+ row
429
+ for row in rows
430
+ if "instance_id" in row and row["instance_id"] is not None and int(row["instance_id"]) == 0
431
+ ),
432
+ None,
433
+ )
434
+ # fallback to find the first instance if the instance 0 is not found
435
+ if not head_instance:
436
+ if target_instances > len(rows):
437
+ raise RuntimeError(
438
+ f"Couldn’t retrieve head instance due to missing instances. {target_instances} > {len(rows)}"
439
+ )
440
+ # Sort by start_time first, then by instance_id
441
+ try:
442
+ sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
443
+ except TypeError:
444
+ raise RuntimeError("Job instance information unavailable.")
445
+ head_instance = sorted_instances[0]
424
446
 
425
- # Sort by start_time first, then by instance_id
426
- try:
427
- sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
428
- except TypeError:
429
- raise RuntimeError("Job instance information unavailable.")
430
- head_instance = sorted_instances[0]
431
447
  if not head_instance["start_time"]:
432
448
  # If head instance hasn't started yet, return None
433
449
  return None
@@ -1,6 +1,8 @@
1
+ import json
1
2
  import logging
2
3
  import pathlib
3
4
  import textwrap
5
+ from pathlib import PurePath
4
6
  from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
5
7
  from uuid import uuid4
6
8
 
@@ -11,7 +13,13 @@ from snowflake import snowpark
11
13
  from snowflake.ml._internal import telemetry
12
14
  from snowflake.ml._internal.utils import identifier
13
15
  from snowflake.ml.jobs import job as jb
14
- from snowflake.ml.jobs._utils import payload_utils, query_helper, spec_utils
16
+ from snowflake.ml.jobs._utils import (
17
+ feature_flags,
18
+ payload_utils,
19
+ query_helper,
20
+ spec_utils,
21
+ types,
22
+ )
15
23
  from snowflake.snowpark.context import get_active_session
16
24
  from snowflake.snowpark.exceptions import SnowparkSQLException
17
25
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -426,7 +434,6 @@ def _submit_job(
426
434
 
427
435
  Raises:
428
436
  ValueError: If database or schema value(s) are invalid
429
- SnowparkSQLException: If there is an error submitting the job.
430
437
  """
431
438
  session = session or get_active_session()
432
439
 
@@ -446,7 +453,7 @@ def _submit_job(
446
453
  env_vars = kwargs.pop("env_vars", None)
447
454
  spec_overrides = kwargs.pop("spec_overrides", None)
448
455
  enable_metrics = kwargs.pop("enable_metrics", True)
449
- query_warehouse = kwargs.pop("query_warehouse", None)
456
+ query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
450
457
  additional_payloads = kwargs.pop("additional_payloads", None)
451
458
 
452
459
  if additional_payloads:
@@ -484,6 +491,27 @@ def _submit_job(
484
491
  source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
485
492
  ).upload(session, stage_path)
486
493
 
494
+ if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
495
+ # Add default env vars (extracted from spec_utils.generate_service_spec)
496
+ combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
497
+
498
+ return _do_submit_job_v2(
499
+ session=session,
500
+ payload=uploaded_payload,
501
+ args=args,
502
+ env_vars=combined_env_vars,
503
+ spec_overrides=spec_overrides,
504
+ compute_pool=compute_pool,
505
+ job_id=job_id,
506
+ external_access_integrations=external_access_integrations,
507
+ query_warehouse=query_warehouse,
508
+ target_instances=target_instances,
509
+ min_instances=min_instances,
510
+ enable_metrics=enable_metrics,
511
+ use_async=True,
512
+ )
513
+
514
+ # Fall back to v1
487
515
  # Generate service spec
488
516
  spec = spec_utils.generate_service_spec(
489
517
  session,
@@ -494,6 +522,8 @@ def _submit_job(
494
522
  min_instances=min_instances,
495
523
  enable_metrics=enable_metrics,
496
524
  )
525
+
526
+ # Generate spec overrides
497
527
  spec_overrides = spec_utils.generate_spec_overrides(
498
528
  environment_vars=env_vars,
499
529
  custom_overrides=spec_overrides,
@@ -501,37 +531,25 @@ def _submit_job(
501
531
  if spec_overrides:
502
532
  spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
503
533
 
504
- query_text, params = _generate_submission_query(
505
- spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
534
+ return _do_submit_job_v1(
535
+ session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
506
536
  )
507
- try:
508
- _ = query_helper.run_query(session, query_text, params=params)
509
- except SnowparkSQLException as e:
510
- if "Invalid spec: unknown option 'resourceManagement' for 'spec'." in e.message:
511
- logger.warning("Dropping 'resourceManagement' from spec because control policy is not enabled.")
512
- spec["spec"].pop("resourceManagement", None)
513
- query_text, params = _generate_submission_query(
514
- spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
515
- )
516
- _ = query_helper.run_query(session, query_text, params=params)
517
- else:
518
- raise
519
- return get_job(job_id, session=session)
520
537
 
521
538
 
522
- def _generate_submission_query(
539
+ def _do_submit_job_v1(
540
+ session: snowpark.Session,
523
541
  spec: dict[str, Any],
524
542
  external_access_integrations: list[str],
525
543
  query_warehouse: Optional[str],
526
544
  target_instances: int,
527
- session: snowpark.Session,
528
545
  compute_pool: str,
529
546
  job_id: str,
530
- ) -> tuple[str, list[Any]]:
547
+ ) -> jb.MLJob[Any]:
531
548
  """
532
549
  Generate the SQL query for job submission.
533
550
 
534
551
  Args:
552
+ session: The Snowpark session to use.
535
553
  spec: The service spec for the job.
536
554
  external_access_integrations: The external access integrations for the job.
537
555
  query_warehouse: The query warehouse for the job.
@@ -541,7 +559,7 @@ def _generate_submission_query(
541
559
  job_id: The ID of the job.
542
560
 
543
561
  Returns:
544
- A tuple containing the SQL query text and the parameters for the query.
562
+ The job object.
545
563
  """
546
564
  query_template = textwrap.dedent(
547
565
  """\
@@ -559,12 +577,77 @@ def _generate_submission_query(
559
577
  if external_access_integrations:
560
578
  external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
561
579
  query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
562
- query_warehouse = query_warehouse or session.get_current_warehouse()
563
580
  if query_warehouse:
564
581
  query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
565
582
  params.append(query_warehouse)
566
583
  if target_instances > 1:
567
584
  query.append("REPLICAS = ?")
568
585
  params.append(target_instances)
586
+
569
587
  query_text = "\n".join(line for line in query if line)
570
- return query_text, params
588
+ _ = query_helper.run_query(session, query_text, params=params)
589
+
590
+ return get_job(job_id, session=session)
591
+
592
+
593
+ def _do_submit_job_v2(
594
+ session: snowpark.Session,
595
+ payload: types.UploadedPayload,
596
+ args: Optional[list[str]],
597
+ env_vars: dict[str, str],
598
+ spec_overrides: dict[str, Any],
599
+ compute_pool: str,
600
+ job_id: Optional[str] = None,
601
+ external_access_integrations: Optional[list[str]] = None,
602
+ query_warehouse: Optional[str] = None,
603
+ target_instances: int = 1,
604
+ min_instances: int = 1,
605
+ enable_metrics: bool = True,
606
+ use_async: bool = True,
607
+ ) -> jb.MLJob[Any]:
608
+ """
609
+ Generate the SQL query for job submission.
610
+
611
+ Args:
612
+ session: The Snowpark session to use.
613
+ payload: The uploaded job payload.
614
+ args: Arguments to pass to the entrypoint script.
615
+ env_vars: Environment variables to set in the job container.
616
+ spec_overrides: Custom service specification overrides.
617
+ compute_pool: The compute pool to use for job execution.
618
+ job_id: The ID of the job.
619
+ external_access_integrations: Optional list of external access integrations.
620
+ query_warehouse: Optional query warehouse to use.
621
+ target_instances: Number of instances for multi-node job.
622
+ min_instances: Minimum number of instances required to start the job.
623
+ enable_metrics: Whether to enable platform metrics for the job.
624
+ use_async: Whether to run the job asynchronously.
625
+
626
+ Returns:
627
+ The job object.
628
+ """
629
+ args = [
630
+ (payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
631
+ ] + (args or [])
632
+ spec_options = {
633
+ "STAGE_PATH": payload.stage_path.as_posix(),
634
+ "ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
635
+ "ARGS": args,
636
+ "ENV_VARS": env_vars,
637
+ "ENABLE_METRICS": enable_metrics,
638
+ "SPEC_OVERRIDES": spec_overrides,
639
+ }
640
+ job_options = {
641
+ "EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
642
+ "QUERY_WAREHOUSE": query_warehouse,
643
+ "TARGET_INSTANCES": target_instances,
644
+ "MIN_INSTANCES": min_instances,
645
+ "ASYNC": use_async,
646
+ }
647
+ job_options = {k: v for k, v in job_options.items() if v is not None}
648
+
649
+ query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
650
+ params = [job_id, compute_pool, json.dumps(spec_options), json.dumps(job_options)]
651
+ actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
652
+
653
+ return get_job(actual_job_id, session=session)
@@ -1,5 +1,10 @@
1
+ from snowflake.ml.model._client.model.batch_inference_specs import (
2
+ InputSpec,
3
+ JobSpec,
4
+ OutputSpec,
5
+ )
1
6
  from snowflake.ml.model._client.model.model_impl import Model
2
7
  from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
3
8
  from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
4
9
 
5
- __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel"]
10
+ __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "InputSpec", "JobSpec", "OutputSpec"]
@@ -0,0 +1,27 @@
1
+ from typing import Optional, Union
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class InputSpec(BaseModel):
7
+ input_stage_location: str
8
+ input_file_pattern: str = "*"
9
+
10
+
11
+ class OutputSpec(BaseModel):
12
+ output_stage_location: str
13
+ output_file_prefix: Optional[str] = None
14
+ completion_filename: str = "_SUCCESS"
15
+
16
+
17
+ class JobSpec(BaseModel):
18
+ image_repo: Optional[str] = None
19
+ job_name: Optional[str] = None
20
+ num_workers: Optional[int] = None
21
+ function_name: Optional[str] = None
22
+ gpu: Optional[Union[str, int]] = None
23
+ force_rebuild: bool = False
24
+ max_batch_rows: int = 1024
25
+ warehouse: Optional[str] = None
26
+ cpu_requests: Optional[str] = None
27
+ memory_requests: Optional[str] = None