snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/ml/_internal/platform_capabilities.py +13 -11
  3. snowflake/ml/_internal/utils/identifier.py +2 -2
  4. snowflake/ml/jobs/_utils/constants.py +1 -1
  5. snowflake/ml/jobs/_utils/payload_utils.py +39 -30
  6. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  7. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +1 -1
  8. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  9. snowflake/ml/jobs/decorators.py +6 -0
  10. snowflake/ml/jobs/job.py +63 -16
  11. snowflake/ml/jobs/manager.py +50 -16
  12. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  13. snowflake/ml/model/_client/ops/service_ops.py +26 -14
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +340 -170
  15. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
  16. snowflake/ml/model/_client/sql/service.py +4 -13
  17. snowflake/ml/model/_model_composer/model_composer.py +41 -18
  18. snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
  19. snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
  20. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
  21. snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
  22. snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
  23. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  24. snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
  25. snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
  26. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +4 -4
  28. snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
  29. snowflake/ml/model/custom_model.py +17 -4
  30. snowflake/ml/model/model_signature.py +3 -3
  31. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  32. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  33. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  34. snowflake/ml/modeling/cluster/birch.py +9 -1
  35. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  36. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  37. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  38. snowflake/ml/modeling/cluster/k_means.py +9 -1
  39. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  40. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  41. snowflake/ml/modeling/cluster/optics.py +9 -1
  42. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  43. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  44. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  45. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  46. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  47. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  48. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  49. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  50. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  51. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  52. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  53. snowflake/ml/modeling/covariance/oas.py +9 -1
  54. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  55. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  56. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  57. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  58. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  59. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  60. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  61. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  62. snowflake/ml/modeling/decomposition/pca.py +9 -1
  63. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  64. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  65. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  66. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  67. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  68. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  69. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  70. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  71. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  72. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  73. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  74. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  77. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  78. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  79. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  80. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  81. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  82. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  83. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  84. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  85. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  86. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  87. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  88. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  89. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  90. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  91. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  93. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  94. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  95. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  96. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  97. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  98. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  99. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  100. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  104. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  106. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  108. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  109. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  110. snowflake/ml/modeling/linear_model/lars.py +9 -1
  111. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  112. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  113. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  114. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  115. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  116. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  117. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  118. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  119. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  120. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  122. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  123. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  124. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  125. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  126. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  127. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  128. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  129. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  130. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  131. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  132. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  133. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  134. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  135. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  136. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  137. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  138. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  139. snowflake/ml/modeling/manifold/isomap.py +9 -1
  140. snowflake/ml/modeling/manifold/mds.py +9 -1
  141. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  142. snowflake/ml/modeling/manifold/tsne.py +9 -1
  143. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  144. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  145. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  146. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  147. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  148. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  149. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  150. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  151. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  152. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  153. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  154. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  155. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  156. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  157. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  158. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  159. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  160. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  161. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  162. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  163. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  164. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  165. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  166. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  167. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  168. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  169. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  170. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  171. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  172. snowflake/ml/modeling/svm/svc.py +9 -1
  173. snowflake/ml/modeling/svm/svr.py +9 -1
  174. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  175. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  176. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  177. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  178. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  179. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  180. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  181. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  182. snowflake/ml/monitoring/explain_visualize.py +286 -0
  183. snowflake/ml/registry/_manager/model_manager.py +23 -2
  184. snowflake/ml/registry/registry.py +10 -9
  185. snowflake/ml/version.py +1 -1
  186. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +40 -8
  187. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/RECORD +190 -189
  188. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  189. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  190. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import pathlib
3
3
  import textwrap
4
- from typing import Any, Callable, Literal, Optional, TypeVar, Union, overload
4
+ from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast, overload
5
5
  from uuid import uuid4
6
6
 
7
7
  import yaml
@@ -52,7 +52,7 @@ def list_jobs(
52
52
  query += f" LIMIT {limit}"
53
53
  df = session.sql(query)
54
54
  df = df.select(
55
- df['"name"'].alias('"id"'),
55
+ df['"name"'],
56
56
  df['"owner"'],
57
57
  df['"status"'],
58
58
  df['"created_on"'],
@@ -65,16 +65,16 @@ def list_jobs(
65
65
  def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
66
66
  """Retrieve a job service from the backend."""
67
67
  session = session or get_active_session()
68
-
69
68
  try:
70
- # Validate job_id
71
- job_id = identifier.resolve_identifier(job_id)
69
+ database, schema, job_name = identifier.parse_schema_level_object_identifier(job_id)
70
+ database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
71
+ schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
72
72
  except ValueError as e:
73
73
  raise ValueError(f"Invalid job ID: {job_id}") from e
74
74
 
75
+ job_id = f"{database}.{schema}.{job_name}"
75
76
  try:
76
77
  # Validate that job exists by doing a status check
77
- # FIXME: Retrieve return path
78
78
  job = jb.MLJob[Any](job_id, session=session)
79
79
  _ = job.status
80
80
  return job
@@ -110,6 +110,8 @@ def submit_file(
110
110
  spec_overrides: Optional[dict[str, Any]] = None,
111
111
  num_instances: Optional[int] = None,
112
112
  enable_metrics: bool = False,
113
+ database: Optional[str] = None,
114
+ schema: Optional[str] = None,
113
115
  session: Optional[snowpark.Session] = None,
114
116
  ) -> jb.MLJob[None]:
115
117
  """
@@ -127,6 +129,8 @@ def submit_file(
127
129
  spec_overrides: Custom service specification overrides to apply.
128
130
  num_instances: The number of instances to use for the job. If none specified, single node job is created.
129
131
  enable_metrics: Whether to enable metrics publishing for the job.
132
+ database: The database to use.
133
+ schema: The schema to use.
130
134
  session: The Snowpark session to use. If none specified, uses active session.
131
135
 
132
136
  Returns:
@@ -144,6 +148,8 @@ def submit_file(
144
148
  spec_overrides=spec_overrides,
145
149
  num_instances=num_instances,
146
150
  enable_metrics=enable_metrics,
151
+ database=database,
152
+ schema=schema,
147
153
  session=session,
148
154
  )
149
155
 
@@ -163,6 +169,8 @@ def submit_directory(
163
169
  spec_overrides: Optional[dict[str, Any]] = None,
164
170
  num_instances: Optional[int] = None,
165
171
  enable_metrics: bool = False,
172
+ database: Optional[str] = None,
173
+ schema: Optional[str] = None,
166
174
  session: Optional[snowpark.Session] = None,
167
175
  ) -> jb.MLJob[None]:
168
176
  """
@@ -181,6 +189,8 @@ def submit_directory(
181
189
  spec_overrides: Custom service specification overrides to apply.
182
190
  num_instances: The number of instances to use for the job. If none specified, single node job is created.
183
191
  enable_metrics: Whether to enable metrics publishing for the job.
192
+ database: The database to use.
193
+ schema: The schema to use.
184
194
  session: The Snowpark session to use. If none specified, uses active session.
185
195
 
186
196
  Returns:
@@ -199,6 +209,8 @@ def submit_directory(
199
209
  spec_overrides=spec_overrides,
200
210
  num_instances=num_instances,
201
211
  enable_metrics=enable_metrics,
212
+ database=database,
213
+ schema=schema,
202
214
  session=session,
203
215
  )
204
216
 
@@ -218,6 +230,8 @@ def _submit_job(
218
230
  spec_overrides: Optional[dict[str, Any]] = None,
219
231
  num_instances: Optional[int] = None,
220
232
  enable_metrics: bool = False,
233
+ database: Optional[str] = None,
234
+ schema: Optional[str] = None,
221
235
  session: Optional[snowpark.Session] = None,
222
236
  ) -> jb.MLJob[None]:
223
237
  ...
@@ -238,6 +252,8 @@ def _submit_job(
238
252
  spec_overrides: Optional[dict[str, Any]] = None,
239
253
  num_instances: Optional[int] = None,
240
254
  enable_metrics: bool = False,
255
+ database: Optional[str] = None,
256
+ schema: Optional[str] = None,
241
257
  session: Optional[snowpark.Session] = None,
242
258
  ) -> jb.MLJob[T]:
243
259
  ...
@@ -269,6 +285,8 @@ def _submit_job(
269
285
  spec_overrides: Optional[dict[str, Any]] = None,
270
286
  num_instances: Optional[int] = None,
271
287
  enable_metrics: bool = False,
288
+ database: Optional[str] = None,
289
+ schema: Optional[str] = None,
272
290
  session: Optional[snowpark.Session] = None,
273
291
  ) -> jb.MLJob[T]:
274
292
  """
@@ -287,6 +305,8 @@ def _submit_job(
287
305
  spec_overrides: Custom service specification overrides to apply.
288
306
  num_instances: The number of instances to use for the job. If none specified, single node job is created.
289
307
  enable_metrics: Whether to enable metrics publishing for the job.
308
+ database: The database to use.
309
+ schema: The schema to use.
290
310
  session: The Snowpark session to use. If none specified, uses active session.
291
311
 
292
312
  Returns:
@@ -294,17 +314,28 @@ def _submit_job(
294
314
 
295
315
  Raises:
296
316
  RuntimeError: If required Snowflake features are not enabled.
317
+ ValueError: If database or schema value(s) are invalid
297
318
  """
298
319
  # Display warning about PrPr parameters
299
320
  if num_instances is not None:
300
321
  logger.warning(
301
322
  "_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
302
323
  )
324
+ if database and not schema:
325
+ raise ValueError("Schema must be specified if database is specified.")
303
326
 
304
327
  session = session or get_active_session()
305
- job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
306
- stage_name = "@" + stage_name.lstrip("@").rstrip("/")
307
- stage_path = pathlib.PurePosixPath(f"{stage_name}/{job_id}")
328
+
329
+ # Validate database and schema identifiers on client side since
330
+ # SQL parser for EXECUTE JOB SERVICE seems to struggle with this
331
+ database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
332
+ schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
333
+
334
+ job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
335
+ job_id = f"{database}.{schema}.{job_name}"
336
+ stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
337
+ stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
338
+ stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
308
339
 
309
340
  # Upload payload
310
341
  uploaded_payload = payload_utils.JobPayload(
@@ -331,31 +362,34 @@ def _submit_job(
331
362
 
332
363
  # Generate SQL command for job submission
333
364
  query_template = textwrap.dedent(
334
- f"""\
365
+ """\
335
366
  EXECUTE JOB SERVICE
336
- IN COMPUTE POOL {compute_pool}
367
+ IN COMPUTE POOL IDENTIFIER(?)
337
368
  FROM SPECIFICATION $$
338
- {{}}
369
+ {}
339
370
  $$
340
- NAME = {job_id}
371
+ NAME = IDENTIFIER(?)
341
372
  ASYNC = TRUE
342
373
  """
343
374
  )
375
+ params: list[Any] = [compute_pool, job_id]
344
376
  query = query_template.format(yaml.dump(spec)).splitlines()
345
377
  if external_access_integrations:
346
378
  external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
347
379
  query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
348
380
  query_warehouse = query_warehouse or session.get_current_warehouse()
349
381
  if query_warehouse:
350
- query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
382
+ query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
383
+ params.append(query_warehouse)
351
384
  if num_instances:
352
- query.append(f"REPLICAS = {num_instances}")
385
+ query.append("REPLICAS = ?")
386
+ params.append(num_instances)
353
387
 
354
388
  # Submit job
355
389
  query_text = "\n".join(line for line in query if line)
356
390
 
357
391
  try:
358
- _ = session.sql(query_text).collect()
392
+ _ = session.sql(query_text, params=params).collect()
359
393
  except SnowparkSQLException as e:
360
394
  if "invalid property 'ASYNC'" in e.message:
361
395
  raise RuntimeError(
@@ -920,7 +920,7 @@ class ModelVersion(lineage_node.LineageNode):
920
920
  project=_TELEMETRY_PROJECT,
921
921
  subproject=_TELEMETRY_SUBPROJECT,
922
922
  )
923
- def run_job(
923
+ def _run_job(
924
924
  self,
925
925
  X: Union[pd.DataFrame, "dataframe.DataFrame"],
926
926
  *,
@@ -125,19 +125,25 @@ class ServiceOperator:
125
125
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
126
126
  else:
127
127
  stage_path = None
128
- spec_yaml_str_or_path = self._model_deployment_spec.save(
128
+ self._model_deployment_spec.add_model_spec(
129
129
  database_name=database_name,
130
130
  schema_name=schema_name,
131
131
  model_name=model_name,
132
132
  version_name=version_name,
133
- service_database_name=service_database_name,
134
- service_schema_name=service_schema_name,
135
- service_name=service_name,
133
+ )
134
+ self._model_deployment_spec.add_image_build_spec(
136
135
  image_build_compute_pool_name=image_build_compute_pool_name,
137
- inference_compute_pool_name=service_compute_pool_name,
138
136
  image_repo_database_name=image_repo_database_name,
139
137
  image_repo_schema_name=image_repo_schema_name,
140
138
  image_repo_name=image_repo_name,
139
+ force_rebuild=force_rebuild,
140
+ external_access_integrations=build_external_access_integrations,
141
+ )
142
+ self._model_deployment_spec.add_service_spec(
143
+ service_database_name=service_database_name,
144
+ service_schema_name=service_schema_name,
145
+ service_name=service_name,
146
+ inference_compute_pool_name=service_compute_pool_name,
141
147
  ingress_enabled=ingress_enabled,
142
148
  max_instances=max_instances,
143
149
  cpu=cpu_requests,
@@ -145,9 +151,8 @@ class ServiceOperator:
145
151
  gpu=gpu_requests,
146
152
  num_workers=num_workers,
147
153
  max_batch_rows=max_batch_rows,
148
- force_rebuild=force_rebuild,
149
- external_access_integrations=build_external_access_integrations,
150
154
  )
155
+ spec_yaml_str_or_path = self._model_deployment_spec.save()
151
156
  if self._workspace:
152
157
  assert stage_path is not None
153
158
  file_utils.upload_directory_to_stage(
@@ -534,26 +539,22 @@ class ServiceOperator:
534
539
 
535
540
  try:
536
541
  # save the spec
537
- spec_yaml_str_or_path = self._model_deployment_spec.save(
542
+ self._model_deployment_spec.add_model_spec(
538
543
  database_name=database_name,
539
544
  schema_name=schema_name,
540
545
  model_name=model_name,
541
546
  version_name=version_name,
547
+ )
548
+ self._model_deployment_spec.add_job_spec(
542
549
  job_database_name=job_database_name,
543
550
  job_schema_name=job_schema_name,
544
551
  job_name=job_name,
545
- image_build_compute_pool_name=compute_pool_name,
546
552
  inference_compute_pool_name=compute_pool_name,
547
- image_repo_database_name=image_repo_database_name,
548
- image_repo_schema_name=image_repo_schema_name,
549
- image_repo_name=image_repo_name,
550
553
  cpu=cpu_requests,
551
554
  memory=memory_requests,
552
555
  gpu=gpu_requests,
553
556
  num_workers=num_workers,
554
557
  max_batch_rows=max_batch_rows,
555
- force_rebuild=force_rebuild,
556
- external_access_integrations=build_external_access_integrations,
557
558
  warehouse=warehouse_name,
558
559
  target_method=target_method,
559
560
  input_table_database_name=input_table_database_name,
@@ -563,6 +564,17 @@ class ServiceOperator:
563
564
  output_table_schema_name=output_table_schema_name,
564
565
  output_table_name=output_table_name,
565
566
  )
567
+
568
+ self._model_deployment_spec.add_image_build_spec(
569
+ image_build_compute_pool_name=compute_pool_name,
570
+ image_repo_database_name=image_repo_database_name,
571
+ image_repo_schema_name=image_repo_schema_name,
572
+ image_repo_name=image_repo_name,
573
+ force_rebuild=force_rebuild,
574
+ external_access_integrations=build_external_access_integrations,
575
+ )
576
+
577
+ spec_yaml_str_or_path = self._model_deployment_spec.save()
566
578
  if self._workspace:
567
579
  assert stage_path is not None
568
580
  file_utils.upload_directory_to_stage(