snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/utils/service_logger.py +26 -1
  3. snowflake/ml/experiment/_client/artifact.py +76 -0
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  5. snowflake/ml/experiment/callback/keras.py +63 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +5 -1
  7. snowflake/ml/experiment/callback/xgboost.py +5 -1
  8. snowflake/ml/experiment/experiment_tracking.py +89 -4
  9. snowflake/ml/feature_store/feature_store.py +1150 -131
  10. snowflake/ml/feature_store/feature_view.py +122 -0
  11. snowflake/ml/jobs/_utils/__init__.py +0 -0
  12. snowflake/ml/jobs/_utils/constants.py +9 -14
  13. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  14. snowflake/ml/jobs/_utils/payload_utils.py +61 -19
  15. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  16. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  17. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  18. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
  19. snowflake/ml/jobs/_utils/spec_utils.py +44 -13
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  21. snowflake/ml/jobs/_utils/types.py +7 -8
  22. snowflake/ml/jobs/job.py +34 -18
  23. snowflake/ml/jobs/manager.py +107 -24
  24. snowflake/ml/model/__init__.py +6 -1
  25. snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
  26. snowflake/ml/model/_client/model/model_version_impl.py +225 -73
  27. snowflake/ml/model/_client/ops/service_ops.py +128 -174
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
  30. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  32. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
  33. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  35. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  36. snowflake/ml/model/_signatures/utils.py +4 -2
  37. snowflake/ml/model/inference_engine.py +5 -0
  38. snowflake/ml/model/models/huggingface_pipeline.py +4 -3
  39. snowflake/ml/model/openai_signatures.py +57 -0
  40. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  41. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  46. snowflake/ml/modeling/cluster/birch.py +1 -1
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  48. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  50. snowflake/ml/modeling/cluster/k_means.py +1 -1
  51. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  53. snowflake/ml/modeling/cluster/optics.py +1 -1
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  57. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  64. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  65. snowflake/ml/modeling/covariance/oas.py +1 -1
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  69. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  74. snowflake/ml/modeling/decomposition/pca.py +1 -1
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  105. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  106. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  107. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  116. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  118. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  122. snowflake/ml/modeling/linear_model/lars.py +1 -1
  123. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  124. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  129. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  142. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  151. snowflake/ml/modeling/manifold/isomap.py +1 -1
  152. snowflake/ml/modeling/manifold/mds.py +1 -1
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  154. snowflake/ml/modeling/manifold/tsne.py +1 -1
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  157. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  158. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  159. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  160. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  161. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  162. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  163. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  164. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  165. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  166. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  167. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  168. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  169. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  170. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  171. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  172. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  173. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  174. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  175. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  176. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  177. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  178. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  179. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  180. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  181. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  182. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  183. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  184. snowflake/ml/modeling/svm/svc.py +1 -1
  185. snowflake/ml/modeling/svm/svr.py +1 -1
  186. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  187. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  188. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  189. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  191. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  192. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  193. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  194. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  195. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  196. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  197. snowflake/ml/monitoring/model_monitor.py +26 -0
  198. snowflake/ml/registry/_manager/model_manager.py +7 -35
  199. snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
  200. snowflake/ml/version.py +1 -1
  201. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
  202. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
  203. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
  204. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
@@ -10,13 +10,13 @@ import time
10
10
  from typing import Any, Optional, Union, cast
11
11
 
12
12
  from snowflake import snowpark
13
+ from snowflake.ml import jobs
13
14
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
14
15
  from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
15
- from snowflake.ml.model import model_signature, type_hints
16
+ from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
16
17
  from snowflake.ml.model._client.service import model_deployment_spec
17
18
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
18
- from snowflake.ml.model._signatures import snowpark_handler
19
- from snowflake.snowpark import async_job, dataframe, exceptions, row, session
19
+ from snowflake.snowpark import async_job, exceptions, row, session
20
20
  from snowflake.snowpark._internal import utils as snowpark_utils
21
21
 
22
22
  module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
@@ -131,6 +131,12 @@ class HFModelArgs:
131
131
  warehouse: Optional[str] = None
132
132
 
133
133
 
134
+ @dataclasses.dataclass
135
+ class InferenceEngineArgs:
136
+ inference_engine: inference_engine_module.InferenceEngine
137
+ inference_engine_args_override: Optional[list[str]] = None
138
+
139
+
134
140
  class ServiceOperator:
135
141
  """Service operator for container services logic."""
136
142
 
@@ -180,7 +186,7 @@ class ServiceOperator:
180
186
  service_name: sql_identifier.SqlIdentifier,
181
187
  image_build_compute_pool_name: sql_identifier.SqlIdentifier,
182
188
  service_compute_pool_name: sql_identifier.SqlIdentifier,
183
- image_repo: str,
189
+ image_repo_name: Optional[str],
184
190
  ingress_enabled: bool,
185
191
  max_instances: int,
186
192
  cpu_requests: Optional[str],
@@ -195,6 +201,8 @@ class ServiceOperator:
195
201
  statement_params: Optional[dict[str, Any]] = None,
196
202
  # hf model
197
203
  hf_model_args: Optional[HFModelArgs] = None,
204
+ # inference engine model
205
+ inference_engine_args: Optional[InferenceEngineArgs] = None,
198
206
  ) -> Union[str, async_job.AsyncJob]:
199
207
 
200
208
  # Generate operation ID for this deployment
@@ -205,15 +213,14 @@ class ServiceOperator:
205
213
  schema_name = schema_name or self._schema_name
206
214
 
207
215
  # Fall back to the model's database and schema if not provided then to the registry's database and schema
208
- service_database_name = service_database_name or database_name or self._database_name
209
- service_schema_name = service_schema_name or schema_name or self._schema_name
216
+ service_database_name = service_database_name or database_name
217
+ service_schema_name = service_schema_name or schema_name
210
218
 
211
- # Parse image repo
212
- image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
213
- image_repo
214
- )
215
- image_repo_database_name = image_repo_database_name or database_name or self._database_name
216
- image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
219
+ image_repo_fqn = ServiceOperator._get_image_repo_fqn(image_repo_name, database_name, schema_name)
220
+
221
+ # There may be more conditions to enable image build in the future
222
+ # For now, we only enable image build if inference engine is not specified
223
+ is_enable_image_build = inference_engine_args is None
217
224
 
218
225
  # Step 1: Preparing deployment artifacts
219
226
  progress_status.update("preparing deployment artifacts...")
@@ -230,14 +237,15 @@ class ServiceOperator:
230
237
  model_name=model_name,
231
238
  version_name=version_name,
232
239
  )
233
- self._model_deployment_spec.add_image_build_spec(
234
- image_build_compute_pool_name=image_build_compute_pool_name,
235
- image_repo_database_name=image_repo_database_name,
236
- image_repo_schema_name=image_repo_schema_name,
237
- image_repo_name=image_repo_name,
238
- force_rebuild=force_rebuild,
239
- external_access_integrations=build_external_access_integrations,
240
- )
240
+
241
+ if is_enable_image_build:
242
+ self._model_deployment_spec.add_image_build_spec(
243
+ image_build_compute_pool_name=image_build_compute_pool_name,
244
+ fully_qualified_image_repo_name=image_repo_fqn,
245
+ force_rebuild=force_rebuild,
246
+ external_access_integrations=build_external_access_integrations,
247
+ )
248
+
241
249
  self._model_deployment_spec.add_service_spec(
242
250
  service_database_name=service_database_name,
243
251
  service_schema_name=service_schema_name,
@@ -266,6 +274,13 @@ class ServiceOperator:
266
274
  warehouse=hf_model_args.warehouse,
267
275
  **(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
268
276
  )
277
+
278
+ if inference_engine_args:
279
+ self._model_deployment_spec.add_inference_engine_spec(
280
+ inference_engine=inference_engine_args.inference_engine,
281
+ inference_engine_args=inference_engine_args.inference_engine_args_override,
282
+ )
283
+
269
284
  spec_yaml_str_or_path = self._model_deployment_spec.save()
270
285
 
271
286
  # Step 2: Uploading deployment artifacts
@@ -412,6 +427,29 @@ class ServiceOperator:
412
427
 
413
428
  return async_job
414
429
 
430
+ @staticmethod
431
+ def _get_image_repo_fqn(
432
+ image_repo_name: Optional[str],
433
+ database_name: sql_identifier.SqlIdentifier,
434
+ schema_name: sql_identifier.SqlIdentifier,
435
+ ) -> Optional[str]:
436
+ """Get the fully qualified name of the image repository."""
437
+ if image_repo_name is None or image_repo_name.strip() == "":
438
+ return None
439
+ # Parse image repo
440
+ (
441
+ image_repo_database_name,
442
+ image_repo_schema_name,
443
+ image_repo_name,
444
+ ) = sql_identifier.parse_fully_qualified_name(image_repo_name)
445
+ image_repo_database_name = image_repo_database_name or database_name
446
+ image_repo_schema_name = image_repo_schema_name or schema_name
447
+ return identifier.get_schema_level_object_identifier(
448
+ db=image_repo_database_name.identifier(),
449
+ schema=image_repo_schema_name.identifier(),
450
+ object_name=image_repo_name.identifier(),
451
+ )
452
+
415
453
  def _start_service_log_streaming(
416
454
  self,
417
455
  async_job: snowpark.AsyncJob,
@@ -824,181 +862,97 @@ class ServiceOperator:
824
862
  except exceptions.SnowparkSQLException:
825
863
  return False
826
864
 
827
- def invoke_job_method(
865
+ def invoke_batch_job_method(
828
866
  self,
829
- target_method: str,
830
- signature: model_signature.ModelSignature,
831
- X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
832
- database_name: Optional[sql_identifier.SqlIdentifier],
833
- schema_name: Optional[sql_identifier.SqlIdentifier],
867
+ *,
868
+ function_name: str,
834
869
  model_name: sql_identifier.SqlIdentifier,
835
870
  version_name: sql_identifier.SqlIdentifier,
836
- job_database_name: Optional[sql_identifier.SqlIdentifier],
837
- job_schema_name: Optional[sql_identifier.SqlIdentifier],
838
- job_name: sql_identifier.SqlIdentifier,
871
+ job_name: str,
839
872
  compute_pool_name: sql_identifier.SqlIdentifier,
840
- warehouse_name: sql_identifier.SqlIdentifier,
841
- image_repo: str,
842
- output_table_database_name: Optional[sql_identifier.SqlIdentifier],
843
- output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
844
- output_table_name: sql_identifier.SqlIdentifier,
845
- cpu_requests: Optional[str],
846
- memory_requests: Optional[str],
847
- gpu_requests: Optional[Union[int, str]],
873
+ warehouse: sql_identifier.SqlIdentifier,
874
+ image_repo_name: Optional[str],
875
+ input_stage_location: str,
876
+ input_file_pattern: str,
877
+ output_stage_location: str,
878
+ completion_filename: str,
879
+ force_rebuild: bool,
848
880
  num_workers: Optional[int],
849
881
  max_batch_rows: Optional[int],
850
- force_rebuild: bool,
851
- build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
882
+ cpu_requests: Optional[str],
883
+ memory_requests: Optional[str],
852
884
  statement_params: Optional[dict[str, Any]] = None,
853
- ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
854
- # fall back to the registry's database and schema if not provided
855
- database_name = database_name or self._database_name
856
- schema_name = schema_name or self._schema_name
885
+ ) -> jobs.MLJob[Any]:
886
+ database_name = self._database_name
887
+ schema_name = self._schema_name
857
888
 
858
- # fall back to the model's database and schema if not provided then to the registry's database and schema
859
- job_database_name = job_database_name or database_name or self._database_name
860
- job_schema_name = job_schema_name or schema_name or self._schema_name
889
+ job_database_name, job_schema_name, job_name = sql_identifier.parse_fully_qualified_name(job_name)
890
+ job_database_name = job_database_name or database_name
891
+ job_schema_name = job_schema_name or schema_name
861
892
 
862
- # Parse image repo
863
- image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
864
- image_repo
893
+ self._model_deployment_spec.clear()
894
+
895
+ self._model_deployment_spec.add_model_spec(
896
+ database_name=database_name,
897
+ schema_name=schema_name,
898
+ model_name=model_name,
899
+ version_name=version_name,
865
900
  )
866
- image_repo_database_name = image_repo_database_name or database_name or self._database_name
867
- image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
868
901
 
869
- input_table_database_name = job_database_name
870
- input_table_schema_name = job_schema_name
871
- output_table_database_name = output_table_database_name or database_name or self._database_name
872
- output_table_schema_name = output_table_schema_name or schema_name or self._schema_name
902
+ self._model_deployment_spec.add_job_spec(
903
+ job_database_name=job_database_name,
904
+ job_schema_name=job_schema_name,
905
+ job_name=job_name,
906
+ inference_compute_pool_name=compute_pool_name,
907
+ num_workers=num_workers,
908
+ max_batch_rows=max_batch_rows,
909
+ input_stage_location=input_stage_location,
910
+ input_file_pattern=input_file_pattern,
911
+ output_stage_location=output_stage_location,
912
+ completion_filename=completion_filename,
913
+ function_name=function_name,
914
+ warehouse=warehouse,
915
+ cpu=cpu_requests,
916
+ memory=memory_requests,
917
+ )
918
+
919
+ self._model_deployment_spec.add_image_build_spec(
920
+ image_build_compute_pool_name=compute_pool_name,
921
+ fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
922
+ force_rebuild=force_rebuild,
923
+ )
924
+
925
+ spec_yaml_str_or_path = self._model_deployment_spec.save()
873
926
 
874
927
  if self._workspace:
928
+ module_logger.info("using workspace")
875
929
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
930
+ file_utils.upload_directory_to_stage(
931
+ self._session,
932
+ local_path=pathlib.Path(self._workspace.name),
933
+ stage_path=pathlib.PurePosixPath(stage_path),
934
+ statement_params=statement_params,
935
+ )
876
936
  else:
937
+ module_logger.info("not using workspace")
877
938
  stage_path = None
878
939
 
879
- # validate and prepare input
880
- if not isinstance(X, dataframe.DataFrame):
881
- keep_order = True
882
- output_with_input_features = False
883
- df = model_signature._convert_and_validate_local_data(X, signature.inputs)
884
- s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
885
- self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
886
- )
887
- else:
888
- keep_order = False
889
- output_with_input_features = True
890
- s_df = X
891
-
892
- # only write the index and feature input columns
893
- cols = [snowpark_handler._KEEP_ORDER_COL_NAME] if snowpark_handler._KEEP_ORDER_COL_NAME in s_df.columns else []
894
- cols += [
895
- sql_identifier.SqlIdentifier(feature.name, case_sensitive=True).identifier() for feature in signature.inputs
896
- ]
897
- s_df = s_df.select(cols)
898
- original_cols = s_df.columns
899
-
900
- # input/output tables
901
- fq_output_table_name = identifier.get_schema_level_object_identifier(
902
- output_table_database_name.identifier(),
903
- output_table_schema_name.identifier(),
904
- output_table_name.identifier(),
905
- )
906
- tmp_input_table_id = sql_identifier.SqlIdentifier(
907
- snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
908
- )
909
- fq_tmp_input_table_name = identifier.get_schema_level_object_identifier(
910
- job_database_name.identifier(),
911
- job_schema_name.identifier(),
912
- tmp_input_table_id.identifier(),
913
- )
914
- s_df.write.save_as_table(
915
- table_name=fq_tmp_input_table_name,
916
- mode="errorifexists",
940
+ _, async_job = self._service_client.deploy_model(
941
+ stage_path=stage_path if self._workspace else None,
942
+ model_deployment_spec_file_rel_path=(
943
+ model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
944
+ ),
945
+ model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
917
946
  statement_params=statement_params,
918
947
  )
919
948
 
920
- try:
921
- self._model_deployment_spec.clear()
922
- # save the spec
923
- self._model_deployment_spec.add_model_spec(
924
- database_name=database_name,
925
- schema_name=schema_name,
926
- model_name=model_name,
927
- version_name=version_name,
928
- )
929
- self._model_deployment_spec.add_job_spec(
930
- job_database_name=job_database_name,
931
- job_schema_name=job_schema_name,
932
- job_name=job_name,
933
- inference_compute_pool_name=compute_pool_name,
934
- cpu=cpu_requests,
935
- memory=memory_requests,
936
- gpu=gpu_requests,
937
- num_workers=num_workers,
938
- max_batch_rows=max_batch_rows,
939
- warehouse=warehouse_name,
940
- target_method=target_method,
941
- input_table_database_name=input_table_database_name,
942
- input_table_schema_name=input_table_schema_name,
943
- input_table_name=tmp_input_table_id,
944
- output_table_database_name=output_table_database_name,
945
- output_table_schema_name=output_table_schema_name,
946
- output_table_name=output_table_name,
947
- )
948
-
949
- self._model_deployment_spec.add_image_build_spec(
950
- image_build_compute_pool_name=compute_pool_name,
951
- image_repo_database_name=image_repo_database_name,
952
- image_repo_schema_name=image_repo_schema_name,
953
- image_repo_name=image_repo_name,
954
- force_rebuild=force_rebuild,
955
- external_access_integrations=build_external_access_integrations,
956
- )
949
+ # Block until the async job is done
950
+ async_job.result()
957
951
 
958
- spec_yaml_str_or_path = self._model_deployment_spec.save()
959
- if self._workspace:
960
- assert stage_path is not None
961
- file_utils.upload_directory_to_stage(
962
- self._session,
963
- local_path=pathlib.Path(self._workspace.name),
964
- stage_path=pathlib.PurePosixPath(stage_path),
965
- statement_params=statement_params,
966
- )
967
-
968
- # deploy the job
969
- query_id, async_job = self._service_client.deploy_model(
970
- stage_path=stage_path if self._workspace else None,
971
- model_deployment_spec_file_rel_path=(
972
- model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
973
- ),
974
- model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
975
- statement_params=statement_params,
976
- )
977
-
978
- while not async_job.is_done():
979
- time.sleep(5)
980
- finally:
981
- self._session.table(fq_tmp_input_table_name).drop_table()
982
-
983
- # handle the output
984
- df_res = self._session.table(fq_output_table_name)
985
- if keep_order:
986
- df_res = df_res.sort(
987
- snowpark_handler._KEEP_ORDER_COL_NAME,
988
- ascending=True,
989
- )
990
- df_res = df_res.drop(snowpark_handler._KEEP_ORDER_COL_NAME)
991
-
992
- if not output_with_input_features:
993
- df_res = df_res.drop(*original_cols)
994
-
995
- # get final result
996
- if not isinstance(X, dataframe.DataFrame):
997
- return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
998
- df_res, features=signature.outputs, statement_params=statement_params
999
- )
1000
- else:
1001
- return df_res
952
+ return jobs.MLJob(
953
+ id=sql_identifier.get_fully_qualified_name(job_database_name, job_schema_name, job_name),
954
+ session=self._session,
955
+ )
1002
956
 
1003
957
  def _create_temp_stage(
1004
958
  self,
@@ -1,10 +1,12 @@
1
1
  import json
2
2
  import pathlib
3
+ import warnings
3
4
  from typing import Any, Optional, Union
4
5
 
5
6
  import yaml
6
7
 
7
8
  from snowflake.ml._internal.utils import identifier, sql_identifier
9
+ from snowflake.ml.model import inference_engine as inference_engine_module
8
10
  from snowflake.ml.model._client.service import model_deployment_spec_schema
9
11
 
10
12
 
@@ -24,6 +26,8 @@ class ModelDeploymentSpec:
24
26
  self._service: Optional[model_deployment_spec_schema.Service] = None
25
27
  self._job: Optional[model_deployment_spec_schema.Job] = None
26
28
  self._model_loggings: Optional[list[model_deployment_spec_schema.ModelLogging]] = None
29
+ # this is referring to custom inference engine spec (vllm, sglang, etc)
30
+ self._inference_engine_spec: Optional[model_deployment_spec_schema.InferenceEngineSpec] = None
27
31
  self._inference_spec: dict[str, Any] = {} # Common inference spec for service/job
28
32
 
29
33
  self.database: Optional[sql_identifier.SqlIdentifier] = None
@@ -71,10 +75,8 @@ class ModelDeploymentSpec:
71
75
 
72
76
  def add_image_build_spec(
73
77
  self,
74
- image_build_compute_pool_name: sql_identifier.SqlIdentifier,
75
- image_repo_name: sql_identifier.SqlIdentifier,
76
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier] = None,
77
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
78
+ image_build_compute_pool_name: Optional[sql_identifier.SqlIdentifier] = None,
79
+ fully_qualified_image_repo_name: Optional[str] = None,
78
80
  force_rebuild: bool = False,
79
81
  external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]] = None,
80
82
  ) -> "ModelDeploymentSpec":
@@ -82,33 +84,29 @@ class ModelDeploymentSpec:
82
84
 
83
85
  Args:
84
86
  image_build_compute_pool_name: Compute pool for image building.
85
- image_repo_name: Name of the image repository.
86
- image_repo_database_name: Database name for the image repository.
87
- image_repo_schema_name: Schema name for the image repository.
87
+ fully_qualified_image_repo_name: Fully qualified name of the image repository.
88
88
  force_rebuild: Whether to force rebuilding the image.
89
89
  external_access_integrations: List of external access integrations.
90
90
 
91
91
  Returns:
92
92
  Self for chaining.
93
93
  """
94
- saved_image_repo_database = image_repo_database_name or self.database
95
- saved_image_repo_schema = image_repo_schema_name or self.schema
96
- assert saved_image_repo_database is not None
97
- assert saved_image_repo_schema is not None
98
- fq_image_repo_name = identifier.get_schema_level_object_identifier(
99
- db=saved_image_repo_database.identifier(),
100
- schema=saved_image_repo_schema.identifier(),
101
- object_name=image_repo_name.identifier(),
102
- )
103
-
104
- self._image_build = model_deployment_spec_schema.ImageBuild(
105
- compute_pool=image_build_compute_pool_name.identifier(),
106
- image_repo=fq_image_repo_name,
107
- force_rebuild=force_rebuild,
108
- external_access_integrations=(
109
- [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
110
- ),
111
- )
94
+ if (
95
+ image_build_compute_pool_name is not None
96
+ or fully_qualified_image_repo_name is not None
97
+ or force_rebuild is True
98
+ or external_access_integrations is not None
99
+ ):
100
+ self._image_build = model_deployment_spec_schema.ImageBuild(
101
+ compute_pool=(
102
+ None if image_build_compute_pool_name is None else image_build_compute_pool_name.identifier()
103
+ ),
104
+ image_repo=fully_qualified_image_repo_name,
105
+ force_rebuild=force_rebuild,
106
+ external_access_integrations=(
107
+ [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
108
+ ),
109
+ )
112
110
  return self
113
111
 
114
112
  def _add_inference_spec(
@@ -196,16 +194,14 @@ class ModelDeploymentSpec:
196
194
  self,
197
195
  job_name: sql_identifier.SqlIdentifier,
198
196
  inference_compute_pool_name: sql_identifier.SqlIdentifier,
197
+ function_name: str,
198
+ input_stage_location: str,
199
+ output_stage_location: str,
200
+ completion_filename: str,
201
+ input_file_pattern: str,
199
202
  warehouse: sql_identifier.SqlIdentifier,
200
- target_method: str,
201
- input_table_name: sql_identifier.SqlIdentifier,
202
- output_table_name: sql_identifier.SqlIdentifier,
203
203
  job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
204
204
  job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
205
- input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
206
- input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
207
- output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
208
- output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
209
205
  cpu: Optional[str] = None,
210
206
  memory: Optional[str] = None,
211
207
  gpu: Optional[Union[str, int]] = None,
@@ -217,16 +213,14 @@ class ModelDeploymentSpec:
217
213
  Args:
218
214
  job_name: Name of the job.
219
215
  inference_compute_pool_name: Compute pool for inference.
216
+ warehouse: Warehouse for the job.
217
+ function_name: Function name.
218
+ input_stage_location: Stage location for input data.
219
+ output_stage_location: Stage location for output data.
220
220
  job_database_name: Database name for the job.
221
221
  job_schema_name: Schema name for the job.
222
- warehouse: Warehouse for the job.
223
- target_method: Target method for inference.
224
- input_table_name: Input table name.
225
- output_table_name: Output table name.
226
- input_table_database_name: Database for input table.
227
- input_table_schema_name: Schema for input table.
228
- output_table_database_name: Database for output table.
229
- output_table_schema_name: Schema for output table.
222
+ input_file_pattern: Pattern for input files (optional).
223
+ completion_filename: Name of completion file (default: "completion.txt").
230
224
  cpu: CPU requirement.
231
225
  memory: Memory requirement.
232
226
  gpu: GPU requirement.
@@ -244,41 +238,28 @@ class ModelDeploymentSpec:
244
238
 
245
239
  saved_job_database = job_database_name or self.database
246
240
  saved_job_schema = job_schema_name or self.schema
247
- input_table_database_name = input_table_database_name or self.database
248
- input_table_schema_name = input_table_schema_name or self.schema
249
- output_table_database_name = output_table_database_name or self.database
250
- output_table_schema_name = output_table_schema_name or self.schema
251
241
 
252
242
  assert saved_job_database is not None
253
243
  assert saved_job_schema is not None
254
- assert input_table_database_name is not None
255
- assert input_table_schema_name is not None
256
- assert output_table_database_name is not None
257
- assert output_table_schema_name is not None
258
244
 
259
245
  fq_job_name = identifier.get_schema_level_object_identifier(
260
246
  saved_job_database.identifier(), saved_job_schema.identifier(), job_name.identifier()
261
247
  )
262
- fq_input_table_name = identifier.get_schema_level_object_identifier(
263
- input_table_database_name.identifier(),
264
- input_table_schema_name.identifier(),
265
- input_table_name.identifier(),
266
- )
267
- fq_output_table_name = identifier.get_schema_level_object_identifier(
268
- output_table_database_name.identifier(),
269
- output_table_schema_name.identifier(),
270
- output_table_name.identifier(),
271
- )
272
248
 
273
249
  self._add_inference_spec(cpu, memory, gpu, num_workers, max_batch_rows)
274
250
 
275
251
  self._job = model_deployment_spec_schema.Job(
276
252
  name=fq_job_name,
277
253
  compute_pool=inference_compute_pool_name.identifier(),
278
- warehouse=warehouse.identifier(),
279
- target_method=target_method,
280
- input_table_name=fq_input_table_name,
281
- output_table_name=fq_output_table_name,
254
+ warehouse=warehouse.identifier() if warehouse else None,
255
+ function_name=function_name,
256
+ input=model_deployment_spec_schema.Input(
257
+ input_stage_location=input_stage_location, input_file_pattern=input_file_pattern
258
+ ),
259
+ output=model_deployment_spec_schema.Output(
260
+ output_stage_location=output_stage_location,
261
+ completion_filename=completion_filename,
262
+ ),
282
263
  **self._inference_spec,
283
264
  )
284
265
  return self
@@ -363,6 +344,86 @@ class ModelDeploymentSpec:
363
344
  self._model_loggings.append(model_logging)
364
345
  return self
365
346
 
347
+ def add_inference_engine_spec(
348
+ self,
349
+ inference_engine: inference_engine_module.InferenceEngine,
350
+ inference_engine_args: Optional[list[str]] = None,
351
+ ) -> "ModelDeploymentSpec":
352
+ """Add inference engine specification. This must be called after self.add_service_spec().
353
+
354
+ Args:
355
+ inference_engine: Inference engine.
356
+ inference_engine_args: Inference engine arguments.
357
+
358
+ Returns:
359
+ Self for chaining.
360
+
361
+ Raises:
362
+ ValueError: If inference engine specification is called before add_service_spec().
363
+ ValueError: If the argument does not have a '--' prefix.
364
+ """
365
+ # TODO: needs to eventually support job deployment spec
366
+ if self._service is None:
367
+ raise ValueError("Inference engine specification must be called after add_service_spec().")
368
+
369
+ if inference_engine_args is None:
370
+ inference_engine_args = []
371
+
372
+ # Validate inference engine
373
+ if inference_engine == inference_engine_module.InferenceEngine.VLLM:
374
+ # Block list for VLLM args that should not be user-configurable
375
+ # make this a set for faster lookup
376
+ block_list = {
377
+ "--host",
378
+ "--port",
379
+ "--allowed-headers",
380
+ "--api-key",
381
+ "--lora-modules",
382
+ "--prompt-adapter",
383
+ "--ssl-keyfile",
384
+ "--ssl-certfile",
385
+ "--ssl-ca-certs",
386
+ "--enable-ssl-refresh",
387
+ "--ssl-cert-reqs",
388
+ "--root-path",
389
+ "--middleware",
390
+ "--disable-frontend-multiprocessing",
391
+ "--enable-request-id-headers",
392
+ "--enable-auto-tool-choice",
393
+ "--tool-call-parser",
394
+ "--tool-parser-plugin",
395
+ "--log-config-file",
396
+ }
397
+
398
+ filtered_args = []
399
+ for arg in inference_engine_args:
400
+ # Check if the argument has a '--' prefix
401
+ if not arg.startswith("--"):
402
+ raise ValueError(
403
+ f"""The argument {arg} is not allowed for configuration in Snowflake ML's
404
+ {inference_engine.value} inference engine. Maybe you forgot to add '--' prefix?""",
405
+ )
406
+
407
+ # Filter out blocked args and warn user
408
+ if arg.split("=")[0] in block_list:
409
+ warnings.warn(
410
+ f"""The argument {arg} is not allowed for configuration in Snowflake ML's
411
+ {inference_engine.value} inference engine. It will be ignored.""",
412
+ UserWarning,
413
+ stacklevel=2,
414
+ )
415
+ else:
416
+ filtered_args.append(arg)
417
+
418
+ inference_engine_args = filtered_args
419
+
420
+ self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
421
+ # convert to string to be saved in the deployment spec
422
+ inference_engine_name=inference_engine.value,
423
+ inference_engine_args=inference_engine_args,
424
+ )
425
+ return self
426
+
366
427
  def save(self) -> str:
367
428
  """Constructs the final deployment spec from added components and saves it.
368
429
 
@@ -377,8 +438,6 @@ class ModelDeploymentSpec:
377
438
  # Validations
378
439
  if not self._models:
379
440
  raise ValueError("Model specification is required. Call add_model_spec().")
380
- if not self._image_build:
381
- raise ValueError("Image build specification is required. Call add_image_build_spec().")
382
441
  if not self._service and not self._job:
383
442
  raise ValueError(
384
443
  "Either service or job specification is required. Call add_service_spec() or add_job_spec()."