snowflake-ml-python 1.7.1__py3-none-any.whl → 1.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +82 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/_internal/utils/jwt_generator.py +141 -0
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +18 -18
- snowflake/ml/model/_client/model/model_version_impl.py +24 -8
- snowflake/ml/model/_client/ops/model_ops.py +2 -6
- snowflake/ml/model/_client/ops/service_ops.py +12 -7
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_client/sql/stage.py +1 -1
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +10 -9
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/pandas_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
- snowflake/ml/model/_signatures/utils.py +0 -1
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/pipeline/pipeline.py +6 -176
- snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
- snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +5 -170
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +9 -9
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -2
- snowflake/ml/monitoring/model_monitor.py +26 -11
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +53 -34
- snowflake/ml/utils/authentication.py +75 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +120 -53
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +71 -74
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- snowflake/ml/monitoring/entities/output_score_type.py +0 -90
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
|
|
1
|
+
import base64
|
2
|
+
import hashlib
|
3
|
+
import logging
|
4
|
+
from datetime import datetime, timedelta, timezone
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
import jwt
|
8
|
+
from cryptography.hazmat.primitives import serialization
|
9
|
+
from cryptography.hazmat.primitives.asymmetric import types
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
ISSUER = "iss"
|
14
|
+
EXPIRE_TIME = "exp"
|
15
|
+
ISSUE_TIME = "iat"
|
16
|
+
SUBJECT = "sub"
|
17
|
+
|
18
|
+
|
19
|
+
class JWTGenerator:
|
20
|
+
"""
|
21
|
+
Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator
|
22
|
+
keeps the generated token and only regenerates the token if a specified period of time has passed.
|
23
|
+
"""
|
24
|
+
|
25
|
+
_DEFAULT_LIFETIME = timedelta(minutes=59) # The tokens will have a 59-minute lifetime
|
26
|
+
_DEFAULT_RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes
|
27
|
+
ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
account: str,
|
32
|
+
user: str,
|
33
|
+
private_key: types.PRIVATE_KEY_TYPES,
|
34
|
+
lifetime: Optional[timedelta] = None,
|
35
|
+
renewal_delay: Optional[timedelta] = None,
|
36
|
+
) -> None:
|
37
|
+
"""
|
38
|
+
Create a new JWTGenerator object.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
account: The account identifier.
|
42
|
+
user: The username.
|
43
|
+
private_key: The private key used to sign the JWT.
|
44
|
+
lifetime: The lifetime of the token.
|
45
|
+
renewal_delay: The time before the token expires to renew it.
|
46
|
+
"""
|
47
|
+
|
48
|
+
# Construct the fully qualified name of the user in uppercase.
|
49
|
+
self.account = JWTGenerator._prepare_account_name_for_jwt(account)
|
50
|
+
self.user = user.upper()
|
51
|
+
self.qualified_username = self.account + "." + self.user
|
52
|
+
self.private_key = private_key
|
53
|
+
self.public_key_fp = JWTGenerator._calculate_public_key_fingerprint(self.private_key)
|
54
|
+
|
55
|
+
self.issuer = self.qualified_username + "." + self.public_key_fp
|
56
|
+
self.lifetime = lifetime or JWTGenerator._DEFAULT_LIFETIME
|
57
|
+
self.renewal_delay = renewal_delay or JWTGenerator._DEFAULT_RENEWAL_DELTA
|
58
|
+
self.renew_time = datetime.now(timezone.utc)
|
59
|
+
self.token: Optional[str] = None
|
60
|
+
|
61
|
+
logger.info(
|
62
|
+
"""Creating JWTGenerator with arguments
|
63
|
+
account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
|
64
|
+
self.account,
|
65
|
+
self.user,
|
66
|
+
self.lifetime,
|
67
|
+
self.renewal_delay,
|
68
|
+
)
|
69
|
+
|
70
|
+
@staticmethod
|
71
|
+
def _prepare_account_name_for_jwt(raw_account: str) -> str:
|
72
|
+
account = raw_account
|
73
|
+
if ".global" not in account:
|
74
|
+
# Handle the general case.
|
75
|
+
idx = account.find(".")
|
76
|
+
if idx > 0:
|
77
|
+
account = account[0:idx]
|
78
|
+
else:
|
79
|
+
# Handle the replication case.
|
80
|
+
idx = account.find("-")
|
81
|
+
if idx > 0:
|
82
|
+
account = account[0:idx]
|
83
|
+
# Use uppercase for the account identifier.
|
84
|
+
return account.upper()
|
85
|
+
|
86
|
+
def get_token(self) -> str:
|
87
|
+
now = datetime.now(timezone.utc) # Fetch the current time
|
88
|
+
if self.token is not None and self.renew_time > now:
|
89
|
+
return self.token
|
90
|
+
|
91
|
+
# If the token has expired or doesn't exist, regenerate the token.
|
92
|
+
logger.info(
|
93
|
+
"Generating a new token because the present time (%s) is later than the renewal time (%s)",
|
94
|
+
now,
|
95
|
+
self.renew_time,
|
96
|
+
)
|
97
|
+
# Calculate the next time we need to renew the token.
|
98
|
+
self.renew_time = now + self.renewal_delay
|
99
|
+
|
100
|
+
# Create our payload
|
101
|
+
payload = {
|
102
|
+
# Set the issuer to the fully qualified username concatenated with the public key fingerprint.
|
103
|
+
ISSUER: self.issuer,
|
104
|
+
# Set the subject to the fully qualified username.
|
105
|
+
SUBJECT: self.qualified_username,
|
106
|
+
# Set the issue time to now.
|
107
|
+
ISSUE_TIME: now,
|
108
|
+
# Set the expiration time, based on the lifetime specified for this object.
|
109
|
+
EXPIRE_TIME: now + self.lifetime,
|
110
|
+
}
|
111
|
+
|
112
|
+
# Regenerate the actual token
|
113
|
+
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
|
114
|
+
# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
|
115
|
+
# If the token is a byte string, convert it to a string.
|
116
|
+
if isinstance(token, bytes):
|
117
|
+
token = token.decode("utf-8")
|
118
|
+
self.token = token
|
119
|
+
logger.info(
|
120
|
+
"Generated a JWT with the following payload: %s",
|
121
|
+
jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
|
122
|
+
)
|
123
|
+
|
124
|
+
return token
|
125
|
+
|
126
|
+
@staticmethod
|
127
|
+
def _calculate_public_key_fingerprint(private_key: types.PRIVATE_KEY_TYPES) -> str:
|
128
|
+
# Get the raw bytes of public key.
|
129
|
+
public_key_raw = private_key.public_key().public_bytes(
|
130
|
+
serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo
|
131
|
+
)
|
132
|
+
|
133
|
+
# Get the sha256 hash of the raw bytes.
|
134
|
+
sha256hash = hashlib.sha256()
|
135
|
+
sha256hash.update(public_key_raw)
|
136
|
+
|
137
|
+
# Base64-encode the value and prepend the prefix 'SHA256:'.
|
138
|
+
public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
|
139
|
+
logger.info("Public key fingerprint is %s", public_key_fp)
|
140
|
+
|
141
|
+
return public_key_fp
|
snowflake/ml/data/__init__.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1
|
+
from pkgutil import extend_path
|
2
|
+
|
1
3
|
from .data_connector import DataConnector
|
2
4
|
from .data_ingestor import DataIngestor, DataIngestorType
|
3
5
|
from .data_source import DataFrameInfo, DatasetInfo, DataSource
|
4
6
|
|
5
7
|
__all__ = ["DataConnector", "DataSource", "DataFrameInfo", "DatasetInfo", "DataIngestor", "DataIngestorType"]
|
8
|
+
__path__ = extend_path(__path__, __name__)
|
@@ -2,7 +2,7 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
import time
|
5
|
-
from typing import Any, Deque, Dict, Iterator, List, Optional, Union
|
5
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional, Sequence, Union
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import numpy.typing as npt
|
@@ -47,7 +47,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
47
47
|
def __init__(
|
48
48
|
self,
|
49
49
|
session: snowpark.Session,
|
50
|
-
data_sources:
|
50
|
+
data_sources: Sequence[data_source.DataSource],
|
51
51
|
format: Optional[str] = None,
|
52
52
|
**kwargs: Any,
|
53
53
|
) -> None:
|
@@ -60,14 +60,14 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
60
60
|
kwargs: Miscellaneous arguments passed to underlying PyArrow Dataset initializer.
|
61
61
|
"""
|
62
62
|
self._session = session
|
63
|
-
self._data_sources = data_sources
|
63
|
+
self._data_sources = list(data_sources)
|
64
64
|
self._format = format
|
65
65
|
self._kwargs = kwargs
|
66
66
|
|
67
67
|
self._schema: Optional[pa.Schema] = None
|
68
68
|
|
69
69
|
@classmethod
|
70
|
-
def from_sources(cls, session: snowpark.Session, sources:
|
70
|
+
def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
|
71
71
|
return cls(session, sources)
|
72
72
|
|
73
73
|
@property
|
@@ -1,5 +1,16 @@
|
|
1
1
|
import os
|
2
|
-
from typing import
|
2
|
+
from typing import (
|
3
|
+
TYPE_CHECKING,
|
4
|
+
Any,
|
5
|
+
Dict,
|
6
|
+
Generator,
|
7
|
+
List,
|
8
|
+
Optional,
|
9
|
+
Sequence,
|
10
|
+
Type,
|
11
|
+
TypeVar,
|
12
|
+
cast,
|
13
|
+
)
|
3
14
|
|
4
15
|
import numpy.typing as npt
|
5
16
|
from typing_extensions import deprecated
|
@@ -12,6 +23,7 @@ from snowflake.ml.modeling._internal.constants import (
|
|
12
23
|
IN_ML_RUNTIME_ENV_VAR,
|
13
24
|
USE_OPTIMIZED_DATA_INGESTOR,
|
14
25
|
)
|
26
|
+
from snowflake.snowpark import context as sf_context
|
15
27
|
|
16
28
|
if TYPE_CHECKING:
|
17
29
|
import pandas as pd
|
@@ -35,8 +47,10 @@ class DataConnector:
|
|
35
47
|
def __init__(
|
36
48
|
self,
|
37
49
|
ingestor: data_ingestor.DataIngestor,
|
50
|
+
**kwargs: Any,
|
38
51
|
) -> None:
|
39
52
|
self._ingestor = ingestor
|
53
|
+
self._kwargs = kwargs
|
40
54
|
|
41
55
|
@classmethod
|
42
56
|
@snowpark._internal.utils.private_preview(version="1.6.0")
|
@@ -44,20 +58,34 @@ class DataConnector:
|
|
44
58
|
cls: Type[DataConnectorType],
|
45
59
|
df: snowpark.DataFrame,
|
46
60
|
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
47
|
-
**kwargs: Any
|
61
|
+
**kwargs: Any,
|
48
62
|
) -> DataConnectorType:
|
49
63
|
if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
|
50
64
|
raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
|
51
|
-
|
52
|
-
|
53
|
-
|
65
|
+
return cast(
|
66
|
+
DataConnectorType,
|
67
|
+
cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs),
|
68
|
+
)
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
@snowpark._internal.utils.private_preview(version="1.7.3")
|
72
|
+
def from_sql(
|
73
|
+
cls: Type[DataConnectorType],
|
74
|
+
query: str,
|
75
|
+
session: Optional[snowpark.Session] = None,
|
76
|
+
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
77
|
+
**kwargs: Any,
|
78
|
+
) -> DataConnectorType:
|
79
|
+
session = session or sf_context.get_active_session()
|
80
|
+
source = data_source.DataFrameInfo(query)
|
81
|
+
return cls.from_sources(session, [source], ingestor_class=ingestor_class, **kwargs)
|
54
82
|
|
55
83
|
@classmethod
|
56
84
|
def from_dataset(
|
57
85
|
cls: Type[DataConnectorType],
|
58
86
|
ds: "dataset.Dataset",
|
59
87
|
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
60
|
-
**kwargs: Any
|
88
|
+
**kwargs: Any,
|
61
89
|
) -> DataConnectorType:
|
62
90
|
dsv = ds.selected_version
|
63
91
|
assert dsv is not None
|
@@ -75,9 +103,9 @@ class DataConnector:
|
|
75
103
|
def from_sources(
|
76
104
|
cls: Type[DataConnectorType],
|
77
105
|
session: snowpark.Session,
|
78
|
-
sources:
|
106
|
+
sources: Sequence[data_source.DataSource],
|
79
107
|
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
80
|
-
**kwargs: Any
|
108
|
+
**kwargs: Any,
|
81
109
|
) -> DataConnectorType:
|
82
110
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
83
111
|
ingestor = ingestor_class.from_sources(session, sources)
|
@@ -130,7 +158,11 @@ class DataConnector:
|
|
130
158
|
func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
|
131
159
|
)
|
132
160
|
def to_torch_datapipe(
|
133
|
-
self,
|
161
|
+
self,
|
162
|
+
*,
|
163
|
+
batch_size: int,
|
164
|
+
shuffle: bool = False,
|
165
|
+
drop_last_batch: bool = True,
|
134
166
|
) -> "torch_data.IterDataPipe": # type: ignore[type-arg]
|
135
167
|
"""Transform the Snowflake data into a ready-to-use Pytorch datapipe.
|
136
168
|
|
@@ -149,8 +181,13 @@ class DataConnector:
|
|
149
181
|
"""
|
150
182
|
from snowflake.ml.data import torch_utils
|
151
183
|
|
184
|
+
expand_dims = self._kwargs.get("expand_dims", True)
|
152
185
|
return torch_utils.TorchDataPipeWrapper(
|
153
|
-
self._ingestor,
|
186
|
+
self._ingestor,
|
187
|
+
batch_size=batch_size,
|
188
|
+
shuffle=shuffle,
|
189
|
+
drop_last=drop_last_batch,
|
190
|
+
expand_dims=expand_dims,
|
154
191
|
)
|
155
192
|
|
156
193
|
@telemetry.send_api_usage_telemetry(
|
@@ -179,8 +216,13 @@ class DataConnector:
|
|
179
216
|
"""
|
180
217
|
from snowflake.ml.data import torch_utils
|
181
218
|
|
219
|
+
expand_dims = self._kwargs.get("expand_dims", True)
|
182
220
|
return torch_utils.TorchDatasetWrapper(
|
183
|
-
self._ingestor,
|
221
|
+
self._ingestor,
|
222
|
+
batch_size=batch_size,
|
223
|
+
shuffle=shuffle,
|
224
|
+
drop_last=drop_last_batch,
|
225
|
+
expand_dims=expand_dims,
|
184
226
|
)
|
185
227
|
|
186
228
|
@telemetry.send_api_usage_telemetry(
|
@@ -6,6 +6,7 @@ from typing import (
|
|
6
6
|
List,
|
7
7
|
Optional,
|
8
8
|
Protocol,
|
9
|
+
Sequence,
|
9
10
|
Type,
|
10
11
|
TypeVar,
|
11
12
|
)
|
@@ -25,7 +26,7 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
|
|
25
26
|
class DataIngestor(Protocol):
|
26
27
|
@classmethod
|
27
28
|
def from_sources(
|
28
|
-
cls: Type[DataIngestorType], session: snowpark.Session, sources:
|
29
|
+
cls: Type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
|
29
30
|
) -> DataIngestorType:
|
30
31
|
raise NotImplementedError
|
31
32
|
|
snowflake/ml/data/torch_utils.py
CHANGED
@@ -17,6 +17,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
|
17
17
|
batch_size: Optional[int],
|
18
18
|
shuffle: bool = False,
|
19
19
|
drop_last: bool = False,
|
20
|
+
expand_dims: bool = True,
|
20
21
|
) -> None:
|
21
22
|
"""Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
|
22
23
|
squeeze = False
|
@@ -29,6 +30,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
|
29
30
|
self._shuffle = shuffle
|
30
31
|
self._drop_last = drop_last
|
31
32
|
self._squeeze_outputs = squeeze
|
33
|
+
self._expand_dims = expand_dims
|
32
34
|
|
33
35
|
def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]:
|
34
36
|
max_idx = 0
|
@@ -47,7 +49,10 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
|
47
49
|
):
|
48
50
|
# Skip indices during multi-process data loading to prevent data duplication
|
49
51
|
if counter == filter_idx:
|
50
|
-
yield {
|
52
|
+
yield {
|
53
|
+
k: _preprocess_array(v, squeeze=self._squeeze_outputs, expand_dims=self._expand_dims)
|
54
|
+
for k, v in batch.items()
|
55
|
+
}
|
51
56
|
if counter < max_idx:
|
52
57
|
counter += 1
|
53
58
|
else:
|
@@ -58,13 +63,21 @@ class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Di
|
|
58
63
|
"""Wrap a DataIngestor into a PyTorch IterDataPipe"""
|
59
64
|
|
60
65
|
def __init__(
|
61
|
-
self,
|
66
|
+
self,
|
67
|
+
ingestor: data_ingestor.DataIngestor,
|
68
|
+
*,
|
69
|
+
batch_size: int,
|
70
|
+
shuffle: bool = False,
|
71
|
+
drop_last: bool = False,
|
72
|
+
expand_dims: bool = True,
|
62
73
|
) -> None:
|
63
74
|
"""Not intended for direct usage. Use DataConnector.to_torch_datapipe() instead"""
|
64
|
-
super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
75
|
+
super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, expand_dims=expand_dims)
|
65
76
|
|
66
77
|
|
67
|
-
def _preprocess_array(
|
78
|
+
def _preprocess_array(
|
79
|
+
arr: npt.NDArray[Any], squeeze: bool = False, expand_dims: bool = True
|
80
|
+
) -> Union[npt.NDArray[Any], List[np.object_]]:
|
68
81
|
"""Preprocesses batch column values."""
|
69
82
|
single_dimensional = arr.ndim < 2 and not arr.dtype == np.object_
|
70
83
|
|
@@ -73,7 +86,7 @@ def _preprocess_array(arr: npt.NDArray[Any], squeeze: bool = False) -> Union[npt
|
|
73
86
|
arr = arr.squeeze(axis=0)
|
74
87
|
|
75
88
|
# For single dimensional data,
|
76
|
-
if single_dimensional:
|
89
|
+
if single_dimensional and expand_dims:
|
77
90
|
axis = 0 if arr.ndim == 0 else 1
|
78
91
|
arr = np.expand_dims(arr, axis=axis)
|
79
92
|
|
@@ -45,8 +45,9 @@ class ExampleHelper:
|
|
45
45
|
"""Return a dataframe object about descriptions of all examples."""
|
46
46
|
root_dir = Path(__file__).parent
|
47
47
|
rows = []
|
48
|
+
hide_folders = ["citibike_trip_features", "source_data"]
|
48
49
|
for f_name in os.listdir(root_dir):
|
49
|
-
if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name
|
50
|
+
if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name not in hide_folders:
|
50
51
|
source_file_path = root_dir.joinpath(f"{f_name}/source.yaml")
|
51
52
|
source_dict = self._read_yaml(str(source_file_path))
|
52
53
|
rows.append((f_name, source_dict["model_category"], source_dict["desc"], source_dict["label_columns"]))
|
snowflake/ml/fileset/fileset.py
CHANGED
@@ -11,11 +11,9 @@ from snowflake.ml._internal.exceptions import (
|
|
11
11
|
fileset_error_messages,
|
12
12
|
fileset_errors,
|
13
13
|
)
|
14
|
-
from snowflake.ml._internal.utils import
|
15
|
-
|
16
|
-
|
17
|
-
snowpark_dataframe_utils,
|
18
|
-
)
|
14
|
+
from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
|
15
|
+
from snowflake.ml.data import data_connector
|
16
|
+
from snowflake.ml.data._internal import arrow_ingestor
|
19
17
|
from snowflake.ml.fileset import sfcfs
|
20
18
|
from snowflake.snowpark import exceptions as snowpark_exceptions, functions
|
21
19
|
|
@@ -285,6 +283,16 @@ class FileSet:
|
|
285
283
|
"""Get the Snowflake absolute path to this FileSet directory."""
|
286
284
|
return _fileset_absolute_path(self._target_stage_loc, self.name)
|
287
285
|
|
286
|
+
def _to_data_connector(self) -> data_connector.DataConnector:
|
287
|
+
self._fs.optimize_read(self._list_files())
|
288
|
+
ingester = arrow_ingestor.ArrowIngestor(
|
289
|
+
self._snowpark_session,
|
290
|
+
self._list_files(),
|
291
|
+
format="parquet",
|
292
|
+
filesystem=self._fs,
|
293
|
+
)
|
294
|
+
return data_connector.DataConnector(ingester, expand_dims=False)
|
295
|
+
|
288
296
|
@telemetry.send_api_usage_telemetry(
|
289
297
|
project=_PROJECT,
|
290
298
|
)
|
@@ -362,13 +370,9 @@ class FileSet:
|
|
362
370
|
----
|
363
371
|
{'_COL_1':[10]}
|
364
372
|
"""
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
self._fs.optimize_read(self._list_files())
|
369
|
-
|
370
|
-
input_dp = IterableWrapper(self._list_files())
|
371
|
-
return torch_datapipe_module.ReadAndParseParquet(input_dp, self._fs, batch_size, shuffle, drop_last_batch)
|
373
|
+
return self._to_data_connector().to_torch_datapipe(
|
374
|
+
batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
|
375
|
+
)
|
372
376
|
|
373
377
|
@telemetry.send_api_usage_telemetry(
|
374
378
|
project=_PROJECT,
|
@@ -402,12 +406,8 @@ class FileSet:
|
|
402
406
|
----
|
403
407
|
{'_COL_1': <tf.Tensor: shape=(1,), dtype=int64, numpy=[10]>}
|
404
408
|
"""
|
405
|
-
|
406
|
-
|
407
|
-
self._fs.optimize_read(self._list_files())
|
408
|
-
|
409
|
-
return tf_dataset_module.read_and_parse_parquet(
|
410
|
-
self._list_files(), self._fs, batch_size, shuffle, drop_last_batch
|
409
|
+
return self._to_data_connector().to_tf_dataset(
|
410
|
+
batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
|
411
411
|
)
|
412
412
|
|
413
413
|
@telemetry.send_api_usage_telemetry(
|
@@ -14,7 +14,7 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
15
15
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
16
16
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
17
|
-
from snowflake.snowpark import Session, dataframe
|
17
|
+
from snowflake.snowpark import Session, async_job, dataframe
|
18
18
|
|
19
19
|
_TELEMETRY_PROJECT = "MLOps"
|
20
20
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
@@ -447,13 +447,15 @@ class ModelVersion(lineage_node.LineageNode):
|
|
447
447
|
target_function_info = functions[0]
|
448
448
|
|
449
449
|
if service_name:
|
450
|
+
database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
|
451
|
+
|
450
452
|
return self._model_ops.invoke_method(
|
451
453
|
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
452
454
|
signature=target_function_info["signature"],
|
453
455
|
X=X,
|
454
|
-
database_name=
|
455
|
-
schema_name=
|
456
|
-
service_name=
|
456
|
+
database_name=database_name_id,
|
457
|
+
schema_name=schema_name_id,
|
458
|
+
service_name=service_name_id,
|
457
459
|
strict_input_validation=strict_input_validation,
|
458
460
|
statement_params=statement_params,
|
459
461
|
)
|
@@ -631,7 +633,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
631
633
|
max_batch_rows: Optional[int] = None,
|
632
634
|
force_rebuild: bool = False,
|
633
635
|
build_external_access_integration: Optional[str] = None,
|
634
|
-
|
636
|
+
block: bool = True,
|
637
|
+
) -> Union[str, async_job.AsyncJob]:
|
635
638
|
"""Create an inference service with the given spec.
|
636
639
|
|
637
640
|
Args:
|
@@ -659,6 +662,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
659
662
|
force_rebuild: Whether to force a model inference image rebuild.
|
660
663
|
build_external_access_integration: (Deprecated) The external access integration for image build. This is
|
661
664
|
usually permitting access to conda & PyPI repositories.
|
665
|
+
block: A bool value indicating whether this function will wait until the service is available.
|
666
|
+
When it is ``False``, this function executes the underlying service creation asynchronously
|
667
|
+
and returns an :class:`AsyncJob`.
|
662
668
|
"""
|
663
669
|
...
|
664
670
|
|
@@ -679,7 +685,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
679
685
|
max_batch_rows: Optional[int] = None,
|
680
686
|
force_rebuild: bool = False,
|
681
687
|
build_external_access_integrations: Optional[List[str]] = None,
|
682
|
-
|
688
|
+
block: bool = True,
|
689
|
+
) -> Union[str, async_job.AsyncJob]:
|
683
690
|
"""Create an inference service with the given spec.
|
684
691
|
|
685
692
|
Args:
|
@@ -707,6 +714,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
707
714
|
force_rebuild: Whether to force a model inference image rebuild.
|
708
715
|
build_external_access_integrations: The external access integrations for image build. This is usually
|
709
716
|
permitting access to conda & PyPI repositories.
|
717
|
+
block: A bool value indicating whether this function will wait until the service is available.
|
718
|
+
When it is ``False``, this function executes the underlying service creation asynchronously
|
719
|
+
and returns an :class:`AsyncJob`.
|
710
720
|
"""
|
711
721
|
...
|
712
722
|
|
@@ -742,7 +752,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
742
752
|
force_rebuild: bool = False,
|
743
753
|
build_external_access_integration: Optional[str] = None,
|
744
754
|
build_external_access_integrations: Optional[List[str]] = None,
|
745
|
-
|
755
|
+
block: bool = True,
|
756
|
+
) -> Union[str, async_job.AsyncJob]:
|
746
757
|
"""Create an inference service with the given spec.
|
747
758
|
|
748
759
|
Args:
|
@@ -772,12 +783,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
772
783
|
usually permitting access to conda & PyPI repositories.
|
773
784
|
build_external_access_integrations: The external access integrations for image build. This is usually
|
774
785
|
permitting access to conda & PyPI repositories.
|
786
|
+
block: A bool value indicating whether this function will wait until the service is available.
|
787
|
+
When it is False, this function executes the underlying service creation asynchronously
|
788
|
+
and returns an AsyncJob.
|
775
789
|
|
776
790
|
Raises:
|
777
791
|
ValueError: Illegal external access integration arguments.
|
778
792
|
|
779
793
|
Returns:
|
780
|
-
|
794
|
+
If `block=True`, return result information about service creation from server.
|
795
|
+
Otherwise, return the service creation AsyncJob.
|
781
796
|
"""
|
782
797
|
statement_params = telemetry.get_statement_params(
|
783
798
|
project=_TELEMETRY_PROJECT,
|
@@ -829,6 +844,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
829
844
|
if build_external_access_integrations is None
|
830
845
|
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
831
846
|
),
|
847
|
+
block=block,
|
832
848
|
statement_params=statement_params,
|
833
849
|
)
|
834
850
|
|
@@ -168,14 +168,10 @@ class ModelOperator:
|
|
168
168
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
169
169
|
model_name: sql_identifier.SqlIdentifier,
|
170
170
|
version_name: sql_identifier.SqlIdentifier,
|
171
|
+
model_exists: bool,
|
171
172
|
statement_params: Optional[Dict[str, Any]] = None,
|
172
173
|
) -> None:
|
173
|
-
if
|
174
|
-
database_name=database_name,
|
175
|
-
schema_name=schema_name,
|
176
|
-
model_name=model_name,
|
177
|
-
statement_params=statement_params,
|
178
|
-
):
|
174
|
+
if model_exists:
|
179
175
|
return self._model_version_client.add_version_from_model_version(
|
180
176
|
source_database_name=source_database_name,
|
181
177
|
source_schema_name=source_schema_name,
|
@@ -6,7 +6,7 @@ import re
|
|
6
6
|
import tempfile
|
7
7
|
import threading
|
8
8
|
import time
|
9
|
-
from typing import Any, Dict, List, Optional, Tuple, cast
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
10
10
|
|
11
11
|
from packaging import version
|
12
12
|
|
@@ -15,7 +15,7 @@ from snowflake.ml._internal import file_utils
|
|
15
15
|
from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
|
16
16
|
from snowflake.ml.model._client.service import model_deployment_spec
|
17
17
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
18
|
-
from snowflake.snowpark import exceptions, row, session
|
18
|
+
from snowflake.snowpark import async_job, exceptions, row, session
|
19
19
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
20
20
|
|
21
21
|
module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
|
@@ -107,8 +107,9 @@ class ServiceOperator:
|
|
107
107
|
max_batch_rows: Optional[int],
|
108
108
|
force_rebuild: bool,
|
109
109
|
build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
|
110
|
+
block: bool,
|
110
111
|
statement_params: Optional[Dict[str, Any]] = None,
|
111
|
-
) -> str:
|
112
|
+
) -> Union[str, async_job.AsyncJob]:
|
112
113
|
|
113
114
|
# Fall back to the registry's database and schema if not provided
|
114
115
|
database_name = database_name or self._database_name
|
@@ -204,11 +205,15 @@ class ServiceOperator:
|
|
204
205
|
log_thread = self._start_service_log_streaming(
|
205
206
|
async_job, services, model_inference_service_exists, force_rebuild, statement_params
|
206
207
|
)
|
207
|
-
log_thread.join()
|
208
208
|
|
209
|
-
|
210
|
-
|
211
|
-
|
209
|
+
if block:
|
210
|
+
log_thread.join()
|
211
|
+
|
212
|
+
res = cast(str, cast(List[row.Row], async_job.result())[0][0])
|
213
|
+
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
214
|
+
return res
|
215
|
+
else:
|
216
|
+
return async_job
|
212
217
|
|
213
218
|
def _start_service_log_streaming(
|
214
219
|
self,
|
@@ -10,6 +10,7 @@ from snowflake.ml._internal.utils import (
|
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
12
|
from snowflake.ml.model._client.sql import _base
|
13
|
+
from snowflake.ml.model._model_composer.model_method import constants
|
13
14
|
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
14
15
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
15
16
|
|
@@ -333,6 +334,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
333
334
|
|
334
335
|
args_sql = ", ".join(args_sql_list)
|
335
336
|
|
337
|
+
wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
338
|
+
if wide_input:
|
339
|
+
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
340
|
+
args_sql = f"object_construct_keep_null({input_args_sql})"
|
341
|
+
|
336
342
|
sql = textwrap.dedent(
|
337
343
|
f"""WITH {','.join(with_statements)}
|
338
344
|
SELECT *,
|
@@ -412,6 +418,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
412
418
|
|
413
419
|
args_sql = ", ".join(args_sql_list)
|
414
420
|
|
421
|
+
wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
422
|
+
if wide_input:
|
423
|
+
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
424
|
+
args_sql = f"object_construct_keep_null({input_args_sql})"
|
425
|
+
|
415
426
|
sql = textwrap.dedent(
|
416
427
|
f"""WITH {','.join(with_statements)}
|
417
428
|
SELECT *,
|