snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/type_utils.py +3 -3
  7. snowflake/ml/data/__init__.py +5 -0
  8. snowflake/ml/model/_client/model/model_version_impl.py +7 -7
  9. snowflake/ml/model/_client/ops/model_ops.py +51 -30
  10. snowflake/ml/model/_client/ops/service_ops.py +13 -2
  11. snowflake/ml/model/_client/sql/model.py +0 -14
  12. snowflake/ml/model/_client/sql/service.py +25 -1
  13. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  14. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  15. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
  16. snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
  17. snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
  18. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
  19. snowflake/ml/model/_packager/model_handlers/sklearn.py +48 -1
  20. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  22. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  23. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  24. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  26. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  27. snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
  28. snowflake/ml/model/_signatures/core.py +63 -16
  29. snowflake/ml/model/_signatures/pandas_handler.py +71 -27
  30. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  31. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  32. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  33. snowflake/ml/model/_signatures/utils.py +4 -0
  34. snowflake/ml/model/model_signature.py +38 -9
  35. snowflake/ml/model/type_hints.py +1 -1
  36. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  37. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  38. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +158 -1045
  39. snowflake/ml/monitoring/_manager/model_monitor_manager.py +106 -230
  40. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  41. snowflake/ml/monitoring/model_monitor.py +7 -96
  42. snowflake/ml/registry/registry.py +17 -29
  43. snowflake/ml/version.py +1 -1
  44. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +31 -5
  45. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +48 -47
  46. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  47. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  48. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,73 +1,125 @@
1
- from typing import Iterator, cast
1
+ import json
2
+ from typing import Any, Iterator, Optional
2
3
 
3
- import requests
4
+ _FIELD_SEPARATOR = ":"
4
5
 
5
6
 
6
7
  class Event:
7
- def __init__(self, event: str = "message", data: str = "") -> None:
8
+ """Representation of an event from the event stream."""
9
+
10
+ def __init__(
11
+ self,
12
+ id: Optional[str] = None,
13
+ event: str = "message",
14
+ data: str = "",
15
+ comment: Optional[str] = None,
16
+ retry: Optional[int] = None,
17
+ ) -> None:
18
+ self.id = id
8
19
  self.event = event
9
20
  self.data = data
21
+ self.comment = comment
22
+ self.retry = retry
10
23
 
11
24
  def __str__(self) -> str:
12
25
  s = f"{self.event} event"
26
+ if self.id:
27
+ s += f" #{self.id}"
13
28
  if self.data:
14
- s += f", {len(self.data)} bytes"
29
+ s += ", {} byte{}".format(len(self.data), "s" if len(self.data) else "")
15
30
  else:
16
31
  s += ", no data"
32
+ if self.comment:
33
+ s += f", comment: {self.comment}"
34
+ if self.retry:
35
+ s += f", retry in {self.retry}ms"
17
36
  return s
18
37
 
19
38
 
39
+ # This is copied from the snowpy library:
40
+ # https://github.com/snowflakedb/snowpy/blob/main/libs/snowflake.core/src/snowflake/core/rest.py#L39
41
+ # TODO(SNOW-1750723) - Current there’s code duplication across snowflake-ml-python
42
+ # and snowpy library for Cortex REST API which was done to meet our GA timelines
43
+ # Once snowpy has a release with https://github.com/snowflakedb/snowpy/pull/679, we should
44
+ # remove the class here and directly refer from the snowflake.core package directly
20
45
  class SSEClient:
21
- def __init__(self, response: requests.Response) -> None:
46
+ def __init__(self, event_source: Any, char_enc: str = "utf-8") -> None:
47
+ self._event_source = event_source
48
+ self._char_enc = char_enc
22
49
 
23
- self.response = response
24
-
25
- def _read(self) -> Iterator[str]:
26
-
27
- lines = b""
28
- for chunk in self.response:
50
+ def _read(self) -> Iterator[bytes]:
51
+ data = b""
52
+ for chunk in self._event_source:
29
53
  for line in chunk.splitlines(True):
30
- lines += line
31
- if lines.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
32
- yield cast(str, lines)
33
- lines = b""
34
- if lines:
35
- yield cast(str, lines)
54
+ data += line
55
+ if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
56
+ yield data
57
+ data = b""
58
+ if data:
59
+ yield data
36
60
 
37
61
  def events(self) -> Iterator[Event]:
38
- for raw_event in self._read():
62
+ content_type = self._event_source.headers.get("Content-Type")
63
+ # The check for empty content-type is present because it's being populated after
64
+ # the change in https://github.com/snowflakedb/snowflake/pull/217654.
65
+ # This can be removed once the above change makes it to prod or we move to snowpy
66
+ # for SSEClient implementation.
67
+ if content_type == "text/event-stream" or not content_type:
68
+ return self._handle_sse()
69
+ elif content_type == "application/json":
70
+ return self._handle_json()
71
+ else:
72
+ raise ValueError(f"Unknown Content-Type: {content_type}")
73
+
74
+ def _handle_sse(self) -> Iterator[Event]:
75
+ for chunk in self._read():
39
76
  event = Event()
40
- # splitlines() only uses \r and \n
41
- for line in raw_event.splitlines():
77
+ # Split before decoding so splitlines() only uses \r and \n
78
+ for line_bytes in chunk.splitlines():
79
+ # Decode the line.
80
+ line = line_bytes.decode(self._char_enc)
42
81
 
43
- line = cast(bytes, line).decode("utf-8")
82
+ # Lines starting with a separator are comments and are to be
83
+ # ignored.
84
+ if not line.strip() or line.startswith(_FIELD_SEPARATOR):
85
+ continue
44
86
 
45
- data = line.split(":", 1)
87
+ data = line.split(_FIELD_SEPARATOR, 1)
46
88
  field = data[0]
47
89
 
90
+ # Ignore unknown fields.
91
+ if not hasattr(event, field):
92
+ continue
93
+
48
94
  if len(data) > 1:
95
+ # From the spec:
49
96
  # "If value starts with a single U+0020 SPACE character,
50
- # remove it from value. .strip() would remove all white spaces"
97
+ # remove it from value."
51
98
  if data[1].startswith(" "):
52
99
  value = data[1][1:]
53
100
  else:
54
101
  value = data[1]
55
102
  else:
103
+ # If no value is present after the separator,
104
+ # assume an empty value.
56
105
  value = ""
57
106
 
58
107
  # The data field may come over multiple lines and their values
59
108
  # are concatenated with each other.
109
+ current_value = getattr(event, field, "")
60
110
  if field == "data":
61
- event.data += value + "\n"
62
- elif field == "event":
63
- event.event = value
111
+ new_value = current_value + value + "\n"
112
+ else:
113
+ new_value = value
114
+ setattr(event, field, new_value)
64
115
 
116
+ # Events with no data are not dispatched.
65
117
  if not event.data:
66
118
  continue
67
119
 
68
120
  # If the data field ends with a newline, remove it.
69
121
  if event.data.endswith("\n"):
70
- event.data = event.data[0:-1] # Replace trailing newline - rstrip would remove multiple.
122
+ event.data = event.data[0:-1]
71
123
 
72
124
  # Empty event names default to 'message'
73
125
  event.event = event.event or "message"
@@ -77,5 +129,16 @@ class SSEClient:
77
129
 
78
130
  yield event
79
131
 
132
+ def _handle_json(self) -> Iterator[Event]:
133
+ data_list = json.loads(self._event_source.data.decode(self._char_enc))
134
+ for data in data_list:
135
+ yield Event(
136
+ id=data.get("id"),
137
+ event=data.get("event"),
138
+ data=data.get("data"),
139
+ comment=data.get("comment"),
140
+ retry=data.get("retry"),
141
+ )
142
+
80
143
  def close(self) -> None:
81
- self.response.close()
144
+ self._event_source.close()
snowflake/cortex/_util.py CHANGED
@@ -1,6 +1,8 @@
1
- from typing import Dict, List, Optional, Union, cast
1
+ from typing import Any, Dict, List, Optional, Union, cast
2
2
 
3
3
  from snowflake import snowpark
4
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
5
+ from snowflake.ml._internal.utils import formatting
4
6
  from snowflake.snowpark import context, functions
5
7
 
6
8
  CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
@@ -64,3 +66,30 @@ def _call_sql_function_immediate(
64
66
  empty_df = session.create_dataframe([snowpark.Row()])
65
67
  df = empty_df.select(functions.builtin(function)(*lit_args))
66
68
  return cast(str, df.collect()[0][0])
69
+
70
+
71
+ def call_sql_function_literals(function: str, session: Optional[snowpark.Session], *args: Any) -> str:
72
+ r"""Call a SQL function with only literal arguments.
73
+
74
+ This is useful for calling system functions.
75
+
76
+ Args:
77
+ function: The name of the function to be called.
78
+ session: The Snowpark session to use.
79
+ *args: The list of arguments
80
+
81
+ Returns:
82
+ String value that corresponds the the first cell in the dataframe.
83
+
84
+ Raises:
85
+ SnowflakeMLException: If no session is given and no active session exists.
86
+ """
87
+ if session is None:
88
+ session = context.get_active_session()
89
+ if session is None:
90
+ raise exceptions.SnowflakeMLException(
91
+ error_code=error_codes.INVALID_SNOWPARK_SESSION,
92
+ )
93
+
94
+ function_arguments = ",".join(["NULL" if arg is None else formatting.format_value_for_select(arg) for arg in args])
95
+ return cast(str, session.sql(f"SELECT {function}({function_arguments})").collect()[0][0])
@@ -1,4 +1,4 @@
1
- import sys
1
+ import importlib
2
2
  from typing import Any, Generic, Type, TypeVar, Union, cast
3
3
 
4
4
  import numpy as np
@@ -51,8 +51,8 @@ class LazyType(Generic[T]):
51
51
  def get_class(self) -> Type[T]:
52
52
  if self._runtime_class is None:
53
53
  try:
54
- m = sys.modules[self.module]
55
- except KeyError:
54
+ m = importlib.import_module(self.module)
55
+ except ModuleNotFoundError:
56
56
  raise ValueError(f"Module {self.module} not imported.")
57
57
 
58
58
  self._runtime_class = cast("Type[T]", getattr(m, self.qualname))
@@ -0,0 +1,5 @@
1
+ from .data_connector import DataConnector
2
+ from .data_ingestor import DataIngestor, DataIngestorType
3
+ from .data_source import DataFrameInfo, DatasetInfo, DataSource
4
+
5
+ __all__ = ["DataConnector", "DataSource", "DataFrameInfo", "DatasetInfo", "DataIngestor", "DataIngestorType"]
@@ -851,17 +851,13 @@ class ModelVersion(lineage_node.LineageNode):
851
851
  )
852
852
 
853
853
  return pd.DataFrame(
854
- self._model_ops.list_inference_services(
854
+ self._model_ops.show_services(
855
855
  database_name=None,
856
856
  schema_name=None,
857
857
  model_name=self._model_name,
858
858
  version_name=self._version_name,
859
859
  statement_params=statement_params,
860
- ),
861
- columns=[
862
- self._model_ops.INFERENCE_SERVICE_NAME_COL_NAME,
863
- self._model_ops.INFERENCE_SERVICE_ENDPOINT_COL_NAME,
864
- ],
860
+ )
865
861
  )
866
862
 
867
863
  @telemetry.send_api_usage_telemetry(
@@ -889,12 +885,16 @@ class ModelVersion(lineage_node.LineageNode):
889
885
  project=_TELEMETRY_PROJECT,
890
886
  subproject=_TELEMETRY_SUBPROJECT,
891
887
  )
888
+
889
+ database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
892
890
  self._model_ops.delete_service(
893
891
  database_name=None,
894
892
  schema_name=None,
895
893
  model_name=self._model_name,
896
894
  version_name=self._version_name,
897
- service_name=service_name,
895
+ service_database_name=database_name_id,
896
+ service_schema_name=schema_name_id,
897
+ service_name=service_name_id,
898
898
  statement_params=statement_params,
899
899
  )
900
900
 
@@ -3,7 +3,7 @@ import os
3
3
  import pathlib
4
4
  import tempfile
5
5
  import warnings
6
- from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
6
+ from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast, overload
7
7
 
8
8
  import yaml
9
9
 
@@ -31,9 +31,14 @@ from snowflake.snowpark import dataframe, row, session
31
31
  from snowflake.snowpark._internal import utils as snowpark_utils
32
32
 
33
33
 
34
+ class ServiceInfo(TypedDict):
35
+ name: str
36
+ inference_endpoint: Optional[str]
37
+
38
+
34
39
  class ModelOperator:
35
- INFERENCE_SERVICE_NAME_COL_NAME = "service_name"
36
- INFERENCE_SERVICE_ENDPOINT_COL_NAME = "endpoints"
40
+ INFERENCE_SERVICE_ENDPOINT_NAME = "inference"
41
+ INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app"
37
42
 
38
43
  def __init__(
39
44
  self,
@@ -517,7 +522,7 @@ class ModelOperator:
517
522
  statement_params=statement_params,
518
523
  )
519
524
 
520
- def list_inference_services(
525
+ def show_services(
521
526
  self,
522
527
  *,
523
528
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -525,7 +530,7 @@ class ModelOperator:
525
530
  model_name: sql_identifier.SqlIdentifier,
526
531
  version_name: sql_identifier.SqlIdentifier,
527
532
  statement_params: Optional[Dict[str, Any]] = None,
528
- ) -> Dict[str, List[str]]:
533
+ ) -> List[ServiceInfo]:
529
534
  res = self._model_client.show_versions(
530
535
  database_name=database_name,
531
536
  schema_name=schema_name,
@@ -546,21 +551,28 @@ class ModelOperator:
546
551
 
547
552
  json_array = json.loads(res[0][service_col_name])
548
553
  # TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
549
- services = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
550
- endpoint_col_name = self._model_client.MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME
551
-
552
- services_col, endpoints_col = [], []
553
- for service in services:
554
- res = self._model_client.show_endpoints(service_name=service)
555
- endpoints = [endpoint[endpoint_col_name] for endpoint in res]
556
- for endpoint in endpoints:
557
- services_col.append(service)
558
- endpoints_col.append(endpoint)
559
-
560
- return {
561
- self.INFERENCE_SERVICE_NAME_COL_NAME: services_col,
562
- self.INFERENCE_SERVICE_ENDPOINT_COL_NAME: endpoints_col,
563
- }
554
+ fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
555
+
556
+ result = []
557
+ ingress_url: Optional[str] = None
558
+ for fully_qualified_service_name in fully_qualified_service_names:
559
+ db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
560
+ for res_row in self._service_client.show_endpoints(
561
+ database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
562
+ ):
563
+ if (
564
+ res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME]
565
+ == self.INFERENCE_SERVICE_ENDPOINT_NAME
566
+ and res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] is not None
567
+ ):
568
+ ingress_url = str(
569
+ res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME]
570
+ )
571
+ if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
572
+ ingress_url = None
573
+ result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url))
574
+
575
+ return result
564
576
 
565
577
  def delete_service(
566
578
  self,
@@ -569,33 +581,42 @@ class ModelOperator:
569
581
  schema_name: Optional[sql_identifier.SqlIdentifier],
570
582
  model_name: sql_identifier.SqlIdentifier,
571
583
  version_name: sql_identifier.SqlIdentifier,
572
- service_name: str,
584
+ service_database_name: Optional[sql_identifier.SqlIdentifier],
585
+ service_schema_name: Optional[sql_identifier.SqlIdentifier],
586
+ service_name: sql_identifier.SqlIdentifier,
573
587
  statement_params: Optional[Dict[str, Any]] = None,
574
588
  ) -> None:
575
- services = self.list_inference_services(
589
+ services = self.show_services(
576
590
  database_name=database_name,
577
591
  schema_name=schema_name,
578
592
  model_name=model_name,
579
593
  version_name=version_name,
580
594
  statement_params=statement_params,
581
595
  )
582
- db, schema, service_name = sql_identifier.parse_fully_qualified_name(service_name)
596
+
597
+ # Fall back to the model's database and schema.
598
+ # database_name or schema_name are set if the model is created or get using fully qualified name
599
+ # Otherwise, the model's database and schema are same as registry's database and schema, which are set in the
600
+ # self._model_client.
601
+
602
+ service_database_name = service_database_name or database_name or self._model_client._database_name
603
+ service_schema_name = service_schema_name or schema_name or self._model_client._schema_name
583
604
  fully_qualified_service_name = sql_identifier.get_fully_qualified_name(
584
- db, schema, service_name, self._session.get_current_database(), self._session.get_current_schema()
605
+ service_database_name, service_schema_name, service_name
585
606
  )
586
607
 
587
- service_col_name = self.INFERENCE_SERVICE_NAME_COL_NAME
588
- for service in services[service_col_name]:
589
- if service == fully_qualified_service_name:
608
+ for service_info in services:
609
+ if service_info["name"] == fully_qualified_service_name:
590
610
  self._service_client.drop_service(
591
- database_name=db,
592
- schema_name=schema,
611
+ database_name=service_database_name,
612
+ schema_name=service_schema_name,
593
613
  service_name=service_name,
594
614
  statement_params=statement_params,
595
615
  )
596
616
  return
597
617
  raise ValueError(
598
- f"Service '{service_name}' does not exist or unauthorized or not associated with this model version."
618
+ f"Service '{fully_qualified_service_name}' does not exist "
619
+ "or unauthorized or not associated with this model version."
599
620
  )
600
621
 
601
622
  def get_model_version_manifest(
@@ -109,6 +109,17 @@ class ServiceOperator:
109
109
  build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
110
110
  statement_params: Optional[Dict[str, Any]] = None,
111
111
  ) -> str:
112
+
113
+ # Fall back to the registry's database and schema if not provided
114
+ database_name = database_name or self._database_name
115
+ schema_name = schema_name or self._schema_name
116
+
117
+ # Fall back to the model's database and schema if not provided then to the registry's database and schema
118
+ service_database_name = service_database_name or database_name or self._database_name
119
+ service_schema_name = service_schema_name or schema_name or self._schema_name
120
+
121
+ image_repo_database_name = image_repo_database_name or database_name or self._database_name
122
+ image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
112
123
  # create a temp stage
113
124
  stage_name = sql_identifier.SqlIdentifier(
114
125
  snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
@@ -130,8 +141,8 @@ class ServiceOperator:
130
141
  raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
131
142
 
132
143
  self._model_deployment_spec.save(
133
- database_name=database_name or self._database_name,
134
- schema_name=schema_name or self._schema_name,
144
+ database_name=database_name,
145
+ schema_name=schema_name,
135
146
  model_name=model_name,
136
147
  version_name=version_name,
137
148
  service_database_name=service_database_name,
@@ -17,8 +17,6 @@ class ModelSQLClient(_base._BaseSQLClient):
17
17
  MODEL_VERSION_ALIASES_COL_NAME = "aliases"
18
18
  MODEL_VERSION_INFERENCE_SERVICES_COL_NAME = "inference_services"
19
19
 
20
- MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME = "name"
21
-
22
20
  def show_models(
23
21
  self,
24
22
  *,
@@ -85,18 +83,6 @@ class ModelSQLClient(_base._BaseSQLClient):
85
83
 
86
84
  return res.validate()
87
85
 
88
- def show_endpoints(
89
- self,
90
- *,
91
- service_name: str,
92
- ) -> List[row.Row]:
93
- res = query_result_checker.SqlResultValidator(
94
- self._session,
95
- (f"SHOW ENDPOINTS IN SERVICE {service_name}"),
96
- ).has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
97
-
98
- return res.validate()
99
-
100
86
  def set_comment(
101
87
  self,
102
88
  *,
@@ -10,7 +10,7 @@ from snowflake.ml._internal.utils import (
10
10
  sql_identifier,
11
11
  )
12
12
  from snowflake.ml.model._client.sql import _base
13
- from snowflake.snowpark import dataframe, functions as F, types as spt
13
+ from snowflake.snowpark import dataframe, functions as F, row, types as spt
14
14
  from snowflake.snowpark._internal import utils as snowpark_utils
15
15
 
16
16
 
@@ -26,6 +26,9 @@ class ServiceStatus(enum.Enum):
26
26
 
27
27
 
28
28
  class ServiceSQLClient(_base._BaseSQLClient):
29
+ MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
30
+ MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
31
+
29
32
  def build_model_container(
30
33
  self,
31
34
  *,
@@ -216,3 +219,24 @@ class ServiceSQLClient(_base._BaseSQLClient):
216
219
  f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}",
217
220
  statement_params=statement_params,
218
221
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
222
+
223
+ def show_endpoints(
224
+ self,
225
+ *,
226
+ database_name: Optional[sql_identifier.SqlIdentifier],
227
+ schema_name: Optional[sql_identifier.SqlIdentifier],
228
+ service_name: sql_identifier.SqlIdentifier,
229
+ statement_params: Optional[Dict[str, Any]] = None,
230
+ ) -> List[row.Row]:
231
+ fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
232
+ res = (
233
+ query_result_checker.SqlResultValidator(
234
+ self._session,
235
+ (f"SHOW ENDPOINTS IN SERVICE {fully_qualified_service_name}"),
236
+ statement_params=statement_params,
237
+ )
238
+ .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
239
+ .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
240
+ )
241
+
242
+ return res.validate()
@@ -5,6 +5,7 @@ import sys
5
5
 
6
6
  import anyio
7
7
  import pandas as pd
8
+ import numpy as np
8
9
  from _snowflake import vectorized
9
10
 
10
11
  from snowflake.ml.model._packager import model_packager
@@ -47,4 +48,4 @@ def {function_name}(df: pd.DataFrame) -> dict:
47
48
  df.columns = input_cols
48
49
  input_df = df.astype(dtype=dtype_map)
49
50
  predictions_df = runner(input_df[input_cols])
50
- return predictions_df.to_dict("records")
51
+ return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
@@ -174,6 +174,18 @@ class ModelEnv:
174
174
  except env_utils.DuplicateDependencyError:
175
175
  pass
176
176
 
177
+ def remove_if_present_conda(self, conda_pkgs: List[str]) -> None:
178
+ """Remove conda requirements from model env if present.
179
+
180
+ Args:
181
+ conda_pkgs: A list of package name to be removed from conda requirements.
182
+ """
183
+ for pkg_name in conda_pkgs:
184
+ spec_conda = env_utils._find_conda_dep_spec(self._conda_dependencies, pkg_name)
185
+ if spec_conda:
186
+ channel, spec = spec_conda
187
+ self._conda_dependencies[channel].remove(spec)
188
+
177
189
  def generate_env_for_cuda(self) -> None:
178
190
  if self.cuda_version is None:
179
191
  return
@@ -179,7 +179,7 @@ def convert_explanations_to_2D_df(
179
179
  return pd.DataFrame(explanations)
180
180
 
181
181
  if hasattr(model, "classes_"):
182
- classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr]
182
+ classes_list = [str(cl) for cl in model.classes_]
183
183
  len_classes = len(classes_list)
184
184
  if explanations.shape[2] != len_classes:
185
185
  raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
@@ -95,7 +95,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
95
95
  get_prediction_fn=get_prediction,
96
96
  )
97
97
  model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
98
- model_meta.task = model_task_and_output.task
98
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
99
99
  if enable_explainability:
100
100
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
101
101
  model_meta = handlers_utils.add_explain_method_signature(
@@ -2,7 +2,7 @@ import inspect
2
2
  import os
3
3
  import pathlib
4
4
  import sys
5
- from typing import Dict, Optional, Type, final
5
+ from typing import Dict, Optional, Type, cast, final
6
6
 
7
7
  import anyio
8
8
  import cloudpickle
@@ -108,6 +108,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
108
108
  model_meta=model_meta,
109
109
  model_blobs_dir_path=model_blobs_dir_path,
110
110
  is_sub_model=True,
111
+ **cast(model_types.BaseModelSaveOption, kwargs),
111
112
  )
112
113
 
113
114
  # Make sure that the module where the model is defined get pickled by value as well.
@@ -175,6 +176,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
175
176
  name=sub_model_name,
176
177
  model_meta=model_meta,
177
178
  model_blobs_dir_path=model_blobs_dir_path,
179
+ **cast(model_types.BaseModelLoadOption, kwargs),
178
180
  )
179
181
  models[sub_model_name] = sub_model
180
182
  reconstructed_context = custom_model.ModelContext(artifacts=artifacts, models=models)
@@ -196,13 +196,14 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
196
196
  with open(model_blob_file_path, "rb") as f:
197
197
  model = cloudpickle.load(f)
198
198
  assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type))
199
+ assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
199
200
 
200
201
  return model
201
202
 
202
203
  @classmethod
203
204
  def convert_as_custom_model(
204
205
  cls,
205
- raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
206
+ raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
206
207
  model_meta: model_meta_api.ModelMetadata,
207
208
  background_data: Optional[pd.DataFrame] = None,
208
209
  **kwargs: Unpack[model_types.LGBMModelLoadOptions],