snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/telemetry.py +3 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
- snowflake/ml/experiment/callback/keras.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +3 -0
- snowflake/ml/experiment/callback/xgboost.py +3 -0
- snowflake/ml/experiment/experiment_tracking.py +19 -7
- snowflake/ml/feature_store/feature_store.py +236 -61
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +16 -2
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +8 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +2 -1
- snowflake/ml/jobs/_utils/stage_utils.py +4 -0
- snowflake/ml/jobs/_utils/types.py +15 -0
- snowflake/ml/jobs/job.py +186 -40
- snowflake/ml/jobs/manager.py +48 -39
- snowflake/ml/model/__init__.py +19 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
- snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
- snowflake/ml/model/_client/model/model_version_impl.py +168 -18
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +22 -5
- snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
- snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/type_hints.py +16 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py
CHANGED
|
@@ -12,12 +12,19 @@ from snowflake import snowpark
|
|
|
12
12
|
from snowflake.ml._internal import telemetry
|
|
13
13
|
from snowflake.ml._internal.utils import identifier
|
|
14
14
|
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
|
15
|
-
from snowflake.ml.jobs.
|
|
15
|
+
from snowflake.ml.jobs._interop import results as interop_result, utils as interop_utils
|
|
16
|
+
from snowflake.ml.jobs._utils import (
|
|
17
|
+
constants,
|
|
18
|
+
payload_utils,
|
|
19
|
+
query_helper,
|
|
20
|
+
stage_utils,
|
|
21
|
+
types,
|
|
22
|
+
)
|
|
16
23
|
from snowflake.snowpark import Row, context as sp_context
|
|
17
24
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
18
25
|
|
|
19
26
|
_PROJECT = "MLJob"
|
|
20
|
-
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
|
|
27
|
+
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR", "DELETED"}
|
|
21
28
|
|
|
22
29
|
T = TypeVar("T")
|
|
23
30
|
|
|
@@ -36,7 +43,12 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
36
43
|
self._session = session or sp_context.get_active_session()
|
|
37
44
|
|
|
38
45
|
self._status: types.JOB_STATUS = "PENDING"
|
|
39
|
-
self._result: Optional[
|
|
46
|
+
self._result: Optional[interop_result.ExecutionResult] = None
|
|
47
|
+
|
|
48
|
+
@cached_property
|
|
49
|
+
def _service_info(self) -> types.ServiceInfo:
|
|
50
|
+
"""Get the job's service info."""
|
|
51
|
+
return _resolve_service_info(self.id, self._session)
|
|
40
52
|
|
|
41
53
|
@cached_property
|
|
42
54
|
def name(self) -> str:
|
|
@@ -44,7 +56,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
44
56
|
|
|
45
57
|
@cached_property
|
|
46
58
|
def target_instances(self) -> int:
|
|
47
|
-
return
|
|
59
|
+
return self._service_info.target_instances
|
|
48
60
|
|
|
49
61
|
@cached_property
|
|
50
62
|
def min_instances(self) -> int:
|
|
@@ -69,8 +81,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
69
81
|
@cached_property
|
|
70
82
|
def _compute_pool(self) -> str:
|
|
71
83
|
"""Get the job's compute pool name."""
|
|
72
|
-
|
|
73
|
-
return cast(str, row["compute_pool"])
|
|
84
|
+
return self._service_info.compute_pool
|
|
74
85
|
|
|
75
86
|
@property
|
|
76
87
|
def _service_spec(self) -> dict[str, Any]:
|
|
@@ -82,7 +93,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
82
93
|
@property
|
|
83
94
|
def _container_spec(self) -> dict[str, Any]:
|
|
84
95
|
"""Get the job's main container spec."""
|
|
85
|
-
|
|
96
|
+
try:
|
|
97
|
+
containers = self._service_spec["spec"]["containers"]
|
|
98
|
+
except SnowparkSQLException as e:
|
|
99
|
+
if e.sql_error_code == 2003:
|
|
100
|
+
# If the job is deleted, the service spec is not available
|
|
101
|
+
return {}
|
|
102
|
+
raise
|
|
86
103
|
if len(containers) == 1:
|
|
87
104
|
return cast(dict[str, Any], containers[0])
|
|
88
105
|
try:
|
|
@@ -103,24 +120,30 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
103
120
|
"""Get the job's result file location."""
|
|
104
121
|
result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
|
105
122
|
if result_path_str is None:
|
|
106
|
-
raise
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
123
|
+
raise NotImplementedError(f"Job {self.name} doesn't have a result path configured")
|
|
124
|
+
|
|
125
|
+
return self._transform_path(result_path_str)
|
|
126
|
+
|
|
127
|
+
def _transform_path(self, path_str: str) -> str:
|
|
128
|
+
"""Transform a local path within the container to a stage path."""
|
|
129
|
+
path = payload_utils.resolve_path(path_str)
|
|
130
|
+
if isinstance(path, stage_utils.StagePath):
|
|
131
|
+
# Stage paths need no transformation
|
|
132
|
+
return path.as_posix()
|
|
133
|
+
if not path.is_absolute():
|
|
134
|
+
# Assume relative paths are relative to stage mount path
|
|
135
|
+
return f"{self._stage_path}/{path.as_posix()}"
|
|
136
|
+
|
|
137
|
+
# If result path is absolute, rebase it onto the stage mount path
|
|
138
|
+
# TODO: Rather than matching by name, use the longest mount path which matches
|
|
114
139
|
volume_mounts = self._container_spec["volumeMounts"]
|
|
115
140
|
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
116
141
|
stage_mount = Path(stage_mount_str)
|
|
117
142
|
try:
|
|
118
|
-
relative_path =
|
|
143
|
+
relative_path = path.relative_to(stage_mount)
|
|
119
144
|
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
120
145
|
except ValueError:
|
|
121
|
-
raise ValueError(
|
|
122
|
-
f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
|
|
123
|
-
)
|
|
146
|
+
raise ValueError(f"Result path {path} is absolute, but should be relative to stage mount {stage_mount}")
|
|
124
147
|
|
|
125
148
|
@overload
|
|
126
149
|
def get_logs(
|
|
@@ -165,7 +188,14 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
165
188
|
Returns:
|
|
166
189
|
The job's execution logs.
|
|
167
190
|
"""
|
|
168
|
-
logs = _get_logs(
|
|
191
|
+
logs = _get_logs(
|
|
192
|
+
self._session,
|
|
193
|
+
self.id,
|
|
194
|
+
limit,
|
|
195
|
+
instance_id,
|
|
196
|
+
self._container_spec["name"] if "name" in self._container_spec else constants.DEFAULT_CONTAINER_NAME,
|
|
197
|
+
verbose,
|
|
198
|
+
)
|
|
169
199
|
assert isinstance(logs, str) # mypy
|
|
170
200
|
if as_list:
|
|
171
201
|
return logs.splitlines()
|
|
@@ -199,8 +229,22 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
199
229
|
Raises:
|
|
200
230
|
TimeoutError: If the job does not complete within the specified timeout.
|
|
201
231
|
"""
|
|
202
|
-
delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
|
|
203
232
|
start_time = time.monotonic()
|
|
233
|
+
try:
|
|
234
|
+
# spcs_wait_for() is a synchronous query, it’s more effective to do polling with exponential
|
|
235
|
+
# backoff. If the job is running for a long time. We want a hybrid option: use spcs_wait_for()
|
|
236
|
+
# for the first 30 seconds, then switch to polling for long running jobs
|
|
237
|
+
min_timeout = (
|
|
238
|
+
int(min(timeout, constants.JOB_SPCS_TIMEOUT_SECONDS))
|
|
239
|
+
if timeout >= 0
|
|
240
|
+
else constants.JOB_SPCS_TIMEOUT_SECONDS
|
|
241
|
+
)
|
|
242
|
+
query_helper.run_query(self._session, f"call {self.id}!spcs_wait_for('DONE', {min_timeout})")
|
|
243
|
+
return self.status
|
|
244
|
+
except SnowparkSQLException:
|
|
245
|
+
# if the function does not support for this environment
|
|
246
|
+
pass
|
|
247
|
+
delay: float = float(constants.JOB_POLL_INITIAL_DELAY_SECONDS) # Start with 5s delay
|
|
204
248
|
warning_shown = False
|
|
205
249
|
while (status := self.status) not in TERMINAL_JOB_STATUSES:
|
|
206
250
|
elapsed = time.monotonic() - start_time
|
|
@@ -218,7 +262,6 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
218
262
|
delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
|
219
263
|
return self.status
|
|
220
264
|
|
|
221
|
-
@snowpark._internal.utils.private_preview(version="1.8.2")
|
|
222
265
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
|
223
266
|
def result(self, timeout: float = -1) -> T:
|
|
224
267
|
"""
|
|
@@ -237,13 +280,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
237
280
|
if self._result is None:
|
|
238
281
|
self.wait(timeout)
|
|
239
282
|
try:
|
|
240
|
-
self._result = interop_utils.
|
|
283
|
+
self._result = interop_utils.load_result(
|
|
284
|
+
self._result_path, session=self._session, path_transform=self._transform_path
|
|
285
|
+
)
|
|
241
286
|
except Exception as e:
|
|
242
|
-
raise RuntimeError(f"Failed to retrieve result for job
|
|
287
|
+
raise RuntimeError(f"Failed to retrieve result for job, error: {e!r}") from e
|
|
243
288
|
|
|
244
|
-
|
|
245
|
-
return cast(T, self._result.result)
|
|
246
|
-
raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
|
|
289
|
+
return cast(T, self._result.get_value())
|
|
247
290
|
|
|
248
291
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
249
292
|
def cancel(self) -> None:
|
|
@@ -256,22 +299,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
256
299
|
self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
|
|
257
300
|
logger.debug(f"Cancellation requested for job {self.id}")
|
|
258
301
|
except SnowparkSQLException as e:
|
|
259
|
-
raise RuntimeError(f"Failed to cancel job
|
|
302
|
+
raise RuntimeError(f"Failed to cancel job, error: {e!r}") from e
|
|
260
303
|
|
|
261
304
|
|
|
262
305
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
|
263
306
|
def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
|
|
264
307
|
"""Retrieve job or job instance execution status."""
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
308
|
+
try:
|
|
309
|
+
if instance_id is not None:
|
|
310
|
+
# Get specific instance status
|
|
311
|
+
rows = query_helper.run_query(session, "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,))
|
|
312
|
+
for row in rows:
|
|
313
|
+
if row["instance_id"] == str(instance_id):
|
|
314
|
+
return cast(types.JOB_STATUS, row["status"])
|
|
315
|
+
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
|
316
|
+
else:
|
|
317
|
+
row = _get_service_info(session, job_id)
|
|
318
|
+
return cast(types.JOB_STATUS, row["status"])
|
|
319
|
+
except SnowparkSQLException as e:
|
|
320
|
+
if e.sql_error_code == 2003:
|
|
321
|
+
row = _get_service_info_spcs(session, job_id)
|
|
322
|
+
return cast(types.JOB_STATUS, row["STATUS"])
|
|
323
|
+
raise
|
|
275
324
|
|
|
276
325
|
|
|
277
326
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
|
@@ -542,8 +591,21 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
|
542
591
|
|
|
543
592
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
|
544
593
|
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
|
545
|
-
|
|
546
|
-
|
|
594
|
+
try:
|
|
595
|
+
row = _get_service_info(session, job_id)
|
|
596
|
+
return int(row["target_instances"])
|
|
597
|
+
except SnowparkSQLException as e:
|
|
598
|
+
if e.sql_error_code == 2003:
|
|
599
|
+
row = _get_service_info_spcs(session, job_id)
|
|
600
|
+
try:
|
|
601
|
+
params = json.loads(row["PARAMETERS"])
|
|
602
|
+
if isinstance(params, dict):
|
|
603
|
+
return int(params.get("REPLICAS", 1))
|
|
604
|
+
else:
|
|
605
|
+
return 1
|
|
606
|
+
except (json.JSONDecodeError, ValueError):
|
|
607
|
+
return 1
|
|
608
|
+
raise
|
|
547
609
|
|
|
548
610
|
|
|
549
611
|
def _get_logs_spcs(
|
|
@@ -581,3 +643,87 @@ def _get_logs_spcs(
|
|
|
581
643
|
query.append(f" LIMIT {limit};")
|
|
582
644
|
rows = session.sql("\n".join(query)).collect()
|
|
583
645
|
return rows
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def _get_service_info_spcs(session: snowpark.Session, job_id: str) -> Any:
|
|
649
|
+
"""
|
|
650
|
+
Retrieve the service info from the SPCS interface.
|
|
651
|
+
|
|
652
|
+
Args:
|
|
653
|
+
session (Session): The Snowpark session to use.
|
|
654
|
+
job_id (str): The job ID.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
Any: The service info.
|
|
658
|
+
|
|
659
|
+
Raises:
|
|
660
|
+
SnowparkSQLException: If the job does not exist or is too old to retrieve.
|
|
661
|
+
"""
|
|
662
|
+
db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
|
|
663
|
+
db = db or session.get_current_database()
|
|
664
|
+
schema = schema or session.get_current_schema()
|
|
665
|
+
rows = query_helper.run_query(
|
|
666
|
+
session,
|
|
667
|
+
"""
|
|
668
|
+
select DATABASE_NAME, SCHEMA_NAME, NAME, STATUS, COMPUTE_POOL_NAME, PARAMETERS
|
|
669
|
+
from table(snowflake.spcs.get_job_history())
|
|
670
|
+
where database_name = ? and schema_name = ? and name = ?
|
|
671
|
+
""",
|
|
672
|
+
params=(db, schema, name),
|
|
673
|
+
)
|
|
674
|
+
if rows:
|
|
675
|
+
return rows[0]
|
|
676
|
+
else:
|
|
677
|
+
raise SnowparkSQLException(f"Job {job_id} does not exist or could not be retrieved", sql_error_code=2003)
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def _resolve_service_info(id: str, session: snowpark.Session) -> types.ServiceInfo:
|
|
681
|
+
try:
|
|
682
|
+
row = _get_service_info(session, id)
|
|
683
|
+
except SnowparkSQLException as e:
|
|
684
|
+
if e.sql_error_code == 2003:
|
|
685
|
+
row = _get_service_info_spcs(session, id)
|
|
686
|
+
else:
|
|
687
|
+
raise
|
|
688
|
+
if not row:
|
|
689
|
+
raise SnowparkSQLException(f"Job {id} does not exist or could not be retrieved", sql_error_code=2003)
|
|
690
|
+
|
|
691
|
+
if "compute_pool" in row:
|
|
692
|
+
compute_pool = row["compute_pool"]
|
|
693
|
+
elif "COMPUTE_POOL_NAME" in row:
|
|
694
|
+
compute_pool = row["COMPUTE_POOL_NAME"]
|
|
695
|
+
else:
|
|
696
|
+
raise ValueError(f"compute_pool not found in row: {row}")
|
|
697
|
+
|
|
698
|
+
if "status" in row:
|
|
699
|
+
status = row["status"]
|
|
700
|
+
elif "STATUS" in row:
|
|
701
|
+
status = row["STATUS"]
|
|
702
|
+
else:
|
|
703
|
+
raise ValueError(f"status not found in row: {row}")
|
|
704
|
+
# Normalize target_instances
|
|
705
|
+
target_instances: int
|
|
706
|
+
if "target_instances" in row and row["target_instances"] is not None:
|
|
707
|
+
try:
|
|
708
|
+
target_instances = int(row["target_instances"])
|
|
709
|
+
except (ValueError, TypeError):
|
|
710
|
+
target_instances = 1
|
|
711
|
+
elif "PARAMETERS" in row and row["PARAMETERS"]:
|
|
712
|
+
try:
|
|
713
|
+
params = json.loads(row["PARAMETERS"])
|
|
714
|
+
target_instances = int(params.get("REPLICAS", 1)) if isinstance(params, dict) else 1
|
|
715
|
+
except (json.JSONDecodeError, ValueError, TypeError):
|
|
716
|
+
target_instances = 1
|
|
717
|
+
else:
|
|
718
|
+
target_instances = 1
|
|
719
|
+
|
|
720
|
+
database_name = row["database_name"] if "database_name" in row else row["DATABASE_NAME"]
|
|
721
|
+
schema_name = row["schema_name"] if "schema_name" in row else row["SCHEMA_NAME"]
|
|
722
|
+
|
|
723
|
+
return types.ServiceInfo(
|
|
724
|
+
database_name=database_name,
|
|
725
|
+
schema_name=schema_name,
|
|
726
|
+
status=cast(types.JOB_STATUS, status),
|
|
727
|
+
compute_pool=cast(str, compute_pool),
|
|
728
|
+
target_instances=target_instances,
|
|
729
|
+
)
|
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -21,6 +21,7 @@ from snowflake.ml.jobs._utils import (
|
|
|
21
21
|
spec_utils,
|
|
22
22
|
types,
|
|
23
23
|
)
|
|
24
|
+
from snowflake.snowpark._internal import utils as sp_utils
|
|
24
25
|
from snowflake.snowpark.context import get_active_session
|
|
25
26
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
26
27
|
from snowflake.snowpark.functions import coalesce, col, lit, when
|
|
@@ -179,8 +180,10 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
|
179
180
|
_ = job._service_spec
|
|
180
181
|
return job
|
|
181
182
|
except SnowparkSQLException as e:
|
|
182
|
-
if
|
|
183
|
-
|
|
183
|
+
if e.sql_error_code == 2003:
|
|
184
|
+
job = jb.MLJob[Any](job_id, session=session)
|
|
185
|
+
_ = job.status
|
|
186
|
+
return job
|
|
184
187
|
raise
|
|
185
188
|
|
|
186
189
|
|
|
@@ -446,7 +449,7 @@ def _submit_job(
|
|
|
446
449
|
Raises:
|
|
447
450
|
ValueError: If database or schema value(s) are invalid
|
|
448
451
|
RuntimeError: If schema is not specified in session context or job submission
|
|
449
|
-
|
|
452
|
+
SnowparkSQLException: if failed to upload payload
|
|
450
453
|
"""
|
|
451
454
|
session = _ensure_session(session)
|
|
452
455
|
|
|
@@ -512,49 +515,44 @@ def _submit_job(
|
|
|
512
515
|
uploaded_payload = payload_utils.JobPayload(
|
|
513
516
|
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
|
|
514
517
|
).upload(session, stage_path)
|
|
515
|
-
except
|
|
518
|
+
except SnowparkSQLException as e:
|
|
516
519
|
if e.sql_error_code == 90106:
|
|
517
520
|
raise RuntimeError(
|
|
518
521
|
"Please specify a schema, either in the session context or as a parameter in the job submission"
|
|
519
522
|
)
|
|
520
523
|
raise
|
|
521
524
|
|
|
522
|
-
|
|
523
|
-
if target_instances > 1:
|
|
524
|
-
default_spec_overrides = {
|
|
525
|
-
"spec": {
|
|
526
|
-
"endpoints": [
|
|
527
|
-
{"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
|
|
528
|
-
]
|
|
529
|
-
},
|
|
530
|
-
}
|
|
531
|
-
if spec_overrides:
|
|
532
|
-
spec_overrides = spec_utils.merge_patch(
|
|
533
|
-
default_spec_overrides, spec_overrides, display_name="spec_overrides"
|
|
534
|
-
)
|
|
535
|
-
else:
|
|
536
|
-
spec_overrides = default_spec_overrides
|
|
537
|
-
|
|
538
|
-
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
|
|
525
|
+
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True):
|
|
539
526
|
# Add default env vars (extracted from spec_utils.generate_service_spec)
|
|
540
527
|
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
541
528
|
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
529
|
+
try:
|
|
530
|
+
return _do_submit_job_v2(
|
|
531
|
+
session=session,
|
|
532
|
+
payload=uploaded_payload,
|
|
533
|
+
args=args,
|
|
534
|
+
env_vars=combined_env_vars,
|
|
535
|
+
spec_overrides=spec_overrides,
|
|
536
|
+
compute_pool=compute_pool,
|
|
537
|
+
job_id=job_id,
|
|
538
|
+
external_access_integrations=external_access_integrations,
|
|
539
|
+
query_warehouse=query_warehouse,
|
|
540
|
+
target_instances=target_instances,
|
|
541
|
+
min_instances=min_instances,
|
|
542
|
+
enable_metrics=enable_metrics,
|
|
543
|
+
use_async=True,
|
|
544
|
+
runtime_environment=runtime_environment,
|
|
545
|
+
)
|
|
546
|
+
except SnowparkSQLException as e:
|
|
547
|
+
if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
|
|
548
|
+
raise
|
|
549
|
+
# SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
|
|
550
|
+
# stored procedures. This will be fixed in an upcoming release.
|
|
551
|
+
logger.warning(
|
|
552
|
+
"Job submission using V2 failed with error {}. Falling back to V1.".format(
|
|
553
|
+
str(e).split("\n", 1)[0],
|
|
554
|
+
)
|
|
555
|
+
)
|
|
558
556
|
|
|
559
557
|
# Fall back to v1
|
|
560
558
|
# Generate service spec
|
|
@@ -688,7 +686,7 @@ def _do_submit_job_v2(
|
|
|
688
686
|
# for the image tag or full image URL, we use that directly
|
|
689
687
|
if runtime_environment:
|
|
690
688
|
spec_options["RUNTIME"] = runtime_environment
|
|
691
|
-
elif feature_flags.FeatureFlags.
|
|
689
|
+
elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
692
690
|
# when feature flag is enabled, we get the local python version and wrap it in a dict
|
|
693
691
|
# in system function, we can know whether it is python version or image tag or full image URL through the format
|
|
694
692
|
spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
|
|
@@ -699,10 +697,21 @@ def _do_submit_job_v2(
|
|
|
699
697
|
"MIN_INSTANCES": min_instances,
|
|
700
698
|
"ASYNC": use_async,
|
|
701
699
|
}
|
|
700
|
+
if payload.payload_name:
|
|
701
|
+
job_options["GENERATE_SUFFIX"] = True
|
|
702
702
|
job_options = {k: v for k, v in job_options.items() if v is not None}
|
|
703
703
|
|
|
704
704
|
query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
|
|
705
|
-
|
|
705
|
+
if job_id:
|
|
706
|
+
database, schema, _ = identifier.parse_schema_level_object_identifier(job_id)
|
|
707
|
+
params = [
|
|
708
|
+
job_id
|
|
709
|
+
if payload.payload_name is None
|
|
710
|
+
else identifier.get_schema_level_object_identifier(database, schema, payload.payload_name) + "_",
|
|
711
|
+
compute_pool,
|
|
712
|
+
json.dumps(spec_options),
|
|
713
|
+
json.dumps(job_options),
|
|
714
|
+
]
|
|
706
715
|
actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
|
|
707
716
|
|
|
708
717
|
return get_job(actual_job_id, session=session)
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
1
4
|
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
2
5
|
JobSpec,
|
|
3
6
|
OutputSpec,
|
|
@@ -18,3 +21,19 @@ __all__ = [
|
|
|
18
21
|
"SaveMode",
|
|
19
22
|
"Volatility",
|
|
20
23
|
]
|
|
24
|
+
|
|
25
|
+
_deprecation_warning_msg_for_3_9 = (
|
|
26
|
+
"Python 3.9 is deprecated in snowflake-ml-python. " "Please upgrade to Python 3.10 or greater."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
warnings.filterwarnings(
|
|
30
|
+
"once",
|
|
31
|
+
message=_deprecation_warning_msg_for_3_9,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if sys.version_info.major == 3 and sys.version_info.minor == 9:
|
|
35
|
+
warnings.warn(
|
|
36
|
+
_deprecation_warning_msg_for_3_9,
|
|
37
|
+
category=DeprecationWarning,
|
|
38
|
+
stacklevel=2,
|
|
39
|
+
)
|
|
@@ -19,11 +19,74 @@ class SaveMode(str, Enum):
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class OutputSpec(BaseModel):
|
|
22
|
+
"""Specification for batch inference output.
|
|
23
|
+
|
|
24
|
+
Defines where the inference results should be written and how to handle
|
|
25
|
+
existing files at the output location.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
stage_location (str): The stage path where batch inference results will be saved.
|
|
29
|
+
This should be a full path including the stage with @ prefix. For example,
|
|
30
|
+
'@My_DB.PUBLIC.MY_STAGE/someth/path/'. A non-existent directory will be re-created.
|
|
31
|
+
Only Snowflake internal stages are supported at this moment.
|
|
32
|
+
mode (SaveMode): The save mode that determines behavior when files already exist
|
|
33
|
+
at the output location. Defaults to SaveMode.ERROR which raises an error
|
|
34
|
+
if files exist. Can be set to SaveMode.OVERWRITE to replace existing files.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> output_spec = OutputSpec(
|
|
38
|
+
... stage_location="@My_DB.PUBLIC.MY_STAGE/someth/path/",
|
|
39
|
+
... mode=SaveMode.OVERWRITE
|
|
40
|
+
... )
|
|
41
|
+
"""
|
|
42
|
+
|
|
22
43
|
stage_location: str
|
|
23
44
|
mode: SaveMode = SaveMode.ERROR
|
|
24
45
|
|
|
25
46
|
|
|
26
47
|
class JobSpec(BaseModel):
|
|
48
|
+
"""Specification for batch inference job execution.
|
|
49
|
+
|
|
50
|
+
Defines the compute resources, job settings, and execution parameters
|
|
51
|
+
for running batch inference jobs in Snowflake.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
image_repo (Optional[str]): Container image repository for the inference job.
|
|
55
|
+
If not specified, uses the default repository.
|
|
56
|
+
job_name (Optional[str]): Custom name for the batch inference job.
|
|
57
|
+
If not provided, a name will be auto-generated in the form of "BATCH_INFERENCE_<UUID>".
|
|
58
|
+
num_workers (Optional[int]): The number of workers to run the inference service for handling
|
|
59
|
+
requests in parallel within an instance of the service. By default, it is set to 2*vCPU+1
|
|
60
|
+
of the node for CPU based inference and 1 for GPU based inference. For GPU based inference,
|
|
61
|
+
please see best practices before playing with this value.
|
|
62
|
+
function_name (Optional[str]): Name of the specific function to call for inference.
|
|
63
|
+
Required when the model has multiple inference functions.
|
|
64
|
+
force_rebuild (bool): Whether to force rebuilding the container image even if
|
|
65
|
+
it already exists. Defaults to False.
|
|
66
|
+
max_batch_rows (int): Maximum number of rows to process in a single batch.
|
|
67
|
+
Defaults to 1024. Larger values may improve throughput.
|
|
68
|
+
warehouse (Optional[str]): Snowflake warehouse to use for the batch inference job.
|
|
69
|
+
If not specified, uses the session's current warehouse.
|
|
70
|
+
cpu_requests (Optional[str]): The cpu limit for CPU based inference. Can be an integer,
|
|
71
|
+
fractional or string values. If None, we attempt to utilize all the vCPU of the node.
|
|
72
|
+
memory_requests (Optional[str]): The memory limit for inference. Can be an integer
|
|
73
|
+
or a fractional value, but requires a unit (GiB, MiB). If None, we attempt to utilize all
|
|
74
|
+
the memory of the node.
|
|
75
|
+
gpu_requests (Optional[str]): The gpu limit for GPU based inference. Can be integer or
|
|
76
|
+
string values. Use CPU if None.
|
|
77
|
+
replicas (Optional[int]): Number of job replicas to run for high availability.
|
|
78
|
+
If not specified, defaults to 1 replica.
|
|
79
|
+
|
|
80
|
+
Example:
|
|
81
|
+
>>> job_spec = JobSpec(
|
|
82
|
+
... job_name="my_inference_job",
|
|
83
|
+
... num_workers=4,
|
|
84
|
+
... cpu_requests="2",
|
|
85
|
+
... memory_requests="8Gi",
|
|
86
|
+
... max_batch_rows=2048
|
|
87
|
+
... )
|
|
88
|
+
"""
|
|
89
|
+
|
|
27
90
|
image_repo: Optional[str] = None
|
|
28
91
|
job_name: Optional[str] = None
|
|
29
92
|
num_workers: Optional[int] = None
|
|
@@ -6,13 +6,9 @@ from snowflake.ml.model._client.ops import service_ops
|
|
|
6
6
|
def _get_inference_engine_args(
|
|
7
7
|
experimental_options: Optional[dict[str, Any]],
|
|
8
8
|
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
9
|
-
|
|
10
|
-
if not experimental_options:
|
|
9
|
+
if not experimental_options or "inference_engine" not in experimental_options:
|
|
11
10
|
return None
|
|
12
11
|
|
|
13
|
-
if "inference_engine" not in experimental_options:
|
|
14
|
-
raise ValueError("inference_engine is required in experimental_options")
|
|
15
|
-
|
|
16
12
|
return service_ops.InferenceEngineArgs(
|
|
17
13
|
inference_engine=experimental_options["inference_engine"],
|
|
18
14
|
inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
|