snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. snowflake/ml/_internal/telemetry.py +42 -16
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/data/data_connector.py +1 -1
  4. snowflake/ml/jobs/__init__.py +2 -0
  5. snowflake/ml/jobs/_utils/constants.py +12 -2
  6. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  7. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +95 -39
  9. snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
  11. snowflake/ml/jobs/_utils/spec_utils.py +30 -6
  12. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  13. snowflake/ml/jobs/_utils/types.py +5 -1
  14. snowflake/ml/jobs/decorators.py +10 -7
  15. snowflake/ml/jobs/job.py +176 -28
  16. snowflake/ml/jobs/manager.py +119 -26
  17. snowflake/ml/model/_client/model/model_impl.py +58 -0
  18. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +6 -3
  20. snowflake/ml/model/_client/ops/service_ops.py +24 -7
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
  22. snowflake/ml/model/_client/sql/model_version.py +1 -1
  23. snowflake/ml/model/_client/sql/service.py +73 -28
  24. snowflake/ml/model/_client/sql/stage.py +5 -2
  25. snowflake/ml/model/_model_composer/model_composer.py +3 -1
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  28. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
  29. snowflake/ml/model/_signatures/core.py +24 -0
  30. snowflake/ml/monitoring/explain_visualize.py +160 -22
  31. snowflake/ml/monitoring/model_monitor.py +0 -4
  32. snowflake/ml/registry/registry.py +34 -14
  33. snowflake/ml/utils/connection_params.py +9 -3
  34. snowflake/ml/utils/html_utils.py +263 -0
  35. snowflake/ml/version.py +1 -1
  36. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
  37. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
  38. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
  39. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  40. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import enum
4
4
  import functools
5
5
  import inspect
6
6
  import operator
7
+ import os
7
8
  import sys
8
9
  import time
9
10
  import traceback
@@ -13,7 +14,7 @@ from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union, c
13
14
  from typing_extensions import ParamSpec
14
15
 
15
16
  from snowflake import connector
16
- from snowflake.connector import telemetry as connector_telemetry, time_util
17
+ from snowflake.connector import connect, telemetry as connector_telemetry, time_util
17
18
  from snowflake.ml import version as snowml_version
18
19
  from snowflake.ml._internal import env
19
20
  from snowflake.ml._internal.exceptions import (
@@ -37,6 +38,37 @@ _Args = ParamSpec("_Args")
37
38
  _ReturnValue = TypeVar("_ReturnValue")
38
39
 
39
40
 
41
+ def _get_login_token() -> Union[str, bytes]:
42
+ with open("/snowflake/session/token") as f:
43
+ return f.read()
44
+
45
+
46
+ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
47
+ conn = None
48
+ if os.getenv("SNOWFLAKE_HOST") is not None and os.getenv("SNOWFLAKE_ACCOUNT") is not None:
49
+ try:
50
+ conn = connect(
51
+ host=os.getenv("SNOWFLAKE_HOST"),
52
+ account=os.getenv("SNOWFLAKE_ACCOUNT"),
53
+ token=_get_login_token(),
54
+ authenticator="oauth",
55
+ )
56
+ except Exception:
57
+ # Failed to get a new SnowflakeConnection in SPCS. Fall back to using the active session.
58
+ # This will work in some cases once SPCS enables multiple authentication modes, and users select any auth.
59
+ pass
60
+
61
+ if conn is None:
62
+ try:
63
+ active_session = next(iter(session._get_active_sessions()))
64
+ conn = active_session._conn._conn if active_session.telemetry_enabled else None
65
+ except snowpark_exceptions.SnowparkSessionException:
66
+ # Failed to get an active session. No connection available.
67
+ pass
68
+
69
+ return conn
70
+
71
+
40
72
  @enum.unique
41
73
  class TelemetryProject(enum.Enum):
42
74
  MLOPS = "MLOps"
@@ -378,13 +410,14 @@ def send_custom_usage(
378
410
  data: Optional[dict[str, Any]] = None,
379
411
  **kwargs: Any,
380
412
  ) -> None:
381
- active_session = next(iter(session._get_active_sessions()))
382
- assert active_session, "Missing active session object"
413
+ conn = _get_snowflake_connection()
383
414
 
384
- client = _SourceTelemetryClient(conn=active_session._conn._conn, project=project, subproject=subproject)
385
- common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
386
- data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
387
- client._send(msg=data)
415
+ # Send telemetry if Snowflake connection is available.
416
+ if conn is not None:
417
+ client = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
418
+ common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
419
+ data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
420
+ client._send(msg=data)
388
421
 
389
422
 
390
423
  def send_api_usage_telemetry(
@@ -501,7 +534,6 @@ def send_api_usage_telemetry(
501
534
  return update_stmt_params_if_snowpark_df(result, statement_params)
502
535
 
503
536
  # prioritize `conn_attr_name` over the active session
504
- telemetry_enabled = True
505
537
  if conn_attr_name:
506
538
  # raise AttributeError if conn attribute does not exist in `self`
507
539
  conn = operator.attrgetter(conn_attr_name)(args[0])
@@ -509,16 +541,10 @@ def send_api_usage_telemetry(
509
541
  raise TypeError(
510
542
  f"Expected a conn object of type {' or '.join(_CONNECTION_TYPES.keys())} but got {type(conn)}"
511
543
  )
512
- # get an active session
513
544
  else:
514
- try:
515
- active_session = next(iter(session._get_active_sessions()))
516
- conn = active_session._conn._conn
517
- telemetry_enabled = active_session.telemetry_enabled
518
- except snowpark_exceptions.SnowparkSessionException:
519
- conn = None
545
+ conn = _get_snowflake_connection()
520
546
 
521
- if conn is None or not telemetry_enabled:
547
+ if conn is None:
522
548
  # Telemetry not enabled, just execute without our additional telemetry logic
523
549
  try:
524
550
  return ctx.run(execute_func_with_statement_params)
@@ -0,0 +1,196 @@
1
+ import configparser
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ from absl import logging
6
+ from cryptography.hazmat import backends
7
+ from cryptography.hazmat.primitives import serialization
8
+
9
+ _DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
10
+
11
+
12
+ def _read_token(token_file: str = "") -> str:
13
+ """
14
+ Reads token from environment or file provided.
15
+
16
+ First tries to read the token from environment variable
17
+ (`SNOWFLAKE_TOKEN`) followed by the token file.
18
+ Both the options are tried out in SnowServices.
19
+
20
+ Args:
21
+ token_file: File from which token needs to be read. Optional.
22
+
23
+ Returns:
24
+ the token.
25
+ """
26
+ token = os.getenv("SNOWFLAKE_TOKEN", "")
27
+ if token:
28
+ return token
29
+ if token_file and os.path.exists(token_file):
30
+ with open(token_file) as f:
31
+ token = f.read()
32
+ return token
33
+
34
+
35
+ _ENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN ENCRYPTED PRIVATE KEY-----"
36
+ _UNENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN PRIVATE KEY-----"
37
+
38
+
39
+ def _load_pem_to_der(private_key_path: str) -> bytes:
40
+ """Given a private key file path (in PEM format), decode key data into DER format."""
41
+ with open(private_key_path, "rb") as f:
42
+ private_key_pem = f.read()
43
+ private_key_passphrase: Optional[str] = os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", None)
44
+
45
+ # Only PKCS#8 format key will be accepted. However, openssl
46
+ # transparently handle PKCS#8 and PKCS#1 format (by some fallback
47
+ # logic) and their is no function to distinguish between them. By
48
+ # reading openssl source code, apparently they also relies on header
49
+ # to determine if give bytes is PKCS#8 format or not
50
+ if not private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and not private_key_pem.startswith(
51
+ _UNENCRYPTED_PKCS8_PK_HEADER
52
+ ):
53
+ raise Exception("Private key provided is not in PKCS#8 format. Please use correct format.")
54
+
55
+ if private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and private_key_passphrase is None:
56
+ raise Exception(
57
+ "Private key is encrypted but passphrase could not be found. "
58
+ "Please set SNOWFLAKE_PRIVATE_KEY_PASSPHRASE env variable."
59
+ )
60
+
61
+ if private_key_pem.startswith(_UNENCRYPTED_PKCS8_PK_HEADER):
62
+ private_key_passphrase = None
63
+
64
+ private_key = serialization.load_pem_private_key(
65
+ private_key_pem,
66
+ str.encode(private_key_passphrase) if private_key_passphrase is not None else private_key_passphrase,
67
+ backends.default_backend(),
68
+ )
69
+
70
+ return private_key.private_bytes(
71
+ encoding=serialization.Encoding.DER,
72
+ format=serialization.PrivateFormat.PKCS8,
73
+ encryption_algorithm=serialization.NoEncryption(),
74
+ )
75
+
76
+
77
+ def _connection_properties_from_env() -> dict[str, str]:
78
+ """Returns a dict with all possible login related env variables."""
79
+ sf_conn_prop = {
80
+ # Mandatory fields
81
+ "account": os.environ["SNOWFLAKE_ACCOUNT"],
82
+ "database": os.environ["SNOWFLAKE_DATABASE"],
83
+ # With a default value
84
+ "token_file": os.getenv("SNOWFLAKE_TOKEN_FILE", "/snowflake/session/token"),
85
+ "ssl": os.getenv("SNOWFLAKE_SSL", "on"),
86
+ "protocol": os.getenv("SNOWFLAKE_PROTOCOL", "https"),
87
+ }
88
+ # With empty default value
89
+ for key, env_var in {
90
+ "user": "SNOWFLAKE_USER",
91
+ "authenticator": "SNOWFLAKE_AUTHENTICATOR",
92
+ "password": "SNOWFLAKE_PASSWORD",
93
+ "host": "SNOWFLAKE_HOST",
94
+ "port": "SNOWFLAKE_PORT",
95
+ "schema": "SNOWFLAKE_SCHEMA",
96
+ "warehouse": "SNOWFLAKE_WAREHOUSE",
97
+ "private_key_path": "SNOWFLAKE_PRIVATE_KEY_PATH",
98
+ }.items():
99
+ value = os.getenv(env_var, "")
100
+ if value:
101
+ sf_conn_prop[key] = value
102
+ return sf_conn_prop
103
+
104
+
105
+ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]:
106
+ """Loads the dictionary from snowsql config file."""
107
+ snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
108
+ if not os.path.exists(snowsql_config_file):
109
+ logging.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
110
+ raise Exception("Snowflake SnowSQL config not found.")
111
+
112
+ config = configparser.ConfigParser(inline_comment_prefixes="#")
113
+
114
+ snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME")
115
+ if snowflake_connection_name is not None:
116
+ connection_name = snowflake_connection_name
117
+
118
+ if connection_name:
119
+ if not connection_name.startswith("connections."):
120
+ connection_name = "connections." + connection_name
121
+ else:
122
+ # See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
123
+ connection_name = "connections"
124
+
125
+ logging.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
126
+ config.read(snowsql_config_file)
127
+ conn_params = dict(config[connection_name])
128
+ # Remap names to appropriate args in Python Connector API
129
+ # Note: "dbname" should become "database"
130
+ conn_params = {k.replace("name", ""): v.strip('"') for k, v in conn_params.items()}
131
+ if "db" in conn_params:
132
+ conn_params["database"] = conn_params["db"]
133
+ del conn_params["db"]
134
+ return conn_params
135
+
136
+
137
+ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
138
+ """Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
139
+
140
+ NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order:
141
+ 1. If SNOWFLAKE_TOKEN is defined in the environment, it will be used.
142
+ 2. If SNOWFLAKE_TOKEN_FILE is defined in the environment and file matching the value found, content of the file
143
+ will be used.
144
+
145
+ If token is found, username, password will be reset and 'authenticator' will be set to 'oauth'.
146
+
147
+ Python Connector:
148
+ >> ctx = snowflake.connector.connect(**(SnowflakeLoginOptions()))
149
+
150
+ Snowpark Session:
151
+ >> session = Session.builder.configs(SnowflakeLoginOptions()).create()
152
+
153
+ Usage Note:
154
+ Ideally one should have a snowsql config file. Read more here:
155
+ https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
156
+
157
+ If snowsql config file does not exist, it tries auth from env variables.
158
+
159
+ Args:
160
+ connection_name: Name of the connection to look for inside the config file. If environment variable
161
+ SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name.
162
+ login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE).
163
+
164
+ Returns:
165
+ A dict with connection parameters.
166
+
167
+ Raises:
168
+ Exception: if none of config file and environment variable are present.
169
+ """
170
+ conn_prop: dict[str, Union[str, bytes]] = {}
171
+ login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE)
172
+ # If login file exists, use this exclusively.
173
+ if os.path.exists(login_file):
174
+ conn_prop = {**(_load_from_snowsql_config_file(connection_name, login_file))}
175
+ else:
176
+ # If environment exists for SNOWFLAKE_ACCOUNT, assume everything
177
+ # comes from environment. Mixing it not allowed.
178
+ account = os.getenv("SNOWFLAKE_ACCOUNT", "")
179
+ if account:
180
+ conn_prop = {**_connection_properties_from_env()}
181
+ else:
182
+ raise Exception("Snowflake credential is neither set in env nor a login file was provided.")
183
+
184
+ # Token, if specified, is always side-loaded in all cases.
185
+ token = _read_token(str(conn_prop["token_file"]) if "token_file" in conn_prop else "")
186
+ if token:
187
+ conn_prop["token"] = token
188
+ if "authenticator" not in conn_prop or conn_prop["authenticator"]:
189
+ conn_prop["authenticator"] = "oauth"
190
+ elif "private_key_path" in conn_prop and "private_key" not in conn_prop:
191
+ conn_prop["private_key"] = _load_pem_to_der(str(conn_prop["private_key_path"]))
192
+
193
+ if "ssl" in conn_prop and conn_prop["ssl"].lower() == "off":
194
+ conn_prop["protocol"] = "http"
195
+
196
+ return conn_prop
@@ -249,7 +249,7 @@ class DataConnector:
249
249
 
250
250
  # Switch to use Runtime's Data Ingester if running in ML runtime
251
251
  # Fail silently if the data ingester is not found
252
- if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR):
252
+ if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR, "").lower() in ("true", "1"):
253
253
  try:
254
254
  from runtime_external_entities import get_ingester_class
255
255
 
@@ -7,6 +7,7 @@ from snowflake.ml.jobs.manager import (
7
7
  list_jobs,
8
8
  submit_directory,
9
9
  submit_file,
10
+ submit_from_stage,
10
11
  )
11
12
 
12
13
  __all__ = [
@@ -18,4 +19,5 @@ __all__ = [
18
19
  "delete_job",
19
20
  "MLJob",
20
21
  "JOB_STATUS",
22
+ "submit_from_stage",
21
23
  ]
@@ -5,6 +5,8 @@ from snowflake.ml.jobs._utils.types import ComputeResources
5
5
  DEFAULT_CONTAINER_NAME = "main"
6
6
  PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
7
  RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
8
+ MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
9
+ RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
8
10
  MEMORY_VOLUME_NAME = "dshm"
9
11
  STAGE_VOLUME_NAME = "stage-volume"
10
12
  STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
@@ -13,7 +15,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
13
15
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
14
16
  DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
15
17
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
16
- DEFAULT_IMAGE_TAG = "1.2.3"
18
+ DEFAULT_IMAGE_TAG = "1.4.2"
17
19
  DEFAULT_ENTRYPOINT_PATH = "func.py"
18
20
 
19
21
  # Percent of container memory to allocate for /dev/shm volume
@@ -37,16 +39,24 @@ RAY_PORTS = {
37
39
  # Node health check configuration
38
40
  # TODO(SNOW-1937020): Revisit the health check configuration
39
41
  ML_RUNTIME_HEALTH_CHECK_PORT = "5001"
42
+ ENABLE_HEALTH_CHECKS_ENV_VAR = "ENABLE_HEALTH_CHECKS"
40
43
  ENABLE_HEALTH_CHECKS = "false"
41
44
 
42
45
  # Job status polling constants
43
46
  JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
44
- JOB_POLL_MAX_DELAY_SECONDS = 1
47
+ JOB_POLL_MAX_DELAY_SECONDS = 30
45
48
 
46
49
  # Magic attributes
47
50
  IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
48
51
  RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
49
52
 
53
+ # Log start and end messages
54
+ LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
55
+ LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
56
+
57
+ # Default setting for verbose logging in get_log function
58
+ DEFAULT_VERBOSE_LOG = False
59
+
50
60
  # Compute pool resource information
51
61
  # TODO: Query Snowflake for resource information instead of relying on this hardcoded
52
62
  # table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
@@ -0,0 +1,43 @@
1
+ import inspect
2
+ from typing import Any, Callable, Optional
3
+
4
+ from snowflake import snowpark
5
+ from snowflake.snowpark import context as sp_context
6
+
7
+
8
+ class FunctionPayload:
9
+ def __init__(
10
+ self,
11
+ func: Callable[..., Any],
12
+ session: Optional[snowpark.Session] = None,
13
+ session_argument: str = "",
14
+ *args: Any,
15
+ **kwargs: Any
16
+ ) -> None:
17
+ self.function = func
18
+ self.args = args
19
+ self.kwargs = kwargs
20
+ self._session = session
21
+ self._session_argument = session_argument
22
+
23
+ @property
24
+ def session(self) -> Optional[snowpark.Session]:
25
+ return self._session
26
+
27
+ def __getstate__(self) -> dict[str, Any]:
28
+ """Customize pickling to exclude session."""
29
+ state = self.__dict__.copy()
30
+ state["_session"] = None
31
+ return state
32
+
33
+ def __setstate__(self, state: dict[str, Any]) -> None:
34
+ """Restore session from context during unpickling."""
35
+ self.__dict__.update(state)
36
+ self._session = sp_context.get_active_session()
37
+
38
+ def __call__(self) -> Any:
39
+ sig = inspect.signature(self.function)
40
+ bound = sig.bind_partial(*self.args, **self.kwargs)
41
+ bound.arguments[self._session_argument] = self._session
42
+
43
+ return self.function(*bound.args, **bound.kwargs)
@@ -80,7 +80,7 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
80
80
  # TODO: Check if file exists
81
81
  with session.file.get_stream(result_path) as result_stream:
82
82
  return ExecutionResult.from_dict(pickle.load(result_stream))
83
- except (sp_exceptions.SnowparkSQLException, TypeError, pickle.UnpicklingError):
83
+ except (sp_exceptions.SnowparkSQLException, pickle.UnpicklingError, TypeError, ImportError):
84
84
  # Fall back to JSON result if loading pickled result fails for any reason
85
85
  result_json_path = os.path.splitext(result_path)[0] + ".json"
86
86
  with session.file.get_stream(result_json_path) as result_stream: