snowflake-ml-python 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. snowflake/ml/_internal/telemetry.py +3 -2
  2. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +18 -19
  3. snowflake/ml/experiment/callback/keras.py +3 -0
  4. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  5. snowflake/ml/experiment/callback/xgboost.py +3 -0
  6. snowflake/ml/experiment/experiment_tracking.py +50 -70
  7. snowflake/ml/feature_store/feature_store.py +299 -69
  8. snowflake/ml/feature_store/feature_view.py +12 -6
  9. snowflake/ml/fileset/stage_fs.py +12 -1
  10. snowflake/ml/jobs/_utils/constants.py +12 -1
  11. snowflake/ml/jobs/_utils/payload_utils.py +7 -1
  12. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  13. snowflake/ml/jobs/_utils/types.py +5 -0
  14. snowflake/ml/jobs/job.py +19 -5
  15. snowflake/ml/jobs/manager.py +18 -7
  16. snowflake/ml/model/__init__.py +19 -0
  17. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  18. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  19. snowflake/ml/model/_client/model/model_version_impl.py +129 -11
  20. snowflake/ml/model/_client/ops/model_ops.py +11 -4
  21. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  22. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  23. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  25. snowflake/ml/model/_model_composer/model_method/model_method.py +4 -1
  26. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  27. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  28. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -0
  29. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  30. snowflake/ml/model/type_hints.py +16 -0
  31. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  32. snowflake/ml/monitoring/explain_visualize.py +3 -1
  33. snowflake/ml/version.py +1 -1
  34. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/METADATA +50 -4
  35. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/RECORD +38 -37
  36. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/WHEEL +0 -0
  37. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/licenses/LICENSE.txt +0 -0
  38. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.19.0.dist-info}/top_level.txt +0 -0
@@ -73,6 +73,7 @@ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
73
73
  class TelemetryProject(enum.Enum):
74
74
  MLOPS = "MLOps"
75
75
  MODELING = "ModelDevelopment"
76
+ EXPERIMENT_TRACKING = "ExperimentTracking"
76
77
  # TODO: Update with remaining projects.
77
78
 
78
79
 
@@ -464,14 +465,14 @@ def send_api_usage_telemetry(
464
465
 
465
466
  # noqa: DAR402
466
467
  """
467
- start_time = time.perf_counter()
468
-
469
468
  if subproject is not None and subproject_extractor is not None:
470
469
  raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
471
470
 
472
471
  def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
473
472
  @functools.wraps(func)
474
473
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
474
+ start_time = time.perf_counter()
475
+
475
476
  params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
476
477
 
477
478
  api_calls: list[Union[dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
@@ -1,17 +1,17 @@
1
1
  from typing import Optional
2
2
 
3
+ from snowflake.ml._internal import telemetry
3
4
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
5
  from snowflake.ml.experiment._client import artifact
5
6
  from snowflake.ml.model._client.sql import _base
6
7
  from snowflake.ml.utils import sql_client
7
8
  from snowflake.snowpark import file_operation, row, session
8
9
 
10
+ RUN_NAME_COL_NAME = "name"
11
+ RUN_METADATA_COL_NAME = "metadata"
9
12
 
10
- class ExperimentTrackingSQLClient(_base._BaseSQLClient):
11
-
12
- RUN_NAME_COL_NAME = "name"
13
- RUN_METADATA_COL_NAME = "metadata"
14
13
 
14
+ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
15
15
  def __init__(
16
16
  self,
17
17
  session: session.Session,
@@ -28,6 +28,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
28
28
  """
29
29
  super().__init__(session, database_name=database_name, schema_name=schema_name)
30
30
 
31
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
31
32
  def create_experiment(
32
33
  self,
33
34
  experiment_name: sql_identifier.SqlIdentifier,
@@ -39,24 +40,21 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
39
40
  self._session, f"CREATE EXPERIMENT {if_not_exists_sql} {experiment_fqn}"
40
41
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
41
42
 
43
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
42
44
  def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
43
45
  experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
44
46
  query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
45
47
  expected_rows=1, expected_cols=1
46
48
  ).validate()
47
49
 
48
- def add_run(
49
- self,
50
- *,
51
- experiment_name: sql_identifier.SqlIdentifier,
52
- run_name: sql_identifier.SqlIdentifier,
53
- live: bool = True,
54
- ) -> None:
50
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
51
+ def add_run(self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier) -> None:
55
52
  experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
56
53
  query_result_checker.SqlResultValidator(
57
- self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD {'LIVE' if live else ''} RUN {run_name}"
54
+ self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD RUN {run_name}"
58
55
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
59
56
 
57
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
60
58
  def commit_run(
61
59
  self,
62
60
  *,
@@ -68,6 +66,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
68
66
  self._session, f"ALTER EXPERIMENT {experiment_fqn} COMMIT RUN {run_name}"
69
67
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
70
68
 
69
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
71
70
  def drop_run(
72
71
  self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier
73
72
  ) -> None:
@@ -76,6 +75,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
76
75
  self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
77
76
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
78
77
 
78
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
79
79
  def modify_run_add_metrics(
80
80
  self,
81
81
  *,
@@ -89,6 +89,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
89
89
  f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
90
90
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
91
91
 
92
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
92
93
  def modify_run_add_params(
93
94
  self,
94
95
  *,
@@ -102,6 +103,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
102
103
  f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
103
104
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
104
105
 
106
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
105
107
  def put_artifact(
106
108
  self,
107
109
  *,
@@ -118,6 +120,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
118
120
  auto_compress=auto_compress,
119
121
  )[0]
120
122
 
123
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
121
124
  def list_artifacts(
122
125
  self,
123
126
  *,
@@ -125,13 +128,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
125
128
  run_name: sql_identifier.SqlIdentifier,
126
129
  artifact_path: str,
127
130
  ) -> list[artifact.ArtifactInfo]:
128
- results = (
129
- query_result_checker.SqlResultValidator(
130
- self._session, f"LIST {self._build_snow_uri(experiment_name, run_name, artifact_path)}"
131
- )
132
- .has_dimensions(expected_cols=4)
133
- .validate()
134
- )
131
+ results = self._session.sql(f"LIST {self._build_snow_uri(experiment_name, run_name, artifact_path)}").collect()
135
132
  return [
136
133
  artifact.ArtifactInfo(
137
134
  name=str(result.name).removeprefix(f"/versions/{run_name}/"),
@@ -142,6 +139,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
142
139
  for result in results
143
140
  ]
144
141
 
142
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
145
143
  def get_artifact(
146
144
  self,
147
145
  *,
@@ -155,6 +153,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
155
153
  target_directory=target_path,
156
154
  )[0]
157
155
 
156
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
158
157
  def show_runs_in_experiment(
159
158
  self, *, experiment_name: sql_identifier.SqlIdentifier, like: Optional[str] = None
160
159
  ) -> list[row.Row]:
@@ -20,6 +20,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
20
20
  log_params: bool = True,
21
21
  log_every_n_epochs: int = 1,
22
22
  model_name: Optional[str] = None,
23
+ version_name: Optional[str] = None,
23
24
  model_signature: Optional["ModelSignature"] = None,
24
25
  ) -> None:
25
26
  self._experiment_tracking = experiment_tracking
@@ -30,6 +31,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
30
31
  raise ValueError("`log_every_n_epochs` must be positive.")
31
32
  self.log_every_n_epochs = log_every_n_epochs
32
33
  self.model_name = model_name
34
+ self.version_name = version_name
33
35
  self.model_signature = model_signature
34
36
 
35
37
  def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
@@ -59,5 +61,6 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
59
61
  self._experiment_tracking.log_model( # type: ignore[call-arg]
60
62
  model=self.model,
61
63
  model_name=model_name,
64
+ version_name=self.version_name,
62
65
  signatures={"predict": self.model_signature},
63
66
  )
@@ -17,6 +17,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
17
17
  log_params: bool = True,
18
18
  log_every_n_epochs: int = 1,
19
19
  model_name: Optional[str] = None,
20
+ version_name: Optional[str] = None,
20
21
  model_signature: Optional["ModelSignature"] = None,
21
22
  ) -> None:
22
23
  self._experiment_tracking = experiment_tracking
@@ -27,6 +28,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
27
28
  raise ValueError("`log_every_n_epochs` must be positive.")
28
29
  self.log_every_n_epochs = log_every_n_epochs
29
30
  self.model_name = model_name
31
+ self.version_name = version_name
30
32
  self.model_signature = model_signature
31
33
 
32
34
  super().__init__(eval_result={})
@@ -50,6 +52,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
50
52
  self._experiment_tracking.log_model( # type: ignore[call-arg]
51
53
  model=env.model,
52
54
  model_name=model_name,
55
+ version_name=self.version_name,
53
56
  signatures={"predict": self.model_signature},
54
57
  )
55
58
  else:
@@ -20,6 +20,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
20
20
  log_params: bool = True,
21
21
  log_every_n_epochs: int = 1,
22
22
  model_name: Optional[str] = None,
23
+ version_name: Optional[str] = None,
23
24
  model_signature: Optional["ModelSignature"] = None,
24
25
  ) -> None:
25
26
  self._experiment_tracking = experiment_tracking
@@ -30,6 +31,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
30
31
  raise ValueError("`log_every_n_epochs` must be positive.")
31
32
  self.log_every_n_epochs = log_every_n_epochs
32
33
  self.model_name = model_name
34
+ self.version_name = version_name
33
35
  self.model_signature = model_signature
34
36
 
35
37
  def before_training(self, model: xgb.Booster) -> xgb.Booster:
@@ -61,6 +63,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
61
63
  self._experiment_tracking.log_model( # type: ignore[call-arg]
62
64
  model=model,
63
65
  model_name=model_name,
66
+ version_name=self.version_name,
64
67
  signatures={"predict": self.model_signature},
65
68
  )
66
69
 
@@ -1,13 +1,13 @@
1
1
  import functools
2
2
  import json
3
3
  import sys
4
- from typing import Any, Callable, Concatenate, Optional, ParamSpec, TypeVar, Union
4
+ from typing import Any, Optional, Union
5
5
  from urllib.parse import quote
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml import model as ml_model, registry
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
- from snowflake.ml._internal.utils import mixins, sql_identifier
10
+ from snowflake.ml._internal.utils import connection_params, sql_identifier
11
11
  from snowflake.ml.experiment import (
12
12
  _entities as entities,
13
13
  _experiment_info as experiment_info,
@@ -21,34 +21,12 @@ from snowflake.ml.utils import sql_client as sql_client_utils
21
21
 
22
22
  DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
23
23
 
24
- P = ParamSpec("P")
25
- T = TypeVar("T")
26
24
 
27
-
28
- def _restore_session(
29
- func: Callable[Concatenate["ExperimentTracking", P], T],
30
- ) -> Callable[Concatenate["ExperimentTracking", P], T]:
31
- @functools.wraps(func)
32
- def wrapper(self: "ExperimentTracking", /, *args: P.args, **kwargs: P.kwargs) -> T:
33
- if self._session is None:
34
- if self._session_state is None:
35
- raise RuntimeError(
36
- f"Session is not set before calling {func.__name__}, and there is no session state to restore from"
37
- )
38
- self._set_session(self._session_state)
39
- if self._session is None:
40
- raise RuntimeError(f"Failed to restore session before calling {func.__name__}")
41
- return func(self, *args, **kwargs)
42
-
43
- return wrapper
44
-
45
-
46
- class ExperimentTracking(mixins.SerializableSessionMixin):
25
+ class ExperimentTracking:
47
26
  """
48
27
  Class to manage experiments in Snowflake.
49
28
  """
50
29
 
51
- @snowpark._internal.utils.private_preview(version="1.9.1")
52
30
  def __init__(
53
31
  self,
54
32
  session: snowpark.Session,
@@ -93,10 +71,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
93
71
  database_name=self._database_name,
94
72
  schema_name=self._schema_name,
95
73
  )
96
- self._session: Optional[snowpark.Session] = session
97
- # Used to store information about the session if the session could not be restored during unpickling
98
- # _session_state is None if and only if _session is not None
99
- self._session_state: Optional[mixins._SessionState] = None
74
+ self._session = session
100
75
 
101
76
  # The experiment in context
102
77
  self._experiment: Optional[entities.Experiment] = None
@@ -104,35 +79,40 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
104
79
  self._run: Optional[entities.Run] = None
105
80
 
106
81
  def __getstate__(self) -> dict[str, Any]:
107
- state = super().__getstate__()
82
+ parent_state = (
83
+ super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
84
+ if hasattr(super(), "__getstate__")
85
+ else self.__dict__
86
+ )
87
+ state = dict(parent_state) # Create a copy so we can safely modify the state
88
+
108
89
  # Remove unpicklable attributes
90
+ state["_session"] = None
109
91
  state["_sql_client"] = None
110
92
  state["_registry"] = None
111
93
  return state
112
94
 
113
- def _set_session(self, session_state: mixins._SessionState) -> None:
114
- try:
115
- super()._set_session(session_state)
116
- assert self._session is not None
117
- except (snowpark.exceptions.SnowparkSessionException, AssertionError):
118
- # If session was not set, store the session state
119
- self._session = None
120
- self._session_state = session_state
95
+ def __setstate__(self, state: dict[str, Any]) -> None:
96
+ if hasattr(super(), "__setstate__"):
97
+ super().__setstate__(state) # type: ignore[misc]
121
98
  else:
122
- # If session was set, clear the session state, and reinitialize the SQL client and registry
123
- self._session_state = None
124
- self._sql_client = sql_client.ExperimentTrackingSQLClient(
125
- session=self._session,
126
- database_name=self._database_name,
127
- schema_name=self._schema_name,
128
- )
129
- self._registry = registry.Registry(
130
- session=self._session,
131
- database_name=self._database_name,
132
- schema_name=self._schema_name,
133
- )
99
+ self.__dict__.update(state)
100
+
101
+ # Restore unpicklable attributes
102
+ options: dict[str, Any] = connection_params.SnowflakeLoginOptions()
103
+ options["client_session_keep_alive"] = True # Needed for long-running training jobs
104
+ self._session = snowpark.Session.builder.configs(options).getOrCreate()
105
+ self._sql_client = sql_client.ExperimentTrackingSQLClient(
106
+ session=self._session,
107
+ database_name=self._database_name,
108
+ schema_name=self._schema_name,
109
+ )
110
+ self._registry = registry.Registry(
111
+ session=self._session,
112
+ database_name=self._database_name,
113
+ schema_name=self._schema_name,
114
+ )
134
115
 
135
- @_restore_session
136
116
  def set_experiment(
137
117
  self,
138
118
  experiment_name: str,
@@ -157,7 +137,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
157
137
  self._run = None
158
138
  return self._experiment
159
139
 
160
- @_restore_session
161
140
  def delete_experiment(
162
141
  self,
163
142
  experiment_name: str,
@@ -174,10 +153,8 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
174
153
  self._run = None
175
154
 
176
155
  @functools.wraps(registry.Registry.log_model)
177
- @_restore_session
178
156
  def log_model(
179
157
  self,
180
- /, # self needs to be a positional argument to stop mypy from complaining
181
158
  model: Union[type_hints.SupportedModelType, ml_model.ModelVersion],
182
159
  *,
183
160
  model_name: str,
@@ -187,29 +164,40 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
187
164
  with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
188
165
  return self._registry.log_model(model, model_name=model_name, **kwargs)
189
166
 
190
- @_restore_session
191
167
  def start_run(
192
168
  self,
193
169
  run_name: Optional[str] = None,
194
170
  ) -> entities.Run:
195
171
  """
196
- Start a new run.
172
+ Start a new run. If a run name of an existing run is provided, resumes the run if it is running.
197
173
 
198
174
  Args:
199
175
  run_name: The name of the run. If None, a default name will be generated.
200
176
 
201
177
  Returns:
202
- Run: The run that was started.
178
+ Run: The run that was started or resumed.
203
179
 
204
180
  Raises:
205
- RuntimeError: If a run is already active.
181
+ RuntimeError: If a run is already active. If a run with the same name exists but is not running.
206
182
  """
207
183
  if self._run:
208
184
  raise RuntimeError("A run is already active. Please end the current run before starting a new one.")
209
185
  experiment = self._get_or_set_experiment()
210
- run_name = (
211
- sql_identifier.SqlIdentifier(run_name) if run_name is not None else self._generate_run_name(experiment)
212
- )
186
+
187
+ if run_name is None:
188
+ run_name = self._generate_run_name(experiment)
189
+ elif runs := self._sql_client.show_runs_in_experiment(experiment_name=experiment.name, like=run_name):
190
+ if "RUNNING" != json.loads(runs[0][sql_client.RUN_METADATA_COL_NAME])["status"]:
191
+ raise RuntimeError(f"Run {run_name} exists but cannot be resumed as it is no longer running.")
192
+ else:
193
+ self._run = entities.Run(
194
+ experiment_tracking=self,
195
+ experiment_name=experiment.name,
196
+ run_name=sql_identifier.SqlIdentifier(run_name),
197
+ )
198
+ return self._run
199
+
200
+ run_name = sql_identifier.SqlIdentifier(run_name)
213
201
  self._sql_client.add_run(
214
202
  experiment_name=experiment.name,
215
203
  run_name=run_name,
@@ -217,7 +205,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
217
205
  self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
218
206
  return self._run
219
207
 
220
- @_restore_session
221
208
  def end_run(self, run_name: Optional[str] = None) -> None:
222
209
  """
223
210
  End the current run if no run name is provided. Otherwise, the specified run is ended.
@@ -247,7 +234,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
247
234
  self._run = None
248
235
  self._print_urls(experiment_name=experiment_name, run_name=run_name)
249
236
 
250
- @_restore_session
251
237
  def delete_run(
252
238
  self,
253
239
  run_name: str,
@@ -286,7 +272,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
286
272
  """
287
273
  self.log_metrics(metrics={key: value}, step=step)
288
274
 
289
- @_restore_session
290
275
  def log_metrics(
291
276
  self,
292
277
  metrics: dict[str, float],
@@ -323,7 +308,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
323
308
  """
324
309
  self.log_params({key: value})
325
310
 
326
- @_restore_session
327
311
  def log_params(
328
312
  self,
329
313
  params: dict[str, Any],
@@ -345,7 +329,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
345
329
  params=json.dumps([param.to_dict() for param in params_list]),
346
330
  )
347
331
 
348
- @_restore_session
349
332
  def log_artifact(
350
333
  self,
351
334
  local_path: str,
@@ -369,7 +352,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
369
352
  file_path=file_path,
370
353
  )
371
354
 
372
- @_restore_session
373
355
  def list_artifacts(
374
356
  self,
375
357
  run_name: str,
@@ -398,7 +380,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
398
380
  artifact_path=artifact_path or "",
399
381
  )
400
382
 
401
- @_restore_session
402
383
  def download_artifacts(
403
384
  self,
404
385
  run_name: str,
@@ -440,11 +421,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
440
421
  return self._run
441
422
  return self.start_run()
442
423
 
443
- @_restore_session
444
424
  def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
445
425
  generator = hrid_generator.HRID16()
446
426
  existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
447
- existing_run_names = [row[sql_client.ExperimentTrackingSQLClient.RUN_NAME_COL_NAME] for row in existing_runs]
427
+ existing_run_names = [row[sql_client.RUN_NAME_COL_NAME] for row in existing_runs]
448
428
  for _ in range(1000):
449
429
  run_name = generator.generate()[1]
450
430
  if run_name not in existing_run_names: