snowflake-ml-python 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +0 -4
- snowflake/ml/_internal/utils/mixins.py +26 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
- snowflake/ml/data/data_connector.py +2 -2
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/experiment/_experiment_info.py +3 -3
- snowflake/ml/feature_store/__init__.py +2 -0
- snowflake/ml/feature_store/aggregation.py +367 -0
- snowflake/ml/feature_store/feature.py +366 -0
- snowflake/ml/feature_store/feature_store.py +234 -20
- snowflake/ml/feature_store/feature_view.py +189 -4
- snowflake/ml/feature_store/metadata_manager.py +425 -0
- snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
- snowflake/ml/jobs/_interop/data_utils.py +8 -8
- snowflake/ml/jobs/_interop/dto_schema.py +52 -7
- snowflake/ml/jobs/_interop/protocols.py +124 -7
- snowflake/ml/jobs/_interop/utils.py +92 -33
- snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
- snowflake/ml/jobs/_utils/constants.py +4 -0
- snowflake/ml/jobs/_utils/feature_flags.py +97 -13
- snowflake/ml/jobs/_utils/payload_utils.py +6 -40
- snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
- snowflake/ml/jobs/decorators.py +17 -22
- snowflake/ml/jobs/job.py +25 -10
- snowflake/ml/jobs/job_definition.py +100 -8
- snowflake/ml/model/__init__.py +4 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
- snowflake/ml/model/_client/model/model_version_impl.py +56 -28
- snowflake/ml/model/_client/ops/model_ops.py +2 -8
- snowflake/ml/model/_client/ops/service_ops.py +6 -11
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +21 -29
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +20 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +70 -14
- snowflake/ml/model/_signatures/utils.py +76 -1
- snowflake/ml/model/models/huggingface_pipeline.py +3 -0
- snowflake/ml/model/openai_signatures.py +154 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +79 -2
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +47 -44
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
- snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
- snowflake/ml/jobs/_utils/spec_utils.py +0 -22
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
|
@@ -14,11 +14,14 @@ from snowflake.ml._internal import telemetry
|
|
|
14
14
|
from snowflake.ml._internal.utils import identifier
|
|
15
15
|
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
|
16
16
|
from snowflake.ml.jobs import job as jb
|
|
17
|
+
from snowflake.ml.jobs._interop import utils as interop_utils
|
|
17
18
|
from snowflake.ml.jobs._utils import (
|
|
19
|
+
arg_protocol,
|
|
18
20
|
constants,
|
|
19
21
|
feature_flags,
|
|
20
22
|
payload_utils,
|
|
21
23
|
query_helper,
|
|
24
|
+
runtime_env_utils,
|
|
22
25
|
types,
|
|
23
26
|
)
|
|
24
27
|
from snowflake.snowpark import context as sp_context
|
|
@@ -40,6 +43,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
40
43
|
compute_pool: str,
|
|
41
44
|
name: str,
|
|
42
45
|
entrypoint_args: list[Any],
|
|
46
|
+
arg_protocol: Optional[arg_protocol.ArgProtocol] = arg_protocol.ArgProtocol.NONE,
|
|
47
|
+
default_args: Optional[list[Any]] = None,
|
|
43
48
|
database: Optional[str] = None,
|
|
44
49
|
schema: Optional[str] = None,
|
|
45
50
|
session: Optional[snowpark.Session] = None,
|
|
@@ -49,12 +54,22 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
49
54
|
self.spec_options = spec_options
|
|
50
55
|
self.compute_pool = compute_pool
|
|
51
56
|
self.session = session or sp_context.get_active_session()
|
|
52
|
-
|
|
53
|
-
|
|
57
|
+
resolved_database = database or self.session.get_current_database()
|
|
58
|
+
resolved_schema = schema or self.session.get_current_schema()
|
|
59
|
+
if resolved_database is None:
|
|
60
|
+
raise ValueError("Database must be specified either in the session context or as a parameter.")
|
|
61
|
+
if resolved_schema is None:
|
|
62
|
+
raise ValueError("Schema must be specified either in the session context or as a parameter.")
|
|
63
|
+
self.database = identifier.resolve_identifier(resolved_database)
|
|
64
|
+
self.schema = identifier.resolve_identifier(resolved_schema)
|
|
54
65
|
self.job_definition_id = identifier.get_schema_level_object_identifier(self.database, self.schema, name)
|
|
55
66
|
self.entrypoint_args = entrypoint_args
|
|
67
|
+
self.arg_protocol = arg_protocol
|
|
68
|
+
self.default_args = default_args
|
|
56
69
|
|
|
57
70
|
def delete(self) -> None:
|
|
71
|
+
if self.session is None:
|
|
72
|
+
raise RuntimeError("Session is required to delete job definition")
|
|
58
73
|
if self.stage_name:
|
|
59
74
|
try:
|
|
60
75
|
self.session.sql(f"REMOVE {self.stage_name}/").collect()
|
|
@@ -62,9 +77,27 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
62
77
|
except Exception as e:
|
|
63
78
|
logger.warning(f"Failed to clean up stage files for job definition {self.stage_name}: {e}")
|
|
64
79
|
|
|
65
|
-
def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> list[Any]:
|
|
66
|
-
|
|
67
|
-
|
|
80
|
+
def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> Optional[list[Any]]:
|
|
81
|
+
if self.arg_protocol == arg_protocol.ArgProtocol.NONE:
|
|
82
|
+
if len(kwargs) > 0:
|
|
83
|
+
raise ValueError(f"Keyword arguments are not supported with {self.arg_protocol}")
|
|
84
|
+
return list(args)
|
|
85
|
+
elif self.arg_protocol == arg_protocol.ArgProtocol.CLI:
|
|
86
|
+
return _combine_runtime_arguments(self.default_args, *args, **kwargs)
|
|
87
|
+
elif self.arg_protocol == arg_protocol.ArgProtocol.PICKLE:
|
|
88
|
+
if not args and not kwargs:
|
|
89
|
+
return []
|
|
90
|
+
uid = uuid4().hex[:8]
|
|
91
|
+
rel_path = f"{uid}/function_args"
|
|
92
|
+
file_path = f"{self.stage_name}/{constants.APP_STAGE_SUBPATH}/{rel_path}"
|
|
93
|
+
payload = interop_utils.save_result(
|
|
94
|
+
(args, kwargs), file_path, session=self.session, max_inline_size=interop_utils._MAX_INLINE_SIZE
|
|
95
|
+
)
|
|
96
|
+
if payload is not None:
|
|
97
|
+
return [f"--function_args={payload.decode('utf-8')}"]
|
|
98
|
+
return [f"--function_args={rel_path}"]
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f"Invalid arg_protocol: {self.arg_protocol}")
|
|
68
101
|
|
|
69
102
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
70
103
|
def __call__(self, *args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
|
@@ -98,6 +131,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
98
131
|
json.dumps(job_options_dict),
|
|
99
132
|
]
|
|
100
133
|
query_template = "CALL SYSTEM$EXECUTE_ML_JOB(%s, %s, %s, %s)"
|
|
134
|
+
assert self.session is not None, "Session is required to generate MLJob SQL query"
|
|
101
135
|
sql = self.session._conn._cursor._preprocess_pyformat_query(query_template, params)
|
|
102
136
|
return sql
|
|
103
137
|
|
|
@@ -123,6 +157,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
123
157
|
entrypoint: Optional[Union[str, list[str]]] = None,
|
|
124
158
|
target_instances: int = 1,
|
|
125
159
|
generate_suffix: bool = True,
|
|
160
|
+
arg_protocol: Optional[arg_protocol.ArgProtocol] = arg_protocol.ArgProtocol.NONE,
|
|
126
161
|
**kwargs: Any,
|
|
127
162
|
) -> "MLJobDefinition[_Args, _ReturnValue]":
|
|
128
163
|
# Use kwargs for less common optional parameters
|
|
@@ -142,6 +177,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
142
177
|
)
|
|
143
178
|
overwrite = kwargs.pop("overwrite", False)
|
|
144
179
|
name = kwargs.pop("name", None)
|
|
180
|
+
default_args = kwargs.pop("default_args", None)
|
|
145
181
|
# Warn if there are unknown kwargs
|
|
146
182
|
if kwargs:
|
|
147
183
|
logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
|
|
@@ -149,6 +185,11 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
149
185
|
# Validate parameters
|
|
150
186
|
if database and not schema:
|
|
151
187
|
raise ValueError("Schema must be specified if database is specified.")
|
|
188
|
+
|
|
189
|
+
compute_pool = identifier.resolve_identifier(compute_pool)
|
|
190
|
+
if query_warehouse is not None:
|
|
191
|
+
query_warehouse = identifier.resolve_identifier(query_warehouse)
|
|
192
|
+
|
|
152
193
|
if target_instances < 1:
|
|
153
194
|
raise ValueError("target_instances must be greater than 0.")
|
|
154
195
|
if not (0 < min_instances <= target_instances):
|
|
@@ -190,10 +231,11 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
190
231
|
)
|
|
191
232
|
raise
|
|
192
233
|
|
|
193
|
-
if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled(
|
|
234
|
+
if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
194
235
|
# Pass a JSON object for runtime versions so it serializes as nested JSON in options
|
|
195
236
|
runtime_environment = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
|
|
196
237
|
|
|
238
|
+
runtime = runtime_env_utils.get_runtime_image(session, compute_pool, runtime_environment)
|
|
197
239
|
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
198
240
|
entrypoint_args = [v.as_posix() if isinstance(v, PurePath) else v for v in uploaded_payload.entrypoint]
|
|
199
241
|
spec_options = types.SpecOptions(
|
|
@@ -203,8 +245,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
203
245
|
env_vars=combined_env_vars,
|
|
204
246
|
enable_metrics=enable_metrics,
|
|
205
247
|
spec_overrides=spec_overrides,
|
|
206
|
-
runtime=
|
|
207
|
-
enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(
|
|
248
|
+
runtime=runtime,
|
|
249
|
+
enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(),
|
|
208
250
|
)
|
|
209
251
|
|
|
210
252
|
job_options = types.JobOptions(
|
|
@@ -222,6 +264,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
222
264
|
compute_pool=compute_pool,
|
|
223
265
|
entrypoint_args=entrypoint_args,
|
|
224
266
|
session=session,
|
|
267
|
+
arg_protocol=arg_protocol,
|
|
268
|
+
default_args=default_args,
|
|
225
269
|
database=database,
|
|
226
270
|
schema=schema,
|
|
227
271
|
name=name,
|
|
@@ -230,3 +274,51 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
230
274
|
|
|
231
275
|
def _generate_suffix() -> str:
|
|
232
276
|
return str(uuid4().hex)[:8]
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _combine_runtime_arguments(
|
|
280
|
+
default_runtime_args: Optional[list[Any]] = None, *args: Any, **kwargs: Any
|
|
281
|
+
) -> list[Any]:
|
|
282
|
+
"""Merge default CLI arguments with runtime overrides into a flat argument list.
|
|
283
|
+
|
|
284
|
+
Parses `default_runtime_args` for flags (e.g., `--key value`) and merges them with
|
|
285
|
+
`kwargs`. Keyword arguments override defaults unless their value is None. Positional
|
|
286
|
+
arguments from both `default_args` and `*args` are preserved in order.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
default_runtime_args: Optional list of default CLI arguments to parse for flags and positional args.
|
|
290
|
+
*args: Additional positional arguments to include in the output.
|
|
291
|
+
**kwargs: Keyword arguments that override default flags. Values of None are ignored.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
A list of CLI-style arguments: positional args followed by `--key value` pairs.
|
|
295
|
+
"""
|
|
296
|
+
cli_args = list(args)
|
|
297
|
+
flags: dict[str, Any] = {}
|
|
298
|
+
if default_runtime_args:
|
|
299
|
+
i = 0
|
|
300
|
+
while i < len(default_runtime_args):
|
|
301
|
+
arg = default_runtime_args[i]
|
|
302
|
+
if isinstance(arg, str) and arg.startswith("--"):
|
|
303
|
+
key = arg[2:]
|
|
304
|
+
# Check if next arg is a value (not a flag)
|
|
305
|
+
if i + 1 < len(default_runtime_args):
|
|
306
|
+
next_arg = default_runtime_args[i + 1]
|
|
307
|
+
if not (isinstance(next_arg, str) and next_arg.startswith("--")):
|
|
308
|
+
flags[key] = next_arg
|
|
309
|
+
i += 2
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
flags[key] = None
|
|
313
|
+
else:
|
|
314
|
+
cli_args.append(arg)
|
|
315
|
+
i += 1
|
|
316
|
+
# Prioritize kwargs over default_args. Explicit None values in kwargs
|
|
317
|
+
# serve as overrides and are converted to the string "None" to match
|
|
318
|
+
# CLI flag conventions (--key=value)
|
|
319
|
+
# Downstream logic must handle the parsing of these string-based nulls.
|
|
320
|
+
for k, v in kwargs.items():
|
|
321
|
+
flags[k] = v
|
|
322
|
+
for k, v in flags.items():
|
|
323
|
+
cli_args.extend([f"--{k}", str(v)])
|
|
324
|
+
return cli_args
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -4,6 +4,8 @@ import warnings
|
|
|
4
4
|
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
5
5
|
ColumnHandlingOptions,
|
|
6
6
|
FileEncoding,
|
|
7
|
+
InputFormat,
|
|
8
|
+
InputSpec,
|
|
7
9
|
JobSpec,
|
|
8
10
|
OutputSpec,
|
|
9
11
|
SaveMode,
|
|
@@ -20,6 +22,8 @@ __all__ = [
|
|
|
20
22
|
"ModelVersion",
|
|
21
23
|
"ExportMode",
|
|
22
24
|
"HuggingFacePipelineModel",
|
|
25
|
+
"InputSpec",
|
|
26
|
+
"InputFormat",
|
|
23
27
|
"JobSpec",
|
|
24
28
|
"OutputSpec",
|
|
25
29
|
"SaveMode",
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Any, Optional
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel
|
|
5
5
|
from typing_extensions import TypedDict
|
|
@@ -19,6 +19,12 @@ class SaveMode(str, Enum):
|
|
|
19
19
|
ERROR = "error"
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
class InputFormat(str, Enum):
|
|
23
|
+
"""The format of the input column data."""
|
|
24
|
+
|
|
25
|
+
FULL_STAGE_PATH = "full_stage_path"
|
|
26
|
+
|
|
27
|
+
|
|
22
28
|
class FileEncoding(str, Enum):
|
|
23
29
|
"""The encoding of the file content that will be passed to the custom model."""
|
|
24
30
|
|
|
@@ -30,7 +36,37 @@ class FileEncoding(str, Enum):
|
|
|
30
36
|
class ColumnHandlingOptions(TypedDict):
|
|
31
37
|
"""Options for handling specific columns during run_batch for file I/O."""
|
|
32
38
|
|
|
33
|
-
|
|
39
|
+
input_format: InputFormat
|
|
40
|
+
convert_to: FileEncoding
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class InputSpec(BaseModel):
|
|
44
|
+
"""Specification for batch inference input options.
|
|
45
|
+
|
|
46
|
+
Defines optional configuration for processing input data during batch inference.
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
|
|
50
|
+
(e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
|
|
51
|
+
model's inference method. Defaults to None.
|
|
52
|
+
column_handling (Optional[dict[str, ColumnHandlingOptions]]): Optional dictionary
|
|
53
|
+
specifying how to handle specific columns during file I/O. Maps column names to their
|
|
54
|
+
input format and file encoding configuration.
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
>>> input_spec = InputSpec(
|
|
58
|
+
... params={"temperature": 0.7, "top_k": 50},
|
|
59
|
+
... column_handling={
|
|
60
|
+
... "image_col": {
|
|
61
|
+
... "input_format": InputFormat.FULL_STAGE_PATH,
|
|
62
|
+
... "convert_to": FileEncoding.BASE64
|
|
63
|
+
... }
|
|
64
|
+
... }
|
|
65
|
+
... )
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
params: Optional[dict[str, Any]] = None
|
|
69
|
+
column_handling: Optional[dict[str, ColumnHandlingOptions]] = None
|
|
34
70
|
|
|
35
71
|
|
|
36
72
|
class OutputSpec(BaseModel):
|
|
@@ -33,6 +33,12 @@ _BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
|
|
|
33
33
|
VLLM_SUPPORTED_TASKS = [
|
|
34
34
|
"text-generation",
|
|
35
35
|
"image-text-to-text",
|
|
36
|
+
"video-text-to-text",
|
|
37
|
+
"audio-text-to-text",
|
|
38
|
+
]
|
|
39
|
+
VALID_OPENAI_SIGNATURES = [
|
|
40
|
+
openai_signatures.OPENAI_CHAT_SIGNATURE,
|
|
41
|
+
openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
|
|
36
42
|
]
|
|
37
43
|
|
|
38
44
|
|
|
@@ -661,13 +667,12 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
661
667
|
@snowpark._internal.utils.private_preview(version="1.18.0")
|
|
662
668
|
def run_batch(
|
|
663
669
|
self,
|
|
670
|
+
X: dataframe.DataFrame,
|
|
664
671
|
*,
|
|
665
672
|
compute_pool: str,
|
|
666
|
-
input_spec:
|
|
673
|
+
input_spec: Optional[batch_inference_specs.InputSpec] = None,
|
|
667
674
|
output_spec: batch_inference_specs.OutputSpec,
|
|
668
675
|
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
669
|
-
params: Optional[dict[str, Any]] = None,
|
|
670
|
-
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]] = None,
|
|
671
676
|
inference_engine_options: Optional[dict[str, Any]] = None,
|
|
672
677
|
) -> job.MLJob[Any]:
|
|
673
678
|
"""Execute batch inference on datasets as an SPCS job.
|
|
@@ -675,19 +680,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
675
680
|
Args:
|
|
676
681
|
compute_pool (str): Name of the compute pool to use for building the image containers and batch
|
|
677
682
|
inference execution.
|
|
678
|
-
|
|
683
|
+
X (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
|
|
679
684
|
The DataFrame should contain all required features for model prediction and passthrough columns.
|
|
680
685
|
output_spec (batch_inference_specs.OutputSpec): Configuration for where and how to save
|
|
681
686
|
the inference results. Specifies the stage location and file handling behavior.
|
|
687
|
+
input_spec (Optional[batch_inference_specs.InputSpec]): Optional configuration for input
|
|
688
|
+
processing including model inference parameters and column handling options.
|
|
689
|
+
If None, default values will be used for params and column_handling.
|
|
682
690
|
job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
|
|
683
691
|
execution parameters such as compute resources, worker counts, and job naming.
|
|
684
692
|
If None, default values will be used.
|
|
685
|
-
params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
|
|
686
|
-
(e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
|
|
687
|
-
model's inference method. Defaults to None.
|
|
688
|
-
column_handling (Optional[dict[str, batch_inference_specs.FileEncoding]]): Optional dictionary
|
|
689
|
-
specifying how to handle specific columns during file I/O. Maps column names to their
|
|
690
|
-
file encoding configuration.
|
|
691
693
|
inference_engine_options: Options for the service creation with custom inference engine.
|
|
692
694
|
Supports `engine` and `engine_args_override`.
|
|
693
695
|
`engine` is the type of the inference engine to use.
|
|
@@ -699,7 +701,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
699
701
|
|
|
700
702
|
Raises:
|
|
701
703
|
ValueError: If warehouse is not set in job_spec and no current warehouse is available.
|
|
702
|
-
RuntimeError: If the
|
|
704
|
+
RuntimeError: If the input data cannot be processed or written to the staging location.
|
|
703
705
|
|
|
704
706
|
Example:
|
|
705
707
|
>>> # Prepare input data - Example 1: From a table
|
|
@@ -732,10 +734,24 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
732
734
|
>>> # Run batch inference
|
|
733
735
|
>>> job = model_version.run_batch(
|
|
734
736
|
... compute_pool="my_compute_pool",
|
|
735
|
-
...
|
|
737
|
+
... X=input_df,
|
|
736
738
|
... output_spec=output_spec,
|
|
737
739
|
... job_spec=job_spec
|
|
738
740
|
... )
|
|
741
|
+
>>>
|
|
742
|
+
>>> # Run batch inference with InputSpec for additional options
|
|
743
|
+
>>> from snowflake.ml.model._client.model.batch_inference_specs import InputSpec, FileEncoding
|
|
744
|
+
>>> input_spec = InputSpec(
|
|
745
|
+
... params={"temperature": 0.7, "top_k": 50},
|
|
746
|
+
... column_handling={"image_col": {"encoding": FileEncoding.BASE64}}
|
|
747
|
+
... )
|
|
748
|
+
>>> job = model_version.run_batch(
|
|
749
|
+
... compute_pool="my_compute_pool",
|
|
750
|
+
... X=input_df,
|
|
751
|
+
... output_spec=output_spec,
|
|
752
|
+
... input_spec=input_spec,
|
|
753
|
+
... job_spec=job_spec
|
|
754
|
+
... )
|
|
739
755
|
|
|
740
756
|
Note:
|
|
741
757
|
This method is currently in private preview and requires Snowflake version 1.18.0 or later.
|
|
@@ -747,6 +763,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
747
763
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
748
764
|
)
|
|
749
765
|
|
|
766
|
+
# Extract params and column_handling from input_spec if provided
|
|
767
|
+
if input_spec is None:
|
|
768
|
+
input_spec = batch_inference_specs.InputSpec()
|
|
769
|
+
|
|
770
|
+
params = input_spec.params
|
|
771
|
+
column_handling = input_spec.column_handling
|
|
772
|
+
|
|
750
773
|
if job_spec is None:
|
|
751
774
|
job_spec = batch_inference_specs.JobSpec()
|
|
752
775
|
|
|
@@ -772,10 +795,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
772
795
|
self._service_ops._enforce_save_mode(output_spec.mode, output_stage_location)
|
|
773
796
|
|
|
774
797
|
try:
|
|
775
|
-
|
|
798
|
+
X.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
|
|
776
799
|
# todo: be specific about the type of errors to provide better error messages.
|
|
777
800
|
except Exception as e:
|
|
778
|
-
raise RuntimeError(f"Failed to process
|
|
801
|
+
raise RuntimeError(f"Failed to process input data: {e}")
|
|
779
802
|
|
|
780
803
|
if job_spec.job_name is None:
|
|
781
804
|
# Same as the MLJob ID generation logic with a different prefix
|
|
@@ -1123,16 +1146,11 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1123
1146
|
func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
|
|
1124
1147
|
}
|
|
1125
1148
|
|
|
1126
|
-
if deserialized_signatures not in
|
|
1127
|
-
openai_signatures.OPENAI_CHAT_SIGNATURE,
|
|
1128
|
-
openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
|
|
1129
|
-
]:
|
|
1149
|
+
if deserialized_signatures not in VALID_OPENAI_SIGNATURES:
|
|
1130
1150
|
raise ValueError(
|
|
1131
|
-
"Inference engine requires the model to be logged with
|
|
1132
|
-
"
|
|
1151
|
+
"Inference engine requires the model to be logged with one of the following signatures: "
|
|
1152
|
+
f"{VALID_OPENAI_SIGNATURES}. Please log the model again with one of these supported signatures."
|
|
1133
1153
|
f"Found signatures: {signatures_dict}. "
|
|
1134
|
-
"Please log the model again with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE or "
|
|
1135
|
-
"signatures=openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING"
|
|
1136
1154
|
)
|
|
1137
1155
|
|
|
1138
1156
|
@overload
|
|
@@ -1144,6 +1162,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1144
1162
|
service_compute_pool: str,
|
|
1145
1163
|
image_repo: Optional[str] = None,
|
|
1146
1164
|
ingress_enabled: bool = False,
|
|
1165
|
+
min_instances: int = 0,
|
|
1147
1166
|
max_instances: int = 1,
|
|
1148
1167
|
cpu_requests: Optional[str] = None,
|
|
1149
1168
|
memory_requests: Optional[str] = None,
|
|
@@ -1170,8 +1189,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1170
1189
|
will be used.
|
|
1171
1190
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
1172
1191
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
1173
|
-
|
|
1174
|
-
|
|
1192
|
+
min_instances: The minimum number of instances for the inference service. The service will automatically
|
|
1193
|
+
scale between min_instances and max_instances based on traffic and hardware utilization. If set to
|
|
1194
|
+
0 (default), the service will automatically suspend after a period of inactivity.
|
|
1195
|
+
max_instances: The maximum number of instances for the inference service.
|
|
1175
1196
|
cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
|
|
1176
1197
|
None, we attempt to utilize all the vCPU of the node.
|
|
1177
1198
|
memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
|
|
@@ -1207,6 +1228,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1207
1228
|
service_compute_pool: str,
|
|
1208
1229
|
image_repo: Optional[str] = None,
|
|
1209
1230
|
ingress_enabled: bool = False,
|
|
1231
|
+
min_instances: int = 0,
|
|
1210
1232
|
max_instances: int = 1,
|
|
1211
1233
|
cpu_requests: Optional[str] = None,
|
|
1212
1234
|
memory_requests: Optional[str] = None,
|
|
@@ -1233,8 +1255,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1233
1255
|
will be used.
|
|
1234
1256
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
1235
1257
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
1236
|
-
|
|
1237
|
-
|
|
1258
|
+
min_instances: The minimum number of instances for the inference service. The service will automatically
|
|
1259
|
+
scale between min_instances and max_instances based on traffic and hardware utilization. If set to
|
|
1260
|
+
0 (default), the service will automatically suspend after a period of inactivity.
|
|
1261
|
+
max_instances: The maximum number of instances for the inference service.
|
|
1238
1262
|
cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
|
|
1239
1263
|
None, we attempt to utilize all the vCPU of the node.
|
|
1240
1264
|
memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
|
|
@@ -1284,6 +1308,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1284
1308
|
service_compute_pool: str,
|
|
1285
1309
|
image_repo: Optional[str] = None,
|
|
1286
1310
|
ingress_enabled: bool = False,
|
|
1311
|
+
min_instances: int = 0,
|
|
1287
1312
|
max_instances: int = 1,
|
|
1288
1313
|
cpu_requests: Optional[str] = None,
|
|
1289
1314
|
memory_requests: Optional[str] = None,
|
|
@@ -1311,8 +1336,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1311
1336
|
will be used.
|
|
1312
1337
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
1313
1338
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
1314
|
-
|
|
1315
|
-
|
|
1339
|
+
min_instances: The minimum number of instances for the inference service. The service will automatically
|
|
1340
|
+
scale between min_instances and max_instances based on traffic and hardware utilization. If set to
|
|
1341
|
+
0 (default), the service will automatically suspend after a period of inactivity.
|
|
1342
|
+
max_instances: The maximum number of instances for the inference service.
|
|
1316
1343
|
cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
|
|
1317
1344
|
None, we attempt to utilize all the vCPU of the node.
|
|
1318
1345
|
memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
|
|
@@ -1402,6 +1429,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1402
1429
|
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
1403
1430
|
image_repo_name=image_repo,
|
|
1404
1431
|
ingress_enabled=ingress_enabled,
|
|
1432
|
+
min_instances=min_instances,
|
|
1405
1433
|
max_instances=max_instances,
|
|
1406
1434
|
cpu_requests=cpu_requests,
|
|
1407
1435
|
memory_requests=memory_requests,
|
|
@@ -10,7 +10,6 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
|
|
|
10
10
|
import yaml
|
|
11
11
|
from typing_extensions import NotRequired
|
|
12
12
|
|
|
13
|
-
from snowflake.ml._internal import platform_capabilities
|
|
14
13
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
15
14
|
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier, url
|
|
16
15
|
from snowflake.ml.model import model_signature, type_hints
|
|
@@ -698,9 +697,6 @@ class ModelOperator:
|
|
|
698
697
|
|
|
699
698
|
result: list[ServiceInfo] = []
|
|
700
699
|
is_privatelink_connection = self._is_privatelink_connection()
|
|
701
|
-
is_autocapture_param_enabled = (
|
|
702
|
-
platform_capabilities.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
|
|
703
|
-
)
|
|
704
700
|
|
|
705
701
|
for fully_qualified_service_name in fully_qualified_service_names:
|
|
706
702
|
port: Optional[int] = None
|
|
@@ -742,10 +738,8 @@ class ModelOperator:
|
|
|
742
738
|
inference_endpoint=inference_endpoint,
|
|
743
739
|
internal_endpoint=f"http://{internal_dns}:{port}" if port is not None else None,
|
|
744
740
|
)
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
autocapture_enabled = self._service_client.get_proxy_container_autocapture(service_description)
|
|
748
|
-
service_info["autocapture_enabled"] = autocapture_enabled
|
|
741
|
+
autocapture_enabled = self._service_client.is_autocapture_enabled(service_description)
|
|
742
|
+
service_info["autocapture_enabled"] = autocapture_enabled
|
|
749
743
|
|
|
750
744
|
result.append(service_info)
|
|
751
745
|
|
|
@@ -155,7 +155,6 @@ class ServiceOperator:
|
|
|
155
155
|
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
|
|
156
156
|
workspace_path=pathlib.Path(self._workspace.name)
|
|
157
157
|
)
|
|
158
|
-
self._inference_autocapture_enabled = pc.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
|
|
159
158
|
|
|
160
159
|
def __eq__(self, __value: object) -> bool:
|
|
161
160
|
if not isinstance(__value, ServiceOperator):
|
|
@@ -176,6 +175,7 @@ class ServiceOperator:
|
|
|
176
175
|
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
177
176
|
image_repo_name: Optional[str],
|
|
178
177
|
ingress_enabled: bool,
|
|
178
|
+
min_instances: int,
|
|
179
179
|
max_instances: int,
|
|
180
180
|
cpu_requests: Optional[str],
|
|
181
181
|
memory_requests: Optional[str],
|
|
@@ -216,10 +216,6 @@ class ServiceOperator:
|
|
|
216
216
|
progress_status.update("preparing deployment artifacts...")
|
|
217
217
|
progress_status.increment()
|
|
218
218
|
|
|
219
|
-
# If autocapture param is disabled, don't allow create service with autocapture
|
|
220
|
-
if not self._inference_autocapture_enabled and autocapture:
|
|
221
|
-
raise ValueError("Invalid Argument: Autocapture feature is not supported.")
|
|
222
|
-
|
|
223
219
|
if self._workspace:
|
|
224
220
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
225
221
|
else:
|
|
@@ -246,6 +242,7 @@ class ServiceOperator:
|
|
|
246
242
|
service_name=service_name,
|
|
247
243
|
inference_compute_pool_name=service_compute_pool_name,
|
|
248
244
|
ingress_enabled=ingress_enabled,
|
|
245
|
+
min_instances=min_instances,
|
|
249
246
|
max_instances=max_instances,
|
|
250
247
|
cpu=cpu_requests,
|
|
251
248
|
memory=memory_requests,
|
|
@@ -834,15 +831,13 @@ class ServiceOperator:
|
|
|
834
831
|
service_seen_before = False
|
|
835
832
|
|
|
836
833
|
while True:
|
|
837
|
-
# Check if async job has
|
|
834
|
+
# Check if async job has completed
|
|
838
835
|
if async_job.is_done():
|
|
839
836
|
try:
|
|
840
837
|
async_job.result()
|
|
841
|
-
# Async job completed successfully
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
|
|
845
|
-
)
|
|
838
|
+
# Async job completed successfully - deployment is done
|
|
839
|
+
module_logger.debug(f"Async job completed successfully, returning from wait for {service_name}")
|
|
840
|
+
return
|
|
846
841
|
except Exception as e:
|
|
847
842
|
raise RuntimeError(f"Service deployment failed: {e}")
|
|
848
843
|
|
|
@@ -140,6 +140,7 @@ class ModelDeploymentSpec:
|
|
|
140
140
|
service_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
141
141
|
service_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
142
142
|
ingress_enabled: bool = True,
|
|
143
|
+
min_instances: int = 0,
|
|
143
144
|
max_instances: int = 1,
|
|
144
145
|
cpu: Optional[str] = None,
|
|
145
146
|
memory: Optional[str] = None,
|
|
@@ -156,6 +157,7 @@ class ModelDeploymentSpec:
|
|
|
156
157
|
service_database_name: Database name for the service.
|
|
157
158
|
service_schema_name: Schema name for the service.
|
|
158
159
|
ingress_enabled: Whether ingress is enabled.
|
|
160
|
+
min_instances: Minimum number of service instances.
|
|
159
161
|
max_instances: Maximum number of service instances.
|
|
160
162
|
cpu: CPU requirement.
|
|
161
163
|
memory: Memory requirement.
|
|
@@ -187,6 +189,7 @@ class ModelDeploymentSpec:
|
|
|
187
189
|
name=fq_service_name,
|
|
188
190
|
compute_pool=inference_compute_pool_name.identifier(),
|
|
189
191
|
ingress_enabled=ingress_enabled,
|
|
192
|
+
min_instances=min_instances,
|
|
190
193
|
max_instances=max_instances,
|
|
191
194
|
autocapture=autocapture,
|
|
192
195
|
**self._inference_spec,
|