snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +42 -13
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/_utils/constants.py +9 -0
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +12 -4
- snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +85 -2
- snowflake/ml/jobs/_utils/spec_utils.py +7 -5
- snowflake/ml/jobs/decorators.py +7 -3
- snowflake/ml/jobs/job.py +158 -25
- snowflake/ml/jobs/manager.py +29 -19
- snowflake/ml/model/_client/ops/service_ops.py +5 -3
- snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +16 -19
- snowflake/ml/model/_model_composer/model_composer.py +3 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
- snowflake/ml/monitoring/explain_visualize.py +160 -22
- snowflake/ml/utils/connection_params.py +8 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +27 -9
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +26 -26
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
1
3
|
import time
|
2
4
|
from functools import cached_property
|
3
5
|
from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
|
@@ -16,6 +18,8 @@ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
|
16
18
|
|
17
19
|
T = TypeVar("T")
|
18
20
|
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
19
23
|
|
20
24
|
class MLJob(Generic[T]):
|
21
25
|
def __init__(
|
@@ -36,8 +40,15 @@ class MLJob(Generic[T]):
|
|
36
40
|
return identifier.parse_schema_level_object_identifier(self.id)[-1]
|
37
41
|
|
38
42
|
@cached_property
|
39
|
-
def
|
40
|
-
return
|
43
|
+
def target_instances(self) -> int:
|
44
|
+
return _get_target_instances(self._session, self.id)
|
45
|
+
|
46
|
+
@cached_property
|
47
|
+
def min_instances(self) -> int:
|
48
|
+
try:
|
49
|
+
return int(self._container_spec["env"].get(constants.MIN_INSTANCES_ENV_VAR, 1))
|
50
|
+
except TypeError:
|
51
|
+
return 1
|
41
52
|
|
42
53
|
@property
|
43
54
|
def id(self) -> str:
|
@@ -52,6 +63,12 @@ class MLJob(Generic[T]):
|
|
52
63
|
self._status = _get_status(self._session, self.id)
|
53
64
|
return self._status
|
54
65
|
|
66
|
+
@cached_property
|
67
|
+
def _compute_pool(self) -> str:
|
68
|
+
"""Get the job's compute pool name."""
|
69
|
+
row = _get_service_info(self._session, self.id)
|
70
|
+
return cast(str, row["compute_pool"])
|
71
|
+
|
55
72
|
@property
|
56
73
|
def _service_spec(self) -> dict[str, Any]:
|
57
74
|
"""Get the job's service spec."""
|
@@ -82,15 +99,34 @@ class MLJob(Generic[T]):
|
|
82
99
|
return f"{self._stage_path}/{result_path}"
|
83
100
|
|
84
101
|
@overload
|
85
|
-
def get_logs(
|
102
|
+
def get_logs(
|
103
|
+
self,
|
104
|
+
limit: int = -1,
|
105
|
+
instance_id: Optional[int] = None,
|
106
|
+
*,
|
107
|
+
as_list: Literal[True],
|
108
|
+
verbose: bool = constants.DEFAULT_VERBOSE_LOG,
|
109
|
+
) -> list[str]:
|
86
110
|
...
|
87
111
|
|
88
112
|
@overload
|
89
|
-
def get_logs(
|
113
|
+
def get_logs(
|
114
|
+
self,
|
115
|
+
limit: int = -1,
|
116
|
+
instance_id: Optional[int] = None,
|
117
|
+
*,
|
118
|
+
as_list: Literal[False] = False,
|
119
|
+
verbose: bool = constants.DEFAULT_VERBOSE_LOG,
|
120
|
+
) -> str:
|
90
121
|
...
|
91
122
|
|
92
123
|
def get_logs(
|
93
|
-
self,
|
124
|
+
self,
|
125
|
+
limit: int = -1,
|
126
|
+
instance_id: Optional[int] = None,
|
127
|
+
*,
|
128
|
+
as_list: bool = False,
|
129
|
+
verbose: bool = constants.DEFAULT_VERBOSE_LOG,
|
94
130
|
) -> Union[str, list[str]]:
|
95
131
|
"""
|
96
132
|
Return the job's execution logs.
|
@@ -100,17 +136,20 @@ class MLJob(Generic[T]):
|
|
100
136
|
instance_id: Optional instance ID to get logs from a specific instance.
|
101
137
|
If not provided, returns logs from the head node.
|
102
138
|
as_list: If True, returns logs as a list of lines. Otherwise, returns logs as a single string.
|
139
|
+
verbose: Whether to return the full log or just the user log.
|
103
140
|
|
104
141
|
Returns:
|
105
142
|
The job's execution logs.
|
106
143
|
"""
|
107
|
-
logs = _get_logs(self._session, self.id, limit, instance_id)
|
144
|
+
logs = _get_logs(self._session, self.id, limit, instance_id, verbose)
|
108
145
|
assert isinstance(logs, str) # mypy
|
109
146
|
if as_list:
|
110
147
|
return logs.splitlines()
|
111
148
|
return logs
|
112
149
|
|
113
|
-
def show_logs(
|
150
|
+
def show_logs(
|
151
|
+
self, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = constants.DEFAULT_VERBOSE_LOG
|
152
|
+
) -> None:
|
114
153
|
"""
|
115
154
|
Display the job's execution logs.
|
116
155
|
|
@@ -118,8 +157,9 @@ class MLJob(Generic[T]):
|
|
118
157
|
limit: The maximum number of lines to display. Negative values are treated as no limit.
|
119
158
|
instance_id: Optional instance ID to get logs from a specific instance.
|
120
159
|
If not provided, displays logs from the head node.
|
160
|
+
verbose: Whether to return the full log or just the user log.
|
121
161
|
"""
|
122
|
-
print(self.get_logs(limit, instance_id, as_list=False)) # noqa: T201: we need to print here.
|
162
|
+
print(self.get_logs(limit, instance_id, as_list=False, verbose=verbose)) # noqa: T201: we need to print here.
|
123
163
|
|
124
164
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
125
165
|
def wait(self, timeout: float = -1) -> types.JOB_STATUS:
|
@@ -137,7 +177,16 @@ class MLJob(Generic[T]):
|
|
137
177
|
"""
|
138
178
|
delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
|
139
179
|
start_time = time.monotonic()
|
140
|
-
|
180
|
+
warning_shown = False
|
181
|
+
while (status := self.status) not in TERMINAL_JOB_STATUSES:
|
182
|
+
if status == "PENDING" and not warning_shown:
|
183
|
+
pool_info = _get_compute_pool_info(self._session, self._compute_pool)
|
184
|
+
if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
|
185
|
+
logger.warning(
|
186
|
+
f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use)."
|
187
|
+
" Job execution may be delayed."
|
188
|
+
)
|
189
|
+
warning_shown = True
|
141
190
|
if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
|
142
191
|
raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
|
143
192
|
time.sleep(delay)
|
@@ -195,7 +244,9 @@ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
|
195
244
|
|
196
245
|
|
197
246
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
198
|
-
def _get_logs(
|
247
|
+
def _get_logs(
|
248
|
+
session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = True
|
249
|
+
) -> str:
|
199
250
|
"""
|
200
251
|
Retrieve the job's execution logs.
|
201
252
|
|
@@ -204,24 +255,20 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
204
255
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
205
256
|
session: The Snowpark session to use. If none specified, uses active session.
|
206
257
|
instance_id: Optional instance ID to get logs from a specific instance.
|
258
|
+
verbose: Whether to return the full log or just the portion between START and END messages.
|
207
259
|
|
208
260
|
Returns:
|
209
261
|
The job's execution logs.
|
210
262
|
|
211
263
|
Raises:
|
212
|
-
SnowparkSQLException: if the container is pending
|
213
264
|
RuntimeError: if failed to get head instance_id
|
214
|
-
|
215
265
|
"""
|
216
266
|
# If instance_id is not specified, try to get the head instance ID
|
217
267
|
if instance_id is None:
|
218
268
|
try:
|
219
269
|
instance_id = _get_head_instance_id(session, job_id)
|
220
270
|
except RuntimeError:
|
221
|
-
|
222
|
-
"Failed to retrieve job logs. "
|
223
|
-
"Logs may be inaccessible due to job expiration and can be retrieved from Event Table instead."
|
224
|
-
)
|
271
|
+
instance_id = None
|
225
272
|
|
226
273
|
# Assemble params: [job_id, instance_id, container_name, (optional) limit]
|
227
274
|
params: list[Any] = [
|
@@ -231,7 +278,6 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
231
278
|
]
|
232
279
|
if limit > 0:
|
233
280
|
params.append(limit)
|
234
|
-
|
235
281
|
try:
|
236
282
|
(row,) = session.sql(
|
237
283
|
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
@@ -239,9 +285,43 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
239
285
|
).collect()
|
240
286
|
except SnowparkSQLException as e:
|
241
287
|
if "Container Status: PENDING" in e.message:
|
242
|
-
|
243
|
-
|
244
|
-
|
288
|
+
logger.warning("Waiting for container to start. Logs will be shown when available.")
|
289
|
+
return ""
|
290
|
+
else:
|
291
|
+
# event table accepts job name, not fully qualified name
|
292
|
+
# cast is to resolve the type check error
|
293
|
+
db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
|
294
|
+
db = cast(str, db or session.get_current_database())
|
295
|
+
schema = cast(str, schema or session.get_current_schema())
|
296
|
+
logs = _get_service_log_from_event_table(
|
297
|
+
session, db, schema, name, limit, instance_id if instance_id else None
|
298
|
+
)
|
299
|
+
if len(logs) == 0:
|
300
|
+
raise RuntimeError(
|
301
|
+
"No logs were found. Please verify that the database, schema, and job ID are correct."
|
302
|
+
)
|
303
|
+
return os.linesep.join(row[0] for row in logs)
|
304
|
+
|
305
|
+
full_log = str(row[0])
|
306
|
+
|
307
|
+
# If verbose is True, return the complete log
|
308
|
+
if verbose:
|
309
|
+
return full_log
|
310
|
+
|
311
|
+
# Otherwise, extract only the portion between LOG_START_MSG and LOG_END_MSG
|
312
|
+
start_idx = full_log.find(constants.LOG_START_MSG)
|
313
|
+
if start_idx != -1:
|
314
|
+
start_idx += len(constants.LOG_START_MSG)
|
315
|
+
else:
|
316
|
+
# If start message not found, start from the beginning
|
317
|
+
start_idx = 0
|
318
|
+
|
319
|
+
end_idx = full_log.find(constants.LOG_END_MSG, start_idx)
|
320
|
+
if end_idx == -1:
|
321
|
+
# If end message not found, return everything after start
|
322
|
+
end_idx = len(full_log)
|
323
|
+
|
324
|
+
return full_log[start_idx:end_idx].strip()
|
245
325
|
|
246
326
|
|
247
327
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
@@ -256,13 +336,18 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
256
336
|
Returns:
|
257
337
|
Optional[int]: The head instance ID of the job, or None if the head instance has not started yet.
|
258
338
|
|
259
|
-
|
339
|
+
Raises:
|
260
340
|
RuntimeError: If the instances died or if some instances disappeared.
|
341
|
+
|
261
342
|
"""
|
262
|
-
|
343
|
+
try:
|
344
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
345
|
+
except SnowparkSQLException:
|
346
|
+
# service may be deleted
|
347
|
+
raise RuntimeError("Couldn’t retrieve instances")
|
263
348
|
if not rows:
|
264
349
|
return None
|
265
|
-
if
|
350
|
+
if _get_target_instances(session, job_id) > len(rows):
|
266
351
|
raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
|
267
352
|
|
268
353
|
# Sort by start_time first, then by instance_id
|
@@ -270,7 +355,6 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
270
355
|
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
271
356
|
except TypeError:
|
272
357
|
raise RuntimeError("Job instance information unavailable.")
|
273
|
-
|
274
358
|
head_instance = sorted_instances[0]
|
275
359
|
if not head_instance["start_time"]:
|
276
360
|
# If head instance hasn't started yet, return None
|
@@ -281,12 +365,61 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
281
365
|
return 0
|
282
366
|
|
283
367
|
|
368
|
+
def _get_service_log_from_event_table(
|
369
|
+
session: snowpark.Session, database: str, schema: str, name: str, limit: int, instance_id: Optional[int]
|
370
|
+
) -> list[Row]:
|
371
|
+
params: list[Any] = [
|
372
|
+
database,
|
373
|
+
schema,
|
374
|
+
name,
|
375
|
+
]
|
376
|
+
query = [
|
377
|
+
"SELECT VALUE FROM snowflake.telemetry.events_view",
|
378
|
+
'WHERE RESOURCE_ATTRIBUTES:"snow.database.name" = ?',
|
379
|
+
'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
|
380
|
+
'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
|
381
|
+
]
|
382
|
+
|
383
|
+
if instance_id:
|
384
|
+
query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
|
385
|
+
params.append(instance_id)
|
386
|
+
|
387
|
+
query.append("AND RECORD_TYPE = 'LOG'")
|
388
|
+
# sort by TIMESTAMP; although OBSERVED_TIMESTAMP is for log, it is NONE currently when record_type is log
|
389
|
+
query.append("ORDER BY TIMESTAMP")
|
390
|
+
|
391
|
+
if limit > 0:
|
392
|
+
query.append("LIMIT ?")
|
393
|
+
params.append(limit)
|
394
|
+
|
395
|
+
rows = session.sql(
|
396
|
+
"\n".join(line for line in query if line),
|
397
|
+
params=params,
|
398
|
+
).collect()
|
399
|
+
return rows
|
400
|
+
|
401
|
+
|
284
402
|
def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
|
285
403
|
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
286
404
|
return row
|
287
405
|
|
288
406
|
|
407
|
+
def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
408
|
+
"""
|
409
|
+
Check if the compute pool has enough available instances.
|
410
|
+
|
411
|
+
Args:
|
412
|
+
session (Session): The Snowpark session to use.
|
413
|
+
compute_pool (str): The name of the compute pool.
|
414
|
+
|
415
|
+
Returns:
|
416
|
+
Row: The compute pool information.
|
417
|
+
"""
|
418
|
+
(pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
|
419
|
+
return pool_info
|
420
|
+
|
421
|
+
|
289
422
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
290
|
-
def
|
423
|
+
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
291
424
|
row = _get_service_info(session, job_id)
|
292
425
|
return int(row["target_instances"]) if row["target_instances"] else 0
|
snowflake/ml/jobs/manager.py
CHANGED
@@ -108,7 +108,8 @@ def submit_file(
|
|
108
108
|
external_access_integrations: Optional[list[str]] = None,
|
109
109
|
query_warehouse: Optional[str] = None,
|
110
110
|
spec_overrides: Optional[dict[str, Any]] = None,
|
111
|
-
|
111
|
+
target_instances: int = 1,
|
112
|
+
min_instances: int = 1,
|
112
113
|
enable_metrics: bool = False,
|
113
114
|
database: Optional[str] = None,
|
114
115
|
schema: Optional[str] = None,
|
@@ -127,7 +128,8 @@ def submit_file(
|
|
127
128
|
external_access_integrations: A list of external access integrations.
|
128
129
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
129
130
|
spec_overrides: Custom service specification overrides to apply.
|
130
|
-
|
131
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
132
|
+
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
131
133
|
enable_metrics: Whether to enable metrics publishing for the job.
|
132
134
|
database: The database to use.
|
133
135
|
schema: The schema to use.
|
@@ -146,7 +148,8 @@ def submit_file(
|
|
146
148
|
external_access_integrations=external_access_integrations,
|
147
149
|
query_warehouse=query_warehouse,
|
148
150
|
spec_overrides=spec_overrides,
|
149
|
-
|
151
|
+
target_instances=target_instances,
|
152
|
+
min_instances=min_instances,
|
150
153
|
enable_metrics=enable_metrics,
|
151
154
|
database=database,
|
152
155
|
schema=schema,
|
@@ -167,7 +170,8 @@ def submit_directory(
|
|
167
170
|
external_access_integrations: Optional[list[str]] = None,
|
168
171
|
query_warehouse: Optional[str] = None,
|
169
172
|
spec_overrides: Optional[dict[str, Any]] = None,
|
170
|
-
|
173
|
+
target_instances: int = 1,
|
174
|
+
min_instances: int = 1,
|
171
175
|
enable_metrics: bool = False,
|
172
176
|
database: Optional[str] = None,
|
173
177
|
schema: Optional[str] = None,
|
@@ -187,7 +191,8 @@ def submit_directory(
|
|
187
191
|
external_access_integrations: A list of external access integrations.
|
188
192
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
189
193
|
spec_overrides: Custom service specification overrides to apply.
|
190
|
-
|
194
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
195
|
+
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
191
196
|
enable_metrics: Whether to enable metrics publishing for the job.
|
192
197
|
database: The database to use.
|
193
198
|
schema: The schema to use.
|
@@ -207,7 +212,8 @@ def submit_directory(
|
|
207
212
|
external_access_integrations=external_access_integrations,
|
208
213
|
query_warehouse=query_warehouse,
|
209
214
|
spec_overrides=spec_overrides,
|
210
|
-
|
215
|
+
target_instances=target_instances,
|
216
|
+
min_instances=min_instances,
|
211
217
|
enable_metrics=enable_metrics,
|
212
218
|
database=database,
|
213
219
|
schema=schema,
|
@@ -228,7 +234,8 @@ def _submit_job(
|
|
228
234
|
external_access_integrations: Optional[list[str]] = None,
|
229
235
|
query_warehouse: Optional[str] = None,
|
230
236
|
spec_overrides: Optional[dict[str, Any]] = None,
|
231
|
-
|
237
|
+
target_instances: int = 1,
|
238
|
+
min_instances: int = 1,
|
232
239
|
enable_metrics: bool = False,
|
233
240
|
database: Optional[str] = None,
|
234
241
|
schema: Optional[str] = None,
|
@@ -250,7 +257,8 @@ def _submit_job(
|
|
250
257
|
external_access_integrations: Optional[list[str]] = None,
|
251
258
|
query_warehouse: Optional[str] = None,
|
252
259
|
spec_overrides: Optional[dict[str, Any]] = None,
|
253
|
-
|
260
|
+
target_instances: int = 1,
|
261
|
+
min_instances: int = 1,
|
254
262
|
enable_metrics: bool = False,
|
255
263
|
database: Optional[str] = None,
|
256
264
|
schema: Optional[str] = None,
|
@@ -267,7 +275,7 @@ def _submit_job(
|
|
267
275
|
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
268
276
|
"pip_requirements",
|
269
277
|
"external_access_integrations",
|
270
|
-
"
|
278
|
+
"target_instances",
|
271
279
|
"enable_metrics",
|
272
280
|
],
|
273
281
|
)
|
@@ -283,7 +291,8 @@ def _submit_job(
|
|
283
291
|
external_access_integrations: Optional[list[str]] = None,
|
284
292
|
query_warehouse: Optional[str] = None,
|
285
293
|
spec_overrides: Optional[dict[str, Any]] = None,
|
286
|
-
|
294
|
+
target_instances: int = 1,
|
295
|
+
min_instances: int = 1,
|
287
296
|
enable_metrics: bool = False,
|
288
297
|
database: Optional[str] = None,
|
289
298
|
schema: Optional[str] = None,
|
@@ -303,7 +312,8 @@ def _submit_job(
|
|
303
312
|
external_access_integrations: A list of external access integrations.
|
304
313
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
305
314
|
spec_overrides: Custom service specification overrides to apply.
|
306
|
-
|
315
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
316
|
+
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
307
317
|
enable_metrics: Whether to enable metrics publishing for the job.
|
308
318
|
database: The database to use.
|
309
319
|
schema: The schema to use.
|
@@ -316,13 +326,12 @@ def _submit_job(
|
|
316
326
|
RuntimeError: If required Snowflake features are not enabled.
|
317
327
|
ValueError: If database or schema value(s) are invalid
|
318
328
|
"""
|
319
|
-
# Display warning about PrPr parameters
|
320
|
-
if num_instances is not None:
|
321
|
-
logger.warning(
|
322
|
-
"_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
|
323
|
-
)
|
324
329
|
if database and not schema:
|
325
330
|
raise ValueError("Schema must be specified if database is specified.")
|
331
|
+
if target_instances < 1 or min_instances < 1:
|
332
|
+
raise ValueError("target_instances and min_instances must be greater than 0.")
|
333
|
+
if min_instances > target_instances:
|
334
|
+
raise ValueError("min_instances must be less than or equal to target_instances.")
|
326
335
|
|
327
336
|
session = session or get_active_session()
|
328
337
|
|
@@ -350,7 +359,8 @@ def _submit_job(
|
|
350
359
|
compute_pool=compute_pool,
|
351
360
|
payload=uploaded_payload,
|
352
361
|
args=args,
|
353
|
-
|
362
|
+
target_instances=target_instances,
|
363
|
+
min_instances=min_instances,
|
354
364
|
enable_metrics=enable_metrics,
|
355
365
|
)
|
356
366
|
spec_overrides = spec_utils.generate_spec_overrides(
|
@@ -381,9 +391,9 @@ def _submit_job(
|
|
381
391
|
if query_warehouse:
|
382
392
|
query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
|
383
393
|
params.append(query_warehouse)
|
384
|
-
if
|
394
|
+
if target_instances > 1:
|
385
395
|
query.append("REPLICAS = ?")
|
386
|
-
params.append(
|
396
|
+
params.append(target_instances)
|
387
397
|
|
388
398
|
# Submit job
|
389
399
|
query_text = "\n".join(line for line in query if line)
|
@@ -125,6 +125,7 @@ class ServiceOperator:
|
|
125
125
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
126
126
|
else:
|
127
127
|
stage_path = None
|
128
|
+
self._model_deployment_spec.clear()
|
128
129
|
self._model_deployment_spec.add_model_spec(
|
129
130
|
database_name=database_name,
|
130
131
|
schema_name=schema_name,
|
@@ -168,7 +169,7 @@ class ServiceOperator:
|
|
168
169
|
schema_name=service_schema_name,
|
169
170
|
service_name=service_name,
|
170
171
|
service_status_list_if_exists=[
|
171
|
-
service_sql.ServiceStatus.
|
172
|
+
service_sql.ServiceStatus.RUNNING,
|
172
173
|
service_sql.ServiceStatus.SUSPENDING,
|
173
174
|
service_sql.ServiceStatus.SUSPENDED,
|
174
175
|
],
|
@@ -331,7 +332,7 @@ class ServiceOperator:
|
|
331
332
|
include_message=True,
|
332
333
|
statement_params=statement_params,
|
333
334
|
)
|
334
|
-
if (service_status != service_sql.ServiceStatus.
|
335
|
+
if (service_status != service_sql.ServiceStatus.RUNNING) or (
|
335
336
|
service_status != service_log_meta.service_status
|
336
337
|
):
|
337
338
|
service_log_meta.service_status = service_status
|
@@ -428,7 +429,7 @@ class ServiceOperator:
|
|
428
429
|
if service_status_list_if_exists is None:
|
429
430
|
service_status_list_if_exists = [
|
430
431
|
service_sql.ServiceStatus.PENDING,
|
431
|
-
service_sql.ServiceStatus.
|
432
|
+
service_sql.ServiceStatus.RUNNING,
|
432
433
|
service_sql.ServiceStatus.SUSPENDING,
|
433
434
|
service_sql.ServiceStatus.SUSPENDED,
|
434
435
|
service_sql.ServiceStatus.DONE,
|
@@ -538,6 +539,7 @@ class ServiceOperator:
|
|
538
539
|
)
|
539
540
|
|
540
541
|
try:
|
542
|
+
self._model_deployment_spec.clear()
|
541
543
|
# save the spec
|
542
544
|
self._model_deployment_spec.add_model_spec(
|
543
545
|
database_name=database_name,
|
@@ -29,6 +29,17 @@ class ModelDeploymentSpec:
|
|
29
29
|
self.database: Optional[sql_identifier.SqlIdentifier] = None
|
30
30
|
self.schema: Optional[sql_identifier.SqlIdentifier] = None
|
31
31
|
|
32
|
+
def clear(self) -> None:
|
33
|
+
"""Reset the deployment spec to its initial state."""
|
34
|
+
self._models = []
|
35
|
+
self._image_build = None
|
36
|
+
self._service = None
|
37
|
+
self._job = None
|
38
|
+
self._model_loggings = None
|
39
|
+
self._inference_spec = {}
|
40
|
+
self.database = None
|
41
|
+
self.schema = None
|
42
|
+
|
32
43
|
def add_model_spec(
|
33
44
|
self,
|
34
45
|
database_name: sql_identifier.SqlIdentifier,
|
@@ -293,7 +293,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
293
293
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
294
294
|
options = {"parallel": 10}
|
295
295
|
cursor = self._session._conn._cursor
|
296
|
-
cursor._download(stage_location_url, str(target_path), options)
|
296
|
+
cursor._download(stage_location_url, str(target_path), options)
|
297
297
|
cursor.fetchall()
|
298
298
|
else:
|
299
299
|
query_result_checker.SqlResultValidator(
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import enum
|
2
|
-
import json
|
3
2
|
import textwrap
|
4
3
|
from typing import Any, Optional, Union
|
5
4
|
|
@@ -15,22 +14,25 @@ from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
|
15
14
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
16
15
|
|
17
16
|
|
17
|
+
# The enum comes from https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service#output
|
18
|
+
# except UNKNOWN
|
18
19
|
class ServiceStatus(enum.Enum):
|
19
20
|
UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
|
20
21
|
PENDING = "PENDING" # resource set is being created, can't be used yet
|
21
|
-
READY = "READY" # resource set has been deployed.
|
22
22
|
SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
|
23
23
|
SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
|
24
24
|
DELETING = "DELETING" # resource set is being deleted
|
25
25
|
FAILED = "FAILED" # resource set has failed and cannot be used anymore
|
26
26
|
DONE = "DONE" # resource set has finished running
|
27
|
-
NOT_FOUND = "NOT_FOUND" # not found or deleted
|
28
27
|
INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
|
28
|
+
RUNNING = "RUNNING"
|
29
|
+
DELETED = "DELETED"
|
29
30
|
|
30
31
|
|
31
32
|
class ServiceSQLClient(_base._BaseSQLClient):
|
32
33
|
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
33
34
|
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
35
|
+
SERVICE_STATUS = "service_status"
|
34
36
|
|
35
37
|
def build_model_container(
|
36
38
|
self,
|
@@ -199,22 +201,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
199
201
|
include_message: bool = False,
|
200
202
|
statement_params: Optional[dict[str, Any]] = None,
|
201
203
|
) -> tuple[ServiceStatus, Optional[str]]:
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
if metadata and metadata["status"]:
|
214
|
-
service_status = ServiceStatus(metadata["status"])
|
215
|
-
message = metadata["message"] if include_message else None
|
216
|
-
return service_status, message
|
217
|
-
return ServiceStatus.UNKNOWN, None
|
204
|
+
fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
205
|
+
query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
|
206
|
+
rows = self._session.sql(query).collect(statement_params=statement_params)
|
207
|
+
if len(rows) == 0:
|
208
|
+
return ServiceStatus.UNKNOWN, None
|
209
|
+
row = rows[0]
|
210
|
+
service_status = row[ServiceSQLClient.SERVICE_STATUS]
|
211
|
+
message = row["message"] if include_message else None
|
212
|
+
if not isinstance(service_status, ServiceStatus):
|
213
|
+
return ServiceStatus.UNKNOWN, message
|
214
|
+
return ServiceStatus(service_status), message
|
218
215
|
|
219
216
|
def drop_service(
|
220
217
|
self,
|
@@ -188,7 +188,9 @@ class ModelComposer:
|
|
188
188
|
if not options:
|
189
189
|
options = model_types.BaseModelSaveOption()
|
190
190
|
|
191
|
-
if not snowpark_utils.is_in_stored_procedure()
|
191
|
+
if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
|
192
|
+
model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
|
193
|
+
]:
|
192
194
|
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
193
195
|
self.session,
|
194
196
|
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
@@ -216,7 +216,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
216
216
|
explain_fn=cls._build_explain_fn(model, background_data, input_signature),
|
217
217
|
output_feature_names=transformed_background_data.columns,
|
218
218
|
)
|
219
|
-
except
|
219
|
+
except Exception:
|
220
220
|
if kwargs.get("enable_explainability", None):
|
221
221
|
# user explicitly enabled explainability, so we should raise the error
|
222
222
|
raise ValueError(
|
@@ -12,7 +12,7 @@ REQUIREMENTS = [
|
|
12
12
|
"importlib_resources>=6.1.1, <7",
|
13
13
|
"numpy>=1.23,<2",
|
14
14
|
"packaging>=20.9,<25",
|
15
|
-
"pandas>=1.
|
15
|
+
"pandas>=2.1.4,<3",
|
16
16
|
"pyarrow",
|
17
17
|
"pydantic>=2.8.2, <3",
|
18
18
|
"pyjwt>=2.0.0, <3",
|
@@ -24,9 +24,10 @@ REQUIREMENTS = [
|
|
24
24
|
"scikit-learn<1.6",
|
25
25
|
"scipy>=1.9,<2",
|
26
26
|
"shap>=0.46.0,<1",
|
27
|
-
"snowflake-connector-python>=3.
|
27
|
+
"snowflake-connector-python>=3.15.0,<4",
|
28
28
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
29
29
|
"snowflake.core>=1.0.2,<2",
|
30
30
|
"sqlparse>=0.4,<1",
|
31
31
|
"typing-extensions>=4.1.0,<5",
|
32
|
+
"xgboost>=1.7.3,<3",
|
32
33
|
]
|