snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. snowflake/ml/_internal/telemetry.py +42 -13
  2. snowflake/ml/data/data_connector.py +1 -1
  3. snowflake/ml/jobs/_utils/constants.py +9 -0
  4. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  5. snowflake/ml/jobs/_utils/payload_utils.py +12 -4
  6. snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
  7. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +85 -2
  8. snowflake/ml/jobs/_utils/spec_utils.py +7 -5
  9. snowflake/ml/jobs/decorators.py +7 -3
  10. snowflake/ml/jobs/job.py +158 -25
  11. snowflake/ml/jobs/manager.py +29 -19
  12. snowflake/ml/model/_client/ops/service_ops.py +5 -3
  13. snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
  14. snowflake/ml/model/_client/sql/model_version.py +1 -1
  15. snowflake/ml/model/_client/sql/service.py +16 -19
  16. snowflake/ml/model/_model_composer/model_composer.py +3 -1
  17. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
  18. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
  19. snowflake/ml/monitoring/explain_visualize.py +160 -22
  20. snowflake/ml/utils/connection_params.py +8 -2
  21. snowflake/ml/version.py +1 -1
  22. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +27 -9
  23. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +26 -26
  24. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
  25. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  26. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py CHANGED
@@ -1,3 +1,5 @@
1
+ import logging
2
+ import os
1
3
  import time
2
4
  from functools import cached_property
3
5
  from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
@@ -16,6 +18,8 @@ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
16
18
 
17
19
  T = TypeVar("T")
18
20
 
21
+ logger = logging.getLogger(__name__)
22
+
19
23
 
20
24
  class MLJob(Generic[T]):
21
25
  def __init__(
@@ -36,8 +40,15 @@ class MLJob(Generic[T]):
36
40
  return identifier.parse_schema_level_object_identifier(self.id)[-1]
37
41
 
38
42
  @cached_property
39
- def num_instances(self) -> int:
40
- return _get_num_instances(self._session, self.id)
43
+ def target_instances(self) -> int:
44
+ return _get_target_instances(self._session, self.id)
45
+
46
+ @cached_property
47
+ def min_instances(self) -> int:
48
+ try:
49
+ return int(self._container_spec["env"].get(constants.MIN_INSTANCES_ENV_VAR, 1))
50
+ except TypeError:
51
+ return 1
41
52
 
42
53
  @property
43
54
  def id(self) -> str:
@@ -52,6 +63,12 @@ class MLJob(Generic[T]):
52
63
  self._status = _get_status(self._session, self.id)
53
64
  return self._status
54
65
 
66
+ @cached_property
67
+ def _compute_pool(self) -> str:
68
+ """Get the job's compute pool name."""
69
+ row = _get_service_info(self._session, self.id)
70
+ return cast(str, row["compute_pool"])
71
+
55
72
  @property
56
73
  def _service_spec(self) -> dict[str, Any]:
57
74
  """Get the job's service spec."""
@@ -82,15 +99,34 @@ class MLJob(Generic[T]):
82
99
  return f"{self._stage_path}/{result_path}"
83
100
 
84
101
  @overload
85
- def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[True]) -> list[str]:
102
+ def get_logs(
103
+ self,
104
+ limit: int = -1,
105
+ instance_id: Optional[int] = None,
106
+ *,
107
+ as_list: Literal[True],
108
+ verbose: bool = constants.DEFAULT_VERBOSE_LOG,
109
+ ) -> list[str]:
86
110
  ...
87
111
 
88
112
  @overload
89
- def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[False] = False) -> str:
113
+ def get_logs(
114
+ self,
115
+ limit: int = -1,
116
+ instance_id: Optional[int] = None,
117
+ *,
118
+ as_list: Literal[False] = False,
119
+ verbose: bool = constants.DEFAULT_VERBOSE_LOG,
120
+ ) -> str:
90
121
  ...
91
122
 
92
123
  def get_logs(
93
- self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: bool = False
124
+ self,
125
+ limit: int = -1,
126
+ instance_id: Optional[int] = None,
127
+ *,
128
+ as_list: bool = False,
129
+ verbose: bool = constants.DEFAULT_VERBOSE_LOG,
94
130
  ) -> Union[str, list[str]]:
95
131
  """
96
132
  Return the job's execution logs.
@@ -100,17 +136,20 @@ class MLJob(Generic[T]):
100
136
  instance_id: Optional instance ID to get logs from a specific instance.
101
137
  If not provided, returns logs from the head node.
102
138
  as_list: If True, returns logs as a list of lines. Otherwise, returns logs as a single string.
139
+ verbose: Whether to return the full log or just the user log.
103
140
 
104
141
  Returns:
105
142
  The job's execution logs.
106
143
  """
107
- logs = _get_logs(self._session, self.id, limit, instance_id)
144
+ logs = _get_logs(self._session, self.id, limit, instance_id, verbose)
108
145
  assert isinstance(logs, str) # mypy
109
146
  if as_list:
110
147
  return logs.splitlines()
111
148
  return logs
112
149
 
113
- def show_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> None:
150
+ def show_logs(
151
+ self, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = constants.DEFAULT_VERBOSE_LOG
152
+ ) -> None:
114
153
  """
115
154
  Display the job's execution logs.
116
155
 
@@ -118,8 +157,9 @@ class MLJob(Generic[T]):
118
157
  limit: The maximum number of lines to display. Negative values are treated as no limit.
119
158
  instance_id: Optional instance ID to get logs from a specific instance.
120
159
  If not provided, displays logs from the head node.
160
+ verbose: Whether to return the full log or just the user log.
121
161
  """
122
- print(self.get_logs(limit, instance_id, as_list=False)) # noqa: T201: we need to print here.
162
+ print(self.get_logs(limit, instance_id, as_list=False, verbose=verbose)) # noqa: T201: we need to print here.
123
163
 
124
164
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
125
165
  def wait(self, timeout: float = -1) -> types.JOB_STATUS:
@@ -137,7 +177,16 @@ class MLJob(Generic[T]):
137
177
  """
138
178
  delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
139
179
  start_time = time.monotonic()
140
- while self.status not in TERMINAL_JOB_STATUSES:
180
+ warning_shown = False
181
+ while (status := self.status) not in TERMINAL_JOB_STATUSES:
182
+ if status == "PENDING" and not warning_shown:
183
+ pool_info = _get_compute_pool_info(self._session, self._compute_pool)
184
+ if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
185
+ logger.warning(
186
+ f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use)."
187
+ " Job execution may be delayed."
188
+ )
189
+ warning_shown = True
141
190
  if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
142
191
  raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
143
192
  time.sleep(delay)
@@ -195,7 +244,9 @@ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
195
244
 
196
245
 
197
246
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
198
- def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None) -> str:
247
+ def _get_logs(
248
+ session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = True
249
+ ) -> str:
199
250
  """
200
251
  Retrieve the job's execution logs.
201
252
 
@@ -204,24 +255,20 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
204
255
  limit: The maximum number of lines to return. Negative values are treated as no limit.
205
256
  session: The Snowpark session to use. If none specified, uses active session.
206
257
  instance_id: Optional instance ID to get logs from a specific instance.
258
+ verbose: Whether to return the full log or just the portion between START and END messages.
207
259
 
208
260
  Returns:
209
261
  The job's execution logs.
210
262
 
211
263
  Raises:
212
- SnowparkSQLException: if the container is pending
213
264
  RuntimeError: if failed to get head instance_id
214
-
215
265
  """
216
266
  # If instance_id is not specified, try to get the head instance ID
217
267
  if instance_id is None:
218
268
  try:
219
269
  instance_id = _get_head_instance_id(session, job_id)
220
270
  except RuntimeError:
221
- raise RuntimeError(
222
- "Failed to retrieve job logs. "
223
- "Logs may be inaccessible due to job expiration and can be retrieved from Event Table instead."
224
- )
271
+ instance_id = None
225
272
 
226
273
  # Assemble params: [job_id, instance_id, container_name, (optional) limit]
227
274
  params: list[Any] = [
@@ -231,7 +278,6 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
231
278
  ]
232
279
  if limit > 0:
233
280
  params.append(limit)
234
-
235
281
  try:
236
282
  (row,) = session.sql(
237
283
  f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
@@ -239,9 +285,43 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
239
285
  ).collect()
240
286
  except SnowparkSQLException as e:
241
287
  if "Container Status: PENDING" in e.message:
242
- return "Warning: Waiting for container to start. Logs will be shown when available."
243
- raise
244
- return str(row[0])
288
+ logger.warning("Waiting for container to start. Logs will be shown when available.")
289
+ return ""
290
+ else:
291
+ # event table accepts job name, not fully qualified name
292
+ # cast is to resolve the type check error
293
+ db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
294
+ db = cast(str, db or session.get_current_database())
295
+ schema = cast(str, schema or session.get_current_schema())
296
+ logs = _get_service_log_from_event_table(
297
+ session, db, schema, name, limit, instance_id if instance_id else None
298
+ )
299
+ if len(logs) == 0:
300
+ raise RuntimeError(
301
+ "No logs were found. Please verify that the database, schema, and job ID are correct."
302
+ )
303
+ return os.linesep.join(row[0] for row in logs)
304
+
305
+ full_log = str(row[0])
306
+
307
+ # If verbose is True, return the complete log
308
+ if verbose:
309
+ return full_log
310
+
311
+ # Otherwise, extract only the portion between LOG_START_MSG and LOG_END_MSG
312
+ start_idx = full_log.find(constants.LOG_START_MSG)
313
+ if start_idx != -1:
314
+ start_idx += len(constants.LOG_START_MSG)
315
+ else:
316
+ # If start message not found, start from the beginning
317
+ start_idx = 0
318
+
319
+ end_idx = full_log.find(constants.LOG_END_MSG, start_idx)
320
+ if end_idx == -1:
321
+ # If end message not found, return everything after start
322
+ end_idx = len(full_log)
323
+
324
+ return full_log[start_idx:end_idx].strip()
245
325
 
246
326
 
247
327
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
@@ -256,13 +336,18 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
256
336
  Returns:
257
337
  Optional[int]: The head instance ID of the job, or None if the head instance has not started yet.
258
338
 
259
- Raises:
339
+ Raises:
260
340
  RuntimeError: If the instances died or if some instances disappeared.
341
+
261
342
  """
262
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
343
+ try:
344
+ rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
345
+ except SnowparkSQLException:
346
+ # service may be deleted
347
+ raise RuntimeError("Couldn’t retrieve instances")
263
348
  if not rows:
264
349
  return None
265
- if _get_num_instances(session, job_id) > len(rows):
350
+ if _get_target_instances(session, job_id) > len(rows):
266
351
  raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
267
352
 
268
353
  # Sort by start_time first, then by instance_id
@@ -270,7 +355,6 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
270
355
  sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
271
356
  except TypeError:
272
357
  raise RuntimeError("Job instance information unavailable.")
273
-
274
358
  head_instance = sorted_instances[0]
275
359
  if not head_instance["start_time"]:
276
360
  # If head instance hasn't started yet, return None
@@ -281,12 +365,61 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
281
365
  return 0
282
366
 
283
367
 
368
+ def _get_service_log_from_event_table(
369
+ session: snowpark.Session, database: str, schema: str, name: str, limit: int, instance_id: Optional[int]
370
+ ) -> list[Row]:
371
+ params: list[Any] = [
372
+ database,
373
+ schema,
374
+ name,
375
+ ]
376
+ query = [
377
+ "SELECT VALUE FROM snowflake.telemetry.events_view",
378
+ 'WHERE RESOURCE_ATTRIBUTES:"snow.database.name" = ?',
379
+ 'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
380
+ 'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
381
+ ]
382
+
383
+ if instance_id:
384
+ query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
385
+ params.append(instance_id)
386
+
387
+ query.append("AND RECORD_TYPE = 'LOG'")
388
+ # sort by TIMESTAMP; although OBSERVED_TIMESTAMP is for log, it is NONE currently when record_type is log
389
+ query.append("ORDER BY TIMESTAMP")
390
+
391
+ if limit > 0:
392
+ query.append("LIMIT ?")
393
+ params.append(limit)
394
+
395
+ rows = session.sql(
396
+ "\n".join(line for line in query if line),
397
+ params=params,
398
+ ).collect()
399
+ return rows
400
+
401
+
284
402
  def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
285
403
  (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
286
404
  return row
287
405
 
288
406
 
407
+ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
408
+ """
409
+ Check if the compute pool has enough available instances.
410
+
411
+ Args:
412
+ session (Session): The Snowpark session to use.
413
+ compute_pool (str): The name of the compute pool.
414
+
415
+ Returns:
416
+ Row: The compute pool information.
417
+ """
418
+ (pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
419
+ return pool_info
420
+
421
+
289
422
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
290
- def _get_num_instances(session: snowpark.Session, job_id: str) -> int:
423
+ def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
291
424
  row = _get_service_info(session, job_id)
292
425
  return int(row["target_instances"]) if row["target_instances"] else 0
@@ -108,7 +108,8 @@ def submit_file(
108
108
  external_access_integrations: Optional[list[str]] = None,
109
109
  query_warehouse: Optional[str] = None,
110
110
  spec_overrides: Optional[dict[str, Any]] = None,
111
- num_instances: Optional[int] = None,
111
+ target_instances: int = 1,
112
+ min_instances: int = 1,
112
113
  enable_metrics: bool = False,
113
114
  database: Optional[str] = None,
114
115
  schema: Optional[str] = None,
@@ -127,7 +128,8 @@ def submit_file(
127
128
  external_access_integrations: A list of external access integrations.
128
129
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
129
130
  spec_overrides: Custom service specification overrides to apply.
130
- num_instances: The number of instances to use for the job. If none specified, single node job is created.
131
+ target_instances: The number of instances to use for the job. If none specified, single node job is created.
132
+ min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
131
133
  enable_metrics: Whether to enable metrics publishing for the job.
132
134
  database: The database to use.
133
135
  schema: The schema to use.
@@ -146,7 +148,8 @@ def submit_file(
146
148
  external_access_integrations=external_access_integrations,
147
149
  query_warehouse=query_warehouse,
148
150
  spec_overrides=spec_overrides,
149
- num_instances=num_instances,
151
+ target_instances=target_instances,
152
+ min_instances=min_instances,
150
153
  enable_metrics=enable_metrics,
151
154
  database=database,
152
155
  schema=schema,
@@ -167,7 +170,8 @@ def submit_directory(
167
170
  external_access_integrations: Optional[list[str]] = None,
168
171
  query_warehouse: Optional[str] = None,
169
172
  spec_overrides: Optional[dict[str, Any]] = None,
170
- num_instances: Optional[int] = None,
173
+ target_instances: int = 1,
174
+ min_instances: int = 1,
171
175
  enable_metrics: bool = False,
172
176
  database: Optional[str] = None,
173
177
  schema: Optional[str] = None,
@@ -187,7 +191,8 @@ def submit_directory(
187
191
  external_access_integrations: A list of external access integrations.
188
192
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
189
193
  spec_overrides: Custom service specification overrides to apply.
190
- num_instances: The number of instances to use for the job. If none specified, single node job is created.
194
+ target_instances: The number of instances to use for the job. If none specified, single node job is created.
195
+ min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
191
196
  enable_metrics: Whether to enable metrics publishing for the job.
192
197
  database: The database to use.
193
198
  schema: The schema to use.
@@ -207,7 +212,8 @@ def submit_directory(
207
212
  external_access_integrations=external_access_integrations,
208
213
  query_warehouse=query_warehouse,
209
214
  spec_overrides=spec_overrides,
210
- num_instances=num_instances,
215
+ target_instances=target_instances,
216
+ min_instances=min_instances,
211
217
  enable_metrics=enable_metrics,
212
218
  database=database,
213
219
  schema=schema,
@@ -228,7 +234,8 @@ def _submit_job(
228
234
  external_access_integrations: Optional[list[str]] = None,
229
235
  query_warehouse: Optional[str] = None,
230
236
  spec_overrides: Optional[dict[str, Any]] = None,
231
- num_instances: Optional[int] = None,
237
+ target_instances: int = 1,
238
+ min_instances: int = 1,
232
239
  enable_metrics: bool = False,
233
240
  database: Optional[str] = None,
234
241
  schema: Optional[str] = None,
@@ -250,7 +257,8 @@ def _submit_job(
250
257
  external_access_integrations: Optional[list[str]] = None,
251
258
  query_warehouse: Optional[str] = None,
252
259
  spec_overrides: Optional[dict[str, Any]] = None,
253
- num_instances: Optional[int] = None,
260
+ target_instances: int = 1,
261
+ min_instances: int = 1,
254
262
  enable_metrics: bool = False,
255
263
  database: Optional[str] = None,
256
264
  schema: Optional[str] = None,
@@ -267,7 +275,7 @@ def _submit_job(
267
275
  # TODO: Log lengths of args, env_vars, and spec_overrides values
268
276
  "pip_requirements",
269
277
  "external_access_integrations",
270
- "num_instances",
278
+ "target_instances",
271
279
  "enable_metrics",
272
280
  ],
273
281
  )
@@ -283,7 +291,8 @@ def _submit_job(
283
291
  external_access_integrations: Optional[list[str]] = None,
284
292
  query_warehouse: Optional[str] = None,
285
293
  spec_overrides: Optional[dict[str, Any]] = None,
286
- num_instances: Optional[int] = None,
294
+ target_instances: int = 1,
295
+ min_instances: int = 1,
287
296
  enable_metrics: bool = False,
288
297
  database: Optional[str] = None,
289
298
  schema: Optional[str] = None,
@@ -303,7 +312,8 @@ def _submit_job(
303
312
  external_access_integrations: A list of external access integrations.
304
313
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
305
314
  spec_overrides: Custom service specification overrides to apply.
306
- num_instances: The number of instances to use for the job. If none specified, single node job is created.
315
+ target_instances: The number of instances to use for the job. If none specified, single node job is created.
316
+ min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
307
317
  enable_metrics: Whether to enable metrics publishing for the job.
308
318
  database: The database to use.
309
319
  schema: The schema to use.
@@ -316,13 +326,12 @@ def _submit_job(
316
326
  RuntimeError: If required Snowflake features are not enabled.
317
327
  ValueError: If database or schema value(s) are invalid
318
328
  """
319
- # Display warning about PrPr parameters
320
- if num_instances is not None:
321
- logger.warning(
322
- "_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
323
- )
324
329
  if database and not schema:
325
330
  raise ValueError("Schema must be specified if database is specified.")
331
+ if target_instances < 1 or min_instances < 1:
332
+ raise ValueError("target_instances and min_instances must be greater than 0.")
333
+ if min_instances > target_instances:
334
+ raise ValueError("min_instances must be less than or equal to target_instances.")
326
335
 
327
336
  session = session or get_active_session()
328
337
 
@@ -350,7 +359,8 @@ def _submit_job(
350
359
  compute_pool=compute_pool,
351
360
  payload=uploaded_payload,
352
361
  args=args,
353
- num_instances=num_instances,
362
+ target_instances=target_instances,
363
+ min_instances=min_instances,
354
364
  enable_metrics=enable_metrics,
355
365
  )
356
366
  spec_overrides = spec_utils.generate_spec_overrides(
@@ -381,9 +391,9 @@ def _submit_job(
381
391
  if query_warehouse:
382
392
  query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
383
393
  params.append(query_warehouse)
384
- if num_instances:
394
+ if target_instances > 1:
385
395
  query.append("REPLICAS = ?")
386
- params.append(num_instances)
396
+ params.append(target_instances)
387
397
 
388
398
  # Submit job
389
399
  query_text = "\n".join(line for line in query if line)
@@ -125,6 +125,7 @@ class ServiceOperator:
125
125
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
126
126
  else:
127
127
  stage_path = None
128
+ self._model_deployment_spec.clear()
128
129
  self._model_deployment_spec.add_model_spec(
129
130
  database_name=database_name,
130
131
  schema_name=schema_name,
@@ -168,7 +169,7 @@ class ServiceOperator:
168
169
  schema_name=service_schema_name,
169
170
  service_name=service_name,
170
171
  service_status_list_if_exists=[
171
- service_sql.ServiceStatus.READY,
172
+ service_sql.ServiceStatus.RUNNING,
172
173
  service_sql.ServiceStatus.SUSPENDING,
173
174
  service_sql.ServiceStatus.SUSPENDED,
174
175
  ],
@@ -331,7 +332,7 @@ class ServiceOperator:
331
332
  include_message=True,
332
333
  statement_params=statement_params,
333
334
  )
334
- if (service_status != service_sql.ServiceStatus.READY) or (
335
+ if (service_status != service_sql.ServiceStatus.RUNNING) or (
335
336
  service_status != service_log_meta.service_status
336
337
  ):
337
338
  service_log_meta.service_status = service_status
@@ -428,7 +429,7 @@ class ServiceOperator:
428
429
  if service_status_list_if_exists is None:
429
430
  service_status_list_if_exists = [
430
431
  service_sql.ServiceStatus.PENDING,
431
- service_sql.ServiceStatus.READY,
432
+ service_sql.ServiceStatus.RUNNING,
432
433
  service_sql.ServiceStatus.SUSPENDING,
433
434
  service_sql.ServiceStatus.SUSPENDED,
434
435
  service_sql.ServiceStatus.DONE,
@@ -538,6 +539,7 @@ class ServiceOperator:
538
539
  )
539
540
 
540
541
  try:
542
+ self._model_deployment_spec.clear()
541
543
  # save the spec
542
544
  self._model_deployment_spec.add_model_spec(
543
545
  database_name=database_name,
@@ -29,6 +29,17 @@ class ModelDeploymentSpec:
29
29
  self.database: Optional[sql_identifier.SqlIdentifier] = None
30
30
  self.schema: Optional[sql_identifier.SqlIdentifier] = None
31
31
 
32
+ def clear(self) -> None:
33
+ """Reset the deployment spec to its initial state."""
34
+ self._models = []
35
+ self._image_build = None
36
+ self._service = None
37
+ self._job = None
38
+ self._model_loggings = None
39
+ self._inference_spec = {}
40
+ self.database = None
41
+ self.schema = None
42
+
32
43
  def add_model_spec(
33
44
  self,
34
45
  database_name: sql_identifier.SqlIdentifier,
@@ -293,7 +293,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
293
293
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
294
294
  options = {"parallel": 10}
295
295
  cursor = self._session._conn._cursor
296
- cursor._download(stage_location_url, str(target_path), options) # type: ignore[union-attr]
296
+ cursor._download(stage_location_url, str(target_path), options)
297
297
  cursor.fetchall()
298
298
  else:
299
299
  query_result_checker.SqlResultValidator(
@@ -1,5 +1,4 @@
1
1
  import enum
2
- import json
3
2
  import textwrap
4
3
  from typing import Any, Optional, Union
5
4
 
@@ -15,22 +14,25 @@ from snowflake.snowpark import dataframe, functions as F, row, types as spt
15
14
  from snowflake.snowpark._internal import utils as snowpark_utils
16
15
 
17
16
 
17
+ # The enum comes from https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service#output
18
+ # except UNKNOWN
18
19
  class ServiceStatus(enum.Enum):
19
20
  UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
20
21
  PENDING = "PENDING" # resource set is being created, can't be used yet
21
- READY = "READY" # resource set has been deployed.
22
22
  SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
23
23
  SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
24
24
  DELETING = "DELETING" # resource set is being deleted
25
25
  FAILED = "FAILED" # resource set has failed and cannot be used anymore
26
26
  DONE = "DONE" # resource set has finished running
27
- NOT_FOUND = "NOT_FOUND" # not found or deleted
28
27
  INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
28
+ RUNNING = "RUNNING"
29
+ DELETED = "DELETED"
29
30
 
30
31
 
31
32
  class ServiceSQLClient(_base._BaseSQLClient):
32
33
  MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
33
34
  MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
35
+ SERVICE_STATUS = "service_status"
34
36
 
35
37
  def build_model_container(
36
38
  self,
@@ -199,22 +201,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
199
201
  include_message: bool = False,
200
202
  statement_params: Optional[dict[str, Any]] = None,
201
203
  ) -> tuple[ServiceStatus, Optional[str]]:
202
- system_func = "SYSTEM$GET_SERVICE_STATUS"
203
- rows = (
204
- query_result_checker.SqlResultValidator(
205
- self._session,
206
- f"CALL {system_func}('{self.fully_qualified_object_name(database_name, schema_name, service_name)}')",
207
- statement_params=statement_params,
208
- )
209
- .has_dimensions(expected_rows=1, expected_cols=1)
210
- .validate()
211
- )
212
- metadata = json.loads(rows[0][system_func])[0]
213
- if metadata and metadata["status"]:
214
- service_status = ServiceStatus(metadata["status"])
215
- message = metadata["message"] if include_message else None
216
- return service_status, message
217
- return ServiceStatus.UNKNOWN, None
204
+ fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
205
+ query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
206
+ rows = self._session.sql(query).collect(statement_params=statement_params)
207
+ if len(rows) == 0:
208
+ return ServiceStatus.UNKNOWN, None
209
+ row = rows[0]
210
+ service_status = row[ServiceSQLClient.SERVICE_STATUS]
211
+ message = row["message"] if include_message else None
212
+ if not isinstance(service_status, ServiceStatus):
213
+ return ServiceStatus.UNKNOWN, message
214
+ return ServiceStatus(service_status), message
218
215
 
219
216
  def drop_service(
220
217
  self,
@@ -188,7 +188,9 @@ class ModelComposer:
188
188
  if not options:
189
189
  options = model_types.BaseModelSaveOption()
190
190
 
191
- if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
191
+ if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
192
+ model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
193
+ ]:
192
194
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
193
195
  self.session,
194
196
  reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
@@ -216,7 +216,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
216
216
  explain_fn=cls._build_explain_fn(model, background_data, input_signature),
217
217
  output_feature_names=transformed_background_data.columns,
218
218
  )
219
- except ValueError:
219
+ except Exception:
220
220
  if kwargs.get("enable_explainability", None):
221
221
  # user explicitly enabled explainability, so we should raise the error
222
222
  raise ValueError(
@@ -12,7 +12,7 @@ REQUIREMENTS = [
12
12
  "importlib_resources>=6.1.1, <7",
13
13
  "numpy>=1.23,<2",
14
14
  "packaging>=20.9,<25",
15
- "pandas>=1.0.0,<3",
15
+ "pandas>=2.1.4,<3",
16
16
  "pyarrow",
17
17
  "pydantic>=2.8.2, <3",
18
18
  "pyjwt>=2.0.0, <3",
@@ -24,9 +24,10 @@ REQUIREMENTS = [
24
24
  "scikit-learn<1.6",
25
25
  "scipy>=1.9,<2",
26
26
  "shap>=0.46.0,<1",
27
- "snowflake-connector-python>=3.14.0,<4",
27
+ "snowflake-connector-python>=3.15.0,<4",
28
28
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
29
29
  "snowflake.core>=1.0.2,<2",
30
30
  "sqlparse>=0.4,<1",
31
31
  "typing-extensions>=4.1.0,<5",
32
+ "xgboost>=1.7.3,<3",
32
33
  ]