snowflake-ml-python 1.24.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +26 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
- snowflake/ml/data/data_connector.py +2 -2
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/experiment/_experiment_info.py +3 -3
- snowflake/ml/jobs/_interop/data_utils.py +8 -8
- snowflake/ml/jobs/_interop/dto_schema.py +52 -7
- snowflake/ml/jobs/_interop/protocols.py +124 -7
- snowflake/ml/jobs/_interop/utils.py +92 -33
- snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
- snowflake/ml/jobs/_utils/constants.py +4 -0
- snowflake/ml/jobs/_utils/feature_flags.py +97 -13
- snowflake/ml/jobs/_utils/payload_utils.py +6 -40
- snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
- snowflake/ml/jobs/decorators.py +17 -22
- snowflake/ml/jobs/job.py +25 -10
- snowflake/ml/jobs/job_definition.py +100 -8
- snowflake/ml/model/_client/model/model_version_impl.py +25 -14
- snowflake/ml/model/_client/ops/service_ops.py +6 -6
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/models/huggingface_pipeline.py +3 -0
- snowflake/ml/model/openai_signatures.py +154 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +41 -2
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +31 -32
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
- snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
- snowflake/ml/jobs/_utils/spec_utils.py +0 -22
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/decorators.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
import copy
|
|
2
|
-
import functools
|
|
3
2
|
from typing import Any, Callable, Optional, TypeVar
|
|
4
3
|
|
|
5
4
|
from typing_extensions import ParamSpec
|
|
6
5
|
|
|
7
6
|
from snowflake import snowpark
|
|
8
7
|
from snowflake.ml._internal import telemetry
|
|
9
|
-
from snowflake.ml.jobs import
|
|
10
|
-
from snowflake.ml.jobs._utils import
|
|
8
|
+
from snowflake.ml.jobs import job_definition as jd
|
|
9
|
+
from snowflake.ml.jobs._utils import arg_protocol, constants
|
|
11
10
|
|
|
12
11
|
_PROJECT = "MLJob"
|
|
13
12
|
|
|
@@ -25,7 +24,7 @@ def remote(
|
|
|
25
24
|
external_access_integrations: Optional[list[str]] = None,
|
|
26
25
|
session: Optional[snowpark.Session] = None,
|
|
27
26
|
**kwargs: Any,
|
|
28
|
-
) -> Callable[[Callable[_Args, _ReturnValue]],
|
|
27
|
+
) -> Callable[[Callable[_Args, _ReturnValue]], jd.MLJobDefinition[_Args, _ReturnValue]]:
|
|
29
28
|
"""
|
|
30
29
|
Submit a job to the compute pool.
|
|
31
30
|
|
|
@@ -51,29 +50,25 @@ def remote(
|
|
|
51
50
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
|
52
51
|
"""
|
|
53
52
|
|
|
54
|
-
def decorator(func: Callable[_Args, _ReturnValue]) ->
|
|
53
|
+
def decorator(func: Callable[_Args, _ReturnValue]) -> jd.MLJobDefinition[_Args, _ReturnValue]:
|
|
55
54
|
# Copy the function to avoid modifying the original
|
|
56
55
|
# We need to modify the line number of the function to exclude the
|
|
57
56
|
# decorator from the copied source code
|
|
58
57
|
wrapped_func = copy.copy(func)
|
|
59
58
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
|
60
59
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
|
75
|
-
return job
|
|
76
|
-
|
|
77
|
-
return wrapper
|
|
60
|
+
setattr(wrapped_func, constants.IS_MLJOB_REMOTE_ATTR, True)
|
|
61
|
+
return jd.MLJobDefinition.register(
|
|
62
|
+
source=wrapped_func,
|
|
63
|
+
compute_pool=compute_pool,
|
|
64
|
+
stage_name=stage_name,
|
|
65
|
+
target_instances=target_instances,
|
|
66
|
+
pip_requirements=pip_requirements,
|
|
67
|
+
external_access_integrations=external_access_integrations,
|
|
68
|
+
session=session or snowpark.context.get_active_session(),
|
|
69
|
+
arg_protocol=arg_protocol.ArgProtocol.PICKLE,
|
|
70
|
+
generate_suffix=True,
|
|
71
|
+
**kwargs,
|
|
72
|
+
)
|
|
78
73
|
|
|
79
74
|
return decorator
|
snowflake/ml/jobs/job.py
CHANGED
|
@@ -123,26 +123,41 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
123
123
|
|
|
124
124
|
return self._transform_path(result_path_str)
|
|
125
125
|
|
|
126
|
-
|
|
126
|
+
# After introducing ML Job definitions, we have additional stage mount for result path
|
|
127
|
+
# the result path is like @payload_stage/{job_definition_name}/{job_name}/mljob_result
|
|
128
|
+
@property
|
|
129
|
+
def _result_stage_path(self) -> Optional[str]:
|
|
130
|
+
volumes = self._service_spec["spec"]["volumes"]
|
|
131
|
+
stage_volume = next((v for v in volumes if v["name"] == constants.RESULT_VOLUME_NAME), None)
|
|
132
|
+
if stage_volume is None:
|
|
133
|
+
return self._stage_path
|
|
134
|
+
elif "stageConfig" in stage_volume:
|
|
135
|
+
return cast(str, stage_volume["stageConfig"]["name"])
|
|
136
|
+
else:
|
|
137
|
+
return cast(str, stage_volume["source"])
|
|
138
|
+
|
|
139
|
+
def _transform_path(
|
|
140
|
+
self,
|
|
141
|
+
path_str: str,
|
|
142
|
+
) -> str:
|
|
127
143
|
"""Transform a local path within the container to a stage path."""
|
|
128
144
|
path = stage_utils.resolve_path(path_str)
|
|
129
145
|
if isinstance(path, stage_utils.StagePath):
|
|
130
|
-
# Stage paths need no transformation
|
|
131
146
|
return path.as_posix()
|
|
132
147
|
if not path.is_absolute():
|
|
133
|
-
|
|
134
|
-
return f"{self._stage_path}/{path.as_posix()}"
|
|
148
|
+
return f"{self._result_stage_path}/{path.as_posix()}"
|
|
135
149
|
|
|
136
|
-
# If result path is absolute, rebase it onto the stage mount path
|
|
137
|
-
# TODO: Rather than matching by name, use the longest mount path which matches
|
|
138
150
|
volume_mounts = self._container_spec["volumeMounts"]
|
|
139
|
-
|
|
151
|
+
stage_volume = next((v for v in volume_mounts if v["name"] == constants.RESULT_VOLUME_NAME), None)
|
|
152
|
+
if stage_volume is None:
|
|
153
|
+
stage_volume = next(v for v in volume_mounts if v["name"] == constants.STAGE_VOLUME_NAME)
|
|
154
|
+
stage_mount_str = stage_volume["mountPath"]
|
|
140
155
|
stage_mount = Path(stage_mount_str)
|
|
141
156
|
try:
|
|
142
157
|
relative_path = path.relative_to(stage_mount)
|
|
143
|
-
return f"{self.
|
|
158
|
+
return f"{self._result_stage_path}/{relative_path.as_posix()}"
|
|
144
159
|
except ValueError:
|
|
145
|
-
raise ValueError(f"Result
|
|
160
|
+
raise ValueError(f"Result Path {path} is absolute, but should be relative to stage mount {stage_mount}")
|
|
146
161
|
|
|
147
162
|
@overload
|
|
148
163
|
def get_logs(
|
|
@@ -279,7 +294,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
279
294
|
if self._result is None:
|
|
280
295
|
self.wait(timeout)
|
|
281
296
|
try:
|
|
282
|
-
self._result = interop_utils.
|
|
297
|
+
self._result = interop_utils.load(
|
|
283
298
|
self._result_path, session=self._session, path_transform=self._transform_path
|
|
284
299
|
)
|
|
285
300
|
except Exception as e:
|
|
@@ -14,11 +14,14 @@ from snowflake.ml._internal import telemetry
|
|
|
14
14
|
from snowflake.ml._internal.utils import identifier
|
|
15
15
|
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
|
16
16
|
from snowflake.ml.jobs import job as jb
|
|
17
|
+
from snowflake.ml.jobs._interop import utils as interop_utils
|
|
17
18
|
from snowflake.ml.jobs._utils import (
|
|
19
|
+
arg_protocol,
|
|
18
20
|
constants,
|
|
19
21
|
feature_flags,
|
|
20
22
|
payload_utils,
|
|
21
23
|
query_helper,
|
|
24
|
+
runtime_env_utils,
|
|
22
25
|
types,
|
|
23
26
|
)
|
|
24
27
|
from snowflake.snowpark import context as sp_context
|
|
@@ -40,6 +43,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
40
43
|
compute_pool: str,
|
|
41
44
|
name: str,
|
|
42
45
|
entrypoint_args: list[Any],
|
|
46
|
+
arg_protocol: Optional[arg_protocol.ArgProtocol] = arg_protocol.ArgProtocol.NONE,
|
|
47
|
+
default_args: Optional[list[Any]] = None,
|
|
43
48
|
database: Optional[str] = None,
|
|
44
49
|
schema: Optional[str] = None,
|
|
45
50
|
session: Optional[snowpark.Session] = None,
|
|
@@ -49,12 +54,22 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
49
54
|
self.spec_options = spec_options
|
|
50
55
|
self.compute_pool = compute_pool
|
|
51
56
|
self.session = session or sp_context.get_active_session()
|
|
52
|
-
|
|
53
|
-
|
|
57
|
+
resolved_database = database or self.session.get_current_database()
|
|
58
|
+
resolved_schema = schema or self.session.get_current_schema()
|
|
59
|
+
if resolved_database is None:
|
|
60
|
+
raise ValueError("Database must be specified either in the session context or as a parameter.")
|
|
61
|
+
if resolved_schema is None:
|
|
62
|
+
raise ValueError("Schema must be specified either in the session context or as a parameter.")
|
|
63
|
+
self.database = identifier.resolve_identifier(resolved_database)
|
|
64
|
+
self.schema = identifier.resolve_identifier(resolved_schema)
|
|
54
65
|
self.job_definition_id = identifier.get_schema_level_object_identifier(self.database, self.schema, name)
|
|
55
66
|
self.entrypoint_args = entrypoint_args
|
|
67
|
+
self.arg_protocol = arg_protocol
|
|
68
|
+
self.default_args = default_args
|
|
56
69
|
|
|
57
70
|
def delete(self) -> None:
|
|
71
|
+
if self.session is None:
|
|
72
|
+
raise RuntimeError("Session is required to delete job definition")
|
|
58
73
|
if self.stage_name:
|
|
59
74
|
try:
|
|
60
75
|
self.session.sql(f"REMOVE {self.stage_name}/").collect()
|
|
@@ -62,9 +77,27 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
62
77
|
except Exception as e:
|
|
63
78
|
logger.warning(f"Failed to clean up stage files for job definition {self.stage_name}: {e}")
|
|
64
79
|
|
|
65
|
-
def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> list[Any]:
|
|
66
|
-
|
|
67
|
-
|
|
80
|
+
def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> Optional[list[Any]]:
|
|
81
|
+
if self.arg_protocol == arg_protocol.ArgProtocol.NONE:
|
|
82
|
+
if len(kwargs) > 0:
|
|
83
|
+
raise ValueError(f"Keyword arguments are not supported with {self.arg_protocol}")
|
|
84
|
+
return list(args)
|
|
85
|
+
elif self.arg_protocol == arg_protocol.ArgProtocol.CLI:
|
|
86
|
+
return _combine_runtime_arguments(self.default_args, *args, **kwargs)
|
|
87
|
+
elif self.arg_protocol == arg_protocol.ArgProtocol.PICKLE:
|
|
88
|
+
if not args and not kwargs:
|
|
89
|
+
return []
|
|
90
|
+
uid = uuid4().hex[:8]
|
|
91
|
+
rel_path = f"{uid}/function_args"
|
|
92
|
+
file_path = f"{self.stage_name}/{constants.APP_STAGE_SUBPATH}/{rel_path}"
|
|
93
|
+
payload = interop_utils.save_result(
|
|
94
|
+
(args, kwargs), file_path, session=self.session, max_inline_size=interop_utils._MAX_INLINE_SIZE
|
|
95
|
+
)
|
|
96
|
+
if payload is not None:
|
|
97
|
+
return [f"--function_args={payload.decode('utf-8')}"]
|
|
98
|
+
return [f"--function_args={rel_path}"]
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f"Invalid arg_protocol: {self.arg_protocol}")
|
|
68
101
|
|
|
69
102
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
70
103
|
def __call__(self, *args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
|
@@ -98,6 +131,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
98
131
|
json.dumps(job_options_dict),
|
|
99
132
|
]
|
|
100
133
|
query_template = "CALL SYSTEM$EXECUTE_ML_JOB(%s, %s, %s, %s)"
|
|
134
|
+
assert self.session is not None, "Session is required to generate MLJob SQL query"
|
|
101
135
|
sql = self.session._conn._cursor._preprocess_pyformat_query(query_template, params)
|
|
102
136
|
return sql
|
|
103
137
|
|
|
@@ -123,6 +157,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
123
157
|
entrypoint: Optional[Union[str, list[str]]] = None,
|
|
124
158
|
target_instances: int = 1,
|
|
125
159
|
generate_suffix: bool = True,
|
|
160
|
+
arg_protocol: Optional[arg_protocol.ArgProtocol] = arg_protocol.ArgProtocol.NONE,
|
|
126
161
|
**kwargs: Any,
|
|
127
162
|
) -> "MLJobDefinition[_Args, _ReturnValue]":
|
|
128
163
|
# Use kwargs for less common optional parameters
|
|
@@ -142,6 +177,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
142
177
|
)
|
|
143
178
|
overwrite = kwargs.pop("overwrite", False)
|
|
144
179
|
name = kwargs.pop("name", None)
|
|
180
|
+
default_args = kwargs.pop("default_args", None)
|
|
145
181
|
# Warn if there are unknown kwargs
|
|
146
182
|
if kwargs:
|
|
147
183
|
logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
|
|
@@ -149,6 +185,11 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
149
185
|
# Validate parameters
|
|
150
186
|
if database and not schema:
|
|
151
187
|
raise ValueError("Schema must be specified if database is specified.")
|
|
188
|
+
|
|
189
|
+
compute_pool = identifier.resolve_identifier(compute_pool)
|
|
190
|
+
if query_warehouse is not None:
|
|
191
|
+
query_warehouse = identifier.resolve_identifier(query_warehouse)
|
|
192
|
+
|
|
152
193
|
if target_instances < 1:
|
|
153
194
|
raise ValueError("target_instances must be greater than 0.")
|
|
154
195
|
if not (0 < min_instances <= target_instances):
|
|
@@ -190,10 +231,11 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
190
231
|
)
|
|
191
232
|
raise
|
|
192
233
|
|
|
193
|
-
if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled(
|
|
234
|
+
if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
194
235
|
# Pass a JSON object for runtime versions so it serializes as nested JSON in options
|
|
195
236
|
runtime_environment = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
|
|
196
237
|
|
|
238
|
+
runtime = runtime_env_utils.get_runtime_image(session, compute_pool, runtime_environment)
|
|
197
239
|
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
198
240
|
entrypoint_args = [v.as_posix() if isinstance(v, PurePath) else v for v in uploaded_payload.entrypoint]
|
|
199
241
|
spec_options = types.SpecOptions(
|
|
@@ -203,8 +245,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
203
245
|
env_vars=combined_env_vars,
|
|
204
246
|
enable_metrics=enable_metrics,
|
|
205
247
|
spec_overrides=spec_overrides,
|
|
206
|
-
runtime=
|
|
207
|
-
enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(
|
|
248
|
+
runtime=runtime,
|
|
249
|
+
enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(),
|
|
208
250
|
)
|
|
209
251
|
|
|
210
252
|
job_options = types.JobOptions(
|
|
@@ -222,6 +264,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
222
264
|
compute_pool=compute_pool,
|
|
223
265
|
entrypoint_args=entrypoint_args,
|
|
224
266
|
session=session,
|
|
267
|
+
arg_protocol=arg_protocol,
|
|
268
|
+
default_args=default_args,
|
|
225
269
|
database=database,
|
|
226
270
|
schema=schema,
|
|
227
271
|
name=name,
|
|
@@ -230,3 +274,51 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
|
230
274
|
|
|
231
275
|
def _generate_suffix() -> str:
|
|
232
276
|
return str(uuid4().hex)[:8]
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _combine_runtime_arguments(
|
|
280
|
+
default_runtime_args: Optional[list[Any]] = None, *args: Any, **kwargs: Any
|
|
281
|
+
) -> list[Any]:
|
|
282
|
+
"""Merge default CLI arguments with runtime overrides into a flat argument list.
|
|
283
|
+
|
|
284
|
+
Parses `default_runtime_args` for flags (e.g., `--key value`) and merges them with
|
|
285
|
+
`kwargs`. Keyword arguments override defaults unless their value is None. Positional
|
|
286
|
+
arguments from both `default_args` and `*args` are preserved in order.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
default_runtime_args: Optional list of default CLI arguments to parse for flags and positional args.
|
|
290
|
+
*args: Additional positional arguments to include in the output.
|
|
291
|
+
**kwargs: Keyword arguments that override default flags. Values of None are ignored.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
A list of CLI-style arguments: positional args followed by `--key value` pairs.
|
|
295
|
+
"""
|
|
296
|
+
cli_args = list(args)
|
|
297
|
+
flags: dict[str, Any] = {}
|
|
298
|
+
if default_runtime_args:
|
|
299
|
+
i = 0
|
|
300
|
+
while i < len(default_runtime_args):
|
|
301
|
+
arg = default_runtime_args[i]
|
|
302
|
+
if isinstance(arg, str) and arg.startswith("--"):
|
|
303
|
+
key = arg[2:]
|
|
304
|
+
# Check if next arg is a value (not a flag)
|
|
305
|
+
if i + 1 < len(default_runtime_args):
|
|
306
|
+
next_arg = default_runtime_args[i + 1]
|
|
307
|
+
if not (isinstance(next_arg, str) and next_arg.startswith("--")):
|
|
308
|
+
flags[key] = next_arg
|
|
309
|
+
i += 2
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
flags[key] = None
|
|
313
|
+
else:
|
|
314
|
+
cli_args.append(arg)
|
|
315
|
+
i += 1
|
|
316
|
+
# Prioritize kwargs over default_args. Explicit None values in kwargs
|
|
317
|
+
# serve as overrides and are converted to the string "None" to match
|
|
318
|
+
# CLI flag conventions (--key=value)
|
|
319
|
+
# Downstream logic must handle the parsing of these string-based nulls.
|
|
320
|
+
for k, v in kwargs.items():
|
|
321
|
+
flags[k] = v
|
|
322
|
+
for k, v in flags.items():
|
|
323
|
+
cli_args.extend([f"--{k}", str(v)])
|
|
324
|
+
return cli_args
|
|
@@ -33,6 +33,12 @@ _BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
|
|
|
33
33
|
VLLM_SUPPORTED_TASKS = [
|
|
34
34
|
"text-generation",
|
|
35
35
|
"image-text-to-text",
|
|
36
|
+
"video-text-to-text",
|
|
37
|
+
"audio-text-to-text",
|
|
38
|
+
]
|
|
39
|
+
VALID_OPENAI_SIGNATURES = [
|
|
40
|
+
openai_signatures.OPENAI_CHAT_SIGNATURE,
|
|
41
|
+
openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
|
|
36
42
|
]
|
|
37
43
|
|
|
38
44
|
|
|
@@ -1140,16 +1146,11 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1140
1146
|
func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
|
|
1141
1147
|
}
|
|
1142
1148
|
|
|
1143
|
-
if deserialized_signatures not in
|
|
1144
|
-
openai_signatures.OPENAI_CHAT_SIGNATURE,
|
|
1145
|
-
openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
|
|
1146
|
-
]:
|
|
1149
|
+
if deserialized_signatures not in VALID_OPENAI_SIGNATURES:
|
|
1147
1150
|
raise ValueError(
|
|
1148
|
-
"Inference engine requires the model to be logged with
|
|
1149
|
-
"
|
|
1151
|
+
"Inference engine requires the model to be logged with one of the following signatures: "
|
|
1152
|
+
f"{VALID_OPENAI_SIGNATURES}. Please log the model again with one of these supported signatures."
|
|
1150
1153
|
f"Found signatures: {signatures_dict}. "
|
|
1151
|
-
"Please log the model again with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE or "
|
|
1152
|
-
"signatures=openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING"
|
|
1153
1154
|
)
|
|
1154
1155
|
|
|
1155
1156
|
@overload
|
|
@@ -1161,6 +1162,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1161
1162
|
service_compute_pool: str,
|
|
1162
1163
|
image_repo: Optional[str] = None,
|
|
1163
1164
|
ingress_enabled: bool = False,
|
|
1165
|
+
min_instances: int = 0,
|
|
1164
1166
|
max_instances: int = 1,
|
|
1165
1167
|
cpu_requests: Optional[str] = None,
|
|
1166
1168
|
memory_requests: Optional[str] = None,
|
|
@@ -1187,8 +1189,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1187
1189
|
will be used.
|
|
1188
1190
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
1189
1191
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
+
min_instances: The minimum number of instances for the inference service. The service will automatically
|
|
1193
|
+
scale between min_instances and max_instances based on traffic and hardware utilization. If set to
|
|
1194
|
+
0 (default), the service will automatically suspend after a period of inactivity.
|
|
1195
|
+
max_instances: The maximum number of instances for the inference service.
|
|
1192
1196
|
cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
|
|
1193
1197
|
None, we attempt to utilize all the vCPU of the node.
|
|
1194
1198
|
memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
|
|
@@ -1224,6 +1228,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1224
1228
|
service_compute_pool: str,
|
|
1225
1229
|
image_repo: Optional[str] = None,
|
|
1226
1230
|
ingress_enabled: bool = False,
|
|
1231
|
+
min_instances: int = 0,
|
|
1227
1232
|
max_instances: int = 1,
|
|
1228
1233
|
cpu_requests: Optional[str] = None,
|
|
1229
1234
|
memory_requests: Optional[str] = None,
|
|
@@ -1250,8 +1255,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1250
1255
|
will be used.
|
|
1251
1256
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
1252
1257
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
1253
|
-
|
|
1254
|
-
|
|
1258
|
+
min_instances: The minimum number of instances for the inference service. The service will automatically
|
|
1259
|
+
scale between min_instances and max_instances based on traffic and hardware utilization. If set to
|
|
1260
|
+
0 (default), the service will automatically suspend after a period of inactivity.
|
|
1261
|
+
max_instances: The maximum number of instances for the inference service.
|
|
1255
1262
|
cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
|
|
1256
1263
|
None, we attempt to utilize all the vCPU of the node.
|
|
1257
1264
|
memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
|
|
@@ -1301,6 +1308,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1301
1308
|
service_compute_pool: str,
|
|
1302
1309
|
image_repo: Optional[str] = None,
|
|
1303
1310
|
ingress_enabled: bool = False,
|
|
1311
|
+
min_instances: int = 0,
|
|
1304
1312
|
max_instances: int = 1,
|
|
1305
1313
|
cpu_requests: Optional[str] = None,
|
|
1306
1314
|
memory_requests: Optional[str] = None,
|
|
@@ -1328,8 +1336,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1328
1336
|
will be used.
|
|
1329
1337
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
1330
1338
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
1331
|
-
|
|
1332
|
-
|
|
1339
|
+
min_instances: The minimum number of instances for the inference service. The service will automatically
|
|
1340
|
+
scale between min_instances and max_instances based on traffic and hardware utilization. If set to
|
|
1341
|
+
0 (default), the service will automatically suspend after a period of inactivity.
|
|
1342
|
+
max_instances: The maximum number of instances for the inference service.
|
|
1333
1343
|
cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
|
|
1334
1344
|
None, we attempt to utilize all the vCPU of the node.
|
|
1335
1345
|
memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
|
|
@@ -1419,6 +1429,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1419
1429
|
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
1420
1430
|
image_repo_name=image_repo,
|
|
1421
1431
|
ingress_enabled=ingress_enabled,
|
|
1432
|
+
min_instances=min_instances,
|
|
1422
1433
|
max_instances=max_instances,
|
|
1423
1434
|
cpu_requests=cpu_requests,
|
|
1424
1435
|
memory_requests=memory_requests,
|
|
@@ -175,6 +175,7 @@ class ServiceOperator:
|
|
|
175
175
|
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
176
176
|
image_repo_name: Optional[str],
|
|
177
177
|
ingress_enabled: bool,
|
|
178
|
+
min_instances: int,
|
|
178
179
|
max_instances: int,
|
|
179
180
|
cpu_requests: Optional[str],
|
|
180
181
|
memory_requests: Optional[str],
|
|
@@ -241,6 +242,7 @@ class ServiceOperator:
|
|
|
241
242
|
service_name=service_name,
|
|
242
243
|
inference_compute_pool_name=service_compute_pool_name,
|
|
243
244
|
ingress_enabled=ingress_enabled,
|
|
245
|
+
min_instances=min_instances,
|
|
244
246
|
max_instances=max_instances,
|
|
245
247
|
cpu=cpu_requests,
|
|
246
248
|
memory=memory_requests,
|
|
@@ -829,15 +831,13 @@ class ServiceOperator:
|
|
|
829
831
|
service_seen_before = False
|
|
830
832
|
|
|
831
833
|
while True:
|
|
832
|
-
# Check if async job has
|
|
834
|
+
# Check if async job has completed
|
|
833
835
|
if async_job.is_done():
|
|
834
836
|
try:
|
|
835
837
|
async_job.result()
|
|
836
|
-
# Async job completed successfully
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
|
|
840
|
-
)
|
|
838
|
+
# Async job completed successfully - deployment is done
|
|
839
|
+
module_logger.debug(f"Async job completed successfully, returning from wait for {service_name}")
|
|
840
|
+
return
|
|
841
841
|
except Exception as e:
|
|
842
842
|
raise RuntimeError(f"Service deployment failed: {e}")
|
|
843
843
|
|
|
@@ -140,6 +140,7 @@ class ModelDeploymentSpec:
|
|
|
140
140
|
service_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
141
141
|
service_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
142
142
|
ingress_enabled: bool = True,
|
|
143
|
+
min_instances: int = 0,
|
|
143
144
|
max_instances: int = 1,
|
|
144
145
|
cpu: Optional[str] = None,
|
|
145
146
|
memory: Optional[str] = None,
|
|
@@ -156,6 +157,7 @@ class ModelDeploymentSpec:
|
|
|
156
157
|
service_database_name: Database name for the service.
|
|
157
158
|
service_schema_name: Schema name for the service.
|
|
158
159
|
ingress_enabled: Whether ingress is enabled.
|
|
160
|
+
min_instances: Minimum number of service instances.
|
|
159
161
|
max_instances: Maximum number of service instances.
|
|
160
162
|
cpu: CPU requirement.
|
|
161
163
|
memory: Memory requirement.
|
|
@@ -187,6 +189,7 @@ class ModelDeploymentSpec:
|
|
|
187
189
|
name=fq_service_name,
|
|
188
190
|
compute_pool=inference_compute_pool_name.identifier(),
|
|
189
191
|
ingress_enabled=ingress_enabled,
|
|
192
|
+
min_instances=min_instances,
|
|
190
193
|
max_instances=max_instances,
|
|
191
194
|
autocapture=autocapture,
|
|
192
195
|
**self._inference_spec,
|
|
@@ -105,6 +105,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
|
|
|
105
105
|
image_repo: Optional[str] = None,
|
|
106
106
|
image_build_compute_pool: Optional[str] = None,
|
|
107
107
|
ingress_enabled: bool = False,
|
|
108
|
+
min_instances: int = 0,
|
|
108
109
|
max_instances: int = 1,
|
|
109
110
|
cpu_requests: Optional[str] = None,
|
|
110
111
|
memory_requests: Optional[str] = None,
|
|
@@ -133,6 +134,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
|
|
|
133
134
|
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
|
134
135
|
the service compute pool if None.
|
|
135
136
|
ingress_enabled: Whether ingress is enabled. Defaults to False.
|
|
137
|
+
min_instances: Minimum number of instances. Defaults to 0.
|
|
136
138
|
max_instances: Maximum number of instances. Defaults to 1.
|
|
137
139
|
cpu_requests: CPU requests configuration. Defaults to None.
|
|
138
140
|
memory_requests: Memory requests configuration. Defaults to None.
|
|
@@ -225,6 +227,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
|
|
|
225
227
|
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
226
228
|
image_repo_name=image_repo,
|
|
227
229
|
ingress_enabled=ingress_enabled,
|
|
230
|
+
min_instances=min_instances,
|
|
228
231
|
max_instances=max_instances,
|
|
229
232
|
cpu_requests=cpu_requests,
|
|
230
233
|
memory_requests=memory_requests,
|