snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +82 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +18 -18
- snowflake/ml/model/_client/model/model_version_impl.py +5 -3
- snowflake/ml/model/_client/ops/model_ops.py +2 -6
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/pandas_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/pipeline/pipeline.py +6 -176
- snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
- snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +41 -22
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +38 -9
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +63 -67
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,16 @@
|
|
1
1
|
import os
|
2
|
-
from typing import
|
2
|
+
from typing import (
|
3
|
+
TYPE_CHECKING,
|
4
|
+
Any,
|
5
|
+
Dict,
|
6
|
+
Generator,
|
7
|
+
List,
|
8
|
+
Optional,
|
9
|
+
Sequence,
|
10
|
+
Type,
|
11
|
+
TypeVar,
|
12
|
+
cast,
|
13
|
+
)
|
3
14
|
|
4
15
|
import numpy.typing as npt
|
5
16
|
from typing_extensions import deprecated
|
@@ -12,6 +23,7 @@ from snowflake.ml.modeling._internal.constants import (
|
|
12
23
|
IN_ML_RUNTIME_ENV_VAR,
|
13
24
|
USE_OPTIMIZED_DATA_INGESTOR,
|
14
25
|
)
|
26
|
+
from snowflake.snowpark import context as sf_context
|
15
27
|
|
16
28
|
if TYPE_CHECKING:
|
17
29
|
import pandas as pd
|
@@ -35,8 +47,10 @@ class DataConnector:
|
|
35
47
|
def __init__(
|
36
48
|
self,
|
37
49
|
ingestor: data_ingestor.DataIngestor,
|
50
|
+
**kwargs: Any,
|
38
51
|
) -> None:
|
39
52
|
self._ingestor = ingestor
|
53
|
+
self._kwargs = kwargs
|
40
54
|
|
41
55
|
@classmethod
|
42
56
|
@snowpark._internal.utils.private_preview(version="1.6.0")
|
@@ -44,20 +58,34 @@ class DataConnector:
|
|
44
58
|
cls: Type[DataConnectorType],
|
45
59
|
df: snowpark.DataFrame,
|
46
60
|
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
47
|
-
**kwargs: Any
|
61
|
+
**kwargs: Any,
|
48
62
|
) -> DataConnectorType:
|
49
63
|
if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
|
50
64
|
raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
|
51
|
-
|
52
|
-
|
53
|
-
|
65
|
+
return cast(
|
66
|
+
DataConnectorType,
|
67
|
+
cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs),
|
68
|
+
)
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
@snowpark._internal.utils.private_preview(version="1.7.3")
|
72
|
+
def from_sql(
|
73
|
+
cls: Type[DataConnectorType],
|
74
|
+
query: str,
|
75
|
+
session: Optional[snowpark.Session] = None,
|
76
|
+
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
77
|
+
**kwargs: Any,
|
78
|
+
) -> DataConnectorType:
|
79
|
+
session = session or sf_context.get_active_session()
|
80
|
+
source = data_source.DataFrameInfo(query)
|
81
|
+
return cls.from_sources(session, [source], ingestor_class=ingestor_class, **kwargs)
|
54
82
|
|
55
83
|
@classmethod
|
56
84
|
def from_dataset(
|
57
85
|
cls: Type[DataConnectorType],
|
58
86
|
ds: "dataset.Dataset",
|
59
87
|
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
60
|
-
**kwargs: Any
|
88
|
+
**kwargs: Any,
|
61
89
|
) -> DataConnectorType:
|
62
90
|
dsv = ds.selected_version
|
63
91
|
assert dsv is not None
|
@@ -75,9 +103,9 @@ class DataConnector:
|
|
75
103
|
def from_sources(
|
76
104
|
cls: Type[DataConnectorType],
|
77
105
|
session: snowpark.Session,
|
78
|
-
sources:
|
106
|
+
sources: Sequence[data_source.DataSource],
|
79
107
|
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
80
|
-
**kwargs: Any
|
108
|
+
**kwargs: Any,
|
81
109
|
) -> DataConnectorType:
|
82
110
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
83
111
|
ingestor = ingestor_class.from_sources(session, sources)
|
@@ -130,7 +158,11 @@ class DataConnector:
|
|
130
158
|
func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
|
131
159
|
)
|
132
160
|
def to_torch_datapipe(
|
133
|
-
self,
|
161
|
+
self,
|
162
|
+
*,
|
163
|
+
batch_size: int,
|
164
|
+
shuffle: bool = False,
|
165
|
+
drop_last_batch: bool = True,
|
134
166
|
) -> "torch_data.IterDataPipe": # type: ignore[type-arg]
|
135
167
|
"""Transform the Snowflake data into a ready-to-use Pytorch datapipe.
|
136
168
|
|
@@ -149,8 +181,13 @@ class DataConnector:
|
|
149
181
|
"""
|
150
182
|
from snowflake.ml.data import torch_utils
|
151
183
|
|
184
|
+
expand_dims = self._kwargs.get("expand_dims", True)
|
152
185
|
return torch_utils.TorchDataPipeWrapper(
|
153
|
-
self._ingestor,
|
186
|
+
self._ingestor,
|
187
|
+
batch_size=batch_size,
|
188
|
+
shuffle=shuffle,
|
189
|
+
drop_last=drop_last_batch,
|
190
|
+
expand_dims=expand_dims,
|
154
191
|
)
|
155
192
|
|
156
193
|
@telemetry.send_api_usage_telemetry(
|
@@ -179,8 +216,13 @@ class DataConnector:
|
|
179
216
|
"""
|
180
217
|
from snowflake.ml.data import torch_utils
|
181
218
|
|
219
|
+
expand_dims = self._kwargs.get("expand_dims", True)
|
182
220
|
return torch_utils.TorchDatasetWrapper(
|
183
|
-
self._ingestor,
|
221
|
+
self._ingestor,
|
222
|
+
batch_size=batch_size,
|
223
|
+
shuffle=shuffle,
|
224
|
+
drop_last=drop_last_batch,
|
225
|
+
expand_dims=expand_dims,
|
184
226
|
)
|
185
227
|
|
186
228
|
@telemetry.send_api_usage_telemetry(
|
@@ -6,6 +6,7 @@ from typing import (
|
|
6
6
|
List,
|
7
7
|
Optional,
|
8
8
|
Protocol,
|
9
|
+
Sequence,
|
9
10
|
Type,
|
10
11
|
TypeVar,
|
11
12
|
)
|
@@ -25,7 +26,7 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
|
|
25
26
|
class DataIngestor(Protocol):
|
26
27
|
@classmethod
|
27
28
|
def from_sources(
|
28
|
-
cls: Type[DataIngestorType], session: snowpark.Session, sources:
|
29
|
+
cls: Type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
|
29
30
|
) -> DataIngestorType:
|
30
31
|
raise NotImplementedError
|
31
32
|
|
snowflake/ml/data/torch_utils.py
CHANGED
@@ -17,6 +17,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
|
17
17
|
batch_size: Optional[int],
|
18
18
|
shuffle: bool = False,
|
19
19
|
drop_last: bool = False,
|
20
|
+
expand_dims: bool = True,
|
20
21
|
) -> None:
|
21
22
|
"""Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
|
22
23
|
squeeze = False
|
@@ -29,6 +30,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
|
29
30
|
self._shuffle = shuffle
|
30
31
|
self._drop_last = drop_last
|
31
32
|
self._squeeze_outputs = squeeze
|
33
|
+
self._expand_dims = expand_dims
|
32
34
|
|
33
35
|
def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]:
|
34
36
|
max_idx = 0
|
@@ -47,7 +49,10 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
|
47
49
|
):
|
48
50
|
# Skip indices during multi-process data loading to prevent data duplication
|
49
51
|
if counter == filter_idx:
|
50
|
-
yield {
|
52
|
+
yield {
|
53
|
+
k: _preprocess_array(v, squeeze=self._squeeze_outputs, expand_dims=self._expand_dims)
|
54
|
+
for k, v in batch.items()
|
55
|
+
}
|
51
56
|
if counter < max_idx:
|
52
57
|
counter += 1
|
53
58
|
else:
|
@@ -58,13 +63,21 @@ class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Di
|
|
58
63
|
"""Wrap a DataIngestor into a PyTorch IterDataPipe"""
|
59
64
|
|
60
65
|
def __init__(
|
61
|
-
self,
|
66
|
+
self,
|
67
|
+
ingestor: data_ingestor.DataIngestor,
|
68
|
+
*,
|
69
|
+
batch_size: int,
|
70
|
+
shuffle: bool = False,
|
71
|
+
drop_last: bool = False,
|
72
|
+
expand_dims: bool = True,
|
62
73
|
) -> None:
|
63
74
|
"""Not intended for direct usage. Use DataConnector.to_torch_datapipe() instead"""
|
64
|
-
super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
75
|
+
super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, expand_dims=expand_dims)
|
65
76
|
|
66
77
|
|
67
|
-
def _preprocess_array(
|
78
|
+
def _preprocess_array(
|
79
|
+
arr: npt.NDArray[Any], squeeze: bool = False, expand_dims: bool = True
|
80
|
+
) -> Union[npt.NDArray[Any], List[np.object_]]:
|
68
81
|
"""Preprocesses batch column values."""
|
69
82
|
single_dimensional = arr.ndim < 2 and not arr.dtype == np.object_
|
70
83
|
|
@@ -73,7 +86,7 @@ def _preprocess_array(arr: npt.NDArray[Any], squeeze: bool = False) -> Union[npt
|
|
73
86
|
arr = arr.squeeze(axis=0)
|
74
87
|
|
75
88
|
# For single dimensional data,
|
76
|
-
if single_dimensional:
|
89
|
+
if single_dimensional and expand_dims:
|
77
90
|
axis = 0 if arr.ndim == 0 else 1
|
78
91
|
arr = np.expand_dims(arr, axis=axis)
|
79
92
|
|
@@ -45,8 +45,9 @@ class ExampleHelper:
|
|
45
45
|
"""Return a dataframe object about descriptions of all examples."""
|
46
46
|
root_dir = Path(__file__).parent
|
47
47
|
rows = []
|
48
|
+
hide_folders = ["citibike_trip_features", "source_data"]
|
48
49
|
for f_name in os.listdir(root_dir):
|
49
|
-
if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name
|
50
|
+
if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name not in hide_folders:
|
50
51
|
source_file_path = root_dir.joinpath(f"{f_name}/source.yaml")
|
51
52
|
source_dict = self._read_yaml(str(source_file_path))
|
52
53
|
rows.append((f_name, source_dict["model_category"], source_dict["desc"], source_dict["label_columns"]))
|
snowflake/ml/fileset/fileset.py
CHANGED
@@ -11,11 +11,9 @@ from snowflake.ml._internal.exceptions import (
|
|
11
11
|
fileset_error_messages,
|
12
12
|
fileset_errors,
|
13
13
|
)
|
14
|
-
from snowflake.ml._internal.utils import
|
15
|
-
|
16
|
-
|
17
|
-
snowpark_dataframe_utils,
|
18
|
-
)
|
14
|
+
from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
|
15
|
+
from snowflake.ml.data import data_connector
|
16
|
+
from snowflake.ml.data._internal import arrow_ingestor
|
19
17
|
from snowflake.ml.fileset import sfcfs
|
20
18
|
from snowflake.snowpark import exceptions as snowpark_exceptions, functions
|
21
19
|
|
@@ -285,6 +283,16 @@ class FileSet:
|
|
285
283
|
"""Get the Snowflake absolute path to this FileSet directory."""
|
286
284
|
return _fileset_absolute_path(self._target_stage_loc, self.name)
|
287
285
|
|
286
|
+
def _to_data_connector(self) -> data_connector.DataConnector:
|
287
|
+
self._fs.optimize_read(self._list_files())
|
288
|
+
ingester = arrow_ingestor.ArrowIngestor(
|
289
|
+
self._snowpark_session,
|
290
|
+
self._list_files(),
|
291
|
+
format="parquet",
|
292
|
+
filesystem=self._fs,
|
293
|
+
)
|
294
|
+
return data_connector.DataConnector(ingester, expand_dims=False)
|
295
|
+
|
288
296
|
@telemetry.send_api_usage_telemetry(
|
289
297
|
project=_PROJECT,
|
290
298
|
)
|
@@ -362,13 +370,9 @@ class FileSet:
|
|
362
370
|
----
|
363
371
|
{'_COL_1':[10]}
|
364
372
|
"""
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
self._fs.optimize_read(self._list_files())
|
369
|
-
|
370
|
-
input_dp = IterableWrapper(self._list_files())
|
371
|
-
return torch_datapipe_module.ReadAndParseParquet(input_dp, self._fs, batch_size, shuffle, drop_last_batch)
|
373
|
+
return self._to_data_connector().to_torch_datapipe(
|
374
|
+
batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
|
375
|
+
)
|
372
376
|
|
373
377
|
@telemetry.send_api_usage_telemetry(
|
374
378
|
project=_PROJECT,
|
@@ -402,12 +406,8 @@ class FileSet:
|
|
402
406
|
----
|
403
407
|
{'_COL_1': <tf.Tensor: shape=(1,), dtype=int64, numpy=[10]>}
|
404
408
|
"""
|
405
|
-
|
406
|
-
|
407
|
-
self._fs.optimize_read(self._list_files())
|
408
|
-
|
409
|
-
return tf_dataset_module.read_and_parse_parquet(
|
410
|
-
self._list_files(), self._fs, batch_size, shuffle, drop_last_batch
|
409
|
+
return self._to_data_connector().to_tf_dataset(
|
410
|
+
batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
|
411
411
|
)
|
412
412
|
|
413
413
|
@telemetry.send_api_usage_telemetry(
|
@@ -447,13 +447,15 @@ class ModelVersion(lineage_node.LineageNode):
|
|
447
447
|
target_function_info = functions[0]
|
448
448
|
|
449
449
|
if service_name:
|
450
|
+
database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
|
451
|
+
|
450
452
|
return self._model_ops.invoke_method(
|
451
453
|
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
452
454
|
signature=target_function_info["signature"],
|
453
455
|
X=X,
|
454
|
-
database_name=
|
455
|
-
schema_name=
|
456
|
-
service_name=
|
456
|
+
database_name=database_name_id,
|
457
|
+
schema_name=schema_name_id,
|
458
|
+
service_name=service_name_id,
|
457
459
|
strict_input_validation=strict_input_validation,
|
458
460
|
statement_params=statement_params,
|
459
461
|
)
|
@@ -168,14 +168,10 @@ class ModelOperator:
|
|
168
168
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
169
169
|
model_name: sql_identifier.SqlIdentifier,
|
170
170
|
version_name: sql_identifier.SqlIdentifier,
|
171
|
+
model_exists: bool,
|
171
172
|
statement_params: Optional[Dict[str, Any]] = None,
|
172
173
|
) -> None:
|
173
|
-
if
|
174
|
-
database_name=database_name,
|
175
|
-
schema_name=schema_name,
|
176
|
-
model_name=model_name,
|
177
|
-
statement_params=statement_params,
|
178
|
-
):
|
174
|
+
if model_exists:
|
179
175
|
return self._model_version_client.add_version_from_model_version(
|
180
176
|
source_database_name=source_database_name,
|
181
177
|
source_schema_name=source_schema_name,
|
@@ -10,6 +10,7 @@ from snowflake.ml._internal.utils import (
|
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
12
|
from snowflake.ml.model._client.sql import _base
|
13
|
+
from snowflake.ml.model._model_composer.model_method import constants
|
13
14
|
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
14
15
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
15
16
|
|
@@ -333,6 +334,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
333
334
|
|
334
335
|
args_sql = ", ".join(args_sql_list)
|
335
336
|
|
337
|
+
wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
338
|
+
if wide_input:
|
339
|
+
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
340
|
+
args_sql = f"object_construct_keep_null({input_args_sql})"
|
341
|
+
|
336
342
|
sql = textwrap.dedent(
|
337
343
|
f"""WITH {','.join(with_statements)}
|
338
344
|
SELECT *,
|
@@ -412,6 +418,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
412
418
|
|
413
419
|
args_sql = ", ".join(args_sql_list)
|
414
420
|
|
421
|
+
wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
422
|
+
if wide_input:
|
423
|
+
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
424
|
+
args_sql = f"object_construct_keep_null({input_args_sql})"
|
425
|
+
|
415
426
|
sql = textwrap.dedent(
|
416
427
|
f"""WITH {','.join(with_statements)}
|
417
428
|
SELECT *,
|
@@ -88,6 +88,7 @@ class ModelComposer:
|
|
88
88
|
pip_requirements: Optional[List[str]] = None,
|
89
89
|
target_platforms: Optional[List[model_types.TargetPlatform]] = None,
|
90
90
|
python_version: Optional[str] = None,
|
91
|
+
user_files: Optional[Dict[str, List[str]]] = None,
|
91
92
|
ext_modules: Optional[List[ModuleType]] = None,
|
92
93
|
code_paths: Optional[List[str]] = None,
|
93
94
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
@@ -97,9 +98,12 @@ class ModelComposer:
|
|
97
98
|
options = model_types.BaseModelSaveOption()
|
98
99
|
|
99
100
|
if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
100
|
-
snowml_matched_versions = env_utils.
|
101
|
-
|
102
|
-
|
101
|
+
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
102
|
+
self.session,
|
103
|
+
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_env.VERSION}")],
|
104
|
+
python_version=python_version or snowml_env.PYTHON_VERSION,
|
105
|
+
statement_params=self._statement_params,
|
106
|
+
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
103
107
|
|
104
108
|
if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
|
105
109
|
logging.info(
|
@@ -131,6 +135,7 @@ class ModelComposer:
|
|
131
135
|
model_meta=self.packager.meta,
|
132
136
|
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
133
137
|
options=options,
|
138
|
+
user_files=user_files,
|
134
139
|
data_sources=self._get_data_sources(model, sample_input_data),
|
135
140
|
target_platforms=target_platforms,
|
136
141
|
)
|
@@ -2,7 +2,7 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import pathlib
|
4
4
|
import warnings
|
5
|
-
from typing import List, Optional, cast
|
5
|
+
from typing import Dict, List, Optional, cast
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
@@ -11,9 +11,11 @@ from snowflake.ml.data import data_source
|
|
11
11
|
from snowflake.ml.model import type_hints
|
12
12
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
13
13
|
from snowflake.ml.model._model_composer.model_method import (
|
14
|
+
constants,
|
14
15
|
function_generator,
|
15
16
|
model_method,
|
16
17
|
)
|
18
|
+
from snowflake.ml.model._model_composer.model_user_file import model_user_file
|
17
19
|
from snowflake.ml.model._packager.model_meta import (
|
18
20
|
model_meta as model_meta_api,
|
19
21
|
model_meta_schema,
|
@@ -30,9 +32,11 @@ class ModelManifest:
|
|
30
32
|
workspace_path: A local path where model related files should be dumped to.
|
31
33
|
runtimes: A list of ModelRuntime objects managing the runtimes and environment in the MODEL object.
|
32
34
|
methods: A list of ModelMethod objects managing the method we registered to the MODEL object.
|
35
|
+
user_files: A list of ModelUserFile objects managing extra files uploaded to the workspace.
|
33
36
|
"""
|
34
37
|
|
35
38
|
MANIFEST_FILE_REL_PATH = "MANIFEST.yml"
|
39
|
+
_ENABLE_USER_FILES = False
|
36
40
|
_DEFAULT_RUNTIME_NAME = "python_runtime"
|
37
41
|
|
38
42
|
def __init__(self, workspace_path: pathlib.Path) -> None:
|
@@ -42,6 +46,7 @@ class ModelManifest:
|
|
42
46
|
self,
|
43
47
|
model_meta: model_meta_api.ModelMetadata,
|
44
48
|
model_rel_path: pathlib.PurePosixPath,
|
49
|
+
user_files: Optional[Dict[str, List[str]]] = None,
|
45
50
|
options: Optional[type_hints.ModelSaveOption] = None,
|
46
51
|
data_sources: Optional[List[data_source.DataSource]] = None,
|
47
52
|
target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
|
@@ -79,6 +84,7 @@ class ModelManifest:
|
|
79
84
|
|
80
85
|
self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
|
81
86
|
self.methods: List[model_method.ModelMethod] = []
|
87
|
+
|
82
88
|
for target_method in model_meta.signatures.keys():
|
83
89
|
method = model_method.ModelMethod(
|
84
90
|
model_meta=model_meta,
|
@@ -88,11 +94,21 @@ class ModelManifest:
|
|
88
94
|
is_partitioned_function=model_meta.function_properties.get(target_method, {}).get(
|
89
95
|
model_meta_schema.FunctionProperties.PARTITIONED.value, False
|
90
96
|
),
|
97
|
+
wide_input=len(model_meta.signatures[target_method].inputs) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT,
|
91
98
|
options=model_method.get_model_method_options_from_options(options, target_method),
|
92
99
|
)
|
93
100
|
|
94
101
|
self.methods.append(method)
|
95
102
|
|
103
|
+
self.user_files: List[model_user_file.ModelUserFile] = []
|
104
|
+
|
105
|
+
if user_files is not None:
|
106
|
+
for subdirectory, paths in user_files.items():
|
107
|
+
for path in paths:
|
108
|
+
self.user_files.append(
|
109
|
+
model_user_file.ModelUserFile(pathlib.PurePosixPath(subdirectory), pathlib.Path(path))
|
110
|
+
)
|
111
|
+
|
96
112
|
method_name_counter = collections.Counter([method.method_name for method in self.methods])
|
97
113
|
dup_method_names = [k for k, v in method_name_counter.items() if v > 1]
|
98
114
|
if dup_method_names:
|
@@ -129,6 +145,9 @@ class ModelManifest:
|
|
129
145
|
],
|
130
146
|
)
|
131
147
|
|
148
|
+
if self._ENABLE_USER_FILES:
|
149
|
+
manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
|
150
|
+
|
132
151
|
lineage_sources = self._extract_lineage_info(data_sources)
|
133
152
|
if lineage_sources:
|
134
153
|
manifest_dict["lineage_sources"] = lineage_sources
|
@@ -94,5 +94,6 @@ class ModelManifestDict(TypedDict):
|
|
94
94
|
runtimes: Required[Dict[str, ModelRuntimeDict]]
|
95
95
|
methods: Required[List[ModelMethodDict]]
|
96
96
|
user_data: NotRequired[Dict[str, Any]]
|
97
|
+
user_files: NotRequired[List[str]]
|
97
98
|
lineage_sources: NotRequired[List[LineageSourceDict]]
|
98
99
|
target_platforms: NotRequired[List[str]]
|
@@ -0,0 +1 @@
|
|
1
|
+
SNOWPARK_UDF_INPUT_COL_LIMIT = 500
|
@@ -43,6 +43,7 @@ class FunctionGenerator:
|
|
43
43
|
target_method: str,
|
44
44
|
function_type: str,
|
45
45
|
is_partitioned_function: bool = False,
|
46
|
+
wide_input: bool = False,
|
46
47
|
options: Optional[FunctionGenerateOptions] = None,
|
47
48
|
) -> None:
|
48
49
|
import importlib_resources
|
@@ -70,6 +71,7 @@ class FunctionGenerator:
|
|
70
71
|
model_dir_name=self.model_dir_rel_path.name,
|
71
72
|
target_method=target_method,
|
72
73
|
max_batch_size=options.get("max_batch_size", None),
|
74
|
+
wide_input=wide_input,
|
73
75
|
function_name=FunctionGenerator.FUNCTION_NAME,
|
74
76
|
)
|
75
77
|
with open(function_file_path, "w", encoding="utf-8") as f:
|
@@ -43,7 +43,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
43
43
|
|
44
44
|
|
45
45
|
# Actual function
|
46
|
-
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
|
46
|
+
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
47
47
|
def {function_name}(df: pd.DataFrame) -> dict:
|
48
48
|
df.columns = input_cols
|
49
49
|
input_df = df.astype(dtype=dtype_map)
|
@@ -48,7 +48,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
48
48
|
|
49
49
|
# Actual table function
|
50
50
|
class {function_name}:
|
51
|
-
@vectorized(input=pd.DataFrame)
|
51
|
+
@vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
|
52
52
|
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
53
53
|
df.columns = input_cols
|
54
54
|
input_df = df.astype(dtype=dtype_map)
|
@@ -43,7 +43,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
43
43
|
|
44
44
|
# Actual table function
|
45
45
|
class {function_name}:
|
46
|
-
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
|
46
|
+
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
47
47
|
def process(self, df: pd.DataFrame) -> pd.DataFrame:
|
48
48
|
df.columns = input_cols
|
49
49
|
input_df = df.astype(dtype=dtype_map)
|
@@ -7,7 +7,10 @@ from typing_extensions import NotRequired
|
|
7
7
|
from snowflake.ml._internal.utils import sql_identifier
|
8
8
|
from snowflake.ml.model import model_signature, type_hints
|
9
9
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
10
|
-
from snowflake.ml.model._model_composer.model_method import
|
10
|
+
from snowflake.ml.model._model_composer.model_method import (
|
11
|
+
constants,
|
12
|
+
function_generator,
|
13
|
+
)
|
11
14
|
from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
|
12
15
|
from snowflake.snowpark._internal import type_utils
|
13
16
|
|
@@ -64,6 +67,7 @@ class ModelMethod:
|
|
64
67
|
runtime_name: str,
|
65
68
|
function_generator: function_generator.FunctionGenerator,
|
66
69
|
is_partitioned_function: bool = False,
|
70
|
+
wide_input: bool = False,
|
67
71
|
options: Optional[ModelMethodOptions] = None,
|
68
72
|
) -> None:
|
69
73
|
self.model_meta = model_meta
|
@@ -71,6 +75,7 @@ class ModelMethod:
|
|
71
75
|
self.function_generator = function_generator
|
72
76
|
self.is_partitioned_function = is_partitioned_function
|
73
77
|
self.runtime_name = runtime_name
|
78
|
+
self.wide_input = wide_input
|
74
79
|
self.options = options or {}
|
75
80
|
try:
|
76
81
|
self.method_name = sql_identifier.SqlIdentifier(
|
@@ -114,12 +119,15 @@ class ModelMethod:
|
|
114
119
|
self.target_method,
|
115
120
|
self.function_type,
|
116
121
|
self.is_partitioned_function,
|
122
|
+
self.wide_input,
|
117
123
|
options=options,
|
118
124
|
)
|
119
125
|
input_list = [
|
120
126
|
ModelMethod._get_method_arg_from_feature(ft, case_sensitive=self.options.get("case_sensitive", False))
|
121
127
|
for ft in self.model_meta.signatures[self.target_method].inputs
|
122
128
|
]
|
129
|
+
if len(input_list) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT:
|
130
|
+
input_list = [{"name": "INPUT", "type": "OBJECT"}]
|
123
131
|
input_name_counter = collections.Counter([input_info["name"] for input_info in input_list])
|
124
132
|
dup_input_names = [k for k, v in input_name_counter.items() if v > 1]
|
125
133
|
if dup_input_names:
|
@@ -0,0 +1,27 @@
|
|
1
|
+
import os
|
2
|
+
import pathlib
|
3
|
+
|
4
|
+
from snowflake.ml._internal import file_utils
|
5
|
+
|
6
|
+
|
7
|
+
class ModelUserFile:
|
8
|
+
"""Class representing a user provided file.
|
9
|
+
|
10
|
+
Attributes:
|
11
|
+
subdirectory_name: A local path where model related files should be dumped to.
|
12
|
+
local_path: A list of ModelRuntime objects managing the runtimes and environment in the MODEL object.
|
13
|
+
"""
|
14
|
+
|
15
|
+
USER_FILES_DIR_REL_PATH = "user_files"
|
16
|
+
|
17
|
+
def __init__(self, subdirectory_name: pathlib.PurePosixPath, local_path: pathlib.Path) -> None:
|
18
|
+
self.subdirectory_name = subdirectory_name
|
19
|
+
self.local_path = local_path
|
20
|
+
|
21
|
+
def save(self, workspace_path: pathlib.Path) -> str:
|
22
|
+
user_files_path = workspace_path / ModelUserFile.USER_FILES_DIR_REL_PATH / self.subdirectory_name
|
23
|
+
user_files_path.mkdir(parents=True, exist_ok=True)
|
24
|
+
|
25
|
+
# copy the file to the workspace
|
26
|
+
file_utils.copy_file_or_tree(str(self.local_path), str(user_files_path))
|
27
|
+
return os.path.join(self.subdirectory_name, self.local_path.name)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
+
import pathlib
|
3
4
|
import warnings
|
4
|
-
from typing import Any, Callable, Iterable, List, Optional, Sequence, cast
|
5
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast
|
5
6
|
|
6
7
|
import numpy as np
|
7
8
|
import numpy.typing as npt
|
@@ -118,7 +119,7 @@ def get_explainability_supported_background(
|
|
118
119
|
meta: model_meta.ModelMetadata,
|
119
120
|
explain_target_method: Optional[str],
|
120
121
|
) -> pd.DataFrame:
|
121
|
-
if sample_input_data is None:
|
122
|
+
if sample_input_data is None or explain_target_method is None:
|
122
123
|
return None
|
123
124
|
|
124
125
|
if isinstance(sample_input_data, pd.DataFrame):
|
@@ -223,3 +224,27 @@ def get_explain_target_method(
|
|
223
224
|
if method in target_methods_list:
|
224
225
|
return method
|
225
226
|
return None
|
227
|
+
|
228
|
+
|
229
|
+
def save_transformers_config_with_auto_map(local_model_path: str) -> None:
|
230
|
+
import huggingface_hub
|
231
|
+
|
232
|
+
for f_path in pathlib.Path(local_model_path).iterdir():
|
233
|
+
if f_path.name in ["config.json", "tokenizer_config.json"]:
|
234
|
+
with open(f_path) as f:
|
235
|
+
config_dict = json.load(f)
|
236
|
+
|
237
|
+
# a. get repository and class_path from configs
|
238
|
+
auto_map_configs = cast(Dict[str, str], config_dict.get("auto_map", {}))
|
239
|
+
for config_name, config_value in auto_map_configs.items():
|
240
|
+
repository, _, class_path = config_value.rpartition("--")
|
241
|
+
|
242
|
+
# b. download required configs from hf hub
|
243
|
+
if repository:
|
244
|
+
huggingface_hub.snapshot_download(repo_id=repository, local_dir=local_model_path)
|
245
|
+
|
246
|
+
# c. update config files
|
247
|
+
config_dict["auto_map"][config_name] = class_path
|
248
|
+
|
249
|
+
with open(f_path, "w") as f:
|
250
|
+
json.dump(config_dict, f)
|
@@ -94,8 +94,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
94
94
|
sample_input_data=sample_input_data,
|
95
95
|
get_prediction_fn=get_prediction,
|
96
96
|
)
|
97
|
-
model_task_and_output = model_task_utils.
|
98
|
-
model_meta.task =
|
97
|
+
model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
|
98
|
+
model_meta.task = model_task_and_output.task
|
99
99
|
if enable_explainability:
|
100
100
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
101
101
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -227,7 +227,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
227
227
|
import shap
|
228
228
|
|
229
229
|
explainer = shap.TreeExplainer(raw_model)
|
230
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X)
|
230
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
|
231
231
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
232
232
|
|
233
233
|
if target_method == "explain":
|