snowflake-ml-python 1.7.4__py3-none-any.whl → 1.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. snowflake/ml/_internal/env_utils.py +64 -21
  2. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  3. snowflake/ml/_internal/telemetry.py +21 -0
  4. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  5. snowflake/ml/feature_store/feature_store.py +18 -0
  6. snowflake/ml/feature_store/feature_view.py +46 -1
  7. snowflake/ml/jobs/_utils/constants.py +7 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +139 -53
  9. snowflake/ml/jobs/_utils/spec_utils.py +5 -7
  10. snowflake/ml/jobs/decorators.py +5 -25
  11. snowflake/ml/jobs/job.py +4 -4
  12. snowflake/ml/model/_packager/model_env/model_env.py +45 -28
  13. snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
  14. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +16 -0
  15. snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
  16. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
  17. snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
  18. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
  19. snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
  20. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  21. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
  22. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
  23. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  24. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
  25. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  26. snowflake/ml/model/_signatures/core.py +2 -2
  27. snowflake/ml/model/_signatures/numpy_handler.py +5 -5
  28. snowflake/ml/model/_signatures/pandas_handler.py +9 -7
  29. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  30. snowflake/ml/model/model_signature.py +8 -0
  31. snowflake/ml/model/type_hints.py +15 -0
  32. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  33. snowflake/ml/modeling/pipeline/pipeline.py +18 -1
  34. snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
  35. snowflake/ml/registry/registry.py +34 -4
  36. snowflake/ml/version.py +1 -1
  37. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +58 -25
  38. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +41 -38
  39. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
  40. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
  41. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,8 @@
1
+ import functools
1
2
  import inspect
2
3
  import io
4
+ import itertools
5
+ import pickle
3
6
  import sys
4
7
  import textwrap
5
8
  from pathlib import Path, PurePath
@@ -19,9 +22,11 @@ import cloudpickle as cp
19
22
 
20
23
  from snowflake import snowpark
21
24
  from snowflake.ml.jobs._utils import constants, types
25
+ from snowflake.snowpark import exceptions as sp_exceptions
22
26
  from snowflake.snowpark._internal import code_generation
23
27
 
24
28
  _SUPPORTED_ARG_TYPES = {str, int, float}
29
+ _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
25
30
  _STARTUP_SCRIPT_PATH = PurePath("startup.sh")
26
31
  _STARTUP_SCRIPT_CODE = textwrap.dedent(
27
32
  f"""
@@ -69,12 +74,11 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
69
74
  shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
70
75
 
71
76
  # Configure IP address and logging directory
72
- eth0Ip=$(ifconfig eth0 | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
77
+ eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
73
78
  log_dir="/tmp/ray"
74
79
 
75
- # Check if eth0Ip is empty and set default if necessary
76
- if [ -z "$eth0Ip" ]; then
77
- # This should never happen, but just in case ethOIp is not set, we should default to localhost
80
+ # Check if eth0Ip is a valid IP address and fall back to default if necessary
81
+ if [[ ! $eth0Ip =~ ^[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+$ ]]; then
78
82
  eth0Ip="127.0.0.1"
79
83
  fi
80
84
 
@@ -120,6 +124,34 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
120
124
  ).strip()
121
125
 
122
126
 
127
+ def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
128
+ parent = parent.absolute()
129
+ if entrypoint is None:
130
+ if parent.is_file():
131
+ # Infer entrypoint from source
132
+ entrypoint = parent
133
+ else:
134
+ raise ValueError("entrypoint must be provided when source is a directory")
135
+ elif entrypoint.is_absolute():
136
+ # Absolute path - validate it's a subpath of source dir
137
+ if not entrypoint.is_relative_to(parent):
138
+ raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint})")
139
+ else:
140
+ # Relative path
141
+ if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
142
+ # Relative to working dir iff path is relative to source dir and exists
143
+ entrypoint = abs_entrypoint
144
+ else:
145
+ # Relative to source dir
146
+ entrypoint = parent.joinpath(entrypoint)
147
+ if not entrypoint.is_file():
148
+ raise FileNotFoundError(
149
+ "Entrypoint not found. Ensure the entrypoint is a valid file and is under"
150
+ f" the source directory (source={parent}, entrypoint={entrypoint})"
151
+ )
152
+ return entrypoint
153
+
154
+
123
155
  class JobPayload:
124
156
  def __init__(
125
157
  self,
@@ -138,23 +170,23 @@ class JobPayload:
138
170
  # since we will generate the file from the serialized callable
139
171
  pass
140
172
  elif isinstance(self.source, Path):
141
- # Validate self.source and self.entrypoint for files
142
- if not self.source.exists():
143
- raise FileNotFoundError(f"{self.source} does not exist")
144
- if self.entrypoint is None:
145
- if self.source.is_file():
146
- self.entrypoint = self.source
147
- else:
148
- raise ValueError("entrypoint must be provided when source is a directory")
149
- if not self.entrypoint.is_file():
150
- # Check if self.entrypoint is a valid relative path
151
- self.entrypoint = self.source.joinpath(self.entrypoint)
152
- if not self.entrypoint.is_file():
153
- raise FileNotFoundError(f"File {self.entrypoint} does not exist")
154
- if not self.entrypoint.is_relative_to(self.source):
155
- raise ValueError(f"{self.entrypoint} must be a subpath of {self.source}")
156
- if self.entrypoint.suffix != ".py":
157
- raise NotImplementedError("Only Python entrypoints are supported currently")
173
+ # Validate source
174
+ source = self.source
175
+ if not source.exists():
176
+ raise FileNotFoundError(f"{source} does not exist")
177
+ source = source.absolute()
178
+
179
+ # Validate entrypoint
180
+ entrypoint = _resolve_entrypoint(source, self.entrypoint)
181
+ if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
182
+ raise ValueError(
183
+ "Unsupported entrypoint type:"
184
+ f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
185
+ )
186
+
187
+ # Update fields with normalized values
188
+ self.source = source
189
+ self.entrypoint = entrypoint
158
190
  else:
159
191
  raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
160
192
 
@@ -168,12 +200,16 @@ class JobPayload:
168
200
  entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
169
201
 
170
202
  # Create stage if necessary
171
- stage_name = stage_path.parts[0]
172
- session.sql(
173
- f"create stage if not exists {stage_name.lstrip('@')}"
174
- " encryption = ( type = 'SNOWFLAKE_SSE' )"
175
- " comment = 'Created by snowflake.ml.jobs Python API'"
176
- ).collect()
203
+ stage_name = stage_path.parts[0].lstrip("@")
204
+ # Explicitly check if stage exists first since we may not have CREATE STAGE privilege
205
+ try:
206
+ session.sql(f"describe stage {stage_name}").collect()
207
+ except sp_exceptions.SnowparkSQLException:
208
+ session.sql(
209
+ f"create stage if not exists {stage_name}"
210
+ " encryption = ( type = 'SNOWFLAKE_SSE' )"
211
+ " comment = 'Created by snowflake.ml.jobs Python API'"
212
+ ).collect()
177
213
 
178
214
  # Upload payload to stage
179
215
  if not isinstance(source, Path):
@@ -237,7 +273,7 @@ class JobPayload:
237
273
  )
238
274
 
239
275
 
240
- def get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
276
+ def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
241
277
  # Unwrap Optional type annotations
242
278
  param_type = param.annotation
243
279
  if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
@@ -249,7 +285,7 @@ def get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
249
285
  return cast(Type[object], param_type)
250
286
 
251
287
 
252
- def validate_parameter_type(param_type: Type[object], param_name: str) -> None:
288
+ def _validate_parameter_type(param_type: Type[object], param_name: str) -> None:
253
289
  # Validate param_type is a supported type
254
290
  if param_type not in _SUPPORTED_ARG_TYPES:
255
291
  raise ValueError(
@@ -258,41 +294,60 @@ def validate_parameter_type(param_type: Type[object], param_name: str) -> None:
258
294
  )
259
295
 
260
296
 
261
- def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
262
- signature = inspect.signature(func)
263
- if any(
264
- p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
265
- for p in signature.parameters.values()
266
- ):
267
- raise NotImplementedError("Function must not have unpacking arguments (* or **)")
268
-
269
- # Mirrored from Snowpark generate_python_code() function
270
- # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
297
+ def _generate_source_code_comment(func: Callable[..., Any]) -> str:
298
+ """Generate a comment string containing the source code of a function for readability."""
271
299
  try:
272
- source_code_comment = (
273
- code_generation.generate_source_code(func) if source_code_display else "" # type: ignore[arg-type]
274
- )
300
+ if isinstance(func, functools.partial):
301
+ # Unwrap functools.partial and generate source code comment from the original function
302
+ comment = code_generation.generate_source_code(func.func) # type: ignore[arg-type]
303
+ args = itertools.chain((repr(a) for a in func.args), (f"{k}={v!r}" for k, v in func.keywords.items()))
304
+
305
+ # Update invocation comment to show arguments passed via functools.partial
306
+ comment = comment.replace(
307
+ f"= {func.func.__name__}",
308
+ "= functools.partial({}({}))".format(
309
+ func.func.__name__,
310
+ ", ".join(args),
311
+ ),
312
+ )
313
+ return comment
314
+ else:
315
+ return code_generation.generate_source_code(func) # type: ignore[arg-type]
275
316
  except Exception as exc:
276
317
  error_msg = f"Source code comment could not be generated for {func} due to error {exc}."
277
- source_code_comment = code_generation.comment_source_code(error_msg)
318
+ return code_generation.comment_source_code(error_msg)
278
319
 
279
- func_name = "func"
280
- func_code = f"""
281
- {source_code_comment}
282
-
283
- import pickle
284
- {func_name} = pickle.loads(bytes.fromhex('{cp.dumps(func).hex()}'))
285
- """
286
320
 
321
+ def _serialize_callable(func: Callable[..., Any]) -> bytes:
322
+ try:
323
+ func_bytes: bytes = cp.dumps(func)
324
+ return func_bytes
325
+ except pickle.PicklingError as e:
326
+ if isinstance(func, functools.partial):
327
+ # Try to find which part of the partial isn't serializable for better debuggability
328
+ objects = [
329
+ ("function", func.func),
330
+ *((f"positional arg {i}", a) for i, a in enumerate(func.args)),
331
+ *((f"keyword arg '{k}'", v) for k, v in func.keywords.items()),
332
+ ]
333
+ for name, obj in objects:
334
+ try:
335
+ cp.dumps(obj)
336
+ except pickle.PicklingError:
337
+ raise ValueError(f"Unable to serialize {name}: {obj}") from e
338
+ raise ValueError(f"Unable to serialize function: {func}") from e
339
+
340
+
341
+ def _generate_param_handler_code(signature: inspect.Signature, output_name: str = "kwargs") -> str:
287
342
  # Generate argparse logic for argument handling (type coercion, default values, etc)
288
343
  argparse_code = ["import argparse", "", "parser = argparse.ArgumentParser()"]
289
344
  argparse_postproc = []
290
345
  for name, param in signature.parameters.items():
291
346
  opts = {}
292
347
 
293
- param_type = get_parameter_type(param)
348
+ param_type = _get_parameter_type(param)
294
349
  if param_type is not None:
295
- validate_parameter_type(param_type, name)
350
+ _validate_parameter_type(param_type, name)
296
351
  opts["type"] = param_type.__name__
297
352
 
298
353
  if param.default != inspect.Parameter.empty:
@@ -324,6 +379,37 @@ import pickle
324
379
  )
325
380
  argparse_code.append("args = parser.parse_args()")
326
381
  param_code = "\n".join(argparse_code + argparse_postproc)
382
+ param_code += f"\n{output_name} = vars(args)"
383
+
384
+ return param_code
385
+
386
+
387
+ def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
388
+ """Generate an entrypoint script from a Python function."""
389
+ signature = inspect.signature(func)
390
+ if any(
391
+ p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
392
+ for p in signature.parameters.values()
393
+ ):
394
+ raise NotImplementedError("Function must not have unpacking arguments (* or **)")
395
+
396
+ # Mirrored from Snowpark generate_python_code() function
397
+ # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
398
+ source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
399
+
400
+ func_name = "func"
401
+ func_code = f"""
402
+ {source_code_comment}
403
+
404
+ import pickle
405
+ {func_name} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
406
+ """
407
+
408
+ arg_dict_name = "kwargs"
409
+ if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
410
+ param_code = f"{arg_dict_name} = {{}}"
411
+ else:
412
+ param_code = _generate_param_handler_code(signature, arg_dict_name)
327
413
 
328
414
  return f"""
329
415
  ### Version guard to check compatibility across Python versions ###
@@ -348,5 +434,5 @@ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor
348
434
  if __name__ == '__main__':
349
435
  {textwrap.indent(param_code, ' ')}
350
436
 
351
- {func_name}(**vars(args))
437
+ {func_name}(**{arg_dict_name})
352
438
  """
@@ -141,37 +141,35 @@ def generate_service_spec(
141
141
  )
142
142
 
143
143
  # Mount 30% of memory limit as a memory-backed volume
144
- memory_volume_name = "dshm"
145
144
  memory_volume_size = min(
146
145
  ceil(image_spec.resource_limits.memory * constants.MEMORY_VOLUME_SIZE),
147
146
  image_spec.resource_requests.memory,
148
147
  )
149
148
  volume_mounts.append(
150
149
  {
151
- "name": memory_volume_name,
150
+ "name": constants.MEMORY_VOLUME_NAME,
152
151
  "mountPath": "/dev/shm",
153
152
  }
154
153
  )
155
154
  volumes.append(
156
155
  {
157
- "name": memory_volume_name,
156
+ "name": constants.MEMORY_VOLUME_NAME,
158
157
  "source": "memory",
159
158
  "size": f"{memory_volume_size}Gi",
160
159
  }
161
160
  )
162
161
 
163
162
  # Mount payload as volume
164
- stage_mount = PurePath("/opt/app")
165
- stage_volume_name = "stage-volume"
163
+ stage_mount = PurePath(constants.STAGE_VOLUME_MOUNT_PATH)
166
164
  volume_mounts.append(
167
165
  {
168
- "name": stage_volume_name,
166
+ "name": constants.STAGE_VOLUME_NAME,
169
167
  "mountPath": stage_mount.as_posix(),
170
168
  }
171
169
  )
172
170
  volumes.append(
173
171
  {
174
- "name": stage_volume_name,
172
+ "name": constants.STAGE_VOLUME_NAME,
175
173
  "source": payload.stage_path.as_posix(),
176
174
  }
177
175
  )
@@ -1,6 +1,5 @@
1
1
  import copy
2
2
  import functools
3
- import inspect
4
3
  from typing import Callable, Dict, List, Optional, TypeVar
5
4
 
6
5
  from typing_extensions import ParamSpec
@@ -8,7 +7,7 @@ from typing_extensions import ParamSpec
8
7
  from snowflake import snowpark
9
8
  from snowflake.ml._internal import telemetry
10
9
  from snowflake.ml.jobs import job as jb, manager as jm
11
- from snowflake.ml.jobs._utils import payload_utils
10
+ from snowflake.ml.jobs._utils import constants
12
11
 
13
12
  _PROJECT = "MLJob"
14
13
 
@@ -50,31 +49,12 @@ def remote(
50
49
  wrapped_func = copy.copy(func)
51
50
  wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
52
51
 
53
- # Validate function arguments based on signature
54
- signature = inspect.signature(func)
55
- pos_arg_names = []
56
- for name, param in signature.parameters.items():
57
- param_type = payload_utils.get_parameter_type(param)
58
- if param_type is not None:
59
- payload_utils.validate_parameter_type(param_type, name)
60
- if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
61
- pos_arg_names.append(name)
62
-
63
52
  @functools.wraps(func)
64
53
  def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
65
- # Validate positional args
66
- for i, arg in enumerate(args):
67
- arg_name = pos_arg_names[i] if i < len(pos_arg_names) else f"args[{i}]"
68
- payload_utils.validate_parameter_type(type(arg), arg_name)
69
-
70
- # Validate keyword args
71
- for k, v in kwargs.items():
72
- payload_utils.validate_parameter_type(type(v), k)
73
-
74
- arg_list = [str(v) for v in args] + [x for k, v in kwargs.items() for x in (f"--{k}", str(v))]
54
+ payload = functools.partial(func, *args, **kwargs)
55
+ setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
75
56
  job = jm._submit_job(
76
- source=wrapped_func,
77
- args=arg_list,
57
+ source=payload,
78
58
  stage_name=stage_name,
79
59
  compute_pool=compute_pool,
80
60
  pip_requirements=pip_requirements,
@@ -83,7 +63,7 @@ def remote(
83
63
  env_vars=env_vars,
84
64
  session=session,
85
65
  )
86
- assert isinstance(job, jb.MLJob)
66
+ assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
87
67
  return job
88
68
 
89
69
  return wrapper
snowflake/ml/jobs/job.py CHANGED
@@ -4,7 +4,7 @@ from typing import Any, List, Optional, cast
4
4
  from snowflake import snowpark
5
5
  from snowflake.ml._internal import telemetry
6
6
  from snowflake.ml.jobs._utils import constants, types
7
- from snowflake.snowpark.context import get_active_session
7
+ from snowflake.snowpark import context as sp_context
8
8
 
9
9
  _PROJECT = "MLJob"
10
10
  TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
@@ -13,7 +13,7 @@ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
13
13
  class MLJob:
14
14
  def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
15
15
  self._id = id
16
- self._session = session or get_active_session()
16
+ self._session = session or sp_context.get_active_session()
17
17
  self._status: types.JOB_STATUS = "PENDING"
18
18
 
19
19
  @property
@@ -79,7 +79,7 @@ class MLJob:
79
79
  return self.status
80
80
 
81
81
 
82
- @telemetry.send_api_usage_telemetry(project=_PROJECT)
82
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
83
83
  def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
84
84
  """Retrieve job execution status."""
85
85
  # TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
@@ -90,7 +90,7 @@ def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
90
90
  return cast(types.JOB_STATUS, row["status"])
91
91
 
92
92
 
93
- @telemetry.send_api_usage_telemetry(project=_PROJECT)
93
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
94
94
  def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
95
95
  """
96
96
  Retrieve the job's execution logs.
@@ -113,7 +113,33 @@ class ModelEnv:
113
113
  self._snowpark_ml_version = version.parse(snowpark_ml_version)
114
114
 
115
115
  def include_if_absent(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
116
- """Append requirements into model env if absent.
116
+ """Append requirements into model env if absent. Depending on the environment, requirements may be added
117
+ to either the pip requirements or conda dependencies.
118
+
119
+ Args:
120
+ pkgs: A list of ModelDependency namedtuple to be appended.
121
+ check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
122
+ """
123
+ if self.pip_requirements and not self.conda_dependencies and pkgs:
124
+ pip_pkg_reqs: List[str] = []
125
+ warnings.warn(
126
+ (
127
+ "Dependencies specified from pip requirements."
128
+ " This may prevent model deploying to Snowflake Warehouse."
129
+ ),
130
+ category=UserWarning,
131
+ stacklevel=2,
132
+ )
133
+ for conda_req_str, pip_name in pkgs:
134
+ _, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
135
+ pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
136
+ pip_pkg_reqs.append(str(pip_req))
137
+ self._include_if_absent_pip(pip_pkg_reqs, check_local_version)
138
+ else:
139
+ self._include_if_absent_conda(pkgs, check_local_version)
140
+
141
+ def _include_if_absent_conda(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
142
+ """Append requirements into model env conda dependencies if absent.
117
143
 
118
144
  Args:
119
145
  pkgs: A list of ModelDependency namedtuple to be appended.
@@ -134,8 +160,8 @@ class ModelEnv:
134
160
  if show_warning_message:
135
161
  warnings.warn(
136
162
  (
137
- f"Basic dependency {req_to_add.name} specified from PIP requirements."
138
- + " This may prevent model deploying to Snowflake Warehouse."
163
+ f"Basic dependency {req_to_add.name} specified from pip requirements."
164
+ " This may prevent model deploying to Snowflake Warehouse."
139
165
  ),
140
166
  category=UserWarning,
141
167
  stacklevel=2,
@@ -157,11 +183,11 @@ class ModelEnv:
157
183
  stacklevel=2,
158
184
  )
159
185
 
160
- def include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
161
- """Append pip requirements into model env if absent.
186
+ def _include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
187
+ """Append pip requirements into model env pip requirements if absent.
162
188
 
163
189
  Args:
164
- pkgs: A list of string to be appended in pip requirement.
190
+ pkgs: A list of strings to be appended to pip environment.
165
191
  check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
166
192
  """
167
193
 
@@ -187,25 +213,6 @@ class ModelEnv:
187
213
  self._conda_dependencies[channel].remove(spec)
188
214
 
189
215
  def generate_env_for_cuda(self) -> None:
190
- if self.cuda_version is None:
191
- return
192
-
193
- cuda_spec = env_utils.find_dep_spec(
194
- self._conda_dependencies, self._pip_requirements, conda_pkg_name="cuda", remove_spec=False
195
- )
196
- if cuda_spec and not cuda_spec.specifier.contains(self.cuda_version):
197
- raise ValueError(
198
- "The CUDA requirement you specified in your conda dependencies or pip requirements is"
199
- " conflicting with CUDA version required. Please do not specify CUDA dependency using conda"
200
- " dependencies or pip requirements."
201
- )
202
-
203
- if not cuda_spec:
204
- self.include_if_absent(
205
- [ModelDependency(requirement=f"nvidia::cuda=={self.cuda_version}.*", pip_name="cuda")],
206
- check_local_version=False,
207
- )
208
-
209
216
  xgboost_spec = env_utils.find_dep_spec(
210
217
  self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
211
218
  )
@@ -236,7 +243,7 @@ class ModelEnv:
236
243
  check_local_version=False,
237
244
  )
238
245
 
239
- self.include_if_absent_pip(["bitsandbytes>=0.41.0"], check_local_version=False)
246
+ self._include_if_absent_pip(["bitsandbytes>=0.41.0"], check_local_version=False)
240
247
 
241
248
  def relax_version(self) -> None:
242
249
  """Relax the version requirements for both conda dependencies and pip requirements.
@@ -252,7 +259,9 @@ class ModelEnv:
252
259
  self._pip_requirements = list(map(env_utils.relax_requirement_version, self._pip_requirements))
253
260
 
254
261
  def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None:
255
- conda_dependencies_dict, pip_requirements_list, python_version = env_utils.load_conda_env_file(conda_env_path)
262
+ conda_dependencies_dict, pip_requirements_list, python_version, cuda_version = env_utils.load_conda_env_file(
263
+ conda_env_path
264
+ )
256
265
 
257
266
  for channel, channel_dependencies in conda_dependencies_dict.items():
258
267
  if channel != env_utils.DEFAULT_CHANNEL_NAME:
@@ -310,6 +319,9 @@ class ModelEnv:
310
319
  if python_version:
311
320
  self.python_version = python_version
312
321
 
322
+ if cuda_version:
323
+ self.cuda_version = cuda_version
324
+
313
325
  def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
314
326
  pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
315
327
 
@@ -342,12 +354,17 @@ class ModelEnv:
342
354
  self.snowpark_ml_version = env_dict["snowpark_ml_version"]
343
355
 
344
356
  def save_as_dict(
345
- self, base_dir: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
357
+ self,
358
+ base_dir: pathlib.Path,
359
+ default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL,
360
+ is_gpu: Optional[bool] = False,
346
361
  ) -> model_meta_schema.ModelEnvDict:
362
+ cuda_version = self.cuda_version if is_gpu else None
347
363
  env_utils.save_conda_env_file(
348
364
  pathlib.Path(base_dir / self.conda_env_rel_path),
349
365
  self._conda_dependencies,
350
366
  self.python_version,
367
+ cuda_version,
351
368
  default_channel_override=default_channel_override,
352
369
  )
353
370
  env_utils.save_requirements_file(
@@ -39,7 +39,7 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
39
39
 
40
40
 
41
41
  def get_truncated_sample_data(
42
- sample_input_data: model_types.SupportedDataType, length: int = 100
42
+ sample_input_data: model_types.SupportedDataType, length: int = 100, is_for_modeling_model: bool = False
43
43
  ) -> model_types.SupportedLocalDataType:
44
44
  trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
45
45
  local_sample_input: model_types.SupportedLocalDataType = None
@@ -47,6 +47,8 @@ def get_truncated_sample_data(
47
47
  # Added because of Any from missing stubs.
48
48
  trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
49
49
  local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
50
+ if is_for_modeling_model:
51
+ local_sample_input.columns = trunc_sample_input.columns
50
52
  else:
51
53
  local_sample_input = trunc_sample_input
52
54
  return local_sample_input
@@ -58,13 +60,15 @@ def validate_signature(
58
60
  target_methods: Iterable[str],
59
61
  sample_input_data: Optional[model_types.SupportedDataType],
60
62
  get_prediction_fn: Callable[[str, model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
63
+ is_for_modeling_model: bool = False,
61
64
  ) -> model_meta.ModelMetadata:
62
65
  if model_meta.signatures:
63
66
  validate_target_methods(model, list(model_meta.signatures.keys()))
64
67
  if sample_input_data is not None:
65
- local_sample_input = get_truncated_sample_data(sample_input_data)
68
+ local_sample_input = get_truncated_sample_data(
69
+ sample_input_data, is_for_modeling_model=is_for_modeling_model
70
+ )
66
71
  for target_method in model_meta.signatures.keys():
67
-
68
72
  model_signature_inst = model_meta.signatures.get(target_method)
69
73
  if model_signature_inst is not None:
70
74
  # strict validation the input signature
@@ -77,7 +81,7 @@ def validate_signature(
77
81
  assert (
78
82
  sample_input_data is not None
79
83
  ), "Model signature and sample input are None at the same time. This should not happen with local model."
80
- local_sample_input = get_truncated_sample_data(sample_input_data)
84
+ local_sample_input = get_truncated_sample_data(sample_input_data, is_for_modeling_model=is_for_modeling_model)
81
85
  for target_method in target_methods:
82
86
  predictions_df = get_prediction_fn(target_method, local_sample_input)
83
87
  sig = model_signature.infer_signature(
@@ -146,6 +146,10 @@ class HuggingFacePipelineHandler(
146
146
  framework = getattr(model, "framework", None)
147
147
  batch_size = getattr(model, "batch_size", None)
148
148
 
149
+ has_tokenizer = getattr(model, "tokenizer", None) is not None
150
+ has_feature_extractor = getattr(model, "feature_extractor", None) is not None
151
+ has_image_preprocessor = getattr(model, "image_preprocessor", None) is not None
152
+
149
153
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
150
154
  params = {
151
155
  **model._preprocess_params, # type:ignore[attr-defined]
@@ -234,6 +238,9 @@ class HuggingFacePipelineHandler(
234
238
  {
235
239
  "task": task,
236
240
  "batch_size": batch_size if batch_size is not None else 1,
241
+ "has_tokenizer": has_tokenizer,
242
+ "has_feature_extractor": has_feature_extractor,
243
+ "has_image_preprocessor": has_image_preprocessor,
237
244
  }
238
245
  ),
239
246
  )
@@ -308,6 +315,14 @@ class HuggingFacePipelineHandler(
308
315
  if os.path.isdir(model_blob_file_or_dir_path):
309
316
  import transformers
310
317
 
318
+ additional_pipeline_params = {}
319
+ if model_blob_options.get("has_tokenizer", False):
320
+ additional_pipeline_params["tokenizer"] = model_blob_file_or_dir_path
321
+ if model_blob_options.get("has_feature_extractor", False):
322
+ additional_pipeline_params["feature_extractor"] = model_blob_file_or_dir_path
323
+ if model_blob_options.get("has_image_preprocessor", False):
324
+ additional_pipeline_params["image_preprocessor"] = model_blob_file_or_dir_path
325
+
311
326
  with open(
312
327
  os.path.join(
313
328
  model_blob_file_or_dir_path,
@@ -324,6 +339,7 @@ class HuggingFacePipelineHandler(
324
339
  model=model_blob_file_or_dir_path,
325
340
  trust_remote_code=True,
326
341
  torch_dtype="auto",
342
+ **additional_pipeline_params,
327
343
  **device_config,
328
344
  )
329
345