snowflake-ml-python 1.14.0__py3-none-any.whl → 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +13 -7
  2. snowflake/ml/_internal/utils/connection_params.py +5 -3
  3. snowflake/ml/_internal/utils/jwt_generator.py +3 -2
  4. snowflake/ml/_internal/utils/mixins.py +24 -9
  5. snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
  6. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
  7. snowflake/ml/experiment/_entities/__init__.py +2 -1
  8. snowflake/ml/experiment/_entities/run.py +0 -15
  9. snowflake/ml/experiment/_entities/run_metadata.py +3 -51
  10. snowflake/ml/experiment/experiment_tracking.py +71 -27
  11. snowflake/ml/jobs/_utils/spec_utils.py +49 -11
  12. snowflake/ml/jobs/manager.py +20 -0
  13. snowflake/ml/model/__init__.py +12 -2
  14. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -4
  15. snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
  16. snowflake/ml/model/_client/model/model_version_impl.py +30 -62
  17. snowflake/ml/model/_client/ops/service_ops.py +68 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  19. snowflake/ml/model/_client/sql/service.py +29 -2
  20. snowflake/ml/model/_client/sql/stage.py +8 -0
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  22. snowflake/ml/model/_model_composer/model_method/model_method.py +25 -2
  23. snowflake/ml/model/_packager/model_env/model_env.py +26 -16
  24. snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
  25. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
  26. snowflake/ml/model/_packager/model_packager.py +4 -3
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  28. snowflake/ml/model/_signatures/utils.py +0 -21
  29. snowflake/ml/model/models/huggingface_pipeline.py +56 -21
  30. snowflake/ml/model/type_hints.py +13 -0
  31. snowflake/ml/model/volatility.py +34 -0
  32. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  33. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  34. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  35. snowflake/ml/modeling/cluster/birch.py +1 -1
  36. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  37. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  38. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  39. snowflake/ml/modeling/cluster/k_means.py +1 -1
  40. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  41. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  42. snowflake/ml/modeling/cluster/optics.py +1 -1
  43. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  44. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  45. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  46. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  47. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  48. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  49. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  50. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  51. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  52. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  53. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  54. snowflake/ml/modeling/covariance/oas.py +1 -1
  55. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  56. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  57. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  58. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  59. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  60. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  61. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  62. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  63. snowflake/ml/modeling/decomposition/pca.py +1 -1
  64. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  65. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  66. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  67. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  68. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  69. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  70. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  71. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  72. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  73. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  74. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  75. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  78. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  79. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  82. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  83. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  84. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  85. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  86. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  87. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  88. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  89. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  90. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  91. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  94. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  95. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  96. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  97. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  98. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  99. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  100. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  101. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  102. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  103. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  104. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  105. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  106. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  107. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  108. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  109. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  110. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  111. snowflake/ml/modeling/linear_model/lars.py +1 -1
  112. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  113. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  114. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  115. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  116. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  118. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  119. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  120. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  123. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  124. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  125. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  126. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  127. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  128. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  129. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  130. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  131. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  132. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  133. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  135. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  136. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  137. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  138. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  140. snowflake/ml/modeling/manifold/isomap.py +1 -1
  141. snowflake/ml/modeling/manifold/mds.py +1 -1
  142. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  143. snowflake/ml/modeling/manifold/tsne.py +1 -1
  144. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  145. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  146. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  147. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  148. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  149. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  150. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  151. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  152. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  153. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  154. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  155. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  156. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  157. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  158. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  159. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  160. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  161. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  162. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  163. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  164. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  165. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  166. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  167. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  168. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  169. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  170. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  171. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  172. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  173. snowflake/ml/modeling/svm/svc.py +1 -1
  174. snowflake/ml/modeling/svm/svr.py +1 -1
  175. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  176. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  177. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  178. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  179. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  180. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  181. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  182. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  183. snowflake/ml/registry/_manager/model_manager.py +2 -1
  184. snowflake/ml/registry/_manager/model_parameter_reconciler.py +29 -2
  185. snowflake/ml/registry/registry.py +15 -0
  186. snowflake/ml/utils/authentication.py +16 -0
  187. snowflake/ml/utils/connection_params.py +5 -3
  188. snowflake/ml/version.py +1 -1
  189. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/METADATA +81 -36
  190. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/RECORD +193 -191
  191. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/WHEEL +0 -0
  192. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/licenses/LICENSE.txt +0 -0
  193. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import os
3
+ import re
3
4
  import sys
4
5
  from math import ceil
5
6
  from pathlib import PurePath
@@ -10,6 +11,8 @@ from snowflake.ml._internal.utils import snowflake_env
10
11
  from snowflake.ml.jobs._utils import constants, feature_flags, query_helper, types
11
12
  from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
12
13
 
14
+ _OCI_TAG_REGEX = re.compile("^[a-zA-Z0-9._-]{1,128}$")
15
+
13
16
 
14
17
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
15
18
  """Extract resource information for the specified compute pool"""
@@ -56,22 +59,55 @@ def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU"
56
59
  return selected_runtime.runtime_container_image if selected_runtime else None
57
60
 
58
61
 
59
- def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
62
+ def _check_image_tag_valid(tag: Optional[str]) -> bool:
63
+ if tag is None:
64
+ return False
65
+
66
+ return _OCI_TAG_REGEX.fullmatch(tag) is not None
67
+
68
+
69
+ def _get_image_spec(
70
+ session: snowpark.Session, compute_pool: str, runtime_environment: Optional[str] = None
71
+ ) -> types.ImageSpec:
72
+ """
73
+ Resolve image specification (container image and resources) for the job.
74
+
75
+ Behavior:
76
+ - If `runtime_environment` is empty or the feature flag is disabled, use the
77
+ default image tag and image name.
78
+ - If `runtime_environment` is a valid image tag, use that tag with the default
79
+ repository/name.
80
+ - If `runtime_environment` is a full image URL, use it directly.
81
+ - If the feature flag is enabled and `runtime_environment` is not provided,
82
+ select an ML Runtime image matching the local Python major.minor
83
+ - When multiple inputs are provided, `runtime_environment` takes priority.
84
+
85
+ Args:
86
+ session: Snowflake session.
87
+ compute_pool: Compute pool used to infer CPU/GPU resources.
88
+ runtime_environment: Optional image tag or full image URL to override.
89
+
90
+ Returns:
91
+ Image spec including container image and resource requests/limits.
92
+ """
60
93
  # Retrieve compute pool node resources
61
94
  resources = _get_node_resources(session, compute_pool=compute_pool)
95
+ hardware = "GPU" if resources.gpu > 0 else "CPU"
96
+ image_tag = _get_runtime_image_tag()
97
+ image_repo = constants.DEFAULT_IMAGE_REPO
98
+ image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
62
99
 
63
100
  # Use MLRuntime image
64
- hardware = "GPU" if resources.gpu > 0 else "CPU"
65
101
  container_image = None
66
- if feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
102
+ if runtime_environment:
103
+ if _check_image_tag_valid(runtime_environment):
104
+ image_tag = runtime_environment
105
+ else:
106
+ container_image = runtime_environment
107
+ elif feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
67
108
  container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
68
109
 
69
- if not container_image:
70
- image_repo = constants.DEFAULT_IMAGE_REPO
71
- image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
72
- image_tag = _get_runtime_image_tag()
73
- container_image = f"{image_repo}/{image_name}:{image_tag}"
74
-
110
+ container_image = container_image or f"{image_repo}/{image_name}:{image_tag}"
75
111
  # TODO: Should each instance consume the entire pod?
76
112
  return types.ImageSpec(
77
113
  resource_requests=resources,
@@ -127,6 +163,7 @@ def generate_service_spec(
127
163
  target_instances: int = 1,
128
164
  min_instances: int = 1,
129
165
  enable_metrics: bool = False,
166
+ runtime_environment: Optional[str] = None,
130
167
  ) -> dict[str, Any]:
131
168
  """
132
169
  Generate a service specification for a job.
@@ -139,11 +176,12 @@ def generate_service_spec(
139
176
  target_instances: Number of instances for multi-node job
140
177
  enable_metrics: Enable platform metrics for the job
141
178
  min_instances: Minimum number of instances required to start the job
179
+ runtime_environment: The runtime image to use. Only support image tag or full image URL.
142
180
 
143
181
  Returns:
144
182
  Job service specification
145
183
  """
146
- image_spec = _get_image_spec(session, compute_pool)
184
+ image_spec = _get_image_spec(session, compute_pool, runtime_environment)
147
185
 
148
186
  # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
149
187
  resource_requests: dict[str, Union[str, int]] = {
@@ -317,7 +355,7 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
317
355
  Returns:
318
356
  The patched object.
319
357
  """
320
- if not type(base) is type(patch):
358
+ if type(base) is not type(patch):
321
359
  if base is not None:
322
360
  logging.warning(f"Type mismatch while merging {display_name} (base={type(base)}, patch={type(patch)})")
323
361
  return patch
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import pathlib
4
+ import sys
4
5
  import textwrap
5
6
  from pathlib import PurePath
6
7
  from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
@@ -344,6 +345,9 @@ def submit_from_stage(
344
345
  query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
345
346
  spec_overrides (dict): A dictionary of overrides for the service spec.
346
347
  imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
348
+ runtime_environment (str): The runtime image to use. Only support image tag or full image URL,
349
+ e.g. "1.7.1" or "image_repo/image_name:image_tag". When it refers to a full image URL,
350
+ it should contain image repository, image name and image tag.
347
351
 
348
352
  Returns:
349
353
  An object representing the submitted job.
@@ -409,6 +413,7 @@ def _submit_job(
409
413
  "min_instances",
410
414
  "enable_metrics",
411
415
  "query_warehouse",
416
+ "runtime_environment",
412
417
  ],
413
418
  )
414
419
  def _submit_job(
@@ -459,6 +464,9 @@ def _submit_job(
459
464
  )
460
465
  imports = kwargs.pop("additional_payloads")
461
466
 
467
+ if "runtime_environment" in kwargs:
468
+ logger.warning("'runtime_environment' is in private preview since 1.15.0, do not use it in production.")
469
+
462
470
  # Use kwargs for less common optional parameters
463
471
  database = kwargs.pop("database", None)
464
472
  schema = kwargs.pop("schema", None)
@@ -470,6 +478,7 @@ def _submit_job(
470
478
  enable_metrics = kwargs.pop("enable_metrics", True)
471
479
  query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
472
480
  imports = kwargs.pop("imports", None) or imports
481
+ runtime_environment = kwargs.pop("runtime_environment", None)
473
482
 
474
483
  # Warn if there are unknown kwargs
475
484
  if kwargs:
@@ -544,6 +553,7 @@ def _submit_job(
544
553
  min_instances=min_instances,
545
554
  enable_metrics=enable_metrics,
546
555
  use_async=True,
556
+ runtime_environment=runtime_environment,
547
557
  )
548
558
 
549
559
  # Fall back to v1
@@ -556,6 +566,7 @@ def _submit_job(
556
566
  target_instances=target_instances,
557
567
  min_instances=min_instances,
558
568
  enable_metrics=enable_metrics,
569
+ runtime_environment=runtime_environment,
559
570
  )
560
571
 
561
572
  # Generate spec overrides
@@ -639,6 +650,7 @@ def _do_submit_job_v2(
639
650
  min_instances: int = 1,
640
651
  enable_metrics: bool = True,
641
652
  use_async: bool = True,
653
+ runtime_environment: Optional[str] = None,
642
654
  ) -> jb.MLJob[Any]:
643
655
  """
644
656
  Generate the SQL query for job submission.
@@ -657,6 +669,7 @@ def _do_submit_job_v2(
657
669
  min_instances: Minimum number of instances required to start the job.
658
670
  enable_metrics: Whether to enable platform metrics for the job.
659
671
  use_async: Whether to run the job asynchronously.
672
+ runtime_environment: image tag or full image URL to use for the job.
660
673
 
661
674
  Returns:
662
675
  The job object.
@@ -672,6 +685,13 @@ def _do_submit_job_v2(
672
685
  "ENABLE_METRICS": enable_metrics,
673
686
  "SPEC_OVERRIDES": spec_overrides,
674
687
  }
688
+ # for the image tag or full image URL, we use that directly
689
+ if runtime_environment:
690
+ spec_options["RUNTIME"] = runtime_environment
691
+ elif feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
692
+ # when feature flag is enabled, we get the local python version and wrap it in a dict
693
+ # in system function, we can know whether it is python version or image tag or full image URL through the format
694
+ spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
675
695
  job_options = {
676
696
  "EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
677
697
  "QUERY_WAREHOUSE": query_warehouse,
@@ -1,10 +1,20 @@
1
1
  from snowflake.ml.model._client.model.batch_inference_specs import (
2
- InputSpec,
3
2
  JobSpec,
4
3
  OutputSpec,
4
+ SaveMode,
5
5
  )
6
6
  from snowflake.ml.model._client.model.model_impl import Model
7
7
  from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
8
8
  from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
9
+ from snowflake.ml.model.volatility import Volatility
9
10
 
10
- __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "InputSpec", "JobSpec", "OutputSpec"]
11
+ __all__ = [
12
+ "Model",
13
+ "ModelVersion",
14
+ "ExportMode",
15
+ "HuggingFacePipelineModel",
16
+ "JobSpec",
17
+ "OutputSpec",
18
+ "SaveMode",
19
+ "Volatility",
20
+ ]
@@ -1,14 +1,26 @@
1
- from typing import Optional, Union
1
+ from enum import Enum
2
+ from typing import Optional
2
3
 
3
4
  from pydantic import BaseModel
4
5
 
5
6
 
6
- class InputSpec(BaseModel):
7
- stage_location: str
7
+ class SaveMode(str, Enum):
8
+ """Save mode options for batch inference output.
9
+
10
+ Determines the behavior when files already exist in the output location.
11
+
12
+ OVERWRITE: Remove existing files and write new results.
13
+
14
+ ERROR: Raise an error if files already exist in the output location.
15
+ """
16
+
17
+ OVERWRITE = "overwrite"
18
+ ERROR = "error"
8
19
 
9
20
 
10
21
  class OutputSpec(BaseModel):
11
22
  stage_location: str
23
+ mode: SaveMode = SaveMode.ERROR
12
24
 
13
25
 
14
26
  class JobSpec(BaseModel):
@@ -16,10 +28,10 @@ class JobSpec(BaseModel):
16
28
  job_name: Optional[str] = None
17
29
  num_workers: Optional[int] = None
18
30
  function_name: Optional[str] = None
19
- gpu: Optional[Union[str, int]] = None
20
31
  force_rebuild: bool = False
21
32
  max_batch_rows: int = 1024
22
33
  warehouse: Optional[str] = None
23
34
  cpu_requests: Optional[str] = None
24
35
  memory_requests: Optional[str] = None
36
+ gpu_requests: Optional[str] = None
25
37
  replicas: Optional[int] = None
@@ -0,0 +1,55 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ from snowflake.ml.model._client.ops import service_ops
4
+
5
+
6
+ def _get_inference_engine_args(
7
+ experimental_options: Optional[dict[str, Any]],
8
+ ) -> Optional[service_ops.InferenceEngineArgs]:
9
+
10
+ if not experimental_options:
11
+ return None
12
+
13
+ if "inference_engine" not in experimental_options:
14
+ raise ValueError("inference_engine is required in experimental_options")
15
+
16
+ return service_ops.InferenceEngineArgs(
17
+ inference_engine=experimental_options["inference_engine"],
18
+ inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
19
+ )
20
+
21
+
22
+ def _enrich_inference_engine_args(
23
+ inference_engine_args: service_ops.InferenceEngineArgs,
24
+ gpu_requests: Optional[Union[str, int]] = None,
25
+ ) -> Optional[service_ops.InferenceEngineArgs]:
26
+ """Enrich inference engine args with model path and tensor parallelism settings.
27
+
28
+ Args:
29
+ inference_engine_args: The original inference engine args
30
+ gpu_requests: The number of GPUs requested
31
+
32
+ Returns:
33
+ Enriched inference engine args
34
+
35
+ Raises:
36
+ ValueError: Invalid gpu_requests
37
+ """
38
+ if inference_engine_args.inference_engine_args_override is None:
39
+ inference_engine_args.inference_engine_args_override = []
40
+
41
+ gpu_count = None
42
+
43
+ # Set tensor-parallelism if gpu_requests is specified
44
+ if gpu_requests is not None:
45
+ # assert gpu_requests is a string or an integer before casting to int
46
+ try:
47
+ gpu_count = int(gpu_requests)
48
+ if gpu_count > 0:
49
+ inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
50
+ else:
51
+ raise ValueError(f"GPU count must be greater than 0, got {gpu_count}")
52
+ except ValueError:
53
+ raise ValueError(f"Invalid gpu_requests: {gpu_requests} with type {type(gpu_requests).__name__}")
54
+
55
+ return inference_engine_args
@@ -12,7 +12,10 @@ from snowflake.ml._internal import telemetry
12
12
  from snowflake.ml._internal.utils import sql_identifier
13
13
  from snowflake.ml.lineage import lineage_node
14
14
  from snowflake.ml.model import task, type_hints
15
- from snowflake.ml.model._client.model import batch_inference_specs
15
+ from snowflake.ml.model._client.model import (
16
+ batch_inference_specs,
17
+ inference_engine_utils,
18
+ )
16
19
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
17
20
  from snowflake.ml.model._model_composer import model_composer
18
21
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -22,6 +25,7 @@ from snowflake.snowpark import Session, async_job, dataframe
22
25
  _TELEMETRY_PROJECT = "MLOps"
23
26
  _TELEMETRY_SUBPROJECT = "ModelManagement"
24
27
  _BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
28
+ _BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
25
29
 
26
30
 
27
31
  class ExportMode(enum.Enum):
@@ -547,13 +551,15 @@ class ModelVersion(lineage_node.LineageNode):
547
551
  subproject=_TELEMETRY_SUBPROJECT,
548
552
  func_params_to_log=[
549
553
  "compute_pool",
554
+ "output_spec",
555
+ "job_spec",
550
556
  ],
551
557
  )
552
558
  def _run_batch(
553
559
  self,
554
560
  *,
555
561
  compute_pool: str,
556
- input_spec: batch_inference_specs.InputSpec,
562
+ input_spec: dataframe.DataFrame,
557
563
  output_spec: batch_inference_specs.OutputSpec,
558
564
  job_spec: Optional[batch_inference_specs.JobSpec] = None,
559
565
  ) -> jobs.MLJob[Any]:
@@ -569,6 +575,20 @@ class ModelVersion(lineage_node.LineageNode):
569
575
  if warehouse is None:
570
576
  raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
571
577
 
578
+ # use a temporary folder in the output stage to store the intermediate output from the dataframe
579
+ output_stage_location = output_spec.stage_location
580
+ if not output_stage_location.endswith("/"):
581
+ output_stage_location += "/"
582
+ input_stage_location = f"{output_stage_location}{_BATCH_INFERENCE_TEMPORARY_FOLDER}/"
583
+
584
+ self._service_ops._enforce_save_mode(output_spec.mode, output_stage_location)
585
+
586
+ try:
587
+ input_spec.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
588
+ # todo: be specific about the type of errors to provide better error messages.
589
+ except Exception as e:
590
+ raise RuntimeError(f"Failed to process input_spec: {e}")
591
+
572
592
  if job_spec.job_name is None:
573
593
  # Same as the MLJob ID generation logic with a different prefix
574
594
  job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
@@ -589,12 +609,13 @@ class ModelVersion(lineage_node.LineageNode):
589
609
  warehouse=sql_identifier.SqlIdentifier(warehouse),
590
610
  cpu_requests=job_spec.cpu_requests,
591
611
  memory_requests=job_spec.memory_requests,
612
+ gpu_requests=job_spec.gpu_requests,
592
613
  job_name=job_name,
593
614
  replicas=job_spec.replicas,
594
615
  # input and output
595
- input_stage_location=input_spec.stage_location,
616
+ input_stage_location=input_stage_location,
596
617
  input_file_pattern="*",
597
- output_stage_location=output_spec.stage_location,
618
+ output_stage_location=output_stage_location,
598
619
  completion_filename="_SUCCESS",
599
620
  # misc
600
621
  statement_params=statement_params,
@@ -768,60 +789,6 @@ class ModelVersion(lineage_node.LineageNode):
768
789
  version_name=sql_identifier.SqlIdentifier(version),
769
790
  )
770
791
 
771
- def _get_inference_engine_args(
772
- self, experimental_options: Optional[dict[str, Any]]
773
- ) -> Optional[service_ops.InferenceEngineArgs]:
774
-
775
- if not experimental_options:
776
- return None
777
-
778
- if "inference_engine" not in experimental_options:
779
- raise ValueError("inference_engine is required in experimental_options")
780
-
781
- return service_ops.InferenceEngineArgs(
782
- inference_engine=experimental_options["inference_engine"],
783
- inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
784
- )
785
-
786
- def _enrich_inference_engine_args(
787
- self,
788
- inference_engine_args: service_ops.InferenceEngineArgs,
789
- gpu_requests: Optional[Union[str, int]] = None,
790
- ) -> Optional[service_ops.InferenceEngineArgs]:
791
- """Enrich inference engine args with tensor parallelism settings.
792
-
793
- Args:
794
- inference_engine_args: The original inference engine args
795
- gpu_requests: The number of GPUs requested
796
-
797
- Returns:
798
- Enriched inference engine args
799
-
800
- Raises:
801
- ValueError: Invalid gpu_requests
802
- """
803
- if inference_engine_args.inference_engine_args_override is None:
804
- inference_engine_args.inference_engine_args_override = []
805
-
806
- gpu_count = None
807
-
808
- # Set tensor-parallelism if gpu_requests is specified
809
- if gpu_requests is not None:
810
- # assert gpu_requests is a string or an integer before casting to int
811
- if isinstance(gpu_requests, str) or isinstance(gpu_requests, int):
812
- try:
813
- gpu_count = int(gpu_requests)
814
- except ValueError:
815
- raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
816
-
817
- if gpu_count is not None:
818
- if gpu_count > 0:
819
- inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
820
- else:
821
- raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
822
-
823
- return inference_engine_args
824
-
825
792
  def _check_huggingface_text_generation_model(
826
793
  self,
827
794
  statement_params: Optional[dict[str, Any]] = None,
@@ -1101,13 +1068,14 @@ class ModelVersion(lineage_node.LineageNode):
1101
1068
  if experimental_options:
1102
1069
  self._check_huggingface_text_generation_model(statement_params)
1103
1070
 
1104
- inference_engine_args: Optional[service_ops.InferenceEngineArgs] = self._get_inference_engine_args(
1105
- experimental_options
1106
- )
1071
+ inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
1107
1072
 
1108
1073
  # Enrich inference engine args if inference engine is specified
1109
1074
  if inference_engine_args is not None:
1110
- inference_engine_args = self._enrich_inference_engine_args(inference_engine_args, gpu_requests)
1075
+ inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
1076
+ inference_engine_args,
1077
+ gpu_requests,
1078
+ )
1111
1079
 
1112
1080
  from snowflake.ml.model import event_handler
1113
1081
  from snowflake.snowpark import exceptions
@@ -7,6 +7,7 @@ import re
7
7
  import tempfile
8
8
  import threading
9
9
  import time
10
+ import warnings
10
11
  from typing import Any, Optional, Union, cast
11
12
 
12
13
  from snowflake import snowpark
@@ -14,6 +15,7 @@ from snowflake.ml import jobs
14
15
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
15
16
  from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
16
17
  from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
18
+ from snowflake.ml.model._client.model import batch_inference_specs
17
19
  from snowflake.ml.model._client.service import model_deployment_spec
18
20
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
19
21
  from snowflake.snowpark import async_job, exceptions, row, session
@@ -155,16 +157,17 @@ class ServiceOperator:
155
157
  database_name=database_name,
156
158
  schema_name=schema_name,
157
159
  )
158
- if pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled():
160
+ self._stage_client = stage_sql.StageSQLClient(
161
+ session,
162
+ database_name=database_name,
163
+ schema_name=schema_name,
164
+ )
165
+ self._use_inlined_deployment_spec = pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled()
166
+ if self._use_inlined_deployment_spec:
159
167
  self._workspace = None
160
168
  self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec()
161
169
  else:
162
170
  self._workspace = tempfile.TemporaryDirectory()
163
- self._stage_client = stage_sql.StageSQLClient(
164
- session,
165
- database_name=database_name,
166
- schema_name=schema_name,
167
- )
168
171
  self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
169
172
  workspace_path=pathlib.Path(self._workspace.name)
170
173
  )
@@ -264,7 +267,14 @@ class ServiceOperator:
264
267
  self._model_deployment_spec.add_hf_logger_spec(
265
268
  hf_model_name=hf_model_args.hf_model_name,
266
269
  hf_task=hf_model_args.hf_task,
267
- hf_token=hf_model_args.hf_token,
270
+ hf_token=(
271
+ # when using inlined deployment spec, we need to use QMARK_RESERVED_TOKEN
272
+ # to avoid revealing the token while calling the SYSTEM$DEPLOY_MODEL function
273
+ # noop if using file-based deployment spec or token is not provided
274
+ service_sql.QMARK_RESERVED_TOKEN
275
+ if hf_model_args.hf_token and self._use_inlined_deployment_spec
276
+ else hf_model_args.hf_token
277
+ ),
268
278
  hf_tokenizer=hf_model_args.hf_tokenizer,
269
279
  hf_revision=hf_model_args.hf_revision,
270
280
  hf_trust_remote_code=hf_model_args.hf_trust_remote_code,
@@ -320,6 +330,14 @@ class ServiceOperator:
320
330
  model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
321
331
  ),
322
332
  model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
333
+ query_params=(
334
+ # when using inlined deployment spec, we need to add the token to the query params
335
+ # to avoid revealing the token while calling the SYSTEM$DEPLOY_MODEL function
336
+ # noop if using file-based deployment spec or token is not provided
337
+ [hf_model_args.hf_token]
338
+ if (self._use_inlined_deployment_spec and hf_model_args and hf_model_args.hf_token)
339
+ else []
340
+ ),
323
341
  statement_params=statement_params,
324
342
  )
325
343
 
@@ -635,6 +653,47 @@ class ServiceOperator:
635
653
  else:
636
654
  module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
637
655
 
656
+ def _enforce_save_mode(self, output_mode: batch_inference_specs.SaveMode, output_stage_location: str) -> None:
657
+ """Enforce the save mode for the output stage location.
658
+
659
+ Args:
660
+ output_mode: The output mode
661
+ output_stage_location: The output stage location to check/clean.
662
+
663
+ Raises:
664
+ FileExistsError: When ERROR mode is specified and files exist in the output location.
665
+ RuntimeError: When operations fail (checking files or removing files).
666
+ ValueError: When an invalid SaveMode is specified.
667
+ """
668
+ list_results = self._stage_client.list_stage(output_stage_location)
669
+
670
+ if output_mode == batch_inference_specs.SaveMode.ERROR:
671
+ if len(list_results) > 0:
672
+ raise FileExistsError(
673
+ f"Output stage location '{output_stage_location}' is not empty. "
674
+ f"Found {len(list_results)} existing files. When using ERROR mode, the output location "
675
+ f"must be empty. Please clear the existing files or use OVERWRITE mode."
676
+ )
677
+ elif output_mode == batch_inference_specs.SaveMode.OVERWRITE:
678
+ if len(list_results) > 0:
679
+ warnings.warn(
680
+ f"Output stage location '{output_stage_location}' is not empty. "
681
+ f"Found {len(list_results)} existing files. OVERWRITE mode will remove all existing files "
682
+ f"in the output location before running the batch inference job.",
683
+ stacklevel=2,
684
+ )
685
+ try:
686
+ self._session.sql(f"REMOVE {output_stage_location}").collect()
687
+ except Exception as e:
688
+ raise RuntimeError(
689
+ f"OVERWRITE was specified. However, failed to remove existing files in output stage "
690
+ f"{output_stage_location}: {e}. Please clear up the existing files manually and retry "
691
+ f"the operation."
692
+ )
693
+ else:
694
+ valid_modes = list(batch_inference_specs.SaveMode)
695
+ raise ValueError(f"Invalid SaveMode: {output_mode}. Must be one of {valid_modes}")
696
+
638
697
  def _stream_service_logs(
639
698
  self,
640
699
  async_job: snowpark.AsyncJob,
@@ -911,6 +970,7 @@ class ServiceOperator:
911
970
  max_batch_rows: Optional[int],
912
971
  cpu_requests: Optional[str],
913
972
  memory_requests: Optional[str],
973
+ gpu_requests: Optional[str],
914
974
  replicas: Optional[int],
915
975
  statement_params: Optional[dict[str, Any]] = None,
916
976
  ) -> jobs.MLJob[Any]:
@@ -945,6 +1005,7 @@ class ServiceOperator:
945
1005
  warehouse=warehouse,
946
1006
  cpu=cpu_requests,
947
1007
  memory=memory_requests,
1008
+ gpu=gpu_requests,
948
1009
  replicas=replicas,
949
1010
  )
950
1011
 
@@ -204,7 +204,7 @@ class ModelDeploymentSpec:
204
204
  job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
205
205
  cpu: Optional[str] = None,
206
206
  memory: Optional[str] = None,
207
- gpu: Optional[Union[str, int]] = None,
207
+ gpu: Optional[str] = None,
208
208
  num_workers: Optional[int] = None,
209
209
  max_batch_rows: Optional[int] = None,
210
210
  replicas: Optional[int] = None,