snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/type_utils.py +3 -3
  7. snowflake/ml/_internal/utils/jwt_generator.py +141 -0
  8. snowflake/ml/data/__init__.py +5 -0
  9. snowflake/ml/model/_client/model/model_version_impl.py +26 -12
  10. snowflake/ml/model/_client/ops/model_ops.py +51 -30
  11. snowflake/ml/model/_client/ops/service_ops.py +25 -9
  12. snowflake/ml/model/_client/sql/model.py +0 -14
  13. snowflake/ml/model/_client/sql/service.py +25 -1
  14. snowflake/ml/model/_client/sql/stage.py +1 -1
  15. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  16. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  17. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
  18. snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
  19. snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
  20. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
  21. snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -1
  22. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
  23. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  24. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  25. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  26. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  28. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  29. snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
  30. snowflake/ml/model/_signatures/core.py +63 -16
  31. snowflake/ml/model/_signatures/pandas_handler.py +71 -27
  32. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  33. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  34. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  35. snowflake/ml/model/_signatures/utils.py +4 -1
  36. snowflake/ml/model/model_signature.py +38 -9
  37. snowflake/ml/model/type_hints.py +1 -1
  38. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  39. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  40. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +148 -1200
  41. snowflake/ml/monitoring/_manager/model_monitor_manager.py +114 -238
  42. snowflake/ml/monitoring/entities/model_monitor_config.py +38 -12
  43. snowflake/ml/monitoring/model_monitor.py +12 -86
  44. snowflake/ml/registry/registry.py +28 -40
  45. snowflake/ml/utils/authentication.py +75 -0
  46. snowflake/ml/version.py +1 -1
  47. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/METADATA +116 -52
  48. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/RECORD +51 -49
  49. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/WHEEL +1 -1
  50. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  51. snowflake/ml/monitoring/entities/output_score_type.py +0 -90
  52. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/LICENSE.txt +0 -0
  53. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/top_level.txt +0 -0
@@ -1,73 +1,125 @@
1
- from typing import Iterator, cast
1
+ import json
2
+ from typing import Any, Iterator, Optional
2
3
 
3
- import requests
4
+ _FIELD_SEPARATOR = ":"
4
5
 
5
6
 
6
7
  class Event:
7
- def __init__(self, event: str = "message", data: str = "") -> None:
8
+ """Representation of an event from the event stream."""
9
+
10
+ def __init__(
11
+ self,
12
+ id: Optional[str] = None,
13
+ event: str = "message",
14
+ data: str = "",
15
+ comment: Optional[str] = None,
16
+ retry: Optional[int] = None,
17
+ ) -> None:
18
+ self.id = id
8
19
  self.event = event
9
20
  self.data = data
21
+ self.comment = comment
22
+ self.retry = retry
10
23
 
11
24
  def __str__(self) -> str:
12
25
  s = f"{self.event} event"
26
+ if self.id:
27
+ s += f" #{self.id}"
13
28
  if self.data:
14
- s += f", {len(self.data)} bytes"
29
+ s += ", {} byte{}".format(len(self.data), "s" if len(self.data) else "")
15
30
  else:
16
31
  s += ", no data"
32
+ if self.comment:
33
+ s += f", comment: {self.comment}"
34
+ if self.retry:
35
+ s += f", retry in {self.retry}ms"
17
36
  return s
18
37
 
19
38
 
39
+ # This is copied from the snowpy library:
40
+ # https://github.com/snowflakedb/snowpy/blob/main/libs/snowflake.core/src/snowflake/core/rest.py#L39
41
+ # TODO(SNOW-1750723) - Current there’s code duplication across snowflake-ml-python
42
+ # and snowpy library for Cortex REST API which was done to meet our GA timelines
43
+ # Once snowpy has a release with https://github.com/snowflakedb/snowpy/pull/679, we should
44
+ # remove the class here and directly refer from the snowflake.core package directly
20
45
  class SSEClient:
21
- def __init__(self, response: requests.Response) -> None:
46
+ def __init__(self, event_source: Any, char_enc: str = "utf-8") -> None:
47
+ self._event_source = event_source
48
+ self._char_enc = char_enc
22
49
 
23
- self.response = response
24
-
25
- def _read(self) -> Iterator[str]:
26
-
27
- lines = b""
28
- for chunk in self.response:
50
+ def _read(self) -> Iterator[bytes]:
51
+ data = b""
52
+ for chunk in self._event_source:
29
53
  for line in chunk.splitlines(True):
30
- lines += line
31
- if lines.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
32
- yield cast(str, lines)
33
- lines = b""
34
- if lines:
35
- yield cast(str, lines)
54
+ data += line
55
+ if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
56
+ yield data
57
+ data = b""
58
+ if data:
59
+ yield data
36
60
 
37
61
  def events(self) -> Iterator[Event]:
38
- for raw_event in self._read():
62
+ content_type = self._event_source.headers.get("Content-Type")
63
+ # The check for empty content-type is present because it's being populated after
64
+ # the change in https://github.com/snowflakedb/snowflake/pull/217654.
65
+ # This can be removed once the above change makes it to prod or we move to snowpy
66
+ # for SSEClient implementation.
67
+ if content_type == "text/event-stream" or not content_type:
68
+ return self._handle_sse()
69
+ elif content_type == "application/json":
70
+ return self._handle_json()
71
+ else:
72
+ raise ValueError(f"Unknown Content-Type: {content_type}")
73
+
74
+ def _handle_sse(self) -> Iterator[Event]:
75
+ for chunk in self._read():
39
76
  event = Event()
40
- # splitlines() only uses \r and \n
41
- for line in raw_event.splitlines():
77
+ # Split before decoding so splitlines() only uses \r and \n
78
+ for line_bytes in chunk.splitlines():
79
+ # Decode the line.
80
+ line = line_bytes.decode(self._char_enc)
42
81
 
43
- line = cast(bytes, line).decode("utf-8")
82
+ # Lines starting with a separator are comments and are to be
83
+ # ignored.
84
+ if not line.strip() or line.startswith(_FIELD_SEPARATOR):
85
+ continue
44
86
 
45
- data = line.split(":", 1)
87
+ data = line.split(_FIELD_SEPARATOR, 1)
46
88
  field = data[0]
47
89
 
90
+ # Ignore unknown fields.
91
+ if not hasattr(event, field):
92
+ continue
93
+
48
94
  if len(data) > 1:
95
+ # From the spec:
49
96
  # "If value starts with a single U+0020 SPACE character,
50
- # remove it from value. .strip() would remove all white spaces"
97
+ # remove it from value."
51
98
  if data[1].startswith(" "):
52
99
  value = data[1][1:]
53
100
  else:
54
101
  value = data[1]
55
102
  else:
103
+ # If no value is present after the separator,
104
+ # assume an empty value.
56
105
  value = ""
57
106
 
58
107
  # The data field may come over multiple lines and their values
59
108
  # are concatenated with each other.
109
+ current_value = getattr(event, field, "")
60
110
  if field == "data":
61
- event.data += value + "\n"
62
- elif field == "event":
63
- event.event = value
111
+ new_value = current_value + value + "\n"
112
+ else:
113
+ new_value = value
114
+ setattr(event, field, new_value)
64
115
 
116
+ # Events with no data are not dispatched.
65
117
  if not event.data:
66
118
  continue
67
119
 
68
120
  # If the data field ends with a newline, remove it.
69
121
  if event.data.endswith("\n"):
70
- event.data = event.data[0:-1] # Replace trailing newline - rstrip would remove multiple.
122
+ event.data = event.data[0:-1]
71
123
 
72
124
  # Empty event names default to 'message'
73
125
  event.event = event.event or "message"
@@ -77,5 +129,16 @@ class SSEClient:
77
129
 
78
130
  yield event
79
131
 
132
+ def _handle_json(self) -> Iterator[Event]:
133
+ data_list = json.loads(self._event_source.data.decode(self._char_enc))
134
+ for data in data_list:
135
+ yield Event(
136
+ id=data.get("id"),
137
+ event=data.get("event"),
138
+ data=data.get("data"),
139
+ comment=data.get("comment"),
140
+ retry=data.get("retry"),
141
+ )
142
+
80
143
  def close(self) -> None:
81
- self.response.close()
144
+ self._event_source.close()
snowflake/cortex/_util.py CHANGED
@@ -1,6 +1,8 @@
1
- from typing import Dict, List, Optional, Union, cast
1
+ from typing import Any, Dict, List, Optional, Union, cast
2
2
 
3
3
  from snowflake import snowpark
4
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
5
+ from snowflake.ml._internal.utils import formatting
4
6
  from snowflake.snowpark import context, functions
5
7
 
6
8
  CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
@@ -64,3 +66,30 @@ def _call_sql_function_immediate(
64
66
  empty_df = session.create_dataframe([snowpark.Row()])
65
67
  df = empty_df.select(functions.builtin(function)(*lit_args))
66
68
  return cast(str, df.collect()[0][0])
69
+
70
+
71
+ def call_sql_function_literals(function: str, session: Optional[snowpark.Session], *args: Any) -> str:
72
+ r"""Call a SQL function with only literal arguments.
73
+
74
+ This is useful for calling system functions.
75
+
76
+ Args:
77
+ function: The name of the function to be called.
78
+ session: The Snowpark session to use.
79
+ *args: The list of arguments
80
+
81
+ Returns:
82
+ String value that corresponds the the first cell in the dataframe.
83
+
84
+ Raises:
85
+ SnowflakeMLException: If no session is given and no active session exists.
86
+ """
87
+ if session is None:
88
+ session = context.get_active_session()
89
+ if session is None:
90
+ raise exceptions.SnowflakeMLException(
91
+ error_code=error_codes.INVALID_SNOWPARK_SESSION,
92
+ )
93
+
94
+ function_arguments = ",".join(["NULL" if arg is None else formatting.format_value_for_select(arg) for arg in args])
95
+ return cast(str, session.sql(f"SELECT {function}({function_arguments})").collect()[0][0])
@@ -1,4 +1,4 @@
1
- import sys
1
+ import importlib
2
2
  from typing import Any, Generic, Type, TypeVar, Union, cast
3
3
 
4
4
  import numpy as np
@@ -51,8 +51,8 @@ class LazyType(Generic[T]):
51
51
  def get_class(self) -> Type[T]:
52
52
  if self._runtime_class is None:
53
53
  try:
54
- m = sys.modules[self.module]
55
- except KeyError:
54
+ m = importlib.import_module(self.module)
55
+ except ModuleNotFoundError:
56
56
  raise ValueError(f"Module {self.module} not imported.")
57
57
 
58
58
  self._runtime_class = cast("Type[T]", getattr(m, self.qualname))
@@ -0,0 +1,141 @@
1
+ import base64
2
+ import hashlib
3
+ import logging
4
+ from datetime import datetime, timedelta, timezone
5
+ from typing import Optional
6
+
7
+ import jwt
8
+ from cryptography.hazmat.primitives import serialization
9
+ from cryptography.hazmat.primitives.asymmetric import types
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ ISSUER = "iss"
14
+ EXPIRE_TIME = "exp"
15
+ ISSUE_TIME = "iat"
16
+ SUBJECT = "sub"
17
+
18
+
19
+ class JWTGenerator:
20
+ """
21
+ Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator
22
+ keeps the generated token and only regenerates the token if a specified period of time has passed.
23
+ """
24
+
25
+ _DEFAULT_LIFETIME = timedelta(minutes=59) # The tokens will have a 59-minute lifetime
26
+ _DEFAULT_RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes
27
+ ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
28
+
29
+ def __init__(
30
+ self,
31
+ account: str,
32
+ user: str,
33
+ private_key: types.PRIVATE_KEY_TYPES,
34
+ lifetime: Optional[timedelta] = None,
35
+ renewal_delay: Optional[timedelta] = None,
36
+ ) -> None:
37
+ """
38
+ Create a new JWTGenerator object.
39
+
40
+ Args:
41
+ account: The account identifier.
42
+ user: The username.
43
+ private_key: The private key used to sign the JWT.
44
+ lifetime: The lifetime of the token.
45
+ renewal_delay: The time before the token expires to renew it.
46
+ """
47
+
48
+ # Construct the fully qualified name of the user in uppercase.
49
+ self.account = JWTGenerator._prepare_account_name_for_jwt(account)
50
+ self.user = user.upper()
51
+ self.qualified_username = self.account + "." + self.user
52
+ self.private_key = private_key
53
+ self.public_key_fp = JWTGenerator._calculate_public_key_fingerprint(self.private_key)
54
+
55
+ self.issuer = self.qualified_username + "." + self.public_key_fp
56
+ self.lifetime = lifetime or JWTGenerator._DEFAULT_LIFETIME
57
+ self.renewal_delay = renewal_delay or JWTGenerator._DEFAULT_RENEWAL_DELTA
58
+ self.renew_time = datetime.now(timezone.utc)
59
+ self.token: Optional[str] = None
60
+
61
+ logger.info(
62
+ """Creating JWTGenerator with arguments
63
+ account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
64
+ self.account,
65
+ self.user,
66
+ self.lifetime,
67
+ self.renewal_delay,
68
+ )
69
+
70
+ @staticmethod
71
+ def _prepare_account_name_for_jwt(raw_account: str) -> str:
72
+ account = raw_account
73
+ if ".global" not in account:
74
+ # Handle the general case.
75
+ idx = account.find(".")
76
+ if idx > 0:
77
+ account = account[0:idx]
78
+ else:
79
+ # Handle the replication case.
80
+ idx = account.find("-")
81
+ if idx > 0:
82
+ account = account[0:idx]
83
+ # Use uppercase for the account identifier.
84
+ return account.upper()
85
+
86
+ def get_token(self) -> str:
87
+ now = datetime.now(timezone.utc) # Fetch the current time
88
+ if self.token is not None and self.renew_time > now:
89
+ return self.token
90
+
91
+ # If the token has expired or doesn't exist, regenerate the token.
92
+ logger.info(
93
+ "Generating a new token because the present time (%s) is later than the renewal time (%s)",
94
+ now,
95
+ self.renew_time,
96
+ )
97
+ # Calculate the next time we need to renew the token.
98
+ self.renew_time = now + self.renewal_delay
99
+
100
+ # Create our payload
101
+ payload = {
102
+ # Set the issuer to the fully qualified username concatenated with the public key fingerprint.
103
+ ISSUER: self.issuer,
104
+ # Set the subject to the fully qualified username.
105
+ SUBJECT: self.qualified_username,
106
+ # Set the issue time to now.
107
+ ISSUE_TIME: now,
108
+ # Set the expiration time, based on the lifetime specified for this object.
109
+ EXPIRE_TIME: now + self.lifetime,
110
+ }
111
+
112
+ # Regenerate the actual token
113
+ token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
114
+ # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
115
+ # If the token is a byte string, convert it to a string.
116
+ if isinstance(token, bytes):
117
+ token = token.decode("utf-8")
118
+ self.token = token
119
+ logger.info(
120
+ "Generated a JWT with the following payload: %s",
121
+ jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
122
+ )
123
+
124
+ return token
125
+
126
+ @staticmethod
127
+ def _calculate_public_key_fingerprint(private_key: types.PRIVATE_KEY_TYPES) -> str:
128
+ # Get the raw bytes of public key.
129
+ public_key_raw = private_key.public_key().public_bytes(
130
+ serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo
131
+ )
132
+
133
+ # Get the sha256 hash of the raw bytes.
134
+ sha256hash = hashlib.sha256()
135
+ sha256hash.update(public_key_raw)
136
+
137
+ # Base64-encode the value and prepend the prefix 'SHA256:'.
138
+ public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
139
+ logger.info("Public key fingerprint is %s", public_key_fp)
140
+
141
+ return public_key_fp
@@ -0,0 +1,5 @@
1
+ from .data_connector import DataConnector
2
+ from .data_ingestor import DataIngestor, DataIngestorType
3
+ from .data_source import DataFrameInfo, DatasetInfo, DataSource
4
+
5
+ __all__ = ["DataConnector", "DataSource", "DataFrameInfo", "DatasetInfo", "DataIngestor", "DataIngestorType"]
@@ -14,7 +14,7 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
15
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
16
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
17
- from snowflake.snowpark import Session, dataframe
17
+ from snowflake.snowpark import Session, async_job, dataframe
18
18
 
19
19
  _TELEMETRY_PROJECT = "MLOps"
20
20
  _TELEMETRY_SUBPROJECT = "ModelManagement"
@@ -631,7 +631,8 @@ class ModelVersion(lineage_node.LineageNode):
631
631
  max_batch_rows: Optional[int] = None,
632
632
  force_rebuild: bool = False,
633
633
  build_external_access_integration: Optional[str] = None,
634
- ) -> str:
634
+ block: bool = True,
635
+ ) -> Union[str, async_job.AsyncJob]:
635
636
  """Create an inference service with the given spec.
636
637
 
637
638
  Args:
@@ -659,6 +660,9 @@ class ModelVersion(lineage_node.LineageNode):
659
660
  force_rebuild: Whether to force a model inference image rebuild.
660
661
  build_external_access_integration: (Deprecated) The external access integration for image build. This is
661
662
  usually permitting access to conda & PyPI repositories.
663
+ block: A bool value indicating whether this function will wait until the service is available.
664
+ When it is ``False``, this function executes the underlying service creation asynchronously
665
+ and returns an :class:`AsyncJob`.
662
666
  """
663
667
  ...
664
668
 
@@ -679,7 +683,8 @@ class ModelVersion(lineage_node.LineageNode):
679
683
  max_batch_rows: Optional[int] = None,
680
684
  force_rebuild: bool = False,
681
685
  build_external_access_integrations: Optional[List[str]] = None,
682
- ) -> str:
686
+ block: bool = True,
687
+ ) -> Union[str, async_job.AsyncJob]:
683
688
  """Create an inference service with the given spec.
684
689
 
685
690
  Args:
@@ -707,6 +712,9 @@ class ModelVersion(lineage_node.LineageNode):
707
712
  force_rebuild: Whether to force a model inference image rebuild.
708
713
  build_external_access_integrations: The external access integrations for image build. This is usually
709
714
  permitting access to conda & PyPI repositories.
715
+ block: A bool value indicating whether this function will wait until the service is available.
716
+ When it is ``False``, this function executes the underlying service creation asynchronously
717
+ and returns an :class:`AsyncJob`.
710
718
  """
711
719
  ...
712
720
 
@@ -742,7 +750,8 @@ class ModelVersion(lineage_node.LineageNode):
742
750
  force_rebuild: bool = False,
743
751
  build_external_access_integration: Optional[str] = None,
744
752
  build_external_access_integrations: Optional[List[str]] = None,
745
- ) -> str:
753
+ block: bool = True,
754
+ ) -> Union[str, async_job.AsyncJob]:
746
755
  """Create an inference service with the given spec.
747
756
 
748
757
  Args:
@@ -772,12 +781,16 @@ class ModelVersion(lineage_node.LineageNode):
772
781
  usually permitting access to conda & PyPI repositories.
773
782
  build_external_access_integrations: The external access integrations for image build. This is usually
774
783
  permitting access to conda & PyPI repositories.
784
+ block: A bool value indicating whether this function will wait until the service is available.
785
+ When it is False, this function executes the underlying service creation asynchronously
786
+ and returns an AsyncJob.
775
787
 
776
788
  Raises:
777
789
  ValueError: Illegal external access integration arguments.
778
790
 
779
791
  Returns:
780
- Result information about service creation from server.
792
+ If `block=True`, return result information about service creation from server.
793
+ Otherwise, return the service creation AsyncJob.
781
794
  """
782
795
  statement_params = telemetry.get_statement_params(
783
796
  project=_TELEMETRY_PROJECT,
@@ -829,6 +842,7 @@ class ModelVersion(lineage_node.LineageNode):
829
842
  if build_external_access_integrations is None
830
843
  else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
831
844
  ),
845
+ block=block,
832
846
  statement_params=statement_params,
833
847
  )
834
848
 
@@ -851,17 +865,13 @@ class ModelVersion(lineage_node.LineageNode):
851
865
  )
852
866
 
853
867
  return pd.DataFrame(
854
- self._model_ops.list_inference_services(
868
+ self._model_ops.show_services(
855
869
  database_name=None,
856
870
  schema_name=None,
857
871
  model_name=self._model_name,
858
872
  version_name=self._version_name,
859
873
  statement_params=statement_params,
860
- ),
861
- columns=[
862
- self._model_ops.INFERENCE_SERVICE_NAME_COL_NAME,
863
- self._model_ops.INFERENCE_SERVICE_ENDPOINT_COL_NAME,
864
- ],
874
+ )
865
875
  )
866
876
 
867
877
  @telemetry.send_api_usage_telemetry(
@@ -889,12 +899,16 @@ class ModelVersion(lineage_node.LineageNode):
889
899
  project=_TELEMETRY_PROJECT,
890
900
  subproject=_TELEMETRY_SUBPROJECT,
891
901
  )
902
+
903
+ database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
892
904
  self._model_ops.delete_service(
893
905
  database_name=None,
894
906
  schema_name=None,
895
907
  model_name=self._model_name,
896
908
  version_name=self._version_name,
897
- service_name=service_name,
909
+ service_database_name=database_name_id,
910
+ service_schema_name=schema_name_id,
911
+ service_name=service_name_id,
898
912
  statement_params=statement_params,
899
913
  )
900
914
 
@@ -3,7 +3,7 @@ import os
3
3
  import pathlib
4
4
  import tempfile
5
5
  import warnings
6
- from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
6
+ from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast, overload
7
7
 
8
8
  import yaml
9
9
 
@@ -31,9 +31,14 @@ from snowflake.snowpark import dataframe, row, session
31
31
  from snowflake.snowpark._internal import utils as snowpark_utils
32
32
 
33
33
 
34
+ class ServiceInfo(TypedDict):
35
+ name: str
36
+ inference_endpoint: Optional[str]
37
+
38
+
34
39
  class ModelOperator:
35
- INFERENCE_SERVICE_NAME_COL_NAME = "service_name"
36
- INFERENCE_SERVICE_ENDPOINT_COL_NAME = "endpoints"
40
+ INFERENCE_SERVICE_ENDPOINT_NAME = "inference"
41
+ INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app"
37
42
 
38
43
  def __init__(
39
44
  self,
@@ -517,7 +522,7 @@ class ModelOperator:
517
522
  statement_params=statement_params,
518
523
  )
519
524
 
520
- def list_inference_services(
525
+ def show_services(
521
526
  self,
522
527
  *,
523
528
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -525,7 +530,7 @@ class ModelOperator:
525
530
  model_name: sql_identifier.SqlIdentifier,
526
531
  version_name: sql_identifier.SqlIdentifier,
527
532
  statement_params: Optional[Dict[str, Any]] = None,
528
- ) -> Dict[str, List[str]]:
533
+ ) -> List[ServiceInfo]:
529
534
  res = self._model_client.show_versions(
530
535
  database_name=database_name,
531
536
  schema_name=schema_name,
@@ -546,21 +551,28 @@ class ModelOperator:
546
551
 
547
552
  json_array = json.loads(res[0][service_col_name])
548
553
  # TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
549
- services = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
550
- endpoint_col_name = self._model_client.MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME
551
-
552
- services_col, endpoints_col = [], []
553
- for service in services:
554
- res = self._model_client.show_endpoints(service_name=service)
555
- endpoints = [endpoint[endpoint_col_name] for endpoint in res]
556
- for endpoint in endpoints:
557
- services_col.append(service)
558
- endpoints_col.append(endpoint)
559
-
560
- return {
561
- self.INFERENCE_SERVICE_NAME_COL_NAME: services_col,
562
- self.INFERENCE_SERVICE_ENDPOINT_COL_NAME: endpoints_col,
563
- }
554
+ fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
555
+
556
+ result = []
557
+ ingress_url: Optional[str] = None
558
+ for fully_qualified_service_name in fully_qualified_service_names:
559
+ db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
560
+ for res_row in self._service_client.show_endpoints(
561
+ database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
562
+ ):
563
+ if (
564
+ res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME]
565
+ == self.INFERENCE_SERVICE_ENDPOINT_NAME
566
+ and res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] is not None
567
+ ):
568
+ ingress_url = str(
569
+ res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME]
570
+ )
571
+ if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
572
+ ingress_url = None
573
+ result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url))
574
+
575
+ return result
564
576
 
565
577
  def delete_service(
566
578
  self,
@@ -569,33 +581,42 @@ class ModelOperator:
569
581
  schema_name: Optional[sql_identifier.SqlIdentifier],
570
582
  model_name: sql_identifier.SqlIdentifier,
571
583
  version_name: sql_identifier.SqlIdentifier,
572
- service_name: str,
584
+ service_database_name: Optional[sql_identifier.SqlIdentifier],
585
+ service_schema_name: Optional[sql_identifier.SqlIdentifier],
586
+ service_name: sql_identifier.SqlIdentifier,
573
587
  statement_params: Optional[Dict[str, Any]] = None,
574
588
  ) -> None:
575
- services = self.list_inference_services(
589
+ services = self.show_services(
576
590
  database_name=database_name,
577
591
  schema_name=schema_name,
578
592
  model_name=model_name,
579
593
  version_name=version_name,
580
594
  statement_params=statement_params,
581
595
  )
582
- db, schema, service_name = sql_identifier.parse_fully_qualified_name(service_name)
596
+
597
+ # Fall back to the model's database and schema.
598
+ # database_name or schema_name are set if the model is created or get using fully qualified name
599
+ # Otherwise, the model's database and schema are same as registry's database and schema, which are set in the
600
+ # self._model_client.
601
+
602
+ service_database_name = service_database_name or database_name or self._model_client._database_name
603
+ service_schema_name = service_schema_name or schema_name or self._model_client._schema_name
583
604
  fully_qualified_service_name = sql_identifier.get_fully_qualified_name(
584
- db, schema, service_name, self._session.get_current_database(), self._session.get_current_schema()
605
+ service_database_name, service_schema_name, service_name
585
606
  )
586
607
 
587
- service_col_name = self.INFERENCE_SERVICE_NAME_COL_NAME
588
- for service in services[service_col_name]:
589
- if service == fully_qualified_service_name:
608
+ for service_info in services:
609
+ if service_info["name"] == fully_qualified_service_name:
590
610
  self._service_client.drop_service(
591
- database_name=db,
592
- schema_name=schema,
611
+ database_name=service_database_name,
612
+ schema_name=service_schema_name,
593
613
  service_name=service_name,
594
614
  statement_params=statement_params,
595
615
  )
596
616
  return
597
617
  raise ValueError(
598
- f"Service '{service_name}' does not exist or unauthorized or not associated with this model version."
618
+ f"Service '{fully_qualified_service_name}' does not exist "
619
+ "or unauthorized or not associated with this model version."
599
620
  )
600
621
 
601
622
  def get_model_version_manifest(