snowflake-ml-python 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +3 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +18 -19
- snowflake/ml/experiment/callback/keras.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +3 -0
- snowflake/ml/experiment/callback/xgboost.py +3 -0
- snowflake/ml/experiment/experiment_tracking.py +50 -70
- snowflake/ml/feature_store/feature_store.py +299 -69
- snowflake/ml/feature_store/feature_view.py +12 -6
- snowflake/ml/fileset/stage_fs.py +12 -1
- snowflake/ml/jobs/_utils/constants.py +12 -1
- snowflake/ml/jobs/_utils/payload_utils.py +7 -1
- snowflake/ml/jobs/_utils/stage_utils.py +4 -0
- snowflake/ml/jobs/_utils/types.py +5 -0
- snowflake/ml/jobs/job.py +19 -5
- snowflake/ml/jobs/manager.py +18 -7
- snowflake/ml/model/__init__.py +19 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
- snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
- snowflake/ml/model/_client/model/model_version_impl.py +129 -11
- snowflake/ml/model/_client/ops/model_ops.py +11 -4
- snowflake/ml/model/_client/ops/service_ops.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +4 -1
- snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
- snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/type_hints.py +16 -0
- snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
- snowflake/ml/monitoring/explain_visualize.py +3 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/METADATA +50 -4
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/RECORD +38 -37
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/top_level.txt +0 -0
|
@@ -488,10 +488,13 @@ class JobPayload:
|
|
|
488
488
|
" comment = 'Created by snowflake.ml.jobs Python API'",
|
|
489
489
|
params=[stage_name],
|
|
490
490
|
)
|
|
491
|
-
|
|
491
|
+
payload_name = None
|
|
492
492
|
# Upload payload to stage - organize into app/ subdirectory
|
|
493
493
|
app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
|
|
494
494
|
if not isinstance(source, types.PayloadPath):
|
|
495
|
+
if isinstance(source, function_payload_utils.FunctionPayload):
|
|
496
|
+
payload_name = source.function.__name__
|
|
497
|
+
|
|
495
498
|
source_code = generate_python_code(source, source_code_display=True)
|
|
496
499
|
_ = session.file.put_stream(
|
|
497
500
|
io.BytesIO(source_code.encode()),
|
|
@@ -502,12 +505,14 @@ class JobPayload:
|
|
|
502
505
|
source = Path(entrypoint.file_path.parent)
|
|
503
506
|
|
|
504
507
|
elif isinstance(source, stage_utils.StagePath):
|
|
508
|
+
payload_name = entrypoint.file_path.stem
|
|
505
509
|
# copy payload to stage
|
|
506
510
|
if source == entrypoint.file_path:
|
|
507
511
|
source = source.parent
|
|
508
512
|
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
509
513
|
|
|
510
514
|
elif isinstance(source, Path):
|
|
515
|
+
payload_name = entrypoint.file_path.stem
|
|
511
516
|
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
512
517
|
if source.is_file():
|
|
513
518
|
source = source.parent
|
|
@@ -562,6 +567,7 @@ class JobPayload:
|
|
|
562
567
|
*python_entrypoint,
|
|
563
568
|
],
|
|
564
569
|
env_vars=env_vars,
|
|
570
|
+
payload_name=payload_name,
|
|
565
571
|
)
|
|
566
572
|
|
|
567
573
|
|
|
@@ -32,6 +32,10 @@ class StagePath:
|
|
|
32
32
|
self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
|
|
33
33
|
self._path = Path(relpath or "")
|
|
34
34
|
|
|
35
|
+
@property
|
|
36
|
+
def stem(self) -> str:
|
|
37
|
+
return self._path.stem
|
|
38
|
+
|
|
35
39
|
@property
|
|
36
40
|
def parts(self) -> tuple[str, ...]:
|
|
37
41
|
return self._path.parts
|
|
@@ -23,6 +23,10 @@ class PayloadPath(Protocol):
|
|
|
23
23
|
def name(self) -> str:
|
|
24
24
|
...
|
|
25
25
|
|
|
26
|
+
@property
|
|
27
|
+
def stem(self) -> str:
|
|
28
|
+
...
|
|
29
|
+
|
|
26
30
|
@property
|
|
27
31
|
def suffix(self) -> str:
|
|
28
32
|
...
|
|
@@ -92,6 +96,7 @@ class UploadedPayload:
|
|
|
92
96
|
stage_path: PurePath
|
|
93
97
|
entrypoint: list[Union[str, PurePath]]
|
|
94
98
|
env_vars: dict[str, str] = field(default_factory=dict)
|
|
99
|
+
payload_name: Optional[str] = None
|
|
95
100
|
|
|
96
101
|
|
|
97
102
|
@dataclass(frozen=True)
|
snowflake/ml/jobs/job.py
CHANGED
|
@@ -109,18 +109,18 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
109
109
|
return cast(dict[str, Any], container_spec)
|
|
110
110
|
|
|
111
111
|
@property
|
|
112
|
-
def _stage_path(self) -> str:
|
|
112
|
+
def _stage_path(self) -> Optional[str]:
|
|
113
113
|
"""Get the job's artifact storage stage location."""
|
|
114
114
|
volumes = self._service_spec["spec"]["volumes"]
|
|
115
|
-
|
|
116
|
-
return cast(str,
|
|
115
|
+
stage_volume = next((v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME), None)
|
|
116
|
+
return cast(str, stage_volume["source"]) if stage_volume else None
|
|
117
117
|
|
|
118
118
|
@property
|
|
119
119
|
def _result_path(self) -> str:
|
|
120
120
|
"""Get the job's result file location."""
|
|
121
121
|
result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
|
122
122
|
if result_path_str is None:
|
|
123
|
-
raise
|
|
123
|
+
raise NotImplementedError(f"Job {self.name} doesn't have a result path configured")
|
|
124
124
|
|
|
125
125
|
return self._transform_path(result_path_str)
|
|
126
126
|
|
|
@@ -229,8 +229,22 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
229
229
|
Raises:
|
|
230
230
|
TimeoutError: If the job does not complete within the specified timeout.
|
|
231
231
|
"""
|
|
232
|
-
delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
|
|
233
232
|
start_time = time.monotonic()
|
|
233
|
+
try:
|
|
234
|
+
# spcs_wait_for() is a synchronous query, it’s more effective to do polling with exponential
|
|
235
|
+
# backoff. If the job is running for a long time. We want a hybrid option: use spcs_wait_for()
|
|
236
|
+
# for the first 30 seconds, then switch to polling for long running jobs
|
|
237
|
+
min_timeout = (
|
|
238
|
+
int(min(timeout, constants.JOB_SPCS_TIMEOUT_SECONDS))
|
|
239
|
+
if timeout >= 0
|
|
240
|
+
else constants.JOB_SPCS_TIMEOUT_SECONDS
|
|
241
|
+
)
|
|
242
|
+
query_helper.run_query(self._session, f"call {self.id}!spcs_wait_for('DONE', {min_timeout})")
|
|
243
|
+
return self.status
|
|
244
|
+
except SnowparkSQLException:
|
|
245
|
+
# if the function does not support for this environment
|
|
246
|
+
pass
|
|
247
|
+
delay: float = float(constants.JOB_POLL_INITIAL_DELAY_SECONDS) # Start with 5s delay
|
|
234
248
|
warning_shown = False
|
|
235
249
|
while (status := self.status) not in TERMINAL_JOB_STATUSES:
|
|
236
250
|
elapsed = time.monotonic() - start_time
|
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -192,12 +192,12 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
|
|
|
192
192
|
"""Delete a job service from the backend. Status and logs will be lost."""
|
|
193
193
|
job = job if isinstance(job, jb.MLJob) else get_job(job, session=session)
|
|
194
194
|
session = job._session
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
195
|
+
if job._stage_path:
|
|
196
|
+
try:
|
|
197
|
+
session.sql(f"REMOVE {job._stage_path}/").collect()
|
|
198
|
+
logger.debug(f"Successfully cleaned up stage files for job {job.id} at {job._stage_path}")
|
|
199
|
+
except Exception as e:
|
|
200
|
+
logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
|
|
201
201
|
query_helper.run_query(session, "DROP SERVICE IDENTIFIER(?)", params=(job.id,))
|
|
202
202
|
|
|
203
203
|
|
|
@@ -697,10 +697,21 @@ def _do_submit_job_v2(
|
|
|
697
697
|
"MIN_INSTANCES": min_instances,
|
|
698
698
|
"ASYNC": use_async,
|
|
699
699
|
}
|
|
700
|
+
if payload.payload_name:
|
|
701
|
+
job_options["GENERATE_SUFFIX"] = True
|
|
700
702
|
job_options = {k: v for k, v in job_options.items() if v is not None}
|
|
701
703
|
|
|
702
704
|
query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
|
|
703
|
-
|
|
705
|
+
if job_id:
|
|
706
|
+
database, schema, _ = identifier.parse_schema_level_object_identifier(job_id)
|
|
707
|
+
params = [
|
|
708
|
+
job_id
|
|
709
|
+
if payload.payload_name is None
|
|
710
|
+
else identifier.get_schema_level_object_identifier(database, schema, payload.payload_name) + "_",
|
|
711
|
+
compute_pool,
|
|
712
|
+
json.dumps(spec_options),
|
|
713
|
+
json.dumps(job_options),
|
|
714
|
+
]
|
|
704
715
|
actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
|
|
705
716
|
|
|
706
717
|
return get_job(actual_job_id, session=session)
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
1
4
|
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
2
5
|
JobSpec,
|
|
3
6
|
OutputSpec,
|
|
@@ -18,3 +21,19 @@ __all__ = [
|
|
|
18
21
|
"SaveMode",
|
|
19
22
|
"Volatility",
|
|
20
23
|
]
|
|
24
|
+
|
|
25
|
+
_deprecation_warning_msg_for_3_9 = (
|
|
26
|
+
"Python 3.9 is deprecated in snowflake-ml-python. " "Please upgrade to Python 3.10 or greater."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
warnings.filterwarnings(
|
|
30
|
+
"once",
|
|
31
|
+
message=_deprecation_warning_msg_for_3_9,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if sys.version_info.major == 3 and sys.version_info.minor == 9:
|
|
35
|
+
warnings.warn(
|
|
36
|
+
_deprecation_warning_msg_for_3_9,
|
|
37
|
+
category=DeprecationWarning,
|
|
38
|
+
stacklevel=2,
|
|
39
|
+
)
|
|
@@ -19,11 +19,74 @@ class SaveMode(str, Enum):
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class OutputSpec(BaseModel):
|
|
22
|
+
"""Specification for batch inference output.
|
|
23
|
+
|
|
24
|
+
Defines where the inference results should be written and how to handle
|
|
25
|
+
existing files at the output location.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
stage_location (str): The stage path where batch inference results will be saved.
|
|
29
|
+
This should be a full path including the stage with @ prefix. For example,
|
|
30
|
+
'@My_DB.PUBLIC.MY_STAGE/someth/path/'. A non-existent directory will be re-created.
|
|
31
|
+
Only Snowflake internal stages are supported at this moment.
|
|
32
|
+
mode (SaveMode): The save mode that determines behavior when files already exist
|
|
33
|
+
at the output location. Defaults to SaveMode.ERROR which raises an error
|
|
34
|
+
if files exist. Can be set to SaveMode.OVERWRITE to replace existing files.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> output_spec = OutputSpec(
|
|
38
|
+
... stage_location="@My_DB.PUBLIC.MY_STAGE/someth/path/",
|
|
39
|
+
... mode=SaveMode.OVERWRITE
|
|
40
|
+
... )
|
|
41
|
+
"""
|
|
42
|
+
|
|
22
43
|
stage_location: str
|
|
23
44
|
mode: SaveMode = SaveMode.ERROR
|
|
24
45
|
|
|
25
46
|
|
|
26
47
|
class JobSpec(BaseModel):
|
|
48
|
+
"""Specification for batch inference job execution.
|
|
49
|
+
|
|
50
|
+
Defines the compute resources, job settings, and execution parameters
|
|
51
|
+
for running batch inference jobs in Snowflake.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
image_repo (Optional[str]): Container image repository for the inference job.
|
|
55
|
+
If not specified, uses the default repository.
|
|
56
|
+
job_name (Optional[str]): Custom name for the batch inference job.
|
|
57
|
+
If not provided, a name will be auto-generated in the form of "BATCH_INFERENCE_<UUID>".
|
|
58
|
+
num_workers (Optional[int]): The number of workers to run the inference service for handling
|
|
59
|
+
requests in parallel within an instance of the service. By default, it is set to 2*vCPU+1
|
|
60
|
+
of the node for CPU based inference and 1 for GPU based inference. For GPU based inference,
|
|
61
|
+
please see best practices before playing with this value.
|
|
62
|
+
function_name (Optional[str]): Name of the specific function to call for inference.
|
|
63
|
+
Required when the model has multiple inference functions.
|
|
64
|
+
force_rebuild (bool): Whether to force rebuilding the container image even if
|
|
65
|
+
it already exists. Defaults to False.
|
|
66
|
+
max_batch_rows (int): Maximum number of rows to process in a single batch.
|
|
67
|
+
Defaults to 1024. Larger values may improve throughput.
|
|
68
|
+
warehouse (Optional[str]): Snowflake warehouse to use for the batch inference job.
|
|
69
|
+
If not specified, uses the session's current warehouse.
|
|
70
|
+
cpu_requests (Optional[str]): The cpu limit for CPU based inference. Can be an integer,
|
|
71
|
+
fractional or string values. If None, we attempt to utilize all the vCPU of the node.
|
|
72
|
+
memory_requests (Optional[str]): The memory limit for inference. Can be an integer
|
|
73
|
+
or a fractional value, but requires a unit (GiB, MiB). If None, we attempt to utilize all
|
|
74
|
+
the memory of the node.
|
|
75
|
+
gpu_requests (Optional[str]): The gpu limit for GPU based inference. Can be integer or
|
|
76
|
+
string values. Use CPU if None.
|
|
77
|
+
replicas (Optional[int]): Number of job replicas to run for high availability.
|
|
78
|
+
If not specified, defaults to 1 replica.
|
|
79
|
+
|
|
80
|
+
Example:
|
|
81
|
+
>>> job_spec = JobSpec(
|
|
82
|
+
... job_name="my_inference_job",
|
|
83
|
+
... num_workers=4,
|
|
84
|
+
... cpu_requests="2",
|
|
85
|
+
... memory_requests="8Gi",
|
|
86
|
+
... max_batch_rows=2048
|
|
87
|
+
... )
|
|
88
|
+
"""
|
|
89
|
+
|
|
27
90
|
image_repo: Optional[str] = None
|
|
28
91
|
job_name: Optional[str] = None
|
|
29
92
|
num_workers: Optional[int] = None
|
|
@@ -6,13 +6,9 @@ from snowflake.ml.model._client.ops import service_ops
|
|
|
6
6
|
def _get_inference_engine_args(
|
|
7
7
|
experimental_options: Optional[dict[str, Any]],
|
|
8
8
|
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
9
|
-
|
|
10
|
-
if not experimental_options:
|
|
9
|
+
if not experimental_options or "inference_engine" not in experimental_options:
|
|
11
10
|
return None
|
|
12
11
|
|
|
13
|
-
if "inference_engine" not in experimental_options:
|
|
14
|
-
raise ValueError("inference_engine is required in experimental_options")
|
|
15
|
-
|
|
16
12
|
return service_ops.InferenceEngineArgs(
|
|
17
13
|
inference_engine=experimental_options["inference_engine"],
|
|
18
14
|
inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
|
|
@@ -7,6 +7,7 @@ from typing import Any, Callable, Optional, Union, overload
|
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
10
|
+
from snowflake import snowpark
|
|
10
11
|
from snowflake.ml import jobs
|
|
11
12
|
from snowflake.ml._internal import telemetry
|
|
12
13
|
from snowflake.ml._internal.utils import sql_identifier
|
|
@@ -593,7 +594,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
593
594
|
"job_spec",
|
|
594
595
|
],
|
|
595
596
|
)
|
|
596
|
-
|
|
597
|
+
@snowpark._internal.utils.private_preview(version="1.18.0")
|
|
598
|
+
def run_batch(
|
|
597
599
|
self,
|
|
598
600
|
*,
|
|
599
601
|
compute_pool: str,
|
|
@@ -601,6 +603,68 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
601
603
|
output_spec: batch_inference_specs.OutputSpec,
|
|
602
604
|
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
603
605
|
) -> jobs.MLJob[Any]:
|
|
606
|
+
"""Execute batch inference on datasets as an SPCS job.
|
|
607
|
+
|
|
608
|
+
Args:
|
|
609
|
+
compute_pool (str): Name of the compute pool to use for building the image containers and batch
|
|
610
|
+
inference execution.
|
|
611
|
+
input_spec (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
|
|
612
|
+
The DataFrame should contain all required features for model prediction and passthrough columns.
|
|
613
|
+
output_spec (batch_inference_specs.OutputSpec): Configuration for where and how to save
|
|
614
|
+
the inference results. Specifies the stage location and file handling behavior.
|
|
615
|
+
job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
|
|
616
|
+
execution parameters such as compute resources, worker counts, and job naming.
|
|
617
|
+
If None, default values will be used.
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
jobs.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
|
|
621
|
+
lifecycle.
|
|
622
|
+
|
|
623
|
+
Raises:
|
|
624
|
+
ValueError: If warehouse is not set in job_spec and no current warehouse is available.
|
|
625
|
+
RuntimeError: If the input_spec cannot be processed or written to the staging location.
|
|
626
|
+
|
|
627
|
+
Example:
|
|
628
|
+
>>> # Prepare input data - Example 1: From a table
|
|
629
|
+
>>> input_df = session.table("my_input_table")
|
|
630
|
+
>>>
|
|
631
|
+
>>> # Prepare input data - Example 2: From a SQL query
|
|
632
|
+
>>> input_df = session.sql(
|
|
633
|
+
... "SELECT id, feature_1, feature_2 FROM feature_table WHERE feature_1 > 100"
|
|
634
|
+
... )
|
|
635
|
+
>>>
|
|
636
|
+
>>> # Prepare input data - Example 3: From Parquet files in a stage
|
|
637
|
+
>>> input_df = session.read.option("pattern", ".*\\.parquet").parquet(
|
|
638
|
+
... "@my_stage/input_data/"
|
|
639
|
+
... ).select("id", "feature_1", "feature_2")
|
|
640
|
+
>>>
|
|
641
|
+
>>> # Configure output location
|
|
642
|
+
>>> output_spec = OutputSpec(
|
|
643
|
+
... stage_location='@My_DB.PUBLIC.MY_STAGE/someth/path/',
|
|
644
|
+
... mode=SaveMode.OVERWRITE
|
|
645
|
+
... )
|
|
646
|
+
>>>
|
|
647
|
+
>>> # Configure job parameters
|
|
648
|
+
>>> job_spec = JobSpec(
|
|
649
|
+
... job_name="my_batch_inference",
|
|
650
|
+
... num_workers=4,
|
|
651
|
+
... cpu_requests="2",
|
|
652
|
+
... memory_requests="8Gi"
|
|
653
|
+
... )
|
|
654
|
+
>>>
|
|
655
|
+
>>> # Run batch inference
|
|
656
|
+
>>> job = model_version.run_batch(
|
|
657
|
+
... compute_pool="my_compute_pool",
|
|
658
|
+
... input_spec=input_df,
|
|
659
|
+
... output_spec=output_spec,
|
|
660
|
+
... job_spec=job_spec
|
|
661
|
+
... )
|
|
662
|
+
|
|
663
|
+
Note:
|
|
664
|
+
This method is currently in private preview and requires Snowflake version 1.18.0 or later.
|
|
665
|
+
The input data is temporarily stored in the output stage location under /_temporary before
|
|
666
|
+
inference execution.
|
|
667
|
+
"""
|
|
604
668
|
statement_params = telemetry.get_statement_params(
|
|
605
669
|
project=_TELEMETRY_PROJECT,
|
|
606
670
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
@@ -827,6 +891,51 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
827
891
|
version_name=sql_identifier.SqlIdentifier(version),
|
|
828
892
|
)
|
|
829
893
|
|
|
894
|
+
def _can_run_on_gpu(
|
|
895
|
+
self,
|
|
896
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
897
|
+
) -> bool:
|
|
898
|
+
"""Check if the model has GPU runtime support.
|
|
899
|
+
|
|
900
|
+
Args:
|
|
901
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
902
|
+
in the SQL command to fetch model spec.
|
|
903
|
+
|
|
904
|
+
Returns:
|
|
905
|
+
True if the model has GPU runtime configured, False otherwise.
|
|
906
|
+
"""
|
|
907
|
+
# Fetch model spec
|
|
908
|
+
model_spec = self._get_model_spec(statement_params)
|
|
909
|
+
|
|
910
|
+
# Check if runtimes section exists and has gpu runtime
|
|
911
|
+
runtimes = model_spec.get("runtimes", {})
|
|
912
|
+
return "gpu" in runtimes
|
|
913
|
+
|
|
914
|
+
def _throw_error_if_gpu_is_not_supported(
|
|
915
|
+
self,
|
|
916
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
|
917
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
918
|
+
) -> None:
|
|
919
|
+
"""Check if the model has GPU runtime support.
|
|
920
|
+
|
|
921
|
+
Args:
|
|
922
|
+
gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
|
|
923
|
+
if None.
|
|
924
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
925
|
+
in the SQL command to fetch model spec.
|
|
926
|
+
|
|
927
|
+
Raises:
|
|
928
|
+
ValueError: If the model does not have GPU runtime support.
|
|
929
|
+
"""
|
|
930
|
+
if gpu_requests is not None and not self._can_run_on_gpu(statement_params):
|
|
931
|
+
raise ValueError(
|
|
932
|
+
f"GPU resources requested (gpu_requests={gpu_requests}), but the model "
|
|
933
|
+
f"{self.fully_qualified_model_name} version {self.version_name} does not have GPU runtime support. "
|
|
934
|
+
"Please ensure the model was logged with GPU runtime configuration or do not provide gpu_requests. "
|
|
935
|
+
"To log the model with GPU runtime configuration, provide `cuda_version` in the `options` while calling"
|
|
936
|
+
" the `log_model` function."
|
|
937
|
+
)
|
|
938
|
+
|
|
830
939
|
def _check_huggingface_text_generation_model(
|
|
831
940
|
self,
|
|
832
941
|
statement_params: Optional[dict[str, Any]] = None,
|
|
@@ -926,9 +1035,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
926
1035
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
927
1036
|
and returns an :class:`AsyncJob`.
|
|
928
1037
|
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
929
|
-
Currently,
|
|
1038
|
+
Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
|
|
930
1039
|
`inference_engine` is the name of the inference engine to use.
|
|
931
1040
|
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1041
|
+
`autocapture` is a boolean to enable/disable inference table.
|
|
932
1042
|
"""
|
|
933
1043
|
...
|
|
934
1044
|
|
|
@@ -984,9 +1094,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
984
1094
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
985
1095
|
and returns an :class:`AsyncJob`.
|
|
986
1096
|
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
987
|
-
Currently,
|
|
1097
|
+
Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
|
|
988
1098
|
`inference_engine` is the name of the inference engine to use.
|
|
989
1099
|
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1100
|
+
`autocapture` is a boolean to enable/disable inference table.
|
|
990
1101
|
"""
|
|
991
1102
|
...
|
|
992
1103
|
|
|
@@ -1059,21 +1170,20 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1059
1170
|
When it is False, this function executes the underlying service creation asynchronously
|
|
1060
1171
|
and returns an AsyncJob.
|
|
1061
1172
|
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
1062
|
-
Currently,
|
|
1173
|
+
Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
|
|
1063
1174
|
`inference_engine` is the name of the inference engine to use.
|
|
1064
1175
|
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1176
|
+
`autocapture` is a boolean to enable/disable inference table.
|
|
1065
1177
|
|
|
1066
1178
|
|
|
1067
1179
|
Raises:
|
|
1068
|
-
ValueError: Illegal external access integration arguments
|
|
1180
|
+
ValueError: Illegal external access integration arguments, or if GPU resources are requested
|
|
1181
|
+
but the model does not have GPU runtime support.
|
|
1069
1182
|
exceptions.SnowparkSQLException: if service already exists.
|
|
1070
1183
|
|
|
1071
1184
|
Returns:
|
|
1072
1185
|
If `block=True`, return result information about service creation from server.
|
|
1073
1186
|
Otherwise, return the service creation AsyncJob.
|
|
1074
|
-
|
|
1075
|
-
Raises:
|
|
1076
|
-
ValueError: Illegal external access integration arguments.
|
|
1077
1187
|
"""
|
|
1078
1188
|
statement_params = telemetry.get_statement_params(
|
|
1079
1189
|
project=_TELEMETRY_PROJECT,
|
|
@@ -1096,12 +1206,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1096
1206
|
|
|
1097
1207
|
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
|
1098
1208
|
|
|
1099
|
-
#
|
|
1100
|
-
|
|
1101
|
-
self._check_huggingface_text_generation_model(statement_params)
|
|
1209
|
+
# Validate GPU support if GPU resources are requested
|
|
1210
|
+
self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
|
|
1102
1211
|
|
|
1103
1212
|
inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
|
|
1104
1213
|
|
|
1214
|
+
# Check if model is HuggingFace text-generation before doing inference engine checks
|
|
1215
|
+
# Only validate if inference engine is actually specified
|
|
1216
|
+
if inference_engine_args is not None:
|
|
1217
|
+
self._check_huggingface_text_generation_model(statement_params)
|
|
1218
|
+
|
|
1105
1219
|
# Enrich inference engine args if inference engine is specified
|
|
1106
1220
|
if inference_engine_args is not None:
|
|
1107
1221
|
inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
|
|
@@ -1109,6 +1223,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1109
1223
|
gpu_requests,
|
|
1110
1224
|
)
|
|
1111
1225
|
|
|
1226
|
+
# Extract autocapture from experimental_options
|
|
1227
|
+
autocapture = experimental_options.get("autocapture") if experimental_options else None
|
|
1228
|
+
|
|
1112
1229
|
from snowflake.ml.model import event_handler
|
|
1113
1230
|
from snowflake.snowpark import exceptions
|
|
1114
1231
|
|
|
@@ -1148,6 +1265,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1148
1265
|
statement_params=statement_params,
|
|
1149
1266
|
progress_status=status,
|
|
1150
1267
|
inference_engine_args=inference_engine_args,
|
|
1268
|
+
autocapture=autocapture,
|
|
1151
1269
|
)
|
|
1152
1270
|
status.update(label="Model service created successfully", state="complete", expanded=False)
|
|
1153
1271
|
return result
|
|
@@ -515,10 +515,17 @@ class ModelOperator:
|
|
|
515
515
|
statement_params=statement_params,
|
|
516
516
|
)
|
|
517
517
|
for r in res:
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
518
|
+
aliases_data = r[self._model_client.MODEL_VERSION_ALIASES_COL_NAME]
|
|
519
|
+
if aliases_data:
|
|
520
|
+
aliases_list = json.loads(aliases_data)
|
|
521
|
+
|
|
522
|
+
# Compare using Snowflake identifier semantics for exact match
|
|
523
|
+
for alias in aliases_list:
|
|
524
|
+
if sql_identifier.SqlIdentifier(alias) == alias_name:
|
|
525
|
+
return sql_identifier.SqlIdentifier(
|
|
526
|
+
r[self._model_client.MODEL_VERSION_NAME_COL_NAME], case_sensitive=True
|
|
527
|
+
)
|
|
528
|
+
|
|
522
529
|
return None
|
|
523
530
|
|
|
524
531
|
def get_tag_value(
|
|
@@ -206,6 +206,8 @@ class ServiceOperator:
|
|
|
206
206
|
hf_model_args: Optional[HFModelArgs] = None,
|
|
207
207
|
# inference engine model
|
|
208
208
|
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
209
|
+
# inference table
|
|
210
|
+
autocapture: Optional[bool] = None,
|
|
209
211
|
) -> Union[str, async_job.AsyncJob]:
|
|
210
212
|
|
|
211
213
|
# Generate operation ID for this deployment
|
|
@@ -261,6 +263,7 @@ class ServiceOperator:
|
|
|
261
263
|
gpu=gpu_requests,
|
|
262
264
|
num_workers=num_workers,
|
|
263
265
|
max_batch_rows=max_batch_rows,
|
|
266
|
+
autocapture=autocapture,
|
|
264
267
|
)
|
|
265
268
|
if hf_model_args:
|
|
266
269
|
# hf model
|
|
@@ -146,6 +146,7 @@ class ModelDeploymentSpec:
|
|
|
146
146
|
gpu: Optional[Union[str, int]] = None,
|
|
147
147
|
num_workers: Optional[int] = None,
|
|
148
148
|
max_batch_rows: Optional[int] = None,
|
|
149
|
+
autocapture: Optional[bool] = None,
|
|
149
150
|
) -> "ModelDeploymentSpec":
|
|
150
151
|
"""Add service specification to the deployment spec.
|
|
151
152
|
|
|
@@ -161,6 +162,7 @@ class ModelDeploymentSpec:
|
|
|
161
162
|
gpu: GPU requirement.
|
|
162
163
|
num_workers: Number of workers.
|
|
163
164
|
max_batch_rows: Maximum batch rows for inference.
|
|
165
|
+
autocapture: Whether to enable inference table.
|
|
164
166
|
|
|
165
167
|
Raises:
|
|
166
168
|
ValueError: If a job spec already exists.
|
|
@@ -186,6 +188,7 @@ class ModelDeploymentSpec:
|
|
|
186
188
|
compute_pool=inference_compute_pool_name.identifier(),
|
|
187
189
|
ingress_enabled=ingress_enabled,
|
|
188
190
|
max_instances=max_instances,
|
|
191
|
+
autocapture=autocapture,
|
|
189
192
|
**self._inference_spec,
|
|
190
193
|
)
|
|
191
194
|
return self
|
|
@@ -87,7 +87,9 @@ class ModelManifest:
|
|
|
87
87
|
model_meta_schema.FunctionProperties.PARTITIONED.value, False
|
|
88
88
|
),
|
|
89
89
|
wide_input=len(model_meta.signatures[target_method].inputs) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT,
|
|
90
|
-
options=model_method.get_model_method_options_from_options(
|
|
90
|
+
options=model_method.get_model_method_options_from_options(
|
|
91
|
+
options, target_method, model_meta.model_type
|
|
92
|
+
),
|
|
91
93
|
)
|
|
92
94
|
|
|
93
95
|
self.methods.append(method)
|
|
@@ -32,7 +32,7 @@ class ModelMethodOptions(TypedDict):
|
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def get_model_method_options_from_options(
|
|
35
|
-
options: type_hints.ModelSaveOption, target_method: str
|
|
35
|
+
options: type_hints.ModelSaveOption, target_method: str, model_type: Optional[str] = None
|
|
36
36
|
) -> ModelMethodOptions:
|
|
37
37
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
|
38
38
|
method_option = options.get("method_options", {}).get(target_method, {})
|
|
@@ -42,6 +42,9 @@ def get_model_method_options_from_options(
|
|
|
42
42
|
case_sensitive = utils.determine_explain_case_sensitive_from_method_options(
|
|
43
43
|
options.get("method_options", {}), target_method
|
|
44
44
|
)
|
|
45
|
+
elif model_type == "prophet":
|
|
46
|
+
# Prophet models always require TABLE_FUNCTION because they need entire time series context
|
|
47
|
+
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
|
45
48
|
global_function_type = options.get("function_type", default_function_type)
|
|
46
49
|
function_type = method_option.get("function_type", global_function_type)
|
|
47
50
|
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|