snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +6 -9
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +3 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +89 -40
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
- snowflake/ml/jobs/_utils/spec_utils.py +29 -5
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +20 -28
- snowflake/ml/jobs/job.py +197 -61
- snowflake/ml/jobs/manager.py +253 -121
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +18 -6
- snowflake/ml/model/_client/ops/service_ops.py +23 -6
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_client/sql/service.py +68 -20
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/monitoring/explain_visualize.py +2 -2
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +144 -47
- snowflake/ml/utils/connection_params.py +1 -1
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -411,16 +411,13 @@ def send_custom_usage(
|
|
411
411
|
**kwargs: Any,
|
412
412
|
) -> None:
|
413
413
|
conn = _get_snowflake_connection()
|
414
|
-
if conn is None:
|
415
|
-
raise ValueError(
|
416
|
-
"""Snowflake connection is required to send custom telemetry. This means there
|
417
|
-
must be at least one active session, or that telemetry is being sent from within an SPCS service."""
|
418
|
-
)
|
419
414
|
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
415
|
+
# Send telemetry if Snowflake connection is available.
|
416
|
+
if conn is not None:
|
417
|
+
client = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
|
418
|
+
common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
|
419
|
+
data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
|
420
|
+
client._send(msg=data)
|
424
421
|
|
425
422
|
|
426
423
|
def send_api_usage_telemetry(
|
@@ -0,0 +1,196 @@
|
|
1
|
+
import configparser
|
2
|
+
import os
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
from absl import logging
|
6
|
+
from cryptography.hazmat import backends
|
7
|
+
from cryptography.hazmat.primitives import serialization
|
8
|
+
|
9
|
+
_DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
|
10
|
+
|
11
|
+
|
12
|
+
def _read_token(token_file: str = "") -> str:
|
13
|
+
"""
|
14
|
+
Reads token from environment or file provided.
|
15
|
+
|
16
|
+
First tries to read the token from environment variable
|
17
|
+
(`SNOWFLAKE_TOKEN`) followed by the token file.
|
18
|
+
Both the options are tried out in SnowServices.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
token_file: File from which token needs to be read. Optional.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
the token.
|
25
|
+
"""
|
26
|
+
token = os.getenv("SNOWFLAKE_TOKEN", "")
|
27
|
+
if token:
|
28
|
+
return token
|
29
|
+
if token_file and os.path.exists(token_file):
|
30
|
+
with open(token_file) as f:
|
31
|
+
token = f.read()
|
32
|
+
return token
|
33
|
+
|
34
|
+
|
35
|
+
_ENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN ENCRYPTED PRIVATE KEY-----"
|
36
|
+
_UNENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN PRIVATE KEY-----"
|
37
|
+
|
38
|
+
|
39
|
+
def _load_pem_to_der(private_key_path: str) -> bytes:
|
40
|
+
"""Given a private key file path (in PEM format), decode key data into DER format."""
|
41
|
+
with open(private_key_path, "rb") as f:
|
42
|
+
private_key_pem = f.read()
|
43
|
+
private_key_passphrase: Optional[str] = os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", None)
|
44
|
+
|
45
|
+
# Only PKCS#8 format key will be accepted. However, openssl
|
46
|
+
# transparently handle PKCS#8 and PKCS#1 format (by some fallback
|
47
|
+
# logic) and their is no function to distinguish between them. By
|
48
|
+
# reading openssl source code, apparently they also relies on header
|
49
|
+
# to determine if give bytes is PKCS#8 format or not
|
50
|
+
if not private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and not private_key_pem.startswith(
|
51
|
+
_UNENCRYPTED_PKCS8_PK_HEADER
|
52
|
+
):
|
53
|
+
raise Exception("Private key provided is not in PKCS#8 format. Please use correct format.")
|
54
|
+
|
55
|
+
if private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and private_key_passphrase is None:
|
56
|
+
raise Exception(
|
57
|
+
"Private key is encrypted but passphrase could not be found. "
|
58
|
+
"Please set SNOWFLAKE_PRIVATE_KEY_PASSPHRASE env variable."
|
59
|
+
)
|
60
|
+
|
61
|
+
if private_key_pem.startswith(_UNENCRYPTED_PKCS8_PK_HEADER):
|
62
|
+
private_key_passphrase = None
|
63
|
+
|
64
|
+
private_key = serialization.load_pem_private_key(
|
65
|
+
private_key_pem,
|
66
|
+
str.encode(private_key_passphrase) if private_key_passphrase is not None else private_key_passphrase,
|
67
|
+
backends.default_backend(),
|
68
|
+
)
|
69
|
+
|
70
|
+
return private_key.private_bytes(
|
71
|
+
encoding=serialization.Encoding.DER,
|
72
|
+
format=serialization.PrivateFormat.PKCS8,
|
73
|
+
encryption_algorithm=serialization.NoEncryption(),
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
def _connection_properties_from_env() -> dict[str, str]:
|
78
|
+
"""Returns a dict with all possible login related env variables."""
|
79
|
+
sf_conn_prop = {
|
80
|
+
# Mandatory fields
|
81
|
+
"account": os.environ["SNOWFLAKE_ACCOUNT"],
|
82
|
+
"database": os.environ["SNOWFLAKE_DATABASE"],
|
83
|
+
# With a default value
|
84
|
+
"token_file": os.getenv("SNOWFLAKE_TOKEN_FILE", "/snowflake/session/token"),
|
85
|
+
"ssl": os.getenv("SNOWFLAKE_SSL", "on"),
|
86
|
+
"protocol": os.getenv("SNOWFLAKE_PROTOCOL", "https"),
|
87
|
+
}
|
88
|
+
# With empty default value
|
89
|
+
for key, env_var in {
|
90
|
+
"user": "SNOWFLAKE_USER",
|
91
|
+
"authenticator": "SNOWFLAKE_AUTHENTICATOR",
|
92
|
+
"password": "SNOWFLAKE_PASSWORD",
|
93
|
+
"host": "SNOWFLAKE_HOST",
|
94
|
+
"port": "SNOWFLAKE_PORT",
|
95
|
+
"schema": "SNOWFLAKE_SCHEMA",
|
96
|
+
"warehouse": "SNOWFLAKE_WAREHOUSE",
|
97
|
+
"private_key_path": "SNOWFLAKE_PRIVATE_KEY_PATH",
|
98
|
+
}.items():
|
99
|
+
value = os.getenv(env_var, "")
|
100
|
+
if value:
|
101
|
+
sf_conn_prop[key] = value
|
102
|
+
return sf_conn_prop
|
103
|
+
|
104
|
+
|
105
|
+
def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]:
|
106
|
+
"""Loads the dictionary from snowsql config file."""
|
107
|
+
snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
108
|
+
if not os.path.exists(snowsql_config_file):
|
109
|
+
logging.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
|
110
|
+
raise Exception("Snowflake SnowSQL config not found.")
|
111
|
+
|
112
|
+
config = configparser.ConfigParser(inline_comment_prefixes="#")
|
113
|
+
|
114
|
+
snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME")
|
115
|
+
if snowflake_connection_name is not None:
|
116
|
+
connection_name = snowflake_connection_name
|
117
|
+
|
118
|
+
if connection_name:
|
119
|
+
if not connection_name.startswith("connections."):
|
120
|
+
connection_name = "connections." + connection_name
|
121
|
+
else:
|
122
|
+
# See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
123
|
+
connection_name = "connections"
|
124
|
+
|
125
|
+
logging.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
|
126
|
+
config.read(snowsql_config_file)
|
127
|
+
conn_params = dict(config[connection_name])
|
128
|
+
# Remap names to appropriate args in Python Connector API
|
129
|
+
# Note: "dbname" should become "database"
|
130
|
+
conn_params = {k.replace("name", ""): v.strip('"') for k, v in conn_params.items()}
|
131
|
+
if "db" in conn_params:
|
132
|
+
conn_params["database"] = conn_params["db"]
|
133
|
+
del conn_params["db"]
|
134
|
+
return conn_params
|
135
|
+
|
136
|
+
|
137
|
+
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
|
138
|
+
"""Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
|
139
|
+
|
140
|
+
NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order:
|
141
|
+
1. If SNOWFLAKE_TOKEN is defined in the environment, it will be used.
|
142
|
+
2. If SNOWFLAKE_TOKEN_FILE is defined in the environment and file matching the value found, content of the file
|
143
|
+
will be used.
|
144
|
+
|
145
|
+
If token is found, username, password will be reset and 'authenticator' will be set to 'oauth'.
|
146
|
+
|
147
|
+
Python Connector:
|
148
|
+
>> ctx = snowflake.connector.connect(**(SnowflakeLoginOptions()))
|
149
|
+
|
150
|
+
Snowpark Session:
|
151
|
+
>> session = Session.builder.configs(SnowflakeLoginOptions()).create()
|
152
|
+
|
153
|
+
Usage Note:
|
154
|
+
Ideally one should have a snowsql config file. Read more here:
|
155
|
+
https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
156
|
+
|
157
|
+
If snowsql config file does not exist, it tries auth from env variables.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
connection_name: Name of the connection to look for inside the config file. If environment variable
|
161
|
+
SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name.
|
162
|
+
login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE).
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
A dict with connection parameters.
|
166
|
+
|
167
|
+
Raises:
|
168
|
+
Exception: if none of config file and environment variable are present.
|
169
|
+
"""
|
170
|
+
conn_prop: dict[str, Union[str, bytes]] = {}
|
171
|
+
login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
172
|
+
# If login file exists, use this exclusively.
|
173
|
+
if os.path.exists(login_file):
|
174
|
+
conn_prop = {**(_load_from_snowsql_config_file(connection_name, login_file))}
|
175
|
+
else:
|
176
|
+
# If environment exists for SNOWFLAKE_ACCOUNT, assume everything
|
177
|
+
# comes from environment. Mixing it not allowed.
|
178
|
+
account = os.getenv("SNOWFLAKE_ACCOUNT", "")
|
179
|
+
if account:
|
180
|
+
conn_prop = {**_connection_properties_from_env()}
|
181
|
+
else:
|
182
|
+
raise Exception("Snowflake credential is neither set in env nor a login file was provided.")
|
183
|
+
|
184
|
+
# Token, if specified, is always side-loaded in all cases.
|
185
|
+
token = _read_token(str(conn_prop["token_file"]) if "token_file" in conn_prop else "")
|
186
|
+
if token:
|
187
|
+
conn_prop["token"] = token
|
188
|
+
if "authenticator" not in conn_prop or conn_prop["authenticator"]:
|
189
|
+
conn_prop["authenticator"] = "oauth"
|
190
|
+
elif "private_key_path" in conn_prop and "private_key" not in conn_prop:
|
191
|
+
conn_prop["private_key"] = _load_pem_to_der(str(conn_prop["private_key_path"]))
|
192
|
+
|
193
|
+
if "ssl" in conn_prop and conn_prop["ssl"].lower() == "off":
|
194
|
+
conn_prop["protocol"] = "http"
|
195
|
+
|
196
|
+
return conn_prop
|
@@ -240,7 +240,7 @@ def get_schema_level_object_identifier(
|
|
240
240
|
"""
|
241
241
|
|
242
242
|
for identifier in (db, schema, object_name):
|
243
|
-
if identifier is not None and SF_IDENTIFIER_RE.
|
243
|
+
if identifier is not None and SF_IDENTIFIER_RE.fullmatch(identifier) is None:
|
244
244
|
raise ValueError(f"Invalid identifier {identifier}")
|
245
245
|
|
246
246
|
if others is None:
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from typing import Any, Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import identifier
|
4
|
+
from snowflake.snowpark import session
|
5
|
+
|
6
|
+
|
7
|
+
class SerializableSessionMixin:
|
8
|
+
"""Mixin that provides pickling capabilities for objects with Snowpark sessions."""
|
9
|
+
|
10
|
+
def __getstate__(self) -> dict[str, Any]:
|
11
|
+
"""Customize pickling to exclude non-serializable session and related components."""
|
12
|
+
state = self.__dict__.copy()
|
13
|
+
|
14
|
+
# Save session metadata for validation during unpickling
|
15
|
+
if hasattr(self, "_session") and self._session is not None:
|
16
|
+
try:
|
17
|
+
state["__session-account__"] = self._session.get_current_account()
|
18
|
+
state["__session-role__"] = self._session.get_current_role()
|
19
|
+
state["__session-database__"] = self._session.get_current_database()
|
20
|
+
state["__session-schema__"] = self._session.get_current_schema()
|
21
|
+
except Exception:
|
22
|
+
pass
|
23
|
+
|
24
|
+
state["_session"] = None
|
25
|
+
return state
|
26
|
+
|
27
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
28
|
+
"""Restore session from context during unpickling."""
|
29
|
+
saved_account = state.pop("__session-account__", None)
|
30
|
+
saved_role = state.pop("__session-role__", None)
|
31
|
+
saved_database = state.pop("__session-database__", None)
|
32
|
+
saved_schema = state.pop("__session-schema__", None)
|
33
|
+
self.__dict__.update(state)
|
34
|
+
|
35
|
+
if saved_account is not None:
|
36
|
+
|
37
|
+
def identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
|
38
|
+
saved_resolved = identifier.resolve_identifier(saved) if saved is not None else saved
|
39
|
+
current_resolved = identifier.resolve_identifier(current) if current is not None else current
|
40
|
+
return saved_resolved == current_resolved
|
41
|
+
|
42
|
+
for active_session in session._get_active_sessions():
|
43
|
+
try:
|
44
|
+
current_account = active_session.get_current_account()
|
45
|
+
current_role = active_session.get_current_role()
|
46
|
+
current_database = active_session.get_current_database()
|
47
|
+
current_schema = active_session.get_current_schema()
|
48
|
+
|
49
|
+
if (
|
50
|
+
identifiers_match(saved_account, current_account)
|
51
|
+
and identifiers_match(saved_role, current_role)
|
52
|
+
and identifiers_match(saved_database, current_database)
|
53
|
+
and identifiers_match(saved_schema, current_schema)
|
54
|
+
):
|
55
|
+
self._session = active_session
|
56
|
+
return
|
57
|
+
except Exception:
|
58
|
+
continue
|
59
|
+
|
60
|
+
# No matching session found or no metadata available
|
61
|
+
raise RuntimeError("No active Snowpark session available. Please create a session.")
|
snowflake/ml/jobs/__init__.py
CHANGED
@@ -6,6 +6,7 @@ DEFAULT_CONTAINER_NAME = "main"
|
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
7
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
8
8
|
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
9
|
+
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
9
10
|
MEMORY_VOLUME_NAME = "dshm"
|
10
11
|
STAGE_VOLUME_NAME = "stage-volume"
|
11
12
|
STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
@@ -14,7 +15,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
|
14
15
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
15
16
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
16
17
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
17
|
-
DEFAULT_IMAGE_TAG = "1.
|
18
|
+
DEFAULT_IMAGE_TAG = "1.5.0"
|
18
19
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
19
20
|
|
20
21
|
# Percent of container memory to allocate for /dev/shm volume
|
@@ -43,7 +44,7 @@ ENABLE_HEALTH_CHECKS = "false"
|
|
43
44
|
|
44
45
|
# Job status polling constants
|
45
46
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
46
|
-
JOB_POLL_MAX_DELAY_SECONDS =
|
47
|
+
JOB_POLL_MAX_DELAY_SECONDS = 30
|
47
48
|
|
48
49
|
# Magic attributes
|
49
50
|
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import inspect
|
2
|
+
from typing import Any, Callable, Optional
|
3
|
+
|
4
|
+
from snowflake import snowpark
|
5
|
+
from snowflake.snowpark import context as sp_context
|
6
|
+
|
7
|
+
|
8
|
+
class FunctionPayload:
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
func: Callable[..., Any],
|
12
|
+
session: Optional[snowpark.Session] = None,
|
13
|
+
session_argument: str = "",
|
14
|
+
*args: Any,
|
15
|
+
**kwargs: Any
|
16
|
+
) -> None:
|
17
|
+
self.function = func
|
18
|
+
self.args = args
|
19
|
+
self.kwargs = kwargs
|
20
|
+
self._session = session
|
21
|
+
self._session_argument = session_argument
|
22
|
+
|
23
|
+
@property
|
24
|
+
def session(self) -> Optional[snowpark.Session]:
|
25
|
+
return self._session
|
26
|
+
|
27
|
+
def __getstate__(self) -> dict[str, Any]:
|
28
|
+
"""Customize pickling to exclude session."""
|
29
|
+
state = self.__dict__.copy()
|
30
|
+
state["_session"] = None
|
31
|
+
return state
|
32
|
+
|
33
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
34
|
+
"""Restore session from context during unpickling."""
|
35
|
+
self.__dict__.update(state)
|
36
|
+
self._session = sp_context.get_active_session()
|
37
|
+
|
38
|
+
def __call__(self) -> Any:
|
39
|
+
sig = inspect.signature(self.function)
|
40
|
+
bound = sig.bind_partial(*self.args, **self.kwargs)
|
41
|
+
bound.arguments[self._session_argument] = self._session
|
42
|
+
|
43
|
+
return self.function(*bound.args, **bound.kwargs)
|
@@ -75,16 +75,75 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
|
|
75
75
|
|
76
76
|
Returns:
|
77
77
|
A dictionary containing the execution result if available, None otherwise.
|
78
|
+
|
79
|
+
Raises:
|
80
|
+
RuntimeError: If both pickle and JSON result retrieval fail.
|
78
81
|
"""
|
79
82
|
try:
|
80
83
|
# TODO: Check if file exists
|
81
84
|
with session.file.get_stream(result_path) as result_stream:
|
82
85
|
return ExecutionResult.from_dict(pickle.load(result_stream))
|
83
|
-
except (
|
86
|
+
except (
|
87
|
+
sp_exceptions.SnowparkSQLException,
|
88
|
+
pickle.UnpicklingError,
|
89
|
+
TypeError,
|
90
|
+
ImportError,
|
91
|
+
AttributeError,
|
92
|
+
MemoryError,
|
93
|
+
) as pickle_error:
|
84
94
|
# Fall back to JSON result if loading pickled result fails for any reason
|
85
|
-
|
86
|
-
|
87
|
-
|
95
|
+
try:
|
96
|
+
result_json_path = os.path.splitext(result_path)[0] + ".json"
|
97
|
+
with session.file.get_stream(result_json_path) as result_stream:
|
98
|
+
return ExecutionResult.from_dict(json.load(result_stream))
|
99
|
+
except Exception as json_error:
|
100
|
+
# Both pickle and JSON failed - provide helpful error message
|
101
|
+
raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error
|
102
|
+
|
103
|
+
|
104
|
+
def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str:
|
105
|
+
"""Create helpful error messages for common result retrieval failures."""
|
106
|
+
|
107
|
+
# Package import issues
|
108
|
+
if isinstance(error, ImportError):
|
109
|
+
return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}"
|
110
|
+
|
111
|
+
# Package versions differ between runtime and local environment
|
112
|
+
if isinstance(error, AttributeError):
|
113
|
+
return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}"
|
114
|
+
|
115
|
+
# Serialization issues
|
116
|
+
if isinstance(error, TypeError):
|
117
|
+
return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}"
|
118
|
+
|
119
|
+
# Python version pickling incompatibility
|
120
|
+
if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower():
|
121
|
+
# TODO: Update this once we support different Python versions
|
122
|
+
client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}"
|
123
|
+
runtime_version = "Python 3.10"
|
124
|
+
return (
|
125
|
+
f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, "
|
126
|
+
f"local environment using Python {client_version}. Error: {str(error)}"
|
127
|
+
)
|
128
|
+
|
129
|
+
# File access issues
|
130
|
+
if isinstance(error, sp_exceptions.SnowparkSQLException):
|
131
|
+
if "not found" in str(error).lower() or "does not exist" in str(error).lower():
|
132
|
+
return (
|
133
|
+
f"Failed to retrieve job result: No result file found. Check job.get_logs() for execution "
|
134
|
+
f"errors. Error: {str(error)}"
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
return f"Failed to retrieve job result: Cannot access result file. Error: {str(error)}"
|
138
|
+
|
139
|
+
if isinstance(error, MemoryError):
|
140
|
+
return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}"
|
141
|
+
|
142
|
+
# Generic fallback
|
143
|
+
base_message = f"Failed to retrieve job result: {str(error)}"
|
144
|
+
if json_error:
|
145
|
+
base_message += f" (JSON fallback also failed: {str(json_error)})"
|
146
|
+
return base_message
|
88
147
|
|
89
148
|
|
90
149
|
def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
|