snowflake-ml-python 1.7.1__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,141 @@
1
+ import base64
2
+ import hashlib
3
+ import logging
4
+ from datetime import datetime, timedelta, timezone
5
+ from typing import Optional
6
+
7
+ import jwt
8
+ from cryptography.hazmat.primitives import serialization
9
+ from cryptography.hazmat.primitives.asymmetric import types
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ ISSUER = "iss"
14
+ EXPIRE_TIME = "exp"
15
+ ISSUE_TIME = "iat"
16
+ SUBJECT = "sub"
17
+
18
+
19
+ class JWTGenerator:
20
+ """
21
+ Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator
22
+ keeps the generated token and only regenerates the token if a specified period of time has passed.
23
+ """
24
+
25
+ _DEFAULT_LIFETIME = timedelta(minutes=59) # The tokens will have a 59-minute lifetime
26
+ _DEFAULT_RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes
27
+ ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
28
+
29
+ def __init__(
30
+ self,
31
+ account: str,
32
+ user: str,
33
+ private_key: types.PRIVATE_KEY_TYPES,
34
+ lifetime: Optional[timedelta] = None,
35
+ renewal_delay: Optional[timedelta] = None,
36
+ ) -> None:
37
+ """
38
+ Create a new JWTGenerator object.
39
+
40
+ Args:
41
+ account: The account identifier.
42
+ user: The username.
43
+ private_key: The private key used to sign the JWT.
44
+ lifetime: The lifetime of the token.
45
+ renewal_delay: The time before the token expires to renew it.
46
+ """
47
+
48
+ # Construct the fully qualified name of the user in uppercase.
49
+ self.account = JWTGenerator._prepare_account_name_for_jwt(account)
50
+ self.user = user.upper()
51
+ self.qualified_username = self.account + "." + self.user
52
+ self.private_key = private_key
53
+ self.public_key_fp = JWTGenerator._calculate_public_key_fingerprint(self.private_key)
54
+
55
+ self.issuer = self.qualified_username + "." + self.public_key_fp
56
+ self.lifetime = lifetime or JWTGenerator._DEFAULT_LIFETIME
57
+ self.renewal_delay = renewal_delay or JWTGenerator._DEFAULT_RENEWAL_DELTA
58
+ self.renew_time = datetime.now(timezone.utc)
59
+ self.token: Optional[str] = None
60
+
61
+ logger.info(
62
+ """Creating JWTGenerator with arguments
63
+ account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
64
+ self.account,
65
+ self.user,
66
+ self.lifetime,
67
+ self.renewal_delay,
68
+ )
69
+
70
+ @staticmethod
71
+ def _prepare_account_name_for_jwt(raw_account: str) -> str:
72
+ account = raw_account
73
+ if ".global" not in account:
74
+ # Handle the general case.
75
+ idx = account.find(".")
76
+ if idx > 0:
77
+ account = account[0:idx]
78
+ else:
79
+ # Handle the replication case.
80
+ idx = account.find("-")
81
+ if idx > 0:
82
+ account = account[0:idx]
83
+ # Use uppercase for the account identifier.
84
+ return account.upper()
85
+
86
+ def get_token(self) -> str:
87
+ now = datetime.now(timezone.utc) # Fetch the current time
88
+ if self.token is not None and self.renew_time > now:
89
+ return self.token
90
+
91
+ # If the token has expired or doesn't exist, regenerate the token.
92
+ logger.info(
93
+ "Generating a new token because the present time (%s) is later than the renewal time (%s)",
94
+ now,
95
+ self.renew_time,
96
+ )
97
+ # Calculate the next time we need to renew the token.
98
+ self.renew_time = now + self.renewal_delay
99
+
100
+ # Create our payload
101
+ payload = {
102
+ # Set the issuer to the fully qualified username concatenated with the public key fingerprint.
103
+ ISSUER: self.issuer,
104
+ # Set the subject to the fully qualified username.
105
+ SUBJECT: self.qualified_username,
106
+ # Set the issue time to now.
107
+ ISSUE_TIME: now,
108
+ # Set the expiration time, based on the lifetime specified for this object.
109
+ EXPIRE_TIME: now + self.lifetime,
110
+ }
111
+
112
+ # Regenerate the actual token
113
+ token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
114
+ # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
115
+ # If the token is a byte string, convert it to a string.
116
+ if isinstance(token, bytes):
117
+ token = token.decode("utf-8")
118
+ self.token = token
119
+ logger.info(
120
+ "Generated a JWT with the following payload: %s",
121
+ jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
122
+ )
123
+
124
+ return token
125
+
126
+ @staticmethod
127
+ def _calculate_public_key_fingerprint(private_key: types.PRIVATE_KEY_TYPES) -> str:
128
+ # Get the raw bytes of public key.
129
+ public_key_raw = private_key.public_key().public_bytes(
130
+ serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo
131
+ )
132
+
133
+ # Get the sha256 hash of the raw bytes.
134
+ sha256hash = hashlib.sha256()
135
+ sha256hash.update(public_key_raw)
136
+
137
+ # Base64-encode the value and prepend the prefix 'SHA256:'.
138
+ public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
139
+ logger.info("Public key fingerprint is %s", public_key_fp)
140
+
141
+ return public_key_fp
@@ -14,7 +14,7 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
15
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
16
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
17
- from snowflake.snowpark import Session, dataframe
17
+ from snowflake.snowpark import Session, async_job, dataframe
18
18
 
19
19
  _TELEMETRY_PROJECT = "MLOps"
20
20
  _TELEMETRY_SUBPROJECT = "ModelManagement"
@@ -631,7 +631,8 @@ class ModelVersion(lineage_node.LineageNode):
631
631
  max_batch_rows: Optional[int] = None,
632
632
  force_rebuild: bool = False,
633
633
  build_external_access_integration: Optional[str] = None,
634
- ) -> str:
634
+ block: bool = True,
635
+ ) -> Union[str, async_job.AsyncJob]:
635
636
  """Create an inference service with the given spec.
636
637
 
637
638
  Args:
@@ -659,6 +660,9 @@ class ModelVersion(lineage_node.LineageNode):
659
660
  force_rebuild: Whether to force a model inference image rebuild.
660
661
  build_external_access_integration: (Deprecated) The external access integration for image build. This is
661
662
  usually permitting access to conda & PyPI repositories.
663
+ block: A bool value indicating whether this function will wait until the service is available.
664
+ When it is ``False``, this function executes the underlying service creation asynchronously
665
+ and returns an :class:`AsyncJob`.
662
666
  """
663
667
  ...
664
668
 
@@ -679,7 +683,8 @@ class ModelVersion(lineage_node.LineageNode):
679
683
  max_batch_rows: Optional[int] = None,
680
684
  force_rebuild: bool = False,
681
685
  build_external_access_integrations: Optional[List[str]] = None,
682
- ) -> str:
686
+ block: bool = True,
687
+ ) -> Union[str, async_job.AsyncJob]:
683
688
  """Create an inference service with the given spec.
684
689
 
685
690
  Args:
@@ -707,6 +712,9 @@ class ModelVersion(lineage_node.LineageNode):
707
712
  force_rebuild: Whether to force a model inference image rebuild.
708
713
  build_external_access_integrations: The external access integrations for image build. This is usually
709
714
  permitting access to conda & PyPI repositories.
715
+ block: A bool value indicating whether this function will wait until the service is available.
716
+ When it is ``False``, this function executes the underlying service creation asynchronously
717
+ and returns an :class:`AsyncJob`.
710
718
  """
711
719
  ...
712
720
 
@@ -742,7 +750,8 @@ class ModelVersion(lineage_node.LineageNode):
742
750
  force_rebuild: bool = False,
743
751
  build_external_access_integration: Optional[str] = None,
744
752
  build_external_access_integrations: Optional[List[str]] = None,
745
- ) -> str:
753
+ block: bool = True,
754
+ ) -> Union[str, async_job.AsyncJob]:
746
755
  """Create an inference service with the given spec.
747
756
 
748
757
  Args:
@@ -772,12 +781,16 @@ class ModelVersion(lineage_node.LineageNode):
772
781
  usually permitting access to conda & PyPI repositories.
773
782
  build_external_access_integrations: The external access integrations for image build. This is usually
774
783
  permitting access to conda & PyPI repositories.
784
+ block: A bool value indicating whether this function will wait until the service is available.
785
+ When it is False, this function executes the underlying service creation asynchronously
786
+ and returns an AsyncJob.
775
787
 
776
788
  Raises:
777
789
  ValueError: Illegal external access integration arguments.
778
790
 
779
791
  Returns:
780
- Result information about service creation from server.
792
+ If `block=True`, return result information about service creation from server.
793
+ Otherwise, return the service creation AsyncJob.
781
794
  """
782
795
  statement_params = telemetry.get_statement_params(
783
796
  project=_TELEMETRY_PROJECT,
@@ -829,6 +842,7 @@ class ModelVersion(lineage_node.LineageNode):
829
842
  if build_external_access_integrations is None
830
843
  else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
831
844
  ),
845
+ block=block,
832
846
  statement_params=statement_params,
833
847
  )
834
848
 
@@ -6,7 +6,7 @@ import re
6
6
  import tempfile
7
7
  import threading
8
8
  import time
9
- from typing import Any, Dict, List, Optional, Tuple, cast
9
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast
10
10
 
11
11
  from packaging import version
12
12
 
@@ -15,7 +15,7 @@ from snowflake.ml._internal import file_utils
15
15
  from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
16
16
  from snowflake.ml.model._client.service import model_deployment_spec
17
17
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
18
- from snowflake.snowpark import exceptions, row, session
18
+ from snowflake.snowpark import async_job, exceptions, row, session
19
19
  from snowflake.snowpark._internal import utils as snowpark_utils
20
20
 
21
21
  module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
@@ -107,8 +107,9 @@ class ServiceOperator:
107
107
  max_batch_rows: Optional[int],
108
108
  force_rebuild: bool,
109
109
  build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
110
+ block: bool,
110
111
  statement_params: Optional[Dict[str, Any]] = None,
111
- ) -> str:
112
+ ) -> Union[str, async_job.AsyncJob]:
112
113
 
113
114
  # Fall back to the registry's database and schema if not provided
114
115
  database_name = database_name or self._database_name
@@ -204,11 +205,15 @@ class ServiceOperator:
204
205
  log_thread = self._start_service_log_streaming(
205
206
  async_job, services, model_inference_service_exists, force_rebuild, statement_params
206
207
  )
207
- log_thread.join()
208
208
 
209
- res = cast(str, cast(List[row.Row], async_job.result())[0][0])
210
- module_logger.info(f"Inference service {service_name} deployment complete: {res}")
211
- return res
209
+ if block:
210
+ log_thread.join()
211
+
212
+ res = cast(str, cast(List[row.Row], async_job.result())[0][0])
213
+ module_logger.info(f"Inference service {service_name} deployment complete: {res}")
214
+ return res
215
+ else:
216
+ return async_job
212
217
 
213
218
  def _start_service_log_streaming(
214
219
  self,
@@ -15,6 +15,6 @@ class StageSQLClient(_base._BaseSQLClient):
15
15
  ) -> None:
16
16
  query_result_checker.SqlResultValidator(
17
17
  self._session,
18
- f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
18
+ f"CREATE SCOPED TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
19
19
  statement_params=statement_params,
20
20
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -164,6 +164,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
164
164
  stacklevel=1,
165
165
  )
166
166
  enable_explainability = False
167
+ elif model_meta.task == model_types.Task.UNKNOWN:
168
+ enable_explainability = False
167
169
  else:
168
170
  enable_explainability = True
169
171
  if enable_explainability:
@@ -1,2 +1,2 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
- ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
+ ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -118,7 +118,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
118
118
  category=DeprecationWarning,
119
119
  stacklevel=1,
120
120
  )
121
-
122
121
  return core.ModelSignature(
123
122
  inputs=[
124
123
  core.FeatureSpec(name="user_inputs", dtype=core.DataType.STRING, shape=(-1,)),
@@ -1,6 +1,4 @@
1
- import typing
2
- from collections import Counter
3
- from typing import Any, Dict, List, Mapping, Optional, Set
1
+ from typing import Any, Dict, List, Mapping, Optional
4
2
 
5
3
  from snowflake import snowpark
6
4
  from snowflake.ml._internal.utils import (
@@ -10,27 +8,12 @@ from snowflake.ml._internal.utils import (
10
8
  table_manager,
11
9
  )
12
10
  from snowflake.ml.model._client.sql import _base
13
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
14
11
  from snowflake.snowpark import session, types
15
12
 
16
- SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA"
17
-
18
13
  MODEL_JSON_COL_NAME = "model"
19
14
  MODEL_JSON_MODEL_NAME_FIELD = "model_name"
20
15
  MODEL_JSON_VERSION_NAME_FIELD = "version_name"
21
16
 
22
- MONITOR_NAME_COL_NAME = "MONITOR_NAME"
23
- SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME"
24
- FQ_MODEL_NAME_COL_NAME = "FULLY_QUALIFIED_MODEL_NAME"
25
- VERSION_NAME_COL_NAME = "MODEL_VERSION_NAME"
26
- FUNCTION_NAME_COL_NAME = "FUNCTION_NAME"
27
- TASK_COL_NAME = "TASK"
28
- MONITORING_ENABLED_COL_NAME = "IS_ENABLED"
29
- TIMESTAMP_COL_NAME_COL_NAME = "TIMESTAMP_COLUMN_NAME"
30
- PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES"
31
- LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES"
32
- ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES"
33
-
34
17
 
35
18
  def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
36
19
  sql_list = ", ".join([f"'{column}'" for column in columns])
@@ -146,19 +129,6 @@ class ModelMonitorSQLClient:
146
129
  .validate()
147
130
  )
148
131
 
149
- def _validate_unique_columns(
150
- self,
151
- timestamp_column: sql_identifier.SqlIdentifier,
152
- id_columns: List[sql_identifier.SqlIdentifier],
153
- prediction_columns: List[sql_identifier.SqlIdentifier],
154
- label_columns: List[sql_identifier.SqlIdentifier],
155
- ) -> None:
156
- all_columns = [*id_columns, *prediction_columns, *label_columns, timestamp_column]
157
- num_all_columns = len(all_columns)
158
- num_unique_columns = len(set(all_columns))
159
- if num_all_columns != num_unique_columns:
160
- raise ValueError("Column names must be unique across id, timestamp, prediction, and label columns.")
161
-
162
132
  def validate_existence_by_name(
163
133
  self,
164
134
  *,
@@ -244,125 +214,6 @@ class ModelMonitorSQLClient:
244
214
  if not all([column_name in source_column_schema for column_name in id_columns]):
245
215
  raise ValueError(f"ID column(s): {id_columns} do not exist in source.")
246
216
 
247
- def _validate_timestamp_column_type(
248
- self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier
249
- ) -> None:
250
- """Ensures columns have the same type.
251
-
252
- Args:
253
- table_schema: Dictionary of column names and types in the source table.
254
- timestamp_column: Name of the timestamp column.
255
-
256
- Raises:
257
- ValueError: If the timestamp column is not of type TimestampType.
258
- """
259
- if not isinstance(table_schema[timestamp_column], types.TimestampType):
260
- raise ValueError(
261
- f"Timestamp column: {timestamp_column} must be TimestampType. "
262
- f"Found: {table_schema[timestamp_column]}"
263
- )
264
-
265
- def _validate_id_columns_types(
266
- self, table_schema: Mapping[str, types.DataType], id_columns: List[sql_identifier.SqlIdentifier]
267
- ) -> None:
268
- """Ensures id columns have the correct type.
269
-
270
- Args:
271
- table_schema: Dictionary of column names and types in the source table.
272
- id_columns: List of id column names.
273
-
274
- Raises:
275
- ValueError: If the id column is not of type StringType.
276
- """
277
- id_column_types = list({table_schema[column_name] for column_name in id_columns})
278
- all_id_columns_string = all([isinstance(column_type, types.StringType) for column_type in id_column_types])
279
- if not all_id_columns_string:
280
- raise ValueError(f"Id columns must all be StringType. Found: {id_column_types}")
281
-
282
- def _validate_prediction_columns_types(
283
- self, table_schema: Mapping[str, types.DataType], prediction_columns: List[sql_identifier.SqlIdentifier]
284
- ) -> None:
285
- """Ensures prediction columns have the same type.
286
-
287
- Args:
288
- table_schema: Dictionary of column names and types in the source table.
289
- prediction_columns: List of prediction column names.
290
-
291
- Raises:
292
- ValueError: If the prediction columns do not share the same type.
293
- """
294
-
295
- prediction_column_types = {table_schema[column_name] for column_name in prediction_columns}
296
- if len(prediction_column_types) > 1:
297
- raise ValueError(f"Prediction column types must be the same. Found: {prediction_column_types}")
298
-
299
- def _validate_label_columns_types(
300
- self,
301
- table_schema: Mapping[str, types.DataType],
302
- label_columns: List[sql_identifier.SqlIdentifier],
303
- ) -> None:
304
- """Ensures label columns have the same type, and the correct type for the score type.
305
-
306
- Args:
307
- table_schema: Dictionary of column names and types in the source table.
308
- label_columns: List of label column names.
309
-
310
- Raises:
311
- ValueError: If the label columns do not share the same type.
312
- """
313
- label_column_types = {table_schema[column_name] for column_name in label_columns}
314
- if len(label_column_types) > 1:
315
- raise ValueError(f"Label column types must be the same. Found: {label_column_types}")
316
-
317
- def _validate_column_types(
318
- self,
319
- *,
320
- table_schema: Mapping[str, types.DataType],
321
- timestamp_column: sql_identifier.SqlIdentifier,
322
- id_columns: List[sql_identifier.SqlIdentifier],
323
- prediction_columns: List[sql_identifier.SqlIdentifier],
324
- label_columns: List[sql_identifier.SqlIdentifier],
325
- ) -> None:
326
- """Ensures columns have the expected type.
327
-
328
- Args:
329
- table_schema: Dictionary of column names and types in the source table.
330
- timestamp_column: Name of the timestamp column.
331
- id_columns: List of id column names.
332
- prediction_columns: List of prediction column names.
333
- label_columns: List of label column names.
334
- """
335
- self._validate_timestamp_column_type(table_schema, timestamp_column)
336
- self._validate_id_columns_types(table_schema, id_columns)
337
- self._validate_prediction_columns_types(table_schema, prediction_columns)
338
- self._validate_label_columns_types(table_schema, label_columns)
339
- # TODO(SNOW-1646693): Validate label makes sense with model task
340
-
341
- def _validate_source_table_features_shape(
342
- self,
343
- table_schema: Mapping[str, types.DataType],
344
- special_columns: Set[sql_identifier.SqlIdentifier],
345
- model_function: model_manifest_schema.ModelFunctionInfo,
346
- ) -> None:
347
- table_schema_without_special_columns = {
348
- k: v for k, v in table_schema.items() if sql_identifier.SqlIdentifier(k) not in special_columns
349
- }
350
- schema_column_types_to_count: typing.Counter[types.DataType] = Counter()
351
- for column_type in table_schema_without_special_columns.values():
352
- schema_column_types_to_count[column_type] += 1
353
-
354
- inputs = model_function["signature"].inputs
355
- function_input_types = [input.as_snowpark_type() for input in inputs]
356
- function_input_types_to_count: typing.Counter[types.DataType] = Counter()
357
- for function_input_type in function_input_types:
358
- function_input_types_to_count[function_input_type] += 1
359
-
360
- if function_input_types_to_count != schema_column_types_to_count:
361
- raise ValueError(
362
- "Model function input types do not match the source table input columns types. "
363
- f"Model function expected: {inputs} but got {table_schema_without_special_columns}"
364
- )
365
-
366
217
  def validate_source(
367
218
  self,
368
219
  *,
@@ -395,22 +246,6 @@ class ModelMonitorSQLClient:
395
246
  id_columns=id_columns,
396
247
  )
397
248
 
398
- def delete_monitor_metadata(
399
- self,
400
- name: str,
401
- statement_params: Optional[Dict[str, Any]] = None,
402
- ) -> None:
403
- """Delete the row in the metadata table corresponding to the given monitor name.
404
-
405
- Args:
406
- name: Name of the model monitor whose metadata should be deleted.
407
- statement_params: Optional set of statement_params to include with query.
408
- """
409
- self._sql_client._session.sql(
410
- f"""DELETE FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
411
- WHERE {MONITOR_NAME_COL_NAME} = '{name}'""",
412
- ).collect(statement_params=statement_params)
413
-
414
249
  def _alter_monitor(
415
250
  self,
416
251
  operation: str,
@@ -14,15 +14,6 @@ from snowflake.snowpark import session
14
14
  class ModelMonitorManager:
15
15
  """Class to manage internal operations for Model Monitor workflows."""
16
16
 
17
- def _validate_task_from_model_version(
18
- self,
19
- model_version: model_version_impl.ModelVersion,
20
- ) -> type_hints.Task:
21
- task = model_version.get_model_task()
22
- if task == type_hints.Task.UNKNOWN:
23
- raise ValueError("Registry model must be logged with task in order to be monitored.")
24
- return task
25
-
26
17
  def __init__(
27
18
  self,
28
19
  session: session.Session,
@@ -51,6 +42,15 @@ class ModelMonitorManager:
51
42
  schema_name=self._schema_name,
52
43
  )
53
44
 
45
+ def _validate_task_from_model_version(
46
+ self,
47
+ model_version: model_version_impl.ModelVersion,
48
+ ) -> type_hints.Task:
49
+ task = model_version.get_model_task()
50
+ if task == type_hints.Task.UNKNOWN:
51
+ raise ValueError("Registry model must be logged with task in order to be monitored.")
52
+ return task
53
+
54
54
  def _validate_model_function_from_model_version(
55
55
  self, function: str, model_version: model_version_impl.ModelVersion
56
56
  ) -> None:
@@ -6,23 +6,49 @@ from snowflake.ml.model._client.model import model_version_impl
6
6
 
7
7
  @dataclass
8
8
  class ModelMonitorSourceConfig:
9
+ """Configuration for the source of data to be monitored."""
10
+
9
11
  source: str
12
+ """Name of table or view containing monitoring data."""
13
+
10
14
  timestamp_column: str
15
+ """Name of column in the source containing timestamp."""
16
+
11
17
  id_columns: List[str]
18
+ """List of columns in the source containing unique identifiers."""
19
+
12
20
  prediction_score_columns: Optional[List[str]] = None
21
+ """List of columns in the source containing prediction scores.
22
+ Can be regression scores for regression models and probability scores for classification models."""
23
+
13
24
  prediction_class_columns: Optional[List[str]] = None
25
+ """List of columns in the source containing prediction classes for classification models."""
26
+
14
27
  actual_score_columns: Optional[List[str]] = None
28
+ """List of columns in the source containing actual scores."""
29
+
15
30
  actual_class_columns: Optional[List[str]] = None
31
+ """List of columns in the source containing actual classes for classification models."""
32
+
16
33
  baseline: Optional[str] = None
34
+ """Name of table containing the baseline data."""
17
35
 
18
36
 
19
37
  @dataclass
20
38
  class ModelMonitorConfig:
39
+ """Configuration for the Model Monitor."""
40
+
21
41
  model_version: model_version_impl.ModelVersion
42
+ """Model version to monitor."""
22
43
 
23
- # Python model function name
24
44
  model_function_name: str
45
+ """Function name in the model to monitor."""
46
+
25
47
  background_compute_warehouse_name: str
26
- # TODO: Add support for pythonic notion of time.
48
+ """Name of the warehouse to use for background compute."""
49
+
27
50
  refresh_interval: str = "1 hour"
51
+ """Interval at which to refresh the monitoring data."""
52
+
28
53
  aggregation_window: str = "1 day"
54
+ """Window for aggregating monitoring data."""
@@ -1,5 +1,7 @@
1
+ from snowflake import snowpark
1
2
  from snowflake.ml._internal import telemetry
2
3
  from snowflake.ml._internal.utils import sql_identifier
4
+ from snowflake.ml.monitoring import model_monitor_version
3
5
  from snowflake.ml.monitoring._client import model_monitor_sql_client
4
6
 
5
7
 
@@ -9,13 +11,8 @@ class ModelMonitor:
9
11
  name: sql_identifier.SqlIdentifier
10
12
  _model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient
11
13
 
12
- statement_params = telemetry.get_statement_params(
13
- telemetry.TelemetryProject.MLOPS.value,
14
- telemetry.TelemetrySubProject.MONITORING.value,
15
- )
16
-
17
14
  def __init__(self) -> None:
18
- raise RuntimeError("ModelMonitor's initializer is not meant to be used.")
15
+ raise RuntimeError("Model Monitor's initializer is not meant to be used.")
19
16
 
20
17
  @classmethod
21
18
  def _ref(
@@ -28,10 +25,28 @@ class ModelMonitor:
28
25
  self._model_monitor_client = model_monitor_client
29
26
  return self
30
27
 
28
+ @telemetry.send_api_usage_telemetry(
29
+ project=telemetry.TelemetryProject.MLOPS.value,
30
+ subproject=telemetry.TelemetrySubProject.MONITORING.value,
31
+ )
32
+ @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
31
33
  def suspend(self) -> None:
32
- """Suspend pipeline for ModelMonitor"""
33
- self._model_monitor_client.suspend_monitor(self.name, statement_params=self.statement_params)
34
-
34
+ """Suspend the Model Monitor"""
35
+ statement_params = telemetry.get_statement_params(
36
+ telemetry.TelemetryProject.MLOPS.value,
37
+ telemetry.TelemetrySubProject.MONITORING.value,
38
+ )
39
+ self._model_monitor_client.suspend_monitor(self.name, statement_params=statement_params)
40
+
41
+ @telemetry.send_api_usage_telemetry(
42
+ project=telemetry.TelemetryProject.MLOPS.value,
43
+ subproject=telemetry.TelemetrySubProject.MONITORING.value,
44
+ )
45
+ @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
35
46
  def resume(self) -> None:
36
- """Resume pipeline for ModelMonitor"""
37
- self._model_monitor_client.resume_monitor(self.name, statement_params=self.statement_params)
47
+ """Resume the Model Monitor"""
48
+ statement_params = telemetry.get_statement_params(
49
+ telemetry.TelemetryProject.MLOPS.value,
50
+ telemetry.TelemetrySubProject.MONITORING.value,
51
+ )
52
+ self._model_monitor_client.resume_monitor(self.name, statement_params=statement_params)
@@ -388,15 +388,15 @@ class Registry:
388
388
  source_config: model_monitor_config.ModelMonitorSourceConfig,
389
389
  model_monitor_config: model_monitor_config.ModelMonitorConfig,
390
390
  ) -> model_monitor.ModelMonitor:
391
- """Add a Model Monitor to the Registry
391
+ """Add a Model Monitor to the Registry.
392
392
 
393
393
  Args:
394
- name: Name of Model Monitor to create
395
- source_config: Configuration options of table for ModelMonitor.
396
- model_monitor_config: Configuration options of ModelMonitor.
394
+ name: Name of Model Monitor to create.
395
+ source_config: Configuration options of table for Model Monitor.
396
+ model_monitor_config: Configuration options of Model Monitor.
397
397
 
398
398
  Returns:
399
- The newly added ModelMonitor object.
399
+ The newly added Model Monitor object.
400
400
 
401
401
  Raises:
402
402
  ValueError: If monitoring is not enabled in the Registry.
@@ -407,16 +407,16 @@ class Registry:
407
407
 
408
408
  @overload
409
409
  def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
410
- """Get a Model Monitor on a ModelVersion from the Registry
410
+ """Get a Model Monitor on a Model Version from the Registry.
411
411
 
412
412
  Args:
413
- model_version: ModelVersion for which to retrieve the ModelMonitor.
413
+ model_version: Model Version for which to retrieve the Model Monitor.
414
414
  """
415
415
  ...
416
416
 
417
417
  @overload
418
418
  def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
419
- """Get a Model Monitor from the Registry
419
+ """Get a Model Monitor by name from the Registry.
420
420
 
421
421
  Args:
422
422
  name: Name of Model Monitor to retrieve.
@@ -431,14 +431,14 @@ class Registry:
431
431
  def get_monitor(
432
432
  self, *, name: Optional[str] = None, model_version: Optional[model_version_impl.ModelVersion] = None
433
433
  ) -> model_monitor.ModelMonitor:
434
- """Get a Model Monitor from the Registry
434
+ """Get a Model Monitor from the Registry.
435
435
 
436
436
  Args:
437
437
  name: Name of Model Monitor to retrieve.
438
- model_version: ModelVersion for which to retrieve the ModelMonitor.
438
+ model_version: Model Version for which to retrieve the Model Monitor.
439
439
 
440
440
  Returns:
441
- The fetched ModelMonitor.
441
+ The fetched Model Monitor.
442
442
 
443
443
  Raises:
444
444
  ValueError: If monitoring is not enabled in the Registry.
@@ -476,7 +476,7 @@ class Registry:
476
476
  )
477
477
  @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
478
478
  def delete_monitor(self, name: str) -> None:
479
- """Delete a Model Monitor from the Registry
479
+ """Delete a Model Monitor by name from the Registry.
480
480
 
481
481
  Args:
482
482
  name: Name of the Model Monitor to delete.
@@ -0,0 +1,75 @@
1
+ import http
2
+ import logging
3
+ from datetime import timedelta
4
+ from typing import Dict, Optional
5
+
6
+ import requests
7
+ from cryptography.hazmat.primitives.asymmetric import types
8
+ from requests import auth
9
+
10
+ from snowflake.ml._internal.utils import jwt_generator
11
+
12
+ logger = logging.getLogger(__name__)
13
+ _JWT_TOKEN_CACHE: Dict[str, Dict[int, str]] = {}
14
+
15
+
16
+ def get_jwt_token_generator(
17
+ account: str,
18
+ user: str,
19
+ private_key: types.PRIVATE_KEY_TYPES,
20
+ lifetime: Optional[timedelta] = None,
21
+ renewal_delay: Optional[timedelta] = None,
22
+ ) -> jwt_generator.JWTGenerator:
23
+ return jwt_generator.JWTGenerator(account, user, private_key, lifetime=lifetime, renewal_delay=renewal_delay)
24
+
25
+
26
+ def _get_snowflake_token_by_jwt(
27
+ jwt_token_generator: jwt_generator.JWTGenerator,
28
+ account: Optional[str] = None,
29
+ role: Optional[str] = None,
30
+ endpoint: Optional[str] = None,
31
+ snowflake_account_url: Optional[str] = None,
32
+ ) -> str:
33
+ scope_role = f"session:role:{role}" if role is not None else None
34
+ scope = " ".join(filter(None, [scope_role, endpoint]))
35
+ data = {
36
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
37
+ "scope": scope or None,
38
+ "assertion": jwt_token_generator.get_token(),
39
+ }
40
+ account = account or jwt_token_generator.account
41
+ url = f"https://{account}.snowflakecomputing.com/oauth/token"
42
+ if snowflake_account_url:
43
+ url = f"{snowflake_account_url}/oauth/token"
44
+
45
+ cache_key = hash(frozenset(data.items()))
46
+ if url in _JWT_TOKEN_CACHE:
47
+ if cache_key in _JWT_TOKEN_CACHE[url]:
48
+ return _JWT_TOKEN_CACHE[url][cache_key]
49
+ else:
50
+ _JWT_TOKEN_CACHE[url] = {}
51
+
52
+ response = requests.post(url, data=data)
53
+ if response.status_code != http.HTTPStatus.OK:
54
+ raise RuntimeError(f"Failed to get snowflake token: {response.status_code} {response.content!r}")
55
+ auth_token = response.text
56
+ _JWT_TOKEN_CACHE[url][cache_key] = auth_token
57
+ return auth_token
58
+
59
+
60
+ class SnowflakeJWTTokenAuth(auth.AuthBase):
61
+ def __init__(
62
+ self,
63
+ jwt_token_generator: jwt_generator.JWTGenerator,
64
+ account: Optional[str] = None,
65
+ role: Optional[str] = None,
66
+ endpoint: Optional[str] = None,
67
+ snowflake_account_url: Optional[str] = None,
68
+ ) -> None:
69
+ self.snowflake_token = _get_snowflake_token_by_jwt(
70
+ jwt_token_generator, account, role, endpoint, snowflake_account_url
71
+ )
72
+
73
+ def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
74
+ r.headers["Authorization"] = f'Snowflake Token="{self.snowflake_token}"'
75
+ return r
snowflake/ml/version.py CHANGED
@@ -1 +1 @@
1
- VERSION="1.7.1"
1
+ VERSION="1.7.2"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: snowflake-ml-python
3
- Version: 1.7.1
3
+ Version: 1.7.2
4
4
  Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
5
5
  Author-email: "Snowflake, Inc" <support@snowflake.com>
6
6
  License:
@@ -232,61 +232,62 @@ Classifier: Topic :: Scientific/Engineering :: Information Analysis
232
232
  Requires-Python: <3.12,>=3.9
233
233
  Description-Content-Type: text/markdown
234
234
  License-File: LICENSE.txt
235
- Requires-Dist: absl-py <2,>=0.15
236
- Requires-Dist: anyio <4,>=3.5.0
237
- Requires-Dist: cachetools <6,>=3.1.1
238
- Requires-Dist: cloudpickle >=2.0.0
235
+ Requires-Dist: absl-py<2,>=0.15
236
+ Requires-Dist: anyio<4,>=3.5.0
237
+ Requires-Dist: cachetools<6,>=3.1.1
238
+ Requires-Dist: cloudpickle>=2.0.0
239
239
  Requires-Dist: cryptography
240
- Requires-Dist: fsspec[http] <2024,>=2022.11
241
- Requires-Dist: importlib-resources <7,>=6.1.1
242
- Requires-Dist: numpy <2,>=1.23
243
- Requires-Dist: packaging <25,>=20.9
244
- Requires-Dist: pandas <3,>=1.0.0
240
+ Requires-Dist: fsspec[http]<2024,>=2022.11
241
+ Requires-Dist: importlib_resources<7,>=6.1.1
242
+ Requires-Dist: numpy<2,>=1.23
243
+ Requires-Dist: packaging<25,>=20.9
244
+ Requires-Dist: pandas<3,>=1.0.0
245
245
  Requires-Dist: pyarrow
246
- Requires-Dist: pytimeparse <2,>=1.1.8
247
- Requires-Dist: pyyaml <7,>=6.0
248
- Requires-Dist: retrying <2,>=1.3.3
249
- Requires-Dist: s3fs <2024,>=2022.11
250
- Requires-Dist: scikit-learn <1.6,>=1.4
251
- Requires-Dist: scipy <2,>=1.9
252
- Requires-Dist: snowflake-connector-python[pandas] <4,>=3.5.0
253
- Requires-Dist: snowflake-snowpark-python <2,>=1.17.0
254
- Requires-Dist: sqlparse <1,>=0.4
255
- Requires-Dist: typing-extensions <5,>=4.1.0
256
- Requires-Dist: xgboost <3,>=1.7.3
246
+ Requires-Dist: pyjwt<3,>=2.0.0
247
+ Requires-Dist: pytimeparse<2,>=1.1.8
248
+ Requires-Dist: pyyaml<7,>=6.0
249
+ Requires-Dist: retrying<2,>=1.3.3
250
+ Requires-Dist: s3fs<2024,>=2022.11
251
+ Requires-Dist: scikit-learn<1.6,>=1.4
252
+ Requires-Dist: scipy<2,>=1.9
253
+ Requires-Dist: snowflake-connector-python[pandas]<4,>=3.5.0
254
+ Requires-Dist: snowflake-snowpark-python<2,>=1.17.0
255
+ Requires-Dist: sqlparse<1,>=0.4
256
+ Requires-Dist: typing-extensions<5,>=4.1.0
257
+ Requires-Dist: xgboost<3,>=1.7.3
257
258
  Provides-Extra: all
258
- Requires-Dist: catboost <2,>=1.2.0 ; extra == 'all'
259
- Requires-Dist: lightgbm <5,>=4.1.0 ; extra == 'all'
260
- Requires-Dist: mlflow <2.4,>=2.1.0 ; extra == 'all'
261
- Requires-Dist: peft <1,>=0.5.0 ; extra == 'all'
262
- Requires-Dist: sentence-transformers <3,>=2.2.2 ; extra == 'all'
263
- Requires-Dist: sentencepiece <1,>=0.1.95 ; extra == 'all'
264
- Requires-Dist: shap <1,>=0.46.0 ; extra == 'all'
265
- Requires-Dist: tensorflow <3,>=2.10 ; extra == 'all'
266
- Requires-Dist: tokenizers <1,>=0.10 ; extra == 'all'
267
- Requires-Dist: torch <2.3.0,>=2.0.1 ; extra == 'all'
268
- Requires-Dist: torchdata <1,>=0.4 ; extra == 'all'
269
- Requires-Dist: transformers <5,>=4.32.1 ; extra == 'all'
259
+ Requires-Dist: catboost<2,>=1.2.0; extra == "all"
260
+ Requires-Dist: lightgbm<5,>=4.1.0; extra == "all"
261
+ Requires-Dist: mlflow<2.4,>=2.1.0; extra == "all"
262
+ Requires-Dist: peft<1,>=0.5.0; extra == "all"
263
+ Requires-Dist: sentence-transformers<3,>=2.2.2; extra == "all"
264
+ Requires-Dist: sentencepiece<1,>=0.1.95; extra == "all"
265
+ Requires-Dist: shap<1,>=0.46.0; extra == "all"
266
+ Requires-Dist: tensorflow<3,>=2.10; extra == "all"
267
+ Requires-Dist: tokenizers<1,>=0.10; extra == "all"
268
+ Requires-Dist: torch<2.3.0,>=2.0.1; extra == "all"
269
+ Requires-Dist: torchdata<1,>=0.4; extra == "all"
270
+ Requires-Dist: transformers<5,>=4.32.1; extra == "all"
270
271
  Provides-Extra: catboost
271
- Requires-Dist: catboost <2,>=1.2.0 ; extra == 'catboost'
272
+ Requires-Dist: catboost<2,>=1.2.0; extra == "catboost"
272
273
  Provides-Extra: lightgbm
273
- Requires-Dist: lightgbm <5,>=4.1.0 ; extra == 'lightgbm'
274
+ Requires-Dist: lightgbm<5,>=4.1.0; extra == "lightgbm"
274
275
  Provides-Extra: llm
275
- Requires-Dist: peft <1,>=0.5.0 ; extra == 'llm'
276
+ Requires-Dist: peft<1,>=0.5.0; extra == "llm"
276
277
  Provides-Extra: mlflow
277
- Requires-Dist: mlflow <2.4,>=2.1.0 ; extra == 'mlflow'
278
+ Requires-Dist: mlflow<2.4,>=2.1.0; extra == "mlflow"
278
279
  Provides-Extra: shap
279
- Requires-Dist: shap <1,>=0.46.0 ; extra == 'shap'
280
+ Requires-Dist: shap<1,>=0.46.0; extra == "shap"
280
281
  Provides-Extra: tensorflow
281
- Requires-Dist: tensorflow <3,>=2.10 ; extra == 'tensorflow'
282
+ Requires-Dist: tensorflow<3,>=2.10; extra == "tensorflow"
282
283
  Provides-Extra: torch
283
- Requires-Dist: torch <2.3.0,>=2.0.1 ; extra == 'torch'
284
- Requires-Dist: torchdata <1,>=0.4 ; extra == 'torch'
284
+ Requires-Dist: torch<2.3.0,>=2.0.1; extra == "torch"
285
+ Requires-Dist: torchdata<1,>=0.4; extra == "torch"
285
286
  Provides-Extra: transformers
286
- Requires-Dist: sentence-transformers <3,>=2.2.2 ; extra == 'transformers'
287
- Requires-Dist: sentencepiece <1,>=0.1.95 ; extra == 'transformers'
288
- Requires-Dist: tokenizers <1,>=0.10 ; extra == 'transformers'
289
- Requires-Dist: transformers <5,>=4.32.1 ; extra == 'transformers'
287
+ Requires-Dist: sentence-transformers<3,>=2.2.2; extra == "transformers"
288
+ Requires-Dist: sentencepiece<1,>=0.1.95; extra == "transformers"
289
+ Requires-Dist: tokenizers<1,>=0.10; extra == "transformers"
290
+ Requires-Dist: transformers<5,>=4.32.1; extra == "transformers"
290
291
 
291
292
  # Snowpark ML
292
293
 
@@ -302,7 +303,7 @@ and deployment process, and includes two key components.
302
303
 
303
304
  ### Snowpark ML Development
304
305
 
305
- [Snowpark ML Development](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowpark-ml-development)
306
+ [Snowpark ML Development](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#ml-modeling)
306
307
  provides a collection of python APIs enabling efficient ML model development directly in Snowflake:
307
308
 
308
309
  1. Modeling API (`snowflake.ml.modeling`) for data preprocessing, feature engineering and model training in Snowflake.
@@ -316,14 +317,21 @@ their native data loader formats.
316
317
  1. FileSet API: FileSet provides a Python fsspec-compliant API for materializing data into a Snowflake internal stage
317
318
  from a query or Snowpark Dataframe along with a number of convenience APIs.
318
319
 
319
- ### Snowpark Model Management [Public Preview]
320
+ ### Snowflake MLOps
320
321
 
321
- [Snowpark Model Management](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowpark-ml-ops) complements
322
- the Snowpark ML Development API, and provides model management capabilities along with integrated deployment into Snowflake.
322
+ Snowflake MLOps contains suit of tools and objects to make ML development cycle. It complements
323
+ the Snowpark ML Development API, and provides end to end development to deployment within Snowflake.
323
324
  Currently, the API consists of:
324
325
 
325
- 1. Registry: A python API for managing models within Snowflake which also supports deployment of ML models into Snowflake
326
- as native MODEL object running with Snowflake Warehouse.
326
+ 1. [Registry](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowflake-model-registry): A python API
327
+ allows secure deployment and management of models in Snowflake, supporting models trained both inside and outside of
328
+ Snowflake.
329
+ 2. [Feature Store](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowflake-feature-store): A fully
330
+ integrated solution for defining, managing, storing and discovering ML features derived from your data. The
331
+ Snowflake Feature Store supports automated, incremental refresh from batch and streaming data sources, so that
332
+ feature pipelines need be defined only once to be continuously updated with new data.
333
+ 3. [Datasets](https://docs.snowflake.com/developer-guide/snowflake-ml/overview#snowflake-datasets): Dataset provide an
334
+ immutable, versioned snapshot of your data suitable for ingestion by your machine learning models.
327
335
 
328
336
  ## Getting started
329
337
 
@@ -371,9 +379,39 @@ conda install \
371
379
  Note that until a `snowflake-ml-python` package version is available in the official Snowflake conda channel, there may
372
380
  be compatibility issues. Server-side functionality that `snowflake-ml-python` depends on may not yet be released.
373
381
 
382
+ ### Verifying the package
383
+
384
+ 1. Install cosign.
385
+ This example is using golang installation: [installing-cosign-with-go](https://edu.chainguard.dev/open-source/sigstore/cosign/how-to-install-cosign/#installing-cosign-with-go).
386
+ 1. Download the file from the repository like [pypi](https://pypi.org/project/snowflake-ml-python/#files).
387
+ 1. Download the signature files from the [release tag](https://github.com/snowflakedb/snowflake-ml-python/releases/tag/1.7.0).
388
+ 1. Verify signature on projects signed using Jenkins job:
389
+
390
+ ```sh
391
+ cosign verify-blob snowflake_ml_python-1.7.0.tar.gz --key snowflake-ml-python-1.7.0.pub --signature resources.linux.snowflake_ml_python-1.7.0.tar.gz.sig
392
+
393
+ cosign verify-blob snowflake_ml_python-1.7.0.tar.gz --key snowflake-ml-python-1.7.0.pub --signature resources.linux.snowflake_ml_python-1.7.0
394
+ ```
395
+
396
+ NOTE: Version 1.7.0 is used as example here. Please choose the the latest version.
397
+
374
398
  # Release History
375
399
 
376
- ## 1.7.1
400
+ ## 1.7.2
401
+
402
+ ### Bug Fixes
403
+
404
+ - Model Explainability: Fix issue that explain is enabled for scikit-learn pipeline
405
+ whose task is UNKNOWN and fails later when invoked.
406
+
407
+ ### Behavior Changes
408
+
409
+ ### New Features
410
+
411
+ - Registry: Support asynchronous model inference service creation with the `block` option
412
+ in `ModelVersion.create_service()` set to True by default.
413
+
414
+ ## 1.7.1 (2024-11-05)
377
415
 
378
416
  ### Bug Fixes
379
417
 
@@ -10,7 +10,7 @@ snowflake/cortex/_sse_client.py,sha256=sLYgqAfTOPADCnaWH2RWAJi8KbU_7gSRsTUDcDD5T
10
10
  snowflake/cortex/_summarize.py,sha256=bwpFBzBGmNQSoJqKs3IB5wASjAREnC5ZnViSuZK5IrU,1059
11
11
  snowflake/cortex/_translate.py,sha256=69YUps6mnhzVdubdU_H0IfUAlbBwF9OPemFEQ34P-ts,1404
12
12
  snowflake/cortex/_util.py,sha256=cwRGgrcUo3E05ZaIDT9436vXLQ7GfuBVAjR0QeQ2bDE,3320
13
- snowflake/ml/version.py,sha256=QyWKL6Zvq-VDoZgBZ32iGHIzxeVh0z4fIKkiZHSX7t4,16
13
+ snowflake/ml/version.py,sha256=wJaJaqPpO6Ic3Pl_5e81zlGKYqi1rf5q8V10jTUEDjA,16
14
14
  snowflake/ml/_internal/env.py,sha256=kCrJTRnqQ97VGUVI1cWUPD8HuBWeL5vOOtwUR0NB9Mg,161
15
15
  snowflake/ml/_internal/env_utils.py,sha256=J_jitp8jvDoC3a79EbMSDatFRYw-HiXaI9vR81bhtU8,28075
16
16
  snowflake/ml/_internal/file_utils.py,sha256=OyXHv-UcItiip1YgLnab6etonUQkYuyDtmplZA0CaoU,13622
@@ -36,6 +36,7 @@ snowflake/ml/_internal/utils/db_utils.py,sha256=HBAY0-XHzCP4ai5q3Yqd8O19Ar_Q9J3x
36
36
  snowflake/ml/_internal/utils/formatting.py,sha256=PswZ6Xas7sx3Ok1MBLoH2o7nfXOxaJqpUPg_UqXrQb8,3676
37
37
  snowflake/ml/_internal/utils/identifier.py,sha256=fUYXjXKXAkjLUZpomneMHo2wR4_ZNP4ak-5OJxeUS-g,12467
38
38
  snowflake/ml/_internal/utils/import_utils.py,sha256=iUIROZdiTGy73UCGpG0N-dKtK54H0ymNVge_QNQYY3A,3220
39
+ snowflake/ml/_internal/utils/jwt_generator.py,sha256=bj7Ltnw68WjRcxtV9t5xrTRvV5ETnvovB-o3Y8QWNBg,5357
39
40
  snowflake/ml/_internal/utils/parallelize.py,sha256=Q6_-P2t4DoYNO8DyC1kOl7H3qNL-bUK6EgtlQ_b5ThY,4534
40
41
  snowflake/ml/_internal/utils/pkg_version_utils.py,sha256=FwdLHFhxi3CAQQduGjFavEBmkD9Ra6ZTkt6Eub-WoSA,5168
41
42
  snowflake/ml/_internal/utils/query_result_checker.py,sha256=h1nbUImdB9lSNCON3uIA0xCm8_JrS-TE-jQXJJs9WfU,10668
@@ -101,17 +102,17 @@ snowflake/ml/model/custom_model.py,sha256=O60mjz2Vy8A0Rt3obq43zBT3BxkU7CIcN0AkHs
101
102
  snowflake/ml/model/model_signature.py,sha256=gZnZPs9zTCYkeFoiQzoGUQYZMydYjzH-4xPTzfqt4hU,30496
102
103
  snowflake/ml/model/type_hints.py,sha256=9GPwEuG6B6GSWOXdOy8B1Swz6yDngL865yEtJMd0v1U,8883
103
104
  snowflake/ml/model/_client/model/model_impl.py,sha256=pqjK8mSZIQJ_30tRWWFPIo8X35InSVoAunXlQNtSJEM,15369
104
- snowflake/ml/model/_client/model/model_version_impl.py,sha256=PTVqTkNm1adHUjTTWsUlnTSPiMQV-PZLEaj9UstICqk,39076
105
+ snowflake/ml/model/_client/model/model_version_impl.py,sha256=tGfSR4dF8okdBPeAu7yWVSLtwvnvhnJr9xalKbQZw5M,40144
105
106
  snowflake/ml/model/_client/ops/metadata_ops.py,sha256=7cGx8zYzye2_cvZnyGxoukPtT6Q-Kexd-s4yeZmpmj8,4890
106
107
  snowflake/ml/model/_client/ops/model_ops.py,sha256=didFBsjb7KJYV_586TUK4c9DudVQvjzlphEXJW0AnmY,43935
107
- snowflake/ml/model/_client/ops/service_ops.py,sha256=LLvRqBBwyJsjNfphN_VdH8O1aQEPNf97Wmco5dfLUN0,19093
108
+ snowflake/ml/model/_client/ops/service_ops.py,sha256=t_yLtHlAzHc28XDZ543yAALY5iVsRwVw4i9mtiPaXpQ,19237
108
109
  snowflake/ml/model/_client/service/model_deployment_spec.py,sha256=uyh5k_u8mVP5T4lf0jq8s2cFuiTsbV_nJL6z1Zum2rM,4456
109
110
  snowflake/ml/model/_client/service/model_deployment_spec_schema.py,sha256=eaulF6OFNuDfQz3oPYlDjP26Ww2jWWatm81dCbg602E,825
110
111
  snowflake/ml/model/_client/sql/_base.py,sha256=Qrm8M92g3MHb-QnSLUlbd8iVKCRxLhG_zr5M2qmXwJ8,1473
111
112
  snowflake/ml/model/_client/sql/model.py,sha256=o36oPq4aU9TwahqY2uODYvICxmj1orLztijJ0yMbWnM,5852
112
113
  snowflake/ml/model/_client/sql/model_version.py,sha256=hNMlmwN5JQngKuaeUYV2Bli73RMnHmVH01ABX9NBHFk,20686
113
114
  snowflake/ml/model/_client/sql/service.py,sha256=fvQRhRGU4FBeOBouIoQByTvfQg-qbEQKplCG99BPmL0,10408
114
- snowflake/ml/model/_client/sql/stage.py,sha256=hrCh9P9F4l5R0hLr2r-wLDIEc4XYHMFdX1wNRveMVt0,819
115
+ snowflake/ml/model/_client/sql/stage.py,sha256=165vyAtrScSQWJB8wLXKRUO1QvHTWDmPykeWOyxrDRg,826
115
116
  snowflake/ml/model/_client/sql/tag.py,sha256=pwwrcyPtSnkUfDzL3M8kqM0KSx7CaTtgty3HDhVC9vg,4345
116
117
  snowflake/ml/model/_model_composer/model_composer.py,sha256=535ElL3Kw8eoUjL7fHd-K20eDCBqvJFwowUx2_UOCl8,6712
117
118
  snowflake/ml/model/_model_composer/model_manifest/model_manifest.py,sha256=X6-cKLBZ1X2liIjWnyrd9efQaQhwIoxRSE90Zs0kAZo,7822
@@ -133,7 +134,7 @@ snowflake/ml/model/_packager/model_handlers/lightgbm.py,sha256=E0667G5FFfMssaXjk
133
134
  snowflake/ml/model/_packager/model_handlers/mlflow.py,sha256=A3HnCa065jtHsRM40ZxfLv5alk0RYhVmsU4Jt2klRwQ,9189
134
135
  snowflake/ml/model/_packager/model_handlers/pytorch.py,sha256=DDcf85xisPLT1PyXdmPrjJpIIepkdmWNXCOpT_dCncw,8294
135
136
  snowflake/ml/model/_packager/model_handlers/sentence_transformers.py,sha256=f21fJw2wPsXzzhv71Gi1eHctSlyJ6NAR1EQX5iUL5M8,9842
136
- snowflake/ml/model/_packager/model_handlers/sklearn.py,sha256=UbtqgOztM9racr_N-SPRymEpUwhZGKov5iv6dcbINy8,13995
137
+ snowflake/ml/model/_packager/model_handlers/sklearn.py,sha256=dwwETBdJJM3AVfl3R6VvvVOZQHgnwIuk9dUUCDOs-w0,14111
137
138
  snowflake/ml/model/_packager/model_handlers/snowmlmodel.py,sha256=uhsJ3zK24aavBRO5gNyxv8BHqU9n1TPUBYm1qHTuaxE,12176
138
139
  snowflake/ml/model/_packager/model_handlers/tensorflow.py,sha256=SkbnvkElK4UIMgygv9EK9f5hBxWZ2YDroymUC9uBsBk,9169
139
140
  snowflake/ml/model/_packager/model_handlers/torchscript.py,sha256=BIdRINO1xZ5uHrR9uA0vExWQymOryTaSpyAMpCCtz8U,8036
@@ -146,7 +147,7 @@ snowflake/ml/model/_packager/model_meta/model_meta_schema.py,sha256=5Sdh1_NCKycL
146
147
  snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py,sha256=SORlqpPbOeBg6dvJ3DidHeLVi0w9YF0Zv4tC0Kbc20g,1311
147
148
  snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py,sha256=nf6PWDH_gvX_OiS4A-G6BzyCLFEG4dASU0t5JTsijM4,1041
148
149
  snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py,sha256=qEPzdCw_FzExMbPuyFHupeWlYD88yejLdcmkPwjJzDk,2070
149
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py,sha256=5YHbTmgPdURGQDZwzmC7mlYSl8q_e7hzHJ-JyMXgDFY,1419
150
+ snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py,sha256=eR_qxEwsmzaeaRYH9K4wUAG7bhpqZvn07en2vfRV4c4,1459
150
151
  snowflake/ml/model/_packager/model_runtime/model_runtime.py,sha256=G52nrjzcZiWBJaed6Z1qKq-HjqtnG2MnywDdU9lPusg,5051
151
152
  snowflake/ml/model/_packager/model_task/model_task_utils.py,sha256=0aEUfg71bP5-RkwmzOJBe51yHxLRrtM17tUBoCiuMMk,6310
152
153
  snowflake/ml/model/_signatures/base_handler.py,sha256=WwBfe-83Y0m-HcDx1YSYCGwanIe0fb2MWhTeXc1IeJI,1304
@@ -157,7 +158,7 @@ snowflake/ml/model/_signatures/pandas_handler.py,sha256=ACv8egyiK2Sug8uhkQqMDGTT
157
158
  snowflake/ml/model/_signatures/pytorch_handler.py,sha256=yEU-V_WRjE8Q7NdHyghl0iYpMiIDzGaIR5Pd_ixB1Hk,4631
158
159
  snowflake/ml/model/_signatures/snowpark_handler.py,sha256=2_AY1ssucMICKSPeDjf3mV4WT5farKYdnYkHsvhHZ20,6066
159
160
  snowflake/ml/model/_signatures/tensorflow_handler.py,sha256=9bUbxtHpl4kEoFzeDJF87bQPb8RdLLm9OV23-aUyW3s,6114
160
- snowflake/ml/model/_signatures/utils.py,sha256=-RuAFPJn8JHh8QUMLAgMbgpuDvNLI6gVDeLf-lvUBxQ,13109
161
+ snowflake/ml/model/_signatures/utils.py,sha256=1E_mV1qdUuob8tjB8WaOEfuo2rmQ2FtOgTNyXZGzoJg,13108
161
162
  snowflake/ml/model/models/huggingface_pipeline.py,sha256=62GpPZxBheqCnFNxNOggiDE1y9Dhst-v6D4IkGLuDeQ,10221
162
163
  snowflake/ml/modeling/_internal/constants.py,sha256=aJGngY599w3KqN8cDZCYrjbWe6UwYIbgv0gx0Ukdtc0,105
163
164
  snowflake/ml/modeling/_internal/estimator_utils.py,sha256=mbMm8_5tQde_sQDwI8pS3ljHZ8maCHl2Shb5nQwLYac,11872
@@ -379,23 +380,23 @@ snowflake/ml/modeling/xgboost/xgb_classifier.py,sha256=2QEK6-NihXjKXO8Ue-fOZDyuc
379
380
  snowflake/ml/modeling/xgboost/xgb_regressor.py,sha256=ZorEmRohT2-AUdS8fK0xH8BdB8ENxvVMMDYy34Jzm1o,61703
380
381
  snowflake/ml/modeling/xgboost/xgbrf_classifier.py,sha256=67jh9RosrTeYCWsJbnJ6_MQICHeG22z-DMy8CegP8Vg,62383
381
382
  snowflake/ml/modeling/xgboost/xgbrf_regressor.py,sha256=7_ZwF_QvVqBrkFx_zgGgLXyxtbX26XrWWLozAF-EBB0,61908
382
- snowflake/ml/monitoring/model_monitor.py,sha256=p8FWMpr9O8ScL_y6wdrMUstlpA43gJ0Qiv2e8w-ADts,1374
383
+ snowflake/ml/monitoring/model_monitor.py,sha256=8vJf1YROmJgBLUtpaH-lGKSSJv9R7PxPaQnOdr_j5YE,2200
383
384
  snowflake/ml/monitoring/model_monitor_version.py,sha256=TlmDJZDE0lCVatRaBRgXIjzDF538nrMIc-zWj9MM_nk,46
384
385
  snowflake/ml/monitoring/shap.py,sha256=Dp9nYquPEZjxMTW62YYA9g9qUdmCEFxcSk7ejvOP7PE,3597
385
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py,sha256=EH2jTz6ctNxfXbOMGVbQTAgikUig5YmvSsX93cd9ZF8,20194
386
+ snowflake/ml/monitoring/_client/model_monitor_sql_client.py,sha256=Qr3L6bs84ID5_1TvY6wf5YK2kn3ZVZ-Havo242i3MiY,12710
386
387
  snowflake/ml/monitoring/_client/queries/record_count.ssql,sha256=Bd1uNMwhPKqPyrDd5ug8iY493t9KamJjrlo82OAfmjY,335
387
388
  snowflake/ml/monitoring/_client/queries/rmse.ssql,sha256=OEJiSStRz9-qKoZaFvmubtY_n0xMUjyVU2uiQHCp7KU,822
388
- snowflake/ml/monitoring/_manager/model_monitor_manager.py,sha256=gT3aYZsjD5wIRLdbe7fyyb5vICIxw9WWsK7H0hxbz9E,10314
389
- snowflake/ml/monitoring/entities/model_monitor_config.py,sha256=wvk0v9-VvhFAaNdpYXSqKdWj2Kx-KGjuWVkaCgL4MUc,825
390
- snowflake/ml/monitoring/entities/output_score_type.py,sha256=UJyS4z5hncRZ0agVNa6_X041RY9q3Us-6Bh3dPVAmEw,2982
389
+ snowflake/ml/monitoring/_manager/model_monitor_manager.py,sha256=_-vxqnHqohTHTrwfURjPXijyAeh1mTRdHCG436GaBik,10314
390
+ snowflake/ml/monitoring/entities/model_monitor_config.py,sha256=IxEiee1HfBXCQGzJOZbrDrvoV8J1tDNk43ygNuN00Io,1793
391
391
  snowflake/ml/registry/__init__.py,sha256=XdPQK9ejYkSJVrSQ7HD3jKQO0hKq2mC4bPCB6qrtH3U,76
392
- snowflake/ml/registry/registry.py,sha256=_G6Sm4Zi67iJJ3RUwz2XNYszPnrOtYF5bK8KeGtjubM,23793
392
+ snowflake/ml/registry/registry.py,sha256=5aBedBH8NiFkJJe1Pnggsrjnn0ixdg1oqtUHWyz3wsE,23824
393
393
  snowflake/ml/registry/_manager/model_manager.py,sha256=gFr1EqaMR2Eb4erwVz7fi7xK1G1YsFXz1PF5GvOR0pg,12131
394
+ snowflake/ml/utils/authentication.py,sha256=Wx1kVBZ9XBDuKkRHpPEB2pBxpiJepVLFAirDMx4m5Gk,2612
394
395
  snowflake/ml/utils/connection_params.py,sha256=JRpQppuWRk6bhdLzVDhMfz3Y6yInobFNLHmIBaXD7po,8005
395
396
  snowflake/ml/utils/sparse.py,sha256=XqDQkw39Ml6YIknswdkvFIwUwBk_GBXAbP8IACfPENg,3817
396
397
  snowflake/ml/utils/sql_client.py,sha256=z4Rhi7pQz3s9cyu_Uzfr3deCnrkCdFh9IYIvicsuwdc,692
397
- snowflake_ml_python-1.7.1.dist-info/LICENSE.txt,sha256=PdEp56Av5m3_kl21iFkVTX_EbHJKFGEdmYeIO1pL_Yk,11365
398
- snowflake_ml_python-1.7.1.dist-info/METADATA,sha256=s1vUDI47E0APJr53Bs6qTDV-fWFk3gLkW9yImkzM960,65547
399
- snowflake_ml_python-1.7.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
400
- snowflake_ml_python-1.7.1.dist-info/top_level.txt,sha256=TY0gFSHKDdZy3THb0FGomyikWQasEGldIR1O0HGOHVw,10
401
- snowflake_ml_python-1.7.1.dist-info/RECORD,,
398
+ snowflake_ml_python-1.7.2.dist-info/LICENSE.txt,sha256=PdEp56Av5m3_kl21iFkVTX_EbHJKFGEdmYeIO1pL_Yk,11365
399
+ snowflake_ml_python-1.7.2.dist-info/METADATA,sha256=GwZOHmNQAKaMDP3VeWIDWC-OMhPqldoJaYPrR-_iWGw,67429
400
+ snowflake_ml_python-1.7.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
401
+ snowflake_ml_python-1.7.2.dist-info/top_level.txt,sha256=TY0gFSHKDdZy3THb0FGomyikWQasEGldIR1O0HGOHVw,10
402
+ snowflake_ml_python-1.7.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.3.0)
2
+ Generator: setuptools (75.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,90 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from enum import Enum
4
- from typing import List, Mapping
5
-
6
- from snowflake.ml._internal.utils import sql_identifier
7
- from snowflake.ml.model import type_hints
8
- from snowflake.snowpark import types
9
-
10
- # Accepted data types for each OutputScoreType.
11
- REGRESSION_DATA_TYPES = (
12
- types.ByteType,
13
- types.ShortType,
14
- types.IntegerType,
15
- types.LongType,
16
- types.FloatType,
17
- types.DoubleType,
18
- types.DecimalType,
19
- )
20
- CLASSIFICATION_DATA_TYPES = (
21
- types.ByteType,
22
- types.ShortType,
23
- types.IntegerType,
24
- types.BooleanType,
25
- types.BinaryType,
26
- )
27
- PROBITS_DATA_TYPES = (
28
- types.ByteType,
29
- types.ShortType,
30
- types.IntegerType,
31
- types.LongType,
32
- types.FloatType,
33
- types.DoubleType,
34
- types.DecimalType,
35
- )
36
-
37
-
38
- # OutputScoreType enum
39
- class OutputScoreType(Enum):
40
- UNKNOWN = "UNKNOWN"
41
- REGRESSION = "REGRESSION"
42
- CLASSIFICATION = "CLASSIFICATION"
43
- PROBITS = "PROBITS"
44
-
45
- @classmethod
46
- def deduce_score_type(
47
- cls,
48
- table_schema: Mapping[str, types.DataType],
49
- prediction_columns: List[sql_identifier.SqlIdentifier],
50
- task: type_hints.Task,
51
- ) -> OutputScoreType:
52
- """Find the score type for monitoring given a table schema and the task.
53
-
54
- Args:
55
- table_schema: Dictionary of column names and types in the source table.
56
- prediction_columns: List of prediction columns.
57
- task: Enum value for the task of the model.
58
-
59
- Returns:
60
- Enum value for the score type, informing monitoring table set up.
61
-
62
- Raises:
63
- ValueError: If prediction type fails to align with task.
64
- """
65
- # Already validated we have just one prediction column type
66
- prediction_column_type = {table_schema[column_name] for column_name in prediction_columns}.pop()
67
-
68
- if task == type_hints.Task.TABULAR_REGRESSION:
69
- if isinstance(prediction_column_type, REGRESSION_DATA_TYPES):
70
- return OutputScoreType.REGRESSION
71
- else:
72
- raise ValueError(
73
- f"Expected prediction column type to be one of {REGRESSION_DATA_TYPES} "
74
- f"for REGRESSION task. Found: {prediction_column_type}."
75
- )
76
-
77
- elif task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION:
78
- if isinstance(prediction_column_type, CLASSIFICATION_DATA_TYPES):
79
- return OutputScoreType.CLASSIFICATION
80
- elif isinstance(prediction_column_type, PROBITS_DATA_TYPES):
81
- return OutputScoreType.PROBITS
82
- else:
83
- raise ValueError(
84
- f"Expected prediction column type to be one of {CLASSIFICATION_DATA_TYPES} "
85
- f"or one of {PROBITS_DATA_TYPES} for CLASSIFICATION task. "
86
- f"Found: {prediction_column_type}."
87
- )
88
-
89
- else:
90
- raise ValueError(f"Received unsupported task for model monitoring: {task}.")