snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +42 -16
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +12 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +95 -39
- snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
- snowflake/ml/jobs/_utils/spec_utils.py +30 -6
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +10 -7
- snowflake/ml/jobs/job.py +176 -28
- snowflake/ml/jobs/manager.py +119 -26
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +24 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +73 -28
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_composer.py +3 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/monitoring/explain_visualize.py +160 -22
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/registry.py +34 -14
- snowflake/ml/utils/connection_params.py +9 -3
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import enum
|
|
4
4
|
import functools
|
5
5
|
import inspect
|
6
6
|
import operator
|
7
|
+
import os
|
7
8
|
import sys
|
8
9
|
import time
|
9
10
|
import traceback
|
@@ -13,7 +14,7 @@ from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union, c
|
|
13
14
|
from typing_extensions import ParamSpec
|
14
15
|
|
15
16
|
from snowflake import connector
|
16
|
-
from snowflake.connector import telemetry as connector_telemetry, time_util
|
17
|
+
from snowflake.connector import connect, telemetry as connector_telemetry, time_util
|
17
18
|
from snowflake.ml import version as snowml_version
|
18
19
|
from snowflake.ml._internal import env
|
19
20
|
from snowflake.ml._internal.exceptions import (
|
@@ -37,6 +38,37 @@ _Args = ParamSpec("_Args")
|
|
37
38
|
_ReturnValue = TypeVar("_ReturnValue")
|
38
39
|
|
39
40
|
|
41
|
+
def _get_login_token() -> Union[str, bytes]:
|
42
|
+
with open("/snowflake/session/token") as f:
|
43
|
+
return f.read()
|
44
|
+
|
45
|
+
|
46
|
+
def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
|
47
|
+
conn = None
|
48
|
+
if os.getenv("SNOWFLAKE_HOST") is not None and os.getenv("SNOWFLAKE_ACCOUNT") is not None:
|
49
|
+
try:
|
50
|
+
conn = connect(
|
51
|
+
host=os.getenv("SNOWFLAKE_HOST"),
|
52
|
+
account=os.getenv("SNOWFLAKE_ACCOUNT"),
|
53
|
+
token=_get_login_token(),
|
54
|
+
authenticator="oauth",
|
55
|
+
)
|
56
|
+
except Exception:
|
57
|
+
# Failed to get a new SnowflakeConnection in SPCS. Fall back to using the active session.
|
58
|
+
# This will work in some cases once SPCS enables multiple authentication modes, and users select any auth.
|
59
|
+
pass
|
60
|
+
|
61
|
+
if conn is None:
|
62
|
+
try:
|
63
|
+
active_session = next(iter(session._get_active_sessions()))
|
64
|
+
conn = active_session._conn._conn if active_session.telemetry_enabled else None
|
65
|
+
except snowpark_exceptions.SnowparkSessionException:
|
66
|
+
# Failed to get an active session. No connection available.
|
67
|
+
pass
|
68
|
+
|
69
|
+
return conn
|
70
|
+
|
71
|
+
|
40
72
|
@enum.unique
|
41
73
|
class TelemetryProject(enum.Enum):
|
42
74
|
MLOPS = "MLOps"
|
@@ -378,13 +410,14 @@ def send_custom_usage(
|
|
378
410
|
data: Optional[dict[str, Any]] = None,
|
379
411
|
**kwargs: Any,
|
380
412
|
) -> None:
|
381
|
-
|
382
|
-
assert active_session, "Missing active session object"
|
413
|
+
conn = _get_snowflake_connection()
|
383
414
|
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
415
|
+
# Send telemetry if Snowflake connection is available.
|
416
|
+
if conn is not None:
|
417
|
+
client = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
|
418
|
+
common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
|
419
|
+
data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
|
420
|
+
client._send(msg=data)
|
388
421
|
|
389
422
|
|
390
423
|
def send_api_usage_telemetry(
|
@@ -501,7 +534,6 @@ def send_api_usage_telemetry(
|
|
501
534
|
return update_stmt_params_if_snowpark_df(result, statement_params)
|
502
535
|
|
503
536
|
# prioritize `conn_attr_name` over the active session
|
504
|
-
telemetry_enabled = True
|
505
537
|
if conn_attr_name:
|
506
538
|
# raise AttributeError if conn attribute does not exist in `self`
|
507
539
|
conn = operator.attrgetter(conn_attr_name)(args[0])
|
@@ -509,16 +541,10 @@ def send_api_usage_telemetry(
|
|
509
541
|
raise TypeError(
|
510
542
|
f"Expected a conn object of type {' or '.join(_CONNECTION_TYPES.keys())} but got {type(conn)}"
|
511
543
|
)
|
512
|
-
# get an active session
|
513
544
|
else:
|
514
|
-
|
515
|
-
active_session = next(iter(session._get_active_sessions()))
|
516
|
-
conn = active_session._conn._conn
|
517
|
-
telemetry_enabled = active_session.telemetry_enabled
|
518
|
-
except snowpark_exceptions.SnowparkSessionException:
|
519
|
-
conn = None
|
545
|
+
conn = _get_snowflake_connection()
|
520
546
|
|
521
|
-
if conn is None
|
547
|
+
if conn is None:
|
522
548
|
# Telemetry not enabled, just execute without our additional telemetry logic
|
523
549
|
try:
|
524
550
|
return ctx.run(execute_func_with_statement_params)
|
@@ -0,0 +1,196 @@
|
|
1
|
+
import configparser
|
2
|
+
import os
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
from absl import logging
|
6
|
+
from cryptography.hazmat import backends
|
7
|
+
from cryptography.hazmat.primitives import serialization
|
8
|
+
|
9
|
+
_DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
|
10
|
+
|
11
|
+
|
12
|
+
def _read_token(token_file: str = "") -> str:
|
13
|
+
"""
|
14
|
+
Reads token from environment or file provided.
|
15
|
+
|
16
|
+
First tries to read the token from environment variable
|
17
|
+
(`SNOWFLAKE_TOKEN`) followed by the token file.
|
18
|
+
Both the options are tried out in SnowServices.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
token_file: File from which token needs to be read. Optional.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
the token.
|
25
|
+
"""
|
26
|
+
token = os.getenv("SNOWFLAKE_TOKEN", "")
|
27
|
+
if token:
|
28
|
+
return token
|
29
|
+
if token_file and os.path.exists(token_file):
|
30
|
+
with open(token_file) as f:
|
31
|
+
token = f.read()
|
32
|
+
return token
|
33
|
+
|
34
|
+
|
35
|
+
_ENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN ENCRYPTED PRIVATE KEY-----"
|
36
|
+
_UNENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN PRIVATE KEY-----"
|
37
|
+
|
38
|
+
|
39
|
+
def _load_pem_to_der(private_key_path: str) -> bytes:
|
40
|
+
"""Given a private key file path (in PEM format), decode key data into DER format."""
|
41
|
+
with open(private_key_path, "rb") as f:
|
42
|
+
private_key_pem = f.read()
|
43
|
+
private_key_passphrase: Optional[str] = os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", None)
|
44
|
+
|
45
|
+
# Only PKCS#8 format key will be accepted. However, openssl
|
46
|
+
# transparently handle PKCS#8 and PKCS#1 format (by some fallback
|
47
|
+
# logic) and their is no function to distinguish between them. By
|
48
|
+
# reading openssl source code, apparently they also relies on header
|
49
|
+
# to determine if give bytes is PKCS#8 format or not
|
50
|
+
if not private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and not private_key_pem.startswith(
|
51
|
+
_UNENCRYPTED_PKCS8_PK_HEADER
|
52
|
+
):
|
53
|
+
raise Exception("Private key provided is not in PKCS#8 format. Please use correct format.")
|
54
|
+
|
55
|
+
if private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and private_key_passphrase is None:
|
56
|
+
raise Exception(
|
57
|
+
"Private key is encrypted but passphrase could not be found. "
|
58
|
+
"Please set SNOWFLAKE_PRIVATE_KEY_PASSPHRASE env variable."
|
59
|
+
)
|
60
|
+
|
61
|
+
if private_key_pem.startswith(_UNENCRYPTED_PKCS8_PK_HEADER):
|
62
|
+
private_key_passphrase = None
|
63
|
+
|
64
|
+
private_key = serialization.load_pem_private_key(
|
65
|
+
private_key_pem,
|
66
|
+
str.encode(private_key_passphrase) if private_key_passphrase is not None else private_key_passphrase,
|
67
|
+
backends.default_backend(),
|
68
|
+
)
|
69
|
+
|
70
|
+
return private_key.private_bytes(
|
71
|
+
encoding=serialization.Encoding.DER,
|
72
|
+
format=serialization.PrivateFormat.PKCS8,
|
73
|
+
encryption_algorithm=serialization.NoEncryption(),
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
def _connection_properties_from_env() -> dict[str, str]:
|
78
|
+
"""Returns a dict with all possible login related env variables."""
|
79
|
+
sf_conn_prop = {
|
80
|
+
# Mandatory fields
|
81
|
+
"account": os.environ["SNOWFLAKE_ACCOUNT"],
|
82
|
+
"database": os.environ["SNOWFLAKE_DATABASE"],
|
83
|
+
# With a default value
|
84
|
+
"token_file": os.getenv("SNOWFLAKE_TOKEN_FILE", "/snowflake/session/token"),
|
85
|
+
"ssl": os.getenv("SNOWFLAKE_SSL", "on"),
|
86
|
+
"protocol": os.getenv("SNOWFLAKE_PROTOCOL", "https"),
|
87
|
+
}
|
88
|
+
# With empty default value
|
89
|
+
for key, env_var in {
|
90
|
+
"user": "SNOWFLAKE_USER",
|
91
|
+
"authenticator": "SNOWFLAKE_AUTHENTICATOR",
|
92
|
+
"password": "SNOWFLAKE_PASSWORD",
|
93
|
+
"host": "SNOWFLAKE_HOST",
|
94
|
+
"port": "SNOWFLAKE_PORT",
|
95
|
+
"schema": "SNOWFLAKE_SCHEMA",
|
96
|
+
"warehouse": "SNOWFLAKE_WAREHOUSE",
|
97
|
+
"private_key_path": "SNOWFLAKE_PRIVATE_KEY_PATH",
|
98
|
+
}.items():
|
99
|
+
value = os.getenv(env_var, "")
|
100
|
+
if value:
|
101
|
+
sf_conn_prop[key] = value
|
102
|
+
return sf_conn_prop
|
103
|
+
|
104
|
+
|
105
|
+
def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]:
|
106
|
+
"""Loads the dictionary from snowsql config file."""
|
107
|
+
snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
108
|
+
if not os.path.exists(snowsql_config_file):
|
109
|
+
logging.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
|
110
|
+
raise Exception("Snowflake SnowSQL config not found.")
|
111
|
+
|
112
|
+
config = configparser.ConfigParser(inline_comment_prefixes="#")
|
113
|
+
|
114
|
+
snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME")
|
115
|
+
if snowflake_connection_name is not None:
|
116
|
+
connection_name = snowflake_connection_name
|
117
|
+
|
118
|
+
if connection_name:
|
119
|
+
if not connection_name.startswith("connections."):
|
120
|
+
connection_name = "connections." + connection_name
|
121
|
+
else:
|
122
|
+
# See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
123
|
+
connection_name = "connections"
|
124
|
+
|
125
|
+
logging.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
|
126
|
+
config.read(snowsql_config_file)
|
127
|
+
conn_params = dict(config[connection_name])
|
128
|
+
# Remap names to appropriate args in Python Connector API
|
129
|
+
# Note: "dbname" should become "database"
|
130
|
+
conn_params = {k.replace("name", ""): v.strip('"') for k, v in conn_params.items()}
|
131
|
+
if "db" in conn_params:
|
132
|
+
conn_params["database"] = conn_params["db"]
|
133
|
+
del conn_params["db"]
|
134
|
+
return conn_params
|
135
|
+
|
136
|
+
|
137
|
+
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
|
138
|
+
"""Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
|
139
|
+
|
140
|
+
NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order:
|
141
|
+
1. If SNOWFLAKE_TOKEN is defined in the environment, it will be used.
|
142
|
+
2. If SNOWFLAKE_TOKEN_FILE is defined in the environment and file matching the value found, content of the file
|
143
|
+
will be used.
|
144
|
+
|
145
|
+
If token is found, username, password will be reset and 'authenticator' will be set to 'oauth'.
|
146
|
+
|
147
|
+
Python Connector:
|
148
|
+
>> ctx = snowflake.connector.connect(**(SnowflakeLoginOptions()))
|
149
|
+
|
150
|
+
Snowpark Session:
|
151
|
+
>> session = Session.builder.configs(SnowflakeLoginOptions()).create()
|
152
|
+
|
153
|
+
Usage Note:
|
154
|
+
Ideally one should have a snowsql config file. Read more here:
|
155
|
+
https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
156
|
+
|
157
|
+
If snowsql config file does not exist, it tries auth from env variables.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
connection_name: Name of the connection to look for inside the config file. If environment variable
|
161
|
+
SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name.
|
162
|
+
login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE).
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
A dict with connection parameters.
|
166
|
+
|
167
|
+
Raises:
|
168
|
+
Exception: if none of config file and environment variable are present.
|
169
|
+
"""
|
170
|
+
conn_prop: dict[str, Union[str, bytes]] = {}
|
171
|
+
login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
172
|
+
# If login file exists, use this exclusively.
|
173
|
+
if os.path.exists(login_file):
|
174
|
+
conn_prop = {**(_load_from_snowsql_config_file(connection_name, login_file))}
|
175
|
+
else:
|
176
|
+
# If environment exists for SNOWFLAKE_ACCOUNT, assume everything
|
177
|
+
# comes from environment. Mixing it not allowed.
|
178
|
+
account = os.getenv("SNOWFLAKE_ACCOUNT", "")
|
179
|
+
if account:
|
180
|
+
conn_prop = {**_connection_properties_from_env()}
|
181
|
+
else:
|
182
|
+
raise Exception("Snowflake credential is neither set in env nor a login file was provided.")
|
183
|
+
|
184
|
+
# Token, if specified, is always side-loaded in all cases.
|
185
|
+
token = _read_token(str(conn_prop["token_file"]) if "token_file" in conn_prop else "")
|
186
|
+
if token:
|
187
|
+
conn_prop["token"] = token
|
188
|
+
if "authenticator" not in conn_prop or conn_prop["authenticator"]:
|
189
|
+
conn_prop["authenticator"] = "oauth"
|
190
|
+
elif "private_key_path" in conn_prop and "private_key" not in conn_prop:
|
191
|
+
conn_prop["private_key"] = _load_pem_to_der(str(conn_prop["private_key_path"]))
|
192
|
+
|
193
|
+
if "ssl" in conn_prop and conn_prop["ssl"].lower() == "off":
|
194
|
+
conn_prop["protocol"] = "http"
|
195
|
+
|
196
|
+
return conn_prop
|
@@ -249,7 +249,7 @@ class DataConnector:
|
|
249
249
|
|
250
250
|
# Switch to use Runtime's Data Ingester if running in ML runtime
|
251
251
|
# Fail silently if the data ingester is not found
|
252
|
-
if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR):
|
252
|
+
if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR, "").lower() in ("true", "1"):
|
253
253
|
try:
|
254
254
|
from runtime_external_entities import get_ingester_class
|
255
255
|
|
snowflake/ml/jobs/__init__.py
CHANGED
@@ -5,6 +5,8 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
7
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
8
|
+
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
9
|
+
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
8
10
|
MEMORY_VOLUME_NAME = "dshm"
|
9
11
|
STAGE_VOLUME_NAME = "stage-volume"
|
10
12
|
STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
@@ -13,7 +15,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
|
13
15
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
14
16
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
15
17
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
16
|
-
DEFAULT_IMAGE_TAG = "1.2
|
18
|
+
DEFAULT_IMAGE_TAG = "1.4.2"
|
17
19
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
18
20
|
|
19
21
|
# Percent of container memory to allocate for /dev/shm volume
|
@@ -37,16 +39,24 @@ RAY_PORTS = {
|
|
37
39
|
# Node health check configuration
|
38
40
|
# TODO(SNOW-1937020): Revisit the health check configuration
|
39
41
|
ML_RUNTIME_HEALTH_CHECK_PORT = "5001"
|
42
|
+
ENABLE_HEALTH_CHECKS_ENV_VAR = "ENABLE_HEALTH_CHECKS"
|
40
43
|
ENABLE_HEALTH_CHECKS = "false"
|
41
44
|
|
42
45
|
# Job status polling constants
|
43
46
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
44
|
-
JOB_POLL_MAX_DELAY_SECONDS =
|
47
|
+
JOB_POLL_MAX_DELAY_SECONDS = 30
|
45
48
|
|
46
49
|
# Magic attributes
|
47
50
|
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
48
51
|
RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
|
49
52
|
|
53
|
+
# Log start and end messages
|
54
|
+
LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
|
55
|
+
LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
|
56
|
+
|
57
|
+
# Default setting for verbose logging in get_log function
|
58
|
+
DEFAULT_VERBOSE_LOG = False
|
59
|
+
|
50
60
|
# Compute pool resource information
|
51
61
|
# TODO: Query Snowflake for resource information instead of relying on this hardcoded
|
52
62
|
# table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import inspect
|
2
|
+
from typing import Any, Callable, Optional
|
3
|
+
|
4
|
+
from snowflake import snowpark
|
5
|
+
from snowflake.snowpark import context as sp_context
|
6
|
+
|
7
|
+
|
8
|
+
class FunctionPayload:
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
func: Callable[..., Any],
|
12
|
+
session: Optional[snowpark.Session] = None,
|
13
|
+
session_argument: str = "",
|
14
|
+
*args: Any,
|
15
|
+
**kwargs: Any
|
16
|
+
) -> None:
|
17
|
+
self.function = func
|
18
|
+
self.args = args
|
19
|
+
self.kwargs = kwargs
|
20
|
+
self._session = session
|
21
|
+
self._session_argument = session_argument
|
22
|
+
|
23
|
+
@property
|
24
|
+
def session(self) -> Optional[snowpark.Session]:
|
25
|
+
return self._session
|
26
|
+
|
27
|
+
def __getstate__(self) -> dict[str, Any]:
|
28
|
+
"""Customize pickling to exclude session."""
|
29
|
+
state = self.__dict__.copy()
|
30
|
+
state["_session"] = None
|
31
|
+
return state
|
32
|
+
|
33
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
34
|
+
"""Restore session from context during unpickling."""
|
35
|
+
self.__dict__.update(state)
|
36
|
+
self._session = sp_context.get_active_session()
|
37
|
+
|
38
|
+
def __call__(self) -> Any:
|
39
|
+
sig = inspect.signature(self.function)
|
40
|
+
bound = sig.bind_partial(*self.args, **self.kwargs)
|
41
|
+
bound.arguments[self._session_argument] = self._session
|
42
|
+
|
43
|
+
return self.function(*bound.args, **bound.kwargs)
|
@@ -80,7 +80,7 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
|
|
80
80
|
# TODO: Check if file exists
|
81
81
|
with session.file.get_stream(result_path) as result_stream:
|
82
82
|
return ExecutionResult.from_dict(pickle.load(result_stream))
|
83
|
-
except (sp_exceptions.SnowparkSQLException, TypeError,
|
83
|
+
except (sp_exceptions.SnowparkSQLException, pickle.UnpicklingError, TypeError, ImportError):
|
84
84
|
# Fall back to JSON result if loading pickled result fails for any reason
|
85
85
|
result_json_path = os.path.splitext(result_path)[0] + ".json"
|
86
86
|
with session.file.get_stream(result_json_path) as result_stream:
|