snowflake-ml-python 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +16 -0
- snowflake/ml/_internal/telemetry.py +56 -7
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +1 -7
- snowflake/ml/experiment/_entities/run.py +15 -0
- snowflake/ml/experiment/experiment_tracking.py +61 -73
- snowflake/ml/feature_store/access_manager.py +1 -0
- snowflake/ml/feature_store/feature_store.py +86 -31
- snowflake/ml/feature_store/feature_view.py +12 -6
- snowflake/ml/fileset/stage_fs.py +12 -1
- snowflake/ml/jobs/_utils/feature_flags.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +6 -1
- snowflake/ml/jobs/_utils/spec_utils.py +12 -3
- snowflake/ml/jobs/job.py +8 -3
- snowflake/ml/jobs/manager.py +19 -6
- snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
- snowflake/ml/model/_client/model/model_version_impl.py +45 -17
- snowflake/ml/model/_client/ops/model_ops.py +11 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/models/huggingface_pipeline.py +6 -7
- snowflake/ml/monitoring/explain_visualize.py +3 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/METADATA +68 -5
- {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/RECORD +26 -26
- {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@ from snowflake.ml import version as snowml_version
|
|
|
16
16
|
from snowflake.ml._internal import env as snowml_env, relax_version_strategy
|
|
17
17
|
from snowflake.ml._internal.utils import query_result_checker
|
|
18
18
|
from snowflake.snowpark import context, exceptions, session
|
|
19
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class CONDA_OS(Enum):
|
|
@@ -38,6 +39,21 @@ SNOWPARK_ML_PKG_NAME = "snowflake-ml-python"
|
|
|
38
39
|
SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
|
|
39
40
|
|
|
40
41
|
|
|
42
|
+
def get_execution_context() -> str:
|
|
43
|
+
"""Detect execution context: EXTERNAL, SPCS, or SPROC.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
str: The execution context - "SPROC" if running in a stored procedure,
|
|
47
|
+
"SPCS" if running in SPCS ML runtime, "EXTERNAL" otherwise.
|
|
48
|
+
"""
|
|
49
|
+
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
|
50
|
+
return "SPROC"
|
|
51
|
+
elif snowml_env.IN_ML_RUNTIME:
|
|
52
|
+
return "SPCS"
|
|
53
|
+
else:
|
|
54
|
+
return "EXTERNAL"
|
|
55
|
+
|
|
56
|
+
|
|
41
57
|
def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
|
|
42
58
|
"""Validate the input pip requirement string according to PEP 508.
|
|
43
59
|
|
|
@@ -16,7 +16,7 @@ from typing_extensions import ParamSpec
|
|
|
16
16
|
from snowflake import connector
|
|
17
17
|
from snowflake.connector import connect, telemetry as connector_telemetry, time_util
|
|
18
18
|
from snowflake.ml import version as snowml_version
|
|
19
|
-
from snowflake.ml._internal import env
|
|
19
|
+
from snowflake.ml._internal import env, env_utils
|
|
20
20
|
from snowflake.ml._internal.exceptions import (
|
|
21
21
|
error_codes,
|
|
22
22
|
exceptions as snowml_exceptions,
|
|
@@ -37,6 +37,22 @@ _CONNECTION_TYPES = {
|
|
|
37
37
|
_Args = ParamSpec("_Args")
|
|
38
38
|
_ReturnValue = TypeVar("_ReturnValue")
|
|
39
39
|
|
|
40
|
+
_conn: Optional[connector.SnowflakeConnection] = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def clear_cached_conn() -> None:
|
|
44
|
+
"""Clear the cached Snowflake connection. Primarily for testing purposes."""
|
|
45
|
+
global _conn
|
|
46
|
+
if _conn is not None and _conn.is_valid():
|
|
47
|
+
_conn.close()
|
|
48
|
+
_conn = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_cached_conn() -> Optional[connector.SnowflakeConnection]:
|
|
52
|
+
"""Get the cached Snowflake connection. Primarily for testing purposes."""
|
|
53
|
+
global _conn
|
|
54
|
+
return _conn
|
|
55
|
+
|
|
40
56
|
|
|
41
57
|
def _get_login_token() -> Union[str, bytes]:
|
|
42
58
|
with open("/snowflake/session/token") as f:
|
|
@@ -44,7 +60,11 @@ def _get_login_token() -> Union[str, bytes]:
|
|
|
44
60
|
|
|
45
61
|
|
|
46
62
|
def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
|
|
47
|
-
|
|
63
|
+
global _conn
|
|
64
|
+
if _conn is not None and _conn.is_valid():
|
|
65
|
+
return _conn
|
|
66
|
+
|
|
67
|
+
conn: Optional[connector.SnowflakeConnection] = None
|
|
48
68
|
if os.getenv("SNOWFLAKE_HOST") is not None and os.getenv("SNOWFLAKE_ACCOUNT") is not None:
|
|
49
69
|
try:
|
|
50
70
|
conn = connect(
|
|
@@ -66,6 +86,13 @@ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
|
|
|
66
86
|
# Failed to get an active session. No connection available.
|
|
67
87
|
pass
|
|
68
88
|
|
|
89
|
+
# cache the connection if it's a SnowflakeConnection. there is a behavior at runtime where it could be a
|
|
90
|
+
# StoredProcConnection perhaps incorrect type hinting somewhere
|
|
91
|
+
if isinstance(conn, connector.SnowflakeConnection):
|
|
92
|
+
# if _conn was expired, we need to copy telemetry data to new connection
|
|
93
|
+
if _conn is not None and conn is not None:
|
|
94
|
+
conn._telemetry._log_batch.extend(_conn._telemetry._log_batch)
|
|
95
|
+
_conn = conn
|
|
69
96
|
return conn
|
|
70
97
|
|
|
71
98
|
|
|
@@ -113,6 +140,13 @@ class TelemetryField(enum.Enum):
|
|
|
113
140
|
FUNC_CAT_USAGE = "usage"
|
|
114
141
|
|
|
115
142
|
|
|
143
|
+
@enum.unique
|
|
144
|
+
class CustomTagKey(enum.Enum):
|
|
145
|
+
"""Keys for custom tags in telemetry."""
|
|
146
|
+
|
|
147
|
+
EXECUTION_CONTEXT = "execution_context"
|
|
148
|
+
|
|
149
|
+
|
|
116
150
|
class _TelemetrySourceType(enum.Enum):
|
|
117
151
|
# Automatically inferred telemetry/statement parameters
|
|
118
152
|
AUTO_TELEMETRY = "SNOWML_AUTO_TELEMETRY"
|
|
@@ -441,6 +475,7 @@ def send_api_usage_telemetry(
|
|
|
441
475
|
sfqids_extractor: Optional[Callable[..., list[str]]] = None,
|
|
442
476
|
subproject_extractor: Optional[Callable[[Any], str]] = None,
|
|
443
477
|
custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
|
|
478
|
+
log_execution_context: bool = True,
|
|
444
479
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
|
|
445
480
|
"""
|
|
446
481
|
Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by
|
|
@@ -455,6 +490,8 @@ def send_api_usage_telemetry(
|
|
|
455
490
|
sfqids_extractor: Extract sfqids from `self`.
|
|
456
491
|
subproject_extractor: Extract subproject at runtime from `self`.
|
|
457
492
|
custom_tags: Custom tags.
|
|
493
|
+
log_execution_context: If True, automatically detect and log execution context
|
|
494
|
+
(EXTERNAL, SPCS, or SPROC) in custom_tags.
|
|
458
495
|
|
|
459
496
|
Returns:
|
|
460
497
|
Decorator that sends function usage telemetry for any call to the decorated function.
|
|
@@ -495,6 +532,11 @@ def send_api_usage_telemetry(
|
|
|
495
532
|
if subproject_extractor is not None:
|
|
496
533
|
subproject_name = subproject_extractor(args[0])
|
|
497
534
|
|
|
535
|
+
# Add execution context if enabled
|
|
536
|
+
final_custom_tags = {**custom_tags} if custom_tags is not None else {}
|
|
537
|
+
if log_execution_context:
|
|
538
|
+
final_custom_tags[CustomTagKey.EXECUTION_CONTEXT.value] = env_utils.get_execution_context()
|
|
539
|
+
|
|
498
540
|
statement_params = get_function_usage_statement_params(
|
|
499
541
|
project=project,
|
|
500
542
|
subproject=subproject_name,
|
|
@@ -502,7 +544,7 @@ def send_api_usage_telemetry(
|
|
|
502
544
|
function_name=_get_full_func_name(func),
|
|
503
545
|
function_parameters=params,
|
|
504
546
|
api_calls=api_calls,
|
|
505
|
-
custom_tags=
|
|
547
|
+
custom_tags=final_custom_tags,
|
|
506
548
|
)
|
|
507
549
|
|
|
508
550
|
def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: dict[str, Any]) -> _ReturnValue:
|
|
@@ -538,7 +580,10 @@ def send_api_usage_telemetry(
|
|
|
538
580
|
if conn_attr_name:
|
|
539
581
|
# raise AttributeError if conn attribute does not exist in `self`
|
|
540
582
|
conn = operator.attrgetter(conn_attr_name)(args[0])
|
|
541
|
-
if not isinstance(
|
|
583
|
+
if not isinstance(
|
|
584
|
+
conn,
|
|
585
|
+
_CONNECTION_TYPES.get(type(conn).__name__, connector.SnowflakeConnection),
|
|
586
|
+
):
|
|
542
587
|
raise TypeError(
|
|
543
588
|
f"Expected a conn object of type {' or '.join(_CONNECTION_TYPES.keys())} but got {type(conn)}"
|
|
544
589
|
)
|
|
@@ -560,7 +605,7 @@ def send_api_usage_telemetry(
|
|
|
560
605
|
func_params=params,
|
|
561
606
|
api_calls=api_calls,
|
|
562
607
|
sfqids=sfqids,
|
|
563
|
-
custom_tags=
|
|
608
|
+
custom_tags=final_custom_tags,
|
|
564
609
|
)
|
|
565
610
|
try:
|
|
566
611
|
return ctx.run(execute_func_with_statement_params)
|
|
@@ -571,7 +616,8 @@ def send_api_usage_telemetry(
|
|
|
571
616
|
raise
|
|
572
617
|
if isinstance(e, snowpark_exceptions.SnowparkClientException):
|
|
573
618
|
me = snowml_exceptions.SnowflakeMLException(
|
|
574
|
-
error_code=error_codes.INTERNAL_SNOWPARK_ERROR,
|
|
619
|
+
error_code=error_codes.INTERNAL_SNOWPARK_ERROR,
|
|
620
|
+
original_exception=e,
|
|
575
621
|
)
|
|
576
622
|
else:
|
|
577
623
|
me = snowml_exceptions.SnowflakeMLException(
|
|
@@ -627,7 +673,10 @@ def _get_full_func_name(func: Callable[..., Any]) -> str:
|
|
|
627
673
|
|
|
628
674
|
|
|
629
675
|
def _get_func_params(
|
|
630
|
-
func: Callable[..., Any],
|
|
676
|
+
func: Callable[..., Any],
|
|
677
|
+
func_params_to_log: Optional[Iterable[str]],
|
|
678
|
+
args: Any,
|
|
679
|
+
kwargs: Any,
|
|
631
680
|
) -> dict[str, Any]:
|
|
632
681
|
"""
|
|
633
682
|
Get function parameters.
|
|
@@ -128,13 +128,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
128
128
|
run_name: sql_identifier.SqlIdentifier,
|
|
129
129
|
artifact_path: str,
|
|
130
130
|
) -> list[artifact.ArtifactInfo]:
|
|
131
|
-
results = (
|
|
132
|
-
query_result_checker.SqlResultValidator(
|
|
133
|
-
self._session, f"LIST {self._build_snow_uri(experiment_name, run_name, artifact_path)}"
|
|
134
|
-
)
|
|
135
|
-
.has_dimensions(expected_cols=4)
|
|
136
|
-
.validate()
|
|
137
|
-
)
|
|
131
|
+
results = self._session.sql(f"LIST {self._build_snow_uri(experiment_name, run_name, artifact_path)}").collect()
|
|
138
132
|
return [
|
|
139
133
|
artifact.ArtifactInfo(
|
|
140
134
|
name=str(result.name).removeprefix(f"/versions/{run_name}/"),
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import types
|
|
2
|
+
import warnings
|
|
2
3
|
from typing import TYPE_CHECKING, Optional
|
|
3
4
|
|
|
4
5
|
from snowflake.ml._internal.utils import sql_identifier
|
|
@@ -7,6 +8,8 @@ from snowflake.ml.experiment import _experiment_info as experiment_info
|
|
|
7
8
|
if TYPE_CHECKING:
|
|
8
9
|
from snowflake.ml.experiment import experiment_tracking
|
|
9
10
|
|
|
11
|
+
METADATA_SIZE_WARNING_MESSAGE = "It is likely that no further metrics or parameters will be logged for this run."
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
class Run:
|
|
12
15
|
def __init__(
|
|
@@ -20,6 +23,9 @@ class Run:
|
|
|
20
23
|
self.experiment_name = experiment_name
|
|
21
24
|
self.name = run_name
|
|
22
25
|
|
|
26
|
+
# Whether we've already shown the user a warning about exceeding the run metadata size limit.
|
|
27
|
+
self._warned_about_metadata_size = False
|
|
28
|
+
|
|
23
29
|
self._patcher = experiment_info.ExperimentInfoPatcher(
|
|
24
30
|
experiment_info=self._get_experiment_info(),
|
|
25
31
|
)
|
|
@@ -45,3 +51,12 @@ class Run:
|
|
|
45
51
|
),
|
|
46
52
|
run_name=self.name.identifier(),
|
|
47
53
|
)
|
|
54
|
+
|
|
55
|
+
def _warn_about_run_metadata_size(self, sql_error_msg: str) -> None:
|
|
56
|
+
if not self._warned_about_metadata_size:
|
|
57
|
+
warnings.warn(
|
|
58
|
+
f"{sql_error_msg}. {METADATA_SIZE_WARNING_MESSAGE}",
|
|
59
|
+
RuntimeWarning,
|
|
60
|
+
stacklevel=2,
|
|
61
|
+
)
|
|
62
|
+
self._warned_about_metadata_size = True
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Optional, Union
|
|
5
5
|
from urllib.parse import quote
|
|
6
6
|
|
|
7
7
|
from snowflake import snowpark
|
|
8
8
|
from snowflake.ml import model as ml_model, registry
|
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
10
|
-
from snowflake.ml._internal.utils import
|
|
10
|
+
from snowflake.ml._internal.utils import connection_params, sql_identifier
|
|
11
11
|
from snowflake.ml.experiment import (
|
|
12
12
|
_entities as entities,
|
|
13
13
|
_experiment_info as experiment_info,
|
|
@@ -21,34 +21,12 @@ from snowflake.ml.utils import sql_client as sql_client_utils
|
|
|
21
21
|
|
|
22
22
|
DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
|
|
23
23
|
|
|
24
|
-
P = ParamSpec("P")
|
|
25
|
-
T = TypeVar("T")
|
|
26
24
|
|
|
27
|
-
|
|
28
|
-
def _restore_session(
|
|
29
|
-
func: Callable[Concatenate["ExperimentTracking", P], T],
|
|
30
|
-
) -> Callable[Concatenate["ExperimentTracking", P], T]:
|
|
31
|
-
@functools.wraps(func)
|
|
32
|
-
def wrapper(self: "ExperimentTracking", /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
33
|
-
if self._session is None:
|
|
34
|
-
if self._session_state is None:
|
|
35
|
-
raise RuntimeError(
|
|
36
|
-
f"Session is not set before calling {func.__name__}, and there is no session state to restore from"
|
|
37
|
-
)
|
|
38
|
-
self._set_session(self._session_state)
|
|
39
|
-
if self._session is None:
|
|
40
|
-
raise RuntimeError(f"Failed to restore session before calling {func.__name__}")
|
|
41
|
-
return func(self, *args, **kwargs)
|
|
42
|
-
|
|
43
|
-
return wrapper
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
25
|
+
class ExperimentTracking:
|
|
47
26
|
"""
|
|
48
27
|
Class to manage experiments in Snowflake.
|
|
49
28
|
"""
|
|
50
29
|
|
|
51
|
-
@snowpark._internal.utils.private_preview(version="1.9.1")
|
|
52
30
|
def __init__(
|
|
53
31
|
self,
|
|
54
32
|
session: snowpark.Session,
|
|
@@ -93,10 +71,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
93
71
|
database_name=self._database_name,
|
|
94
72
|
schema_name=self._schema_name,
|
|
95
73
|
)
|
|
96
|
-
self._session
|
|
97
|
-
# Used to store information about the session if the session could not be restored during unpickling
|
|
98
|
-
# _session_state is None if and only if _session is not None
|
|
99
|
-
self._session_state: Optional[mixins._SessionState] = None
|
|
74
|
+
self._session = session
|
|
100
75
|
|
|
101
76
|
# The experiment in context
|
|
102
77
|
self._experiment: Optional[entities.Experiment] = None
|
|
@@ -104,35 +79,40 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
104
79
|
self._run: Optional[entities.Run] = None
|
|
105
80
|
|
|
106
81
|
def __getstate__(self) -> dict[str, Any]:
|
|
107
|
-
|
|
82
|
+
parent_state = (
|
|
83
|
+
super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
|
|
84
|
+
if hasattr(super(), "__getstate__")
|
|
85
|
+
else self.__dict__
|
|
86
|
+
)
|
|
87
|
+
state = dict(parent_state) # Create a copy so we can safely modify the state
|
|
88
|
+
|
|
108
89
|
# Remove unpicklable attributes
|
|
90
|
+
state["_session"] = None
|
|
109
91
|
state["_sql_client"] = None
|
|
110
92
|
state["_registry"] = None
|
|
111
93
|
return state
|
|
112
94
|
|
|
113
|
-
def
|
|
114
|
-
|
|
115
|
-
super().
|
|
116
|
-
assert self._session is not None
|
|
117
|
-
except (snowpark.exceptions.SnowparkSessionException, AssertionError):
|
|
118
|
-
# If session was not set, store the session state
|
|
119
|
-
self._session = None
|
|
120
|
-
self._session_state = session_state
|
|
95
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
96
|
+
if hasattr(super(), "__setstate__"):
|
|
97
|
+
super().__setstate__(state) # type: ignore[misc]
|
|
121
98
|
else:
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
self.
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
99
|
+
self.__dict__.update(state)
|
|
100
|
+
|
|
101
|
+
# Restore unpicklable attributes
|
|
102
|
+
options: dict[str, Any] = connection_params.SnowflakeLoginOptions()
|
|
103
|
+
options["client_session_keep_alive"] = True # Needed for long-running training jobs
|
|
104
|
+
self._session = snowpark.Session.builder.configs(options).getOrCreate()
|
|
105
|
+
self._sql_client = sql_client.ExperimentTrackingSQLClient(
|
|
106
|
+
session=self._session,
|
|
107
|
+
database_name=self._database_name,
|
|
108
|
+
schema_name=self._schema_name,
|
|
109
|
+
)
|
|
110
|
+
self._registry = registry.Registry(
|
|
111
|
+
session=self._session,
|
|
112
|
+
database_name=self._database_name,
|
|
113
|
+
schema_name=self._schema_name,
|
|
114
|
+
)
|
|
134
115
|
|
|
135
|
-
@_restore_session
|
|
136
116
|
def set_experiment(
|
|
137
117
|
self,
|
|
138
118
|
experiment_name: str,
|
|
@@ -157,7 +137,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
157
137
|
self._run = None
|
|
158
138
|
return self._experiment
|
|
159
139
|
|
|
160
|
-
@_restore_session
|
|
161
140
|
def delete_experiment(
|
|
162
141
|
self,
|
|
163
142
|
experiment_name: str,
|
|
@@ -174,10 +153,8 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
174
153
|
self._run = None
|
|
175
154
|
|
|
176
155
|
@functools.wraps(registry.Registry.log_model)
|
|
177
|
-
@_restore_session
|
|
178
156
|
def log_model(
|
|
179
157
|
self,
|
|
180
|
-
/, # self needs to be a positional argument to stop mypy from complaining
|
|
181
158
|
model: Union[type_hints.SupportedModelType, ml_model.ModelVersion],
|
|
182
159
|
*,
|
|
183
160
|
model_name: str,
|
|
@@ -187,7 +164,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
187
164
|
with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
|
|
188
165
|
return self._registry.log_model(model, model_name=model_name, **kwargs)
|
|
189
166
|
|
|
190
|
-
@_restore_session
|
|
191
167
|
def start_run(
|
|
192
168
|
self,
|
|
193
169
|
run_name: Optional[str] = None,
|
|
@@ -229,7 +205,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
229
205
|
self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
|
|
230
206
|
return self._run
|
|
231
207
|
|
|
232
|
-
@_restore_session
|
|
233
208
|
def end_run(self, run_name: Optional[str] = None) -> None:
|
|
234
209
|
"""
|
|
235
210
|
End the current run if no run name is provided. Otherwise, the specified run is ended.
|
|
@@ -259,7 +234,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
259
234
|
self._run = None
|
|
260
235
|
self._print_urls(experiment_name=experiment_name, run_name=run_name)
|
|
261
236
|
|
|
262
|
-
@_restore_session
|
|
263
237
|
def delete_run(
|
|
264
238
|
self,
|
|
265
239
|
run_name: str,
|
|
@@ -298,7 +272,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
298
272
|
"""
|
|
299
273
|
self.log_metrics(metrics={key: value}, step=step)
|
|
300
274
|
|
|
301
|
-
@_restore_session
|
|
302
275
|
def log_metrics(
|
|
303
276
|
self,
|
|
304
277
|
metrics: dict[str, float],
|
|
@@ -310,16 +283,26 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
310
283
|
Args:
|
|
311
284
|
metrics: Dictionary containing metric keys and float values.
|
|
312
285
|
step: The step of the metrics. Defaults to 0.
|
|
286
|
+
|
|
287
|
+
Raises:
|
|
288
|
+
snowpark.exceptions.SnowparkSQLException: If logging metrics fails due to Snowflake SQL errors,
|
|
289
|
+
except for run metadata size limit errors which will issue a warning instead of raising.
|
|
313
290
|
"""
|
|
314
291
|
run = self._get_or_start_run()
|
|
315
292
|
metrics_list = []
|
|
316
293
|
for key, value in metrics.items():
|
|
317
294
|
metrics_list.append(entities.Metric(key, value, step))
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
295
|
+
try:
|
|
296
|
+
self._sql_client.modify_run_add_metrics(
|
|
297
|
+
experiment_name=run.experiment_name,
|
|
298
|
+
run_name=run.name,
|
|
299
|
+
metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
|
|
300
|
+
)
|
|
301
|
+
except snowpark.exceptions.SnowparkSQLException as e:
|
|
302
|
+
if e.sql_error_code == 400003: # EXPERIMENT_RUN_PROPERTY_SIZE_LIMIT_EXCEEDED
|
|
303
|
+
run._warn_about_run_metadata_size(e.message)
|
|
304
|
+
else:
|
|
305
|
+
raise
|
|
323
306
|
|
|
324
307
|
def log_param(
|
|
325
308
|
self,
|
|
@@ -335,7 +318,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
335
318
|
"""
|
|
336
319
|
self.log_params({key: value})
|
|
337
320
|
|
|
338
|
-
@_restore_session
|
|
339
321
|
def log_params(
|
|
340
322
|
self,
|
|
341
323
|
params: dict[str, Any],
|
|
@@ -346,18 +328,27 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
346
328
|
Args:
|
|
347
329
|
params: Dictionary containing parameter keys and values. Values can be of any type, but will be converted
|
|
348
330
|
to string.
|
|
331
|
+
|
|
332
|
+
Raises:
|
|
333
|
+
snowpark.exceptions.SnowparkSQLException: If logging parameters fails due to Snowflake SQL errors,
|
|
334
|
+
except for run metadata size limit errors which will issue a warning instead of raising.
|
|
349
335
|
"""
|
|
350
336
|
run = self._get_or_start_run()
|
|
351
337
|
params_list = []
|
|
352
338
|
for key, value in params.items():
|
|
353
339
|
params_list.append(entities.Param(key, str(value)))
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
340
|
+
try:
|
|
341
|
+
self._sql_client.modify_run_add_params(
|
|
342
|
+
experiment_name=run.experiment_name,
|
|
343
|
+
run_name=run.name,
|
|
344
|
+
params=json.dumps([param.to_dict() for param in params_list]),
|
|
345
|
+
)
|
|
346
|
+
except snowpark.exceptions.SnowparkSQLException as e:
|
|
347
|
+
if e.sql_error_code == 400003: # EXPERIMENT_RUN_PROPERTY_SIZE_LIMIT_EXCEEDED
|
|
348
|
+
run._warn_about_run_metadata_size(e.message)
|
|
349
|
+
else:
|
|
350
|
+
raise
|
|
359
351
|
|
|
360
|
-
@_restore_session
|
|
361
352
|
def log_artifact(
|
|
362
353
|
self,
|
|
363
354
|
local_path: str,
|
|
@@ -381,7 +372,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
381
372
|
file_path=file_path,
|
|
382
373
|
)
|
|
383
374
|
|
|
384
|
-
@_restore_session
|
|
385
375
|
def list_artifacts(
|
|
386
376
|
self,
|
|
387
377
|
run_name: str,
|
|
@@ -410,7 +400,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
410
400
|
artifact_path=artifact_path or "",
|
|
411
401
|
)
|
|
412
402
|
|
|
413
|
-
@_restore_session
|
|
414
403
|
def download_artifacts(
|
|
415
404
|
self,
|
|
416
405
|
run_name: str,
|
|
@@ -452,7 +441,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
452
441
|
return self._run
|
|
453
442
|
return self.start_run()
|
|
454
443
|
|
|
455
|
-
@_restore_session
|
|
456
444
|
def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
|
|
457
445
|
generator = hrid_generator.HRID16()
|
|
458
446
|
existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
|
|
@@ -202,6 +202,7 @@ def _configure_role_hierarchy(
|
|
|
202
202
|
session.sql(f"GRANT ROLE {producer_role} TO ROLE {session.get_current_role()}").collect()
|
|
203
203
|
|
|
204
204
|
if consumer_role is not None:
|
|
205
|
+
# Create CONSUMER and grant it to PRODUCER to build hierarchy
|
|
205
206
|
consumer_role = SqlIdentifier(consumer_role)
|
|
206
207
|
session.sql(f"CREATE ROLE IF NOT EXISTS {consumer_role}").collect()
|
|
207
208
|
session.sql(f"GRANT ROLE {consumer_role} TO ROLE {producer_role}").collect()
|