snowflake-ml-python 1.7.1__py3-none-any.whl → 1.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/jwt_generator.py +141 -0
- snowflake/ml/model/_client/model/model_version_impl.py +19 -5
- snowflake/ml/model/_client/ops/service_ops.py +12 -7
- snowflake/ml/model/_client/sql/stage.py +1 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_signatures/utils.py +0 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +1 -166
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +9 -9
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -2
- snowflake/ml/monitoring/model_monitor.py +26 -11
- snowflake/ml/registry/registry.py +12 -12
- snowflake/ml/utils/authentication.py +75 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.2.dist-info}/METADATA +90 -52
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.2.dist-info}/RECORD +19 -18
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.2.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/entities/output_score_type.py +0 -90
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
|
|
1
|
+
import base64
|
2
|
+
import hashlib
|
3
|
+
import logging
|
4
|
+
from datetime import datetime, timedelta, timezone
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
import jwt
|
8
|
+
from cryptography.hazmat.primitives import serialization
|
9
|
+
from cryptography.hazmat.primitives.asymmetric import types
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
ISSUER = "iss"
|
14
|
+
EXPIRE_TIME = "exp"
|
15
|
+
ISSUE_TIME = "iat"
|
16
|
+
SUBJECT = "sub"
|
17
|
+
|
18
|
+
|
19
|
+
class JWTGenerator:
|
20
|
+
"""
|
21
|
+
Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator
|
22
|
+
keeps the generated token and only regenerates the token if a specified period of time has passed.
|
23
|
+
"""
|
24
|
+
|
25
|
+
_DEFAULT_LIFETIME = timedelta(minutes=59) # The tokens will have a 59-minute lifetime
|
26
|
+
_DEFAULT_RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes
|
27
|
+
ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
account: str,
|
32
|
+
user: str,
|
33
|
+
private_key: types.PRIVATE_KEY_TYPES,
|
34
|
+
lifetime: Optional[timedelta] = None,
|
35
|
+
renewal_delay: Optional[timedelta] = None,
|
36
|
+
) -> None:
|
37
|
+
"""
|
38
|
+
Create a new JWTGenerator object.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
account: The account identifier.
|
42
|
+
user: The username.
|
43
|
+
private_key: The private key used to sign the JWT.
|
44
|
+
lifetime: The lifetime of the token.
|
45
|
+
renewal_delay: The time before the token expires to renew it.
|
46
|
+
"""
|
47
|
+
|
48
|
+
# Construct the fully qualified name of the user in uppercase.
|
49
|
+
self.account = JWTGenerator._prepare_account_name_for_jwt(account)
|
50
|
+
self.user = user.upper()
|
51
|
+
self.qualified_username = self.account + "." + self.user
|
52
|
+
self.private_key = private_key
|
53
|
+
self.public_key_fp = JWTGenerator._calculate_public_key_fingerprint(self.private_key)
|
54
|
+
|
55
|
+
self.issuer = self.qualified_username + "." + self.public_key_fp
|
56
|
+
self.lifetime = lifetime or JWTGenerator._DEFAULT_LIFETIME
|
57
|
+
self.renewal_delay = renewal_delay or JWTGenerator._DEFAULT_RENEWAL_DELTA
|
58
|
+
self.renew_time = datetime.now(timezone.utc)
|
59
|
+
self.token: Optional[str] = None
|
60
|
+
|
61
|
+
logger.info(
|
62
|
+
"""Creating JWTGenerator with arguments
|
63
|
+
account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
|
64
|
+
self.account,
|
65
|
+
self.user,
|
66
|
+
self.lifetime,
|
67
|
+
self.renewal_delay,
|
68
|
+
)
|
69
|
+
|
70
|
+
@staticmethod
|
71
|
+
def _prepare_account_name_for_jwt(raw_account: str) -> str:
|
72
|
+
account = raw_account
|
73
|
+
if ".global" not in account:
|
74
|
+
# Handle the general case.
|
75
|
+
idx = account.find(".")
|
76
|
+
if idx > 0:
|
77
|
+
account = account[0:idx]
|
78
|
+
else:
|
79
|
+
# Handle the replication case.
|
80
|
+
idx = account.find("-")
|
81
|
+
if idx > 0:
|
82
|
+
account = account[0:idx]
|
83
|
+
# Use uppercase for the account identifier.
|
84
|
+
return account.upper()
|
85
|
+
|
86
|
+
def get_token(self) -> str:
|
87
|
+
now = datetime.now(timezone.utc) # Fetch the current time
|
88
|
+
if self.token is not None and self.renew_time > now:
|
89
|
+
return self.token
|
90
|
+
|
91
|
+
# If the token has expired or doesn't exist, regenerate the token.
|
92
|
+
logger.info(
|
93
|
+
"Generating a new token because the present time (%s) is later than the renewal time (%s)",
|
94
|
+
now,
|
95
|
+
self.renew_time,
|
96
|
+
)
|
97
|
+
# Calculate the next time we need to renew the token.
|
98
|
+
self.renew_time = now + self.renewal_delay
|
99
|
+
|
100
|
+
# Create our payload
|
101
|
+
payload = {
|
102
|
+
# Set the issuer to the fully qualified username concatenated with the public key fingerprint.
|
103
|
+
ISSUER: self.issuer,
|
104
|
+
# Set the subject to the fully qualified username.
|
105
|
+
SUBJECT: self.qualified_username,
|
106
|
+
# Set the issue time to now.
|
107
|
+
ISSUE_TIME: now,
|
108
|
+
# Set the expiration time, based on the lifetime specified for this object.
|
109
|
+
EXPIRE_TIME: now + self.lifetime,
|
110
|
+
}
|
111
|
+
|
112
|
+
# Regenerate the actual token
|
113
|
+
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
|
114
|
+
# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
|
115
|
+
# If the token is a byte string, convert it to a string.
|
116
|
+
if isinstance(token, bytes):
|
117
|
+
token = token.decode("utf-8")
|
118
|
+
self.token = token
|
119
|
+
logger.info(
|
120
|
+
"Generated a JWT with the following payload: %s",
|
121
|
+
jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
|
122
|
+
)
|
123
|
+
|
124
|
+
return token
|
125
|
+
|
126
|
+
@staticmethod
|
127
|
+
def _calculate_public_key_fingerprint(private_key: types.PRIVATE_KEY_TYPES) -> str:
|
128
|
+
# Get the raw bytes of public key.
|
129
|
+
public_key_raw = private_key.public_key().public_bytes(
|
130
|
+
serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo
|
131
|
+
)
|
132
|
+
|
133
|
+
# Get the sha256 hash of the raw bytes.
|
134
|
+
sha256hash = hashlib.sha256()
|
135
|
+
sha256hash.update(public_key_raw)
|
136
|
+
|
137
|
+
# Base64-encode the value and prepend the prefix 'SHA256:'.
|
138
|
+
public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
|
139
|
+
logger.info("Public key fingerprint is %s", public_key_fp)
|
140
|
+
|
141
|
+
return public_key_fp
|
@@ -14,7 +14,7 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
15
15
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
16
16
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
17
|
-
from snowflake.snowpark import Session, dataframe
|
17
|
+
from snowflake.snowpark import Session, async_job, dataframe
|
18
18
|
|
19
19
|
_TELEMETRY_PROJECT = "MLOps"
|
20
20
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
@@ -631,7 +631,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
631
631
|
max_batch_rows: Optional[int] = None,
|
632
632
|
force_rebuild: bool = False,
|
633
633
|
build_external_access_integration: Optional[str] = None,
|
634
|
-
|
634
|
+
block: bool = True,
|
635
|
+
) -> Union[str, async_job.AsyncJob]:
|
635
636
|
"""Create an inference service with the given spec.
|
636
637
|
|
637
638
|
Args:
|
@@ -659,6 +660,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
659
660
|
force_rebuild: Whether to force a model inference image rebuild.
|
660
661
|
build_external_access_integration: (Deprecated) The external access integration for image build. This is
|
661
662
|
usually permitting access to conda & PyPI repositories.
|
663
|
+
block: A bool value indicating whether this function will wait until the service is available.
|
664
|
+
When it is ``False``, this function executes the underlying service creation asynchronously
|
665
|
+
and returns an :class:`AsyncJob`.
|
662
666
|
"""
|
663
667
|
...
|
664
668
|
|
@@ -679,7 +683,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
679
683
|
max_batch_rows: Optional[int] = None,
|
680
684
|
force_rebuild: bool = False,
|
681
685
|
build_external_access_integrations: Optional[List[str]] = None,
|
682
|
-
|
686
|
+
block: bool = True,
|
687
|
+
) -> Union[str, async_job.AsyncJob]:
|
683
688
|
"""Create an inference service with the given spec.
|
684
689
|
|
685
690
|
Args:
|
@@ -707,6 +712,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
707
712
|
force_rebuild: Whether to force a model inference image rebuild.
|
708
713
|
build_external_access_integrations: The external access integrations for image build. This is usually
|
709
714
|
permitting access to conda & PyPI repositories.
|
715
|
+
block: A bool value indicating whether this function will wait until the service is available.
|
716
|
+
When it is ``False``, this function executes the underlying service creation asynchronously
|
717
|
+
and returns an :class:`AsyncJob`.
|
710
718
|
"""
|
711
719
|
...
|
712
720
|
|
@@ -742,7 +750,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
742
750
|
force_rebuild: bool = False,
|
743
751
|
build_external_access_integration: Optional[str] = None,
|
744
752
|
build_external_access_integrations: Optional[List[str]] = None,
|
745
|
-
|
753
|
+
block: bool = True,
|
754
|
+
) -> Union[str, async_job.AsyncJob]:
|
746
755
|
"""Create an inference service with the given spec.
|
747
756
|
|
748
757
|
Args:
|
@@ -772,12 +781,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
772
781
|
usually permitting access to conda & PyPI repositories.
|
773
782
|
build_external_access_integrations: The external access integrations for image build. This is usually
|
774
783
|
permitting access to conda & PyPI repositories.
|
784
|
+
block: A bool value indicating whether this function will wait until the service is available.
|
785
|
+
When it is False, this function executes the underlying service creation asynchronously
|
786
|
+
and returns an AsyncJob.
|
775
787
|
|
776
788
|
Raises:
|
777
789
|
ValueError: Illegal external access integration arguments.
|
778
790
|
|
779
791
|
Returns:
|
780
|
-
|
792
|
+
If `block=True`, return result information about service creation from server.
|
793
|
+
Otherwise, return the service creation AsyncJob.
|
781
794
|
"""
|
782
795
|
statement_params = telemetry.get_statement_params(
|
783
796
|
project=_TELEMETRY_PROJECT,
|
@@ -829,6 +842,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
829
842
|
if build_external_access_integrations is None
|
830
843
|
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
831
844
|
),
|
845
|
+
block=block,
|
832
846
|
statement_params=statement_params,
|
833
847
|
)
|
834
848
|
|
@@ -6,7 +6,7 @@ import re
|
|
6
6
|
import tempfile
|
7
7
|
import threading
|
8
8
|
import time
|
9
|
-
from typing import Any, Dict, List, Optional, Tuple, cast
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
10
10
|
|
11
11
|
from packaging import version
|
12
12
|
|
@@ -15,7 +15,7 @@ from snowflake.ml._internal import file_utils
|
|
15
15
|
from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
|
16
16
|
from snowflake.ml.model._client.service import model_deployment_spec
|
17
17
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
18
|
-
from snowflake.snowpark import exceptions, row, session
|
18
|
+
from snowflake.snowpark import async_job, exceptions, row, session
|
19
19
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
20
20
|
|
21
21
|
module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
|
@@ -107,8 +107,9 @@ class ServiceOperator:
|
|
107
107
|
max_batch_rows: Optional[int],
|
108
108
|
force_rebuild: bool,
|
109
109
|
build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
|
110
|
+
block: bool,
|
110
111
|
statement_params: Optional[Dict[str, Any]] = None,
|
111
|
-
) -> str:
|
112
|
+
) -> Union[str, async_job.AsyncJob]:
|
112
113
|
|
113
114
|
# Fall back to the registry's database and schema if not provided
|
114
115
|
database_name = database_name or self._database_name
|
@@ -204,11 +205,15 @@ class ServiceOperator:
|
|
204
205
|
log_thread = self._start_service_log_streaming(
|
205
206
|
async_job, services, model_inference_service_exists, force_rebuild, statement_params
|
206
207
|
)
|
207
|
-
log_thread.join()
|
208
208
|
|
209
|
-
|
210
|
-
|
211
|
-
|
209
|
+
if block:
|
210
|
+
log_thread.join()
|
211
|
+
|
212
|
+
res = cast(str, cast(List[row.Row], async_job.result())[0][0])
|
213
|
+
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
214
|
+
return res
|
215
|
+
else:
|
216
|
+
return async_job
|
212
217
|
|
213
218
|
def _start_service_log_streaming(
|
214
219
|
self,
|
@@ -15,6 +15,6 @@ class StageSQLClient(_base._BaseSQLClient):
|
|
15
15
|
) -> None:
|
16
16
|
query_result_checker.SqlResultValidator(
|
17
17
|
self._session,
|
18
|
-
f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
|
18
|
+
f"CREATE SCOPED TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
|
19
19
|
statement_params=statement_params,
|
20
20
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -164,6 +164,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
164
164
|
stacklevel=1,
|
165
165
|
)
|
166
166
|
enable_explainability = False
|
167
|
+
elif model_meta.task == model_types.Task.UNKNOWN:
|
168
|
+
enable_explainability = False
|
167
169
|
else:
|
168
170
|
enable_explainability = True
|
169
171
|
if enable_explainability:
|
@@ -1,2 +1,2 @@
|
|
1
|
-
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
2
|
-
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
1
|
+
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
2
|
+
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
@@ -118,7 +118,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
|
|
118
118
|
category=DeprecationWarning,
|
119
119
|
stacklevel=1,
|
120
120
|
)
|
121
|
-
|
122
121
|
return core.ModelSignature(
|
123
122
|
inputs=[
|
124
123
|
core.FeatureSpec(name="user_inputs", dtype=core.DataType.STRING, shape=(-1,)),
|
@@ -1,6 +1,4 @@
|
|
1
|
-
import
|
2
|
-
from collections import Counter
|
3
|
-
from typing import Any, Dict, List, Mapping, Optional, Set
|
1
|
+
from typing import Any, Dict, List, Mapping, Optional
|
4
2
|
|
5
3
|
from snowflake import snowpark
|
6
4
|
from snowflake.ml._internal.utils import (
|
@@ -10,27 +8,12 @@ from snowflake.ml._internal.utils import (
|
|
10
8
|
table_manager,
|
11
9
|
)
|
12
10
|
from snowflake.ml.model._client.sql import _base
|
13
|
-
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
14
11
|
from snowflake.snowpark import session, types
|
15
12
|
|
16
|
-
SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA"
|
17
|
-
|
18
13
|
MODEL_JSON_COL_NAME = "model"
|
19
14
|
MODEL_JSON_MODEL_NAME_FIELD = "model_name"
|
20
15
|
MODEL_JSON_VERSION_NAME_FIELD = "version_name"
|
21
16
|
|
22
|
-
MONITOR_NAME_COL_NAME = "MONITOR_NAME"
|
23
|
-
SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME"
|
24
|
-
FQ_MODEL_NAME_COL_NAME = "FULLY_QUALIFIED_MODEL_NAME"
|
25
|
-
VERSION_NAME_COL_NAME = "MODEL_VERSION_NAME"
|
26
|
-
FUNCTION_NAME_COL_NAME = "FUNCTION_NAME"
|
27
|
-
TASK_COL_NAME = "TASK"
|
28
|
-
MONITORING_ENABLED_COL_NAME = "IS_ENABLED"
|
29
|
-
TIMESTAMP_COL_NAME_COL_NAME = "TIMESTAMP_COLUMN_NAME"
|
30
|
-
PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES"
|
31
|
-
LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES"
|
32
|
-
ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES"
|
33
|
-
|
34
17
|
|
35
18
|
def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
|
36
19
|
sql_list = ", ".join([f"'{column}'" for column in columns])
|
@@ -146,19 +129,6 @@ class ModelMonitorSQLClient:
|
|
146
129
|
.validate()
|
147
130
|
)
|
148
131
|
|
149
|
-
def _validate_unique_columns(
|
150
|
-
self,
|
151
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
152
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
153
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
154
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
155
|
-
) -> None:
|
156
|
-
all_columns = [*id_columns, *prediction_columns, *label_columns, timestamp_column]
|
157
|
-
num_all_columns = len(all_columns)
|
158
|
-
num_unique_columns = len(set(all_columns))
|
159
|
-
if num_all_columns != num_unique_columns:
|
160
|
-
raise ValueError("Column names must be unique across id, timestamp, prediction, and label columns.")
|
161
|
-
|
162
132
|
def validate_existence_by_name(
|
163
133
|
self,
|
164
134
|
*,
|
@@ -244,125 +214,6 @@ class ModelMonitorSQLClient:
|
|
244
214
|
if not all([column_name in source_column_schema for column_name in id_columns]):
|
245
215
|
raise ValueError(f"ID column(s): {id_columns} do not exist in source.")
|
246
216
|
|
247
|
-
def _validate_timestamp_column_type(
|
248
|
-
self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier
|
249
|
-
) -> None:
|
250
|
-
"""Ensures columns have the same type.
|
251
|
-
|
252
|
-
Args:
|
253
|
-
table_schema: Dictionary of column names and types in the source table.
|
254
|
-
timestamp_column: Name of the timestamp column.
|
255
|
-
|
256
|
-
Raises:
|
257
|
-
ValueError: If the timestamp column is not of type TimestampType.
|
258
|
-
"""
|
259
|
-
if not isinstance(table_schema[timestamp_column], types.TimestampType):
|
260
|
-
raise ValueError(
|
261
|
-
f"Timestamp column: {timestamp_column} must be TimestampType. "
|
262
|
-
f"Found: {table_schema[timestamp_column]}"
|
263
|
-
)
|
264
|
-
|
265
|
-
def _validate_id_columns_types(
|
266
|
-
self, table_schema: Mapping[str, types.DataType], id_columns: List[sql_identifier.SqlIdentifier]
|
267
|
-
) -> None:
|
268
|
-
"""Ensures id columns have the correct type.
|
269
|
-
|
270
|
-
Args:
|
271
|
-
table_schema: Dictionary of column names and types in the source table.
|
272
|
-
id_columns: List of id column names.
|
273
|
-
|
274
|
-
Raises:
|
275
|
-
ValueError: If the id column is not of type StringType.
|
276
|
-
"""
|
277
|
-
id_column_types = list({table_schema[column_name] for column_name in id_columns})
|
278
|
-
all_id_columns_string = all([isinstance(column_type, types.StringType) for column_type in id_column_types])
|
279
|
-
if not all_id_columns_string:
|
280
|
-
raise ValueError(f"Id columns must all be StringType. Found: {id_column_types}")
|
281
|
-
|
282
|
-
def _validate_prediction_columns_types(
|
283
|
-
self, table_schema: Mapping[str, types.DataType], prediction_columns: List[sql_identifier.SqlIdentifier]
|
284
|
-
) -> None:
|
285
|
-
"""Ensures prediction columns have the same type.
|
286
|
-
|
287
|
-
Args:
|
288
|
-
table_schema: Dictionary of column names and types in the source table.
|
289
|
-
prediction_columns: List of prediction column names.
|
290
|
-
|
291
|
-
Raises:
|
292
|
-
ValueError: If the prediction columns do not share the same type.
|
293
|
-
"""
|
294
|
-
|
295
|
-
prediction_column_types = {table_schema[column_name] for column_name in prediction_columns}
|
296
|
-
if len(prediction_column_types) > 1:
|
297
|
-
raise ValueError(f"Prediction column types must be the same. Found: {prediction_column_types}")
|
298
|
-
|
299
|
-
def _validate_label_columns_types(
|
300
|
-
self,
|
301
|
-
table_schema: Mapping[str, types.DataType],
|
302
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
303
|
-
) -> None:
|
304
|
-
"""Ensures label columns have the same type, and the correct type for the score type.
|
305
|
-
|
306
|
-
Args:
|
307
|
-
table_schema: Dictionary of column names and types in the source table.
|
308
|
-
label_columns: List of label column names.
|
309
|
-
|
310
|
-
Raises:
|
311
|
-
ValueError: If the label columns do not share the same type.
|
312
|
-
"""
|
313
|
-
label_column_types = {table_schema[column_name] for column_name in label_columns}
|
314
|
-
if len(label_column_types) > 1:
|
315
|
-
raise ValueError(f"Label column types must be the same. Found: {label_column_types}")
|
316
|
-
|
317
|
-
def _validate_column_types(
|
318
|
-
self,
|
319
|
-
*,
|
320
|
-
table_schema: Mapping[str, types.DataType],
|
321
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
322
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
323
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
324
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
325
|
-
) -> None:
|
326
|
-
"""Ensures columns have the expected type.
|
327
|
-
|
328
|
-
Args:
|
329
|
-
table_schema: Dictionary of column names and types in the source table.
|
330
|
-
timestamp_column: Name of the timestamp column.
|
331
|
-
id_columns: List of id column names.
|
332
|
-
prediction_columns: List of prediction column names.
|
333
|
-
label_columns: List of label column names.
|
334
|
-
"""
|
335
|
-
self._validate_timestamp_column_type(table_schema, timestamp_column)
|
336
|
-
self._validate_id_columns_types(table_schema, id_columns)
|
337
|
-
self._validate_prediction_columns_types(table_schema, prediction_columns)
|
338
|
-
self._validate_label_columns_types(table_schema, label_columns)
|
339
|
-
# TODO(SNOW-1646693): Validate label makes sense with model task
|
340
|
-
|
341
|
-
def _validate_source_table_features_shape(
|
342
|
-
self,
|
343
|
-
table_schema: Mapping[str, types.DataType],
|
344
|
-
special_columns: Set[sql_identifier.SqlIdentifier],
|
345
|
-
model_function: model_manifest_schema.ModelFunctionInfo,
|
346
|
-
) -> None:
|
347
|
-
table_schema_without_special_columns = {
|
348
|
-
k: v for k, v in table_schema.items() if sql_identifier.SqlIdentifier(k) not in special_columns
|
349
|
-
}
|
350
|
-
schema_column_types_to_count: typing.Counter[types.DataType] = Counter()
|
351
|
-
for column_type in table_schema_without_special_columns.values():
|
352
|
-
schema_column_types_to_count[column_type] += 1
|
353
|
-
|
354
|
-
inputs = model_function["signature"].inputs
|
355
|
-
function_input_types = [input.as_snowpark_type() for input in inputs]
|
356
|
-
function_input_types_to_count: typing.Counter[types.DataType] = Counter()
|
357
|
-
for function_input_type in function_input_types:
|
358
|
-
function_input_types_to_count[function_input_type] += 1
|
359
|
-
|
360
|
-
if function_input_types_to_count != schema_column_types_to_count:
|
361
|
-
raise ValueError(
|
362
|
-
"Model function input types do not match the source table input columns types. "
|
363
|
-
f"Model function expected: {inputs} but got {table_schema_without_special_columns}"
|
364
|
-
)
|
365
|
-
|
366
217
|
def validate_source(
|
367
218
|
self,
|
368
219
|
*,
|
@@ -395,22 +246,6 @@ class ModelMonitorSQLClient:
|
|
395
246
|
id_columns=id_columns,
|
396
247
|
)
|
397
248
|
|
398
|
-
def delete_monitor_metadata(
|
399
|
-
self,
|
400
|
-
name: str,
|
401
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
402
|
-
) -> None:
|
403
|
-
"""Delete the row in the metadata table corresponding to the given monitor name.
|
404
|
-
|
405
|
-
Args:
|
406
|
-
name: Name of the model monitor whose metadata should be deleted.
|
407
|
-
statement_params: Optional set of statement_params to include with query.
|
408
|
-
"""
|
409
|
-
self._sql_client._session.sql(
|
410
|
-
f"""DELETE FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
411
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{name}'""",
|
412
|
-
).collect(statement_params=statement_params)
|
413
|
-
|
414
249
|
def _alter_monitor(
|
415
250
|
self,
|
416
251
|
operation: str,
|
@@ -14,15 +14,6 @@ from snowflake.snowpark import session
|
|
14
14
|
class ModelMonitorManager:
|
15
15
|
"""Class to manage internal operations for Model Monitor workflows."""
|
16
16
|
|
17
|
-
def _validate_task_from_model_version(
|
18
|
-
self,
|
19
|
-
model_version: model_version_impl.ModelVersion,
|
20
|
-
) -> type_hints.Task:
|
21
|
-
task = model_version.get_model_task()
|
22
|
-
if task == type_hints.Task.UNKNOWN:
|
23
|
-
raise ValueError("Registry model must be logged with task in order to be monitored.")
|
24
|
-
return task
|
25
|
-
|
26
17
|
def __init__(
|
27
18
|
self,
|
28
19
|
session: session.Session,
|
@@ -51,6 +42,15 @@ class ModelMonitorManager:
|
|
51
42
|
schema_name=self._schema_name,
|
52
43
|
)
|
53
44
|
|
45
|
+
def _validate_task_from_model_version(
|
46
|
+
self,
|
47
|
+
model_version: model_version_impl.ModelVersion,
|
48
|
+
) -> type_hints.Task:
|
49
|
+
task = model_version.get_model_task()
|
50
|
+
if task == type_hints.Task.UNKNOWN:
|
51
|
+
raise ValueError("Registry model must be logged with task in order to be monitored.")
|
52
|
+
return task
|
53
|
+
|
54
54
|
def _validate_model_function_from_model_version(
|
55
55
|
self, function: str, model_version: model_version_impl.ModelVersion
|
56
56
|
) -> None:
|
@@ -6,23 +6,49 @@ from snowflake.ml.model._client.model import model_version_impl
|
|
6
6
|
|
7
7
|
@dataclass
|
8
8
|
class ModelMonitorSourceConfig:
|
9
|
+
"""Configuration for the source of data to be monitored."""
|
10
|
+
|
9
11
|
source: str
|
12
|
+
"""Name of table or view containing monitoring data."""
|
13
|
+
|
10
14
|
timestamp_column: str
|
15
|
+
"""Name of column in the source containing timestamp."""
|
16
|
+
|
11
17
|
id_columns: List[str]
|
18
|
+
"""List of columns in the source containing unique identifiers."""
|
19
|
+
|
12
20
|
prediction_score_columns: Optional[List[str]] = None
|
21
|
+
"""List of columns in the source containing prediction scores.
|
22
|
+
Can be regression scores for regression models and probability scores for classification models."""
|
23
|
+
|
13
24
|
prediction_class_columns: Optional[List[str]] = None
|
25
|
+
"""List of columns in the source containing prediction classes for classification models."""
|
26
|
+
|
14
27
|
actual_score_columns: Optional[List[str]] = None
|
28
|
+
"""List of columns in the source containing actual scores."""
|
29
|
+
|
15
30
|
actual_class_columns: Optional[List[str]] = None
|
31
|
+
"""List of columns in the source containing actual classes for classification models."""
|
32
|
+
|
16
33
|
baseline: Optional[str] = None
|
34
|
+
"""Name of table containing the baseline data."""
|
17
35
|
|
18
36
|
|
19
37
|
@dataclass
|
20
38
|
class ModelMonitorConfig:
|
39
|
+
"""Configuration for the Model Monitor."""
|
40
|
+
|
21
41
|
model_version: model_version_impl.ModelVersion
|
42
|
+
"""Model version to monitor."""
|
22
43
|
|
23
|
-
# Python model function name
|
24
44
|
model_function_name: str
|
45
|
+
"""Function name in the model to monitor."""
|
46
|
+
|
25
47
|
background_compute_warehouse_name: str
|
26
|
-
|
48
|
+
"""Name of the warehouse to use for background compute."""
|
49
|
+
|
27
50
|
refresh_interval: str = "1 hour"
|
51
|
+
"""Interval at which to refresh the monitoring data."""
|
52
|
+
|
28
53
|
aggregation_window: str = "1 day"
|
54
|
+
"""Window for aggregating monitoring data."""
|
@@ -1,5 +1,7 @@
|
|
1
|
+
from snowflake import snowpark
|
1
2
|
from snowflake.ml._internal import telemetry
|
2
3
|
from snowflake.ml._internal.utils import sql_identifier
|
4
|
+
from snowflake.ml.monitoring import model_monitor_version
|
3
5
|
from snowflake.ml.monitoring._client import model_monitor_sql_client
|
4
6
|
|
5
7
|
|
@@ -9,13 +11,8 @@ class ModelMonitor:
|
|
9
11
|
name: sql_identifier.SqlIdentifier
|
10
12
|
_model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient
|
11
13
|
|
12
|
-
statement_params = telemetry.get_statement_params(
|
13
|
-
telemetry.TelemetryProject.MLOPS.value,
|
14
|
-
telemetry.TelemetrySubProject.MONITORING.value,
|
15
|
-
)
|
16
|
-
|
17
14
|
def __init__(self) -> None:
|
18
|
-
raise RuntimeError("
|
15
|
+
raise RuntimeError("Model Monitor's initializer is not meant to be used.")
|
19
16
|
|
20
17
|
@classmethod
|
21
18
|
def _ref(
|
@@ -28,10 +25,28 @@ class ModelMonitor:
|
|
28
25
|
self._model_monitor_client = model_monitor_client
|
29
26
|
return self
|
30
27
|
|
28
|
+
@telemetry.send_api_usage_telemetry(
|
29
|
+
project=telemetry.TelemetryProject.MLOPS.value,
|
30
|
+
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
31
|
+
)
|
32
|
+
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
31
33
|
def suspend(self) -> None:
|
32
|
-
"""Suspend
|
33
|
-
|
34
|
-
|
34
|
+
"""Suspend the Model Monitor"""
|
35
|
+
statement_params = telemetry.get_statement_params(
|
36
|
+
telemetry.TelemetryProject.MLOPS.value,
|
37
|
+
telemetry.TelemetrySubProject.MONITORING.value,
|
38
|
+
)
|
39
|
+
self._model_monitor_client.suspend_monitor(self.name, statement_params=statement_params)
|
40
|
+
|
41
|
+
@telemetry.send_api_usage_telemetry(
|
42
|
+
project=telemetry.TelemetryProject.MLOPS.value,
|
43
|
+
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
44
|
+
)
|
45
|
+
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
35
46
|
def resume(self) -> None:
|
36
|
-
"""Resume
|
37
|
-
|
47
|
+
"""Resume the Model Monitor"""
|
48
|
+
statement_params = telemetry.get_statement_params(
|
49
|
+
telemetry.TelemetryProject.MLOPS.value,
|
50
|
+
telemetry.TelemetrySubProject.MONITORING.value,
|
51
|
+
)
|
52
|
+
self._model_monitor_client.resume_monitor(self.name, statement_params=statement_params)
|
@@ -388,15 +388,15 @@ class Registry:
|
|
388
388
|
source_config: model_monitor_config.ModelMonitorSourceConfig,
|
389
389
|
model_monitor_config: model_monitor_config.ModelMonitorConfig,
|
390
390
|
) -> model_monitor.ModelMonitor:
|
391
|
-
"""Add a Model Monitor to the Registry
|
391
|
+
"""Add a Model Monitor to the Registry.
|
392
392
|
|
393
393
|
Args:
|
394
|
-
name: Name of Model Monitor to create
|
395
|
-
source_config: Configuration options of table for
|
396
|
-
model_monitor_config: Configuration options of
|
394
|
+
name: Name of Model Monitor to create.
|
395
|
+
source_config: Configuration options of table for Model Monitor.
|
396
|
+
model_monitor_config: Configuration options of Model Monitor.
|
397
397
|
|
398
398
|
Returns:
|
399
|
-
The newly added
|
399
|
+
The newly added Model Monitor object.
|
400
400
|
|
401
401
|
Raises:
|
402
402
|
ValueError: If monitoring is not enabled in the Registry.
|
@@ -407,16 +407,16 @@ class Registry:
|
|
407
407
|
|
408
408
|
@overload
|
409
409
|
def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
|
410
|
-
"""Get a Model Monitor on a
|
410
|
+
"""Get a Model Monitor on a Model Version from the Registry.
|
411
411
|
|
412
412
|
Args:
|
413
|
-
model_version:
|
413
|
+
model_version: Model Version for which to retrieve the Model Monitor.
|
414
414
|
"""
|
415
415
|
...
|
416
416
|
|
417
417
|
@overload
|
418
418
|
def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
|
419
|
-
"""Get a Model Monitor from the Registry
|
419
|
+
"""Get a Model Monitor by name from the Registry.
|
420
420
|
|
421
421
|
Args:
|
422
422
|
name: Name of Model Monitor to retrieve.
|
@@ -431,14 +431,14 @@ class Registry:
|
|
431
431
|
def get_monitor(
|
432
432
|
self, *, name: Optional[str] = None, model_version: Optional[model_version_impl.ModelVersion] = None
|
433
433
|
) -> model_monitor.ModelMonitor:
|
434
|
-
"""Get a Model Monitor from the Registry
|
434
|
+
"""Get a Model Monitor from the Registry.
|
435
435
|
|
436
436
|
Args:
|
437
437
|
name: Name of Model Monitor to retrieve.
|
438
|
-
model_version:
|
438
|
+
model_version: Model Version for which to retrieve the Model Monitor.
|
439
439
|
|
440
440
|
Returns:
|
441
|
-
The fetched
|
441
|
+
The fetched Model Monitor.
|
442
442
|
|
443
443
|
Raises:
|
444
444
|
ValueError: If monitoring is not enabled in the Registry.
|
@@ -476,7 +476,7 @@ class Registry:
|
|
476
476
|
)
|
477
477
|
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
478
478
|
def delete_monitor(self, name: str) -> None:
|
479
|
-
"""Delete a Model Monitor from the Registry
|
479
|
+
"""Delete a Model Monitor by name from the Registry.
|
480
480
|
|
481
481
|
Args:
|
482
482
|
name: Name of the Model Monitor to delete.
|
@@ -0,0 +1,75 @@
|
|
1
|
+
import http
|
2
|
+
import logging
|
3
|
+
from datetime import timedelta
|
4
|
+
from typing import Dict, Optional
|
5
|
+
|
6
|
+
import requests
|
7
|
+
from cryptography.hazmat.primitives.asymmetric import types
|
8
|
+
from requests import auth
|
9
|
+
|
10
|
+
from snowflake.ml._internal.utils import jwt_generator
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
_JWT_TOKEN_CACHE: Dict[str, Dict[int, str]] = {}
|
14
|
+
|
15
|
+
|
16
|
+
def get_jwt_token_generator(
|
17
|
+
account: str,
|
18
|
+
user: str,
|
19
|
+
private_key: types.PRIVATE_KEY_TYPES,
|
20
|
+
lifetime: Optional[timedelta] = None,
|
21
|
+
renewal_delay: Optional[timedelta] = None,
|
22
|
+
) -> jwt_generator.JWTGenerator:
|
23
|
+
return jwt_generator.JWTGenerator(account, user, private_key, lifetime=lifetime, renewal_delay=renewal_delay)
|
24
|
+
|
25
|
+
|
26
|
+
def _get_snowflake_token_by_jwt(
|
27
|
+
jwt_token_generator: jwt_generator.JWTGenerator,
|
28
|
+
account: Optional[str] = None,
|
29
|
+
role: Optional[str] = None,
|
30
|
+
endpoint: Optional[str] = None,
|
31
|
+
snowflake_account_url: Optional[str] = None,
|
32
|
+
) -> str:
|
33
|
+
scope_role = f"session:role:{role}" if role is not None else None
|
34
|
+
scope = " ".join(filter(None, [scope_role, endpoint]))
|
35
|
+
data = {
|
36
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
37
|
+
"scope": scope or None,
|
38
|
+
"assertion": jwt_token_generator.get_token(),
|
39
|
+
}
|
40
|
+
account = account or jwt_token_generator.account
|
41
|
+
url = f"https://{account}.snowflakecomputing.com/oauth/token"
|
42
|
+
if snowflake_account_url:
|
43
|
+
url = f"{snowflake_account_url}/oauth/token"
|
44
|
+
|
45
|
+
cache_key = hash(frozenset(data.items()))
|
46
|
+
if url in _JWT_TOKEN_CACHE:
|
47
|
+
if cache_key in _JWT_TOKEN_CACHE[url]:
|
48
|
+
return _JWT_TOKEN_CACHE[url][cache_key]
|
49
|
+
else:
|
50
|
+
_JWT_TOKEN_CACHE[url] = {}
|
51
|
+
|
52
|
+
response = requests.post(url, data=data)
|
53
|
+
if response.status_code != http.HTTPStatus.OK:
|
54
|
+
raise RuntimeError(f"Failed to get snowflake token: {response.status_code} {response.content!r}")
|
55
|
+
auth_token = response.text
|
56
|
+
_JWT_TOKEN_CACHE[url][cache_key] = auth_token
|
57
|
+
return auth_token
|
58
|
+
|
59
|
+
|
60
|
+
class SnowflakeJWTTokenAuth(auth.AuthBase):
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
jwt_token_generator: jwt_generator.JWTGenerator,
|
64
|
+
account: Optional[str] = None,
|
65
|
+
role: Optional[str] = None,
|
66
|
+
endpoint: Optional[str] = None,
|
67
|
+
snowflake_account_url: Optional[str] = None,
|
68
|
+
) -> None:
|
69
|
+
self.snowflake_token = _get_snowflake_token_by_jwt(
|
70
|
+
jwt_token_generator, account, role, endpoint, snowflake_account_url
|
71
|
+
)
|
72
|
+
|
73
|
+
def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
|
74
|
+
r.headers["Authorization"] = f'Snowflake Token="{self.snowflake_token}"'
|
75
|
+
return r
|
snowflake/ml/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
VERSION="1.7.
|
1
|
+
VERSION="1.7.2"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: snowflake-ml-python
|
3
|
-
Version: 1.7.
|
3
|
+
Version: 1.7.2
|
4
4
|
Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
|
5
5
|
Author-email: "Snowflake, Inc" <support@snowflake.com>
|
6
6
|
License:
|
@@ -232,61 +232,62 @@ Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
|
232
232
|
Requires-Python: <3.12,>=3.9
|
233
233
|
Description-Content-Type: text/markdown
|
234
234
|
License-File: LICENSE.txt
|
235
|
-
Requires-Dist: absl-py
|
236
|
-
Requires-Dist: anyio
|
237
|
-
Requires-Dist: cachetools
|
238
|
-
Requires-Dist: cloudpickle
|
235
|
+
Requires-Dist: absl-py<2,>=0.15
|
236
|
+
Requires-Dist: anyio<4,>=3.5.0
|
237
|
+
Requires-Dist: cachetools<6,>=3.1.1
|
238
|
+
Requires-Dist: cloudpickle>=2.0.0
|
239
239
|
Requires-Dist: cryptography
|
240
|
-
Requires-Dist: fsspec[http]
|
241
|
-
Requires-Dist:
|
242
|
-
Requires-Dist: numpy
|
243
|
-
Requires-Dist: packaging
|
244
|
-
Requires-Dist: pandas
|
240
|
+
Requires-Dist: fsspec[http]<2024,>=2022.11
|
241
|
+
Requires-Dist: importlib_resources<7,>=6.1.1
|
242
|
+
Requires-Dist: numpy<2,>=1.23
|
243
|
+
Requires-Dist: packaging<25,>=20.9
|
244
|
+
Requires-Dist: pandas<3,>=1.0.0
|
245
245
|
Requires-Dist: pyarrow
|
246
|
-
Requires-Dist:
|
247
|
-
Requires-Dist:
|
248
|
-
Requires-Dist:
|
249
|
-
Requires-Dist:
|
250
|
-
Requires-Dist:
|
251
|
-
Requires-Dist:
|
252
|
-
Requires-Dist:
|
253
|
-
Requires-Dist: snowflake-
|
254
|
-
Requires-Dist:
|
255
|
-
Requires-Dist:
|
256
|
-
Requires-Dist:
|
246
|
+
Requires-Dist: pyjwt<3,>=2.0.0
|
247
|
+
Requires-Dist: pytimeparse<2,>=1.1.8
|
248
|
+
Requires-Dist: pyyaml<7,>=6.0
|
249
|
+
Requires-Dist: retrying<2,>=1.3.3
|
250
|
+
Requires-Dist: s3fs<2024,>=2022.11
|
251
|
+
Requires-Dist: scikit-learn<1.6,>=1.4
|
252
|
+
Requires-Dist: scipy<2,>=1.9
|
253
|
+
Requires-Dist: snowflake-connector-python[pandas]<4,>=3.5.0
|
254
|
+
Requires-Dist: snowflake-snowpark-python<2,>=1.17.0
|
255
|
+
Requires-Dist: sqlparse<1,>=0.4
|
256
|
+
Requires-Dist: typing-extensions<5,>=4.1.0
|
257
|
+
Requires-Dist: xgboost<3,>=1.7.3
|
257
258
|
Provides-Extra: all
|
258
|
-
Requires-Dist: catboost
|
259
|
-
Requires-Dist: lightgbm
|
260
|
-
Requires-Dist: mlflow
|
261
|
-
Requires-Dist: peft
|
262
|
-
Requires-Dist: sentence-transformers
|
263
|
-
Requires-Dist: sentencepiece
|
264
|
-
Requires-Dist: shap
|
265
|
-
Requires-Dist: tensorflow
|
266
|
-
Requires-Dist: tokenizers
|
267
|
-
Requires-Dist: torch
|
268
|
-
Requires-Dist: torchdata
|
269
|
-
Requires-Dist: transformers
|
259
|
+
Requires-Dist: catboost<2,>=1.2.0; extra == "all"
|
260
|
+
Requires-Dist: lightgbm<5,>=4.1.0; extra == "all"
|
261
|
+
Requires-Dist: mlflow<2.4,>=2.1.0; extra == "all"
|
262
|
+
Requires-Dist: peft<1,>=0.5.0; extra == "all"
|
263
|
+
Requires-Dist: sentence-transformers<3,>=2.2.2; extra == "all"
|
264
|
+
Requires-Dist: sentencepiece<1,>=0.1.95; extra == "all"
|
265
|
+
Requires-Dist: shap<1,>=0.46.0; extra == "all"
|
266
|
+
Requires-Dist: tensorflow<3,>=2.10; extra == "all"
|
267
|
+
Requires-Dist: tokenizers<1,>=0.10; extra == "all"
|
268
|
+
Requires-Dist: torch<2.3.0,>=2.0.1; extra == "all"
|
269
|
+
Requires-Dist: torchdata<1,>=0.4; extra == "all"
|
270
|
+
Requires-Dist: transformers<5,>=4.32.1; extra == "all"
|
270
271
|
Provides-Extra: catboost
|
271
|
-
Requires-Dist: catboost
|
272
|
+
Requires-Dist: catboost<2,>=1.2.0; extra == "catboost"
|
272
273
|
Provides-Extra: lightgbm
|
273
|
-
Requires-Dist: lightgbm
|
274
|
+
Requires-Dist: lightgbm<5,>=4.1.0; extra == "lightgbm"
|
274
275
|
Provides-Extra: llm
|
275
|
-
Requires-Dist: peft
|
276
|
+
Requires-Dist: peft<1,>=0.5.0; extra == "llm"
|
276
277
|
Provides-Extra: mlflow
|
277
|
-
Requires-Dist: mlflow
|
278
|
+
Requires-Dist: mlflow<2.4,>=2.1.0; extra == "mlflow"
|
278
279
|
Provides-Extra: shap
|
279
|
-
Requires-Dist: shap
|
280
|
+
Requires-Dist: shap<1,>=0.46.0; extra == "shap"
|
280
281
|
Provides-Extra: tensorflow
|
281
|
-
Requires-Dist: tensorflow
|
282
|
+
Requires-Dist: tensorflow<3,>=2.10; extra == "tensorflow"
|
282
283
|
Provides-Extra: torch
|
283
|
-
Requires-Dist: torch
|
284
|
-
Requires-Dist: torchdata
|
284
|
+
Requires-Dist: torch<2.3.0,>=2.0.1; extra == "torch"
|
285
|
+
Requires-Dist: torchdata<1,>=0.4; extra == "torch"
|
285
286
|
Provides-Extra: transformers
|
286
|
-
Requires-Dist: sentence-transformers
|
287
|
-
Requires-Dist: sentencepiece
|
288
|
-
Requires-Dist: tokenizers
|
289
|
-
Requires-Dist: transformers
|
287
|
+
Requires-Dist: sentence-transformers<3,>=2.2.2; extra == "transformers"
|
288
|
+
Requires-Dist: sentencepiece<1,>=0.1.95; extra == "transformers"
|
289
|
+
Requires-Dist: tokenizers<1,>=0.10; extra == "transformers"
|
290
|
+
Requires-Dist: transformers<5,>=4.32.1; extra == "transformers"
|
290
291
|
|
291
292
|
# Snowpark ML
|
292
293
|
|
@@ -302,7 +303,7 @@ and deployment process, and includes two key components.
|
|
302
303
|
|
303
304
|
### Snowpark ML Development
|
304
305
|
|
305
|
-
[Snowpark ML Development](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#
|
306
|
+
[Snowpark ML Development](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#ml-modeling)
|
306
307
|
provides a collection of python APIs enabling efficient ML model development directly in Snowflake:
|
307
308
|
|
308
309
|
1. Modeling API (`snowflake.ml.modeling`) for data preprocessing, feature engineering and model training in Snowflake.
|
@@ -316,14 +317,21 @@ their native data loader formats.
|
|
316
317
|
1. FileSet API: FileSet provides a Python fsspec-compliant API for materializing data into a Snowflake internal stage
|
317
318
|
from a query or Snowpark Dataframe along with a number of convenience APIs.
|
318
319
|
|
319
|
-
###
|
320
|
+
### Snowflake MLOps
|
320
321
|
|
321
|
-
|
322
|
-
the Snowpark ML Development API, and provides
|
322
|
+
Snowflake MLOps contains suit of tools and objects to make ML development cycle. It complements
|
323
|
+
the Snowpark ML Development API, and provides end to end development to deployment within Snowflake.
|
323
324
|
Currently, the API consists of:
|
324
325
|
|
325
|
-
1. Registry: A python API
|
326
|
-
|
326
|
+
1. [Registry](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowflake-model-registry): A python API
|
327
|
+
allows secure deployment and management of models in Snowflake, supporting models trained both inside and outside of
|
328
|
+
Snowflake.
|
329
|
+
2. [Feature Store](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowflake-feature-store): A fully
|
330
|
+
integrated solution for defining, managing, storing and discovering ML features derived from your data. The
|
331
|
+
Snowflake Feature Store supports automated, incremental refresh from batch and streaming data sources, so that
|
332
|
+
feature pipelines need be defined only once to be continuously updated with new data.
|
333
|
+
3. [Datasets](https://docs.snowflake.com/developer-guide/snowflake-ml/overview#snowflake-datasets): Dataset provide an
|
334
|
+
immutable, versioned snapshot of your data suitable for ingestion by your machine learning models.
|
327
335
|
|
328
336
|
## Getting started
|
329
337
|
|
@@ -371,9 +379,39 @@ conda install \
|
|
371
379
|
Note that until a `snowflake-ml-python` package version is available in the official Snowflake conda channel, there may
|
372
380
|
be compatibility issues. Server-side functionality that `snowflake-ml-python` depends on may not yet be released.
|
373
381
|
|
382
|
+
### Verifying the package
|
383
|
+
|
384
|
+
1. Install cosign.
|
385
|
+
This example is using golang installation: [installing-cosign-with-go](https://edu.chainguard.dev/open-source/sigstore/cosign/how-to-install-cosign/#installing-cosign-with-go).
|
386
|
+
1. Download the file from the repository like [pypi](https://pypi.org/project/snowflake-ml-python/#files).
|
387
|
+
1. Download the signature files from the [release tag](https://github.com/snowflakedb/snowflake-ml-python/releases/tag/1.7.0).
|
388
|
+
1. Verify signature on projects signed using Jenkins job:
|
389
|
+
|
390
|
+
```sh
|
391
|
+
cosign verify-blob snowflake_ml_python-1.7.0.tar.gz --key snowflake-ml-python-1.7.0.pub --signature resources.linux.snowflake_ml_python-1.7.0.tar.gz.sig
|
392
|
+
|
393
|
+
cosign verify-blob snowflake_ml_python-1.7.0.tar.gz --key snowflake-ml-python-1.7.0.pub --signature resources.linux.snowflake_ml_python-1.7.0
|
394
|
+
```
|
395
|
+
|
396
|
+
NOTE: Version 1.7.0 is used as example here. Please choose the the latest version.
|
397
|
+
|
374
398
|
# Release History
|
375
399
|
|
376
|
-
## 1.7.
|
400
|
+
## 1.7.2
|
401
|
+
|
402
|
+
### Bug Fixes
|
403
|
+
|
404
|
+
- Model Explainability: Fix issue that explain is enabled for scikit-learn pipeline
|
405
|
+
whose task is UNKNOWN and fails later when invoked.
|
406
|
+
|
407
|
+
### Behavior Changes
|
408
|
+
|
409
|
+
### New Features
|
410
|
+
|
411
|
+
- Registry: Support asynchronous model inference service creation with the `block` option
|
412
|
+
in `ModelVersion.create_service()` set to True by default.
|
413
|
+
|
414
|
+
## 1.7.1 (2024-11-05)
|
377
415
|
|
378
416
|
### Bug Fixes
|
379
417
|
|
@@ -10,7 +10,7 @@ snowflake/cortex/_sse_client.py,sha256=sLYgqAfTOPADCnaWH2RWAJi8KbU_7gSRsTUDcDD5T
|
|
10
10
|
snowflake/cortex/_summarize.py,sha256=bwpFBzBGmNQSoJqKs3IB5wASjAREnC5ZnViSuZK5IrU,1059
|
11
11
|
snowflake/cortex/_translate.py,sha256=69YUps6mnhzVdubdU_H0IfUAlbBwF9OPemFEQ34P-ts,1404
|
12
12
|
snowflake/cortex/_util.py,sha256=cwRGgrcUo3E05ZaIDT9436vXLQ7GfuBVAjR0QeQ2bDE,3320
|
13
|
-
snowflake/ml/version.py,sha256=
|
13
|
+
snowflake/ml/version.py,sha256=wJaJaqPpO6Ic3Pl_5e81zlGKYqi1rf5q8V10jTUEDjA,16
|
14
14
|
snowflake/ml/_internal/env.py,sha256=kCrJTRnqQ97VGUVI1cWUPD8HuBWeL5vOOtwUR0NB9Mg,161
|
15
15
|
snowflake/ml/_internal/env_utils.py,sha256=J_jitp8jvDoC3a79EbMSDatFRYw-HiXaI9vR81bhtU8,28075
|
16
16
|
snowflake/ml/_internal/file_utils.py,sha256=OyXHv-UcItiip1YgLnab6etonUQkYuyDtmplZA0CaoU,13622
|
@@ -36,6 +36,7 @@ snowflake/ml/_internal/utils/db_utils.py,sha256=HBAY0-XHzCP4ai5q3Yqd8O19Ar_Q9J3x
|
|
36
36
|
snowflake/ml/_internal/utils/formatting.py,sha256=PswZ6Xas7sx3Ok1MBLoH2o7nfXOxaJqpUPg_UqXrQb8,3676
|
37
37
|
snowflake/ml/_internal/utils/identifier.py,sha256=fUYXjXKXAkjLUZpomneMHo2wR4_ZNP4ak-5OJxeUS-g,12467
|
38
38
|
snowflake/ml/_internal/utils/import_utils.py,sha256=iUIROZdiTGy73UCGpG0N-dKtK54H0ymNVge_QNQYY3A,3220
|
39
|
+
snowflake/ml/_internal/utils/jwt_generator.py,sha256=bj7Ltnw68WjRcxtV9t5xrTRvV5ETnvovB-o3Y8QWNBg,5357
|
39
40
|
snowflake/ml/_internal/utils/parallelize.py,sha256=Q6_-P2t4DoYNO8DyC1kOl7H3qNL-bUK6EgtlQ_b5ThY,4534
|
40
41
|
snowflake/ml/_internal/utils/pkg_version_utils.py,sha256=FwdLHFhxi3CAQQduGjFavEBmkD9Ra6ZTkt6Eub-WoSA,5168
|
41
42
|
snowflake/ml/_internal/utils/query_result_checker.py,sha256=h1nbUImdB9lSNCON3uIA0xCm8_JrS-TE-jQXJJs9WfU,10668
|
@@ -101,17 +102,17 @@ snowflake/ml/model/custom_model.py,sha256=O60mjz2Vy8A0Rt3obq43zBT3BxkU7CIcN0AkHs
|
|
101
102
|
snowflake/ml/model/model_signature.py,sha256=gZnZPs9zTCYkeFoiQzoGUQYZMydYjzH-4xPTzfqt4hU,30496
|
102
103
|
snowflake/ml/model/type_hints.py,sha256=9GPwEuG6B6GSWOXdOy8B1Swz6yDngL865yEtJMd0v1U,8883
|
103
104
|
snowflake/ml/model/_client/model/model_impl.py,sha256=pqjK8mSZIQJ_30tRWWFPIo8X35InSVoAunXlQNtSJEM,15369
|
104
|
-
snowflake/ml/model/_client/model/model_version_impl.py,sha256=
|
105
|
+
snowflake/ml/model/_client/model/model_version_impl.py,sha256=tGfSR4dF8okdBPeAu7yWVSLtwvnvhnJr9xalKbQZw5M,40144
|
105
106
|
snowflake/ml/model/_client/ops/metadata_ops.py,sha256=7cGx8zYzye2_cvZnyGxoukPtT6Q-Kexd-s4yeZmpmj8,4890
|
106
107
|
snowflake/ml/model/_client/ops/model_ops.py,sha256=didFBsjb7KJYV_586TUK4c9DudVQvjzlphEXJW0AnmY,43935
|
107
|
-
snowflake/ml/model/_client/ops/service_ops.py,sha256=
|
108
|
+
snowflake/ml/model/_client/ops/service_ops.py,sha256=t_yLtHlAzHc28XDZ543yAALY5iVsRwVw4i9mtiPaXpQ,19237
|
108
109
|
snowflake/ml/model/_client/service/model_deployment_spec.py,sha256=uyh5k_u8mVP5T4lf0jq8s2cFuiTsbV_nJL6z1Zum2rM,4456
|
109
110
|
snowflake/ml/model/_client/service/model_deployment_spec_schema.py,sha256=eaulF6OFNuDfQz3oPYlDjP26Ww2jWWatm81dCbg602E,825
|
110
111
|
snowflake/ml/model/_client/sql/_base.py,sha256=Qrm8M92g3MHb-QnSLUlbd8iVKCRxLhG_zr5M2qmXwJ8,1473
|
111
112
|
snowflake/ml/model/_client/sql/model.py,sha256=o36oPq4aU9TwahqY2uODYvICxmj1orLztijJ0yMbWnM,5852
|
112
113
|
snowflake/ml/model/_client/sql/model_version.py,sha256=hNMlmwN5JQngKuaeUYV2Bli73RMnHmVH01ABX9NBHFk,20686
|
113
114
|
snowflake/ml/model/_client/sql/service.py,sha256=fvQRhRGU4FBeOBouIoQByTvfQg-qbEQKplCG99BPmL0,10408
|
114
|
-
snowflake/ml/model/_client/sql/stage.py,sha256=
|
115
|
+
snowflake/ml/model/_client/sql/stage.py,sha256=165vyAtrScSQWJB8wLXKRUO1QvHTWDmPykeWOyxrDRg,826
|
115
116
|
snowflake/ml/model/_client/sql/tag.py,sha256=pwwrcyPtSnkUfDzL3M8kqM0KSx7CaTtgty3HDhVC9vg,4345
|
116
117
|
snowflake/ml/model/_model_composer/model_composer.py,sha256=535ElL3Kw8eoUjL7fHd-K20eDCBqvJFwowUx2_UOCl8,6712
|
117
118
|
snowflake/ml/model/_model_composer/model_manifest/model_manifest.py,sha256=X6-cKLBZ1X2liIjWnyrd9efQaQhwIoxRSE90Zs0kAZo,7822
|
@@ -133,7 +134,7 @@ snowflake/ml/model/_packager/model_handlers/lightgbm.py,sha256=E0667G5FFfMssaXjk
|
|
133
134
|
snowflake/ml/model/_packager/model_handlers/mlflow.py,sha256=A3HnCa065jtHsRM40ZxfLv5alk0RYhVmsU4Jt2klRwQ,9189
|
134
135
|
snowflake/ml/model/_packager/model_handlers/pytorch.py,sha256=DDcf85xisPLT1PyXdmPrjJpIIepkdmWNXCOpT_dCncw,8294
|
135
136
|
snowflake/ml/model/_packager/model_handlers/sentence_transformers.py,sha256=f21fJw2wPsXzzhv71Gi1eHctSlyJ6NAR1EQX5iUL5M8,9842
|
136
|
-
snowflake/ml/model/_packager/model_handlers/sklearn.py,sha256=
|
137
|
+
snowflake/ml/model/_packager/model_handlers/sklearn.py,sha256=dwwETBdJJM3AVfl3R6VvvVOZQHgnwIuk9dUUCDOs-w0,14111
|
137
138
|
snowflake/ml/model/_packager/model_handlers/snowmlmodel.py,sha256=uhsJ3zK24aavBRO5gNyxv8BHqU9n1TPUBYm1qHTuaxE,12176
|
138
139
|
snowflake/ml/model/_packager/model_handlers/tensorflow.py,sha256=SkbnvkElK4UIMgygv9EK9f5hBxWZ2YDroymUC9uBsBk,9169
|
139
140
|
snowflake/ml/model/_packager/model_handlers/torchscript.py,sha256=BIdRINO1xZ5uHrR9uA0vExWQymOryTaSpyAMpCCtz8U,8036
|
@@ -146,7 +147,7 @@ snowflake/ml/model/_packager/model_meta/model_meta_schema.py,sha256=5Sdh1_NCKycL
|
|
146
147
|
snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py,sha256=SORlqpPbOeBg6dvJ3DidHeLVi0w9YF0Zv4tC0Kbc20g,1311
|
147
148
|
snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py,sha256=nf6PWDH_gvX_OiS4A-G6BzyCLFEG4dASU0t5JTsijM4,1041
|
148
149
|
snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py,sha256=qEPzdCw_FzExMbPuyFHupeWlYD88yejLdcmkPwjJzDk,2070
|
149
|
-
snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py,sha256=
|
150
|
+
snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py,sha256=eR_qxEwsmzaeaRYH9K4wUAG7bhpqZvn07en2vfRV4c4,1459
|
150
151
|
snowflake/ml/model/_packager/model_runtime/model_runtime.py,sha256=G52nrjzcZiWBJaed6Z1qKq-HjqtnG2MnywDdU9lPusg,5051
|
151
152
|
snowflake/ml/model/_packager/model_task/model_task_utils.py,sha256=0aEUfg71bP5-RkwmzOJBe51yHxLRrtM17tUBoCiuMMk,6310
|
152
153
|
snowflake/ml/model/_signatures/base_handler.py,sha256=WwBfe-83Y0m-HcDx1YSYCGwanIe0fb2MWhTeXc1IeJI,1304
|
@@ -157,7 +158,7 @@ snowflake/ml/model/_signatures/pandas_handler.py,sha256=ACv8egyiK2Sug8uhkQqMDGTT
|
|
157
158
|
snowflake/ml/model/_signatures/pytorch_handler.py,sha256=yEU-V_WRjE8Q7NdHyghl0iYpMiIDzGaIR5Pd_ixB1Hk,4631
|
158
159
|
snowflake/ml/model/_signatures/snowpark_handler.py,sha256=2_AY1ssucMICKSPeDjf3mV4WT5farKYdnYkHsvhHZ20,6066
|
159
160
|
snowflake/ml/model/_signatures/tensorflow_handler.py,sha256=9bUbxtHpl4kEoFzeDJF87bQPb8RdLLm9OV23-aUyW3s,6114
|
160
|
-
snowflake/ml/model/_signatures/utils.py,sha256
|
161
|
+
snowflake/ml/model/_signatures/utils.py,sha256=1E_mV1qdUuob8tjB8WaOEfuo2rmQ2FtOgTNyXZGzoJg,13108
|
161
162
|
snowflake/ml/model/models/huggingface_pipeline.py,sha256=62GpPZxBheqCnFNxNOggiDE1y9Dhst-v6D4IkGLuDeQ,10221
|
162
163
|
snowflake/ml/modeling/_internal/constants.py,sha256=aJGngY599w3KqN8cDZCYrjbWe6UwYIbgv0gx0Ukdtc0,105
|
163
164
|
snowflake/ml/modeling/_internal/estimator_utils.py,sha256=mbMm8_5tQde_sQDwI8pS3ljHZ8maCHl2Shb5nQwLYac,11872
|
@@ -379,23 +380,23 @@ snowflake/ml/modeling/xgboost/xgb_classifier.py,sha256=2QEK6-NihXjKXO8Ue-fOZDyuc
|
|
379
380
|
snowflake/ml/modeling/xgboost/xgb_regressor.py,sha256=ZorEmRohT2-AUdS8fK0xH8BdB8ENxvVMMDYy34Jzm1o,61703
|
380
381
|
snowflake/ml/modeling/xgboost/xgbrf_classifier.py,sha256=67jh9RosrTeYCWsJbnJ6_MQICHeG22z-DMy8CegP8Vg,62383
|
381
382
|
snowflake/ml/modeling/xgboost/xgbrf_regressor.py,sha256=7_ZwF_QvVqBrkFx_zgGgLXyxtbX26XrWWLozAF-EBB0,61908
|
382
|
-
snowflake/ml/monitoring/model_monitor.py,sha256=
|
383
|
+
snowflake/ml/monitoring/model_monitor.py,sha256=8vJf1YROmJgBLUtpaH-lGKSSJv9R7PxPaQnOdr_j5YE,2200
|
383
384
|
snowflake/ml/monitoring/model_monitor_version.py,sha256=TlmDJZDE0lCVatRaBRgXIjzDF538nrMIc-zWj9MM_nk,46
|
384
385
|
snowflake/ml/monitoring/shap.py,sha256=Dp9nYquPEZjxMTW62YYA9g9qUdmCEFxcSk7ejvOP7PE,3597
|
385
|
-
snowflake/ml/monitoring/_client/model_monitor_sql_client.py,sha256=
|
386
|
+
snowflake/ml/monitoring/_client/model_monitor_sql_client.py,sha256=Qr3L6bs84ID5_1TvY6wf5YK2kn3ZVZ-Havo242i3MiY,12710
|
386
387
|
snowflake/ml/monitoring/_client/queries/record_count.ssql,sha256=Bd1uNMwhPKqPyrDd5ug8iY493t9KamJjrlo82OAfmjY,335
|
387
388
|
snowflake/ml/monitoring/_client/queries/rmse.ssql,sha256=OEJiSStRz9-qKoZaFvmubtY_n0xMUjyVU2uiQHCp7KU,822
|
388
|
-
snowflake/ml/monitoring/_manager/model_monitor_manager.py,sha256=
|
389
|
-
snowflake/ml/monitoring/entities/model_monitor_config.py,sha256=
|
390
|
-
snowflake/ml/monitoring/entities/output_score_type.py,sha256=UJyS4z5hncRZ0agVNa6_X041RY9q3Us-6Bh3dPVAmEw,2982
|
389
|
+
snowflake/ml/monitoring/_manager/model_monitor_manager.py,sha256=_-vxqnHqohTHTrwfURjPXijyAeh1mTRdHCG436GaBik,10314
|
390
|
+
snowflake/ml/monitoring/entities/model_monitor_config.py,sha256=IxEiee1HfBXCQGzJOZbrDrvoV8J1tDNk43ygNuN00Io,1793
|
391
391
|
snowflake/ml/registry/__init__.py,sha256=XdPQK9ejYkSJVrSQ7HD3jKQO0hKq2mC4bPCB6qrtH3U,76
|
392
|
-
snowflake/ml/registry/registry.py,sha256=
|
392
|
+
snowflake/ml/registry/registry.py,sha256=5aBedBH8NiFkJJe1Pnggsrjnn0ixdg1oqtUHWyz3wsE,23824
|
393
393
|
snowflake/ml/registry/_manager/model_manager.py,sha256=gFr1EqaMR2Eb4erwVz7fi7xK1G1YsFXz1PF5GvOR0pg,12131
|
394
|
+
snowflake/ml/utils/authentication.py,sha256=Wx1kVBZ9XBDuKkRHpPEB2pBxpiJepVLFAirDMx4m5Gk,2612
|
394
395
|
snowflake/ml/utils/connection_params.py,sha256=JRpQppuWRk6bhdLzVDhMfz3Y6yInobFNLHmIBaXD7po,8005
|
395
396
|
snowflake/ml/utils/sparse.py,sha256=XqDQkw39Ml6YIknswdkvFIwUwBk_GBXAbP8IACfPENg,3817
|
396
397
|
snowflake/ml/utils/sql_client.py,sha256=z4Rhi7pQz3s9cyu_Uzfr3deCnrkCdFh9IYIvicsuwdc,692
|
397
|
-
snowflake_ml_python-1.7.
|
398
|
-
snowflake_ml_python-1.7.
|
399
|
-
snowflake_ml_python-1.7.
|
400
|
-
snowflake_ml_python-1.7.
|
401
|
-
snowflake_ml_python-1.7.
|
398
|
+
snowflake_ml_python-1.7.2.dist-info/LICENSE.txt,sha256=PdEp56Av5m3_kl21iFkVTX_EbHJKFGEdmYeIO1pL_Yk,11365
|
399
|
+
snowflake_ml_python-1.7.2.dist-info/METADATA,sha256=GwZOHmNQAKaMDP3VeWIDWC-OMhPqldoJaYPrR-_iWGw,67429
|
400
|
+
snowflake_ml_python-1.7.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
401
|
+
snowflake_ml_python-1.7.2.dist-info/top_level.txt,sha256=TY0gFSHKDdZy3THb0FGomyikWQasEGldIR1O0HGOHVw,10
|
402
|
+
snowflake_ml_python-1.7.2.dist-info/RECORD,,
|
@@ -1,90 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from enum import Enum
|
4
|
-
from typing import List, Mapping
|
5
|
-
|
6
|
-
from snowflake.ml._internal.utils import sql_identifier
|
7
|
-
from snowflake.ml.model import type_hints
|
8
|
-
from snowflake.snowpark import types
|
9
|
-
|
10
|
-
# Accepted data types for each OutputScoreType.
|
11
|
-
REGRESSION_DATA_TYPES = (
|
12
|
-
types.ByteType,
|
13
|
-
types.ShortType,
|
14
|
-
types.IntegerType,
|
15
|
-
types.LongType,
|
16
|
-
types.FloatType,
|
17
|
-
types.DoubleType,
|
18
|
-
types.DecimalType,
|
19
|
-
)
|
20
|
-
CLASSIFICATION_DATA_TYPES = (
|
21
|
-
types.ByteType,
|
22
|
-
types.ShortType,
|
23
|
-
types.IntegerType,
|
24
|
-
types.BooleanType,
|
25
|
-
types.BinaryType,
|
26
|
-
)
|
27
|
-
PROBITS_DATA_TYPES = (
|
28
|
-
types.ByteType,
|
29
|
-
types.ShortType,
|
30
|
-
types.IntegerType,
|
31
|
-
types.LongType,
|
32
|
-
types.FloatType,
|
33
|
-
types.DoubleType,
|
34
|
-
types.DecimalType,
|
35
|
-
)
|
36
|
-
|
37
|
-
|
38
|
-
# OutputScoreType enum
|
39
|
-
class OutputScoreType(Enum):
|
40
|
-
UNKNOWN = "UNKNOWN"
|
41
|
-
REGRESSION = "REGRESSION"
|
42
|
-
CLASSIFICATION = "CLASSIFICATION"
|
43
|
-
PROBITS = "PROBITS"
|
44
|
-
|
45
|
-
@classmethod
|
46
|
-
def deduce_score_type(
|
47
|
-
cls,
|
48
|
-
table_schema: Mapping[str, types.DataType],
|
49
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
50
|
-
task: type_hints.Task,
|
51
|
-
) -> OutputScoreType:
|
52
|
-
"""Find the score type for monitoring given a table schema and the task.
|
53
|
-
|
54
|
-
Args:
|
55
|
-
table_schema: Dictionary of column names and types in the source table.
|
56
|
-
prediction_columns: List of prediction columns.
|
57
|
-
task: Enum value for the task of the model.
|
58
|
-
|
59
|
-
Returns:
|
60
|
-
Enum value for the score type, informing monitoring table set up.
|
61
|
-
|
62
|
-
Raises:
|
63
|
-
ValueError: If prediction type fails to align with task.
|
64
|
-
"""
|
65
|
-
# Already validated we have just one prediction column type
|
66
|
-
prediction_column_type = {table_schema[column_name] for column_name in prediction_columns}.pop()
|
67
|
-
|
68
|
-
if task == type_hints.Task.TABULAR_REGRESSION:
|
69
|
-
if isinstance(prediction_column_type, REGRESSION_DATA_TYPES):
|
70
|
-
return OutputScoreType.REGRESSION
|
71
|
-
else:
|
72
|
-
raise ValueError(
|
73
|
-
f"Expected prediction column type to be one of {REGRESSION_DATA_TYPES} "
|
74
|
-
f"for REGRESSION task. Found: {prediction_column_type}."
|
75
|
-
)
|
76
|
-
|
77
|
-
elif task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION:
|
78
|
-
if isinstance(prediction_column_type, CLASSIFICATION_DATA_TYPES):
|
79
|
-
return OutputScoreType.CLASSIFICATION
|
80
|
-
elif isinstance(prediction_column_type, PROBITS_DATA_TYPES):
|
81
|
-
return OutputScoreType.PROBITS
|
82
|
-
else:
|
83
|
-
raise ValueError(
|
84
|
-
f"Expected prediction column type to be one of {CLASSIFICATION_DATA_TYPES} "
|
85
|
-
f"or one of {PROBITS_DATA_TYPES} for CLASSIFICATION task. "
|
86
|
-
f"Found: {prediction_column_type}."
|
87
|
-
)
|
88
|
-
|
89
|
-
else:
|
90
|
-
raise ValueError(f"Received unsupported task for model monitoring: {task}.")
|
File without changes
|
File without changes
|