snowflake-ml-python 1.7.4__py3-none-any.whl → 1.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +21 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/jobs/_utils/constants.py +7 -1
- snowflake/ml/jobs/_utils/payload_utils.py +139 -53
- snowflake/ml/jobs/_utils/spec_utils.py +5 -7
- snowflake/ml/jobs/decorators.py +5 -25
- snowflake/ml/jobs/job.py +4 -4
- snowflake/ml/model/_packager/model_env/model_env.py +45 -28
- snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +16 -0
- snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/core.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +5 -5
- snowflake/ml/model/_signatures/pandas_handler.py +9 -7
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/model_signature.py +8 -0
- snowflake/ml/model/type_hints.py +15 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
- snowflake/ml/registry/registry.py +34 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +58 -25
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +41 -38
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,8 @@
|
|
1
|
+
import functools
|
1
2
|
import inspect
|
2
3
|
import io
|
4
|
+
import itertools
|
5
|
+
import pickle
|
3
6
|
import sys
|
4
7
|
import textwrap
|
5
8
|
from pathlib import Path, PurePath
|
@@ -19,9 +22,11 @@ import cloudpickle as cp
|
|
19
22
|
|
20
23
|
from snowflake import snowpark
|
21
24
|
from snowflake.ml.jobs._utils import constants, types
|
25
|
+
from snowflake.snowpark import exceptions as sp_exceptions
|
22
26
|
from snowflake.snowpark._internal import code_generation
|
23
27
|
|
24
28
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
29
|
+
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
25
30
|
_STARTUP_SCRIPT_PATH = PurePath("startup.sh")
|
26
31
|
_STARTUP_SCRIPT_CODE = textwrap.dedent(
|
27
32
|
f"""
|
@@ -69,12 +74,11 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
69
74
|
shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
|
70
75
|
|
71
76
|
# Configure IP address and logging directory
|
72
|
-
eth0Ip=$(ifconfig eth0 | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
77
|
+
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
73
78
|
log_dir="/tmp/ray"
|
74
79
|
|
75
|
-
# Check if eth0Ip is
|
76
|
-
if [
|
77
|
-
# This should never happen, but just in case ethOIp is not set, we should default to localhost
|
80
|
+
# Check if eth0Ip is a valid IP address and fall back to default if necessary
|
81
|
+
if [[ ! $eth0Ip =~ ^[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+$ ]]; then
|
78
82
|
eth0Ip="127.0.0.1"
|
79
83
|
fi
|
80
84
|
|
@@ -120,6 +124,34 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
120
124
|
).strip()
|
121
125
|
|
122
126
|
|
127
|
+
def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
|
128
|
+
parent = parent.absolute()
|
129
|
+
if entrypoint is None:
|
130
|
+
if parent.is_file():
|
131
|
+
# Infer entrypoint from source
|
132
|
+
entrypoint = parent
|
133
|
+
else:
|
134
|
+
raise ValueError("entrypoint must be provided when source is a directory")
|
135
|
+
elif entrypoint.is_absolute():
|
136
|
+
# Absolute path - validate it's a subpath of source dir
|
137
|
+
if not entrypoint.is_relative_to(parent):
|
138
|
+
raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint})")
|
139
|
+
else:
|
140
|
+
# Relative path
|
141
|
+
if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
|
142
|
+
# Relative to working dir iff path is relative to source dir and exists
|
143
|
+
entrypoint = abs_entrypoint
|
144
|
+
else:
|
145
|
+
# Relative to source dir
|
146
|
+
entrypoint = parent.joinpath(entrypoint)
|
147
|
+
if not entrypoint.is_file():
|
148
|
+
raise FileNotFoundError(
|
149
|
+
"Entrypoint not found. Ensure the entrypoint is a valid file and is under"
|
150
|
+
f" the source directory (source={parent}, entrypoint={entrypoint})"
|
151
|
+
)
|
152
|
+
return entrypoint
|
153
|
+
|
154
|
+
|
123
155
|
class JobPayload:
|
124
156
|
def __init__(
|
125
157
|
self,
|
@@ -138,23 +170,23 @@ class JobPayload:
|
|
138
170
|
# since we will generate the file from the serialized callable
|
139
171
|
pass
|
140
172
|
elif isinstance(self.source, Path):
|
141
|
-
# Validate
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
if not
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
173
|
+
# Validate source
|
174
|
+
source = self.source
|
175
|
+
if not source.exists():
|
176
|
+
raise FileNotFoundError(f"{source} does not exist")
|
177
|
+
source = source.absolute()
|
178
|
+
|
179
|
+
# Validate entrypoint
|
180
|
+
entrypoint = _resolve_entrypoint(source, self.entrypoint)
|
181
|
+
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
182
|
+
raise ValueError(
|
183
|
+
"Unsupported entrypoint type:"
|
184
|
+
f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
|
185
|
+
)
|
186
|
+
|
187
|
+
# Update fields with normalized values
|
188
|
+
self.source = source
|
189
|
+
self.entrypoint = entrypoint
|
158
190
|
else:
|
159
191
|
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
160
192
|
|
@@ -168,12 +200,16 @@ class JobPayload:
|
|
168
200
|
entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
|
169
201
|
|
170
202
|
# Create stage if necessary
|
171
|
-
stage_name = stage_path.parts[0]
|
172
|
-
|
173
|
-
|
174
|
-
"
|
175
|
-
|
176
|
-
|
203
|
+
stage_name = stage_path.parts[0].lstrip("@")
|
204
|
+
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
205
|
+
try:
|
206
|
+
session.sql(f"describe stage {stage_name}").collect()
|
207
|
+
except sp_exceptions.SnowparkSQLException:
|
208
|
+
session.sql(
|
209
|
+
f"create stage if not exists {stage_name}"
|
210
|
+
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
211
|
+
" comment = 'Created by snowflake.ml.jobs Python API'"
|
212
|
+
).collect()
|
177
213
|
|
178
214
|
# Upload payload to stage
|
179
215
|
if not isinstance(source, Path):
|
@@ -237,7 +273,7 @@ class JobPayload:
|
|
237
273
|
)
|
238
274
|
|
239
275
|
|
240
|
-
def
|
276
|
+
def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
241
277
|
# Unwrap Optional type annotations
|
242
278
|
param_type = param.annotation
|
243
279
|
if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
|
@@ -249,7 +285,7 @@ def get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
|
249
285
|
return cast(Type[object], param_type)
|
250
286
|
|
251
287
|
|
252
|
-
def
|
288
|
+
def _validate_parameter_type(param_type: Type[object], param_name: str) -> None:
|
253
289
|
# Validate param_type is a supported type
|
254
290
|
if param_type not in _SUPPORTED_ARG_TYPES:
|
255
291
|
raise ValueError(
|
@@ -258,41 +294,60 @@ def validate_parameter_type(param_type: Type[object], param_name: str) -> None:
|
|
258
294
|
)
|
259
295
|
|
260
296
|
|
261
|
-
def
|
262
|
-
|
263
|
-
if any(
|
264
|
-
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
265
|
-
for p in signature.parameters.values()
|
266
|
-
):
|
267
|
-
raise NotImplementedError("Function must not have unpacking arguments (* or **)")
|
268
|
-
|
269
|
-
# Mirrored from Snowpark generate_python_code() function
|
270
|
-
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
297
|
+
def _generate_source_code_comment(func: Callable[..., Any]) -> str:
|
298
|
+
"""Generate a comment string containing the source code of a function for readability."""
|
271
299
|
try:
|
272
|
-
|
273
|
-
|
274
|
-
|
300
|
+
if isinstance(func, functools.partial):
|
301
|
+
# Unwrap functools.partial and generate source code comment from the original function
|
302
|
+
comment = code_generation.generate_source_code(func.func) # type: ignore[arg-type]
|
303
|
+
args = itertools.chain((repr(a) for a in func.args), (f"{k}={v!r}" for k, v in func.keywords.items()))
|
304
|
+
|
305
|
+
# Update invocation comment to show arguments passed via functools.partial
|
306
|
+
comment = comment.replace(
|
307
|
+
f"= {func.func.__name__}",
|
308
|
+
"= functools.partial({}({}))".format(
|
309
|
+
func.func.__name__,
|
310
|
+
", ".join(args),
|
311
|
+
),
|
312
|
+
)
|
313
|
+
return comment
|
314
|
+
else:
|
315
|
+
return code_generation.generate_source_code(func) # type: ignore[arg-type]
|
275
316
|
except Exception as exc:
|
276
317
|
error_msg = f"Source code comment could not be generated for {func} due to error {exc}."
|
277
|
-
|
318
|
+
return code_generation.comment_source_code(error_msg)
|
278
319
|
|
279
|
-
func_name = "func"
|
280
|
-
func_code = f"""
|
281
|
-
{source_code_comment}
|
282
|
-
|
283
|
-
import pickle
|
284
|
-
{func_name} = pickle.loads(bytes.fromhex('{cp.dumps(func).hex()}'))
|
285
|
-
"""
|
286
320
|
|
321
|
+
def _serialize_callable(func: Callable[..., Any]) -> bytes:
|
322
|
+
try:
|
323
|
+
func_bytes: bytes = cp.dumps(func)
|
324
|
+
return func_bytes
|
325
|
+
except pickle.PicklingError as e:
|
326
|
+
if isinstance(func, functools.partial):
|
327
|
+
# Try to find which part of the partial isn't serializable for better debuggability
|
328
|
+
objects = [
|
329
|
+
("function", func.func),
|
330
|
+
*((f"positional arg {i}", a) for i, a in enumerate(func.args)),
|
331
|
+
*((f"keyword arg '{k}'", v) for k, v in func.keywords.items()),
|
332
|
+
]
|
333
|
+
for name, obj in objects:
|
334
|
+
try:
|
335
|
+
cp.dumps(obj)
|
336
|
+
except pickle.PicklingError:
|
337
|
+
raise ValueError(f"Unable to serialize {name}: {obj}") from e
|
338
|
+
raise ValueError(f"Unable to serialize function: {func}") from e
|
339
|
+
|
340
|
+
|
341
|
+
def _generate_param_handler_code(signature: inspect.Signature, output_name: str = "kwargs") -> str:
|
287
342
|
# Generate argparse logic for argument handling (type coercion, default values, etc)
|
288
343
|
argparse_code = ["import argparse", "", "parser = argparse.ArgumentParser()"]
|
289
344
|
argparse_postproc = []
|
290
345
|
for name, param in signature.parameters.items():
|
291
346
|
opts = {}
|
292
347
|
|
293
|
-
param_type =
|
348
|
+
param_type = _get_parameter_type(param)
|
294
349
|
if param_type is not None:
|
295
|
-
|
350
|
+
_validate_parameter_type(param_type, name)
|
296
351
|
opts["type"] = param_type.__name__
|
297
352
|
|
298
353
|
if param.default != inspect.Parameter.empty:
|
@@ -324,6 +379,37 @@ import pickle
|
|
324
379
|
)
|
325
380
|
argparse_code.append("args = parser.parse_args()")
|
326
381
|
param_code = "\n".join(argparse_code + argparse_postproc)
|
382
|
+
param_code += f"\n{output_name} = vars(args)"
|
383
|
+
|
384
|
+
return param_code
|
385
|
+
|
386
|
+
|
387
|
+
def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
|
388
|
+
"""Generate an entrypoint script from a Python function."""
|
389
|
+
signature = inspect.signature(func)
|
390
|
+
if any(
|
391
|
+
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
392
|
+
for p in signature.parameters.values()
|
393
|
+
):
|
394
|
+
raise NotImplementedError("Function must not have unpacking arguments (* or **)")
|
395
|
+
|
396
|
+
# Mirrored from Snowpark generate_python_code() function
|
397
|
+
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
398
|
+
source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
|
399
|
+
|
400
|
+
func_name = "func"
|
401
|
+
func_code = f"""
|
402
|
+
{source_code_comment}
|
403
|
+
|
404
|
+
import pickle
|
405
|
+
{func_name} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
406
|
+
"""
|
407
|
+
|
408
|
+
arg_dict_name = "kwargs"
|
409
|
+
if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
|
410
|
+
param_code = f"{arg_dict_name} = {{}}"
|
411
|
+
else:
|
412
|
+
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
327
413
|
|
328
414
|
return f"""
|
329
415
|
### Version guard to check compatibility across Python versions ###
|
@@ -348,5 +434,5 @@ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor
|
|
348
434
|
if __name__ == '__main__':
|
349
435
|
{textwrap.indent(param_code, ' ')}
|
350
436
|
|
351
|
-
{func_name}(**
|
437
|
+
{func_name}(**{arg_dict_name})
|
352
438
|
"""
|
@@ -141,37 +141,35 @@ def generate_service_spec(
|
|
141
141
|
)
|
142
142
|
|
143
143
|
# Mount 30% of memory limit as a memory-backed volume
|
144
|
-
memory_volume_name = "dshm"
|
145
144
|
memory_volume_size = min(
|
146
145
|
ceil(image_spec.resource_limits.memory * constants.MEMORY_VOLUME_SIZE),
|
147
146
|
image_spec.resource_requests.memory,
|
148
147
|
)
|
149
148
|
volume_mounts.append(
|
150
149
|
{
|
151
|
-
"name":
|
150
|
+
"name": constants.MEMORY_VOLUME_NAME,
|
152
151
|
"mountPath": "/dev/shm",
|
153
152
|
}
|
154
153
|
)
|
155
154
|
volumes.append(
|
156
155
|
{
|
157
|
-
"name":
|
156
|
+
"name": constants.MEMORY_VOLUME_NAME,
|
158
157
|
"source": "memory",
|
159
158
|
"size": f"{memory_volume_size}Gi",
|
160
159
|
}
|
161
160
|
)
|
162
161
|
|
163
162
|
# Mount payload as volume
|
164
|
-
stage_mount = PurePath(
|
165
|
-
stage_volume_name = "stage-volume"
|
163
|
+
stage_mount = PurePath(constants.STAGE_VOLUME_MOUNT_PATH)
|
166
164
|
volume_mounts.append(
|
167
165
|
{
|
168
|
-
"name":
|
166
|
+
"name": constants.STAGE_VOLUME_NAME,
|
169
167
|
"mountPath": stage_mount.as_posix(),
|
170
168
|
}
|
171
169
|
)
|
172
170
|
volumes.append(
|
173
171
|
{
|
174
|
-
"name":
|
172
|
+
"name": constants.STAGE_VOLUME_NAME,
|
175
173
|
"source": payload.stage_path.as_posix(),
|
176
174
|
}
|
177
175
|
)
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
import inspect
|
4
3
|
from typing import Callable, Dict, List, Optional, TypeVar
|
5
4
|
|
6
5
|
from typing_extensions import ParamSpec
|
@@ -8,7 +7,7 @@ from typing_extensions import ParamSpec
|
|
8
7
|
from snowflake import snowpark
|
9
8
|
from snowflake.ml._internal import telemetry
|
10
9
|
from snowflake.ml.jobs import job as jb, manager as jm
|
11
|
-
from snowflake.ml.jobs._utils import
|
10
|
+
from snowflake.ml.jobs._utils import constants
|
12
11
|
|
13
12
|
_PROJECT = "MLJob"
|
14
13
|
|
@@ -50,31 +49,12 @@ def remote(
|
|
50
49
|
wrapped_func = copy.copy(func)
|
51
50
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
52
51
|
|
53
|
-
# Validate function arguments based on signature
|
54
|
-
signature = inspect.signature(func)
|
55
|
-
pos_arg_names = []
|
56
|
-
for name, param in signature.parameters.items():
|
57
|
-
param_type = payload_utils.get_parameter_type(param)
|
58
|
-
if param_type is not None:
|
59
|
-
payload_utils.validate_parameter_type(param_type, name)
|
60
|
-
if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
61
|
-
pos_arg_names.append(name)
|
62
|
-
|
63
52
|
@functools.wraps(func)
|
64
53
|
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
|
65
|
-
|
66
|
-
|
67
|
-
arg_name = pos_arg_names[i] if i < len(pos_arg_names) else f"args[{i}]"
|
68
|
-
payload_utils.validate_parameter_type(type(arg), arg_name)
|
69
|
-
|
70
|
-
# Validate keyword args
|
71
|
-
for k, v in kwargs.items():
|
72
|
-
payload_utils.validate_parameter_type(type(v), k)
|
73
|
-
|
74
|
-
arg_list = [str(v) for v in args] + [x for k, v in kwargs.items() for x in (f"--{k}", str(v))]
|
54
|
+
payload = functools.partial(func, *args, **kwargs)
|
55
|
+
setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
|
75
56
|
job = jm._submit_job(
|
76
|
-
source=
|
77
|
-
args=arg_list,
|
57
|
+
source=payload,
|
78
58
|
stage_name=stage_name,
|
79
59
|
compute_pool=compute_pool,
|
80
60
|
pip_requirements=pip_requirements,
|
@@ -83,7 +63,7 @@ def remote(
|
|
83
63
|
env_vars=env_vars,
|
84
64
|
session=session,
|
85
65
|
)
|
86
|
-
assert isinstance(job, jb.MLJob)
|
66
|
+
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
87
67
|
return job
|
88
68
|
|
89
69
|
return wrapper
|
snowflake/ml/jobs/job.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Any, List, Optional, cast
|
|
4
4
|
from snowflake import snowpark
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
from snowflake.ml.jobs._utils import constants, types
|
7
|
-
from snowflake.snowpark
|
7
|
+
from snowflake.snowpark import context as sp_context
|
8
8
|
|
9
9
|
_PROJECT = "MLJob"
|
10
10
|
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
@@ -13,7 +13,7 @@ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
|
13
13
|
class MLJob:
|
14
14
|
def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
|
15
15
|
self._id = id
|
16
|
-
self._session = session or get_active_session()
|
16
|
+
self._session = session or sp_context.get_active_session()
|
17
17
|
self._status: types.JOB_STATUS = "PENDING"
|
18
18
|
|
19
19
|
@property
|
@@ -79,7 +79,7 @@ class MLJob:
|
|
79
79
|
return self.status
|
80
80
|
|
81
81
|
|
82
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
82
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
83
83
|
def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
|
84
84
|
"""Retrieve job execution status."""
|
85
85
|
# TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
|
@@ -90,7 +90,7 @@ def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
|
|
90
90
|
return cast(types.JOB_STATUS, row["status"])
|
91
91
|
|
92
92
|
|
93
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
93
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
|
94
94
|
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
95
95
|
"""
|
96
96
|
Retrieve the job's execution logs.
|
@@ -113,7 +113,33 @@ class ModelEnv:
|
|
113
113
|
self._snowpark_ml_version = version.parse(snowpark_ml_version)
|
114
114
|
|
115
115
|
def include_if_absent(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
|
116
|
-
"""Append requirements into model env if absent.
|
116
|
+
"""Append requirements into model env if absent. Depending on the environment, requirements may be added
|
117
|
+
to either the pip requirements or conda dependencies.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
pkgs: A list of ModelDependency namedtuple to be appended.
|
121
|
+
check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
|
122
|
+
"""
|
123
|
+
if self.pip_requirements and not self.conda_dependencies and pkgs:
|
124
|
+
pip_pkg_reqs: List[str] = []
|
125
|
+
warnings.warn(
|
126
|
+
(
|
127
|
+
"Dependencies specified from pip requirements."
|
128
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
129
|
+
),
|
130
|
+
category=UserWarning,
|
131
|
+
stacklevel=2,
|
132
|
+
)
|
133
|
+
for conda_req_str, pip_name in pkgs:
|
134
|
+
_, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
|
135
|
+
pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
|
136
|
+
pip_pkg_reqs.append(str(pip_req))
|
137
|
+
self._include_if_absent_pip(pip_pkg_reqs, check_local_version)
|
138
|
+
else:
|
139
|
+
self._include_if_absent_conda(pkgs, check_local_version)
|
140
|
+
|
141
|
+
def _include_if_absent_conda(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
|
142
|
+
"""Append requirements into model env conda dependencies if absent.
|
117
143
|
|
118
144
|
Args:
|
119
145
|
pkgs: A list of ModelDependency namedtuple to be appended.
|
@@ -134,8 +160,8 @@ class ModelEnv:
|
|
134
160
|
if show_warning_message:
|
135
161
|
warnings.warn(
|
136
162
|
(
|
137
|
-
f"Basic dependency {req_to_add.name} specified from
|
138
|
-
|
163
|
+
f"Basic dependency {req_to_add.name} specified from pip requirements."
|
164
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
139
165
|
),
|
140
166
|
category=UserWarning,
|
141
167
|
stacklevel=2,
|
@@ -157,11 +183,11 @@ class ModelEnv:
|
|
157
183
|
stacklevel=2,
|
158
184
|
)
|
159
185
|
|
160
|
-
def
|
161
|
-
"""Append pip requirements into model env if absent.
|
186
|
+
def _include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
|
187
|
+
"""Append pip requirements into model env pip requirements if absent.
|
162
188
|
|
163
189
|
Args:
|
164
|
-
pkgs: A list of
|
190
|
+
pkgs: A list of strings to be appended to pip environment.
|
165
191
|
check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
|
166
192
|
"""
|
167
193
|
|
@@ -187,25 +213,6 @@ class ModelEnv:
|
|
187
213
|
self._conda_dependencies[channel].remove(spec)
|
188
214
|
|
189
215
|
def generate_env_for_cuda(self) -> None:
|
190
|
-
if self.cuda_version is None:
|
191
|
-
return
|
192
|
-
|
193
|
-
cuda_spec = env_utils.find_dep_spec(
|
194
|
-
self._conda_dependencies, self._pip_requirements, conda_pkg_name="cuda", remove_spec=False
|
195
|
-
)
|
196
|
-
if cuda_spec and not cuda_spec.specifier.contains(self.cuda_version):
|
197
|
-
raise ValueError(
|
198
|
-
"The CUDA requirement you specified in your conda dependencies or pip requirements is"
|
199
|
-
" conflicting with CUDA version required. Please do not specify CUDA dependency using conda"
|
200
|
-
" dependencies or pip requirements."
|
201
|
-
)
|
202
|
-
|
203
|
-
if not cuda_spec:
|
204
|
-
self.include_if_absent(
|
205
|
-
[ModelDependency(requirement=f"nvidia::cuda=={self.cuda_version}.*", pip_name="cuda")],
|
206
|
-
check_local_version=False,
|
207
|
-
)
|
208
|
-
|
209
216
|
xgboost_spec = env_utils.find_dep_spec(
|
210
217
|
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
|
211
218
|
)
|
@@ -236,7 +243,7 @@ class ModelEnv:
|
|
236
243
|
check_local_version=False,
|
237
244
|
)
|
238
245
|
|
239
|
-
self.
|
246
|
+
self._include_if_absent_pip(["bitsandbytes>=0.41.0"], check_local_version=False)
|
240
247
|
|
241
248
|
def relax_version(self) -> None:
|
242
249
|
"""Relax the version requirements for both conda dependencies and pip requirements.
|
@@ -252,7 +259,9 @@ class ModelEnv:
|
|
252
259
|
self._pip_requirements = list(map(env_utils.relax_requirement_version, self._pip_requirements))
|
253
260
|
|
254
261
|
def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None:
|
255
|
-
conda_dependencies_dict, pip_requirements_list, python_version = env_utils.load_conda_env_file(
|
262
|
+
conda_dependencies_dict, pip_requirements_list, python_version, cuda_version = env_utils.load_conda_env_file(
|
263
|
+
conda_env_path
|
264
|
+
)
|
256
265
|
|
257
266
|
for channel, channel_dependencies in conda_dependencies_dict.items():
|
258
267
|
if channel != env_utils.DEFAULT_CHANNEL_NAME:
|
@@ -310,6 +319,9 @@ class ModelEnv:
|
|
310
319
|
if python_version:
|
311
320
|
self.python_version = python_version
|
312
321
|
|
322
|
+
if cuda_version:
|
323
|
+
self.cuda_version = cuda_version
|
324
|
+
|
313
325
|
def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
|
314
326
|
pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
|
315
327
|
|
@@ -342,12 +354,17 @@ class ModelEnv:
|
|
342
354
|
self.snowpark_ml_version = env_dict["snowpark_ml_version"]
|
343
355
|
|
344
356
|
def save_as_dict(
|
345
|
-
self,
|
357
|
+
self,
|
358
|
+
base_dir: pathlib.Path,
|
359
|
+
default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL,
|
360
|
+
is_gpu: Optional[bool] = False,
|
346
361
|
) -> model_meta_schema.ModelEnvDict:
|
362
|
+
cuda_version = self.cuda_version if is_gpu else None
|
347
363
|
env_utils.save_conda_env_file(
|
348
364
|
pathlib.Path(base_dir / self.conda_env_rel_path),
|
349
365
|
self._conda_dependencies,
|
350
366
|
self.python_version,
|
367
|
+
cuda_version,
|
351
368
|
default_channel_override=default_channel_override,
|
352
369
|
)
|
353
370
|
env_utils.save_requirements_file(
|
@@ -39,7 +39,7 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
|
|
39
39
|
|
40
40
|
|
41
41
|
def get_truncated_sample_data(
|
42
|
-
sample_input_data: model_types.SupportedDataType, length: int = 100
|
42
|
+
sample_input_data: model_types.SupportedDataType, length: int = 100, is_for_modeling_model: bool = False
|
43
43
|
) -> model_types.SupportedLocalDataType:
|
44
44
|
trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
|
45
45
|
local_sample_input: model_types.SupportedLocalDataType = None
|
@@ -47,6 +47,8 @@ def get_truncated_sample_data(
|
|
47
47
|
# Added because of Any from missing stubs.
|
48
48
|
trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
|
49
49
|
local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
|
50
|
+
if is_for_modeling_model:
|
51
|
+
local_sample_input.columns = trunc_sample_input.columns
|
50
52
|
else:
|
51
53
|
local_sample_input = trunc_sample_input
|
52
54
|
return local_sample_input
|
@@ -58,13 +60,15 @@ def validate_signature(
|
|
58
60
|
target_methods: Iterable[str],
|
59
61
|
sample_input_data: Optional[model_types.SupportedDataType],
|
60
62
|
get_prediction_fn: Callable[[str, model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
|
63
|
+
is_for_modeling_model: bool = False,
|
61
64
|
) -> model_meta.ModelMetadata:
|
62
65
|
if model_meta.signatures:
|
63
66
|
validate_target_methods(model, list(model_meta.signatures.keys()))
|
64
67
|
if sample_input_data is not None:
|
65
|
-
local_sample_input = get_truncated_sample_data(
|
68
|
+
local_sample_input = get_truncated_sample_data(
|
69
|
+
sample_input_data, is_for_modeling_model=is_for_modeling_model
|
70
|
+
)
|
66
71
|
for target_method in model_meta.signatures.keys():
|
67
|
-
|
68
72
|
model_signature_inst = model_meta.signatures.get(target_method)
|
69
73
|
if model_signature_inst is not None:
|
70
74
|
# strict validation the input signature
|
@@ -77,7 +81,7 @@ def validate_signature(
|
|
77
81
|
assert (
|
78
82
|
sample_input_data is not None
|
79
83
|
), "Model signature and sample input are None at the same time. This should not happen with local model."
|
80
|
-
local_sample_input = get_truncated_sample_data(sample_input_data)
|
84
|
+
local_sample_input = get_truncated_sample_data(sample_input_data, is_for_modeling_model=is_for_modeling_model)
|
81
85
|
for target_method in target_methods:
|
82
86
|
predictions_df = get_prediction_fn(target_method, local_sample_input)
|
83
87
|
sig = model_signature.infer_signature(
|
@@ -146,6 +146,10 @@ class HuggingFacePipelineHandler(
|
|
146
146
|
framework = getattr(model, "framework", None)
|
147
147
|
batch_size = getattr(model, "batch_size", None)
|
148
148
|
|
149
|
+
has_tokenizer = getattr(model, "tokenizer", None) is not None
|
150
|
+
has_feature_extractor = getattr(model, "feature_extractor", None) is not None
|
151
|
+
has_image_preprocessor = getattr(model, "image_preprocessor", None) is not None
|
152
|
+
|
149
153
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
150
154
|
params = {
|
151
155
|
**model._preprocess_params, # type:ignore[attr-defined]
|
@@ -234,6 +238,9 @@ class HuggingFacePipelineHandler(
|
|
234
238
|
{
|
235
239
|
"task": task,
|
236
240
|
"batch_size": batch_size if batch_size is not None else 1,
|
241
|
+
"has_tokenizer": has_tokenizer,
|
242
|
+
"has_feature_extractor": has_feature_extractor,
|
243
|
+
"has_image_preprocessor": has_image_preprocessor,
|
237
244
|
}
|
238
245
|
),
|
239
246
|
)
|
@@ -308,6 +315,14 @@ class HuggingFacePipelineHandler(
|
|
308
315
|
if os.path.isdir(model_blob_file_or_dir_path):
|
309
316
|
import transformers
|
310
317
|
|
318
|
+
additional_pipeline_params = {}
|
319
|
+
if model_blob_options.get("has_tokenizer", False):
|
320
|
+
additional_pipeline_params["tokenizer"] = model_blob_file_or_dir_path
|
321
|
+
if model_blob_options.get("has_feature_extractor", False):
|
322
|
+
additional_pipeline_params["feature_extractor"] = model_blob_file_or_dir_path
|
323
|
+
if model_blob_options.get("has_image_preprocessor", False):
|
324
|
+
additional_pipeline_params["image_preprocessor"] = model_blob_file_or_dir_path
|
325
|
+
|
311
326
|
with open(
|
312
327
|
os.path.join(
|
313
328
|
model_blob_file_or_dir_path,
|
@@ -324,6 +339,7 @@ class HuggingFacePipelineHandler(
|
|
324
339
|
model=model_blob_file_or_dir_path,
|
325
340
|
trust_remote_code=True,
|
326
341
|
torch_dtype="auto",
|
342
|
+
**additional_pipeline_params,
|
327
343
|
**device_config,
|
328
344
|
)
|
329
345
|
|