snowflake-ml-python 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +1 -1
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +2 -1
- snowflake/ml/jobs/_utils/types.py +10 -0
- snowflake/ml/jobs/job.py +168 -36
- snowflake/ml/jobs/manager.py +36 -38
- snowflake/ml/model/_client/model/model_version_impl.py +39 -7
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +7 -2
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +22 -5
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +26 -4
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +35 -27
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py
CHANGED
|
@@ -12,12 +12,19 @@ from snowflake import snowpark
|
|
|
12
12
|
from snowflake.ml._internal import telemetry
|
|
13
13
|
from snowflake.ml._internal.utils import identifier
|
|
14
14
|
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
|
15
|
-
from snowflake.ml.jobs.
|
|
15
|
+
from snowflake.ml.jobs._interop import results as interop_result, utils as interop_utils
|
|
16
|
+
from snowflake.ml.jobs._utils import (
|
|
17
|
+
constants,
|
|
18
|
+
payload_utils,
|
|
19
|
+
query_helper,
|
|
20
|
+
stage_utils,
|
|
21
|
+
types,
|
|
22
|
+
)
|
|
16
23
|
from snowflake.snowpark import Row, context as sp_context
|
|
17
24
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
18
25
|
|
|
19
26
|
_PROJECT = "MLJob"
|
|
20
|
-
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
|
|
27
|
+
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR", "DELETED"}
|
|
21
28
|
|
|
22
29
|
T = TypeVar("T")
|
|
23
30
|
|
|
@@ -36,7 +43,12 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
36
43
|
self._session = session or sp_context.get_active_session()
|
|
37
44
|
|
|
38
45
|
self._status: types.JOB_STATUS = "PENDING"
|
|
39
|
-
self._result: Optional[
|
|
46
|
+
self._result: Optional[interop_result.ExecutionResult] = None
|
|
47
|
+
|
|
48
|
+
@cached_property
|
|
49
|
+
def _service_info(self) -> types.ServiceInfo:
|
|
50
|
+
"""Get the job's service info."""
|
|
51
|
+
return _resolve_service_info(self.id, self._session)
|
|
40
52
|
|
|
41
53
|
@cached_property
|
|
42
54
|
def name(self) -> str:
|
|
@@ -44,7 +56,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
44
56
|
|
|
45
57
|
@cached_property
|
|
46
58
|
def target_instances(self) -> int:
|
|
47
|
-
return
|
|
59
|
+
return self._service_info.target_instances
|
|
48
60
|
|
|
49
61
|
@cached_property
|
|
50
62
|
def min_instances(self) -> int:
|
|
@@ -69,8 +81,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
69
81
|
@cached_property
|
|
70
82
|
def _compute_pool(self) -> str:
|
|
71
83
|
"""Get the job's compute pool name."""
|
|
72
|
-
|
|
73
|
-
return cast(str, row["compute_pool"])
|
|
84
|
+
return self._service_info.compute_pool
|
|
74
85
|
|
|
75
86
|
@property
|
|
76
87
|
def _service_spec(self) -> dict[str, Any]:
|
|
@@ -82,7 +93,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
82
93
|
@property
|
|
83
94
|
def _container_spec(self) -> dict[str, Any]:
|
|
84
95
|
"""Get the job's main container spec."""
|
|
85
|
-
|
|
96
|
+
try:
|
|
97
|
+
containers = self._service_spec["spec"]["containers"]
|
|
98
|
+
except SnowparkSQLException as e:
|
|
99
|
+
if e.sql_error_code == 2003:
|
|
100
|
+
# If the job is deleted, the service spec is not available
|
|
101
|
+
return {}
|
|
102
|
+
raise
|
|
86
103
|
if len(containers) == 1:
|
|
87
104
|
return cast(dict[str, Any], containers[0])
|
|
88
105
|
try:
|
|
@@ -105,22 +122,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
105
122
|
if result_path_str is None:
|
|
106
123
|
raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
|
|
107
124
|
|
|
108
|
-
|
|
109
|
-
result_path = Path(result_path_str)
|
|
110
|
-
if not result_path.is_absolute():
|
|
111
|
-
return f"{self._stage_path}/{result_path.as_posix()}"
|
|
125
|
+
return self._transform_path(result_path_str)
|
|
112
126
|
|
|
113
|
-
|
|
127
|
+
def _transform_path(self, path_str: str) -> str:
|
|
128
|
+
"""Transform a local path within the container to a stage path."""
|
|
129
|
+
path = payload_utils.resolve_path(path_str)
|
|
130
|
+
if isinstance(path, stage_utils.StagePath):
|
|
131
|
+
# Stage paths need no transformation
|
|
132
|
+
return path.as_posix()
|
|
133
|
+
if not path.is_absolute():
|
|
134
|
+
# Assume relative paths are relative to stage mount path
|
|
135
|
+
return f"{self._stage_path}/{path.as_posix()}"
|
|
136
|
+
|
|
137
|
+
# If result path is absolute, rebase it onto the stage mount path
|
|
138
|
+
# TODO: Rather than matching by name, use the longest mount path which matches
|
|
114
139
|
volume_mounts = self._container_spec["volumeMounts"]
|
|
115
140
|
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
116
141
|
stage_mount = Path(stage_mount_str)
|
|
117
142
|
try:
|
|
118
|
-
relative_path =
|
|
143
|
+
relative_path = path.relative_to(stage_mount)
|
|
119
144
|
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
120
145
|
except ValueError:
|
|
121
|
-
raise ValueError(
|
|
122
|
-
f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
|
|
123
|
-
)
|
|
146
|
+
raise ValueError(f"Result path {path} is absolute, but should be relative to stage mount {stage_mount}")
|
|
124
147
|
|
|
125
148
|
@overload
|
|
126
149
|
def get_logs(
|
|
@@ -165,7 +188,14 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
165
188
|
Returns:
|
|
166
189
|
The job's execution logs.
|
|
167
190
|
"""
|
|
168
|
-
logs = _get_logs(
|
|
191
|
+
logs = _get_logs(
|
|
192
|
+
self._session,
|
|
193
|
+
self.id,
|
|
194
|
+
limit,
|
|
195
|
+
instance_id,
|
|
196
|
+
self._container_spec["name"] if "name" in self._container_spec else constants.DEFAULT_CONTAINER_NAME,
|
|
197
|
+
verbose,
|
|
198
|
+
)
|
|
169
199
|
assert isinstance(logs, str) # mypy
|
|
170
200
|
if as_list:
|
|
171
201
|
return logs.splitlines()
|
|
@@ -218,7 +248,6 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
218
248
|
delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
|
219
249
|
return self.status
|
|
220
250
|
|
|
221
|
-
@snowpark._internal.utils.private_preview(version="1.8.2")
|
|
222
251
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
|
223
252
|
def result(self, timeout: float = -1) -> T:
|
|
224
253
|
"""
|
|
@@ -237,13 +266,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
237
266
|
if self._result is None:
|
|
238
267
|
self.wait(timeout)
|
|
239
268
|
try:
|
|
240
|
-
self._result = interop_utils.
|
|
269
|
+
self._result = interop_utils.load_result(
|
|
270
|
+
self._result_path, session=self._session, path_transform=self._transform_path
|
|
271
|
+
)
|
|
241
272
|
except Exception as e:
|
|
242
|
-
raise RuntimeError(f"Failed to retrieve result for job
|
|
273
|
+
raise RuntimeError(f"Failed to retrieve result for job, error: {e!r}") from e
|
|
243
274
|
|
|
244
|
-
|
|
245
|
-
return cast(T, self._result.result)
|
|
246
|
-
raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
|
|
275
|
+
return cast(T, self._result.get_value())
|
|
247
276
|
|
|
248
277
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
249
278
|
def cancel(self) -> None:
|
|
@@ -256,22 +285,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
256
285
|
self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
|
|
257
286
|
logger.debug(f"Cancellation requested for job {self.id}")
|
|
258
287
|
except SnowparkSQLException as e:
|
|
259
|
-
raise RuntimeError(f"Failed to cancel job
|
|
288
|
+
raise RuntimeError(f"Failed to cancel job, error: {e!r}") from e
|
|
260
289
|
|
|
261
290
|
|
|
262
291
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
|
263
292
|
def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
|
|
264
293
|
"""Retrieve job or job instance execution status."""
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
294
|
+
try:
|
|
295
|
+
if instance_id is not None:
|
|
296
|
+
# Get specific instance status
|
|
297
|
+
rows = query_helper.run_query(session, "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,))
|
|
298
|
+
for row in rows:
|
|
299
|
+
if row["instance_id"] == str(instance_id):
|
|
300
|
+
return cast(types.JOB_STATUS, row["status"])
|
|
301
|
+
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
|
302
|
+
else:
|
|
303
|
+
row = _get_service_info(session, job_id)
|
|
304
|
+
return cast(types.JOB_STATUS, row["status"])
|
|
305
|
+
except SnowparkSQLException as e:
|
|
306
|
+
if e.sql_error_code == 2003:
|
|
307
|
+
row = _get_service_info_spcs(session, job_id)
|
|
308
|
+
return cast(types.JOB_STATUS, row["STATUS"])
|
|
309
|
+
raise
|
|
275
310
|
|
|
276
311
|
|
|
277
312
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
|
@@ -542,8 +577,21 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
|
542
577
|
|
|
543
578
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
|
544
579
|
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
|
545
|
-
|
|
546
|
-
|
|
580
|
+
try:
|
|
581
|
+
row = _get_service_info(session, job_id)
|
|
582
|
+
return int(row["target_instances"])
|
|
583
|
+
except SnowparkSQLException as e:
|
|
584
|
+
if e.sql_error_code == 2003:
|
|
585
|
+
row = _get_service_info_spcs(session, job_id)
|
|
586
|
+
try:
|
|
587
|
+
params = json.loads(row["PARAMETERS"])
|
|
588
|
+
if isinstance(params, dict):
|
|
589
|
+
return int(params.get("REPLICAS", 1))
|
|
590
|
+
else:
|
|
591
|
+
return 1
|
|
592
|
+
except (json.JSONDecodeError, ValueError):
|
|
593
|
+
return 1
|
|
594
|
+
raise
|
|
547
595
|
|
|
548
596
|
|
|
549
597
|
def _get_logs_spcs(
|
|
@@ -581,3 +629,87 @@ def _get_logs_spcs(
|
|
|
581
629
|
query.append(f" LIMIT {limit};")
|
|
582
630
|
rows = session.sql("\n".join(query)).collect()
|
|
583
631
|
return rows
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _get_service_info_spcs(session: snowpark.Session, job_id: str) -> Any:
|
|
635
|
+
"""
|
|
636
|
+
Retrieve the service info from the SPCS interface.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
session (Session): The Snowpark session to use.
|
|
640
|
+
job_id (str): The job ID.
|
|
641
|
+
|
|
642
|
+
Returns:
|
|
643
|
+
Any: The service info.
|
|
644
|
+
|
|
645
|
+
Raises:
|
|
646
|
+
SnowparkSQLException: If the job does not exist or is too old to retrieve.
|
|
647
|
+
"""
|
|
648
|
+
db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
|
|
649
|
+
db = db or session.get_current_database()
|
|
650
|
+
schema = schema or session.get_current_schema()
|
|
651
|
+
rows = query_helper.run_query(
|
|
652
|
+
session,
|
|
653
|
+
"""
|
|
654
|
+
select DATABASE_NAME, SCHEMA_NAME, NAME, STATUS, COMPUTE_POOL_NAME, PARAMETERS
|
|
655
|
+
from table(snowflake.spcs.get_job_history())
|
|
656
|
+
where database_name = ? and schema_name = ? and name = ?
|
|
657
|
+
""",
|
|
658
|
+
params=(db, schema, name),
|
|
659
|
+
)
|
|
660
|
+
if rows:
|
|
661
|
+
return rows[0]
|
|
662
|
+
else:
|
|
663
|
+
raise SnowparkSQLException(f"Job {job_id} does not exist or could not be retrieved", sql_error_code=2003)
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def _resolve_service_info(id: str, session: snowpark.Session) -> types.ServiceInfo:
|
|
667
|
+
try:
|
|
668
|
+
row = _get_service_info(session, id)
|
|
669
|
+
except SnowparkSQLException as e:
|
|
670
|
+
if e.sql_error_code == 2003:
|
|
671
|
+
row = _get_service_info_spcs(session, id)
|
|
672
|
+
else:
|
|
673
|
+
raise
|
|
674
|
+
if not row:
|
|
675
|
+
raise SnowparkSQLException(f"Job {id} does not exist or could not be retrieved", sql_error_code=2003)
|
|
676
|
+
|
|
677
|
+
if "compute_pool" in row:
|
|
678
|
+
compute_pool = row["compute_pool"]
|
|
679
|
+
elif "COMPUTE_POOL_NAME" in row:
|
|
680
|
+
compute_pool = row["COMPUTE_POOL_NAME"]
|
|
681
|
+
else:
|
|
682
|
+
raise ValueError(f"compute_pool not found in row: {row}")
|
|
683
|
+
|
|
684
|
+
if "status" in row:
|
|
685
|
+
status = row["status"]
|
|
686
|
+
elif "STATUS" in row:
|
|
687
|
+
status = row["STATUS"]
|
|
688
|
+
else:
|
|
689
|
+
raise ValueError(f"status not found in row: {row}")
|
|
690
|
+
# Normalize target_instances
|
|
691
|
+
target_instances: int
|
|
692
|
+
if "target_instances" in row and row["target_instances"] is not None:
|
|
693
|
+
try:
|
|
694
|
+
target_instances = int(row["target_instances"])
|
|
695
|
+
except (ValueError, TypeError):
|
|
696
|
+
target_instances = 1
|
|
697
|
+
elif "PARAMETERS" in row and row["PARAMETERS"]:
|
|
698
|
+
try:
|
|
699
|
+
params = json.loads(row["PARAMETERS"])
|
|
700
|
+
target_instances = int(params.get("REPLICAS", 1)) if isinstance(params, dict) else 1
|
|
701
|
+
except (json.JSONDecodeError, ValueError, TypeError):
|
|
702
|
+
target_instances = 1
|
|
703
|
+
else:
|
|
704
|
+
target_instances = 1
|
|
705
|
+
|
|
706
|
+
database_name = row["database_name"] if "database_name" in row else row["DATABASE_NAME"]
|
|
707
|
+
schema_name = row["schema_name"] if "schema_name" in row else row["SCHEMA_NAME"]
|
|
708
|
+
|
|
709
|
+
return types.ServiceInfo(
|
|
710
|
+
database_name=database_name,
|
|
711
|
+
schema_name=schema_name,
|
|
712
|
+
status=cast(types.JOB_STATUS, status),
|
|
713
|
+
compute_pool=cast(str, compute_pool),
|
|
714
|
+
target_instances=target_instances,
|
|
715
|
+
)
|
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -21,6 +21,7 @@ from snowflake.ml.jobs._utils import (
|
|
|
21
21
|
spec_utils,
|
|
22
22
|
types,
|
|
23
23
|
)
|
|
24
|
+
from snowflake.snowpark._internal import utils as sp_utils
|
|
24
25
|
from snowflake.snowpark.context import get_active_session
|
|
25
26
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
26
27
|
from snowflake.snowpark.functions import coalesce, col, lit, when
|
|
@@ -179,8 +180,10 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
|
179
180
|
_ = job._service_spec
|
|
180
181
|
return job
|
|
181
182
|
except SnowparkSQLException as e:
|
|
182
|
-
if
|
|
183
|
-
|
|
183
|
+
if e.sql_error_code == 2003:
|
|
184
|
+
job = jb.MLJob[Any](job_id, session=session)
|
|
185
|
+
_ = job.status
|
|
186
|
+
return job
|
|
184
187
|
raise
|
|
185
188
|
|
|
186
189
|
|
|
@@ -446,7 +449,7 @@ def _submit_job(
|
|
|
446
449
|
Raises:
|
|
447
450
|
ValueError: If database or schema value(s) are invalid
|
|
448
451
|
RuntimeError: If schema is not specified in session context or job submission
|
|
449
|
-
|
|
452
|
+
SnowparkSQLException: if failed to upload payload
|
|
450
453
|
"""
|
|
451
454
|
session = _ensure_session(session)
|
|
452
455
|
|
|
@@ -512,49 +515,44 @@ def _submit_job(
|
|
|
512
515
|
uploaded_payload = payload_utils.JobPayload(
|
|
513
516
|
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
|
|
514
517
|
).upload(session, stage_path)
|
|
515
|
-
except
|
|
518
|
+
except SnowparkSQLException as e:
|
|
516
519
|
if e.sql_error_code == 90106:
|
|
517
520
|
raise RuntimeError(
|
|
518
521
|
"Please specify a schema, either in the session context or as a parameter in the job submission"
|
|
519
522
|
)
|
|
520
523
|
raise
|
|
521
524
|
|
|
522
|
-
|
|
523
|
-
if target_instances > 1:
|
|
524
|
-
default_spec_overrides = {
|
|
525
|
-
"spec": {
|
|
526
|
-
"endpoints": [
|
|
527
|
-
{"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
|
|
528
|
-
]
|
|
529
|
-
},
|
|
530
|
-
}
|
|
531
|
-
if spec_overrides:
|
|
532
|
-
spec_overrides = spec_utils.merge_patch(
|
|
533
|
-
default_spec_overrides, spec_overrides, display_name="spec_overrides"
|
|
534
|
-
)
|
|
535
|
-
else:
|
|
536
|
-
spec_overrides = default_spec_overrides
|
|
537
|
-
|
|
538
|
-
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
|
|
525
|
+
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True):
|
|
539
526
|
# Add default env vars (extracted from spec_utils.generate_service_spec)
|
|
540
527
|
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
541
528
|
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
529
|
+
try:
|
|
530
|
+
return _do_submit_job_v2(
|
|
531
|
+
session=session,
|
|
532
|
+
payload=uploaded_payload,
|
|
533
|
+
args=args,
|
|
534
|
+
env_vars=combined_env_vars,
|
|
535
|
+
spec_overrides=spec_overrides,
|
|
536
|
+
compute_pool=compute_pool,
|
|
537
|
+
job_id=job_id,
|
|
538
|
+
external_access_integrations=external_access_integrations,
|
|
539
|
+
query_warehouse=query_warehouse,
|
|
540
|
+
target_instances=target_instances,
|
|
541
|
+
min_instances=min_instances,
|
|
542
|
+
enable_metrics=enable_metrics,
|
|
543
|
+
use_async=True,
|
|
544
|
+
runtime_environment=runtime_environment,
|
|
545
|
+
)
|
|
546
|
+
except SnowparkSQLException as e:
|
|
547
|
+
if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
|
|
548
|
+
raise
|
|
549
|
+
# SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
|
|
550
|
+
# stored procedures. This will be fixed in an upcoming release.
|
|
551
|
+
logger.warning(
|
|
552
|
+
"Job submission using V2 failed with error {}. Falling back to V1.".format(
|
|
553
|
+
str(e).split("\n", 1)[0],
|
|
554
|
+
)
|
|
555
|
+
)
|
|
558
556
|
|
|
559
557
|
# Fall back to v1
|
|
560
558
|
# Generate service spec
|
|
@@ -688,7 +686,7 @@ def _do_submit_job_v2(
|
|
|
688
686
|
# for the image tag or full image URL, we use that directly
|
|
689
687
|
if runtime_environment:
|
|
690
688
|
spec_options["RUNTIME"] = runtime_environment
|
|
691
|
-
elif feature_flags.FeatureFlags.
|
|
689
|
+
elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
692
690
|
# when feature flag is enabled, we get the local python version and wrap it in a dict
|
|
693
691
|
# in system function, we can know whether it is python version or image tag or full image URL through the format
|
|
694
692
|
spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
|
|
@@ -19,7 +19,9 @@ from snowflake.ml.model._client.model import (
|
|
|
19
19
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
20
20
|
from snowflake.ml.model._model_composer import model_composer
|
|
21
21
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
22
|
+
from snowflake.ml.model._model_composer.model_method import utils as model_method_utils
|
|
22
23
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
|
24
|
+
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
|
23
25
|
from snowflake.snowpark import Session, async_job, dataframe
|
|
24
26
|
|
|
25
27
|
_TELEMETRY_PROJECT = "MLOps"
|
|
@@ -41,6 +43,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
41
43
|
_model_name: sql_identifier.SqlIdentifier
|
|
42
44
|
_version_name: sql_identifier.SqlIdentifier
|
|
43
45
|
_functions: list[model_manifest_schema.ModelFunctionInfo]
|
|
46
|
+
_model_spec: Optional[model_meta_schema.ModelMetadataDict]
|
|
44
47
|
|
|
45
48
|
def __init__(self) -> None:
|
|
46
49
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
|
@@ -150,6 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
150
153
|
self._model_name = model_name
|
|
151
154
|
self._version_name = version_name
|
|
152
155
|
self._functions = self._get_functions()
|
|
156
|
+
self._model_spec = None
|
|
153
157
|
super(cls, cls).__init__(
|
|
154
158
|
self,
|
|
155
159
|
session=model_ops._session,
|
|
@@ -437,6 +441,26 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
437
441
|
"""
|
|
438
442
|
return self._functions
|
|
439
443
|
|
|
444
|
+
def _get_model_spec(self, statement_params: Optional[dict[str, Any]] = None) -> model_meta_schema.ModelMetadataDict:
|
|
445
|
+
"""Fetch and cache the model spec for this model version.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
449
|
+
in the SQL command to fetch the model spec.
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
The model spec as a dictionary for this model version.
|
|
453
|
+
"""
|
|
454
|
+
if self._model_spec is None:
|
|
455
|
+
self._model_spec = self._model_ops._fetch_model_spec(
|
|
456
|
+
database_name=None,
|
|
457
|
+
schema_name=None,
|
|
458
|
+
model_name=self._model_name,
|
|
459
|
+
version_name=self._version_name,
|
|
460
|
+
statement_params=statement_params,
|
|
461
|
+
)
|
|
462
|
+
return self._model_spec
|
|
463
|
+
|
|
440
464
|
@overload
|
|
441
465
|
def run(
|
|
442
466
|
self,
|
|
@@ -531,6 +555,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
531
555
|
statement_params=statement_params,
|
|
532
556
|
)
|
|
533
557
|
else:
|
|
558
|
+
explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
|
|
559
|
+
|
|
534
560
|
return self._model_ops.invoke_method(
|
|
535
561
|
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
|
536
562
|
method_function_type=target_function_info["target_method_function_type"],
|
|
@@ -544,8 +570,20 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
544
570
|
partition_column=partition_column,
|
|
545
571
|
statement_params=statement_params,
|
|
546
572
|
is_partitioned=target_function_info["is_partitioned"],
|
|
573
|
+
explain_case_sensitive=explain_case_sensitive,
|
|
547
574
|
)
|
|
548
575
|
|
|
576
|
+
def _determine_explain_case_sensitivity(
|
|
577
|
+
self,
|
|
578
|
+
target_function_info: model_manifest_schema.ModelFunctionInfo,
|
|
579
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
580
|
+
) -> bool:
|
|
581
|
+
model_spec = self._get_model_spec(statement_params)
|
|
582
|
+
method_options = model_spec.get("method_options", {})
|
|
583
|
+
return model_method_utils.determine_explain_case_sensitive_from_method_options(
|
|
584
|
+
method_options, target_function_info["name"]
|
|
585
|
+
)
|
|
586
|
+
|
|
549
587
|
@telemetry.send_api_usage_telemetry(
|
|
550
588
|
project=_TELEMETRY_PROJECT,
|
|
551
589
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
@@ -803,13 +841,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
803
841
|
ValueError: If the model is not a HuggingFace text-generation model.
|
|
804
842
|
"""
|
|
805
843
|
# Fetch model spec
|
|
806
|
-
model_spec = self.
|
|
807
|
-
database_name=None,
|
|
808
|
-
schema_name=None,
|
|
809
|
-
model_name=self._model_name,
|
|
810
|
-
version_name=self._version_name,
|
|
811
|
-
statement_params=statement_params,
|
|
812
|
-
)
|
|
844
|
+
model_spec = self._get_model_spec(statement_params)
|
|
813
845
|
|
|
814
846
|
# Check if model_type is huggingface_pipeline
|
|
815
847
|
model_type = model_spec.get("model_type")
|
|
@@ -952,6 +952,7 @@ class ModelOperator:
|
|
|
952
952
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
|
953
953
|
statement_params: Optional[dict[str, str]] = None,
|
|
954
954
|
is_partitioned: Optional[bool] = None,
|
|
955
|
+
explain_case_sensitive: bool = False,
|
|
955
956
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
956
957
|
...
|
|
957
958
|
|
|
@@ -967,6 +968,7 @@ class ModelOperator:
|
|
|
967
968
|
service_name: sql_identifier.SqlIdentifier,
|
|
968
969
|
strict_input_validation: bool = False,
|
|
969
970
|
statement_params: Optional[dict[str, str]] = None,
|
|
971
|
+
explain_case_sensitive: bool = False,
|
|
970
972
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
971
973
|
...
|
|
972
974
|
|
|
@@ -986,6 +988,7 @@ class ModelOperator:
|
|
|
986
988
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
|
987
989
|
statement_params: Optional[dict[str, str]] = None,
|
|
988
990
|
is_partitioned: Optional[bool] = None,
|
|
991
|
+
explain_case_sensitive: bool = False,
|
|
989
992
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
990
993
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
|
991
994
|
|
|
@@ -1068,6 +1071,7 @@ class ModelOperator:
|
|
|
1068
1071
|
version_name=version_name,
|
|
1069
1072
|
statement_params=statement_params,
|
|
1070
1073
|
is_partitioned=is_partitioned or False,
|
|
1074
|
+
explain_case_sensitive=explain_case_sensitive,
|
|
1071
1075
|
)
|
|
1072
1076
|
|
|
1073
1077
|
if keep_order:
|
|
@@ -438,6 +438,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
438
438
|
partition_column: Optional[sql_identifier.SqlIdentifier],
|
|
439
439
|
statement_params: Optional[dict[str, Any]] = None,
|
|
440
440
|
is_partitioned: bool = True,
|
|
441
|
+
explain_case_sensitive: bool = False,
|
|
441
442
|
) -> dataframe.DataFrame:
|
|
442
443
|
with_statements = []
|
|
443
444
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
|
@@ -505,7 +506,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
505
506
|
cols_to_drop = []
|
|
506
507
|
|
|
507
508
|
for output_name, output_type, output_col_name in returns:
|
|
508
|
-
|
|
509
|
+
case_sensitive = "explain" in method_name.resolved().lower() and explain_case_sensitive
|
|
510
|
+
output_identifier = sql_identifier.SqlIdentifier(output_name, case_sensitive=case_sensitive).identifier()
|
|
509
511
|
if output_identifier != output_col_name:
|
|
510
512
|
cols_to_drop.append(output_identifier)
|
|
511
513
|
output_cols.append(F.col(output_identifier).astype(output_type))
|
|
@@ -11,6 +11,7 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest_sch
|
|
|
11
11
|
from snowflake.ml.model._model_composer.model_method import (
|
|
12
12
|
constants,
|
|
13
13
|
function_generator,
|
|
14
|
+
utils,
|
|
14
15
|
)
|
|
15
16
|
from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
|
|
16
17
|
from snowflake.ml.model.volatility import Volatility
|
|
@@ -34,9 +35,13 @@ def get_model_method_options_from_options(
|
|
|
34
35
|
options: type_hints.ModelSaveOption, target_method: str
|
|
35
36
|
) -> ModelMethodOptions:
|
|
36
37
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
|
38
|
+
method_option = options.get("method_options", {}).get(target_method, {})
|
|
39
|
+
case_sensitive = method_option.get("case_sensitive", False)
|
|
37
40
|
if target_method == "explain":
|
|
38
41
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
|
39
|
-
|
|
42
|
+
case_sensitive = utils.determine_explain_case_sensitive_from_method_options(
|
|
43
|
+
options.get("method_options", {}), target_method
|
|
44
|
+
)
|
|
40
45
|
global_function_type = options.get("function_type", default_function_type)
|
|
41
46
|
function_type = method_option.get("function_type", global_function_type)
|
|
42
47
|
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|
|
@@ -48,7 +53,7 @@ def get_model_method_options_from_options(
|
|
|
48
53
|
|
|
49
54
|
# Only include volatility if explicitly provided in method options
|
|
50
55
|
result: ModelMethodOptions = ModelMethodOptions(
|
|
51
|
-
case_sensitive=
|
|
56
|
+
case_sensitive=case_sensitive,
|
|
52
57
|
function_type=function_type,
|
|
53
58
|
)
|
|
54
59
|
if resolved_volatility:
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Mapping, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def determine_explain_case_sensitive_from_method_options(
|
|
7
|
+
method_options: Mapping[str, Optional[Mapping[str, Any]]],
|
|
8
|
+
target_method: str,
|
|
9
|
+
) -> bool:
|
|
10
|
+
"""Determine explain method case sensitivity from related predict methods.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
method_options: Mapping from method name to its options. Each option may
|
|
14
|
+
contain ``"case_sensitive"`` to indicate SQL identifier sensitivity.
|
|
15
|
+
target_method: The target method name being resolved (e.g., an ``explain_*``
|
|
16
|
+
method).
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
True if the explain method should be treated as case sensitive; otherwise False.
|
|
20
|
+
"""
|
|
21
|
+
if "explain" not in target_method:
|
|
22
|
+
return False
|
|
23
|
+
predict_priority_methods = ["predict_proba", "predict", "predict_log_proba"]
|
|
24
|
+
for src_method in predict_priority_methods:
|
|
25
|
+
src_opts = method_options.get(src_method)
|
|
26
|
+
if src_opts is not None:
|
|
27
|
+
return bool(src_opts.get("case_sensitive", False))
|
|
28
|
+
return False
|
|
@@ -240,14 +240,31 @@ class ModelEnv:
|
|
|
240
240
|
self._conda_dependencies[channel].remove(spec)
|
|
241
241
|
|
|
242
242
|
def generate_env_for_cuda(self) -> None:
|
|
243
|
+
|
|
244
|
+
# Insert py-xgboost-gpu only for XGBoost versions < 3.0.0
|
|
243
245
|
xgboost_spec = env_utils.find_dep_spec(
|
|
244
|
-
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=
|
|
246
|
+
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=False
|
|
245
247
|
)
|
|
246
248
|
if xgboost_spec:
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
249
|
+
# Only handle explicitly pinned versions. Insert GPU variant iff pinned major < 3.
|
|
250
|
+
pinned_major: Optional[int] = None
|
|
251
|
+
for spec in xgboost_spec.specifier:
|
|
252
|
+
if spec.operator in ("==", "===", ">", ">="):
|
|
253
|
+
try:
|
|
254
|
+
pinned_major = version.parse(spec.version).major
|
|
255
|
+
except version.InvalidVersion:
|
|
256
|
+
pinned_major = None
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
if pinned_major is not None and pinned_major < 3:
|
|
260
|
+
xgboost_spec = env_utils.find_dep_spec(
|
|
261
|
+
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
|
|
262
|
+
)
|
|
263
|
+
if xgboost_spec:
|
|
264
|
+
self.include_if_absent(
|
|
265
|
+
[ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
|
|
266
|
+
check_local_version=False,
|
|
267
|
+
)
|
|
251
268
|
|
|
252
269
|
tf_spec = env_utils.find_dep_spec(
|
|
253
270
|
self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True
|