snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +51 -30
- snowflake/ml/model/_client/ops/service_ops.py +13 -2
- snowflake/ml/model/_client/sql/model.py +0 -14
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
- snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +48 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +71 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/model_signature.py +38 -9
- snowflake/ml/model/type_hints.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +158 -1045
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +106 -230
- snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
- snowflake/ml/monitoring/model_monitor.py +7 -96
- snowflake/ml/registry/registry.py +17 -29
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +31 -5
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +48 -47
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
snowflake/cortex/_sse_client.py
CHANGED
@@ -1,73 +1,125 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
from typing import Any, Iterator, Optional
|
2
3
|
|
3
|
-
|
4
|
+
_FIELD_SEPARATOR = ":"
|
4
5
|
|
5
6
|
|
6
7
|
class Event:
|
7
|
-
|
8
|
+
"""Representation of an event from the event stream."""
|
9
|
+
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
id: Optional[str] = None,
|
13
|
+
event: str = "message",
|
14
|
+
data: str = "",
|
15
|
+
comment: Optional[str] = None,
|
16
|
+
retry: Optional[int] = None,
|
17
|
+
) -> None:
|
18
|
+
self.id = id
|
8
19
|
self.event = event
|
9
20
|
self.data = data
|
21
|
+
self.comment = comment
|
22
|
+
self.retry = retry
|
10
23
|
|
11
24
|
def __str__(self) -> str:
|
12
25
|
s = f"{self.event} event"
|
26
|
+
if self.id:
|
27
|
+
s += f" #{self.id}"
|
13
28
|
if self.data:
|
14
|
-
s +=
|
29
|
+
s += ", {} byte{}".format(len(self.data), "s" if len(self.data) else "")
|
15
30
|
else:
|
16
31
|
s += ", no data"
|
32
|
+
if self.comment:
|
33
|
+
s += f", comment: {self.comment}"
|
34
|
+
if self.retry:
|
35
|
+
s += f", retry in {self.retry}ms"
|
17
36
|
return s
|
18
37
|
|
19
38
|
|
39
|
+
# This is copied from the snowpy library:
|
40
|
+
# https://github.com/snowflakedb/snowpy/blob/main/libs/snowflake.core/src/snowflake/core/rest.py#L39
|
41
|
+
# TODO(SNOW-1750723) - Current there’s code duplication across snowflake-ml-python
|
42
|
+
# and snowpy library for Cortex REST API which was done to meet our GA timelines
|
43
|
+
# Once snowpy has a release with https://github.com/snowflakedb/snowpy/pull/679, we should
|
44
|
+
# remove the class here and directly refer from the snowflake.core package directly
|
20
45
|
class SSEClient:
|
21
|
-
def __init__(self,
|
46
|
+
def __init__(self, event_source: Any, char_enc: str = "utf-8") -> None:
|
47
|
+
self._event_source = event_source
|
48
|
+
self._char_enc = char_enc
|
22
49
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
lines = b""
|
28
|
-
for chunk in self.response:
|
50
|
+
def _read(self) -> Iterator[bytes]:
|
51
|
+
data = b""
|
52
|
+
for chunk in self._event_source:
|
29
53
|
for line in chunk.splitlines(True):
|
30
|
-
|
31
|
-
if
|
32
|
-
yield
|
33
|
-
|
34
|
-
if
|
35
|
-
yield
|
54
|
+
data += line
|
55
|
+
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
56
|
+
yield data
|
57
|
+
data = b""
|
58
|
+
if data:
|
59
|
+
yield data
|
36
60
|
|
37
61
|
def events(self) -> Iterator[Event]:
|
38
|
-
|
62
|
+
content_type = self._event_source.headers.get("Content-Type")
|
63
|
+
# The check for empty content-type is present because it's being populated after
|
64
|
+
# the change in https://github.com/snowflakedb/snowflake/pull/217654.
|
65
|
+
# This can be removed once the above change makes it to prod or we move to snowpy
|
66
|
+
# for SSEClient implementation.
|
67
|
+
if content_type == "text/event-stream" or not content_type:
|
68
|
+
return self._handle_sse()
|
69
|
+
elif content_type == "application/json":
|
70
|
+
return self._handle_json()
|
71
|
+
else:
|
72
|
+
raise ValueError(f"Unknown Content-Type: {content_type}")
|
73
|
+
|
74
|
+
def _handle_sse(self) -> Iterator[Event]:
|
75
|
+
for chunk in self._read():
|
39
76
|
event = Event()
|
40
|
-
# splitlines() only uses \r and \n
|
41
|
-
for
|
77
|
+
# Split before decoding so splitlines() only uses \r and \n
|
78
|
+
for line_bytes in chunk.splitlines():
|
79
|
+
# Decode the line.
|
80
|
+
line = line_bytes.decode(self._char_enc)
|
42
81
|
|
43
|
-
|
82
|
+
# Lines starting with a separator are comments and are to be
|
83
|
+
# ignored.
|
84
|
+
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
|
85
|
+
continue
|
44
86
|
|
45
|
-
data = line.split(
|
87
|
+
data = line.split(_FIELD_SEPARATOR, 1)
|
46
88
|
field = data[0]
|
47
89
|
|
90
|
+
# Ignore unknown fields.
|
91
|
+
if not hasattr(event, field):
|
92
|
+
continue
|
93
|
+
|
48
94
|
if len(data) > 1:
|
95
|
+
# From the spec:
|
49
96
|
# "If value starts with a single U+0020 SPACE character,
|
50
|
-
# remove it from value.
|
97
|
+
# remove it from value."
|
51
98
|
if data[1].startswith(" "):
|
52
99
|
value = data[1][1:]
|
53
100
|
else:
|
54
101
|
value = data[1]
|
55
102
|
else:
|
103
|
+
# If no value is present after the separator,
|
104
|
+
# assume an empty value.
|
56
105
|
value = ""
|
57
106
|
|
58
107
|
# The data field may come over multiple lines and their values
|
59
108
|
# are concatenated with each other.
|
109
|
+
current_value = getattr(event, field, "")
|
60
110
|
if field == "data":
|
61
|
-
|
62
|
-
|
63
|
-
|
111
|
+
new_value = current_value + value + "\n"
|
112
|
+
else:
|
113
|
+
new_value = value
|
114
|
+
setattr(event, field, new_value)
|
64
115
|
|
116
|
+
# Events with no data are not dispatched.
|
65
117
|
if not event.data:
|
66
118
|
continue
|
67
119
|
|
68
120
|
# If the data field ends with a newline, remove it.
|
69
121
|
if event.data.endswith("\n"):
|
70
|
-
event.data = event.data[0:-1]
|
122
|
+
event.data = event.data[0:-1]
|
71
123
|
|
72
124
|
# Empty event names default to 'message'
|
73
125
|
event.event = event.event or "message"
|
@@ -77,5 +129,16 @@ class SSEClient:
|
|
77
129
|
|
78
130
|
yield event
|
79
131
|
|
132
|
+
def _handle_json(self) -> Iterator[Event]:
|
133
|
+
data_list = json.loads(self._event_source.data.decode(self._char_enc))
|
134
|
+
for data in data_list:
|
135
|
+
yield Event(
|
136
|
+
id=data.get("id"),
|
137
|
+
event=data.get("event"),
|
138
|
+
data=data.get("data"),
|
139
|
+
comment=data.get("comment"),
|
140
|
+
retry=data.get("retry"),
|
141
|
+
)
|
142
|
+
|
80
143
|
def close(self) -> None:
|
81
|
-
self.
|
144
|
+
self._event_source.close()
|
snowflake/cortex/_util.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
|
-
from typing import Dict, List, Optional, Union, cast
|
1
|
+
from typing import Any, Dict, List, Optional, Union, cast
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
5
|
+
from snowflake.ml._internal.utils import formatting
|
4
6
|
from snowflake.snowpark import context, functions
|
5
7
|
|
6
8
|
CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
|
@@ -64,3 +66,30 @@ def _call_sql_function_immediate(
|
|
64
66
|
empty_df = session.create_dataframe([snowpark.Row()])
|
65
67
|
df = empty_df.select(functions.builtin(function)(*lit_args))
|
66
68
|
return cast(str, df.collect()[0][0])
|
69
|
+
|
70
|
+
|
71
|
+
def call_sql_function_literals(function: str, session: Optional[snowpark.Session], *args: Any) -> str:
|
72
|
+
r"""Call a SQL function with only literal arguments.
|
73
|
+
|
74
|
+
This is useful for calling system functions.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
function: The name of the function to be called.
|
78
|
+
session: The Snowpark session to use.
|
79
|
+
*args: The list of arguments
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
String value that corresponds the the first cell in the dataframe.
|
83
|
+
|
84
|
+
Raises:
|
85
|
+
SnowflakeMLException: If no session is given and no active session exists.
|
86
|
+
"""
|
87
|
+
if session is None:
|
88
|
+
session = context.get_active_session()
|
89
|
+
if session is None:
|
90
|
+
raise exceptions.SnowflakeMLException(
|
91
|
+
error_code=error_codes.INVALID_SNOWPARK_SESSION,
|
92
|
+
)
|
93
|
+
|
94
|
+
function_arguments = ",".join(["NULL" if arg is None else formatting.format_value_for_select(arg) for arg in args])
|
95
|
+
return cast(str, session.sql(f"SELECT {function}({function_arguments})").collect()[0][0])
|
@@ -1,4 +1,4 @@
|
|
1
|
-
import
|
1
|
+
import importlib
|
2
2
|
from typing import Any, Generic, Type, TypeVar, Union, cast
|
3
3
|
|
4
4
|
import numpy as np
|
@@ -51,8 +51,8 @@ class LazyType(Generic[T]):
|
|
51
51
|
def get_class(self) -> Type[T]:
|
52
52
|
if self._runtime_class is None:
|
53
53
|
try:
|
54
|
-
m =
|
55
|
-
except
|
54
|
+
m = importlib.import_module(self.module)
|
55
|
+
except ModuleNotFoundError:
|
56
56
|
raise ValueError(f"Module {self.module} not imported.")
|
57
57
|
|
58
58
|
self._runtime_class = cast("Type[T]", getattr(m, self.qualname))
|
@@ -0,0 +1,5 @@
|
|
1
|
+
from .data_connector import DataConnector
|
2
|
+
from .data_ingestor import DataIngestor, DataIngestorType
|
3
|
+
from .data_source import DataFrameInfo, DatasetInfo, DataSource
|
4
|
+
|
5
|
+
__all__ = ["DataConnector", "DataSource", "DataFrameInfo", "DatasetInfo", "DataIngestor", "DataIngestorType"]
|
@@ -851,17 +851,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
851
851
|
)
|
852
852
|
|
853
853
|
return pd.DataFrame(
|
854
|
-
self._model_ops.
|
854
|
+
self._model_ops.show_services(
|
855
855
|
database_name=None,
|
856
856
|
schema_name=None,
|
857
857
|
model_name=self._model_name,
|
858
858
|
version_name=self._version_name,
|
859
859
|
statement_params=statement_params,
|
860
|
-
)
|
861
|
-
columns=[
|
862
|
-
self._model_ops.INFERENCE_SERVICE_NAME_COL_NAME,
|
863
|
-
self._model_ops.INFERENCE_SERVICE_ENDPOINT_COL_NAME,
|
864
|
-
],
|
860
|
+
)
|
865
861
|
)
|
866
862
|
|
867
863
|
@telemetry.send_api_usage_telemetry(
|
@@ -889,12 +885,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
889
885
|
project=_TELEMETRY_PROJECT,
|
890
886
|
subproject=_TELEMETRY_SUBPROJECT,
|
891
887
|
)
|
888
|
+
|
889
|
+
database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
|
892
890
|
self._model_ops.delete_service(
|
893
891
|
database_name=None,
|
894
892
|
schema_name=None,
|
895
893
|
model_name=self._model_name,
|
896
894
|
version_name=self._version_name,
|
897
|
-
|
895
|
+
service_database_name=database_name_id,
|
896
|
+
service_schema_name=schema_name_id,
|
897
|
+
service_name=service_name_id,
|
898
898
|
statement_params=statement_params,
|
899
899
|
)
|
900
900
|
|
@@ -3,7 +3,7 @@ import os
|
|
3
3
|
import pathlib
|
4
4
|
import tempfile
|
5
5
|
import warnings
|
6
|
-
from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
|
6
|
+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast, overload
|
7
7
|
|
8
8
|
import yaml
|
9
9
|
|
@@ -31,9 +31,14 @@ from snowflake.snowpark import dataframe, row, session
|
|
31
31
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
32
32
|
|
33
33
|
|
34
|
+
class ServiceInfo(TypedDict):
|
35
|
+
name: str
|
36
|
+
inference_endpoint: Optional[str]
|
37
|
+
|
38
|
+
|
34
39
|
class ModelOperator:
|
35
|
-
|
36
|
-
|
40
|
+
INFERENCE_SERVICE_ENDPOINT_NAME = "inference"
|
41
|
+
INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app"
|
37
42
|
|
38
43
|
def __init__(
|
39
44
|
self,
|
@@ -517,7 +522,7 @@ class ModelOperator:
|
|
517
522
|
statement_params=statement_params,
|
518
523
|
)
|
519
524
|
|
520
|
-
def
|
525
|
+
def show_services(
|
521
526
|
self,
|
522
527
|
*,
|
523
528
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -525,7 +530,7 @@ class ModelOperator:
|
|
525
530
|
model_name: sql_identifier.SqlIdentifier,
|
526
531
|
version_name: sql_identifier.SqlIdentifier,
|
527
532
|
statement_params: Optional[Dict[str, Any]] = None,
|
528
|
-
) ->
|
533
|
+
) -> List[ServiceInfo]:
|
529
534
|
res = self._model_client.show_versions(
|
530
535
|
database_name=database_name,
|
531
536
|
schema_name=schema_name,
|
@@ -546,21 +551,28 @@ class ModelOperator:
|
|
546
551
|
|
547
552
|
json_array = json.loads(res[0][service_col_name])
|
548
553
|
# TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
for
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
554
|
+
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
555
|
+
|
556
|
+
result = []
|
557
|
+
ingress_url: Optional[str] = None
|
558
|
+
for fully_qualified_service_name in fully_qualified_service_names:
|
559
|
+
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
560
|
+
for res_row in self._service_client.show_endpoints(
|
561
|
+
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
562
|
+
):
|
563
|
+
if (
|
564
|
+
res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME]
|
565
|
+
== self.INFERENCE_SERVICE_ENDPOINT_NAME
|
566
|
+
and res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] is not None
|
567
|
+
):
|
568
|
+
ingress_url = str(
|
569
|
+
res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME]
|
570
|
+
)
|
571
|
+
if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
|
572
|
+
ingress_url = None
|
573
|
+
result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url))
|
574
|
+
|
575
|
+
return result
|
564
576
|
|
565
577
|
def delete_service(
|
566
578
|
self,
|
@@ -569,33 +581,42 @@ class ModelOperator:
|
|
569
581
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
570
582
|
model_name: sql_identifier.SqlIdentifier,
|
571
583
|
version_name: sql_identifier.SqlIdentifier,
|
572
|
-
|
584
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
585
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
586
|
+
service_name: sql_identifier.SqlIdentifier,
|
573
587
|
statement_params: Optional[Dict[str, Any]] = None,
|
574
588
|
) -> None:
|
575
|
-
services = self.
|
589
|
+
services = self.show_services(
|
576
590
|
database_name=database_name,
|
577
591
|
schema_name=schema_name,
|
578
592
|
model_name=model_name,
|
579
593
|
version_name=version_name,
|
580
594
|
statement_params=statement_params,
|
581
595
|
)
|
582
|
-
|
596
|
+
|
597
|
+
# Fall back to the model's database and schema.
|
598
|
+
# database_name or schema_name are set if the model is created or get using fully qualified name
|
599
|
+
# Otherwise, the model's database and schema are same as registry's database and schema, which are set in the
|
600
|
+
# self._model_client.
|
601
|
+
|
602
|
+
service_database_name = service_database_name or database_name or self._model_client._database_name
|
603
|
+
service_schema_name = service_schema_name or schema_name or self._model_client._schema_name
|
583
604
|
fully_qualified_service_name = sql_identifier.get_fully_qualified_name(
|
584
|
-
|
605
|
+
service_database_name, service_schema_name, service_name
|
585
606
|
)
|
586
607
|
|
587
|
-
|
588
|
-
|
589
|
-
if service == fully_qualified_service_name:
|
608
|
+
for service_info in services:
|
609
|
+
if service_info["name"] == fully_qualified_service_name:
|
590
610
|
self._service_client.drop_service(
|
591
|
-
database_name=
|
592
|
-
schema_name=
|
611
|
+
database_name=service_database_name,
|
612
|
+
schema_name=service_schema_name,
|
593
613
|
service_name=service_name,
|
594
614
|
statement_params=statement_params,
|
595
615
|
)
|
596
616
|
return
|
597
617
|
raise ValueError(
|
598
|
-
f"Service '{
|
618
|
+
f"Service '{fully_qualified_service_name}' does not exist "
|
619
|
+
"or unauthorized or not associated with this model version."
|
599
620
|
)
|
600
621
|
|
601
622
|
def get_model_version_manifest(
|
@@ -109,6 +109,17 @@ class ServiceOperator:
|
|
109
109
|
build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
|
110
110
|
statement_params: Optional[Dict[str, Any]] = None,
|
111
111
|
) -> str:
|
112
|
+
|
113
|
+
# Fall back to the registry's database and schema if not provided
|
114
|
+
database_name = database_name or self._database_name
|
115
|
+
schema_name = schema_name or self._schema_name
|
116
|
+
|
117
|
+
# Fall back to the model's database and schema if not provided then to the registry's database and schema
|
118
|
+
service_database_name = service_database_name or database_name or self._database_name
|
119
|
+
service_schema_name = service_schema_name or schema_name or self._schema_name
|
120
|
+
|
121
|
+
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
122
|
+
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
112
123
|
# create a temp stage
|
113
124
|
stage_name = sql_identifier.SqlIdentifier(
|
114
125
|
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
@@ -130,8 +141,8 @@ class ServiceOperator:
|
|
130
141
|
raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
|
131
142
|
|
132
143
|
self._model_deployment_spec.save(
|
133
|
-
database_name=database_name
|
134
|
-
schema_name=schema_name
|
144
|
+
database_name=database_name,
|
145
|
+
schema_name=schema_name,
|
135
146
|
model_name=model_name,
|
136
147
|
version_name=version_name,
|
137
148
|
service_database_name=service_database_name,
|
@@ -17,8 +17,6 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
17
17
|
MODEL_VERSION_ALIASES_COL_NAME = "aliases"
|
18
18
|
MODEL_VERSION_INFERENCE_SERVICES_COL_NAME = "inference_services"
|
19
19
|
|
20
|
-
MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME = "name"
|
21
|
-
|
22
20
|
def show_models(
|
23
21
|
self,
|
24
22
|
*,
|
@@ -85,18 +83,6 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
85
83
|
|
86
84
|
return res.validate()
|
87
85
|
|
88
|
-
def show_endpoints(
|
89
|
-
self,
|
90
|
-
*,
|
91
|
-
service_name: str,
|
92
|
-
) -> List[row.Row]:
|
93
|
-
res = query_result_checker.SqlResultValidator(
|
94
|
-
self._session,
|
95
|
-
(f"SHOW ENDPOINTS IN SERVICE {service_name}"),
|
96
|
-
).has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
|
97
|
-
|
98
|
-
return res.validate()
|
99
|
-
|
100
86
|
def set_comment(
|
101
87
|
self,
|
102
88
|
*,
|
@@ -10,7 +10,7 @@ from snowflake.ml._internal.utils import (
|
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
12
|
from snowflake.ml.model._client.sql import _base
|
13
|
-
from snowflake.snowpark import dataframe, functions as F, types as spt
|
13
|
+
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
14
14
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
15
15
|
|
16
16
|
|
@@ -26,6 +26,9 @@ class ServiceStatus(enum.Enum):
|
|
26
26
|
|
27
27
|
|
28
28
|
class ServiceSQLClient(_base._BaseSQLClient):
|
29
|
+
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
30
|
+
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
31
|
+
|
29
32
|
def build_model_container(
|
30
33
|
self,
|
31
34
|
*,
|
@@ -216,3 +219,24 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
216
219
|
f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}",
|
217
220
|
statement_params=statement_params,
|
218
221
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
222
|
+
|
223
|
+
def show_endpoints(
|
224
|
+
self,
|
225
|
+
*,
|
226
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
227
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
228
|
+
service_name: sql_identifier.SqlIdentifier,
|
229
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
230
|
+
) -> List[row.Row]:
|
231
|
+
fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
232
|
+
res = (
|
233
|
+
query_result_checker.SqlResultValidator(
|
234
|
+
self._session,
|
235
|
+
(f"SHOW ENDPOINTS IN SERVICE {fully_qualified_service_name}"),
|
236
|
+
statement_params=statement_params,
|
237
|
+
)
|
238
|
+
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
|
239
|
+
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
|
240
|
+
)
|
241
|
+
|
242
|
+
return res.validate()
|
@@ -5,6 +5,7 @@ import sys
|
|
5
5
|
|
6
6
|
import anyio
|
7
7
|
import pandas as pd
|
8
|
+
import numpy as np
|
8
9
|
from _snowflake import vectorized
|
9
10
|
|
10
11
|
from snowflake.ml.model._packager import model_packager
|
@@ -47,4 +48,4 @@ def {function_name}(df: pd.DataFrame) -> dict:
|
|
47
48
|
df.columns = input_cols
|
48
49
|
input_df = df.astype(dtype=dtype_map)
|
49
50
|
predictions_df = runner(input_df[input_cols])
|
50
|
-
return predictions_df.to_dict("records")
|
51
|
+
return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
|
@@ -174,6 +174,18 @@ class ModelEnv:
|
|
174
174
|
except env_utils.DuplicateDependencyError:
|
175
175
|
pass
|
176
176
|
|
177
|
+
def remove_if_present_conda(self, conda_pkgs: List[str]) -> None:
|
178
|
+
"""Remove conda requirements from model env if present.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
conda_pkgs: A list of package name to be removed from conda requirements.
|
182
|
+
"""
|
183
|
+
for pkg_name in conda_pkgs:
|
184
|
+
spec_conda = env_utils._find_conda_dep_spec(self._conda_dependencies, pkg_name)
|
185
|
+
if spec_conda:
|
186
|
+
channel, spec = spec_conda
|
187
|
+
self._conda_dependencies[channel].remove(spec)
|
188
|
+
|
177
189
|
def generate_env_for_cuda(self) -> None:
|
178
190
|
if self.cuda_version is None:
|
179
191
|
return
|
@@ -179,7 +179,7 @@ def convert_explanations_to_2D_df(
|
|
179
179
|
return pd.DataFrame(explanations)
|
180
180
|
|
181
181
|
if hasattr(model, "classes_"):
|
182
|
-
classes_list = [str(cl) for cl in model.classes_]
|
182
|
+
classes_list = [str(cl) for cl in model.classes_]
|
183
183
|
len_classes = len(classes_list)
|
184
184
|
if explanations.shape[2] != len_classes:
|
185
185
|
raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
|
@@ -95,7 +95,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
95
95
|
get_prediction_fn=get_prediction,
|
96
96
|
)
|
97
97
|
model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
|
98
|
-
model_meta.task = model_task_and_output.task
|
98
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
99
99
|
if enable_explainability:
|
100
100
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
101
101
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import os
|
3
3
|
import pathlib
|
4
4
|
import sys
|
5
|
-
from typing import Dict, Optional, Type, final
|
5
|
+
from typing import Dict, Optional, Type, cast, final
|
6
6
|
|
7
7
|
import anyio
|
8
8
|
import cloudpickle
|
@@ -108,6 +108,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
108
108
|
model_meta=model_meta,
|
109
109
|
model_blobs_dir_path=model_blobs_dir_path,
|
110
110
|
is_sub_model=True,
|
111
|
+
**cast(model_types.BaseModelSaveOption, kwargs),
|
111
112
|
)
|
112
113
|
|
113
114
|
# Make sure that the module where the model is defined get pickled by value as well.
|
@@ -175,6 +176,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
175
176
|
name=sub_model_name,
|
176
177
|
model_meta=model_meta,
|
177
178
|
model_blobs_dir_path=model_blobs_dir_path,
|
179
|
+
**cast(model_types.BaseModelLoadOption, kwargs),
|
178
180
|
)
|
179
181
|
models[sub_model_name] = sub_model
|
180
182
|
reconstructed_context = custom_model.ModelContext(artifacts=artifacts, models=models)
|
@@ -196,13 +196,14 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
196
196
|
with open(model_blob_file_path, "rb") as f:
|
197
197
|
model = cloudpickle.load(f)
|
198
198
|
assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type))
|
199
|
+
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
199
200
|
|
200
201
|
return model
|
201
202
|
|
202
203
|
@classmethod
|
203
204
|
def convert_as_custom_model(
|
204
205
|
cls,
|
205
|
-
raw_model: Union["lightgbm.Booster", "lightgbm.
|
206
|
+
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
206
207
|
model_meta: model_meta_api.ModelMetadata,
|
207
208
|
background_data: Optional[pd.DataFrame] = None,
|
208
209
|
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|