snowflake-ml-python 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +3 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +18 -19
- snowflake/ml/experiment/callback/keras.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +3 -0
- snowflake/ml/experiment/callback/xgboost.py +3 -0
- snowflake/ml/experiment/experiment_tracking.py +50 -70
- snowflake/ml/feature_store/feature_store.py +299 -69
- snowflake/ml/feature_store/feature_view.py +12 -6
- snowflake/ml/fileset/stage_fs.py +12 -1
- snowflake/ml/jobs/_utils/constants.py +12 -1
- snowflake/ml/jobs/_utils/payload_utils.py +7 -1
- snowflake/ml/jobs/_utils/stage_utils.py +4 -0
- snowflake/ml/jobs/_utils/types.py +5 -0
- snowflake/ml/jobs/job.py +19 -5
- snowflake/ml/jobs/manager.py +18 -7
- snowflake/ml/model/__init__.py +19 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
- snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
- snowflake/ml/model/_client/model/model_version_impl.py +129 -11
- snowflake/ml/model/_client/ops/model_ops.py +11 -4
- snowflake/ml/model/_client/ops/service_ops.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +4 -1
- snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
- snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/type_hints.py +16 -0
- snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
- snowflake/ml/monitoring/explain_visualize.py +3 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/METADATA +50 -4
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/RECORD +38 -37
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/top_level.txt +0 -0
|
@@ -73,6 +73,7 @@ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
|
|
|
73
73
|
class TelemetryProject(enum.Enum):
|
|
74
74
|
MLOPS = "MLOps"
|
|
75
75
|
MODELING = "ModelDevelopment"
|
|
76
|
+
EXPERIMENT_TRACKING = "ExperimentTracking"
|
|
76
77
|
# TODO: Update with remaining projects.
|
|
77
78
|
|
|
78
79
|
|
|
@@ -464,14 +465,14 @@ def send_api_usage_telemetry(
|
|
|
464
465
|
|
|
465
466
|
# noqa: DAR402
|
|
466
467
|
"""
|
|
467
|
-
start_time = time.perf_counter()
|
|
468
|
-
|
|
469
468
|
if subproject is not None and subproject_extractor is not None:
|
|
470
469
|
raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
|
|
471
470
|
|
|
472
471
|
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
|
|
473
472
|
@functools.wraps(func)
|
|
474
473
|
def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
|
|
474
|
+
start_time = time.perf_counter()
|
|
475
|
+
|
|
475
476
|
params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
|
|
476
477
|
|
|
477
478
|
api_calls: list[Union[dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
|
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
+
from snowflake.ml._internal import telemetry
|
|
3
4
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
|
4
5
|
from snowflake.ml.experiment._client import artifact
|
|
5
6
|
from snowflake.ml.model._client.sql import _base
|
|
6
7
|
from snowflake.ml.utils import sql_client
|
|
7
8
|
from snowflake.snowpark import file_operation, row, session
|
|
8
9
|
|
|
10
|
+
RUN_NAME_COL_NAME = "name"
|
|
11
|
+
RUN_METADATA_COL_NAME = "metadata"
|
|
9
12
|
|
|
10
|
-
class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
11
|
-
|
|
12
|
-
RUN_NAME_COL_NAME = "name"
|
|
13
|
-
RUN_METADATA_COL_NAME = "metadata"
|
|
14
13
|
|
|
14
|
+
class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
15
15
|
def __init__(
|
|
16
16
|
self,
|
|
17
17
|
session: session.Session,
|
|
@@ -28,6 +28,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
28
28
|
"""
|
|
29
29
|
super().__init__(session, database_name=database_name, schema_name=schema_name)
|
|
30
30
|
|
|
31
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
31
32
|
def create_experiment(
|
|
32
33
|
self,
|
|
33
34
|
experiment_name: sql_identifier.SqlIdentifier,
|
|
@@ -39,24 +40,21 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
39
40
|
self._session, f"CREATE EXPERIMENT {if_not_exists_sql} {experiment_fqn}"
|
|
40
41
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
41
42
|
|
|
43
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
42
44
|
def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
|
|
43
45
|
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
44
46
|
query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
|
|
45
47
|
expected_rows=1, expected_cols=1
|
|
46
48
|
).validate()
|
|
47
49
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
*,
|
|
51
|
-
experiment_name: sql_identifier.SqlIdentifier,
|
|
52
|
-
run_name: sql_identifier.SqlIdentifier,
|
|
53
|
-
live: bool = True,
|
|
54
|
-
) -> None:
|
|
50
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
51
|
+
def add_run(self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier) -> None:
|
|
55
52
|
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
56
53
|
query_result_checker.SqlResultValidator(
|
|
57
|
-
self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD
|
|
54
|
+
self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD RUN {run_name}"
|
|
58
55
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
59
56
|
|
|
57
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
60
58
|
def commit_run(
|
|
61
59
|
self,
|
|
62
60
|
*,
|
|
@@ -68,6 +66,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
68
66
|
self._session, f"ALTER EXPERIMENT {experiment_fqn} COMMIT RUN {run_name}"
|
|
69
67
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
70
68
|
|
|
69
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
71
70
|
def drop_run(
|
|
72
71
|
self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier
|
|
73
72
|
) -> None:
|
|
@@ -76,6 +75,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
76
75
|
self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
|
|
77
76
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
78
77
|
|
|
78
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
79
79
|
def modify_run_add_metrics(
|
|
80
80
|
self,
|
|
81
81
|
*,
|
|
@@ -89,6 +89,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
89
89
|
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
|
|
90
90
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
91
91
|
|
|
92
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
92
93
|
def modify_run_add_params(
|
|
93
94
|
self,
|
|
94
95
|
*,
|
|
@@ -102,6 +103,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
102
103
|
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
|
|
103
104
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
104
105
|
|
|
106
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
105
107
|
def put_artifact(
|
|
106
108
|
self,
|
|
107
109
|
*,
|
|
@@ -118,6 +120,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
118
120
|
auto_compress=auto_compress,
|
|
119
121
|
)[0]
|
|
120
122
|
|
|
123
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
121
124
|
def list_artifacts(
|
|
122
125
|
self,
|
|
123
126
|
*,
|
|
@@ -125,13 +128,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
125
128
|
run_name: sql_identifier.SqlIdentifier,
|
|
126
129
|
artifact_path: str,
|
|
127
130
|
) -> list[artifact.ArtifactInfo]:
|
|
128
|
-
results = (
|
|
129
|
-
query_result_checker.SqlResultValidator(
|
|
130
|
-
self._session, f"LIST {self._build_snow_uri(experiment_name, run_name, artifact_path)}"
|
|
131
|
-
)
|
|
132
|
-
.has_dimensions(expected_cols=4)
|
|
133
|
-
.validate()
|
|
134
|
-
)
|
|
131
|
+
results = self._session.sql(f"LIST {self._build_snow_uri(experiment_name, run_name, artifact_path)}").collect()
|
|
135
132
|
return [
|
|
136
133
|
artifact.ArtifactInfo(
|
|
137
134
|
name=str(result.name).removeprefix(f"/versions/{run_name}/"),
|
|
@@ -142,6 +139,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
142
139
|
for result in results
|
|
143
140
|
]
|
|
144
141
|
|
|
142
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
145
143
|
def get_artifact(
|
|
146
144
|
self,
|
|
147
145
|
*,
|
|
@@ -155,6 +153,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
155
153
|
target_directory=target_path,
|
|
156
154
|
)[0]
|
|
157
155
|
|
|
156
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
158
157
|
def show_runs_in_experiment(
|
|
159
158
|
self, *, experiment_name: sql_identifier.SqlIdentifier, like: Optional[str] = None
|
|
160
159
|
) -> list[row.Row]:
|
|
@@ -20,6 +20,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
|
20
20
|
log_params: bool = True,
|
|
21
21
|
log_every_n_epochs: int = 1,
|
|
22
22
|
model_name: Optional[str] = None,
|
|
23
|
+
version_name: Optional[str] = None,
|
|
23
24
|
model_signature: Optional["ModelSignature"] = None,
|
|
24
25
|
) -> None:
|
|
25
26
|
self._experiment_tracking = experiment_tracking
|
|
@@ -30,6 +31,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
|
30
31
|
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
31
32
|
self.log_every_n_epochs = log_every_n_epochs
|
|
32
33
|
self.model_name = model_name
|
|
34
|
+
self.version_name = version_name
|
|
33
35
|
self.model_signature = model_signature
|
|
34
36
|
|
|
35
37
|
def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
|
|
@@ -59,5 +61,6 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
|
59
61
|
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
60
62
|
model=self.model,
|
|
61
63
|
model_name=model_name,
|
|
64
|
+
version_name=self.version_name,
|
|
62
65
|
signatures={"predict": self.model_signature},
|
|
63
66
|
)
|
|
@@ -17,6 +17,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
|
17
17
|
log_params: bool = True,
|
|
18
18
|
log_every_n_epochs: int = 1,
|
|
19
19
|
model_name: Optional[str] = None,
|
|
20
|
+
version_name: Optional[str] = None,
|
|
20
21
|
model_signature: Optional["ModelSignature"] = None,
|
|
21
22
|
) -> None:
|
|
22
23
|
self._experiment_tracking = experiment_tracking
|
|
@@ -27,6 +28,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
|
27
28
|
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
28
29
|
self.log_every_n_epochs = log_every_n_epochs
|
|
29
30
|
self.model_name = model_name
|
|
31
|
+
self.version_name = version_name
|
|
30
32
|
self.model_signature = model_signature
|
|
31
33
|
|
|
32
34
|
super().__init__(eval_result={})
|
|
@@ -50,6 +52,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
|
50
52
|
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
51
53
|
model=env.model,
|
|
52
54
|
model_name=model_name,
|
|
55
|
+
version_name=self.version_name,
|
|
53
56
|
signatures={"predict": self.model_signature},
|
|
54
57
|
)
|
|
55
58
|
else:
|
|
@@ -20,6 +20,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
|
20
20
|
log_params: bool = True,
|
|
21
21
|
log_every_n_epochs: int = 1,
|
|
22
22
|
model_name: Optional[str] = None,
|
|
23
|
+
version_name: Optional[str] = None,
|
|
23
24
|
model_signature: Optional["ModelSignature"] = None,
|
|
24
25
|
) -> None:
|
|
25
26
|
self._experiment_tracking = experiment_tracking
|
|
@@ -30,6 +31,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
|
30
31
|
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
31
32
|
self.log_every_n_epochs = log_every_n_epochs
|
|
32
33
|
self.model_name = model_name
|
|
34
|
+
self.version_name = version_name
|
|
33
35
|
self.model_signature = model_signature
|
|
34
36
|
|
|
35
37
|
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
@@ -61,6 +63,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
|
61
63
|
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
62
64
|
model=model,
|
|
63
65
|
model_name=model_name,
|
|
66
|
+
version_name=self.version_name,
|
|
64
67
|
signatures={"predict": self.model_signature},
|
|
65
68
|
)
|
|
66
69
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Optional, Union
|
|
5
5
|
from urllib.parse import quote
|
|
6
6
|
|
|
7
7
|
from snowflake import snowpark
|
|
8
8
|
from snowflake.ml import model as ml_model, registry
|
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
10
|
-
from snowflake.ml._internal.utils import
|
|
10
|
+
from snowflake.ml._internal.utils import connection_params, sql_identifier
|
|
11
11
|
from snowflake.ml.experiment import (
|
|
12
12
|
_entities as entities,
|
|
13
13
|
_experiment_info as experiment_info,
|
|
@@ -21,34 +21,12 @@ from snowflake.ml.utils import sql_client as sql_client_utils
|
|
|
21
21
|
|
|
22
22
|
DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
|
|
23
23
|
|
|
24
|
-
P = ParamSpec("P")
|
|
25
|
-
T = TypeVar("T")
|
|
26
24
|
|
|
27
|
-
|
|
28
|
-
def _restore_session(
|
|
29
|
-
func: Callable[Concatenate["ExperimentTracking", P], T],
|
|
30
|
-
) -> Callable[Concatenate["ExperimentTracking", P], T]:
|
|
31
|
-
@functools.wraps(func)
|
|
32
|
-
def wrapper(self: "ExperimentTracking", /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
33
|
-
if self._session is None:
|
|
34
|
-
if self._session_state is None:
|
|
35
|
-
raise RuntimeError(
|
|
36
|
-
f"Session is not set before calling {func.__name__}, and there is no session state to restore from"
|
|
37
|
-
)
|
|
38
|
-
self._set_session(self._session_state)
|
|
39
|
-
if self._session is None:
|
|
40
|
-
raise RuntimeError(f"Failed to restore session before calling {func.__name__}")
|
|
41
|
-
return func(self, *args, **kwargs)
|
|
42
|
-
|
|
43
|
-
return wrapper
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
25
|
+
class ExperimentTracking:
|
|
47
26
|
"""
|
|
48
27
|
Class to manage experiments in Snowflake.
|
|
49
28
|
"""
|
|
50
29
|
|
|
51
|
-
@snowpark._internal.utils.private_preview(version="1.9.1")
|
|
52
30
|
def __init__(
|
|
53
31
|
self,
|
|
54
32
|
session: snowpark.Session,
|
|
@@ -93,10 +71,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
93
71
|
database_name=self._database_name,
|
|
94
72
|
schema_name=self._schema_name,
|
|
95
73
|
)
|
|
96
|
-
self._session
|
|
97
|
-
# Used to store information about the session if the session could not be restored during unpickling
|
|
98
|
-
# _session_state is None if and only if _session is not None
|
|
99
|
-
self._session_state: Optional[mixins._SessionState] = None
|
|
74
|
+
self._session = session
|
|
100
75
|
|
|
101
76
|
# The experiment in context
|
|
102
77
|
self._experiment: Optional[entities.Experiment] = None
|
|
@@ -104,35 +79,40 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
104
79
|
self._run: Optional[entities.Run] = None
|
|
105
80
|
|
|
106
81
|
def __getstate__(self) -> dict[str, Any]:
|
|
107
|
-
|
|
82
|
+
parent_state = (
|
|
83
|
+
super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
|
|
84
|
+
if hasattr(super(), "__getstate__")
|
|
85
|
+
else self.__dict__
|
|
86
|
+
)
|
|
87
|
+
state = dict(parent_state) # Create a copy so we can safely modify the state
|
|
88
|
+
|
|
108
89
|
# Remove unpicklable attributes
|
|
90
|
+
state["_session"] = None
|
|
109
91
|
state["_sql_client"] = None
|
|
110
92
|
state["_registry"] = None
|
|
111
93
|
return state
|
|
112
94
|
|
|
113
|
-
def
|
|
114
|
-
|
|
115
|
-
super().
|
|
116
|
-
assert self._session is not None
|
|
117
|
-
except (snowpark.exceptions.SnowparkSessionException, AssertionError):
|
|
118
|
-
# If session was not set, store the session state
|
|
119
|
-
self._session = None
|
|
120
|
-
self._session_state = session_state
|
|
95
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
96
|
+
if hasattr(super(), "__setstate__"):
|
|
97
|
+
super().__setstate__(state) # type: ignore[misc]
|
|
121
98
|
else:
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
self.
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
99
|
+
self.__dict__.update(state)
|
|
100
|
+
|
|
101
|
+
# Restore unpicklable attributes
|
|
102
|
+
options: dict[str, Any] = connection_params.SnowflakeLoginOptions()
|
|
103
|
+
options["client_session_keep_alive"] = True # Needed for long-running training jobs
|
|
104
|
+
self._session = snowpark.Session.builder.configs(options).getOrCreate()
|
|
105
|
+
self._sql_client = sql_client.ExperimentTrackingSQLClient(
|
|
106
|
+
session=self._session,
|
|
107
|
+
database_name=self._database_name,
|
|
108
|
+
schema_name=self._schema_name,
|
|
109
|
+
)
|
|
110
|
+
self._registry = registry.Registry(
|
|
111
|
+
session=self._session,
|
|
112
|
+
database_name=self._database_name,
|
|
113
|
+
schema_name=self._schema_name,
|
|
114
|
+
)
|
|
134
115
|
|
|
135
|
-
@_restore_session
|
|
136
116
|
def set_experiment(
|
|
137
117
|
self,
|
|
138
118
|
experiment_name: str,
|
|
@@ -157,7 +137,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
157
137
|
self._run = None
|
|
158
138
|
return self._experiment
|
|
159
139
|
|
|
160
|
-
@_restore_session
|
|
161
140
|
def delete_experiment(
|
|
162
141
|
self,
|
|
163
142
|
experiment_name: str,
|
|
@@ -174,10 +153,8 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
174
153
|
self._run = None
|
|
175
154
|
|
|
176
155
|
@functools.wraps(registry.Registry.log_model)
|
|
177
|
-
@_restore_session
|
|
178
156
|
def log_model(
|
|
179
157
|
self,
|
|
180
|
-
/, # self needs to be a positional argument to stop mypy from complaining
|
|
181
158
|
model: Union[type_hints.SupportedModelType, ml_model.ModelVersion],
|
|
182
159
|
*,
|
|
183
160
|
model_name: str,
|
|
@@ -187,29 +164,40 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
187
164
|
with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
|
|
188
165
|
return self._registry.log_model(model, model_name=model_name, **kwargs)
|
|
189
166
|
|
|
190
|
-
@_restore_session
|
|
191
167
|
def start_run(
|
|
192
168
|
self,
|
|
193
169
|
run_name: Optional[str] = None,
|
|
194
170
|
) -> entities.Run:
|
|
195
171
|
"""
|
|
196
|
-
Start a new run.
|
|
172
|
+
Start a new run. If a run name of an existing run is provided, resumes the run if it is running.
|
|
197
173
|
|
|
198
174
|
Args:
|
|
199
175
|
run_name: The name of the run. If None, a default name will be generated.
|
|
200
176
|
|
|
201
177
|
Returns:
|
|
202
|
-
Run: The run that was started.
|
|
178
|
+
Run: The run that was started or resumed.
|
|
203
179
|
|
|
204
180
|
Raises:
|
|
205
|
-
RuntimeError: If a run is already active.
|
|
181
|
+
RuntimeError: If a run is already active. If a run with the same name exists but is not running.
|
|
206
182
|
"""
|
|
207
183
|
if self._run:
|
|
208
184
|
raise RuntimeError("A run is already active. Please end the current run before starting a new one.")
|
|
209
185
|
experiment = self._get_or_set_experiment()
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
186
|
+
|
|
187
|
+
if run_name is None:
|
|
188
|
+
run_name = self._generate_run_name(experiment)
|
|
189
|
+
elif runs := self._sql_client.show_runs_in_experiment(experiment_name=experiment.name, like=run_name):
|
|
190
|
+
if "RUNNING" != json.loads(runs[0][sql_client.RUN_METADATA_COL_NAME])["status"]:
|
|
191
|
+
raise RuntimeError(f"Run {run_name} exists but cannot be resumed as it is no longer running.")
|
|
192
|
+
else:
|
|
193
|
+
self._run = entities.Run(
|
|
194
|
+
experiment_tracking=self,
|
|
195
|
+
experiment_name=experiment.name,
|
|
196
|
+
run_name=sql_identifier.SqlIdentifier(run_name),
|
|
197
|
+
)
|
|
198
|
+
return self._run
|
|
199
|
+
|
|
200
|
+
run_name = sql_identifier.SqlIdentifier(run_name)
|
|
213
201
|
self._sql_client.add_run(
|
|
214
202
|
experiment_name=experiment.name,
|
|
215
203
|
run_name=run_name,
|
|
@@ -217,7 +205,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
217
205
|
self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
|
|
218
206
|
return self._run
|
|
219
207
|
|
|
220
|
-
@_restore_session
|
|
221
208
|
def end_run(self, run_name: Optional[str] = None) -> None:
|
|
222
209
|
"""
|
|
223
210
|
End the current run if no run name is provided. Otherwise, the specified run is ended.
|
|
@@ -247,7 +234,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
247
234
|
self._run = None
|
|
248
235
|
self._print_urls(experiment_name=experiment_name, run_name=run_name)
|
|
249
236
|
|
|
250
|
-
@_restore_session
|
|
251
237
|
def delete_run(
|
|
252
238
|
self,
|
|
253
239
|
run_name: str,
|
|
@@ -286,7 +272,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
286
272
|
"""
|
|
287
273
|
self.log_metrics(metrics={key: value}, step=step)
|
|
288
274
|
|
|
289
|
-
@_restore_session
|
|
290
275
|
def log_metrics(
|
|
291
276
|
self,
|
|
292
277
|
metrics: dict[str, float],
|
|
@@ -323,7 +308,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
323
308
|
"""
|
|
324
309
|
self.log_params({key: value})
|
|
325
310
|
|
|
326
|
-
@_restore_session
|
|
327
311
|
def log_params(
|
|
328
312
|
self,
|
|
329
313
|
params: dict[str, Any],
|
|
@@ -345,7 +329,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
345
329
|
params=json.dumps([param.to_dict() for param in params_list]),
|
|
346
330
|
)
|
|
347
331
|
|
|
348
|
-
@_restore_session
|
|
349
332
|
def log_artifact(
|
|
350
333
|
self,
|
|
351
334
|
local_path: str,
|
|
@@ -369,7 +352,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
369
352
|
file_path=file_path,
|
|
370
353
|
)
|
|
371
354
|
|
|
372
|
-
@_restore_session
|
|
373
355
|
def list_artifacts(
|
|
374
356
|
self,
|
|
375
357
|
run_name: str,
|
|
@@ -398,7 +380,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
398
380
|
artifact_path=artifact_path or "",
|
|
399
381
|
)
|
|
400
382
|
|
|
401
|
-
@_restore_session
|
|
402
383
|
def download_artifacts(
|
|
403
384
|
self,
|
|
404
385
|
run_name: str,
|
|
@@ -440,11 +421,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
440
421
|
return self._run
|
|
441
422
|
return self.start_run()
|
|
442
423
|
|
|
443
|
-
@_restore_session
|
|
444
424
|
def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
|
|
445
425
|
generator = hrid_generator.HRID16()
|
|
446
426
|
existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
|
|
447
|
-
existing_run_names = [row[sql_client.
|
|
427
|
+
existing_run_names = [row[sql_client.RUN_NAME_COL_NAME] for row in existing_runs]
|
|
448
428
|
for _ in range(1000):
|
|
449
429
|
run_name = generator.generate()[1]
|
|
450
430
|
if run_name not in existing_run_names:
|