snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +82 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/utils/identifier.py +4 -2
  12. snowflake/ml/data/__init__.py +3 -0
  13. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  14. snowflake/ml/data/data_connector.py +53 -11
  15. snowflake/ml/data/data_ingestor.py +2 -1
  16. snowflake/ml/data/torch_utils.py +18 -5
  17. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  18. snowflake/ml/fileset/fileset.py +18 -18
  19. snowflake/ml/model/_client/model/model_version_impl.py +5 -3
  20. snowflake/ml/model/_client/ops/model_ops.py +2 -6
  21. snowflake/ml/model/_client/sql/model_version.py +11 -0
  22. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  23. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  25. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  26. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  31. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  32. snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
  33. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  34. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
  35. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  36. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  37. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
  38. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  39. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  40. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  43. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  44. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  45. snowflake/ml/model/_signatures/pandas_handler.py +1 -1
  46. snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
  47. snowflake/ml/model/type_hints.py +1 -0
  48. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  49. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  50. snowflake/ml/modeling/pipeline/pipeline.py +6 -176
  51. snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
  52. snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
  53. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
  54. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
  55. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
  56. snowflake/ml/registry/_manager/model_manager.py +70 -33
  57. snowflake/ml/registry/registry.py +41 -22
  58. snowflake/ml/version.py +1 -1
  59. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +38 -9
  60. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +63 -67
  61. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
  62. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  63. snowflake/ml/fileset/parquet_parser.py +0 -170
  64. snowflake/ml/fileset/tf_dataset.py +0 -88
  65. snowflake/ml/fileset/torch_datapipe.py +0 -57
  66. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  67. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  68. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
  69. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,16 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Dict,
6
+ Generator,
7
+ List,
8
+ Optional,
9
+ Sequence,
10
+ Type,
11
+ TypeVar,
12
+ cast,
13
+ )
3
14
 
4
15
  import numpy.typing as npt
5
16
  from typing_extensions import deprecated
@@ -12,6 +23,7 @@ from snowflake.ml.modeling._internal.constants import (
12
23
  IN_ML_RUNTIME_ENV_VAR,
13
24
  USE_OPTIMIZED_DATA_INGESTOR,
14
25
  )
26
+ from snowflake.snowpark import context as sf_context
15
27
 
16
28
  if TYPE_CHECKING:
17
29
  import pandas as pd
@@ -35,8 +47,10 @@ class DataConnector:
35
47
  def __init__(
36
48
  self,
37
49
  ingestor: data_ingestor.DataIngestor,
50
+ **kwargs: Any,
38
51
  ) -> None:
39
52
  self._ingestor = ingestor
53
+ self._kwargs = kwargs
40
54
 
41
55
  @classmethod
42
56
  @snowpark._internal.utils.private_preview(version="1.6.0")
@@ -44,20 +58,34 @@ class DataConnector:
44
58
  cls: Type[DataConnectorType],
45
59
  df: snowpark.DataFrame,
46
60
  ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
47
- **kwargs: Any
61
+ **kwargs: Any,
48
62
  ) -> DataConnectorType:
49
63
  if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
50
64
  raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
51
- source = data_source.DataFrameInfo(df.queries["queries"][0])
52
- assert df._session is not None
53
- return cls.from_sources(df._session, [source], ingestor_class=ingestor_class, **kwargs)
65
+ return cast(
66
+ DataConnectorType,
67
+ cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs),
68
+ )
69
+
70
+ @classmethod
71
+ @snowpark._internal.utils.private_preview(version="1.7.3")
72
+ def from_sql(
73
+ cls: Type[DataConnectorType],
74
+ query: str,
75
+ session: Optional[snowpark.Session] = None,
76
+ ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
77
+ **kwargs: Any,
78
+ ) -> DataConnectorType:
79
+ session = session or sf_context.get_active_session()
80
+ source = data_source.DataFrameInfo(query)
81
+ return cls.from_sources(session, [source], ingestor_class=ingestor_class, **kwargs)
54
82
 
55
83
  @classmethod
56
84
  def from_dataset(
57
85
  cls: Type[DataConnectorType],
58
86
  ds: "dataset.Dataset",
59
87
  ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
60
- **kwargs: Any
88
+ **kwargs: Any,
61
89
  ) -> DataConnectorType:
62
90
  dsv = ds.selected_version
63
91
  assert dsv is not None
@@ -75,9 +103,9 @@ class DataConnector:
75
103
  def from_sources(
76
104
  cls: Type[DataConnectorType],
77
105
  session: snowpark.Session,
78
- sources: List[data_source.DataSource],
106
+ sources: Sequence[data_source.DataSource],
79
107
  ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
80
- **kwargs: Any
108
+ **kwargs: Any,
81
109
  ) -> DataConnectorType:
82
110
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
83
111
  ingestor = ingestor_class.from_sources(session, sources)
@@ -130,7 +158,11 @@ class DataConnector:
130
158
  func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
131
159
  )
132
160
  def to_torch_datapipe(
133
- self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
161
+ self,
162
+ *,
163
+ batch_size: int,
164
+ shuffle: bool = False,
165
+ drop_last_batch: bool = True,
134
166
  ) -> "torch_data.IterDataPipe": # type: ignore[type-arg]
135
167
  """Transform the Snowflake data into a ready-to-use Pytorch datapipe.
136
168
 
@@ -149,8 +181,13 @@ class DataConnector:
149
181
  """
150
182
  from snowflake.ml.data import torch_utils
151
183
 
184
+ expand_dims = self._kwargs.get("expand_dims", True)
152
185
  return torch_utils.TorchDataPipeWrapper(
153
- self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
186
+ self._ingestor,
187
+ batch_size=batch_size,
188
+ shuffle=shuffle,
189
+ drop_last=drop_last_batch,
190
+ expand_dims=expand_dims,
154
191
  )
155
192
 
156
193
  @telemetry.send_api_usage_telemetry(
@@ -179,8 +216,13 @@ class DataConnector:
179
216
  """
180
217
  from snowflake.ml.data import torch_utils
181
218
 
219
+ expand_dims = self._kwargs.get("expand_dims", True)
182
220
  return torch_utils.TorchDatasetWrapper(
183
- self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
221
+ self._ingestor,
222
+ batch_size=batch_size,
223
+ shuffle=shuffle,
224
+ drop_last=drop_last_batch,
225
+ expand_dims=expand_dims,
184
226
  )
185
227
 
186
228
  @telemetry.send_api_usage_telemetry(
@@ -6,6 +6,7 @@ from typing import (
6
6
  List,
7
7
  Optional,
8
8
  Protocol,
9
+ Sequence,
9
10
  Type,
10
11
  TypeVar,
11
12
  )
@@ -25,7 +26,7 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
25
26
  class DataIngestor(Protocol):
26
27
  @classmethod
27
28
  def from_sources(
28
- cls: Type[DataIngestorType], session: snowpark.Session, sources: List[data_source.DataSource]
29
+ cls: Type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
29
30
  ) -> DataIngestorType:
30
31
  raise NotImplementedError
31
32
 
@@ -17,6 +17,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
17
17
  batch_size: Optional[int],
18
18
  shuffle: bool = False,
19
19
  drop_last: bool = False,
20
+ expand_dims: bool = True,
20
21
  ) -> None:
21
22
  """Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
22
23
  squeeze = False
@@ -29,6 +30,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
29
30
  self._shuffle = shuffle
30
31
  self._drop_last = drop_last
31
32
  self._squeeze_outputs = squeeze
33
+ self._expand_dims = expand_dims
32
34
 
33
35
  def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]:
34
36
  max_idx = 0
@@ -47,7 +49,10 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
47
49
  ):
48
50
  # Skip indices during multi-process data loading to prevent data duplication
49
51
  if counter == filter_idx:
50
- yield {k: _preprocess_array(v, squeeze=self._squeeze_outputs) for k, v in batch.items()}
52
+ yield {
53
+ k: _preprocess_array(v, squeeze=self._squeeze_outputs, expand_dims=self._expand_dims)
54
+ for k, v in batch.items()
55
+ }
51
56
  if counter < max_idx:
52
57
  counter += 1
53
58
  else:
@@ -58,13 +63,21 @@ class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Di
58
63
  """Wrap a DataIngestor into a PyTorch IterDataPipe"""
59
64
 
60
65
  def __init__(
61
- self, ingestor: data_ingestor.DataIngestor, *, batch_size: int, shuffle: bool = False, drop_last: bool = False
66
+ self,
67
+ ingestor: data_ingestor.DataIngestor,
68
+ *,
69
+ batch_size: int,
70
+ shuffle: bool = False,
71
+ drop_last: bool = False,
72
+ expand_dims: bool = True,
62
73
  ) -> None:
63
74
  """Not intended for direct usage. Use DataConnector.to_torch_datapipe() instead"""
64
- super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
75
+ super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, expand_dims=expand_dims)
65
76
 
66
77
 
67
- def _preprocess_array(arr: npt.NDArray[Any], squeeze: bool = False) -> Union[npt.NDArray[Any], List[np.object_]]:
78
+ def _preprocess_array(
79
+ arr: npt.NDArray[Any], squeeze: bool = False, expand_dims: bool = True
80
+ ) -> Union[npt.NDArray[Any], List[np.object_]]:
68
81
  """Preprocesses batch column values."""
69
82
  single_dimensional = arr.ndim < 2 and not arr.dtype == np.object_
70
83
 
@@ -73,7 +86,7 @@ def _preprocess_array(arr: npt.NDArray[Any], squeeze: bool = False) -> Union[npt
73
86
  arr = arr.squeeze(axis=0)
74
87
 
75
88
  # For single dimensional data,
76
- if single_dimensional:
89
+ if single_dimensional and expand_dims:
77
90
  axis = 0 if arr.ndim == 0 else 1
78
91
  arr = np.expand_dims(arr, axis=axis)
79
92
 
@@ -45,8 +45,9 @@ class ExampleHelper:
45
45
  """Return a dataframe object about descriptions of all examples."""
46
46
  root_dir = Path(__file__).parent
47
47
  rows = []
48
+ hide_folders = ["citibike_trip_features", "source_data"]
48
49
  for f_name in os.listdir(root_dir):
49
- if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name != "source_data":
50
+ if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name not in hide_folders:
50
51
  source_file_path = root_dir.joinpath(f"{f_name}/source.yaml")
51
52
  source_dict = self._read_yaml(str(source_file_path))
52
53
  rows.append((f_name, source_dict["model_category"], source_dict["desc"], source_dict["label_columns"]))
@@ -11,11 +11,9 @@ from snowflake.ml._internal.exceptions import (
11
11
  fileset_error_messages,
12
12
  fileset_errors,
13
13
  )
14
- from snowflake.ml._internal.utils import (
15
- identifier,
16
- import_utils,
17
- snowpark_dataframe_utils,
18
- )
14
+ from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
15
+ from snowflake.ml.data import data_connector
16
+ from snowflake.ml.data._internal import arrow_ingestor
19
17
  from snowflake.ml.fileset import sfcfs
20
18
  from snowflake.snowpark import exceptions as snowpark_exceptions, functions
21
19
 
@@ -285,6 +283,16 @@ class FileSet:
285
283
  """Get the Snowflake absolute path to this FileSet directory."""
286
284
  return _fileset_absolute_path(self._target_stage_loc, self.name)
287
285
 
286
+ def _to_data_connector(self) -> data_connector.DataConnector:
287
+ self._fs.optimize_read(self._list_files())
288
+ ingester = arrow_ingestor.ArrowIngestor(
289
+ self._snowpark_session,
290
+ self._list_files(),
291
+ format="parquet",
292
+ filesystem=self._fs,
293
+ )
294
+ return data_connector.DataConnector(ingester, expand_dims=False)
295
+
288
296
  @telemetry.send_api_usage_telemetry(
289
297
  project=_PROJECT,
290
298
  )
@@ -362,13 +370,9 @@ class FileSet:
362
370
  ----
363
371
  {'_COL_1':[10]}
364
372
  """
365
- IterableWrapper, _ = import_utils.import_or_get_dummy("torchdata.datapipes.iter.IterableWrapper")
366
- torch_datapipe_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.torch_datapipe")
367
-
368
- self._fs.optimize_read(self._list_files())
369
-
370
- input_dp = IterableWrapper(self._list_files())
371
- return torch_datapipe_module.ReadAndParseParquet(input_dp, self._fs, batch_size, shuffle, drop_last_batch)
373
+ return self._to_data_connector().to_torch_datapipe(
374
+ batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
375
+ )
372
376
 
373
377
  @telemetry.send_api_usage_telemetry(
374
378
  project=_PROJECT,
@@ -402,12 +406,8 @@ class FileSet:
402
406
  ----
403
407
  {'_COL_1': <tf.Tensor: shape=(1,), dtype=int64, numpy=[10]>}
404
408
  """
405
- tf_dataset_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.tf_dataset")
406
-
407
- self._fs.optimize_read(self._list_files())
408
-
409
- return tf_dataset_module.read_and_parse_parquet(
410
- self._list_files(), self._fs, batch_size, shuffle, drop_last_batch
409
+ return self._to_data_connector().to_tf_dataset(
410
+ batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
411
411
  )
412
412
 
413
413
  @telemetry.send_api_usage_telemetry(
@@ -447,13 +447,15 @@ class ModelVersion(lineage_node.LineageNode):
447
447
  target_function_info = functions[0]
448
448
 
449
449
  if service_name:
450
+ database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
451
+
450
452
  return self._model_ops.invoke_method(
451
453
  method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
452
454
  signature=target_function_info["signature"],
453
455
  X=X,
454
- database_name=None,
455
- schema_name=None,
456
- service_name=sql_identifier.SqlIdentifier(service_name),
456
+ database_name=database_name_id,
457
+ schema_name=schema_name_id,
458
+ service_name=service_name_id,
457
459
  strict_input_validation=strict_input_validation,
458
460
  statement_params=statement_params,
459
461
  )
@@ -168,14 +168,10 @@ class ModelOperator:
168
168
  schema_name: Optional[sql_identifier.SqlIdentifier],
169
169
  model_name: sql_identifier.SqlIdentifier,
170
170
  version_name: sql_identifier.SqlIdentifier,
171
+ model_exists: bool,
171
172
  statement_params: Optional[Dict[str, Any]] = None,
172
173
  ) -> None:
173
- if self.validate_existence(
174
- database_name=database_name,
175
- schema_name=schema_name,
176
- model_name=model_name,
177
- statement_params=statement_params,
178
- ):
174
+ if model_exists:
179
175
  return self._model_version_client.add_version_from_model_version(
180
176
  source_database_name=source_database_name,
181
177
  source_schema_name=source_schema_name,
@@ -10,6 +10,7 @@ from snowflake.ml._internal.utils import (
10
10
  sql_identifier,
11
11
  )
12
12
  from snowflake.ml.model._client.sql import _base
13
+ from snowflake.ml.model._model_composer.model_method import constants
13
14
  from snowflake.snowpark import dataframe, functions as F, row, types as spt
14
15
  from snowflake.snowpark._internal import utils as snowpark_utils
15
16
 
@@ -333,6 +334,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
333
334
 
334
335
  args_sql = ", ".join(args_sql_list)
335
336
 
337
+ wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
338
+ if wide_input:
339
+ input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
340
+ args_sql = f"object_construct_keep_null({input_args_sql})"
341
+
336
342
  sql = textwrap.dedent(
337
343
  f"""WITH {','.join(with_statements)}
338
344
  SELECT *,
@@ -412,6 +418,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
412
418
 
413
419
  args_sql = ", ".join(args_sql_list)
414
420
 
421
+ wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
422
+ if wide_input:
423
+ input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
424
+ args_sql = f"object_construct_keep_null({input_args_sql})"
425
+
415
426
  sql = textwrap.dedent(
416
427
  f"""WITH {','.join(with_statements)}
417
428
  SELECT *,
@@ -88,6 +88,7 @@ class ModelComposer:
88
88
  pip_requirements: Optional[List[str]] = None,
89
89
  target_platforms: Optional[List[model_types.TargetPlatform]] = None,
90
90
  python_version: Optional[str] = None,
91
+ user_files: Optional[Dict[str, List[str]]] = None,
91
92
  ext_modules: Optional[List[ModuleType]] = None,
92
93
  code_paths: Optional[List[str]] = None,
93
94
  task: model_types.Task = model_types.Task.UNKNOWN,
@@ -97,9 +98,12 @@ class ModelComposer:
97
98
  options = model_types.BaseModelSaveOption()
98
99
 
99
100
  if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
100
- snowml_matched_versions = env_utils.get_matched_package_versions_in_snowflake_conda_channel(
101
- req=requirements.Requirement(f"snowflake-ml-python=={snowml_env.VERSION}")
102
- )
101
+ snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
102
+ self.session,
103
+ reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_env.VERSION}")],
104
+ python_version=python_version or snowml_env.PYTHON_VERSION,
105
+ statement_params=self._statement_params,
106
+ ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
103
107
 
104
108
  if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
105
109
  logging.info(
@@ -131,6 +135,7 @@ class ModelComposer:
131
135
  model_meta=self.packager.meta,
132
136
  model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
133
137
  options=options,
138
+ user_files=user_files,
134
139
  data_sources=self._get_data_sources(model, sample_input_data),
135
140
  target_platforms=target_platforms,
136
141
  )
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import pathlib
4
4
  import warnings
5
- from typing import List, Optional, cast
5
+ from typing import Dict, List, Optional, cast
6
6
 
7
7
  import yaml
8
8
 
@@ -11,9 +11,11 @@ from snowflake.ml.data import data_source
11
11
  from snowflake.ml.model import type_hints
12
12
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
13
13
  from snowflake.ml.model._model_composer.model_method import (
14
+ constants,
14
15
  function_generator,
15
16
  model_method,
16
17
  )
18
+ from snowflake.ml.model._model_composer.model_user_file import model_user_file
17
19
  from snowflake.ml.model._packager.model_meta import (
18
20
  model_meta as model_meta_api,
19
21
  model_meta_schema,
@@ -30,9 +32,11 @@ class ModelManifest:
30
32
  workspace_path: A local path where model related files should be dumped to.
31
33
  runtimes: A list of ModelRuntime objects managing the runtimes and environment in the MODEL object.
32
34
  methods: A list of ModelMethod objects managing the method we registered to the MODEL object.
35
+ user_files: A list of ModelUserFile objects managing extra files uploaded to the workspace.
33
36
  """
34
37
 
35
38
  MANIFEST_FILE_REL_PATH = "MANIFEST.yml"
39
+ _ENABLE_USER_FILES = False
36
40
  _DEFAULT_RUNTIME_NAME = "python_runtime"
37
41
 
38
42
  def __init__(self, workspace_path: pathlib.Path) -> None:
@@ -42,6 +46,7 @@ class ModelManifest:
42
46
  self,
43
47
  model_meta: model_meta_api.ModelMetadata,
44
48
  model_rel_path: pathlib.PurePosixPath,
49
+ user_files: Optional[Dict[str, List[str]]] = None,
45
50
  options: Optional[type_hints.ModelSaveOption] = None,
46
51
  data_sources: Optional[List[data_source.DataSource]] = None,
47
52
  target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
@@ -79,6 +84,7 @@ class ModelManifest:
79
84
 
80
85
  self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
81
86
  self.methods: List[model_method.ModelMethod] = []
87
+
82
88
  for target_method in model_meta.signatures.keys():
83
89
  method = model_method.ModelMethod(
84
90
  model_meta=model_meta,
@@ -88,11 +94,21 @@ class ModelManifest:
88
94
  is_partitioned_function=model_meta.function_properties.get(target_method, {}).get(
89
95
  model_meta_schema.FunctionProperties.PARTITIONED.value, False
90
96
  ),
97
+ wide_input=len(model_meta.signatures[target_method].inputs) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT,
91
98
  options=model_method.get_model_method_options_from_options(options, target_method),
92
99
  )
93
100
 
94
101
  self.methods.append(method)
95
102
 
103
+ self.user_files: List[model_user_file.ModelUserFile] = []
104
+
105
+ if user_files is not None:
106
+ for subdirectory, paths in user_files.items():
107
+ for path in paths:
108
+ self.user_files.append(
109
+ model_user_file.ModelUserFile(pathlib.PurePosixPath(subdirectory), pathlib.Path(path))
110
+ )
111
+
96
112
  method_name_counter = collections.Counter([method.method_name for method in self.methods])
97
113
  dup_method_names = [k for k, v in method_name_counter.items() if v > 1]
98
114
  if dup_method_names:
@@ -129,6 +145,9 @@ class ModelManifest:
129
145
  ],
130
146
  )
131
147
 
148
+ if self._ENABLE_USER_FILES:
149
+ manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
150
+
132
151
  lineage_sources = self._extract_lineage_info(data_sources)
133
152
  if lineage_sources:
134
153
  manifest_dict["lineage_sources"] = lineage_sources
@@ -94,5 +94,6 @@ class ModelManifestDict(TypedDict):
94
94
  runtimes: Required[Dict[str, ModelRuntimeDict]]
95
95
  methods: Required[List[ModelMethodDict]]
96
96
  user_data: NotRequired[Dict[str, Any]]
97
+ user_files: NotRequired[List[str]]
97
98
  lineage_sources: NotRequired[List[LineageSourceDict]]
98
99
  target_platforms: NotRequired[List[str]]
@@ -0,0 +1 @@
1
+ SNOWPARK_UDF_INPUT_COL_LIMIT = 500
@@ -43,6 +43,7 @@ class FunctionGenerator:
43
43
  target_method: str,
44
44
  function_type: str,
45
45
  is_partitioned_function: bool = False,
46
+ wide_input: bool = False,
46
47
  options: Optional[FunctionGenerateOptions] = None,
47
48
  ) -> None:
48
49
  import importlib_resources
@@ -70,6 +71,7 @@ class FunctionGenerator:
70
71
  model_dir_name=self.model_dir_rel_path.name,
71
72
  target_method=target_method,
72
73
  max_batch_size=options.get("max_batch_size", None),
74
+ wide_input=wide_input,
73
75
  function_name=FunctionGenerator.FUNCTION_NAME,
74
76
  )
75
77
  with open(function_file_path, "w", encoding="utf-8") as f:
@@ -43,7 +43,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
43
43
 
44
44
 
45
45
  # Actual function
46
- @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
46
+ @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
47
  def {function_name}(df: pd.DataFrame) -> dict:
48
48
  df.columns = input_cols
49
49
  input_df = df.astype(dtype=dtype_map)
@@ -48,7 +48,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
48
48
 
49
49
  # Actual table function
50
50
  class {function_name}:
51
- @vectorized(input=pd.DataFrame)
51
+ @vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
52
52
  def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
53
53
  df.columns = input_cols
54
54
  input_df = df.astype(dtype=dtype_map)
@@ -43,7 +43,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
43
43
 
44
44
  # Actual table function
45
45
  class {function_name}:
46
- @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
46
+ @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
47
  def process(self, df: pd.DataFrame) -> pd.DataFrame:
48
48
  df.columns = input_cols
49
49
  input_df = df.astype(dtype=dtype_map)
@@ -7,7 +7,10 @@ from typing_extensions import NotRequired
7
7
  from snowflake.ml._internal.utils import sql_identifier
8
8
  from snowflake.ml.model import model_signature, type_hints
9
9
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
10
- from snowflake.ml.model._model_composer.model_method import function_generator
10
+ from snowflake.ml.model._model_composer.model_method import (
11
+ constants,
12
+ function_generator,
13
+ )
11
14
  from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
12
15
  from snowflake.snowpark._internal import type_utils
13
16
 
@@ -64,6 +67,7 @@ class ModelMethod:
64
67
  runtime_name: str,
65
68
  function_generator: function_generator.FunctionGenerator,
66
69
  is_partitioned_function: bool = False,
70
+ wide_input: bool = False,
67
71
  options: Optional[ModelMethodOptions] = None,
68
72
  ) -> None:
69
73
  self.model_meta = model_meta
@@ -71,6 +75,7 @@ class ModelMethod:
71
75
  self.function_generator = function_generator
72
76
  self.is_partitioned_function = is_partitioned_function
73
77
  self.runtime_name = runtime_name
78
+ self.wide_input = wide_input
74
79
  self.options = options or {}
75
80
  try:
76
81
  self.method_name = sql_identifier.SqlIdentifier(
@@ -114,12 +119,15 @@ class ModelMethod:
114
119
  self.target_method,
115
120
  self.function_type,
116
121
  self.is_partitioned_function,
122
+ self.wide_input,
117
123
  options=options,
118
124
  )
119
125
  input_list = [
120
126
  ModelMethod._get_method_arg_from_feature(ft, case_sensitive=self.options.get("case_sensitive", False))
121
127
  for ft in self.model_meta.signatures[self.target_method].inputs
122
128
  ]
129
+ if len(input_list) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT:
130
+ input_list = [{"name": "INPUT", "type": "OBJECT"}]
123
131
  input_name_counter = collections.Counter([input_info["name"] for input_info in input_list])
124
132
  dup_input_names = [k for k, v in input_name_counter.items() if v > 1]
125
133
  if dup_input_names:
@@ -0,0 +1,27 @@
1
+ import os
2
+ import pathlib
3
+
4
+ from snowflake.ml._internal import file_utils
5
+
6
+
7
+ class ModelUserFile:
8
+ """Class representing a user provided file.
9
+
10
+ Attributes:
11
+ subdirectory_name: A local path where model related files should be dumped to.
12
+ local_path: A list of ModelRuntime objects managing the runtimes and environment in the MODEL object.
13
+ """
14
+
15
+ USER_FILES_DIR_REL_PATH = "user_files"
16
+
17
+ def __init__(self, subdirectory_name: pathlib.PurePosixPath, local_path: pathlib.Path) -> None:
18
+ self.subdirectory_name = subdirectory_name
19
+ self.local_path = local_path
20
+
21
+ def save(self, workspace_path: pathlib.Path) -> str:
22
+ user_files_path = workspace_path / ModelUserFile.USER_FILES_DIR_REL_PATH / self.subdirectory_name
23
+ user_files_path.mkdir(parents=True, exist_ok=True)
24
+
25
+ # copy the file to the workspace
26
+ file_utils.copy_file_or_tree(str(self.local_path), str(user_files_path))
27
+ return os.path.join(self.subdirectory_name, self.local_path.name)
@@ -1,7 +1,8 @@
1
1
  import json
2
2
  import os
3
+ import pathlib
3
4
  import warnings
4
- from typing import Any, Callable, Iterable, List, Optional, Sequence, cast
5
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast
5
6
 
6
7
  import numpy as np
7
8
  import numpy.typing as npt
@@ -118,7 +119,7 @@ def get_explainability_supported_background(
118
119
  meta: model_meta.ModelMetadata,
119
120
  explain_target_method: Optional[str],
120
121
  ) -> pd.DataFrame:
121
- if sample_input_data is None:
122
+ if sample_input_data is None or explain_target_method is None:
122
123
  return None
123
124
 
124
125
  if isinstance(sample_input_data, pd.DataFrame):
@@ -223,3 +224,27 @@ def get_explain_target_method(
223
224
  if method in target_methods_list:
224
225
  return method
225
226
  return None
227
+
228
+
229
+ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
230
+ import huggingface_hub
231
+
232
+ for f_path in pathlib.Path(local_model_path).iterdir():
233
+ if f_path.name in ["config.json", "tokenizer_config.json"]:
234
+ with open(f_path) as f:
235
+ config_dict = json.load(f)
236
+
237
+ # a. get repository and class_path from configs
238
+ auto_map_configs = cast(Dict[str, str], config_dict.get("auto_map", {}))
239
+ for config_name, config_value in auto_map_configs.items():
240
+ repository, _, class_path = config_value.rpartition("--")
241
+
242
+ # b. download required configs from hf hub
243
+ if repository:
244
+ huggingface_hub.snapshot_download(repo_id=repository, local_dir=local_model_path)
245
+
246
+ # c. update config files
247
+ config_dict["auto_map"][config_name] = class_path
248
+
249
+ with open(f_path, "w") as f:
250
+ json.dump(config_dict, f)
@@ -94,8 +94,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
94
94
  sample_input_data=sample_input_data,
95
95
  get_prediction_fn=get_prediction,
96
96
  )
97
- model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
98
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
97
+ model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
98
+ model_meta.task = model_task_and_output.task
99
99
  if enable_explainability:
100
100
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
101
101
  model_meta = handlers_utils.add_explain_method_signature(
@@ -227,7 +227,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
227
227
  import shap
228
228
 
229
229
  explainer = shap.TreeExplainer(raw_model)
230
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
230
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
231
231
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
232
232
 
233
233
  if target_method == "explain":