snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/jwt_generator.py +141 -0
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +26 -12
- snowflake/ml/model/_client/ops/model_ops.py +51 -30
- snowflake/ml/model/_client/ops/service_ops.py +25 -9
- snowflake/ml/model/_client/sql/model.py +0 -14
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_client/sql/stage.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
- snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +71 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -1
- snowflake/ml/model/model_signature.py +38 -9
- snowflake/ml/model/type_hints.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +148 -1200
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +114 -238
- snowflake/ml/monitoring/entities/model_monitor_config.py +38 -12
- snowflake/ml/monitoring/model_monitor.py +12 -86
- snowflake/ml/registry/registry.py +28 -40
- snowflake/ml/utils/authentication.py +75 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/METADATA +116 -52
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/RECORD +51 -49
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- snowflake/ml/monitoring/entities/output_score_type.py +0 -90
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/top_level.txt +0 -0
snowflake/cortex/_sse_client.py
CHANGED
@@ -1,73 +1,125 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
from typing import Any, Iterator, Optional
|
2
3
|
|
3
|
-
|
4
|
+
_FIELD_SEPARATOR = ":"
|
4
5
|
|
5
6
|
|
6
7
|
class Event:
|
7
|
-
|
8
|
+
"""Representation of an event from the event stream."""
|
9
|
+
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
id: Optional[str] = None,
|
13
|
+
event: str = "message",
|
14
|
+
data: str = "",
|
15
|
+
comment: Optional[str] = None,
|
16
|
+
retry: Optional[int] = None,
|
17
|
+
) -> None:
|
18
|
+
self.id = id
|
8
19
|
self.event = event
|
9
20
|
self.data = data
|
21
|
+
self.comment = comment
|
22
|
+
self.retry = retry
|
10
23
|
|
11
24
|
def __str__(self) -> str:
|
12
25
|
s = f"{self.event} event"
|
26
|
+
if self.id:
|
27
|
+
s += f" #{self.id}"
|
13
28
|
if self.data:
|
14
|
-
s +=
|
29
|
+
s += ", {} byte{}".format(len(self.data), "s" if len(self.data) else "")
|
15
30
|
else:
|
16
31
|
s += ", no data"
|
32
|
+
if self.comment:
|
33
|
+
s += f", comment: {self.comment}"
|
34
|
+
if self.retry:
|
35
|
+
s += f", retry in {self.retry}ms"
|
17
36
|
return s
|
18
37
|
|
19
38
|
|
39
|
+
# This is copied from the snowpy library:
|
40
|
+
# https://github.com/snowflakedb/snowpy/blob/main/libs/snowflake.core/src/snowflake/core/rest.py#L39
|
41
|
+
# TODO(SNOW-1750723) - Current there’s code duplication across snowflake-ml-python
|
42
|
+
# and snowpy library for Cortex REST API which was done to meet our GA timelines
|
43
|
+
# Once snowpy has a release with https://github.com/snowflakedb/snowpy/pull/679, we should
|
44
|
+
# remove the class here and directly refer from the snowflake.core package directly
|
20
45
|
class SSEClient:
|
21
|
-
def __init__(self,
|
46
|
+
def __init__(self, event_source: Any, char_enc: str = "utf-8") -> None:
|
47
|
+
self._event_source = event_source
|
48
|
+
self._char_enc = char_enc
|
22
49
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
lines = b""
|
28
|
-
for chunk in self.response:
|
50
|
+
def _read(self) -> Iterator[bytes]:
|
51
|
+
data = b""
|
52
|
+
for chunk in self._event_source:
|
29
53
|
for line in chunk.splitlines(True):
|
30
|
-
|
31
|
-
if
|
32
|
-
yield
|
33
|
-
|
34
|
-
if
|
35
|
-
yield
|
54
|
+
data += line
|
55
|
+
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
56
|
+
yield data
|
57
|
+
data = b""
|
58
|
+
if data:
|
59
|
+
yield data
|
36
60
|
|
37
61
|
def events(self) -> Iterator[Event]:
|
38
|
-
|
62
|
+
content_type = self._event_source.headers.get("Content-Type")
|
63
|
+
# The check for empty content-type is present because it's being populated after
|
64
|
+
# the change in https://github.com/snowflakedb/snowflake/pull/217654.
|
65
|
+
# This can be removed once the above change makes it to prod or we move to snowpy
|
66
|
+
# for SSEClient implementation.
|
67
|
+
if content_type == "text/event-stream" or not content_type:
|
68
|
+
return self._handle_sse()
|
69
|
+
elif content_type == "application/json":
|
70
|
+
return self._handle_json()
|
71
|
+
else:
|
72
|
+
raise ValueError(f"Unknown Content-Type: {content_type}")
|
73
|
+
|
74
|
+
def _handle_sse(self) -> Iterator[Event]:
|
75
|
+
for chunk in self._read():
|
39
76
|
event = Event()
|
40
|
-
# splitlines() only uses \r and \n
|
41
|
-
for
|
77
|
+
# Split before decoding so splitlines() only uses \r and \n
|
78
|
+
for line_bytes in chunk.splitlines():
|
79
|
+
# Decode the line.
|
80
|
+
line = line_bytes.decode(self._char_enc)
|
42
81
|
|
43
|
-
|
82
|
+
# Lines starting with a separator are comments and are to be
|
83
|
+
# ignored.
|
84
|
+
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
|
85
|
+
continue
|
44
86
|
|
45
|
-
data = line.split(
|
87
|
+
data = line.split(_FIELD_SEPARATOR, 1)
|
46
88
|
field = data[0]
|
47
89
|
|
90
|
+
# Ignore unknown fields.
|
91
|
+
if not hasattr(event, field):
|
92
|
+
continue
|
93
|
+
|
48
94
|
if len(data) > 1:
|
95
|
+
# From the spec:
|
49
96
|
# "If value starts with a single U+0020 SPACE character,
|
50
|
-
# remove it from value.
|
97
|
+
# remove it from value."
|
51
98
|
if data[1].startswith(" "):
|
52
99
|
value = data[1][1:]
|
53
100
|
else:
|
54
101
|
value = data[1]
|
55
102
|
else:
|
103
|
+
# If no value is present after the separator,
|
104
|
+
# assume an empty value.
|
56
105
|
value = ""
|
57
106
|
|
58
107
|
# The data field may come over multiple lines and their values
|
59
108
|
# are concatenated with each other.
|
109
|
+
current_value = getattr(event, field, "")
|
60
110
|
if field == "data":
|
61
|
-
|
62
|
-
|
63
|
-
|
111
|
+
new_value = current_value + value + "\n"
|
112
|
+
else:
|
113
|
+
new_value = value
|
114
|
+
setattr(event, field, new_value)
|
64
115
|
|
116
|
+
# Events with no data are not dispatched.
|
65
117
|
if not event.data:
|
66
118
|
continue
|
67
119
|
|
68
120
|
# If the data field ends with a newline, remove it.
|
69
121
|
if event.data.endswith("\n"):
|
70
|
-
event.data = event.data[0:-1]
|
122
|
+
event.data = event.data[0:-1]
|
71
123
|
|
72
124
|
# Empty event names default to 'message'
|
73
125
|
event.event = event.event or "message"
|
@@ -77,5 +129,16 @@ class SSEClient:
|
|
77
129
|
|
78
130
|
yield event
|
79
131
|
|
132
|
+
def _handle_json(self) -> Iterator[Event]:
|
133
|
+
data_list = json.loads(self._event_source.data.decode(self._char_enc))
|
134
|
+
for data in data_list:
|
135
|
+
yield Event(
|
136
|
+
id=data.get("id"),
|
137
|
+
event=data.get("event"),
|
138
|
+
data=data.get("data"),
|
139
|
+
comment=data.get("comment"),
|
140
|
+
retry=data.get("retry"),
|
141
|
+
)
|
142
|
+
|
80
143
|
def close(self) -> None:
|
81
|
-
self.
|
144
|
+
self._event_source.close()
|
snowflake/cortex/_util.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
|
-
from typing import Dict, List, Optional, Union, cast
|
1
|
+
from typing import Any, Dict, List, Optional, Union, cast
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
5
|
+
from snowflake.ml._internal.utils import formatting
|
4
6
|
from snowflake.snowpark import context, functions
|
5
7
|
|
6
8
|
CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
|
@@ -64,3 +66,30 @@ def _call_sql_function_immediate(
|
|
64
66
|
empty_df = session.create_dataframe([snowpark.Row()])
|
65
67
|
df = empty_df.select(functions.builtin(function)(*lit_args))
|
66
68
|
return cast(str, df.collect()[0][0])
|
69
|
+
|
70
|
+
|
71
|
+
def call_sql_function_literals(function: str, session: Optional[snowpark.Session], *args: Any) -> str:
|
72
|
+
r"""Call a SQL function with only literal arguments.
|
73
|
+
|
74
|
+
This is useful for calling system functions.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
function: The name of the function to be called.
|
78
|
+
session: The Snowpark session to use.
|
79
|
+
*args: The list of arguments
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
String value that corresponds the the first cell in the dataframe.
|
83
|
+
|
84
|
+
Raises:
|
85
|
+
SnowflakeMLException: If no session is given and no active session exists.
|
86
|
+
"""
|
87
|
+
if session is None:
|
88
|
+
session = context.get_active_session()
|
89
|
+
if session is None:
|
90
|
+
raise exceptions.SnowflakeMLException(
|
91
|
+
error_code=error_codes.INVALID_SNOWPARK_SESSION,
|
92
|
+
)
|
93
|
+
|
94
|
+
function_arguments = ",".join(["NULL" if arg is None else formatting.format_value_for_select(arg) for arg in args])
|
95
|
+
return cast(str, session.sql(f"SELECT {function}({function_arguments})").collect()[0][0])
|
@@ -1,4 +1,4 @@
|
|
1
|
-
import
|
1
|
+
import importlib
|
2
2
|
from typing import Any, Generic, Type, TypeVar, Union, cast
|
3
3
|
|
4
4
|
import numpy as np
|
@@ -51,8 +51,8 @@ class LazyType(Generic[T]):
|
|
51
51
|
def get_class(self) -> Type[T]:
|
52
52
|
if self._runtime_class is None:
|
53
53
|
try:
|
54
|
-
m =
|
55
|
-
except
|
54
|
+
m = importlib.import_module(self.module)
|
55
|
+
except ModuleNotFoundError:
|
56
56
|
raise ValueError(f"Module {self.module} not imported.")
|
57
57
|
|
58
58
|
self._runtime_class = cast("Type[T]", getattr(m, self.qualname))
|
@@ -0,0 +1,141 @@
|
|
1
|
+
import base64
|
2
|
+
import hashlib
|
3
|
+
import logging
|
4
|
+
from datetime import datetime, timedelta, timezone
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
import jwt
|
8
|
+
from cryptography.hazmat.primitives import serialization
|
9
|
+
from cryptography.hazmat.primitives.asymmetric import types
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
ISSUER = "iss"
|
14
|
+
EXPIRE_TIME = "exp"
|
15
|
+
ISSUE_TIME = "iat"
|
16
|
+
SUBJECT = "sub"
|
17
|
+
|
18
|
+
|
19
|
+
class JWTGenerator:
|
20
|
+
"""
|
21
|
+
Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator
|
22
|
+
keeps the generated token and only regenerates the token if a specified period of time has passed.
|
23
|
+
"""
|
24
|
+
|
25
|
+
_DEFAULT_LIFETIME = timedelta(minutes=59) # The tokens will have a 59-minute lifetime
|
26
|
+
_DEFAULT_RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes
|
27
|
+
ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
account: str,
|
32
|
+
user: str,
|
33
|
+
private_key: types.PRIVATE_KEY_TYPES,
|
34
|
+
lifetime: Optional[timedelta] = None,
|
35
|
+
renewal_delay: Optional[timedelta] = None,
|
36
|
+
) -> None:
|
37
|
+
"""
|
38
|
+
Create a new JWTGenerator object.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
account: The account identifier.
|
42
|
+
user: The username.
|
43
|
+
private_key: The private key used to sign the JWT.
|
44
|
+
lifetime: The lifetime of the token.
|
45
|
+
renewal_delay: The time before the token expires to renew it.
|
46
|
+
"""
|
47
|
+
|
48
|
+
# Construct the fully qualified name of the user in uppercase.
|
49
|
+
self.account = JWTGenerator._prepare_account_name_for_jwt(account)
|
50
|
+
self.user = user.upper()
|
51
|
+
self.qualified_username = self.account + "." + self.user
|
52
|
+
self.private_key = private_key
|
53
|
+
self.public_key_fp = JWTGenerator._calculate_public_key_fingerprint(self.private_key)
|
54
|
+
|
55
|
+
self.issuer = self.qualified_username + "." + self.public_key_fp
|
56
|
+
self.lifetime = lifetime or JWTGenerator._DEFAULT_LIFETIME
|
57
|
+
self.renewal_delay = renewal_delay or JWTGenerator._DEFAULT_RENEWAL_DELTA
|
58
|
+
self.renew_time = datetime.now(timezone.utc)
|
59
|
+
self.token: Optional[str] = None
|
60
|
+
|
61
|
+
logger.info(
|
62
|
+
"""Creating JWTGenerator with arguments
|
63
|
+
account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
|
64
|
+
self.account,
|
65
|
+
self.user,
|
66
|
+
self.lifetime,
|
67
|
+
self.renewal_delay,
|
68
|
+
)
|
69
|
+
|
70
|
+
@staticmethod
|
71
|
+
def _prepare_account_name_for_jwt(raw_account: str) -> str:
|
72
|
+
account = raw_account
|
73
|
+
if ".global" not in account:
|
74
|
+
# Handle the general case.
|
75
|
+
idx = account.find(".")
|
76
|
+
if idx > 0:
|
77
|
+
account = account[0:idx]
|
78
|
+
else:
|
79
|
+
# Handle the replication case.
|
80
|
+
idx = account.find("-")
|
81
|
+
if idx > 0:
|
82
|
+
account = account[0:idx]
|
83
|
+
# Use uppercase for the account identifier.
|
84
|
+
return account.upper()
|
85
|
+
|
86
|
+
def get_token(self) -> str:
|
87
|
+
now = datetime.now(timezone.utc) # Fetch the current time
|
88
|
+
if self.token is not None and self.renew_time > now:
|
89
|
+
return self.token
|
90
|
+
|
91
|
+
# If the token has expired or doesn't exist, regenerate the token.
|
92
|
+
logger.info(
|
93
|
+
"Generating a new token because the present time (%s) is later than the renewal time (%s)",
|
94
|
+
now,
|
95
|
+
self.renew_time,
|
96
|
+
)
|
97
|
+
# Calculate the next time we need to renew the token.
|
98
|
+
self.renew_time = now + self.renewal_delay
|
99
|
+
|
100
|
+
# Create our payload
|
101
|
+
payload = {
|
102
|
+
# Set the issuer to the fully qualified username concatenated with the public key fingerprint.
|
103
|
+
ISSUER: self.issuer,
|
104
|
+
# Set the subject to the fully qualified username.
|
105
|
+
SUBJECT: self.qualified_username,
|
106
|
+
# Set the issue time to now.
|
107
|
+
ISSUE_TIME: now,
|
108
|
+
# Set the expiration time, based on the lifetime specified for this object.
|
109
|
+
EXPIRE_TIME: now + self.lifetime,
|
110
|
+
}
|
111
|
+
|
112
|
+
# Regenerate the actual token
|
113
|
+
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
|
114
|
+
# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
|
115
|
+
# If the token is a byte string, convert it to a string.
|
116
|
+
if isinstance(token, bytes):
|
117
|
+
token = token.decode("utf-8")
|
118
|
+
self.token = token
|
119
|
+
logger.info(
|
120
|
+
"Generated a JWT with the following payload: %s",
|
121
|
+
jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
|
122
|
+
)
|
123
|
+
|
124
|
+
return token
|
125
|
+
|
126
|
+
@staticmethod
|
127
|
+
def _calculate_public_key_fingerprint(private_key: types.PRIVATE_KEY_TYPES) -> str:
|
128
|
+
# Get the raw bytes of public key.
|
129
|
+
public_key_raw = private_key.public_key().public_bytes(
|
130
|
+
serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo
|
131
|
+
)
|
132
|
+
|
133
|
+
# Get the sha256 hash of the raw bytes.
|
134
|
+
sha256hash = hashlib.sha256()
|
135
|
+
sha256hash.update(public_key_raw)
|
136
|
+
|
137
|
+
# Base64-encode the value and prepend the prefix 'SHA256:'.
|
138
|
+
public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
|
139
|
+
logger.info("Public key fingerprint is %s", public_key_fp)
|
140
|
+
|
141
|
+
return public_key_fp
|
@@ -0,0 +1,5 @@
|
|
1
|
+
from .data_connector import DataConnector
|
2
|
+
from .data_ingestor import DataIngestor, DataIngestorType
|
3
|
+
from .data_source import DataFrameInfo, DatasetInfo, DataSource
|
4
|
+
|
5
|
+
__all__ = ["DataConnector", "DataSource", "DataFrameInfo", "DatasetInfo", "DataIngestor", "DataIngestorType"]
|
@@ -14,7 +14,7 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
15
15
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
16
16
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
17
|
-
from snowflake.snowpark import Session, dataframe
|
17
|
+
from snowflake.snowpark import Session, async_job, dataframe
|
18
18
|
|
19
19
|
_TELEMETRY_PROJECT = "MLOps"
|
20
20
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
@@ -631,7 +631,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
631
631
|
max_batch_rows: Optional[int] = None,
|
632
632
|
force_rebuild: bool = False,
|
633
633
|
build_external_access_integration: Optional[str] = None,
|
634
|
-
|
634
|
+
block: bool = True,
|
635
|
+
) -> Union[str, async_job.AsyncJob]:
|
635
636
|
"""Create an inference service with the given spec.
|
636
637
|
|
637
638
|
Args:
|
@@ -659,6 +660,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
659
660
|
force_rebuild: Whether to force a model inference image rebuild.
|
660
661
|
build_external_access_integration: (Deprecated) The external access integration for image build. This is
|
661
662
|
usually permitting access to conda & PyPI repositories.
|
663
|
+
block: A bool value indicating whether this function will wait until the service is available.
|
664
|
+
When it is ``False``, this function executes the underlying service creation asynchronously
|
665
|
+
and returns an :class:`AsyncJob`.
|
662
666
|
"""
|
663
667
|
...
|
664
668
|
|
@@ -679,7 +683,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
679
683
|
max_batch_rows: Optional[int] = None,
|
680
684
|
force_rebuild: bool = False,
|
681
685
|
build_external_access_integrations: Optional[List[str]] = None,
|
682
|
-
|
686
|
+
block: bool = True,
|
687
|
+
) -> Union[str, async_job.AsyncJob]:
|
683
688
|
"""Create an inference service with the given spec.
|
684
689
|
|
685
690
|
Args:
|
@@ -707,6 +712,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
707
712
|
force_rebuild: Whether to force a model inference image rebuild.
|
708
713
|
build_external_access_integrations: The external access integrations for image build. This is usually
|
709
714
|
permitting access to conda & PyPI repositories.
|
715
|
+
block: A bool value indicating whether this function will wait until the service is available.
|
716
|
+
When it is ``False``, this function executes the underlying service creation asynchronously
|
717
|
+
and returns an :class:`AsyncJob`.
|
710
718
|
"""
|
711
719
|
...
|
712
720
|
|
@@ -742,7 +750,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
742
750
|
force_rebuild: bool = False,
|
743
751
|
build_external_access_integration: Optional[str] = None,
|
744
752
|
build_external_access_integrations: Optional[List[str]] = None,
|
745
|
-
|
753
|
+
block: bool = True,
|
754
|
+
) -> Union[str, async_job.AsyncJob]:
|
746
755
|
"""Create an inference service with the given spec.
|
747
756
|
|
748
757
|
Args:
|
@@ -772,12 +781,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
772
781
|
usually permitting access to conda & PyPI repositories.
|
773
782
|
build_external_access_integrations: The external access integrations for image build. This is usually
|
774
783
|
permitting access to conda & PyPI repositories.
|
784
|
+
block: A bool value indicating whether this function will wait until the service is available.
|
785
|
+
When it is False, this function executes the underlying service creation asynchronously
|
786
|
+
and returns an AsyncJob.
|
775
787
|
|
776
788
|
Raises:
|
777
789
|
ValueError: Illegal external access integration arguments.
|
778
790
|
|
779
791
|
Returns:
|
780
|
-
|
792
|
+
If `block=True`, return result information about service creation from server.
|
793
|
+
Otherwise, return the service creation AsyncJob.
|
781
794
|
"""
|
782
795
|
statement_params = telemetry.get_statement_params(
|
783
796
|
project=_TELEMETRY_PROJECT,
|
@@ -829,6 +842,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
829
842
|
if build_external_access_integrations is None
|
830
843
|
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
831
844
|
),
|
845
|
+
block=block,
|
832
846
|
statement_params=statement_params,
|
833
847
|
)
|
834
848
|
|
@@ -851,17 +865,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
851
865
|
)
|
852
866
|
|
853
867
|
return pd.DataFrame(
|
854
|
-
self._model_ops.
|
868
|
+
self._model_ops.show_services(
|
855
869
|
database_name=None,
|
856
870
|
schema_name=None,
|
857
871
|
model_name=self._model_name,
|
858
872
|
version_name=self._version_name,
|
859
873
|
statement_params=statement_params,
|
860
|
-
)
|
861
|
-
columns=[
|
862
|
-
self._model_ops.INFERENCE_SERVICE_NAME_COL_NAME,
|
863
|
-
self._model_ops.INFERENCE_SERVICE_ENDPOINT_COL_NAME,
|
864
|
-
],
|
874
|
+
)
|
865
875
|
)
|
866
876
|
|
867
877
|
@telemetry.send_api_usage_telemetry(
|
@@ -889,12 +899,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
889
899
|
project=_TELEMETRY_PROJECT,
|
890
900
|
subproject=_TELEMETRY_SUBPROJECT,
|
891
901
|
)
|
902
|
+
|
903
|
+
database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
|
892
904
|
self._model_ops.delete_service(
|
893
905
|
database_name=None,
|
894
906
|
schema_name=None,
|
895
907
|
model_name=self._model_name,
|
896
908
|
version_name=self._version_name,
|
897
|
-
|
909
|
+
service_database_name=database_name_id,
|
910
|
+
service_schema_name=schema_name_id,
|
911
|
+
service_name=service_name_id,
|
898
912
|
statement_params=statement_params,
|
899
913
|
)
|
900
914
|
|
@@ -3,7 +3,7 @@ import os
|
|
3
3
|
import pathlib
|
4
4
|
import tempfile
|
5
5
|
import warnings
|
6
|
-
from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
|
6
|
+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast, overload
|
7
7
|
|
8
8
|
import yaml
|
9
9
|
|
@@ -31,9 +31,14 @@ from snowflake.snowpark import dataframe, row, session
|
|
31
31
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
32
32
|
|
33
33
|
|
34
|
+
class ServiceInfo(TypedDict):
|
35
|
+
name: str
|
36
|
+
inference_endpoint: Optional[str]
|
37
|
+
|
38
|
+
|
34
39
|
class ModelOperator:
|
35
|
-
|
36
|
-
|
40
|
+
INFERENCE_SERVICE_ENDPOINT_NAME = "inference"
|
41
|
+
INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app"
|
37
42
|
|
38
43
|
def __init__(
|
39
44
|
self,
|
@@ -517,7 +522,7 @@ class ModelOperator:
|
|
517
522
|
statement_params=statement_params,
|
518
523
|
)
|
519
524
|
|
520
|
-
def
|
525
|
+
def show_services(
|
521
526
|
self,
|
522
527
|
*,
|
523
528
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -525,7 +530,7 @@ class ModelOperator:
|
|
525
530
|
model_name: sql_identifier.SqlIdentifier,
|
526
531
|
version_name: sql_identifier.SqlIdentifier,
|
527
532
|
statement_params: Optional[Dict[str, Any]] = None,
|
528
|
-
) ->
|
533
|
+
) -> List[ServiceInfo]:
|
529
534
|
res = self._model_client.show_versions(
|
530
535
|
database_name=database_name,
|
531
536
|
schema_name=schema_name,
|
@@ -546,21 +551,28 @@ class ModelOperator:
|
|
546
551
|
|
547
552
|
json_array = json.loads(res[0][service_col_name])
|
548
553
|
# TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
for
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
554
|
+
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
555
|
+
|
556
|
+
result = []
|
557
|
+
ingress_url: Optional[str] = None
|
558
|
+
for fully_qualified_service_name in fully_qualified_service_names:
|
559
|
+
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
560
|
+
for res_row in self._service_client.show_endpoints(
|
561
|
+
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
562
|
+
):
|
563
|
+
if (
|
564
|
+
res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME]
|
565
|
+
== self.INFERENCE_SERVICE_ENDPOINT_NAME
|
566
|
+
and res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] is not None
|
567
|
+
):
|
568
|
+
ingress_url = str(
|
569
|
+
res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME]
|
570
|
+
)
|
571
|
+
if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
|
572
|
+
ingress_url = None
|
573
|
+
result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url))
|
574
|
+
|
575
|
+
return result
|
564
576
|
|
565
577
|
def delete_service(
|
566
578
|
self,
|
@@ -569,33 +581,42 @@ class ModelOperator:
|
|
569
581
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
570
582
|
model_name: sql_identifier.SqlIdentifier,
|
571
583
|
version_name: sql_identifier.SqlIdentifier,
|
572
|
-
|
584
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
585
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
586
|
+
service_name: sql_identifier.SqlIdentifier,
|
573
587
|
statement_params: Optional[Dict[str, Any]] = None,
|
574
588
|
) -> None:
|
575
|
-
services = self.
|
589
|
+
services = self.show_services(
|
576
590
|
database_name=database_name,
|
577
591
|
schema_name=schema_name,
|
578
592
|
model_name=model_name,
|
579
593
|
version_name=version_name,
|
580
594
|
statement_params=statement_params,
|
581
595
|
)
|
582
|
-
|
596
|
+
|
597
|
+
# Fall back to the model's database and schema.
|
598
|
+
# database_name or schema_name are set if the model is created or get using fully qualified name
|
599
|
+
# Otherwise, the model's database and schema are same as registry's database and schema, which are set in the
|
600
|
+
# self._model_client.
|
601
|
+
|
602
|
+
service_database_name = service_database_name or database_name or self._model_client._database_name
|
603
|
+
service_schema_name = service_schema_name or schema_name or self._model_client._schema_name
|
583
604
|
fully_qualified_service_name = sql_identifier.get_fully_qualified_name(
|
584
|
-
|
605
|
+
service_database_name, service_schema_name, service_name
|
585
606
|
)
|
586
607
|
|
587
|
-
|
588
|
-
|
589
|
-
if service == fully_qualified_service_name:
|
608
|
+
for service_info in services:
|
609
|
+
if service_info["name"] == fully_qualified_service_name:
|
590
610
|
self._service_client.drop_service(
|
591
|
-
database_name=
|
592
|
-
schema_name=
|
611
|
+
database_name=service_database_name,
|
612
|
+
schema_name=service_schema_name,
|
593
613
|
service_name=service_name,
|
594
614
|
statement_params=statement_params,
|
595
615
|
)
|
596
616
|
return
|
597
617
|
raise ValueError(
|
598
|
-
f"Service '{
|
618
|
+
f"Service '{fully_qualified_service_name}' does not exist "
|
619
|
+
"or unauthorized or not associated with this model version."
|
599
620
|
)
|
600
621
|
|
601
622
|
def get_model_version_manifest(
|